"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.

UNITER for pretraining
"""
from collections import defaultdict

import torch
import math
import numpy as np
from torch import nn
from torch.nn import functional as F
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm

from .layer import GELU, BertOnlyMLMHead
from .model import UniterModel, UniterPreTrainedModel
from .ot import optimal_transport_dist
from trainers.input_utils import get_detailed_input_feats


class RegionFeatureRegression(nn.Module):
    " for MRM"
    def __init__(self, hidden_size, feat_dim, img_linear_weight):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(hidden_size, hidden_size),
                                 GELU(),
                                 LayerNorm(hidden_size, eps=1e-12))

        self.weight = img_linear_weight
        self.bias = nn.Parameter(torch.zeros(feat_dim))

    def forward(self, input_):
        hidden = self.net(input_)
        output = F.linear(hidden, self.weight.t(), self.bias)
        return output


class RegionClassification(nn.Module):
    " for MRC(-kl)"
    def __init__(self, hidden_size, label_dim):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(hidden_size, hidden_size),
                                 GELU(),
                                 LayerNorm(hidden_size, eps=1e-12),
                                 nn.Linear(hidden_size, label_dim))

    def forward(self, input_):
        output = self.net(input_)
        return output


class UniterForPretraining(UniterPreTrainedModel):
    """ UNITER pretraining """
    def __init__(self, config, img_dim, img_label_dim, vision_model=None,
                 tokenizer=None, multimodal_text_part=False,
                 multimodal_img_part=False, additional_config=None):
        super().__init__(config)
        self.uniter = UniterModel(config, img_dim)
        self.cls = BertOnlyMLMHead(
            config, self.uniter.embeddings.word_embeddings.weight)
        self.feat_regress = RegionFeatureRegression(
            config.hidden_size, img_dim,
            self.uniter.img_embeddings.img_linear.weight)
        self.region_classifier = RegionClassification(
            config.hidden_size, img_label_dim)
        self.itm_output = nn.Linear(config.hidden_size, 2)
        self.apply(self.init_weights)

        # Vision model.
        self.vision_model = vision_model
        if self.vision_model is not None:
            # Remove the final FC layer.
            if True:
                self.num_img_dim = self.vision_model.fc.in_features
                self.vision_model.fc = nn.Identity()
            self.freeze_vision_model = False
        self.multimodal_text_part = multimodal_text_part
        self.multimodal_img_part = multimodal_img_part

        # Tokenizer.
        self.tokenizer = tokenizer

        # Change the config to the actual config class instance.
        config = additional_config
        self.config = config

        # Parameters for objectives.
        self.itm_ot_lambda = 0.1
        self.swapping_based_nsp_prob = 0.5

    def forward(self, batch):
        batch["txt_labels"] = batch["masked_lm_labels"]
        new_batch = get_detailed_input_feats(batch, self.tokenizer, self.config)
        batch["position_ids"] = new_batch["position_ids"]
        batch["attn_masks"] = new_batch["attn_masks"] 
        batch["gather_index"] = new_batch["gather_index"]

        # Visual embedding.
        if self.vision_model is not None and batch["images"].ndim > 3:
            images = batch["images"]
            bz, img_len, C, H, W = images.size()
            images = torch.reshape(images, (bz*img_len, C, H, W)).float()
            images = self.vision_model(images)
            if self.freeze_vision_model:
                images = images.detach()
            images = torch.reshape(images, (bz, img_len, self.num_img_dim))
            batch["img_feat"] = images
        else:
            batch["img_feat"] = None
        batch['img_pos_feat'] = None

        # Attention masks handling.
        # TODO Make sure the followings.
        if (not self.multimodal_text_part and not self.multimodal_img_part
            and not self.config.img_text_paired_coattention):
            additional_attn = torch.ones(bz, batch["img_feat"].size(1)).type_as(
                batch["attn_masks"])
            batch["attn_masks"] = torch.cat([batch["attn_masks"],
                                             additional_attn], dim=-1)
        if self.multimodal_text_part:
            batch["images"] = None
            batch["img_feat"] = None
        if self.multimodal_img_part:
            batch["input_ids"] = None
            batch["attn_masks"] = torch.ones(bz, batch["img_feat"].size(1)).type_as(
                batch["attn_masks"])

        batch = {x: batch[x].to(batch['attn_masks'].device)
                 if batch[x] is not None else batch[x] for x in batch}

        # Change ot inputs of exist.
        if "ot_inputs" in new_batch:
            ot_inputs = {}
            for x in new_batch["ot_inputs"]:
                item = new_batch["ot_inputs"][x]
                try:
                    item = item.to(batch['attn_masks'].device)
                except:
                    pass
                ot_inputs[x]= item
            batch["ot_inputs"] = ot_inputs

        # Compute the total loss.
        total_loss = 0

        mlm_loss = self._forward(batch, task="mlm")
        total_loss += 1.0 * mlm_loss

        if ("itm" in self.config.multimodal_pretrain_objectives
            or "itm_ot" in self.config.multimodal_pretrain_objectives):
            itm_loss, ot_loss = self._forward(batch, task="itm")
            if ot_loss is not None:
                ot_pos, ot_neg = ot_loss
                ot_loss = (ot_pos.sum() - ot_neg.sum()
                           ) / (ot_pos.size(0) + ot_neg.size(0))

                # NOTE: be ware of empty tensor
                ot_pos = ot_pos.mean().item()
                if not math.isnan(ot_pos):
                    pass
                    # task2loss[f'{name}_ot_pos'](ot_pos)
                ot_neg = ot_neg.mean().item()
                if not math.isnan(ot_neg):
                    pass
                    # task2loss[f'{name}_ot_neg'](ot_neg)
                itm_loss = itm_loss + self.itm_ot_lambda * ot_loss
            total_loss += 1.0 * itm_loss

        return total_loss, None

    def _forward(self, batch, task, compute_loss=True):
        batch = defaultdict(lambda: None, batch)
        input_ids = batch['input_ids']
        position_ids = batch['position_ids']
        img_feat = batch['img_feat']
        img_pos_feat = batch['img_pos_feat']
        attention_mask = batch['attn_masks']
        gather_index = batch['gather_index']
        if task == 'mlm':
            txt_labels = batch['txt_labels']
            return self.forward_mlm(input_ids, position_ids,
                                    img_feat, img_pos_feat,
                                    attention_mask, gather_index,
                                    txt_labels, compute_loss)
        elif task == 'mrfr':
            raise NotImplementedError("Not done yet!")
            img_mask_tgt = batch['img_mask_tgt']
            img_masks = batch['img_masks']
            mrfr_feat_target = batch['feat_targets']
            return self.forward_mrfr(input_ids, position_ids,
                                     img_feat, img_pos_feat,
                                     attention_mask, gather_index,
                                     img_masks, img_mask_tgt,
                                     mrfr_feat_target, compute_loss)
        elif task == 'itm':
            targets = None
            ot_inputs = None
            if 'uniter_itm_targets' not in batch:
                targets = batch['uniter_itm_targets']
            if 'ot_inputs' in batch:
                ot_inputs = batch['ot_inputs']
            return self.forward_itm(input_ids, position_ids,
                                    img_feat, img_pos_feat,
                                    attention_mask, gather_index,
                                    targets, ot_inputs, compute_loss)
        elif task.startswith('mrc'):
            raise NotImplementedError("Not done yet!")
            img_mask_tgt = batch['img_mask_tgt']
            img_masks = batch['img_masks']
            mrc_label_target = batch['label_targets']
            return self.forward_mrc(input_ids, position_ids,
                                    img_feat, img_pos_feat,
                                    attention_mask, gather_index,
                                    img_masks, img_mask_tgt,
                                    mrc_label_target, task, compute_loss)
        else:
            raise ValueError('invalid task')

    def forward_mlm(self, input_ids, position_ids, img_feat, img_pos_feat,
                    attention_mask, gather_index,
                    txt_labels, compute_loss=True):
        sequence_output = self.uniter(input_ids, position_ids,
                                      img_feat, img_pos_feat,
                                      attention_mask, gather_index,
                                      output_all_encoded_layers=False)
        # get only the text part
        sequence_output = sequence_output[:, :input_ids.size(1), :]
        # only compute masked tokens for better efficiency
        masked_output = self._compute_masked_hidden(sequence_output,
                                                    txt_labels != self.config.mlm_ignore_index)
        prediction_scores = self.cls(masked_output)

        if compute_loss:
            masked_lm_loss = F.cross_entropy(prediction_scores,
                                             txt_labels[txt_labels != self.config.mlm_ignore_index],
                                             # reduction='none')
                                             )
            return masked_lm_loss
        else:
            return prediction_scores

    def _compute_masked_hidden(self, hidden, mask):
        """ get only the masked region (don't compute unnecessary hiddens) """
        mask = mask.unsqueeze(-1).expand_as(hidden)
        hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1))
        return hidden_masked

    def forward_mrfr(self, input_ids, position_ids, img_feat, img_pos_feat,
                     attention_mask, gather_index, img_masks, img_mask_tgt,
                     feat_targets, compute_loss=True):
        sequence_output = self.uniter(input_ids, position_ids,
                                      img_feat, img_pos_feat,
                                      attention_mask, gather_index,
                                      output_all_encoded_layers=False,
                                      img_masks=img_masks)

        # only compute masked tokens for better efficiency
        masked_output = self._compute_masked_hidden(sequence_output,
                                                    img_mask_tgt)
        prediction_feat = self.feat_regress(masked_output)

        if compute_loss:
            mrfr_loss = F.mse_loss(prediction_feat, feat_targets,
                                   reduction='none')
            return mrfr_loss
        else:
            return prediction_feat

    def forward_itm(self, input_ids, position_ids, img_feat, img_pos_feat,
                    attention_mask, gather_index, targets, ot_inputs=None,
                    compute_loss=True):

        if True:
            itm_outputs = self._itm_swapping_based(img_feat,
                                                   img_pos_feat=None,
                                                   input_ids=input_ids)
            img_feat, img_pos_feat, targets = itm_outputs

        sequence_output = self.uniter(input_ids, position_ids,
                                      img_feat, img_pos_feat,
                                      attention_mask, gather_index,
                                      output_all_encoded_layers=False)
        pooled_output = self.uniter.pooler(sequence_output)
        itm_scores = self.itm_output(pooled_output)

        # OT loss
        if ot_inputs is not None:
            ot_scatter = ot_inputs['ot_scatter']

            b = sequence_output.size(0)
            tl = input_ids.size(1)
            il = img_feat.size(1)
            max_l = max(ot_inputs['scatter_max'] + 1, tl+il)

            ot_scatter = ot_scatter.unsqueeze(-1).expand_as(sequence_output)
            ctx_emb = torch.zeros(b, max_l, self.config.hidden_size,
                                  dtype=sequence_output.dtype,
                                  device=sequence_output.device
                                  ).scatter_(dim=1, index=ot_scatter,
                                             src=sequence_output)
            txt_emb = ctx_emb[:, :tl, :]
            img_emb = ctx_emb[:, tl:tl+il, :]

            txt_pad = ot_inputs['txt_pad']
            img_pad = ot_inputs['img_pad']
            # NOTE: run in fp32 for stability
            ot_dist = optimal_transport_dist(txt_emb.float(), img_emb.float(),
                                             txt_pad, img_pad).to(txt_emb)
            ot_pos_dist = ot_dist.masked_select(targets == 1)
            ot_neg_dist = ot_dist.masked_select(targets == 0)
            ot_loss = (ot_pos_dist, ot_neg_dist)
        else:
            ot_loss = None

        if compute_loss:
            itm_loss = F.cross_entropy(itm_scores, targets,
                                       # reduction='none')
                                       )
            return itm_loss, ot_loss
        else:
            return itm_scores, ot_loss

    def forward_mrc(self, input_ids, position_ids, img_feat, img_pos_feat,
                    attention_mask, gather_index, img_masks, img_mask_tgt,
                    label_targets, task, compute_loss=True):
        sequence_output = self.uniter(input_ids, position_ids,
                                      img_feat, img_pos_feat,
                                      attention_mask, gather_index,
                                      output_all_encoded_layers=False,
                                      img_masks=img_masks)

        # only compute masked regions for better efficiency
        masked_output = self._compute_masked_hidden(sequence_output,
                                                    img_mask_tgt)
        prediction_soft_label = self.region_classifier(masked_output)

        if compute_loss:
            if "kl" in task:
                prediction_soft_label = F.log_softmax(
                    prediction_soft_label, dim=-1)
                mrc_loss = F.kl_div(
                    prediction_soft_label, label_targets, reduction='none')
            else:
                # background class should not be the target
                label_targets = torch.max(label_targets[:, 1:], dim=-1)[1] + 1
                mrc_loss = F.cross_entropy(
                    prediction_soft_label, label_targets,
                    ignore_index=0, reduction='none')
            return mrc_loss
        else:
            return prediction_soft_label

    def _itm_swapping_based(self, images, img_pos_feat=None,
                            input_ids=None, compute_loss=True):
        # Perform swapping-based multimodal alignment objective.
        bz, img_len = images.size(0), images.size(1)
        images_if_swapped = torch.zeros(bz, img_len)
        swapping_based_nsp_labels = []
        new_images = []
        for i in range(bz):
            image_ = images[i].clone()  # L x D
            image_lenwise_sum = torch.sum(image_, dim=-1)
            # TODO: Since our visual mask token is 0.
            # non_zero_images = image_lenwise_sum.nonzero().t()[0]
            non_zero_images = torch.nonzero(image_lenwise_sum, as_tuple=False).t()[0]

            if len(non_zero_images) == 0:
                swapping_based_nsp_labels.append(1)
                continue

            sample_batch_idx = i + 1
            if i == bz - 1:
               sample_batch_idx = 0
            image_cands_ = images[sample_batch_idx]
            image_cands_lenwise_sum = torch.sum(image_cands_, dim=-1)
            # non_zero_image_cands_ = image_cands_lenwise_sum.nonzero().t()[0]
            non_zero_image_cands_ = torch.nonzero(image_cands_lenwise_sum, as_tuple=False).t()[0]
            if len(non_zero_image_cands_) == 0:
                swapping_based_nsp_labels.append(1)
                continue

            non_zero_image_cands_ = non_zero_image_cands_.detach().cpu().numpy().astype(int)
            non_zero_images = non_zero_images.detach().cpu().numpy().astype(int)

            # TODO: Prevent swapping the already swapped images.
            non_zero_image_cands_ = set(list(non_zero_image_cands_))
            images_if_swapped_i = torch.nonzero(
                images_if_swapped[
                    sample_batch_idx],
                    as_tuple=False).t()[0].detach().cpu().numpy().astype(int)
            images_if_swapped_i = set(list(images_if_swapped_i))
            non_zero_image_cands_ -= images_if_swapped_i
            non_zero_image_cands_ = list(non_zero_image_cands_)
            if len(non_zero_image_cands_) == 0:
                swapping_based_nsp_labels.append(1)
                continue

            chose_index = np.random.choice(non_zero_image_cands_)
            swapped_index = np.random.choice(non_zero_images)

            # Probability of swapping.
            if_swap = np.random.rand()
            if if_swap > self.swapping_based_nsp_prob:
                image_[swapped_index] = image_cands_[chose_index]
                swapping_based_nsp_labels.append(0)
                images_if_swapped[i][swapped_index] = 1
                if self.config.include_num_img_regional_features is not None:
                    img_regional_features[i][swapped_index] = \
                        img_regional_features[
                            sample_batch_idx][chose_index]
            else:
                swapping_based_nsp_labels.append(1)

            # images[i] = image_
            new_images.append(image_)

        images = torch.stack(new_images)
        swapping_based_nsp_labels = torch.Tensor(
            swapping_based_nsp_labels).type_as(input_ids)
        return images, img_pos_feat, swapping_based_nsp_labels
