# Based on https://github.com/Text2TCS/Transrelation/blob/master/CogALex_XLM_RoBERTa_Updated_Task_Version.ipynb
import os
import torch                                              #for training the model
import time
import datetime
import random
import numpy as np
import pandas as pd                                       #for handling the data
from typing import List
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import TensorDataset
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn import preprocessing                         #for label encoding
from sklearn.metrics import classification_report         #for showing performance on validation/test sets
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import ParameterGrid         #for gridsearch
from tqdm import tqdm

from nlp_safety_prop.utils.paths import get_data_path
from nlp_safety_prop.models.embeddings.roberta import RobertaModelPreTrained
from nlp_safety_prop.models.embeddings.embedder import Embedder


def format_time(elapsed: int):
    """Return time as hh:mm:ss"""
    elapsed_rounded = int(round(elapsed))
    return str(datetime.timedelta(seconds=elapsed_rounded))


class Transrelation(Embedder):
    def __init__(self, max_len: int = 32, seed_val: int = 42, epochs: int = 7):
        self.roberta = RobertaModelPreTrained("xlm-roberta-base", model_args={"num_labels": 4})
        self.xlmr_tokenizer = self.roberta.tokenizer
        self.xlmr_model = self.roberta.model
        self.data_path = get_data_path("raw/synonymy/CogALex_VI/")
        self.model_path = get_data_path("models/xlm_roberta_synonymy/")
        self.max_len = max_len
        self.batch_size = 64
        self.labels = ["ANT", "HYP", "RANDOM", "SYN"]
        self.langs = {"zh": "chinese", "en": "english", "de": "german"}
        self.columns = ["Word1", "Word2", "Label"]
        self.le = preprocessing.LabelEncoder()
        self.le.fit(self.labels)
        self.data_train = None
        self.data_valid = None
        self.data_train_all = None
        self.data_valid_all = None
        self.device = torch.device('cuda')

        print('Connected to GPU:', torch.cuda.get_device_name(0))

        if (os.path.exists(self.model_path)):
            self.roberta = RobertaModelPreTrained(self.model_path, model_args={"num_labels": 4})
            print('Loaded Classification Model From: ', self.model_path)
            self.xlmr_tokenizer = self.roberta.tokenizer
            self.xlmr_model = self.roberta.model
            self.desc = self.xlmr_model.to(self.device)
            self.data_load()

        else:
            print('Training Model: ')

            self.scheduler = None
            self.seed_val = seed_val
            self.epochs = epochs
            self.roberta = RobertaModelPreTrained("xlm-roberta-base", model_args={"num_labels": 4})
            self.xlmr_tokenizer = self.roberta.tokenizer
            self.xlmr_model = self.roberta.model
            self.desc = self.xlmr_model.to(self.device)
            self.data_load()

            self.scheduler = None
            self.seed_val = seed_val
            self.epochs = epochs
            self.optimizer = AdamW(self.xlmr_model.parameters(),
                                   lr=2e-5,  # do work well: 2e-5 with 5-7 epochs for trainall, 1e-5
                                   eps=1e-8  # 1e-8
                                   # weight_decay = 0
                                   )
            self.train()
            # torch.save(self.xlmr_model, self.model_path + 'xlm_roberta_synonymy_classifier.pt')
            self.xlmr_model.save_pretrained(self.model_path)

    def data_load(self):
        # sk learn label encoder for changing the labels to integers

        self.data_train = dict()
        self.data_valid = dict()

        for key in self.langs:
            self.data_train[key] = pd.read_csv(os.path.join(self.data_path, f'train/train_{self.langs[key]}_data.txt'),
                                               sep="\t", header=None)
            self.data_train[key].columns = self.columns
            self.data_valid[key] = pd.read_csv(os.path.join(self.data_path, f'valid/validgold_{self.langs[key]}_data.txt'),
                                               sep="\t", header=None)
            self.data_valid[key].columns = self.columns

            self.data_train[key]["Label"] = self.le.transform(self.data_train[key]["Label"])
            self.data_valid[key]["Label"] = self.le.transform(self.data_valid[key]["Label"])

        # all together
        self.data_train_all = pd.concat([self.data_train[key] for key in self.langs])
        self.data_train_all = self.data_train_all.reset_index(drop=True)
        self.data_valid_all = pd.concat([self.data_valid[key] for key in self.langs])
        self.data_valid_all = self.data_valid_all.reset_index(drop=True)

    def tokenizer_xlm(self, data):
        labels_ = []
        input_ids_ = []
        attn_masks_ = []

        # for each datasample:
        for index, row in data.iterrows():
            word1 = row['Word1']
            word2 = row['Word2']

            # create requiered input, i.e. ids and attention masks
            encoded_dict = self.xlmr_tokenizer.encode_plus(word1, word2,
                                                           max_length=self.max_len,
                                                           padding='max_length',
                                                           truncation=True,
                                                           return_tensors='pt')

            # add encoded sample to lists
            input_ids_.append(encoded_dict['input_ids'])
            attn_masks_.append(encoded_dict['attention_mask'])
            labels_.append(row['Label'])

        # Convert each Python list of Tensors into a 2D Tensor matrix.
        input_ids_ = torch.cat(input_ids_, dim=0)
        attn_masks_ = torch.cat(attn_masks_, dim=0)

        # labels to tensor
        labels_ = torch.tensor(labels_)

        print('Encoder finished. {:,} examples.'.format(len(labels_)))
        return input_ids_, attn_masks_, labels_
        
    def prep_data(self):
        # tokenize data

        # all
        print("All")
        input_ids_train_all, attn_masks_train_all, labels_train_all = self.tokenizer_xlm(self.data_train_all)
        input_ids_valid_all, attn_masks_valid_all, labels_valid_all = self.tokenizer_xlm(self.data_valid_all)
        
        input_ids_train = dict()
        input_ids_valid = dict()
        attn_masks_train = dict()
        attn_masks_valid = dict()
        labels_train = dict()
        labels_valid = dict()

        for key in self.langs:
            print(key)
            input_ids_train[key], attn_masks_train[key], labels_train[key] = self.tokenizer_xlm(self.data_train[key])
            input_ids_valid[key], attn_masks_valid[key], labels_valid[key] = self.tokenizer_xlm(self.data_valid[key])

        # Combine the training inputs into a TensorDataset.

        # all
        tensor_data_train_all = TensorDataset(input_ids_train_all, attn_masks_train_all, labels_train_all)
        tensor_data_valid_all = TensorDataset(input_ids_valid_all, attn_masks_valid_all, labels_valid_all)

        tensor_data_train = dict()
        tensor_data_valid = dict()

        for key in self.langs:
            tensor_data_train[key] = TensorDataset(input_ids_train[key], attn_masks_train[key], labels_train[key])
            tensor_data_valid[key] = TensorDataset(input_ids_valid[key], attn_masks_valid[key], labels_valid[key])

        # prepare pytorch dataloaders

        # all
        train_dataloader_all = DataLoader(tensor_data_train_all, sampler=RandomSampler(tensor_data_train_all),
                                          batch_size=self.batch_size)  # random sampling
        validation_dataloader_all = DataLoader(tensor_data_valid_all, sampler=SequentialSampler(tensor_data_valid_all),
                                               batch_size=self.batch_size)  # sequential sampling

        train_dataloader = dict()
        validation_dataloader = dict()

        for key in self.langs:
            # zh
            train_dataloader[key] = DataLoader(tensor_data_train[key], sampler=RandomSampler(tensor_data_train[key]),
                                               batch_size=self.batch_size)
            validation_dataloader[key] = DataLoader(tensor_data_valid[key], sampler=SequentialSampler(tensor_data_valid[key]),
                                                    batch_size=self.batch_size)

        return train_dataloader, train_dataloader_all, validation_dataloader, validation_dataloader_all

    def prep_model(self, train_dataloader: DataLoader):
        print("Training Samples:", len(train_dataloader.dataset))

        # number of batches x epochs
        total_steps = len(train_dataloader) * self.epochs
        print("total steps:", total_steps)

        # scheduler for lr
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
                                                         num_warmup_steps=0,
                                                         # start low and increase learning rate during these steps
                                                         num_training_steps=total_steps)

    def validate(self, validation_dataloader: DataLoader, verbose: bool):
        # put model in evaluation mode
        self.xlmr_model.eval()

        # variable for loss, predictions and labels
        total_eval_loss = 0
        predictions, true_labels = [], []

        # Evaluate data for one epoch
        for batch in validation_dataloader:
            # Unpack training batch and copy the tensors to the gpu
            b_input_ids = batch[0].to(self.device)
            b_input_mask = batch[1].to(self.device)
            b_labels = batch[2].to(self.device)

            # no backprop needed
            with torch.no_grad():
                # forward pass
                output = self.xlmr_model(b_input_ids,
                                         token_type_ids=None,
                                         attention_mask=b_input_mask,
                                         labels=b_labels)
                loss = output.loss
                logits = output.logits

            # add up loss
            total_eval_loss += loss.item()

            # on cpu
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.to('cpu').numpy()

            # save preds/true labels
            predictions.append(logits)
            true_labels.append(label_ids)

        # results of the whole validation set
        flat_predictions = np.concatenate(predictions, axis=0)
        flat_true_labels = np.concatenate(true_labels, axis=0)

        # logit to label
        predicted_labels = np.argmax(flat_predictions, axis=1).flatten()

        # print classification report
        if verbose:
            print(classification_report(flat_true_labels, predicted_labels, target_names=self.labels))

        # Calculate the validation accuracy, macro f1, and weighted f1 without RANDOM
        val_accuracy = (predicted_labels == flat_true_labels).mean()
        macroF1 = f1_score(flat_true_labels, predicted_labels, average='macro')
        weightedF1_no_random = f1_score(flat_true_labels, predicted_labels, average='weighted', labels=[0, 1, 3])
        print("\t Weighted F1 (no random):", weightedF1_no_random)

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / len(validation_dataloader)

        # plot confusion matrix
        if verbose:
            print(confusion_matrix(flat_true_labels, predicted_labels, labels=[0, 1, 2, 3]))

        return avg_val_loss, val_accuracy, macroF1, weightedF1_no_random

    def train_model(self, train_dataloader: DataLoader, validation_dataloader_set: List[DataLoader], verbose):
        random.seed(self.seed_val)
        np.random.seed(self.seed_val)
        torch.manual_seed(self.seed_val)
        torch.cuda.manual_seed_all(self.seed_val)

        # mostly contains scores about how the training went for each epoch
        training_stats = []

        # total training time
        total_t0 = time.time()

        print('\033[1m' + "================ Model Training ================" + '\033[0m')

        # For each epoch...
        for epoch_i in range(0, self.epochs):

            print("")
            print('\033[1m' + '======== Epoch {:} / {:} ========'.format(epoch_i + 1, self.epochs) + '\033[0m')

            t0 = time.time()

            # summed training loss of the epoch
            total_train_loss = 0

            # model is being put into training mode as mechanisms like dropout work differently during train and test time
            self.xlmr_model.train()

            # For each batch of training data...
            for step, batch in tqdm(enumerate(train_dataloader), desc="Batch", total=len(train_dataloader)):
                # unpack training batch at load it to gpu (device)
                b_input_ids = batch[0].to(self.device)
                b_input_mask = batch[1].to(self.device)
                b_labels = batch[2].to(self.device)

                # clear gradients before calculating new ones
                self.xlmr_model.zero_grad()

                # forward pass with current batch
                output = self.xlmr_model(b_input_ids,
                                         token_type_ids=None,
                                         attention_mask=b_input_mask,
                                         labels=b_labels)

                loss = output.loss
                logits = output.logits

                # add up the loss
                total_train_loss += loss.item()

                # calculate new gradients
                loss.backward()

                # gradient clipping (not bigger than)
                torch.nn.utils.clip_grad_norm_(self.xlmr_model.parameters(), 1.0)

                # Update the networks weights based on the gradient as well as the optimiziers parameters
                self.optimizer.step()

                # lr update
                self.scheduler.step()

            # avg loss over all batches
            avg_train_loss = total_train_loss / len(train_dataloader)

            # training time of this epoch
            training_time = format_time(time.time() - t0)

            print("")
            print("  Average training loss: {0:.2f}".format(avg_train_loss))
            print("  Training epoch took: {:}".format(training_time))

            # VALIDATION

            # all
            print("evaluate on all")
            avg_val_loss_all, val_accuracy_all, macroF1_all, weightedF1_no_random_all = self.validate(validation_dataloader_set[0],
                                                                                                      verbose)
            avg_val_loss = dict()
            val_accuracy = dict()
            macroF1 = dict()
            weightedF1_no_random = dict()

            for key in self.langs:
                print(f"evaluate on {self.langs[key]}")
                avg_val_loss[key], val_accuracy[key], macroF1[key], weightedF1_no_random[key] = self.validate(validation_dataloader_set[1][key],
                                                                                                            verbose)


                print('\033[1m' + "  Validation Loss All: {0:.2f}".format(avg_val_loss_all) + '\033[0m')

                training_stats.append(
                    {
                        'epoch': epoch_i + 1,
                        'Training Loss': avg_train_loss,
                        'Valid. Loss all': avg_val_loss_all,
                        'Valid. Accur. all': val_accuracy_all,
                        'Weigh_F1 all (no random)': weightedF1_no_random_all,
                        'Macro F1 all': macroF1_all,
                        f'Weigh_F1 {key}': weightedF1_no_random[key],
                        'Training Time': training_time,
                    }
                )

        print("\n\nTraining complete!")
        print("Total training took {:} (h:mm:ss)".format(format_time(time.time() - total_t0)))

        return training_stats

    def train(self):
        # start training
        train_dataloader, train_dataloader_all, validation_dataloader, validation_dataloader_all = self.prep_data()
        self.prep_model(train_dataloader_all)
        validation_dataloader_set = [validation_dataloader_all, validation_dataloader]
        training_stats = self.train_model(train_dataloader=train_dataloader_all,
                                          validation_dataloader_set=validation_dataloader_set,
                                          verbose=True)

    def get_embeddings(self, test_data: List[dict]) -> torch.Tensor:
        embeddings = list()

        self.xlmr_model.eval()

        for row in test_data:
            word_embs = list()
            for word in (row["Word1"], row["Word2"]):
                encoded_dict = self.xlmr_tokenizer(word,
                                                   max_length=self.max_len,
                                                   padding='max_length',
                                                   truncation=True,
                                                   return_tensors='pt')

                encoded_dict['input_ids'] = encoded_dict['input_ids'].to(self.device)
                encoded_dict['attention_mask'] = encoded_dict['attention_mask'].to(self.device)
                encoded_dict['labels'] = torch.tensor(0).to(self.device)  # 0=antonym, 1=hyponym, 2=random, 3=synonym
                encoded_dict['token_type_ids'] = None
                encoded_dict['output_hidden_states'] = True
                outputs = self.xlmr_model(**encoded_dict)
                hidden = outputs.hidden_states[-1]
                mask = encoded_dict['attention_mask'].unsqueeze(-1).expand(hidden.size()).float()
                masked_emb = hidden * mask
                summed_emb = torch.sum(masked_emb, 1)
                summed_mask = torch.clamp(mask.sum(1), min=1e-9)
                mean_pooled_emb = summed_emb / summed_mask
                word_embs.append(mean_pooled_emb.detach().to("cpu")[0])

                torch.cuda.empty_cache()
                torch.cuda.synchronize()

            embeddings.append(torch.stack(word_embs))

        return torch.Tensor(torch.stack(embeddings))
    
    def get_representations(self, data: pd.DataFrame) -> torch.Tensor:
        word_pairs = []

        for index, row in data.iterrows():
            word_1, word_2 = row['word_1'], row['word_2']
            word_pairs.append({"Word1": word_1, 
                               "Word2": word_2,
                               })

        return self.get_embeddings(word_pairs)

    def annotate_test_data(self, test_data, output_file_name):
        """Annotates a file that was already parsed to a dataframe (uses tokenizer from before)"""
        pass
        #     (loss, logits)
        #     logits = logits.detach().cpu().numpy()
        #
        #     # write to file
        #     file.write(word1 + "\t" + word2 + "\t" + labels[logits[0].argmax(axis=0)] + "\n")
        # file.close()
