# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""


import argparse
import glob
import logging
import os
import random
import timeit

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from common.config import register_args, load_config_and_tokenizer

from transformers import (
    MODEL_FOR_QUESTION_ANSWERING_MAPPING,
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModelForQuestionAnswering,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

from data.custom_squad_feature import custom_squad_convert_examples_to_features, SquadResult, SquadProcessor

from data.qa_metrics import (compute_predictions_logits,hotpot_evaluate,)

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

from run_qa import load_and_cache_examples, set_seed, to_list
from int_grad.ig_models import IGRobertaForQuestionAnswering
from int_grad.ig_qa_utils import compute_predictions_index_and_logits, stats_of_ig_interpretation
from vis_tools.vis_utils import visualize_attributions, merge_tokens_into_words

logger = logging.getLogger(__name__)

MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

from itertools import chain

def merge_predictions(dicts):
    return dict(chain(*[list(x.items()) for x in dicts]))


def predict_and_attribute(args, batch, model, tokenizer, batch_features, batch_examples):
    model.eval()
    batch = tuple(t.to(args.device) for t in batch)

    is_parallel = args.n_gpu > 1
    # run predictions
    with torch.no_grad():
        inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
            "token_type_ids": batch[2],
            "output_attentions": True
        }

        if args.model_type in ["roberta", "distilbert", "camembert", "bart"]:
            del inputs["token_type_ids"]

        feature_indices = batch[3]

        outputs = model(**inputs)
    # print(feature_indices)

    batch_start_logits, batch_end_logits, batch_attentions = outputs
    outputs = outputs[:-1]

    batch_results = []
    for i, feature_index in enumerate(feature_indices):
        eval_feature = batch_features[i]
        unique_id = int(eval_feature.unique_id)

        output = [to_list(output[i]) for output in outputs]
        start_logits, end_logits = output
        result = SquadResult(unique_id, start_logits, end_logits)

        batch_results.append(result)
    
    batch_prelim_results, batch_predictions = compute_predictions_index_and_logits(
        batch_examples,
        batch_features,
        batch_results,
        args.n_best_size,
        args.max_answer_length,
        args.do_lower_case,
        tokenizer,
        args.dataset,
    )

    
    # run attributions
    batch_start_indexes = torch.LongTensor([x.start_index for x in batch_prelim_results]).to(args.device)
    batch_end_indexes = torch.LongTensor([x.end_index for x in batch_prelim_results]).to(args.device)
    batch_attentions = torch.stack(batch_attentions)
    
    # for data parallel 
    inputs = {
        "input_ids": batch[0],
        "attention_mask": batch[1],
        "token_type_ids": batch[2],
        "input_attentions": torch.transpose(batch_attentions, 0, 1) if is_parallel else batch_attentions,
        "start_indexes": batch_start_indexes,
        "end_indexes": batch_end_indexes,
        "final_start_logits": batch_start_logits,
        "final_end_logits": batch_end_logits,
        "num_steps": args.ig_steps,
        "do_attribute": True,
        "is_parallel": is_parallel,
    }
    if args.model_type in ["roberta", "distilbert", "camembert", "bart"]:
        del inputs["token_type_ids"]
    
    batch_attributions = model(**inputs)
    if is_parallel:
        batch_attributions = torch.transpose(batch_attributions, 0, 1)
    # print(batch_attributions.size())
    # attribution in logits
    return batch_predictions, batch_prelim_results, batch_attentions, batch_attributions


def dump_ig_info(args, examples, features, tokenizer, predictions, prelim_results, attentions, attributions):
    
    # attentions, attributions
    # N_Layer * B * N_HEAD * L * L
    attentions = attentions.detach().cpu().requires_grad_(False)
    attentions = torch.transpose(attentions, 0, 1)
    attributions = attributions.detach().cpu().requires_grad_(False)
    attributions = torch.transpose(attributions, 0, 1)

    for example, feature, prelim_result, attention, attribution in zip(
        examples,
        features,
        prelim_results,
        torch.unbind(attentions),
        torch.unbind(attributions)
    ):
        actual_len = len(feature.tokens)
        attention = attention[:,:,:actual_len, :actual_len].clone().detach()
        attribution = attribution[:,:,:actual_len, :actual_len].clone().detach()
        filename = os.path.join(args.interp_dir, f'{feature.example_index}-{feature.qas_id}.bin')
        prelim_result = prelim_result._asdict()
        prediction = predictions[example.qas_id]
        torch.save({'example': example, 'feature': feature, 'prediction': prediction, 'prelim_result': prelim_result,
            'attention': attention, 'attribution': attribution}, filename)

def ig_interp(args, model, tokenizer, prefix=""):

    if not os.path.exists(args.interp_dir):
        os.makedirs(args.interp_dir)

    # fix the model
    model.requires_grad_(False)

    dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
    
    # assume one on on mapping
    assert len(examples) == len(features)

    if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)

    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

    # multi-gpu evaluate
    if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
        model = torch.nn.DataParallel(model)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)

    all_predictions = []
    start_time = timeit.default_timer()

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
       
        feature_indices = to_list(batch[3])
        batch_features = [features[i] for i in feature_indices]
        batch_examples = [examples[i] for i in feature_indices]
        # batch prem, batch predictions
        batch_predictions, batch_prelim_results, batch_attentions, batch_attributions = predict_and_attribute(
            args,
            batch,
            model,
            tokenizer,
            batch_features,
            batch_examples
        )

        # lots of info, dump to files immediately
        dump_ig_info(args, batch_examples, batch_features, tokenizer, batch_predictions, batch_prelim_results, batch_attentions, batch_attributions)
        all_predictions.append(batch_predictions)        

    evalTime = timeit.default_timer() - start_time
    logger.info("  Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))

    # Compute predictions
    # output_prediction_file =  os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
    # output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))

    output_null_log_odds_file = None

    # XLNet and XLM use a more complex post-processing procedure
    # Compute the F1 and exact scores.
    all_predictions = merge_predictions(all_predictions)
    results = hotpot_evaluate(examples[:len(all_predictions)], all_predictions)
    return results


def ig_analyze(args, tokenizer):
    filenames = os.listdir(args.interp_dir)
    filenames.sort(key=lambda x: int(x.split('-')[0]))
    # print(len(filenames))
    datset_stats = []
    for fname in tqdm(filenames, desc='Visualizing'):
        interp_info = torch.load(os.path.join(args.interp_dir, fname))
        # datset_stats.append(stats_of_ig_interpretation(tokenizer, interp_info))
        visualize_attributions(args, tokenizer, interp_info)

def main():
    parser = argparse.ArgumentParser()
    register_args(parser)

    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument("--do_vis", action="store_true", help="Whether to run vis on the dev set.")
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
    )
    parser.add_argument("--ig_steps", type=int, default=50, help="steps for running integrated gradient")

    parser.add_argument("--interp_dir",default=None,type=str,required=True,help="The output directory where the model checkpoints and predictions will be written.")
    parser.add_argument("--visual_dir",default=None,type=str,help="The output visualization dir.")
    args = parser.parse_args()


    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
    )

    # Set seed
    set_seed(args)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        # Make sure only the first process in distributed training will download model & vocab
        torch.distributed.barrier()

    args.model_type = args.model_type.lower()
    config, tokenizer = load_config_and_tokenizer(args)

    if args.do_vis:
        ig_analyze(args, tokenizer)
    else:
        if args.do_eval and args.local_rank in [-1, 0]:

            logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
            checkpoint = args.model_name_or_path
            logger.info("Evaluate the following checkpoints: %s", checkpoint)

            # Reload the model
            model = IGRobertaForQuestionAnswering.from_pretrained(checkpoint)  # , force_download=True)
            model.to(args.device)

            # Evaluate
            result = ig_interp(args, model, tokenizer, prefix="")            

        logger.info("Results: {}".format(result))

        return result


if __name__ == "__main__":
    main()