'''
Author: your name
Date: 2021-10-14 09:41:10
LastEditTime: 2021-10-21 21:23:22
LastEditors: Please set LastEditors
Description: In User Settings Edit
FilePath: /open_clip-main/src/training/vilt.py
'''
import torch
import torch.nn as nn
from transformers.file_utils import CLOUDFRONT_DISTRIB_PREFIX
from transformers.utils.dummy_pt_objects import AutoModel
import training.modules.vision_transformer as vit
from transformers import BertTokenizer, BertModel
from transformers.models.bert.modeling_bert import BertConfig, BertEmbeddings
import json
import copy
from torch.utils.mobile_optimizer import optimize_for_mobile

class Pooler(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class ITMHead(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.fc = nn.Linear(hidden_size, 2)

    def forward(self, x):
        x = self.fc(x)
        return x


class ViLT(nn.Module):
    """
    """
    def __init__(self, ckpt_path="", tokenizer=BertTokenizer.from_pretrained('bert-base-uncased'), student=False):
        super().__init__()
        self.student = student
        config = json.load(open("/home/roy/ViLT/vilt_config.json", "r"))
        self.config = config
        self.tokenizer = tokenizer
        self.cls_token_id = tokenizer.cls_token_id
        self.pad_token_id = tokenizer.pad_token_id


        bert_config = BertConfig(
            vocab_size=config["vocab_size"],
            hidden_size=config["hidden_size"],
            num_hidden_layers=config["num_layers"],
            num_attention_heads=config["num_heads"],
            intermediate_size=config["hidden_size"] * config["mlp_ratio"],
            max_position_embeddings=config["max_text_len"],
            hidden_dropout_prob=config["drop_rate"],
            attention_probs_dropout_prob=config["drop_rate"],
        )
        self.text_embeddings = BertEmbeddings(bert_config)

        self.token_type_embeddings = nn.Embedding(2, config['hidden_size'])
        self.transformer = getattr(vit, config['vit'])(
            pretrained=False, config=config
        )
        self.pooler = Pooler(config["hidden_size"])
        self.itm_score = ITMHead(config["hidden_size"])
        self.rank_output = nn.Linear(config['hidden_size'], 1)
        self.rank_output.weight.data = self.itm_score.fc.weight.data[1:, :]
        self.rank_output.bias.data = self.itm_score.fc.bias.data[1:]
        self.margin = 0.2
        if ckpt_path != "":
            print("Loading pretrained ViLT")
            ckpt = torch.load(ckpt_path, map_location='cpu')
            state_dict = ckpt["state_dict"]
            for key in self.state_dict():
                if key in state_dict:
                    self.state_dict()[key].copy_(state_dict[key])
                    print(f"{key} found in ckpt")
                else:
                    print(f"{key} not found in ckpt")
            assert len(self.transformer.blocks) == 12
        if student:
            self.transformer._blocks = nn.ModuleList([copy.deepcopy(self.transformer.blocks[i]) for i in range(4)])
            print("Re-initializing ViT to 6-layer")
            del self.transformer.blocks
            self.transformer.blocks = self.transformer._blocks
            assert len(self.transformer.blocks) == 4
            nn.init.xavier_normal_(self.img_rep_embedding)

        self.query_projection = nn.Linear(512, config['hidden_size'])
        # self.query_embeddings = copy.deepcopy(self.text_embeddings)
        # self.token_wise_bias = nn.Parameter(torch.zeros(1).squeeze(0))
        # self.text_lstm = nn.LSTM(input_size=config['hidden_size'], hidden_size=config['hidden_size']//2, num_layers=1, bidirectional=True, batch_first=True)
        # self.text_prompt_embeddings = nn.Parameter(torch.Tensor(40, config['hidden_size']))
        # self.projection = nn.Linear(config['hidden_size'], config['hidden_size'])
        # self.temperature = nn.Parameter(torch.tensor([1.0]).float())
        # nn.init.xavier_normal_(self.text_prompt_embeddings)

    def forward_image(self, images):
        bs = images.shape[0]
        # image embedding
        (
            image_embeds,
            image_masks,
            patch_index,
            image_labels,
        ) = self.transformer.visual_embed(
            images,
            max_image_len=self.config["max_image_len"],
            mask_it=False,
        )

        # prepend [IMG] token
        # image_embeds = torch.cat([self.img_rep_embedding.unsqueeze(0).repeat(bs, 1, 1), image_embeds], dim=1)
        image_embeds = image_embeds+ self.token_type_embeddings(
                torch.ones(bs, image_embeds.size(1)).long().to(images.device)
            )

        co_embeds = image_embeds
        co_masks = image_masks
        # co_masks = torch.cat([torch.ones(bs, 1).to(images.device), image_masks], dim=1)
        x = co_embeds
        for block in self.transformer.blocks:
            x, _ = block(x, mask=co_masks)
        x = self.transformer.norm(x)
        # cls_features = self.pooler(x) # (bs, hidden_size)
        # cls_features = self.img_projection(x[:, 0]) # (bs, hidden_size)
        cls_features = torch.mean(x, dim=1)
        return cls_features, image_embeds, image_masks


    def forward_text_dual(self, texts):
        text_ids = texts # (bs, text_len)
        text_embeds = self.query_embeddings(text_ids) # (bs, text_len, hidden_size)
        return text_embeds

    def sim_score_dual(self, images, texts):
        image_cls_features, *_ = self.forward_image(images)
        text_embeds = self.forward_text_dual(texts)
        text2image_sim_matrix = []
        for i in range(len(texts)):
            text_embed = text_embeds[i].unsqueeze(0).repeat(len(images), 1, 1)
            image_features = image_cls_features.unsqueeze(1)
            score = torch.log(torch.relu((text_embed * image_features).sum(dim=-1) + self.token_wise_bias) + 1.1).sum(dim=-1) # (bs,)
            text2image_sim_matrix.append(score)
        text2image_sim_matrix = torch.stack(text2image_sim_matrix, dim=0)
        return text2image_sim_matrix


    def forward_text(self, texts):
        text_ids = texts # (bs, text_len)
        text_masks = (text_ids!=self.pad_token_id).float() # (bs, text_len)
        output = self.query_bert(input_ids=text_ids, attention_mask=text_masks)
        x = output[0]
        cls = (x * text_masks.unsqueeze(-1)).sum(dim=1) / text_masks.sum(dim=-1).unsqueeze(-1)
        cls = self.query_projection(cls)
        # text_embeds = self.text_embeddings(text_ids) # (bs, text_len, hidden_size)
        # text_embeds = text_embeds + self.token_type_embeddings(torch.zeros(bs, text_embeds.size(1)).long().to(texts.device))

        # co_embeds = text_embeds
        # co_masks = text_masks

        # x = co_embeds
        # for block in self.transformer.blocks:
        #     x, _ = block(x, mask=co_masks)

        # x = self.transformer.norm(x)
        # cls = self.pooler(x)
        # cls = self.text_projection(x[:, 0])
        # cls = torch.mean(x, dim=1)
        return cls, None, text_masks

    def encode_image_with_prompt(self, images):
        bs = images.shape[0]
        # image embedding
        (
            image_embeds,
            image_masks,
            patch_index,
            image_labels,
        ) = self.transformer.visual_embed(
            images,
            max_image_len=self.config["max_image_len"],
            mask_it=False,
        )

        # prepend cls token
        cls_token = torch.tensor([self.cls_token_id]*bs, dtype=torch.long).unsqueeze(-1).to(images.device) # (bs, 1)
        cls_token_embedding = self.text_embeddings(cls_token) # (bs, 1, hidden_size)
        text_placeholder_embedding = torch.cat([cls_token_embedding, self.text_prompt_embeddings.unsqueeze(0).repeat(bs, 1, 1)], dim=1)
        text_placeholder_embedding, image_embeds = (
            text_placeholder_embedding + self.token_type_embeddings(torch.zeros(bs, text_placeholder_embedding.size(1)).long().to(images.device)),
            image_embeds
            + self.token_type_embeddings(
                torch.full_like(image_masks, 1).to(images.device)
            ),
        )

        co_embeds = torch.cat([text_placeholder_embedding, image_embeds], dim=1)
        co_masks = torch.cat([torch.ones(bs, text_placeholder_embedding.size(1)).to(images.device), image_masks], dim=1)
        x = co_embeds
        for block in self.transformer.blocks:
            x, _ = block(x, mask=co_masks)
        x = self.transformer.norm(x)
        cls_features = self.pooler(x) # (bs, hidden_size)
        cls_features = self.projection(cls_features)
        cls_features = cls_features / torch.norm(cls_features, dim=1, keepdim=True)
        return cls_features


    def encode_texts(self, texts):
        # text embedding
        text_ids = texts # (bs, text_len)
        text_masks = (text_ids!=self.pad_token_id).float() # (bs, text_len)
        text_embeds = self.query_embeddings(text_ids) # (bs, text_len, hidden_size)
        text_lstm_output, _ = self.text_lstm(text_embeds) # (bs, text_len, hidden_size)
        text_rep = text_lstm_output * text_masks.unsqueeze(-1)
        text_rep = text_rep.sum(dim=1) / text_masks.sum(dim=-1).unsqueeze(-1)
        text_rep = text_rep / torch.norm(text_rep, dim=1, keepdim=True)
        return text_rep


    def forward_prompt(self, images, texts):
        """
        forward pass in retrieval stage
        texts and images are separately encoded
        """
        bs = images.shape[0]
        norm_text_rep = self.encode_texts(texts)
        norm_img_rep = self.encode_image_with_prompt(images)
        # sim matrix
        t2i_matrix = norm_text_rep @ norm_img_rep.t() / self.temperature
        ground_truth = torch.arange(bs).to(images.device) # for now no label-smoothing is used

        loss = 0.0
        loss += nn.CrossEntropyLoss()(t2i_matrix, ground_truth)
        loss += nn.CrossEntropyLoss()(t2i_matrix.t(), ground_truth)
        hit1 = 0
        hit5 = 0
        hit10 = 0
        for i in range(bs):
            _, topk_id = torch.topk(t2i_matrix[i].detach().cpu(), k=bs)
            tmp = []
            for id in topk_id:
                tmp.append(id.item())
            if i in tmp[:1]:
                hit1 += 1
            if i in tmp[:5]:
                hit5 += 1
            if i in tmp[:10]:
                hit10 += 1
        hit1 /= bs
        hit5 /= bs
        hit10 /= bs
        return loss, t2i_matrix, hit1, hit5, hit10

    def forward(self, images):
        (
            image_embeds,
            image_masks,
            patch_index,
            image_labels,
        ) = self.transformer.visual_embed(
            images,
            max_image_len=self.config["max_image_len"],
            mask_it=False,
        )
        image_embeds = image_embeds + self.token_type_embeddings(torch.full_like(image_masks, 1))
        x = image_embeds
        mask = image_masks
        for block in self.transformer.blocks:
            x, _ = block(x, mask=mask)
        x = self.transformer.norm(x)
        cls = self.pooler(x)
        return cls


    def sim_score_raw(self, images, texts):
        """
        fine-grained ranking of images for a single text throught the cross-encoder
        """
        num_img = images.shape[0]
        total_ranking_score = []
        (
            image_embeds,
            image_masks,
            patch_index,
            image_labels,
        ) = self.transformer.visual_embed(
            images,
            max_image_len=self.config["max_image_len"],
            mask_it=False,
        )
        image_embeds = image_embeds + self.token_type_embeddings(
                torch.full_like(image_masks, 1)
            )
        for i in range(len(texts)):
            text_ids = texts[i].unsqueeze(0) # (bs, text_len)
            text_masks = (text_ids!=self.pad_token_id).bool() # (bs, text_len)
            text_embeds = self.text_embeddings(text_ids) # (1, text_len, hidden_size)
            text_embeds = text_embeds.repeat(num_img, 1, 1)


            text_embeds = text_embeds + self.token_type_embeddings(torch.zeros(num_img, text_embeds.size(1)).long().to(images.device))

            co_embeds = torch.cat([text_embeds, image_embeds], dim=1)
            co_masks = torch.cat([text_masks.repeat(num_img, 1), image_masks], dim=1)

            x = co_embeds
            for block in self.transformer.blocks:
                x, _ = block(x, mask=co_masks)

            x = self.transformer.norm(x)

            cls_feature = self.pooler(x)
            rank_scores = self.rank_output(cls_feature).squeeze(1) # (num_img)
            total_ranking_score.append(rank_scores)
        total_ranking_score = torch.stack(total_ranking_score, dim=0)
        assert total_ranking_score.size(0) == len(texts)
        assert total_ranking_score.size(1) == len(images)
        return total_ranking_score

    def sim_score_embeds(self, image_embeds, image_masks, text_embeds, text_masks):
        """
        fine-grained ranking of images for a single text throught the cross-encoder
        """
        num_img = image_embeds.shape[0]
        total_ranking_score = []
        for i in range(len(text_embeds)):
            co_embeds = torch.cat([text_embeds[i].unsqueeze(0).repeat(num_img, 1, 1), image_embeds], dim=1)
            co_masks = torch.cat([text_masks[i].unsqueeze(0).repeat(num_img, 1), image_masks], dim=1)

            x = co_embeds
            for block in self.transformer.blocks:
                x, _ = block(x, mask=co_masks)

            x = self.transformer.norm(x)

            cls_feature = self.pooler(x)
            rank_scores = self.rank_output(cls_feature).squeeze(1) # (num_img)
            total_ranking_score.append(rank_scores)
        total_ranking_score = torch.stack(total_ranking_score, dim=0)
        return total_ranking_score