import joblib
import numpy as np
from IPython import embed
import collections
import os
from sklearn.metrics import (
    precision_recall_fscore_support,
    accuracy_score,
    classification_report,
)
import torchmetrics
import datasets
from sklearn.model_selection import train_test_split
import torch.nn as nn
from sklearn.utils.class_weight import compute_class_weight
import torch
import pandas as pd
from tqdm import tqdm
from preprocess import CustomNonBinaryClassDataset
import json

from models import ProtoTEx
from models_electra import ProtoTEx_Electra
import sys

sys.path.append("../datasets")
import configs


class dce_loss(torch.nn.Module):
    def __init__(self, n_classes, feat_dim, init_weight=True):
        super(dce_loss, self).__init__()
        self.n_classes = n_classes
        self.feat_dim = feat_dim
        self.centers = nn.Parameter(
            torch.randn(self.feat_dim, self.n_classes).cuda(), requires_grad=True
        )
        if init_weight:
            self.__init_weight()

    def __init_weight(self):
        nn.init.kaiming_normal_(self.centers)

    def forward(self, x):
        features_square = torch.sum(torch.pow(x, 2), 1, keepdim=True)
        centers_square = torch.sum(torch.pow(self.centers, 2), 0, keepdim=True)
        features_into_centers = 2 * torch.matmul(x, (self.centers))
        dist = features_square + centers_square - features_into_centers

        return self.centers, -dist


def regularization(features, centers, labels):
    distance = features - torch.t(centers)[labels]

    distance = torch.sum(torch.pow(distance, 2), 1, keepdim=True)

    distance = (torch.sum(distance, 0, keepdim=True)) / features.shape[0]

    return distance


# Compute class weights
def get_class_weights(train_labels):
    class_weight_vect = compute_class_weight(
        "balanced", classes=np.unique(train_labels), y=train_labels
    )
    print(f"Class weight vectors: {class_weight_vect}")
    return class_weight_vect


def load_adv_data(dataset_info, data_dir, tokenizer):
    all_dataframes = []
    file_names = []
    for file in os.listdir(data_dir):
        if file.startswith("adv"):
            all_dataframes.append(pd.read_csv(os.path.join(data_dir, file)))
            file_names.append(file)
    return {
        file_name: load_classification_dataset(dataset_info, df, tokenizer)
        for file_name, df in zip(file_names, all_dataframes)
    }


def load_one_dataset(data_dir, tokenizer, max_length, test_file, split_training_data):

    if test_file == "train.csv":
        test_dfs = pd.read_csv(os.path.join(data_dir, "train.csv"))

        indices_to_pick = []
        for label in test_dfs["label"].unique():
            sub_df = test_dfs[test_dfs["label"] == label]
            sub_df_sample = sub_df.sample(
                n=min(10000, sub_df.shape[0]), random_state=42, replace=False
            )
            indices_to_pick.extend(sub_df_sample.index.tolist()[:100])
        test_dfs = test_dfs.loc[indices_to_pick]
        test_dfs = test_dfs.sample(frac=1, replace=False).reset_index(drop=True)
        test_dfs = {"train": test_dfs}
    else:
        test_files = {
            test_file[: test_file.find(".")]: os.path.join(data_dir, test_file)
        }
        test_dfs = {
            file_name: pd.read_csv(file_path)
            for file_name, file_path in test_files.items()
        }

    if split_training_data:
        print("Splitting the given one dataset!")
        df = list(test_dfs.values())[0]
        df_texts = df["text"].tolist()
        df_labels = df["label"].tolist()
        train_texts, val_texts, train_labels, val_labels = train_test_split(
            df_texts, df_labels, test_size=0.1, random_state=42
        )
        train_texts, test_texts, train_labels, test_labels = train_test_split(
            train_texts, train_labels, test_size=0.1, random_state=42
        )

        test_df = pd.DataFrame(
            {"text": test_texts, "label": test_labels, "split": "test"}
        )
        test_dfs = {
            "test": test_df,
        }

    test_dfs = {
        file_name: df
        for file_name, df in test_dfs.items()
        if "text" in df.columns and "label" in df.columns
    }

    return {
        file_name: load_classification_dataset(df, tokenizer, max_length)
        for file_name, df in test_dfs.items()
    }


def load_only_test_data(data_dir, tokenizer, max_length):
    test_files = {
        file[: file.find(".")]: os.path.join(data_dir, file)
        for file in os.listdir(data_dir)
        if file.startswith("test")
    }

    test_dfs = {
        file_name: pd.read_csv(file_path) for file_name, file_path in test_files.items()
    }
    test_dfs = {
        file_name: df
        for file_name, df in test_dfs.items()
        if "text" in df.columns and "label" in df.columns
    }

    return {
        file_name: load_classification_dataset(df, tokenizer, max_length)
        for file_name, df in test_dfs.items()
    }


def load_dataset(data_dir, tokenizer, max_length):
    train_df = pd.read_csv(os.path.join(data_dir, "train.csv"))

    indices_to_pick = []
    for label in train_df["label"].unique():
        sub_df = train_df[train_df["label"] == label]
        sub_df_sample = sub_df.sample(
            n=min(10000, sub_df.shape[0]), random_state=42, replace=False
        )
        indices_to_pick.extend(sub_df_sample.index.tolist())
    train_df = train_df.loc[indices_to_pick]
    train_df = train_df.sample(frac=1, replace=False).reset_index(drop=True)

    # if train_df.shape[0] > 10000:
    #     train_text = train_df["text"].tolist()
    #     train_labels = train_df["label"].tolist()
    #     train_text, _, train_labels, _ = train_test_split(
    #         train_text,
    #         train_labels,
    #         train_size=10000,
    #         stratify=train_labels,
    #         random_state=42,
    #     )
    #     train_df = pd.DataFrame({"text": train_text, "label": train_labels})

    print("Train data shape: ", train_df.shape)

    test_files = {
        file[: file.find(".")]: os.path.join(data_dir, file)
        for file in os.listdir(data_dir)
        if (file.startswith("test") or file.startswith("adv") or file.startswith("val"))
    }

    test_dfs = {
        file_name: pd.read_csv(file_path) for file_name, file_path in test_files.items()
    }
    test_dfs = {
        file_name: df
        for file_name, df in test_dfs.items()
        if "text" in df.columns and "label" in df.columns
    }

    return {
        "train": load_classification_dataset(train_df, tokenizer, max_length),
        **{
            file_name: load_classification_dataset(df, tokenizer, max_length)
            for file_name, df in test_dfs.items()
        },
    }


# def load_nli_dataset(dataset_info, df, tokenizer):
#     sentences1 = df["sentence1"].tolist()
#     sentences2 = df["sentence2"].tolist()
#     labels = df["label"].tolist()

#     sentences = (sentences1, sentences2)

#     dataset = CustomNonBinaryClassDataset(
#         sentences=sentences,
#         labels=labels,
#         tokenizer=tokenizer,
#         max_length=dataset_info.max_length,
#     )

#     return dataset


def preprocess_data(tokenizer, dataset, max_length):
    def tokenize_function(examples):
        x = tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=max_length,
        )
        return x

    try:
        tokenized_dataset = dataset.map(tokenize_function, batched=True)
    except Exception as e:
        embed()
        exit()
    return tokenized_dataset


def load_classification_dataset(df, tokenizer, max_length):
    dataset = datasets.Dataset.from_pandas(df)
    tokenized_dataset = preprocess_data(tokenizer, dataset, max_length)

    return tokenized_dataset


def print_predictions(file, predictions, labels):
    df = pd.DataFrame(
        {"index": range(len(predictions)), "predictions": predictions, "labels": labels}
    )
    df.to_csv(file, index=False)


def print_logs(
    file, info, epoch, val_loss, mac_val_prec, mac_val_rec, mac_val_f1, accuracy
):
    logs = []
    s = " ".join((info + " epoch", str(epoch), "Total loss %.4f" % (val_loss), "\n"))
    logs.append(s)
    print(s)
    s = " ".join((info + " epoch", str(epoch), "Prec", str(mac_val_prec), "\n"))
    logs.append(s)
    print(s)
    s = " ".join((info + " epoch", str(epoch), "Recall", str(mac_val_rec), "\n"))
    logs.append(s)
    print(s)
    s = " ".join((info + " epoch", str(epoch), "F1", str(mac_val_f1), "\n"))
    logs.append(s)
    print(s)
    s = " ".join((info + " epoch", str(epoch), "Accuracy", str(accuracy), "\n"))
    logs.append(s)
    print(s)
    #     print("epoch",epoch,"MICRO val precision %.4f, recall %.4f, f1 %.4f,"%(mic_val_prec,mic_val_rec,mic_val_f1))
    print()
    logs.append("\n")
    if file is not None:
        f = open(file, "a")
        f.writelines(logs)
        f.close()


class EarlyStopping(object):
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(
        self,
        score_at_min1=0,
        patience=100,
        verbose=False,
        delta=0,
        path="checkpoint.pt",
        trace_func=print,
        save_epochwise=False,
    ):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = score_at_min1
        self.early_stop = False
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
        self.state_dict_list = [None] * patience
        self.improved = 0
        self.save_model_counter = 0
        self.save_epochwise = save_epochwise
        self.times_improved = 0
        self.activated = False

    def activate(self, s1):
        if not self.activated and s1 > 0:
            self.activated = True

    def __call__(self, score, epoch, model):
        if not self.activated:
            return None
        self.save_model_counter = (self.save_model_counter + 1) % 4

        if self.verbose:
            self.trace_func(
                f"\033[91m The val score  of epoch {epoch} is {score:.4f} \033[0m"
            )
        if score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(
                f"\033[93m EarlyStopping counter: {self.counter} out of {self.patience} \033[0m"
            )
            if self.counter >= self.patience:
                self.early_stop = True
            self.improved = 0
        else:
            self.save_checkpoint(score, model, epoch)
            self.best_score = score
            self.counter = 0
            self.improved = 1

    def save_checkpoint(self, score, model, epoch):
        """Saves model when validation loss decrease."""
        # if self.verbose:
        self.times_improved += 1
        self.trace_func(
            f"\033[92m Validation score improved ({self.best_score:.4f} --> {score:.4f}). \033[0m"
        )
        if self.save_epochwise:
            path = self.path + "_" + str(self.times_improved) + "_" + str(epoch)
        else:
            path = self.path
        torch.save(model.state_dict(), path)


def evaluate(dl, model_new=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    assert model_new is not None
    loader = tqdm(dl, total=len(dl), unit="batches")
    total_len = 0
    model_new.eval()
    model_new = model_new.to(device)
    with torch.no_grad():
        total_loss = 0
        # tts = 0
        y_pred = []
        y_true = []
        for batch in loader:
            input_ids = batch["input_ids"]
            attn_mask = batch["attention_mask"]
            y = batch["label"]
            input_ids = input_ids.to(device)
            attn_mask = attn_mask.to(device)
            y = y.to(device)
            classfn_out, loss = model_new(input_ids, attn_mask, y, use_classfn=1)
            #             print(classfn_out.detach().cpu())
            if classfn_out.ndim == 1:
                predict = torch.zeros_like(y)
                predict[classfn_out > 0] = 1
            else:
                predict = torch.argmax(classfn_out, dim=1)

            y_pred.append(predict.cpu().numpy())
            #             y_pred.append(torch.zeros_like(y).numpy())
            y_true.append(y.cpu().numpy())
            total_loss += len(input_ids) * loss[0].item()
            total_len += len(input_ids)
        #             torch.cuda.empty_cache()
        total_loss = total_loss / total_len
        mac_prec, mac_recall, mac_f1_score, _ = precision_recall_fscore_support(
            np.concatenate(y_true), np.concatenate(y_pred), average="weighted"
        )
        accuracy = accuracy_score(np.concatenate(y_true), np.concatenate(y_pred))
        print(f"LABELS: {np.unique(np.concatenate(y_true))}")
        print(
            f"classification_report:\n{classification_report(np.concatenate(y_true),np.concatenate(y_pred), labels=np.unique(np.concatenate(y_true)), digits = 3)}"
        )

    return (
        total_loss,
        mac_prec,
        mac_recall,
        mac_f1_score,
        accuracy,
        np.concatenate(y_true),
        np.concatenate(y_pred),
    )


def load_model(modelname, num_prototypes, architecture, dataset):
    print("ProtoTEx best model: {0}".format(num_prototypes))
    if architecture == "BART":
        print(f"Using backone: {architecture}")
        torch.cuda.empty_cache()
        model = ProtoTEx(
            num_prototypes=num_prototypes,
            class_weights=None,
            n_classes=configs.dataset_to_num_labels[dataset],
            max_length=configs.dataset_to_max_length[dataset],
            bias=False,
            special_classfn=True,
            p=1,  # p=0.75,
            batchnormlp1=True,
        )
    elif architecture == "ELECTRA":
        model = ProtoTEx_Electra(
            num_prototypes=num_prototypes,
            class_weights=None,
            n_classes=configs.dataset_to_num_labels[dataset],
            max_length=configs.dataset_to_max_length[dataset],
            bias=False,
            special_classfn=True,
            p=1,  # p=0.75,
            batchnormlp1=True,
        )

    else:
        print(f"Invalid backbone architecture: {architecture}")
    print(f"Loading model checkpoint: Models/{modelname}")
    pretrained_dict = torch.load(f"Models/{modelname}")
    # Fiter out unneccessary keys
    model_dict = model.state_dict()
    filtered_dict = {}
    for k, v in pretrained_dict.items():
        if k in model_dict and model_dict[k].shape == v.shape:
            filtered_dict[k] = v
        else:
            print(f"Skipping weights for: {k}")
    model_dict.update(filtered_dict)
    model.load_state_dict(model_dict)
    return model


# Functions for analyzing prototypes


def get_best_k_protos_for_batch(
    dataloader,
    model_new=None,
    topk=None,
    do_all=False,
    architecture="BART",
):
    """
    get the best k protos for that a fraction of test data where each element has a specific true label.
    the "best" is in the sense that it has (or is one of those who has) the minimal distance
    from the encoded representation of the sentence.
    """
    assert model_new is not None

    loader = tqdm(dataloader, total=len(dataloader), unit="batches")
    model_new.eval()
    with torch.no_grad():
        # Updated for negative prototypes

        all_protos = model_new.prototypes

        best_protos = []
        best_protos_dists = []
        all_predictions = []
        all_correct_labels = []
        for batch in loader:
            input_ids = batch["input_ids"]
            attn_mask = batch["attention_mask"]
            y = batch["label"]
            batch_size = input_ids.size(0)
            if architecture == "BART":
                last_hidden_state = model_new.bart_model.base_model.encoder(
                    input_ids.cuda(),
                    attn_mask.cuda(),
                    output_attentions=False,
                    output_hidden_states=False,
                ).last_hidden_state
            elif architecture == "BERT":
                last_hidden_state = model_new.bert_model(
                    input_ids.cuda(), attn_mask.cuda()
                ).last_hidden_state
            elif architecture == "ELECTRA":
                last_hidden_state = model_new.electra_model(
                    input_ids.cuda(), attn_mask.cuda()
                ).last_hidden_state
            else:
                raise ValueError("Invalid architecture")
            if not model_new.dobatchnorm:
                if model_new.use_cosine_dist:
                    input_for_classfn = (
                        torchmetrics.functional.pairwise_cosine_similarity(
                            last_hidden_state.view(batch_size, -1),
                            all_protos.view(model_new.num_protos, -1),
                        )
                    )
                else:
                    input_for_classfn = torch.cdist(
                        last_hidden_state.view(batch_size, -1),
                        all_protos.view(model_new.num_protos, -1),
                    )
            else:
                if model_new.use_cosine_dist:
                    input_for_classfn = (
                        torchmetrics.functional.pairwise_cosine_similarity(
                            last_hidden_state.view(batch_size, -1),
                            all_protos.view(model_new.num_protos, -1),
                        )
                    )
                else:
                    input_for_classfn = torch.cdist(
                        last_hidden_state.view(batch_size, -1),
                        all_protos.view(model_new.num_protos, -1),
                    )
                input_for_classfn = torch.nn.functional.instance_norm(
                    input_for_classfn.view(batch_size, 1, model_new.num_protos)
                ).view(batch_size, model_new.num_protos)

            if do_all:
                temp = torch.topk(
                    input_for_classfn, dim=1, k=topk, largest=model_new.use_cosine_dist
                )
                predicted = torch.argmax(
                    model_new.classfn_model(input_for_classfn).view(
                        batch_size, model_new.n_classes
                    ),
                    dim=1,
                )
            else:
                predicted = torch.argmax(
                    model_new.classfn_model(input_for_classfn).view(
                        batch_size, model_new.n_classes
                    ),
                    dim=1,
                )
                concerned_idxs = torch.nonzero((predicted == y.cuda())).view(-1)
                temp = torch.topk(
                    input_for_classfn[concerned_idxs], dim=1, k=topk, largest=False
                )
            best_protos.append(temp[1].cpu())
            best_protos_dists.append(temp[0].cpu())
            all_predictions.append(predicted.cpu())
            all_correct_labels.append(y.cpu())
        #             best_protos.append((torch.topk(input_for_classfn,dim=1,
        #                                               k=topk,largest=False)[1]).cpu())
        best_protos = torch.cat(best_protos, dim=0).numpy().tolist()
        best_protos_dists = torch.cat(best_protos_dists, dim=0).numpy().tolist()
        all_predictions = torch.cat(all_predictions, dim=0).numpy().tolist()
        all_correct_labels = torch.cat(all_correct_labels, dim=0).numpy().tolist()
    return {
        "best_protos": best_protos,
        "best_protos_dists": best_protos_dists,
        "all_predictions": all_predictions,
        "all_correct_labels": all_correct_labels,
    }


def get_bestk_train_data_for_every_proto(
    train_dataset_loader, model_new=None, top_k=3, architecture="BART"
):
    """
    for every prototype find out k best similar training examples
    """

    loader = tqdm(train_dataset_loader, total=len(train_dataset_loader), unit="batches")
    model_new.eval()
    with torch.no_grad():
        best_train_egs = []
        best_train_egs_values = []
        all_distances = torch.tensor([])
        all_texts = []
        all_labels = []
        predict_all = torch.tensor([])
        true_all = torch.tensor([])

        all_protos = model_new.prototypes
        for batch_index, batch in enumerate(loader):
            input_ids = batch["input_ids"]
            attn_mask = batch["attention_mask"]
            y = batch["label"]
            text = batch["text"]
            batch_size = input_ids.size(0)

            if architecture == "BART":
                last_hidden_state = model_new.bart_model.base_model.encoder(
                    input_ids.cuda(),
                    attn_mask.cuda(),
                    output_attentions=False,
                    output_hidden_states=False,
                ).last_hidden_state
            elif architecture == "BERT":
                last_hidden_state = model_new.bert_model(
                    input_ids.cuda(), attn_mask.cuda()
                ).last_hidden_state
            elif architecture == "ELECTRA":
                last_hidden_state = model_new.electra_model(
                    input_ids.cuda(), attn_mask.cuda()
                ).last_hidden_state
            else:
                raise ValueError("Invalid architecture")

            if not model_new.dobatchnorm:
                if model_new.use_cosine_dist:
                    input_for_classfn = (
                        torchmetrics.functional.pairwise_cosine_similarity(
                            last_hidden_state.view(batch_size, -1),
                            all_protos.view(model_new.num_protos, -1),
                        )
                    )
                else:
                    input_for_classfn = torch.cdist(
                        last_hidden_state.view(batch_size, -1),
                        all_protos.view(model_new.num_protos, -1),
                    )
            else:
                if model_new.use_cosine_dist:
                    input_for_classfn = (
                        torchmetrics.functional.pairwise_cosine_similarity(
                            last_hidden_state.view(batch_size, -1),
                            all_protos.view(model_new.num_protos, -1),
                        )
                    )
                else:
                    input_for_classfn = torch.cdist(
                        last_hidden_state.view(batch_size, -1),
                        all_protos.view(model_new.num_protos, -1),
                    )
                input_for_classfn = torch.nn.functional.instance_norm(
                    input_for_classfn.view(batch_size, 1, model_new.num_protos)
                ).view(batch_size, model_new.num_protos)
            predicted = torch.argmax(
                model_new.classfn_model(input_for_classfn).view(
                    batch_size, model_new.n_classes
                ),
                dim=1,
            )
            concerned_idxs = torch.nonzero((predicted == y.cuda())).view(-1)
            input_for_classfn = input_for_classfn[concerned_idxs]
            selected_text = [text[i] for i in concerned_idxs]
            selected_labels = [y[i] for i in concerned_idxs]

            all_distances = torch.cat((all_distances, input_for_classfn.cpu()), dim=0)
            all_texts.extend(selected_text)
            all_labels.extend(selected_labels)

    best_distances = torch.topk(
        all_distances, dim=0, k=top_k, largest=model_new.use_cosine_dist
    )
    prototypes_texts = {}
    for i in range(model_new.num_protos):
        prototypes_texts[i] = []
        for index in best_distances[1][:, i]:
            prototypes_texts[i].append(
                [all_texts[index.item()], all_labels[index.item()].item()]
            )
    return {"best_train_egs": prototypes_texts}
    # else:
    #     best_train_egs = torch.cat(best_train_egs, dim=0)
    #     best_train_egs_values = torch.cat(best_train_egs_values, dim=0)
    #     best_of_all_examples_for_each_prototype = torch.topk(
    #         best_train_egs_values, dim=0, k=top_k, largest=False
    #     )
    #     topk_idxs = best_of_all_examples_for_each_prototype[1]
    #     final_concerned_idxs = []
    #     for i in range(best_train_egs.size(1)):
    #         concerned_idxs = best_train_egs[topk_idxs[:, i], i]
    #         final_concerned_idxs.append(concerned_idxs)
    #     #         true_all=torch.cat(true_all,dim=0)
    #     #         predict_all=torch.cat(predict_all,dim=0)
    #     return (
    #         torch.stack(final_concerned_idxs, dim=0).cpu().numpy(),
    #         best_of_all_examples_for_each_prototype[0].cpu().numpy().T,
    #     )


def best_protos_for_test(test_dataloader, model_new=None, top_k=5):
    batch_size = 60

    all_protos = model_new.prototypes
    batch = next(iter(test_dataloader))
    input_ids = batch["input_ids"]
    attn_mask = batch["attention_mask"]
    y = batch["label"]
    with torch.no_grad():
        last_hidden_state = model_new.bart_model.base_model.encoder(
            input_ids.cuda(),
            attn_mask.cuda(),
            output_attentions=False,
            output_hidden_states=False,
        ).last_hidden_state
        input_for_classfn = torch.cdist(
            last_hidden_state.view(batch_size, -1),
            all_protos.view(model_new.num_protos, -1),
        )
        predicted = torch.argmax(model_new.classfn_model(input_for_classfn), dim=1)
        proper_idxs_pos = (
            torch.nonzero(torch.logical_and(predicted == y, y == 1)).view(-1)
        )[:15]

        pos_best_protos = torch.topk(
            input_for_classfn[proper_idxs_pos], dim=1, k=top_k, largest=False
        )[1]

    return input_ids[proper_idxs_pos], pos_best_protos
