import json
import argparse
import os
import random
from copy import deepcopy
import logging
import pickle
from tqdm import tqdm, trange
import timeit

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

import transformers
from transformers import AutoModelForMaskedLM, AutoConfig, AutoTokenizer
from transformers.trainer_utils import is_main_process

logger = logging.getLogger(__name__)

class InputFeatures(object):
    def __init__(self, 
                 input_ids, 
                 token_type_ids, 
                 attention_mask, 
                 labels,
                 label_indexs,
                 input_label_ids):
        self.input_ids = input_ids
        self.token_type_ids = token_type_ids
        self.attention_mask = attention_mask
        self.labels = labels
        self.label_indexs = label_indexs
        self.input_label_ids = input_label_ids

def convert_data_to_features(args, tokenizer, context, answer, answer_start, keywords):

    context_tokens = tokenizer.tokenize(context)
    answer_tokens = tokenizer.tokenize(answer)
    
    cls_token = tokenizer.cls_token
    sep_token = tokenizer.sep_token
    mask_token = tokenizer.mask_token

    max_context_length = args.max_seq_length - len(answer_tokens) - (args.max_query_length * 2) - 4

    if len(context_tokens) > max_context_length:
        if answer_start == 0:
            context_tokens = context_tokens[:max_context_length]
        else:
            context_half_len = int(max_context_length / 2)
            char_num = 0

            for i, context_token in enumerate(context_tokens):
                if '##' in context_token:
                    char_num += len(context_token.replace('##',''))
                else:
                    char_num += len(context_token) + 1

                if context_token == answer_tokens[0] and char_num >= answer_start:
                    answer_token_start = i
                    break

            left_bound = answer_token_start - context_half_len
            right_bound = answer_token_start + context_half_len

            if left_bound < 0:
                context_tokens = context_tokens[:max_context_length]
            elif right_bound > len(context_tokens):
                context_tokens = context_tokens[len(context_tokens) - max_context_length:]
            else:
                context_tokens = context_tokens[answer_token_start - context_half_len:answer_token_start + context_half_len]

    input_tokens = [cls_token] + context_tokens + [sep_token]
    token_type_ids = [0] * len(input_tokens)

    input_tokens += answer_tokens + [sep_token]
    while len(token_type_ids) < len(input_tokens):
        token_type_ids.append(1)

    label_indexs = len(input_tokens)

    attention_mask = [1] * len(input_tokens)

    input_ids = tokenizer.convert_tokens_to_ids(input_tokens)

    input_label_tokens = []
    for keyword in keywords:

        if args.eval_type == 'gen_keywords':
            input_label_tokens += tokenizer.tokenize(mask_token) + tokenizer.tokenize(keyword)
        else:
            input_label_tokens += tokenizer.tokenize(mask_token) + tokenizer.tokenize(' ' + keyword)

    input_label_tokens += tokenizer.tokenize(mask_token)
    input_label_ids = tokenizer.convert_tokens_to_ids(input_label_tokens)


    # Zero-pad up to the sequence length.
    while len(input_ids) < args.max_seq_length:
        input_ids.append(0)
        token_type_ids.append(0)
        attention_mask.append(0)    

    assert len(input_ids) == args.max_seq_length
    assert len(token_type_ids) == args.max_seq_length
    assert len(attention_mask) == args.max_seq_length

    return InputFeatures(
            input_ids = input_ids,
            attention_mask = attention_mask,
            token_type_ids = token_type_ids,
            labels = '',
            label_indexs = label_indexs,
            input_label_ids = input_label_ids
            )


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def predict(args, model, tokenizer, features, beam_size=1):

    input_ids = torch.tensor([features.input_ids], dtype=torch.long)
    attention_mask = torch.tensor([features.attention_mask], dtype=torch.long)
    token_type_ids = torch.tensor([features.token_type_ids], dtype=torch.long)
    
    input_ids = input_ids.to(args.device)
    attention_mask = attention_mask.to(args.device)
    token_type_ids = token_type_ids.to(args.device)

    result = []
    model.eval()

    all_candidates = [{'prediction_ids' : [], 'input_label_ids' : features.input_label_ids, 'score' : 0, 'iter' : 0}]
    sep_token = tokenizer.sep_token
    mask_token = tokenizer.mask_token
    EOF_id = tokenizer.convert_tokens_to_ids(sep_token)
    MASK_id = tokenizer.convert_tokens_to_ids(mask_token)
    L_EOF_flag = 0
    R_EOF_flag = 0
    error = False

    with torch.no_grad():
        while(len(result) < beam_size):
            iter_candidates = []
            for all_candidate in all_candidates:

                seq_input_ids = deepcopy(input_ids)
                seq_token_type_ids = deepcopy(token_type_ids)
                seq_attention_mask = deepcopy(attention_mask)
                label_indexs = features.label_indexs

                input_label_ids = []
                prediction_num = 0
                for index, id in enumerate(all_candidate['input_label_ids']):
                    if all_candidate['iter'] != 0 and id == MASK_id:
                        label_id = all_candidate['prediction_ids'][prediction_num]
                        prediction_num += 1  

                        if label_id != EOF_id:          
                            seq_input_ids[0][label_indexs] = MASK_id
                            seq_token_type_ids[0][label_indexs] = 0
                            seq_attention_mask[0][label_indexs] = 1  
                            label_indexs += 1
                            input_label_ids.append(MASK_id)

                            seq_input_ids[0][label_indexs] = label_id
                            seq_token_type_ids[0][label_indexs] = 0
                            seq_attention_mask[0][label_indexs] = 1
                            label_indexs += 1
                            input_label_ids.append(label_id) 

                            seq_input_ids[0][label_indexs] = MASK_id
                            seq_token_type_ids[0][label_indexs] = 0
                            seq_attention_mask[0][label_indexs] = 1
                            label_indexs += 1
                            input_label_ids.append(MASK_id)
                        else:
                            continue
                    else:
                        label_id = id
                        seq_input_ids[0][label_indexs] = label_id
                        seq_token_type_ids[0][label_indexs] = 0
                        seq_attention_mask[0][label_indexs] = 1
                        label_indexs += 1
                        input_label_ids.append(label_id)
                           
                mask_indexs = [x.item() for x in (seq_input_ids[0] == MASK_id).nonzero()]

                inputs = {
                    "input_ids": seq_input_ids,
                    "attention_mask": seq_attention_mask,
                    "token_type_ids": seq_token_type_ids              
                }

                if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart", "longformer"]:
                    del inputs["token_type_ids"]

                # XLNet and XLM use more arguments for their predictions
                if args.model_type in ["xlnet", "xlm"]:
                    inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
                    # for lang_id-sensitive xlm models
                    if hasattr(model, "config") and hasattr(model.config, "lang2id"):
                        inputs.update(
                            {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
                        )

                outputs = model(**inputs)
                
                previous_candidates = [{'prediction_ids' : [] , 'score' : 0, 'mask_num' : 0}]
                for mask_index in mask_indexs:
                    logit_prob = F.log_softmax(outputs['logits'][0][mask_index], dim=0)
                    prob_result = []
                    while(len(prob_result) < beam_size + 1):
                        predicted_id = torch.argmax(logit_prob).item()
                        score = logit_prob[predicted_id].item()
                        logit_prob[predicted_id] = -1000000000
                        prob_result.append((predicted_id, score))
                    prob_result = sorted(prob_result, key=lambda x: x[1], reverse=True)

                    current_candidates = []
                    for pre_candidate in previous_candidates:
                        tmp_num = 0
                        for id, logits in prob_result:

                            prediction_ids = pre_candidate['prediction_ids'] + [id]
                            
                            score = pre_candidate['score'] + logits
                            mask_num = pre_candidate['mask_num'] + 1

                            current_candidates.append({'prediction_ids' : prediction_ids, 'score' : score, 'mask_num' : mask_num, 'token': tokenizer.convert_ids_to_tokens(prediction_ids)})
                            tmp_num += 1
                            if tmp_num == beam_size:
                                break

                    previous_candidates = sorted(current_candidates, key=lambda x: x['score'], reverse=True)[:beam_size]

                for pre_candidate in previous_candidates:

                    score = all_candidate['score'] + (pre_candidate['score'] / pre_candidate['mask_num'])
                    iter = all_candidate['iter'] + 1
                    iter_candidates.append({'prediction_ids' : pre_candidate['prediction_ids'], 'input_label_ids' : input_label_ids, 'score' : score, 'iter' : iter})                  

            iter_result = []
            for iter_candidate in sorted(iter_candidates, key=lambda x: x['score'], reverse=True):
                
                if len(set(iter_candidate['prediction_ids'])) == 1 and iter_candidate['prediction_ids'][0] == EOF_id:
                    result.append(iter_candidate)
                else:
                    iter_result.append(iter_candidate)

                if len(iter_result) == beam_size:
                    break

            if len(result) == beam_size:
                break

            all_candidates = iter_result

            if len(all_candidates[0]['input_label_ids']) >= args.max_query_length * 2 - len(all_candidates[0]['prediction_ids']):
                if len(result) == 0:
                    error = True
                result += all_candidates[:beam_size-len(result)]
                break

    predictions = []
    for candidate in result:
        prediction_tokens = tokenizer.convert_ids_to_tokens(candidate['input_label_ids'])
        prediction_text = tokenizer.convert_tokens_to_string(prediction_tokens).replace(mask_token,'')
        score = candidate['score'] / candidate['iter']
        predictions.append({'prediction_text' : prediction_text, 'score' : score, 'error' : error})

    return sorted(predictions, key=lambda x: x['score'], reverse=True)


def evaluate(args, model, tokenizer, beam_size=1):

    start_time = timeit.default_timer()
    
    """ Load datas """
    with open(args.predict_file, 'rb') as f:
        eval_dataset = json.load(f)
    
    num = 0
    error = []
    gen_question_text = ''
    for index, data in enumerate(tqdm(eval_dataset)):
        try:
            context = data['context']

            if args.eval_type != None:
                if 'gen_keywords' in args.eval_type:
                    keywords = data[args.eval_type][0]    
                else:
                    keywords = data[args.eval_type]
            else:
                keywords = []

            answer_text = ''
            result = []
            if 'race' in args.predict_file:
                answer_text = data['answer']
                features = convert_data_to_features(args, tokenizer, context, answer_text, 0, keywords)
                result += predict(args, model, tokenizer, features, beam_size)
            else:
                for answer in data['answers']:
                    if answer_text == answer['text']:
                        continue
                    answer_text = answer['text']
                    answer_start = answer['answer_start']
                    features = convert_data_to_features(args, tokenizer, context, answer_text, answer_start, keywords)
                    result += predict(args, model, tokenizer, features, beam_size)
            
            if len(result) > beam_size:
                result = sorted(result, key=lambda x: x['score'], reverse=True)

            gen_questions = []
            for ele in result:
                gen_questions.append(ele['prediction_text'])

            data['gen_questions'] = gen_questions
 
            if len(gen_questions) > 0:
                gen_question_text += gen_questions[0] + '\n'
                num += 1
                if result[0]['error'] == True :
                    error.append(index)
            else:
                gen_question_text += '\n'

        except Exception as e:
            data['gen_questions'] = []
            gen_question_text += '\n'
            raise e
            continue
        
    evalTime = timeit.default_timer() - start_time

    logger.info("Evaluation done %d in %f secs (%f sec per example)", num, evalTime, evalTime/num)
    logger.info("error_list: %s" % " ".join([str(x) for x in error]))

    if 'dev' in args.predict_file:
        data_type = 'dev'
    elif 'test' in args.predict_file:
        data_type = 'test'
    else:
        data_type = 'eval'

    output_file = args.output_dir+'{0}_beam_size_{1}_{2}'.format(str(data_type), str(beam_size), str(args.eval_type))


    json.dump(eval_dataset, open(output_file + '.json','w'))
    with open(output_file + '.txt', 'w') as file:
        file.write(gen_question_text.strip())

def main():

    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type bert",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models",
    )

    # Other parameters
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        help="The output directory where the model checkpoints and predictions will be written.",
    )
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="The input evaluation file. If a data dir is specified, will look for the file there"
        + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
    )
    parser.add_argument(
        "--eval_type", default=None, type=str, help="eval_type: None, noun_keywords, noun_verb_keywords, random_one_keywords, random_two_keywords, random_three_keywords"
    )    
    parser.add_argument(
        "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from huggingface.co",
    )

    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded.",
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help="When splitting up a long document into chunks, how much stride to take between chunks.",
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help="The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.",
    )
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help="The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.",
    )
    parser.add_argument(
        "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
    )

    parser.add_argument(
        "--beam_size", type=int, default=1, help="beam search size"
    )

    parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")

    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )

    args = parser.parse_args()

    if args.doc_stride >= args.max_seq_length - args.max_query_length:
        logger.warning(
            "WARNING - You've set a doc stride which may be superior to the document length in some "
            "examples. This could result in errors when building features from the examples. Please reduce the doc "
            "stride or increase the maximum length to ensure the features are correctly built."
        )

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    
    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )

    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    args.model_type = args.model_type.lower()

    config = AutoConfig.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
        use_fast=False,  # SquadDataset is not compatible with Fast tokenizers which have a smarter overflow handeling
    )
    model = AutoModelForMaskedLM.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    if args.local_rank == 0:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    model.to(args.device)

    # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
    # remove the need for this code, but it is still valid.
    if args.fp16:
        try:
            import apex

            apex.amp.register_half_function(torch, "einsum")
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
    
    if args.predict_file != None:
        evaluate(args, model=model, tokenizer=tokenizer, beam_size=args.beam_size)
    else:
        while(1):
            context = input("context: ")
            answer_text = input("answer: ")
            answer_start = context.find(answer_text)
            keywords = input("keywords: ")

            keywords_list = []
            if keywords != '':
                for k in keywords.split(','):
                    keywords_list.append(k)

            if answer_start != -1:
                features = convert_data_to_features(args, tokenizer, context, answer_text, answer_start, keywords_list)
            else:
                features = convert_data_to_features(args, tokenizer, context, answer_text, 0, keywords_list)

            result = predict(args, model, tokenizer, features, args.beam_size)
            print(result)
if __name__ == "__main__":
    main()