"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.

Uniter for VQA model
"""
from collections import defaultdict
import logging

import torch
from torch import nn
from torch.nn import functional as F
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm

from .layer import GELU
from .model import UniterPreTrainedModel, UniterModel
from trainers.input_utils import get_detailed_input_feats
from models.heatmap_module import HeatMapOutput


logger = logging.getLogger(__name__)


class UniterForVisualQuestionAnswering(UniterPreTrainedModel):
    """ Finetune UNITER for VQA
    """
    def __init__(self, config, img_dim, num_answer,
                 vision_model=None, tokenizer=None,
                 multimodal_text_part=False, multimodal_img_part=False,
                 additional_config=None):
        super().__init__(config)
        self.uniter = UniterModel(config, img_dim)
        self.vqa_output = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size*2),
            GELU(),
            LayerNorm(config.hidden_size*2, eps=1e-12),
            nn.Linear(config.hidden_size*2, num_answer)
        )
        self.apply(self.init_weights)

        self.num_answer = num_answer

        # Vision model.
        self.vision_model = vision_model
        if self.vision_model is not None:
            # Remove the final FC layer.
            if True:
                self.num_img_dim = self.vision_model.fc.in_features
                self.vision_model.fc = nn.Identity()
            pass
            self.freeze_vision_model = False
        self.multimodal_text_part = multimodal_text_part
        self.multimodal_img_part = multimodal_img_part

        if self.multimodal_text_part:
            assert self.multimodal_img_part is False
            logging.info("[UNITER] Unimodal only text part!")
        if self.multimodal_img_part:
            assert self.multimodal_text_part is False
            logging.info("[UNITER] Unimodal only image part!")
            
        # Tokenizer.
        self.tokenizer = tokenizer

        # Loss.
        self.sort_loss = nn.CrossEntropyLoss()

        # Change the config to the actual config class instance.
        config = additional_config

        # Heatmap predictions.
        self.hierarchical_version = config.hierarchical_version
        if self.hierarchical_version != "v0":
            self.heatmap = HeatMapOutput(additional_config)

        self.config = config

    def forward(self, batch, compute_loss=True):

        # Visual embedding.
        if self.vision_model is not None and batch["images"].ndim > 3:
            images = batch["images"]
            bz, img_len, C, H, W = images.size()
            images = torch.reshape(images, (bz*img_len, C, H, W)).float()
            images = self.vision_model(images)
            if self.freeze_vision_model:
                images = images.detach()
            images = torch.reshape(images, (bz, img_len, self.num_img_dim))
            batch["img_feat"] = images
        else:
            batch["img_feat"] = None
        batch['img_pos_feat'] = None

        # UNITER inputs
        new_batch = get_detailed_input_feats(batch, self.tokenizer, self.config)
        batch["position_ids"] = new_batch["position_ids"]
        batch["attn_masks"] = new_batch["attn_masks"]
        batch["gather_index"] = new_batch["gather_index"]

        # Attention masks handling.
        # TODO Make sure the followings.
        if (not self.multimodal_text_part and not self.multimodal_img_part
            and not self.config.img_text_paired_coattention):
            additional_attn = torch.ones(bz, self.num_answer).type_as(
                batch["attn_masks"])
            batch["attn_masks"] = torch.cat([batch["attn_masks"],
                                             additional_attn], dim=-1)

        if self.multimodal_text_part:
            batch["images"] = None
            batch["img_feat"] = None
        if self.multimodal_img_part:
            batch["input_ids"] = None
            batch["attn_masks"] = torch.ones(bz, batch["img_feat"].size(1)).type_as(
                batch["attn_masks"])

        batch = {x: batch[x].to(batch['attn_masks'].device)
                 if batch[x] is not None else batch[x] for x in batch}

        batch = defaultdict(lambda: None, batch)
        input_ids = batch["input_ids"]
        position_ids = batch["position_ids"]
        img_feat = batch["img_feat"]
        img_pos_feat = batch["img_pos_feat"]
        attn_masks = batch["attn_masks"]
        gather_index = batch["gather_index"]
        sequence_output = self.uniter(input_ids, position_ids,
                                      img_feat, img_pos_feat,
                                      attn_masks, gather_index,
                                      output_all_encoded_layers=False)
        
        pooled_output = self.uniter.pooler(sequence_output)

        # Heatmap handlings.
        if self.hierarchical_version != "v0":
            return self.heatmap(batch, sequence_output)

        answer_scores = self.vqa_output(pooled_output)

        if compute_loss and ("targets" in batch or "labels" in batch):
            # targets = batch['targets']
            # vqa_loss = F.binary_cross_entropy_with_logits(
            #     answer_scores, targets, reduction='none')
            targets = batch["labels"]
            vqa_loss = self.sort_loss(answer_scores, targets)
            return vqa_loss, answer_scores
        else:
            return (answer_scores, )
