# Erik McGuire, 2021

from zuco_logistics import BertSent
from zuco_params import args
from zuco_utils import debug

if args.pred_zuco or args.aug or args.direct:
    from zuco_modeling import gist2

from dataclasses import field, dataclass
from typing import Optional
from tqdm import tqdm
import dataclasses
import numpy as np
import torch
import os
from torch import nn
import wandb

from transformers import (
    logging,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoConfig,
    TrainerCallback,
    TrainingArguments,
    Trainer,
    set_seed
    )

from sklearn.metrics import (
    precision_recall_fscore_support as score,
    roc_auc_score,
    accuracy_score
    )

if args.task == "rel":
    num_labels = 11
    id2label = { 0: 'AWARDED',
                 1: 'BIRTHPLACE',
                 2: 'DEATHPLACE',
                 3: 'EDUCATION',
                 4: 'EMPLOYER',
                 5: 'FOUNDER',
                 6: 'JOBTITLE',
                 7: 'NATIONALITY',
                 8: 'POLITICALAFFILIATION',
                 9: 'VISITED',
                 10: 'WIFE'}
    label2id = {v: k for (k, v) in id2label.items()}

config = AutoConfig.from_pretrained(args.model_name_or_path,
                                    num_labels=num_labels,
                                    label2id=label2id,
                                    id2label=id2label)

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased',
                                          additional_special_tokens=["<e>", "</e>"])
gist = BertSent(config, tokenizer)

if args.zuco_only and args.do_eval and 'zuco' in args.run_name:
    eval_set = gist.ztrainset

if not 'zuco' in args.run_name:
    if not args.do_train and args.do_predict:
        if args.zuco_only and args.zuco_splits:
            eval_set = gist.ztestset
        else:
            eval_set = gist.test_dataset

@dataclass
class ZucoTrainingArguments(TrainingArguments):
    """Custom arguments for hyperparameter search."""
    et_lmbda: Optional[float] = field(default=args.et_lmbda,
                                      metadata={"help": "Gaze loss coef."})
    eeg_lmbda: Optional[float] = field(default=args.eeg_lmbda,
                                       metadata={"help": "EEG loss coef."})
    pred_lmbda: Optional[float] = field(default=args.pred_lmbda,
                                        metadata={"help": "Loss trade-off."})

class LossCallback(TrainerCallback):
    def on_step_begin(self, arg, state, control, **kwargs):
        if not gist.is_training:
            gist.is_training = True

    def on_step_end(self, arg, state, control, **kwargs):
        if gist.is_training:
            gist.is_training = False

    def on_evaluate(self, training_args, state, control, **kwargs):
        if args.cv > 0:
            gist.cv_scores[gist.current_fold].append(kwargs['metrics'])

class ZuCoTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        """Joint losses w/ attention supervision."""
        if args.pred_zuco or args.aug or args.direct:
            gist2.training = True
            gist2.inputs = inputs
        minputs = {k: v for (k, v) in inputs.items()
                   if k not in ["et", "eeg", "eeg_redmn", "et_trt",
                                "eeg_embeds", "et_embeds"]}
        outputs = model(**minputs,
                        output_attentions=True,
                        return_dict=True)
        loss = outputs.loss
        try:
            wandb.log({"train/loss": outputs.loss}, step=self.state.global_step)
        except:
            pass
        if return_outputs:
            if not gist.is_training:
                return (loss, outputs)
        a = outputs.attentions

        if gist.is_training and args.zuco and not args.direct and not args.pred_zuco and not args.aug:
            head_handling = args.head_handling
            if self.args.eeg_lmbda != 0:
                if head_handling == "z":
                    head_handling = args.eeg_head
                eeg_loss = (self.args.eeg_lmbda * \
                            gist.get_zuco_loss(a, inputs, zuco="eeg",
                                             layer_handling=args.layer_handling,
                                             head_handling=head_handling))
                if args.att_only:
                    loss = eeg_loss
                else:
                    loss += eeg_loss
                wandb.log({"train/eeg_loss": eeg_loss},
                            step=self.state.global_step)
            if self.args.et_lmbda != 0:
                if head_handling == "z":
                    head_handling = args.et_head
                et_loss = (self.args.et_lmbda * \
                           gist.get_zuco_loss(a, inputs, "et",
                                             layer_handling=args.layer_handling,
                                             head_handling=head_handling))
                if args.att_only and self.args.eeg_lmbda == 0:
                        # no EEG loss, no main loss
                        loss = et_loss
                else:
                    # if not att_only: loss == main and/or EEG
                    # if att_only and coef allows it: loss == EEG
                    loss += et_loss
                wandb.log({"train/et_loss": et_loss},
                            step=self.state.global_step)
            elif args.att_only and self.args.et_lmbda == 0 and self.args.eeg_lmbda == 0:
                loss = torch.zeros_like(outputs.loss, requires_grad=True)
            if self.args.et_lmbda != 0 or self.args.eeg_lmbda != 0:
                if self.args.et_lmbda != 0 and self.args.eeg_lmbda != 0:
                    wandb.log({"train/combined_loss": outputs.loss + et_loss + eeg_loss}, step=self.state.global_step)
                elif self.args.et_lmbda != 0:
                    wandb.log({"train/combined_loss": outputs.loss + et_loss},
                                step=self.state.global_step)
                elif self.args.eeg_lmbda != 0:
                    wandb.log({"train/combined_loss": outputs.loss + eeg_loss},
                                step=self.state.global_step)
        elif args.pred_zuco:
            ploss = outputs.ploss
            if args.pred_only:
                loss = ploss
            else:
                # convex combo trade-off
                loss = (1 - self.args.pred_lmbda) * loss
                loss += self.args.pred_lmbda * ploss
        if return_outputs:
            return (loss, outputs)
        else:
            return loss

def evaluate(trainer, training_args):
    """Run inference to save attentions for comparison."""
    model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path,
                                        config=gist.config)
    model.to(gist.device)
    model.eval()
    if args.zuco_only and args.zuco_splits:
        test_set = gist.ztestset
    else:
        test_set = gist.test_dataset
    if not args.per_sample:
        sets = [("test", test_set)]
    else:
        sets = [("test", test_set)]
    for setname, dataset in sets:
        if args.random_scores:
            setname += "_rand"
        loader = trainer.get_eval_dataloader(dataset)
        gist.all_att = {"eeg": [], "et": [], "model": [],
                        "input_ids": [], "preds": [], "eeg_kld": [],
                        "et_kld": [], "eeg_sim": [], "et_sim": [],
                        "incorrect": [], "labels": []}
        for _, inputs in enumerate(tqdm(loader, desc="Evaluating")):
            minputs = {k: v for (k, v) in inputs.items()
                       if k not in ["et", "eeg", "eeg_redmn", "et_trt",
                                    "eeg_embeds", "et_embeds"]}
            minputs = trainer._prepare_inputs(minputs)
            with torch.no_grad():
                outputs = model(**minputs,
                                output_attentions=True,
                                return_dict=True)
                a = outputs.attentions
                preds = outputs.logits.argmax(-1)

            gist.save_all_atts(a, inputs, preds,
                               layer_handling=args.layer_handling,
                               head_handling=args.head_handling)

        torch.save(gist.all_att,
                   f'{training_args.output_dir}/all_atts_{setname}_{args.layer_handling}.pt')

def log_res(preds, labels, savedir=None, ext=None):
    """
    Save dev, test predictions, labels to
    files, for significance testing.
    """
    if not ext:
        if 'zuco' in args.run_name:
            ext = '_zuco'
        else:
            ext = "_eval" if not args.do_predict else "_test"
    for n, fl in [("preds", preds),
                      ("labels", labels)]:
        if not savedir:
            savedir = f"{args.output_dir}{ext}-{args.chkpt}"
        if not os.path.exists(savedir):
            os.makedirs(savedir)
        np.savetxt(f"{savedir}/{n}.csv",
                   fl, delimiter=",")

def compute_metrics(pred):
    """Override of evaluation metrics."""
    labels = pred.label_ids
    preds = pred.predictions[0].argmax(-1)
    precision, recall, f1, support = score(labels, preds,
                                           average='weighted',
                                           zero_division=0)
    acc = accuracy_score(labels, preds)
    d = {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
    }

    log_res(preds,
            labels) if args.save_preds else None

    return d

def zuco_objective(metrics):
    """Metric for hyperparameter search to maximize or minimize.
       Default uses sum of compute_metrics() result.
       Set direction in hyperparameter_search()--default minimizes."""
    return metrics["eval_accuracy"]

def get_model(t:str):
    """
     Use model vs. model_init depending on whether we use search.
     Freeze all but attention weights if desired.
     Ignore custom BERT outputs.
    """
    if t == "model":
        config = AutoConfig.from_pretrained(args.model_name_or_path,
                                            num_labels=num_labels,
                                            label2id=label2id,
                                            id2label=id2label)

        tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased',
                                                  additional_special_tokens=["<e>", "</e>"])
        model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path,
                                            config=config)
        model.config.keys_to_ignore_at_inference = ["ploss",
                                                    "model_attentions",
                                                    "model_scores"]
        if model.bert.embeddings.word_embeddings.num_embeddings != len(tokenizer):
            model.resize_token_embeddings(len(tokenizer))
        if args.freeze:
            for name, param in model.bert.named_parameters():
                if 'attention' not in name.lower():
                    param.requires_grad = False
        return model if args.train_type == "train" else None
    else:
        return model_init if args.train_type == "grid" else None

def model_init(trial):
    """
        Initialize model for hp search w/ subclassed BERT.
        Freeze all but attention weights if desired.
        Ignore custom BERT outputs.
    """
    config = AutoConfig.from_pretrained(args.model_name_or_path,
                                        num_labels=num_labels,
                                        label2id=label2id,
                                        id2label=id2label)

    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased',
                                              additional_special_tokens=["<e>", "</e>"])
    model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path,
                                        config=config)
    model.config.keys_to_ignore_at_inference = ["ploss",
                                                "model_attentions",
                                                "model_scores"]
    if model.bert.embeddings.word_embeddings.num_embeddings != len(tokenizer):
        model.resize_token_embeddings(len(tokenizer))
    if args.freeze:
        for name, param in model.bert.named_parameters():
            if 'attention' not in name.lower():
                param.requires_grad = False
    return model

def hp_space(trial):
    """Hyperparameters to search over."""
    if args.pred_zuco:
        return {
            "pred_lmbda": trial.suggest_float("pred_lmbda",
                                              0.2, 1.0,
                                              step=0.2,
                                              log=False),
        }
    else:
        return {
            "eeg_lmbda": trial.suggest_float("eeg_lmbda",
                                              0.2, 1.0,
                                              step=0.2,
                                              log=False),
            "et_lmbda": trial.suggest_float("et_lmbda",
                                             0.2, 1.0,
                                             step=0.2,
                                             log=False)
        }
