from pathlib import Path
import json
from torch import load
import time
from omegaconf.omegaconf import OmegaConf
from requests.models import HTTPError

from tokenizers import SentencePieceBPETokenizer, Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import Lowercase
from tokenizers.processors import TemplateProcessing
from transformers import (
    AutoModelForSequenceClassification,
    AutoModelForSeq2SeqLM,
    AutoModelForTokenClassification,
    AutoModelForQuestionAnswering,
    AutoModelForImageClassification,
    AutoModelForAudioClassification,
    AutoFeatureExtractor,
    AutoTokenizer,
    PreTrainedTokenizerFast,
    set_seed,
)

from adapters import ElectraAdapterModel, AutoAdapterModel
from adapters import BnConfig, SeqBnConfig, LoRAConfig, DoubleSeqBnConfig, UniPELTConfig

from torch.nn.utils import spectral_norm


def create_transformers_model_tokenizer(
    model_cfg,
    id2label: dict = None,
    seed: int = 42,
    cache_dir=None,
    embeddings=None,
    word2idx=None,
):

    set_seed(seed)
    pretrained_model_name = model_cfg.name
    num_labels = model_cfg.num_labels if id2label is None else len(id2label)
    label2id = {v: k for k, v in id2label.items()} if id2label is not None else None
    classifier_dropout = model_cfg.classifier_dropout

    model_cache_dir = Path(cache_dir) / "model" if cache_dir is not None else None
    tokenizer_cache_dir = (
        Path(cache_dir) / "tokenizer" if cache_dir is not None else None
    )

    # use adapter for training
    use_adapter = model_cfg.get("adapter", False)
    # use setfit for training
    use_setfit = model_cfg.get("setfit", False)
    # use custom ner
    kl_ner = model_cfg.get("kl_ner", False)

    # if use_setfit:
    #     model = SetFitModel.from_pretrained(
    #         pretrained_model_name,
    #         cache_dir=model_cache_dir,
    #     )
    #     tokenizer = None
    #     print("created SetFit model")
    #     return model, tokenizer
    if model_cfg.exists_in_repo:

        model_mapping = {
            "cls": AutoModelForSequenceClassification,
            "ner": AutoModelForTokenClassification,
            "abs-sum": AutoModelForSeq2SeqLM,
            "qa": AutoModelForQuestionAnswering,
            "cv_cls": AutoModelForImageClassification,
            "sp_cls": AutoModelForAudioClassification,
        }

        # if model_cfg.get("adapter") == True and "electra" in model_cfg.name:
        #     model_class = ElectraAdapterModel
        # else:
        if model_cfg.get("adapter"):
            model_class = AutoAdapterModel
        else:
            model_class = model_mapping.get(model_cfg.type, None)


        if model_cfg.type in ["cv_cls", "sp_cls"]:
            kwargs = {}
        else:
            kwargs = get_classifier_dropout_kwargs(
                pretrained_model_name, classifier_dropout
            )
        if (
            model_cfg.type == "cls"
            or model_cfg.type == "cv_cls"
            or model_cfg.type == "sp_cls"
        ):
            kwargs["ignore_mismatched_sizes"] = True
        if num_labels is not None:
            kwargs["num_labels"] = num_labels
        tokenizer_kwargs = get_tokenizer_kwargs(pretrained_model_name, model_cfg.type)
        # Kristofari sometimes returns an error with connection - need to handle it

        from functools import partial

        try:
            model = model_class.from_pretrained(
                pretrained_model_name,
                # id2label=id2label,
                # label2id=label2id,
                cache_dir=model_cache_dir,
                **kwargs
            )
        except HTTPError:
            model = model_class.from_pretrained(
                pretrained_model_name,
                # id2label=id2label,
                # label2id=label2id,
                cache_dir=model_cache_dir,
                local_files_only=True,
                **kwargs
            )
        
        # specific use case for ELECTRA with adapter
        if "ElectraAdapterModel" in model.__class__.__name__:
            model.add_classification_head('cls', 
                                        num_labels=num_labels
                                        )
            
        if use_adapter:

            adapt_type = model_cfg.get("adapter_type")
            if adapt_type == "lora":
                config = LoRAConfig()
            elif adapt_type == "unipelt":
                config = UniPELTConfig()
            elif adapt_type == "bottleneck":
                config = BnConfig(mh_adapter=True, output_adapter=True, reduction_factor=16, non_linearity="relu")
            elif adapt_type == "pfeiffer":
                config = SeqBnConfig()
            elif adapt_type == "double_seq_bn":
                config = DoubleSeqBnConfig()


            adapt_name = model_cfg.get("type")
            model.add_adapter(adapt_name, config=config)
    
            model.train_adapter(adapt_name)
            model.set_active_adapters(adapt_name)


       
        if "xlnet" in pretrained_model_name:
            model.config.use_mems_eval = False
        try:
            if model_cfg.type == "cv_cls" or model_cfg.type == "sp_cls":
                tokenizer = AutoFeatureExtractor.from_pretrained(
                    pretrained_model_name, cache_dir=tokenizer_cache_dir
                )
            else:
                tokenizer = AutoTokenizer.from_pretrained(
                    pretrained_model_name,
                    cache_dir=tokenizer_cache_dir,
                    **tokenizer_kwargs
                )
        except HTTPError:
            tokenizer = AutoTokenizer.from_pretrained(
                pretrained_model_name,
                cache_dir=tokenizer_cache_dir,
                local_files_only=True,
                **tokenizer_kwargs
            )
        assert tokenizer is not None, "Failed to load tokenizer"
        if model_cfg.tokenizer_max_length is not None:
            tokenizer.model_max_length = model_cfg.tokenizer_max_length
    else:
        # CV
        if model_cfg.type == "cv_cls":
            model = resnet18(mnist=True)
            tokenizer = None
            return model, tokenizer
        elif model_cfg.type == "cls" and model_cfg.name == "cnn":
            # build CNN for text classification
            if embeddings is None and model_cfg.embeddings_path is not None:
                embeddings, word2idx = load_embeddings(
                    model_cfg.embeddings_path, model_cfg.embeddings_cache_dir
                )
            model = TextClassificationCNN(
                pretrained_embedding=embeddings,
                freeze_embedding=model_cfg.freeze_embedding,
                vocab_size=model_cfg.vocab_size,
                embed_dim=model_cfg.embed_dim,
                filter_sizes=model_cfg.filter_sizes,
                num_filters=model_cfg.num_filters,
                num_classes=num_labels,
                dropout=model_cfg.classifier_dropout,
            )
            # create tokenizer
            tokenizer_model = WordLevel(word2idx, "[UNK]")
            tokenizer = Tokenizer(tokenizer_model)
            tokenizer.normalizer = Lowercase()
            tokenizer.pre_tokenizer = Whitespace()
            hf_tokenizer = PreTrainedTokenizerFast(
                tokenizer_object=tokenizer, pad_token="[PAD]", unk_token="[UNK]"
            )
            return model, hf_tokenizer

        # Implemented only for FNet
        assert model_cfg.name.startswith(
            "fnet"
        ), "Only FNet is supported among models out of HF repo!"
        assert model_cfg.type in [
            "cls",
            "ner",
        ], "Models not from HF repo are currently supported only for NER and classification tasks"

        path_to_pretrained = Path(model_cfg.path_to_pretrained)
        with open(path_to_pretrained / "config.json") as f:
            pretrained_model_cfg = json.load(f)

        model_class = (
            FNetForSequenceClassification
            if model_cfg.type == "cls"
            else FNetForTokenClassification
        )
        model = model_class(pretrained_model_cfg, num_labels)
        model.load_state_dict(
            load(path_to_pretrained / "fnet.statedict.pt"), strict=False
        )

        orig_tokenizer = SentencePieceBPETokenizer.from_file(
            str(path_to_pretrained / "vocab.json"),
            str(path_to_pretrained / "merges.txt"),
        )
        orig_tokenizer.post_processor = TemplateProcessing(
            single="<s> $A </s>",
            pair="<s> $A [SEP] $B:1 </s>:1",
            special_tokens=[("<s>", 1), ("</s>", 2), ("[MASK]", 6), ("[SEP]", 5)],
        )
        tokenizer = PreTrainedTokenizerFast(
            tokenizer_object=orig_tokenizer,
            model_max_length=model_cfg.tokenizer_max_length,
            bos_token="<s>",
            eos_token="</s>",
            pad_token="<pad>",
            cls_token="[CLS]",
            sep_token="[SEP]",
            mask_token="[MASK]",
        )

    return model, tokenizer


# TODO: add all the other models
def get_classifier_dropout_kwargs(
    pretrained_model_name: str, classifier_dropout: float
):
    if "distilbert" in pretrained_model_name:
        key = "seq_classif_dropout"
    elif "deberta" in pretrained_model_name:
        key = "pooler_dropout"
    elif "xlnet" in pretrained_model_name:
        key = "summary_last_dropout"
    elif "distilrubert" in pretrained_model_name:
        key = "dropout"
    elif "rubert-base" in pretrained_model_name:
        key = "hidden_dropout_prob"
    elif "albert" in pretrained_model_name:
        key = "classifier_dropout_prob"
    else:
        key = "classifier_dropout"
    return {key: classifier_dropout}


def get_tokenizer_kwargs(pretrained_model_name: str, task: str = "cls"):
    if "roberta" in pretrained_model_name and task == "ner":
        return dict(add_prefix_space=True)
    return {}


# import numpy as np
# import logging
# import os
# import torch
# import random
# import torch.nn.functional as F
# from collections import defaultdict, Counter
# from torch import nn
# from torch.nn import CrossEntropyLoss
# from transformers import BertPreTrainedModel, BertModel
# from transformers.modeling_outputs import TokenClassifierOutput

# class BertForTokenClassification(BertPreTrainedModel): # modified the original huggingface BertForTokenClassification to incorporate gaussian
#     def __init__(self, config):
#         super().__init__(config)
#         self.num_labels = config.num_labels
#         self.embedding_dimension = 100

#         self.bert = BertModel(config)
#         self.dropout = nn.Dropout(config.hidden_dropout_prob)
#         self.projection = nn.Sequential(
#             nn.Linear(config.hidden_size, self.embedding_dimension + (config.hidden_size - self.embedding_dimension) // 2)
#         )

#         self.output_embedder_mu = nn.Sequential(
#             nn.ReLU(),
#             nn.Linear(config.hidden_size,
#                       self.embedding_dimension)
#         )

#         self.output_embedder_sigma = nn.Sequential(
#             nn.ReLU(),
#             nn.Linear(config.hidden_size,
#                       self.embedding_dimension)
#         )


#         self.init_weights()

#     def forward(
#             self,
#             input_ids=None,
#             attention_mask=None,
#             token_type_ids=None,
#             position_ids=None,
#             head_mask=None,
#             inputs_embeds=None,
#             labels=None,
#             loss_type="KL",
#             consider_mutual_O=True
#     ):
#         outputs = self.bert(
#             input_ids,
#             attention_mask=attention_mask,
#             token_type_ids=token_type_ids,
#             position_ids=position_ids,
#             head_mask=head_mask,
#             inputs_embeds=inputs_embeds,
#         )

#         sequence_output = self.dropout(outputs[0])
#         original_embedding_mu = ((self.output_embedder_mu((sequence_output))))
#         original_embedding_sigma = (F.elu(self.output_embedder_sigma((sequence_output)))) + 1 + 1e-14
#         outputs = (original_embedding_mu, original_embedding_sigma,) + (outputs[0],) + outputs[2:]

#         if labels is not None:
#             loss = calculate_KL_or_euclidean(self, attention_mask, original_embedding_mu,
#                                                      original_embedding_sigma, labels, consider_mutual_O,
#                                                      loss_type=loss_type)
#             outputs = (loss,) + outputs
#         # return outputs  # (loss), output_mus, output_sigmas, (hidden_states), (attentions)
    
#         return TokenClassifierOutput(
#             loss=loss,
#             logits=logits,
#             hidden_states=outputs[0],
#             attentions= outputs[2:],
#         )
    
    
# def nt_xent(loss, num, denom, temperature = 1):

#     loss = torch.exp(loss/temperature)
#     cnts = torch.sum(num, dim = 1)
#     loss_num = torch.sum(loss * num, dim = 1)
#     loss_denom = torch.sum(loss * denom, dim = 1)
#     # sanity check
#     nonzero_indexes = torch.where(cnts > 0)
#     loss_num, loss_denom, cnts = loss_num[nonzero_indexes], loss_denom[nonzero_indexes], cnts[nonzero_indexes]

#     loss_final = -torch.log2(loss_num) + torch.log2(loss_denom) + torch.log2(cnts)
#     return loss_final

# def loss_kl(mu_i, sigma_i, mu_j, sigma_j, embed_dimension):
#     '''
#     Calculates KL-divergence between two DIAGONAL Gaussians.
#     Reference: https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians.
#     Note: We calculated both directions of KL-divergence.
#     '''
#     sigma_ratio = sigma_j / sigma_i
#     trace_fac = torch.sum(sigma_ratio, 1)
#     log_det = torch.sum(torch.log(sigma_ratio + 1e-14), axis=1)
#     mu_diff_sq = torch.sum((mu_i - mu_j) ** 2 / sigma_i, axis=1)
#     ij_kl = 0.5 * (trace_fac + mu_diff_sq - embed_dimension - log_det)
#     sigma_ratio = sigma_i / sigma_j
#     trace_fac = torch.sum(sigma_ratio, 1)
#     log_det = torch.sum(torch.log(sigma_ratio + 1e-14), axis=1)
#     mu_diff_sq = torch.sum((mu_j - mu_i) ** 2 / sigma_j, axis=1)
#     ji_kl = 0.5 * (trace_fac + mu_diff_sq - embed_dimension - log_det)
#     kl_d = 0.5 * (ij_kl + ji_kl)
#     return kl_d

# def euclidean_distance(a, b, normalize=False):
#     if normalize:
#         a = F.normalize(a)
#         b = F.normalize(b)
#     logits = ((a - b) ** 2).sum(dim=1)
#     return logits


# def remove_irrelevant_tokens_for_loss(self, attention_mask, original_embedding_mu, original_embedding_sigma, labels):
#     active_indices = attention_mask.view(-1) == 1
#     active_indices = torch.where(active_indices == True)[0]

#     output_embedding_mu = original_embedding_mu.view(-1, self.embedding_dimension)[active_indices]
#     output_embedding_sigma = original_embedding_sigma.view(-1, self.embedding_dimension)[active_indices]
#     labels_straightened = labels.view(-1)[active_indices]

#     # remove indices with negative labels only

#     nonneg_indices = torch.where(labels_straightened >= 0)[0]
#     output_embedding_mu = output_embedding_mu[nonneg_indices]
#     output_embedding_sigma = output_embedding_sigma[nonneg_indices]
#     labels_straightened = labels_straightened[nonneg_indices]

#     return output_embedding_mu, output_embedding_sigma, labels_straightened


# def calculate_KL_or_euclidean(self, attention_mask, original_embedding_mu, original_embedding_sigma, labels,
#                               consider_mutual_O=True, loss_type="euclidean"):

#     # we will create embedding pairs in following manner
#     # filtered_embedding | embedding ||| filtered_labels | labels
#     # repeat_interleave |            ||| repeat_interleave |
#     #                   | repeat     |||                   | repeat
#     # extract only active parts that does not contain any paddings

#     output_embedding_mu, output_embedding_sigma, labels_straightened = remove_irrelevant_tokens_for_loss(self, attention_mask,original_embedding_mu, original_embedding_sigma, labels)

#     # remove indices with zero labels, that is "O" classes
#     if not consider_mutual_O:
#         filter_indices = torch.where(labels_straightened > 0)[0]
#         filtered_embedding_mu = output_embedding_mu[filter_indices]
#         filtered_embedding_sigma = output_embedding_sigma[filter_indices]
#         filtered_labels = labels_straightened[filter_indices]
#     else:
#         filtered_embedding_mu = output_embedding_mu
#         filtered_embedding_sigma = output_embedding_sigma
#         filtered_labels = labels_straightened

#     filtered_instances_nos = len(filtered_labels)

#     # repeat interleave
#     filtered_embedding_mu = torch.repeat_interleave(filtered_embedding_mu, len(output_embedding_mu), dim=0)
#     filtered_embedding_sigma = torch.repeat_interleave(filtered_embedding_sigma, len(output_embedding_sigma),dim=0)
#     filtered_labels = torch.repeat_interleave(filtered_labels, len(output_embedding_mu), dim=0)

#     # only repeat
#     repeated_output_embeddings_mu = output_embedding_mu.repeat(filtered_instances_nos, 1)
#     repeated_output_embeddings_sigma = output_embedding_sigma.repeat(filtered_instances_nos, 1)
#     repeated_labels = labels_straightened.repeat(filtered_instances_nos)

#     # avoid losses with own self
#     loss_mask = torch.all(filtered_embedding_mu != repeated_output_embeddings_mu, dim=-1).int()
#     loss_weights = (filtered_labels == repeated_labels).int()
#     loss_weights = loss_weights * loss_mask

#     #ensure that the vector sizes are of filtered_instances_nos * filtered_instances_nos
#     # assert len(repeated_labels) == (filtered_instances_nos * filtered_instances_nos), "dimension is not of square shape."

#     if loss_type == "euclidean":
#         loss = -euclidean_distance(filtered_embedding_mu, repeated_output_embeddings_mu, normalize=True)

#     elif loss_type == "KL":  # KL_divergence
#         loss = -loss_kl(filtered_embedding_mu, filtered_embedding_sigma,
#                             repeated_output_embeddings_mu, repeated_output_embeddings_sigma,
#                             embed_dimension=self.embedding_dimension)

#     else:
#         raise Exception("unknown loss")

#     # reshape the loss, loss_weight, and loss_mask
#     loss = loss.view(filtered_instances_nos, filtered_instances_nos)
#     loss_mask = loss_mask.view(filtered_instances_nos, filtered_instances_nos)
#     loss_weights = loss_weights.view(filtered_instances_nos, filtered_instances_nos)

#     loss_final = nt_xent(loss, loss_weights, loss_mask, temperature = 1)
#     return torch.mean(loss_final)