# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse
import json
import logging
import os
import torch
from tqdm import tqdm

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset

from pytorch_transformers.tokenization_bert import BertTokenizer

from biencoder import BiEncoderRanker
from nn_prediction import get_topk_predictions
from data_biencoder import load_mentions, load_entities, process_m_data, process_e_data
from utils import get_logger, accuracy, write_to_file, save_model

def encode_candidate(biranker, dls, logger):
    cand_encode_dict = {}
    biranker.model.eval()
    for domain in dls.keys():
        dl = dls[domain]
        iter_ = tqdm(dl, desc="Encode "+domain)
        cand_encode_dict[domain] = []
        for step, batch in enumerate(iter_):
            cands = batch[0]
            cands = cands.to('cuda')
            cand_encode = biranker.encode_candidate(cands)
            cand_encode_dict[domain].append(cand_encode)
        cand_encode_dict[domain] = torch.cat(cand_encode_dict[domain], dim=0)
    
    return cand_encode_dict

def get_dataloader(tensor_data, bs):
    dls = dict()
    for domain in tensor_data.keys():
        if 'g_doc' in tensor_data[domain].keys():
            td = TensorDataset(tensor_data[domain]['tokens'], tensor_data[domain]['g_doc'])
        else:
            td = TensorDataset(tensor_data[domain]['tokens'])
        dl_sampler = SequentialSampler(td)
        dls[domain] = DataLoader(td, sampler=dl_sampler, batch_size=bs)

    return dls


def main(params):
    output_path = os.path.join(params["output_path"], params["mode"])
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = get_logger(output_path)

    # load mentions and entities
    train_m, val_m, test_m = load_mentions(params['data_path'], logger)
    train_e, val_e, test_e = load_entities(params['data_path'], logger)

    if params["mode"] == "train":
        target_m = train_m
        target_e = train_e
    elif params["mode"] == "val":
        target_m = val_m
        target_e = val_e
    elif params["mode"] == "test":
        target_m = test_m
        target_e = test_e

    # process m and e
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
    target_m_tensor_data, target_m_num = process_m_data(target_m, target_e, tokenizer, params['max_context_length'])
    target_e_tensor_data, target_e_num = process_e_data(target_e, tokenizer, params['max_cand_length'])

    # init model, load model
    biranker = BiEncoderRanker(params)
    device = biranker.device
    n_gpu = biranker.n_gpu

    saved_dict = torch.load(params['path_to_model'])
    biranker.model.load_state_dict(saved_dict['sd'])
    biranker.model = torch.nn.DataParallel(biranker.model)
    model = biranker.model

    # generate dataloaders 
    eval_batch_size = params["eval_batch_size"] * n_gpu
    encode_batch_size = params["encode_batch_size"] * n_gpu
    params['eval_batch_size'] = eval_batch_size
    params['encode_batch_size'] = encode_batch_size 

    target_m_dls = get_dataloader(target_m_tensor_data, eval_batch_size)
    target_e_dls = get_dataloader(target_e_tensor_data, encode_batch_size)

    # encode candidates 
    candidate_encoding = encode_candidate(biranker, target_e_dls, logger=logger)

    # prediction
    get_topk_predictions(biranker, target_m_dls, candidate_encoding, logger, params["top_k"])

if __name__ == "__main__":
    from evaluating_config import *
    print(args)

    params = args.__dict__
    main(params)
 
