import torch
from torch import nn
import torchvision.models as models
from transformers import *
from torch.nn import CrossEntropyLoss
from .frozen_batch_norm import FrozenBatchNorm2d
# from torch.nn import CrossEntropyLoss
from vlpretrain.loss import paired_hinge_rank_loss2, binary_classification_loss, binary_classification_loss_with_neg, contrastive_loss, paired_hinge_rank_loss3




def get_visn_arch(arch):
    try:
        return getattr(models, arch)
    except AttributeError as e:
        print(e)
        print("There is no arch %s in torchvision." % arch)


class Similarity(nn.Module):
    """
    Dot product or cosine similarity
    """

    def __init__(self, temp):
        super().__init__()
        self.temp = temp
        self.cos = nn.CosineSimilarity(dim=-1)

    def forward(self, x, y):
        return self.cos(x, y) / self.temp

class SimpleBertForMaskedLM(BertForMaskedLM):

    def __init__(self, config):
        super().__init__(config)

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            masked_lm_labels=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            lm_labels=None,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
        sequence_output = outputs[0]

        prediction_scores = self.cls(sequence_output)
        loss_fct = CrossEntropyLoss(ignore_index=-1)
        # print(self.config.vocab_size, prediction_scores.shape, masked_lm_labels.shape)
        token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))

        return token_loss

LANG_MODELS = {
        #   'bert':    (SimpleBertForMaskedLM,       BertTokenizer,       'bert-base-uncased', BertConfig),
        #   'bert-large':  (SimpleBertForMaskedLM,       BertTokenizer,       'bert-large-uncased', BertConfig),
          'bert':    (BertForMaskedLM,       BertTokenizer,       'bert-base-uncased'),
          'bert-large':  (BertForMaskedLM,       BertTokenizer,       'bert-large-uncased'),
          'gpt':     (OpenAIGPTModel,  OpenAIGPTTokenizer,  'openai-gpt'),
          'gpt2':    (GPT2Model,       GPT2Tokenizer,       'gpt2'),
          'ctrl':    (CTRLModel,       CTRLTokenizer,       'ctrl'),
          'xl':      (TransfoXLModel,  TransfoXLTokenizer,  'transfo-xl-wt103'),
          'xlnet':   (XLNetModel,      XLNetTokenizer,      'xlnet-base-cased'),
          'xlm':     (XLMModel,        XLMTokenizer,        'xlm-mlm-enfr-1024'),
          'distil':  (DistilBertModel, DistilBertTokenizer, 'distilbert-base-cased'),
          'roberta': (RobertaModel,    RobertaTokenizer,    'roberta-base'),
          'roberta-large': (RobertaModel,    RobertaTokenizer,    'roberta-large'),
          'xlm-roberta': (XLMRobertaModel, XLMRobertaTokenizer, 'xlm-roberta-base'),
}



class VisnModel(nn.Module):
    def __init__(self, args, dim, arch='resnet50', pretrained=True, finetuning=False, bertonly=False, normalize=False):
        """
        :param dim: dimension of the output
        :param arch: backbone architecture,
        :param pretrained: load feature with pre-trained vector
        :param finetuning: finetune the model
        """
        super().__init__()
        self.args = args
        self.finetuning = finetuning
        self.bertonly = bertonly
        self.normalize = normalize
        if arch == 'vector' or bertonly:
            if bertonly:
                self.backbone = nn.Embedding(1, dim).weight.data
            else:
                self.backbone = nn.Embedding(1, dim)
            
            self.vector = True
        # Setup Backbone
        else:
            self.vector = False
            resnet = get_visn_arch(arch)(pretrained=pretrained)
            backbone_dim = resnet.fc.in_features
            if not self.finetuning:
                for param in resnet.parameters():
                    param.requires_grad = False
            resnet.fc = nn.Identity()
            self.backbone = resnet

            # Surgery on the Networks
            # 1. Frozen Batch Norm
            #    Note that BatchNorm modules have been in-place replaced!
            #    This piece of code is copied from Detectron2, and it was copied from mask-rcnn?
            self.backbone = FrozenBatchNorm2d.convert_frozen_batchnorm(
                self.backbone)
            # print(self.backbone)
            # 2. Frozen the first two (blocks of) layers
            for module in [self.backbone.conv1,
                        self.backbone.layer1]:
                for param in module.parameters():
                    param.requires_grad = False

            print(f"Visn Model: {arch}, Finetune: {finetuning}, Pre-trained: {pretrained}")
            print(f"Visn Model: backbone dim {backbone_dim} --> output dim {dim}")

    
            self.mlp = nn.Sequential(
                nn.Linear(backbone_dim, dim),
                # nn.Tanh()
            )
        print("-------------visn", backbone_dim, dim)

    def forward(self, img):
        """
        :param img: a tensor of shape [batch_size, H, W, C]
        :return: a tensor of [batch_size, d]
        """
        if self.vector is True:
            batch_size, h, w, c = img.shape
            if self.bertonly:
                x = self.backbone.repeat(batch_size, 1)
            else:
                x = self.backbone.weight.repeat(batch_size, 1)
            # x = self.backbone.repeat(batch_size, 1)
        else:

            if not self.finetuning:
                with torch.no_grad():
                    x = self.backbone(img)
                    x = x.detach()
            else:
                x = self.backbone(img)
            x = self.mlp(x)         # [b, dim]
        # if self.normalize:
        # x = x / x.norm(2, dim=-1, keepdim=True)

        return x


class LangModel(nn.Module):
    def __init__(self, args, dim, arch='BERT', layers=(-1,), pretrained=True, finetuning=False, bertonly=False, normalize=False):
        """
        :param dim: dimension of the output
        :param arch: backbone architecture,
        :param aggregate: one of 'last4',
        :param pretrained: load feature with pre-trained vector
        :param finetuning: finetune the model
        """
        super().__init__()
        self.args = args
        self.finetuning = finetuning
        self.bertonly = bertonly
        self.normalize = normalize

        # Setup Backbone
        Model, Tokenizer, weight = LANG_MODELS[arch]
        self.model = Model.from_pretrained(
            weight,
            output_hidden_states=True
        )
        # bert = self.model.bert
        # print(bert.keys())
        if not pretrained:
            self.model.bert.init_weights()

        if not self.finetuning:
            for param in self.model.bert.parameters():
                param.requires_grad = False
        if args.freezelayers:
            for layer_idx in [0,1,2,3,4,5]:
                for param in list(self.model.bert.encoder.layer[layer_idx].parameters()):
                    param.requires_grad = False
        if args.freezeembeddings:
            for param in list(self.model.bert.embeddings.parameters()):
                param.requires_grad = False
        backbone_dim = self.model.bert.config.hidden_size
        # self.bert = bert
        self.layers = sorted(layers)

        print(f"Language Model: {arch} with weight {weight}; Fine-tuning: {finetuning}, Pre-trained: {pretrained}.")
        print(f"Language Model: using layers {self.layers}, result in backbone dim {backbone_dim * len(self.layers)} "
              f"--> output dim {dim}.")


        if bertonly:
            self.mlp = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(backbone_dim * len(self.layers), 1)
            )
        else:
            self.mlp = nn.Sequential(
                nn.Linear(backbone_dim * len(self.layers), dim),
                # nn.Tanh()
            )
        print("-------------lang", backbone_dim, dim)

        

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        """
        :param input_ids: [batch_size, max_len]
        :param attention_mask: [batch_size, max_len]
        :param token_type_ids: [batch_size, max_len]
        :return: [batch_size, max_len, dim]
        """
        if not self.finetuning:
            with torch.no_grad():
                x = self.model.bert(
                    input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                )
        else:
            x = self.model.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
            )
        # print("1",x[0])
        # sequence_output, pooled_output, (hidden_states), (attentions) --> seq_output
        if type(self.model.bert) is XLNetModel:
            output, hidden_states = x[:2]
        else:
            output, pooled_output, hidden_states = x[:3]

        # gather the layers
        # if type(self.backbone) is XLNetModel:
        #     x = torch.cat(list(hidden_states[layer].permute(1, 0, 2) for layer in self.layers), -1)
        # else:
        # x = torch.cat(list(hidden_states[layer] for layer in self.layers), -1)[:,0,:]
        if self.args.avgvector:
            # ((last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1))
            x = torch.cat(list(hidden_states[layer] for layer in self.layers), -1)
            x = (x * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1)

        else:

            x = torch.cat(list(hidden_states[layer] for layer in self.layers), -1)[:,0,:]
        # x = pooled_output
        if not self.finetuning:
            x = x.detach()


        x = self.mlp(x)

        # if not self.bertonly and self.normalize:
        # x = x / x.norm(2, dim=-1, keepdim=True)
        return x


class JointModel(nn.Module):
    def __init__(self, args, lang_model, visn_model):
        super().__init__()
        self.lang_model = lang_model
        self.visn_model = visn_model
        if args.clloss:
            self.criterion = contrastive_loss
            self.sim = Similarity(args.temp)
        else:
            # self.criterion = binary_classification_loss
            self.criterion  = paired_hinge_rank_loss3
            # self.criterion2 = binary_classification_loss_with_neg
        # self.criterion3 = contrastive_loss
        self.args = args
        

    def forward(self, lang_input, visn_input, neg_lang_input= None, caption_lang_input=None, general_lang_input=None):
        lang_output = self.lang_model(*lang_input)
        visn_output = self.visn_model(*visn_input)
       
        # loss2 = None
        if neg_lang_input is not None:
            b, seqlen, h = neg_lang_input[0].shape
            neg_input_ids = neg_lang_input[0].reshape(b*seqlen, -1)
            neg_input_masks = neg_lang_input[1].reshape(b*seqlen, -1)
            neg_lang_output = self.lang_model(neg_input_ids, neg_input_masks)
            neg_lang_output = neg_lang_output.reshape(b, seqlen, -1)
            if self.args.clloss:
                loss = self.criterion(visn_output, lang_output, self.sim, neg_lang_output, weight = self.args.weight, neg_weight=self.args.neg_weight)
            else:
                loss = self.criterion2(lang_output, visn_output, neg_lang_output, self.args.with_random_neg)
        else:
            if self.args.clloss:
                loss = self.criterion(visn_output, lang_output, self.sim, weight = self.args.weight, neg_weight=self.args.neg_weight)
            else:
                loss = self.criterion(lang_output, visn_output)
        # if loss2 is not None:
        #     loss = loss/2 + loss2/2
        # caption_mask_loss = self.lang_model.model(input_ids=caption_lang_input[0], attention_mask=caption_lang_input[1], masked_lm_labels=caption_lang_input[2])
        caption_mask_loss = 0
        general_mask_loss = self.lang_model.model(input_ids=general_lang_input[0], attention_mask=general_lang_input[1], labels=general_lang_input[2])
        # print(general_mask_loss)
        general_mask_loss = general_mask_loss[0]
        # if not caption_lang_input and not general_lang_input:
        #     caption_mask_loss = self.lang_model.model(input_ids=caption_lang_input[0], attention_mask=caption_lang_input[1], masked_lm_labels=caption_lang_input[2])
        #     general_mask_loss = self.lang_model.model(input_ids=general_lang_input[0], attention_mask=general_lang_input[1], masked_lm_labels=general_lang_input[2])
        # else:
        #     caption_mask_loss = 0
        #     general_mask_loss = 0
        return loss, caption_mask_loss, general_mask_loss, lang_output, visn_output


