# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import argparse
import pickle
import torch
import json
import sys
import io
import random
import time
import numpy as np
import logging

from tqdm import tqdm, trange
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from pytorch_transformers.optimization import WarmupLinearSchedule
from pytorch_transformers.tokenization_bert import BertTokenizer

from biencoder import BiEncoderRanker
from utils import get_logger, accuracy, write_to_file, save_model
from optimizer import get_bert_optimizer
from data_biencoder import load_mentions, load_entities, process_m_data, process_e_data

logger = None

# The evaluate function during training uses in-batch negatives:
# for a batch of size B, the labels from the batch are used as label candidates
# B is controlled by the parameter eval_batch_size
def evaluate(biranker, eval_dls, eval_e_tensor, logger):
    biranker.model.eval()
    eval_accuracy = 0.0
    nb_eval_examples = 0
    for domain in eval_dls.keys():
        dl = eval_dls[domain]
        es = eval_e_tensor[domain]['tokens']
        iter_ = tqdm(dl, desc="Evaluation "+domain)
        acc_tmp, nb_tmp = 0.0, 0

        for step, batch in enumerate(iter_):
            context_input, golden_doc = batch
            candidate_input = es[golden_doc]
            context_input = context_input.to("cuda")
            candidate_input = candidate_input.to("cuda")
            with torch.no_grad():
                eval_loss, logits = biranker(context_input, candidate_input)

            logits = logits.detach().cpu().numpy()
            # Using in-batch negatives, the label ids are diagonal
            label_ids = torch.LongTensor(torch.arange(context_input.size(0))).numpy()
            tmp_eval_accuracy, _ = accuracy(logits, label_ids)

            acc_tmp += tmp_eval_accuracy
            nb_tmp += context_input.size(0)

        logger.info("Eval accuracy "+domain+" "+str(acc_tmp / nb_tmp))
        eval_accuracy += acc_tmp
        nb_eval_examples += nb_tmp

    normalized_eval_accuracy = eval_accuracy / nb_eval_examples
    logger.info("Eval accuracy: %.5f" % normalized_eval_accuracy)

    return normalized_eval_accuracy

def get_scheduler(params, optimizer, len_train_data, logger):
    batch_size = params["train_batch_size"]
    epochs = params["num_train_epochs"]

    num_train_steps = int(len_train_data / batch_size) * epochs
    num_warmup_steps = int(num_train_steps * params["warmup_proportion"])

    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_train_steps)
    logger.info(" Num optimization steps = %d" % num_train_steps)
    logger.info(" Num warmup steps = %d", num_warmup_steps)
    return scheduler

def get_dataloader(tensor_data, bs):
    dls = dict()
    for domain in tensor_data.keys():
        if 'g_doc' in tensor_data[domain].keys():
            td = TensorDataset(tensor_data[domain]['tokens'], tensor_data[domain]['g_doc'])
        else:
            td = TensorDataset(tensor_data[domain]['tokens'])
        dl_sampler = RandomSampler(td)
        dls[domain] = DataLoader(td, sampler=dl_sampler, batch_size=bs) 

    return dls

def set_parameters(parameters, target):
    for p in parameters:
        p.requires_grad = target

def main(params):

    # log
    model_output_path = params["output_path"]
    if not os.path.exists(model_output_path):
        os.makedirs(model_output_path)
    logger = get_logger(params["output_path"])

    # load mentions and entities
    train_m, val_m, test_m = load_mentions(params['data_path'], logger)
    train_e, val_e, test_e = load_entities(params['data_path'], logger)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
    train_m_tensor_data, train_m_num = process_m_data(train_m, train_e, tokenizer, params['max_context_length'])
    val_m_tensor_data, val_m_num = process_m_data(val_m, val_e, tokenizer, params['max_context_length'])
    test_m_tensor_data, test_m_num = process_m_data(test_m, test_e, tokenizer, params['max_context_length'])

    train_e_tensor_data, train_e_num = process_e_data(train_e, tokenizer, params['max_cand_length'])
    val_e_tensor_data, val_e_num = process_e_data(val_e, tokenizer, params['max_cand_length'])
    test_e_tensor_data, test_e_num = process_e_data(test_e, tokenizer, params['max_cand_length'])

    # Init model
    biranker = BiEncoderRanker(params)
    biranker.model = torch.nn.DataParallel(biranker.model)
    model = biranker.model

    device = biranker.device
    n_gpu = biranker.n_gpu

    train_batch_size = params["train_batch_size"] * n_gpu
    eval_batch_size = params["eval_batch_size"] * n_gpu
    params['train_batch_size'] = train_batch_size
    params['eval_batch_size'] = eval_batch_size

    # evaluate before training 
    logger.info("Evaluate before training")
    test_m_dls = get_dataloader(test_m_tensor_data, eval_batch_size)
    results = evaluate(biranker, test_m_dls, test_e_tensor_data, logger=logger)
    
    # start training
    write_to_file(os.path.join(model_output_path, "training_params.txt"), str(params))
    logger.info("Starting training")
    logger.info("device: {} n_gpu: {}".format(device, n_gpu))

    optimizer, parameters_encoder, parameters_A = get_bert_optimizer([model], params["learning_rate"])
    scheduler = get_scheduler(params, optimizer, train_m_num, logger)

    model.train()
    best_epoch_idx = -1
    best_score = -1
    step = 0
    num_train_epochs = params["num_train_epochs"]

    for epoch_idx in trange(int(num_train_epochs), desc="Epoch"):

        # train dataloader
        train_m_dls = get_dataloader(train_m_tensor_data, train_batch_size)
        tr_loss, adv_loss, tur_loss = 0.0, 0.0, 0.0
        for domain in train_m_dls.keys():
            dl = train_m_dls[domain]
            es = train_e_tensor_data[domain]['tokens']
            tmp = list(range(0, len(es)))
            iter_ = tqdm(dl, desc="Train "+domain)
            for _, batch in enumerate(iter_):
                context_input, golden_doc = batch
                golden_doc = golden_doc.tolist()
                candidate_input = es[golden_doc] 

                context_input = context_input.to("cuda")
                candidate_input = candidate_input.to("cuda")

                if params['do_adv_loss'] == True:
                    set_parameters(parameters_encoder, False)
                    set_parameters(parameters_A, True)
                    loss, _ = biranker(context_input, candidate_input, if_turbulence=True)
                    adv_loss += loss.item()

                    # maxmize the loss
                    (-loss).backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), params["max_grad_norm"])
                    optimizer.step()
                    optimizer.zero_grad()
                    # project A
                    for A in parameters_A:
                        if torch.norm(A) > params['epsilon']:
                            A.data = ((A / torch.norm(A)) * params['epsilon']).data

                    # minimize the loss
                    set_parameters(parameters_encoder, True)
                    set_parameters(parameters_A, False)
                    loss, _ = biranker(context_input, candidate_input, if_turbulence=True)
                    tur_loss += loss.item()

                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), params["max_grad_norm"])
                    optimizer.step()
                    optimizer.zero_grad()

                    # minimize the normal loss
                    loss, _ = biranker(context_input, candidate_input)
                    tr_loss += loss.item()

                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), params["max_grad_norm"])
                    optimizer.step()
                    optimizer.zero_grad()
                    scheduler.step()

                else:
                    loss, _ = biranker(context_input, candidate_input)
                    tr_loss += loss.item()

                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), params["max_grad_norm"])
                    optimizer.step()
                    optimizer.zero_grad()
                    scheduler.step()

                if (step + 1) % params["print_interval"] == 0:
                    logger.info("Step {} - epoch {} average loss: {}, average adv loss: {}, average tur loss: {}\n".format(step, epoch_idx, tr_loss / (step + 1), adv_loss / (step + 1), tur_loss / (step + 1)))

                # evaluate on valid set
                if (step + 1) % params["eval_interval"] == 0:
                    logger.info("Evaluation on the development dataset")
                    results = evaluate(biranker, test_m_dls, test_e_tensor_data, logger=logger)
                    model.train()
                    logger.info("\n")

                step += 1
            
        logger.info("***** Saving fine - tuned model *****")
        epoch_output_folder_path = os.path.join(model_output_path, "epoch_{}".format(epoch_idx))
        save_model(model, optimizer, scheduler, tokenizer, epoch_output_folder_path)

        results = evaluate(biranker, test_m_dls, test_e_tensor_data, logger=logger)
        ls = [best_score, results]
        li = [best_epoch_idx, epoch_idx]

        best_score = ls[np.argmax(ls)]
        best_epoch_idx = li[np.argmax(ls)]
        logger.info("best evaluation accuracy: {:.5f}".format(best_score))
        logger.info("best epoch idx: {}".format(best_epoch_idx))
        logger.info("\n")
   

if __name__ == "__main__":
    from training_config import args
    print(args)

    params = args.__dict__
    main(params)
