import os
import json
import torch
import random
import string
import argparse

import numpy as np

from Bio import pairwise2
from functools import reduce
from transformers import (MBart50TokenizerFast,
                          MBartForConditionalGeneration)

from pingouin import ttest

from eval_util import read_jsonl, get_final_index, compute_perplexity, plot_saliency_map

MBART_TABLE = {'en': 'en_XX',
               'de': 'de_DE',
               'fr': 'fr_XX',
               'ru': 'ru_RU'}


def get_probabilities(_src_inputs, _tgt_inputs, no_grad):
    """ Performs a forward-pass and returns token-wise probabilities """
    model.zero_grad()
    if no_grad:
        with torch.no_grad():
            model_out = model(input_ids=_src_inputs.input_ids.to(device),
                              attention_mask=_src_inputs.attention_mask.to(device),
                              labels=_tgt_inputs.to(device))
    else:
        model_out = model(input_ids=_src_inputs.input_ids.to(device),
                          attention_mask=_src_inputs.attention_mask.to(device),
                          labels=_tgt_inputs.to(device))
    # Extract probabilities
    model_logits = torch.squeeze(model_out.logits)
    tgt_prob_matrix = model_logits.softmax(axis=-1)
    return tgt_prob_matrix


def get_final_sublist_ids(sub, full):
    """ Finds the last start and end indices of a sublist within a larger list """
    indices = list()
    sub_length = len(sub)
    for s_i, s in enumerate(full):
        if s == sub[0] and (s_i + sub_length) <= len(full):
            if full[s_i: s_i + sub_length] == sub:
                indices = [i for i in range(s_i, s_i + sub_length)]
    return indices


# TODO / NOTE: Not adjusted for EN-RU
def _get_gradient_norm_scores(src_sent, tgt_sent, target_token):
    """ Computes the L2 norm of the gradients of each source embedding with respect to the target pronoun. """

    # Preprocess sequence
    src_inputs = tokenizer(src_sent, return_tensors='pt')
    src_sent_ids = torch.squeeze(src_inputs.input_ids).detach().cpu().numpy().tolist()
    with tokenizer.as_target_tokenizer():
        tgt_inputs = tokenizer(tgt_sent, return_tensors='pt').input_ids
        tgt_sent_ids = torch.squeeze(tgt_inputs).detach().cpu().numpy().tolist()

    # Identify source pronoun ID and its position in source sequence
    src_dict_id = tokenizer.convert_tokens_to_ids('▁it')
    src_idx = get_final_index(src_sent_ids, src_dict_id)

    # Identify target pronoun ID and its position in target sequence
    with tokenizer.as_target_tokenizer():
        tgt_dict_id = tokenizer.convert_tokens_to_ids('▁{:s}'.format(target_token))
        tgt_idx = get_final_index(tgt_sent_ids, tgt_dict_id)

    # Perform a full forward pass for the true target
    tgt_prob_matrix = get_probabilities(src_inputs, tgt_inputs, no_grad=False)
    position_prob = torch.squeeze(tgt_prob_matrix.index_select(dim=0, index=torch.tensor(tgt_idx).to(device)))
    prn_prob = position_prob.index_select(dim=0, index=tgt_inputs[0, tgt_idx].to(position_prob.device))
    prn_prob.backward()  # compute gradients w.r.t target pronoun translation probability
    embedding_grads = model.model.encoder.embed_tokens.weight.grad

    saliency_scores = list()
    # Select relevant embedding entries and compute L2 norm
    for tok_id, src_tok_idx in enumerate(src_sent_ids):
        grad_norm = torch.norm(embedding_grads[src_tok_idx, :], p=1, dim=0).detach().cpu().numpy().tolist()
        saliency_scores.append([tokenizer.convert_ids_to_tokens(src_tok_idx), grad_norm])
    return saliency_scores, src_idx


def _get_prediction_diff_scores(src_sent, true_tgt_sent, false_tgt_sent, target_token):
    """ Iteratively masks source tokens and evaluates the probability of the target token of interest conditioned on
    the masked source """

    # Preprocess sequence
    src_inputs = tokenizer(src_sent, return_tensors='pt')
    src_sent_ids = torch.squeeze(src_inputs.input_ids).detach().cpu().numpy().tolist()
    src_sent_bpe = tokenizer.decode(src_sent_ids).split()
    with tokenizer.as_target_tokenizer():
        true_tgt_inputs = tokenizer(true_tgt_sent, return_tensors='pt').input_ids
        true_tgt_sent_ids = torch.squeeze(true_tgt_inputs).detach().cpu().numpy().tolist()
        true_tgt_sent_bpe = tokenizer.decode(true_tgt_sent_ids).split()
        false_tgt_inputs = tokenizer(false_tgt_sent, return_tensors='pt').input_ids
        false_tgt_sent_ids = torch.squeeze(false_tgt_inputs).detach().cpu().numpy().tolist()
        false_tgt_sent_bpe = tokenizer.decode(false_tgt_sent_ids).split()

        # BART50, annoyingly, splits Russian pronouns into characters, despite their high frequency
        target_token_inputs = tokenizer(target_token).input_ids[1: -1]
        true_token_locations = None
        if len(target_token_inputs) > 1:
            true_token_locations = get_final_sublist_ids(target_token_inputs, true_tgt_sent_ids)

    # Identify source pronoun ID and its position in source sequence
    try:
        src_dict_id = tokenizer.convert_tokens_to_ids('▁it')
        src_idx = get_final_index(src_sent_ids, src_dict_id)
    except AssertionError:
        src_dict_id = tokenizer.convert_tokens_to_ids('▁It')
        src_idx = get_final_index(src_sent_ids, src_dict_id)

    # Perform a full forward pass for the false target
    false_tgt_prob_matrix = \
        get_probabilities(src_inputs, false_tgt_inputs, no_grad=True).detach().cpu().numpy().tolist()
    false_model_probabilities = \
        [false_tgt_prob_matrix[probs_row][tok_id] for probs_row, tok_id in enumerate(false_tgt_sent_ids)]
    false_ppl = compute_perplexity(false_model_probabilities)

    # Perform a full forward pass for the true target
    full_src = src_inputs.input_ids.clone()
    true_tgt_prob_matrix = \
        get_probabilities(src_inputs, true_tgt_inputs, no_grad=True).detach().cpu().numpy().tolist()
    true_model_probabilities = \
        [true_tgt_prob_matrix[probs_row][tok_id] for probs_row, tok_id in enumerate(true_tgt_sent_ids)]
    true_ppl = compute_perplexity(true_model_probabilities)

    # Iterate over target token segments
    ref_tok_probs = list()
    for segment_id, target_token_segment in enumerate(target_token_inputs):
        # Identify target pronoun ID and its position in target sequence
        if len(target_token_inputs) == 1:
            with tokenizer.as_target_tokenizer():
                tgt_dict_id = tokenizer.convert_tokens_to_ids('▁{:s}'.format(target_token))
                tgt_idx = get_final_index(true_tgt_sent_ids, tgt_dict_id)
        else:
            tgt_dict_id = target_token_segment
            tgt_idx = true_token_locations[segment_id]
        ref_tok_probs.append(true_model_probabilities[tgt_idx])
    ref_tok_prob = reduce(lambda x, y: x * y, ref_tok_probs, 1)

    model_is_correct = bool(true_ppl < false_ppl)
    # Compute token-wise saliency scores
    saliency_scores = list()
    for tok_id, src_tok_idx in enumerate(src_sent_ids):
        # Mask-out source token
        src_inputs.input_ids = full_src.clone()
        src_inputs.input_ids[0][tok_id] = mask_index
        # Perform a forward-pass
        true_tgt_prob_matrix = get_probabilities(src_inputs, true_tgt_inputs, no_grad=True)
        true_model_probabilities = \
            [true_tgt_prob_matrix[probs_row][tok_id] for probs_row, tok_id in enumerate(true_tgt_sent_ids)]

        tok_probs = list()
        for segment_id, target_token_segment in enumerate(target_token_inputs):
            # Identify target pronoun ID and its position in target sequence
            if len(target_token_inputs) == 1:
                with tokenizer.as_target_tokenizer():
                    tgt_idx = get_final_index(true_tgt_sent_ids, tgt_dict_id)
            else:
                tgt_idx = true_token_locations[segment_id]

            tok_probs.append(true_model_probabilities[tgt_idx])
        tok_prob = reduce(lambda x, y: x * y, tok_probs, 1)

        # Store results
        saliency_scores.append([tokenizer.convert_ids_to_tokens(src_tok_idx), ref_tok_prob - tok_prob])

    return model_is_correct, saliency_scores, src_idx


def compute_saliency_scores(json_challenge_path, out_dir, saliency_method, pd_saliency_table_path):
    """ Identifies source tokens that are salient to the pronoun choice on the target side. """

    # Read-in samples ('pronoun1' and 'pronoun2' denote the target pronouns)
    samples = read_jsonl(json_challenge_path)

    # Pair contrastive sample samples and select relevant entries
    sample_pairs = dict()
    for s in samples:
        qid = s['qID'].split('-')[-2]
        sid = int(s['qID'].split('-')[-1])
        if sample_pairs.get(qid, None) is None:
            sample_pairs[qid] = dict()
        sample_pairs[qid][sid] = s

    # Read in the prediction_diff evaluation results, if provided
    pd_saliency_table = None
    if pd_saliency_table_path is not None:
        with open(pd_saliency_table_path, 'r', encoding='utf8') as pdp:
            pd_saliency_table = json.load(pdp)

    # Iterate over samples
    saliency_table = dict()
    for qid_id, qid in enumerate(sample_pairs.keys()):

        print('Checking pair {:d}'.format(qid_id))
        if qid_id > 0 and (qid_id + 1) % 100 == 0:
            print('Analysed {:d} contrastive pairs'.format(qid_id + 1))

        saliency_table[qid] = dict()
        for sid in sample_pairs[qid].keys():
            saliency_table[qid][sid] = dict()
            saliency_entry = saliency_table[qid][sid]
            # Unpack
            sample = sample_pairs[qid][sid]
            src = sample['sentence']
            true_tgt = sample['translation1'] if sample['answer'] == 1 else sample['translation2']
            false_tgt = sample['translation2'] if sample['answer'] == 1 else sample['translation1']
            true_ref = sample['referent1_en'] if sample['answer'] == 1 else sample['referent2_en']
            false_ref = sample['referent2_en'] if sample['answer'] == 1 else sample['referent1_en']
            tgt_pron = sample['pronoun1'] if sample['answer'] == 1 else sample['pronoun2']

            # Get scores
            if saliency_method == 'prediction_diff':
                model_is_correct, saliency_scores, src_prn_id = \
                    _get_prediction_diff_scores(src, true_tgt, false_tgt, tgt_pron)
            else:
                model_is_correct = pd_saliency_table[qid][str(sid)]['model_is_correct']
                saliency_scores, src_prn_id = _get_gradient_norm_scores(src, true_tgt, tgt_pron)

            # Update table
            saliency_entry['model_is_correct'] = model_is_correct
            saliency_entry['saliency_scores'] = saliency_scores
            saliency_entry['src_prn_id'] = src_prn_id
            # Merge scores into words (retain max sub-word score per word)
            saliency_entry['word_saliency_scores'] = list()

            new_score_entry = ['', 0]
            for tpl in saliency_entry['saliency_scores']:
                if tpl[0].startswith('▁'):
                    if len(new_score_entry[0]) == 0:
                        new_score_entry[0] = tpl[0]
                        new_score_entry[1] = [tpl[1].cpu()]
                    else:
                        saliency_entry['word_saliency_scores'].append([new_score_entry[0], np.mean(new_score_entry[1])])
                        new_score_entry = [tpl[0], [tpl[1].cpu()]]
                else:
                    new_score_entry[0] = new_score_entry[0] + tpl[0]
                    if tpl[0] == 'en_XX':
                        new_score_entry[1] = [tpl[1].cpu()]
                    else:
                        new_score_entry[1].append(tpl[1].cpu())
            saliency_entry['word_saliency_scores'].append(new_score_entry)

            # Check referent saliency
            # Handle multi-word referents
            saliency_ngram_set = list()
            ngram_lens = [len(ref.split()) for ref in [true_ref, false_ref] if len(ref.split()) > 1]
            for nl in ngram_lens:
                ngrams = list()
                for tpl_id, tpl in enumerate(saliency_entry['word_saliency_scores']):
                    if tpl_id < (len(saliency_entry['word_saliency_scores']) - nl + 1):
                        ngram = ' '.join([saliency_entry['word_saliency_scores'][tpl_id + n][0] for n in range(nl)])
                        score = np.mean([saliency_entry['word_saliency_scores'][tpl_id + n][1] for n in range(nl)])
                        ngrams.append([ngram, score])
                saliency_ngram_set.append(ngrams)

            # LM-specific adjustment
            true_ref = ' '.join(['▁{:s}'.format(seg) for seg in true_ref.split() if len(seg.strip()) > 0])
            false_ref = ' '.join(['▁{:s}'.format(seg) for seg in false_ref.split() if len(seg.strip()) > 0])

            try:
                saliency_entry['true_ref_score'] = [tpl[1] for tpl in saliency_entry['word_saliency_scores'] if
                                                    tpl[0].lower().strip(string.punctuation) == true_ref.lower()][-1]
            except IndexError:
                for ngrams in saliency_ngram_set:
                    try:
                        saliency_entry['true_ref_score'] = \
                            [tpl[1] for tpl in ngrams if
                             tpl[0].lower().strip(string.punctuation) == true_ref.lower()][-1]
                    except IndexError:
                        continue

            try:
                saliency_entry['false_ref_score'] = [tpl[1] for tpl in saliency_entry['word_saliency_scores'] if
                                                     tpl[0].lower().strip(string.punctuation) == false_ref.lower()][-1]
            except IndexError:
                for ngrams in saliency_ngram_set:
                    try:
                        saliency_entry['false_ref_score'] = \
                            [tpl[1] for tpl in ngrams if
                             tpl[0].lower().strip(string.punctuation) == false_ref.lower()][-1]
                    except IndexError:
                        continue

        # Check trigger scores
        saliency_table[qid][1]['trigger_saliency_scores'] = list()
        saliency_table[qid][2]['trigger_saliency_scores'] = list()
        # Check non-trigger scores (ignoring 'it')
        saliency_table[qid][1]['shared_saliency_scores'] = list()
        saliency_table[qid][2]['shared_saliency_scores'] = list()
        # Compute overlap and difference between contrastive English sentences
        alignment = pairwise2.align.globalxx([tpl[0] for tpl in saliency_table[qid][1]['saliency_scores']],
                                             [tpl[0] for tpl in saliency_table[qid][2]['saliency_scores']],
                                             gap_char=['<GAP>'])[0]
        seq_a_gaps, seq_b_gaps = 0, 0
        for tok_id, tok in enumerate(alignment.seqA):
            if tok == '<GAP>':
                seq_a_gaps += 1
            if alignment.seqB[tok_id] == '<GAP>':
                seq_b_gaps += 1
            if (tok_id - seq_a_gaps) != saliency_table[qid][1]['src_prn_id']:  # ignore ambiguous pronoun
                if tok != '<GAP>':  # ignore gaps
                    if tok == alignment.seqB[tok_id]:  # shared tokens
                        saliency_table[qid][1]['shared_saliency_scores'].append(
                            saliency_table[qid][1]['saliency_scores'][tok_id - seq_a_gaps])
                        saliency_table[qid][2]['shared_saliency_scores'].append(
                            saliency_table[qid][2]['saliency_scores'][tok_id - seq_b_gaps])
                    else:  # trigger tokens
                        saliency_table[qid][1]['trigger_saliency_scores'].append(
                            saliency_table[qid][1]['saliency_scores'][tok_id - seq_a_gaps])
                else:
                    if alignment.seqB[tok_id] != '<GAP>':
                        saliency_table[qid][2]['trigger_saliency_scores'].append(
                            saliency_table[qid][2]['saliency_scores'][tok_id - seq_b_gaps])

    # Report
    # 1. Which referent is more salient when model is correct / incorrect)?
    true_ref_scores_model_correct = list()
    false_ref_scores_model_correct = list()
    true_ref_scores_model_incorrect = list()
    false_ref_scores_model_incorrect = list()
    for qid in saliency_table.keys():
        for sid in saliency_table[qid].keys():
            if saliency_table[qid][sid]['model_is_correct']:
                true_ref_scores_model_correct.append(saliency_table[qid][sid]['true_ref_score'])
                false_ref_scores_model_correct.append(saliency_table[qid][sid]['false_ref_score'])
            else:
                true_ref_scores_model_incorrect.append(saliency_table[qid][sid]['true_ref_score'])
                false_ref_scores_model_incorrect.append(saliency_table[qid][sid]['false_ref_score'])
    true_ref_scores_model_all = true_ref_scores_model_correct + true_ref_scores_model_incorrect
    false_ref_scores_model_all = false_ref_scores_model_correct + false_ref_scores_model_incorrect

    print('-' * 20)
    print('Resolved {:d} samples correctly'.format(len(true_ref_scores_model_correct)))
    print('Resolved {:d} samples incorrectly'.format(len(true_ref_scores_model_incorrect)))

    print('-' * 20)
    print('Mean (std.) saliency of the [CORRECT] referent in [CORRECTLY scored] samples: {:.4f} ({:.4f})'.format(
        np.mean(true_ref_scores_model_correct), np.std(true_ref_scores_model_correct)))
    print('Mean (std.) saliency of the [INCORRECT] referent in [CORRECTLY scored] samples: {:.4f} ({:.4f})'.format(
        np.mean(false_ref_scores_model_correct), np.std(false_ref_scores_model_correct)))
    print('Mean (std.) saliency of the [CORRECT] referent in [INCORRECTLY scored] samples: {:.4f} ({:.4f})'.format(
        np.mean(true_ref_scores_model_incorrect), np.std(true_ref_scores_model_incorrect)))
    print('Mean (std.) saliency of the [INCORRECT] referent in [INCORRECTLY scored] samples: {:.4f} ({:.4f})'.format(
        np.mean(false_ref_scores_model_incorrect), np.std(false_ref_scores_model_incorrect)))
    print('Mean (std.) saliency of the [CORRECT] referent in [ALL] samples: {:.4f} ({:.4f})'.format(
        np.mean(true_ref_scores_model_all), np.std(true_ref_scores_model_all)))
    print('Mean (std.) saliency of the [INCORRECT] referent in [ALL] samples: {:.4f} ({:.4f})'.format(
        np.mean(false_ref_scores_model_all), np.std(false_ref_scores_model_all)))

    print('TTest referents:')
    print(ttest(true_ref_scores_model_all, false_ref_scores_model_all))

    # 2. Are triggers more salient when the model correct?
    trigger_scores_model_correct = list()
    shared_scores_model_correct = list()
    trigger_scores_model_incorrect = list()
    shared_scores_model_incorrect = list()
    for qid in saliency_table.keys():
        for sid in saliency_table[qid].keys():
            if saliency_table[qid][sid]['model_is_correct']:
                trigger_scores_model_correct += [tpl[1].cpu() for tpl in saliency_table[qid][sid]['trigger_saliency_scores']]
                shared_scores_model_correct += [tpl[1].cpu() for tpl in saliency_table[qid][sid]['shared_saliency_scores']]
            else:
                trigger_scores_model_incorrect += [tpl[1].cpu() for tpl in saliency_table[qid][sid]['trigger_saliency_scores']]
                shared_scores_model_incorrect += [tpl[1].cpu() for tpl in saliency_table[qid][sid]['shared_saliency_scores']]
    trigger_scores_model_all = trigger_scores_model_correct + trigger_scores_model_incorrect
    shared_scores_model_all = shared_scores_model_correct + shared_scores_model_incorrect

    print('-' * 20)
    print('Average (std.) saliency of [TRIGGERS] in [CORRECTLY scored] samples: {:.4f} ({:.4f})'.format(
        np.mean(trigger_scores_model_correct), np.std(trigger_scores_model_correct)))
    print('Average (std.) saliency of [SHARED TOKENS] in [CORRECTLY scored] samples: {:.4f} ({:.4f})'.format(
        np.mean(shared_scores_model_correct), np.std(shared_scores_model_correct)))
    print('Average (std.) saliency of [TRIGGERS] in [INCORRECTLY scored] samples: {:.4f} ({:.4f})'.format(np.mean(
        trigger_scores_model_incorrect), np.std(trigger_scores_model_incorrect)))
    print('Average (std.) saliency of [SHARED TOKENS] in [INCORRECTLY scored] samples: {:.4f} ({:.4f})'.format(
        np.mean(shared_scores_model_incorrect), np.std(shared_scores_model_incorrect)))
    print('Average (std.) saliency of [TRIGGERS] in [ALL] samples: {:.4f} ({:.4f})'.format(
        np.mean(trigger_scores_model_all), np.std(trigger_scores_model_all)))
    print('Average (std.) saliency of [SHARED TOKENS] in [ALL] samples: {:.4f} ({:.4f})'.format(
        np.mean(shared_scores_model_all), np.std(shared_scores_model_all)))

    print('TTest triggers:')
    print(ttest(trigger_scores_model_all, shared_scores_model_all))

    # Make plots for 10 randomly drawn samples
    print('-' * 20)
    random.seed(42)
    qids = list(saliency_table.keys())
    random.shuffle(qids)
    plot_dir = '{:s}/plots_{:s}'.format(out_dir, saliency_method)
    if not os.path.exists(plot_dir):
        os.makedirs(plot_dir)
    for qid in qids[:10]:
        for sid in saliency_table[qid].keys():
            plt_path = '{:s}/saliency_plot_{:s}-{:d}_{}.png'.format(
                plot_dir, qid, sid, saliency_table[qid][sid]['model_is_correct'])
            plot_saliency_map(saliency_table[qid][sid]['saliency_scores'], '▁', plt_path)
    print('Saved all plots to {:s}'.format(out_dir))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--json_file_path', type=str, required=True,
                        help='path to the JSON file containing the contrastive samples')
    parser.add_argument('--out_dir', type=str, required=True,
                        help='path to the output directory')
    parser.add_argument('--checkpoint_dir', type=str, default=None,
                        help='path to the directory containing checkpoint files of the evaluated model')
    parser.add_argument('--model_type', type=str, choices=['bart', 'mbart50', 't5'],
                        help='Model type to evaluate')
    parser.add_argument('--use_multi', action='store_true',
                        help='Whether to use a multilingual language model; if disabled, uses a monolingual target '
                             'language model')
    parser.add_argument('--use_cpu',  action='store_true',
                        help='whether to use the CPU for model passes'),
    parser.add_argument('--src_lang', type=str,
                        help='language code corresponding to the source language')
    parser.add_argument('--tgt_lang', type=str,
                        help='language code corresponding to the target language')
    parser.add_argument('--saliency_method', type=str, choices=['prediction_diff', 'gradient_norm'],
                        help='saliency method to be used for the performed analysis; NOTE: Due to the construction '
                             'of the evaluation protocol, \'gradient_norm\' evaluation can only be run after '
                             '\'prediction_diff\' has been completed once')
    parser.add_argument('--pd_saliency_table_path', type=str, default=None,
                        help='path to the JSON file containing the results of the \'prediction_diff\' evaluation')
    args = parser.parse_args()

    # Create output directory, if necessary
    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    # Load translation model
    if args.saliency_method == 'gradient_norm':
        assert args.pd_saliency_table_path is not None, \
            '\'prediction_diff\' evaluation must be completed before running \'gradient_norm\' evaluation'

    # Assign checkpoints (change model size by changing checkpoint names)
    model_type = MBartForConditionalGeneration
    tokenizer_type = MBart50TokenizerFast

    # Load models and tokenizers
    if args.model_type != 'mbart50':
        assert args.checkpoint_dir is not None, 'Model checkpoint must be specified for models other than MBART50'
        tokenizer = tokenizer_type.from_pretrained(args.checkpoint_dir)
        model = model_type.from_pretrained(args.checkpoint_dir)
    else:
        model = model_type.from_pretrained('facebook/mbart-large-50-one-to-many-mmt')
        tokenizer = tokenizer_type.from_pretrained('facebook/mbart-large-50-one-to-many-mmt',
                                                   src_lang=MBART_TABLE[args.src_lang],
                                                   tgt_lang=MBART_TABLE[args.tgt_lang])

    device = 'cpu' if args.use_cpu else 'cuda:0'
    model.to(device)

    # Zero-out an unused embedding in the translation model
    with torch.no_grad():
        mask_index = tokenizer.mask_token_id
        model.model.encoder.embed_tokens.weight[mask_index, :] *= 0  # zero-out embedding

    compute_saliency_scores(args.json_file_path, args.out_dir, args.saliency_method, args.pd_saliency_table_path)
