# Erik McGuire, 2021

from sklearn.model_selection import train_test_split, StratifiedKFold
from nltk.tokenize import TreebankWordTokenizer
from zuco_dataset import ZuCoDataset
from typing import Tuple, List, Union
from torchtext.legacy import datasets, data
from zuco_paths import PHRASE_PTH
from zipfile import ZipFile
from torch import Tensor
from pathlib import Path
from zuco_utils import *
from zuco_params import args
import torch.nn as nn
import pandas as pd
import torch

normEEG = ((args.electrode_handling in ['max', 'sum']  and
            args.per_sample) or
            (not args.per_sample and
            args.electrode_handling == 'sum'))

class BertSent:
    def __init__(self, config, tokenizer):
        if args.task == "rel":
            args.zuco_splits = True
        self.config = config
        self.is_training = True
        self.device = torch.device('cuda' if torch.cuda.is_available()
                                          else 'cpu')
        self.skf = StratifiedKFold(n_splits=max(2, args.cv),
                                   shuffle=True,
                                   random_state=args.seed)
        self.tokenizer = tokenizer
        self.inputs = None
        self.embeds = None
        self.training = None
        self.current_fold = 0
        self.cv_scores = {i: [] for i in range(args.cv)}

        self.folds = {'train': [], 'dev': []}
        self.new2old = dict()
        self.n2o_rel = {n: dict() for n in ['train', 'test']}

        self.train_labels = []
        self.val_labels = []
        self.test_labels = []

        self.ztrain_labels = []
        self.zval_labels = []
        self.ztest_labels = []

        self.ztokenizer = TreebankWordTokenizer()

        self.all_att = {"eeg": [], "et": [], "model": [],
                        "input_ids": [], "preds": [], "eeg_kld": [],
                        "et_kld": [], "eeg_sim": [], "et_sim": [],
                        "incorrect": [], "labels": []}
        self.data_splits = self.load_data()
        self.rel = self.load_rel()
        if args.task == "rel":
            self.z_splits = self.read_rel_split()

        self.ztrainsets = []
        self.zdevsets = []
        self.rel_encodings = self.get_encodings(t="rel")

        if args.zuco_splits:
            if args.task == "rel":
                self.ztrainset = ZuCoDataset(self.rel_encodings[0],
                                             self.ztrain_labels, 'train', self.n2o_rel['train'])
                self.ztrainmain = self.ztrainset
                self.ztestset = ZuCoDataset(self.rel_encodings[2],
                                            self.ztest_labels, new2old=self.n2o_rel['test'])
                if args.cv > 0:
                    for ix, train_fold in enumerate(self.folds['train']):
                        self.ztrainsets.append(ZuCoDataset(train_fold[0],
                                                     train_fold[1], 'train',
                                                     new2old=self.n2o_rel[f'train_fold_{ix}']))
                    for ix, dev_fold in enumerate(self.folds['dev']):
                        self.zdevsets.append(ZuCoDataset(dev_fold[0],
                                                   dev_fold[1],
                                                   new2old=self.n2o_rel[f'dev_fold_{ix}']))
        else:
            self.ztrainset = ZuCoDataset(encodings[0],
                                         self.ztrain_labels, 'train',
                                         self.new2old)
            self.zdevset = self.val_dataset
            self.ztestset = self.test_dataset

        if args.task == "rel":
            self.train_dataset = self.ztrainset
            self.test_dataset = self.ztestset


    def load_rel(self):
        return pd.read_csv(PHRASE_PTH, sep="\t")

    def read_rel_split(self):

        train_data, test_data = self.get_rel_splits()

        def get_texts(data):
            texts = data.phrase.tolist()
            if not args.per_sample:
                texts = [" ".join(self.ztokenizer.tokenize(s))
                         for s in texts]
            return texts

        train_main = (get_texts(train_data),
                      train_data.label_id.tolist(),
                      train_data.index.tolist())

        test = (get_texts(test_data),
                test_data.label_id.tolist(),
                test_data.index.tolist())

        for n, split in [('train', train_main), ('test', test)]:
            for new, old in enumerate(split[2]):
                self.n2o_rel[n][new] = self.rel.loc[old, 'sent']

        splits = self.skf.split(train_data.phrase.tolist(),
                                train_data.label_id.tolist())

        for ix, (t_ix, d_ix) in enumerate(splits):
            train_df = train_data.iloc[t_ix]
            dev_df = train_data.iloc[d_ix]

            train = (get_texts(train_df),
                     train_df.label_id.tolist(),
                     train_df.index.tolist())
            dev = (get_texts(dev_df),
                     dev_df.label_id.tolist(),
                     dev_df.index.tolist())

            self.folds['train'].append(train)

            self.folds['dev'].append(dev)

            for n, split in [(f'train_fold_{ix}', train),
                             (f'dev_fold_{ix}', dev)]:
                for new, old in enumerate(split[2]):
                    try:
                        self.n2o_rel[n][new] = self.rel.loc[old, 'sent']
                    except KeyError:
                        self.n2o_rel[n] = dict()
                        self.n2o_rel[n][new] = self.rel.loc[old, 'sent']

        return train_main, test

    def get_zuco_splits(self, texts, labels):
        dev_size, test_size = 0.1, 0.1
        t_train, t_test, l_train, l_test = train_test_split(
                        texts, labels, test_size=test_size, stratify=labels)
        train_size = 1 - test_size
        dev_size = dev_size / train_size
        t_train, t_dev, l_train, l_dev = train_test_split(
        t_train, l_train, test_size=dev_size, stratify=l_train)
        train = (t_train, l_train)
        val = (t_dev, l_dev)
        test = (t_test, l_test)
        return train, val, test

    def get_rel_splits(self):
        train = pd.read_csv(f"data/phrase_train.csv", index_col=[0])
        test = pd.read_csv(f"data/phrase_test.csv", index_col=[0])
        return train, test

    def tokenize(self, dataset, t: str,
                       labels = None, split = "") -> List[Tensor]:
        """Tokenize and pad split's sentences."""
        if dataset:
            texts = dataset[0]
            return self.tokenizer(texts,
                             truncation=True,
                             padding="longest",
                             return_token_type_ids=True,
                             return_attention_mask=True,
                             add_special_tokens=True,
                             return_tensors="pt")
        else:
            return None

    def get_encodings(self, t: str = "") -> Tuple[List[Tensor]]:
        """ Encode each dataset split's sentences. """
        if args.zuco_splits:
            if not args.task == "rel":
                train = self.z_splits[0]
                val = self.z_splits[1]
                test = self.z_splits[2]
                self.ztrain_labels = train[1]
                self.zval_labels = val[1]
                self.ztest_labels = test[1]
            else:
                train = self.z_splits[0]
                val = None
                test = self.z_splits[1]
                self.ztrain_labels = train[1]
                self.ztest_labels = test[1]
                if args.cv > 0 and t == "rel":
                    self.folds['train'] = [(self.tokenize(trn, t, "train"),
                                            trn[1], trn[2]) for trn in self.folds['train']]
                    self.folds['dev'] = [(self.tokenize(dv, t, "val"),
                                          dv[1], dv[2]) for dv in self.folds['dev']]
        else:
            train = self.z_splits
            self.ztrain_labels = train[1]
            val = None
            test = None

        train_encodings = self.tokenize(train, t, "train")
        val_encodings = self.tokenize(val, t, "val")
        test_encodings = self.tokenize(test, t, "test")

        return (train_encodings, val_encodings, test_encodings)

    def prep_lang_att(self, lang_att: Union[Tuple[Tensor], Tensor],
                      layer_handling: str = "-1",
                      head_handling: str = "avg") -> Tensor:
        """
        Select last layer (or avg) model attentions tensor.
         Shape (batch_size x num_heads x seq_len [q] x seq_len [k]).
        Select [CLS] query row as attentions for all sequences, heads.
         Shape (batch_size x num_heads x 1 x seq_len).
        Prep heads.
        Return batch of [CLS] queries' scores for tokens in sequence.
        """
        if layer_handling == "avg":
            lang_att = torch.stack(lang_att) # tuple to tensor
            lang_att = torch.mean(lang_att, dim=0) # avg'd layerwise
        elif layer_handling != "every": # specific layer
            lang_att = lang_att[eval(layer_handling)] # tuple to last tensor
        # else layer_handling='every', so we called this with lang_att: Tensor
        lang_att = lang_att[:, :, 0] # all sequences, all heads, cls weights
        lang_att = self.prep_heads(lang_att, head_handling)
        return lang_att

    def prep_heads(self,
                      lang_att: Union[Tuple[Tensor], Tensor],
                      head_handling: Union[str, int] = -1) -> Tensor:
        """
        Select each key tokens' [CLS]-based attentions from avg values
         over heads, or max value among heads, or last head values.
         Shape (batch_size x 1 x 1 x seq_len).
        """
        if head_handling == 'avg':
            lang_att = torch.mean(lang_att, dim=1) # avg'd over heads
        elif head_handling == 'max':
            lang_att = torch.max(lang_att, dim=1).values # max over heads
        elif head_handling != "all": # specific head
            lang_att = lang_att[:, eval(head_handling), :]
        else:
            lang_att = lang_att.clone()
        return lang_att

    def minmax(self, t: Tensor) -> Tensor:
        """Min-max scaling."""
        try:
            tnz = t[t.nonzero(as_tuple=True)]
            t[t.nonzero(as_tuple=True)] = (tnz - tnz.min())/(tnz.max() - tnz.min())
        except RuntimeError: # In case all zeros.
            pass
        return t

    def mse(self, t, z, ps=False):
        """Don't consider padding in average."""
        if not ps and not args.use_weights:
            se = (t - z) ** 2
            return torch.mean(se[torch.nonzero(se, as_tuple=True)])
        elif not ps or args.task == "rel":
            t[t == 0] = 1e-12
            z[z == 0] = 1e-12
            return nn.KLDivLoss(reduction='batchmean',
                                log_target=False)(t.log(), z)
        else:
            klds = []
            kld = nn.KLDivLoss(reduction='batchmean',
                               log_target=False)
            for ti, zi in zip(t.log(), z.to(t.device)):
                klds.append(kld(ti, zi))
            return klds

    def get_zuco_loss(self, lang_att: Union[Tuple[Tensor], Tensor],
                      inputs: dict, zuco: str,
                      layer_handling: str = "-1",
                      head_handling: str = "avg") -> Tensor:
        """
        Process model attentions to match with EEG or ET attentions.
        Compute attention loss (MSE or KLDiv).
        """
        zin = "eeg_redmn" if zuco == "eeg" else "et_trt"
        if not layer_handling == 'every':
            if head_handling in ["all", "every"]:
                # Make a copy of ZuCo scores, one for each head.
                mse = torch.tensor(0., requires_grad=True).to(self.device)
                if head_handling == "all":
                    zuco_att = torch.unsqueeze(zuco_att, 1)
                    zuco_att = zuco_att.repeat(1, lang_att.shape[1], 1)
                else:
                    for i in range(lang_att.shape[1]):
                        new_lang_att = self.prep_lang_att(lang_att,
                                                layer_handling=layer_handling,
                                                head_handling=str(i))
                        zuco_att, lang_att = self.get_zuco_att(inputs,
                                                               new_lang_att,
                                                               zuco, zin)
                        mse += self.mse(new_lang_att, zuco_att)
                    return mse/lang_att.shape[1]
            else:
                if not args.model_att:
                    lang_att = self.prep_lang_att(lang_att,
                                                  layer_handling=layer_handling,
                                                  head_handling=head_handling)
                else:
                    lang_att = lang_att[:, 0]
                zuco_att, lang_att = self.get_zuco_att(inputs,
                                                       lang_att,
                                                       zuco, zin)
                mse = self.mse(lang_att, zuco_att)
        else: # average loss per layer
            mse = torch.tensor(0., requires_grad=True).to(self.device)
            for i in range(len(lang_att)):
                h_mse = torch.tensor(0., requires_grad=True).to(self.device)
                if not args.model_att:
                    new_lang_att = lang_att[i]
                else:
                    new_lang_att = lang_att[i][:, 0]
                    if head_handling in ["all", "every"]:
                        for j in range(lang_att.shape[1]):
                            new_lang_att_h = self.prep_lang_att(
                                                new_lang_att,
                                                layer_handling=layer_handling,
                                                head_handling=str(j))
                            zuco_att, new_lang_att_h = self.get_zuco_att(inputs,
                                                               new_lang_att_h,
                                                               zuco, zin)
                            h_mse += self.mse(new_lang_att_h, zuco_att)/new_lang_att.shape[1]
                    mse += h_mse
            mse = mse/len(lang_att)
        return mse

    def get_zuco_att(self, inputs, lang_att, zuco, zin):
        zuco_att = inputs[zuco] if not args.per_sample else inputs[zin]
        if not args.per_sample and (zuco == "et" or
                                   (zuco == "eeg" and normEEG)):
            # Normalize big ZuCo values whether using weights or raw scores
            zuco_att = self.minmax(zuco_att)
        if args.use_weights:
            zuco_att = self.zuco_sm(zuco_att,
                                    inputs['attention_mask'],
                                    zuco)
        else:
            # Normalize raw scores if not using weights
            lang_att = self.minmax(lang_att)
            if not args.per_sample and (not zuco == "et" and not normEEG):
                # Norm ZuCo if using raw scores and not already normalized
                zuco_att  = self.minmax(zuco_att)
        return zuco_att, lang_att

    def zuco_sm(self, att, attention_mask, zuco):
        """Way of computing softmax w/o padding."""
        extended_attention_mask = (1.0 - attention_mask) * -10000.0
        att += extended_attention_mask
        att = nn.Softmax(dim=1)(att)
        return att

    def att_sim(self, h, m, k):
        """Compute ratio of top-k overlap between
           human and model for each sample in batch.
        """
        h_topk = torch.topk(h, k)
        m_topk = torch.topk(m, k)
        ratios = []
        for ix, (h_ind, m_ind) in enumerate(zip(h_topk.indices,
                                                m_topk.indices)):
            overlap = set(h_ind.cpu().numpy()) & set(m_ind.cpu().numpy())
            ratios.append(len(overlap) / len(set(h_ind.cpu().numpy())))
        return ratios

    def save_all_atts(self, lang_att: Tuple[Tensor], inputs: dict,
                      preds: Tensor, layer_handling: str = "-1",
                      head_handling: str = "avg") -> Tensor:
        """
        Process model attentions to match with EEG or ET attentions.
        Compute squared errors.
        """
        with torch.no_grad():
            layer_handling = ("avg"
                               if layer_handling == 'every'
                               else layer_handling)
            lang_att = self.prep_lang_att(lang_att,
                                          layer_handling=layer_handling,
                                          head_handling=head_handling)
            if not args.per_sample:
                inp_eeg = inputs["eeg"]
                inp_et = inputs["et"]
                et_att = self.minmax(inp_et)
                if normEEG:
                # Normalize big ZuCo values whether using weights or raw scores
                    eeg_att = self.minmax(inp_eeg)
            else:
                inp_eeg = inputs["eeg_redmn"]
                inp_et = inputs["et_trt"]

            if args.use_weights:
                eeg_att = self.zuco_sm(inp_eeg,
                                       inputs['attention_mask'],
                                       "eeg")
                et_att = self.zuco_sm(inp_et,
                                      inputs['attention_mask'],
                                      "et")
            else:
                # Normalize raw scores if not using weights
                lang_att = self.minmax(lang_att)
                if not args.per_sample and not normEEG:
                    # Norm ZuCo if using raw scores and not already normalized
                    eeg_att = self.minmax(inp_eeg)

            eeg_loss = self.mse(lang_att, eeg_att.to(self.device), ps = True)
            et_loss = self.mse(lang_att, et_att.to(self.device), ps = True)
            eeg_sim = self.att_sim(eeg_att, lang_att, k=2)
            et_sim = self.att_sim(et_att, lang_att, k=2)
            incorrect = [ix
                         for ix, (p, l) in enumerate(zip(preds,
                                                         inputs["labels"]))
                         if p != l]
            self.all_att["incorrect"].append(incorrect)
            self.all_att["input_ids"].append(inputs["input_ids"])
            self.all_att["preds"].append(preds)
            self.all_att["labels"].append(inputs["labels"])
            self.all_att["model"].append(lang_att)
            self.all_att["eeg"].append(eeg_att)
            self.all_att["et"].append(et_att)
            self.all_att["eeg_kld"].append(eeg_loss)
            self.all_att["et_kld"].append(et_loss)
            self.all_att["eeg_sim"].append(eeg_sim)
            self.all_att["et_sim"].append(et_sim)
