import os
import math
import logging
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from torch import nn
import torch.nn.functional as F

from transformers import (
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModel,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    BertConfig, EncoderDecoderConfig, EncoderDecoderModel, BertForMaskedLM,
    RobertaForCausalLM,
)
from transformers.file_utils import is_sklearn_available, requires_sklearn
from torch.nn import LayerNorm as BertLayerNorm
from trainers.train_utils import render_order_heatmap

logger = logging.getLogger(__name__)


class HeatMapOutput(nn.Module):
    
    def __init__(self, config):
        super(HeatMapOutput, self).__init__()

        self.config = config

        self.heat_map_prediction = SimpleClassifier(
            config.hidden_size, config.hidden_size * 1,
            config.max_story_length, 0.5
        )
        # self.heat_map_prediction = nn.Linear(config.hidden_size, config.max_story_length)

        # TODO: higher layer.
        h_version = self.config.hierarchical_version
        h_version_num = int(h_version.split("v")[-1])
        self.h_version_num = h_version_num
        if h_version_num > 1:
            if config.hidden_size == 1024:
                decoder_config_path = "pretrained_models/roberta/large/decoder_config.json"
            elif config.hidden_size == 768:
                decoder_config_path = "pretrained_models/roberta/base/decoder_config.json"
            else:
                raise ValueError

            assert os.path.exists(decoder_config_path)

            decoder_config = AutoConfig.from_pretrained(decoder_config_path)
            # We are not using the ordinary vocabulary.
            decoder_config.out_features = config.hidden_size
            # decoder_config.is_decoder = True
            # decoder_config.add_cross_attention = True

            causal_lm = RobertaForCausalLM(decoder_config)
            self.add_layer = causal_lm.roberta.encoder
            # self.add_layer = causal_lm.roberta.encoder.layer[0].attention
            print(self.add_layer)

        # TODO HL aux objectives.
        self.hl_include_objectives = config.hl_include_objectives
        logging.info("Aux using: {}".format(self.hl_include_objectives))
        if self.hl_include_objectives is not None:
            assert type(self.hl_include_objectives) == list

            if "variable_length_lstm" in self.hl_include_objectives:
                self.single_hm_decoder = torch.nn.LSTM(config.hidden_size,
                                                       config.hidden_size,
                                                       batch_first=True)
                self.single_hm_projector = torch.nn.Linear(config.hidden_size,
                                                           1)
            elif "variable_length_transformer" in self.hl_include_objectives:
                # raise NotImplementedError
                self.heat_map_K = SimpleClassifier(
                    config.hidden_size, config.hidden_size * 2,
                    # 1, 0.5
                    self.config.max_story_length, 0.1
                )
                self.heat_map_V = SimpleClassifier(
                    config.hidden_size, config.hidden_size * 2,
                    # 1, 0.5
                    self.config.max_story_length, 0.1
                )
                print(self.heat_map_K)
                print(self.heat_map_V)

                if "variable_length_cross_modal" in self.hl_include_objectives:
                    self.heat_map_v_K = SimpleClassifier(
                        config.hidden_size, config.hidden_size * 2,
                        # 1, 0.5
                        self.config.max_story_length, 0.5
                    )
                    self.heat_map_v_V = SimpleClassifier(
                        config.hidden_size, config.hidden_size * 2,
                        # 1, 0.5
                        self.config.max_story_length, 0.5
                    )

                """
                if config.hidden_size == 1024:
                    decoder_config_path = "pretrained_models/roberta/large/decoder_config.json"
                elif config.hidden_size == 768:
                    decoder_config_path = "pretrained_models/roberta/base/decoder_config.json"
                else:
                    raise ValueError

                assert os.path.exists(decoder_config_path)

                decoder_config = AutoConfig.from_pretrained(decoder_config_path)
                # We are not using the ordinary vocabulary.
                decoder_config.out_features = config.hidden_size
                decoder_config.is_decoder = True
                decoder_config.add_cross_attention = True

                causal_lm = RobertaForCausalLM(decoder_config)
                self.decoder = causal_lm.roberta.encoder

                self.index_classifier = SimpleClassifier(
                    config.hidden_size, config.hidden_size,
                    config.max_story_length, 0.5
                )
                """

            if ("binary" in self.hl_include_objectives
                or "pairwise" in self.hl_include_objectives):
                self.hl_bin_pred_layer = SimpleClassifier(
                    config.hidden_size * 1,
                    config.hidden_size,
                    1, 0.5
                )
                self.hl_bin_pred_crit = torch.nn.CrossEntropyLoss()
                self.hl_bin_sparse_prob = 1.5  # 1.0
                if self.hl_bin_sparse_prob < 1.0:
                    self.hl_bin_sparse_pos = []

            if "binary_cross_modal" in self.hl_include_objectives:
                raise NotImplementedError("Not done yet!")
                self.hl_binx_pred_layer = SimpleClassifier(
                    config.hidden_size * 1 + config.hidden_size,
                    config.hidden_size,
                    2, 0.5
                )
                self.hl_binx_pred_crit = torch.nn.CrossEntropyLoss()
                self.hl_binx_sparse_prob = 1.5  # 1.0
                if self.hl_binx_sparse_prob < 1.0:
                    self.hl_binx_sparse_pos = []

            if "head" in self.hl_include_objectives:
                self.hl_head_pred_layer = SimpleClassifier(
                    config.hidden_size * 1,
                    config.hidden_size,
                    1, 0.5
                )
                self.hl_head_pred_crit = torch.nn.CrossEntropyLoss()

            if "mlm" in self.hl_include_objectives:
                raise NotImplementedError("Not done yet!")
                self.mlm_loss_fct = CrossEntropyLoss(
                    ignore_index=config.mlm_ignore_index)

            if "itm" in self.hl_include_objectives:
                if "swapping_based_nsp" in config.__dict__:
                    self.swapping_based_nsp = config.swapping_based_nsp
                    self.swapping_based_nsp_prob = 0.5
                    self.itm_loss_fct = torch.nn.CrossEntropyLoss()
                else:
                    raise ValueError("No `swapping_based_nsp` in config.")
                self.seq_relationship = nn.Linear(config.hidden_size, 2)
            else:
                self.swapping_based_nsp = False

            if "cross_modal_dependence" in self.hl_include_objectives:
                self.cross_modal_dependence_prediction = SimpleClassifier(
                    config.hidden_size, config.hidden_size * 1,
                    self.config.max_story_length, 0.5
                )
                self.cross_modal_dependence_loss = torch.nn.MSELoss()

            if "heatmap_pairwise_ranking" in self.hl_include_objectives:
                self.hm_pw_ranking_loss = torch.nn.MarginRankingLoss(margin=0.2)

            if "pairwise_binary_heatmap" in self.hl_include_objectives:
                self.hl_bin_hm_pred_layer = SimpleClassifier(
                    config.hidden_size * 2,
                    config.hidden_size * 1,
                    1, 0.1
                )
                self.hl_bin_hm_pred_crit = torch.nn.BCEWithLogitsLoss()

                self.hl_bin_hm_class_layer = SimpleClassifier(
                    config.hidden_size * 2,
                    config.hidden_size * 1,
                    2, 0.1
                )
                self.hl_bin_hm_class_loss = torch.nn.CrossEntropyLoss()

        # TODO Higher-level layers.
        self.additional_hl = False
        pass

        # TODO More random stuff.
        self.fusion_method = "mul"
        self.heatmap_late_fusion = False

        # Losses:
        # TODO KL-Div.
        self.use_kl_div = False

        self.heat_map_sort_loss = torch.nn.MSELoss()
        # self.heat_map_sort_loss = torch.nn.BCELoss()
        # self.heat_map_sort_loss = torch.nn.BCEWithLogitsLoss()
        # self.heat_map_sort_loss = torch.nn.SmoothL1Loss()
        if self.hl_include_objectives is not None:
            if "pairwise_binary_heatmap" in self.hl_include_objectives:
                self.heat_map_sort_loss = torch.nn.BCEWithLogitsLoss()
                self.heat_map_sort_loss = torch.nn.SmoothL1Loss(beta=0.2)
                self.heat_map_sort_loss = torch.nn.SoftMarginLoss()
                self.heat_map_sort_loss = torch.nn.MSELoss()
        print(self.heat_map_sort_loss)

    def get_heatmap_labels(self, batch, logits):
        heatmap_labels = []
        for i in range(len(batch["labels"])):
            labels_i = batch["labels"][i].detach().cpu().numpy()
            heatmap_labels_i = render_order_heatmap(args=None,  # TODO: change this.
                                                    order_list=labels_i,
                                                    ranking_based=False,
                                                    soft=False)
            heatmap_labels.append(heatmap_labels_i)
        heatmap_labels = torch.stack(heatmap_labels)
        heatmap_labels = heatmap_labels.type_as(logits)
        return heatmap_labels

    def forward(self, batch, sequence_output, itm_repr=None):
        """
            batch (dict): Dict of inputs.
        """
        if sequence_output.size(1) == self.config.max_story_length:
            bz = sequence_output.size(0)
            cls_heat_map = []
            for i in range(bz):
                cls_repr_i = sequence_output[i]
                cls_heat_map_i = self.heat_map_prediction(cls_repr_i)
                cls_heat_map.append(cls_heat_map_i)
            cls_heat_map = torch.stack(cls_heat_map)
            logits = torch.sigmoid(cls_heat_map)
            heatmap_labels = self.get_heatmap_labels(batch, logits)
            loss = self.heat_map_sort_loss(logits, heatmap_labels)
            return (loss, logits)

        input_ids = batch["input_ids"]
        bz, text_len = input_ids.size()

        cls_repr_pos = []
        for i in range(bz):
            cls_repr_pos_i = torch.nonzero(
                input_ids[i]==self.config.cls_id, as_tuple=False)
            cls_repr_pos.append(cls_repr_pos_i)

        cls_heat_map = []
        # print(cls_repr_pos[0])
        # print(batch["input_ids"][0]);raise
        # print(sequence_output.size())
        # print(self.hl_include_objectives);raise

        # TODO See if we need to use additional hl layers.
        if self.additional_hl:
            pass

        # Obtaining linguistics and visual outputs.
        if type(sequence_output) != tuple:
            sequence_output_t = sequence_output[:, :text_len]
            sequence_output_v = sequence_output[:, text_len:]
        else:
            sequence_output_t, sequence_output_v = sequence_output
            raise NotImplementedError("Not done yet!")

        # TODO: Auxiliary predictions.
        if self.hl_include_objectives is not None:
            hl_aux_predictions = [None] * len(self.hl_include_objectives)
            if "head" in self.hl_include_objectives:
                hl_aux_head_predictions = []
            if ("pairwise" in self.hl_include_objectives
                or "binary" in self.hl_include_objectives):
                hl_aux_bin_predictions = []
            if "binary_cross_modal" in self.hl_include_objectives:
                hl_aux_binx_predictions = []
            if "cross_modal_dependence" in self.hl_include_objectives:
                hl_aux_x_dep_predictions = []
            if "variable_length_cross_modal" in self.hl_include_objectives:
                cls_heat_map_xmodal = []
            if "pairwise_binary_heatmap" in self.hl_include_objectives:
                hl_bin_hm_class_obj = []

        for i in range(bz):
            if not self.additional_hl:
                cls_repr_pos_i = cls_repr_pos[i].squeeze()
                cls_repr_i = sequence_output[i][cls_repr_pos_i]
                if self.fusion_method == "img_only" or self.heatmap_late_fusion:
                    raise NotImplementedError("Not done yet!")
            else:
                raise NotImplementedError("Not done yet!")

            if self.heatmap_late_fusion:
                raise NotImplementedError("Not done yet!")

            if self.hl_include_objectives is not None:
                if len(self.hl_include_objectives) > 0:
                    for hl_aux_objective in self.hl_include_objectives:
                        if "head" == hl_aux_objective:
                            hl_aux_head_prediction_curr = self.hl_head_pred_layer(cls_repr_i)
                            hl_aux_head_prediction_curr = hl_aux_head_prediction_curr.squeeze()
                            hl_aux_head_predictions.append(hl_aux_head_prediction_curr)
                        elif ("pairwise" == hl_aux_objective
                              or "binary" == hl_aux_objective):
                            hl_aux_bin_predictions_tmp = []
                            for seq_i in range(len(cls_repr_pos_i)):
                                for seq_j in range(seq_i+1, len(cls_repr_pos_i)):
                                    cls_repr_seq_i = cls_repr_i[seq_i]
                                    cls_repr_seq_j = cls_repr_i[seq_j]
                                    cls_repr_seq_ij = torch.stack(
                                        [cls_repr_seq_i, cls_repr_seq_j])
                                    hl_aux_bin_prediction_curr = self.hl_bin_pred_layer(cls_repr_seq_ij)
                                    hl_aux_bin_prediction_curr = hl_aux_bin_prediction_curr.squeeze()
                                    hl_aux_bin_predictions_tmp.append(hl_aux_bin_prediction_curr)
                            hl_aux_bin_predictions_tmp = torch.stack(hl_aux_bin_predictions_tmp)

                            if self.hl_bin_sparse_prob < 1.0:
                                hl_bin_sparse_len = int(
                                    len(hl_aux_bin_predictions_tmp)
                                        * self.hl_bin_sparse_prob)
                                hl_bin_sparse_pos_tmp = np.random.choice(
                                    np.arange(hl_aux_bin_predictions_tmp.size(0)),
                                              hl_bin_sparse_len)
                                # TODO: Temporary!!!
                                hl_aux_bin_predictions_tmp = hl_aux_bin_predictions_tmp[
                                    hl_bin_sparse_pos_tmp]
                                # hl_aux_bin_predictions_tmp = hl_aux_bin_predictions_tmp[:3]
                                self.hl_bin_sparse_pos.append(hl_bin_sparse_pos_tmp)

                            hl_aux_bin_predictions.append(hl_aux_bin_predictions_tmp)

                        elif "binary_cross_modal" == hl_aux_objective:
                            hl_aux_binx_predictions_tmp = []
                            for seq_i in range(len(cls_repr_pos_i)):
                                for seq_j in range(seq_i+1, len(cls_repr_pos_i)):
                                    # Text modality.
                                    cls_repr_seq_i = cls_repr_i[seq_i]
                                    # Image modality.
                                    if self.config.include_num_img_regional_features  is not None:
                                        img_seq_j = (1 + self.config.include_num_img_regional_features) * seq_j
                                        raise NotImplementedError("Not debugged yet!")
                                    else:
                                        img_seq_j = seq_j
                                    cls_repr_seq_j = sequence_output_v[i][img_seq_j]
                                    # Stacking.
                                    cls_repr_seq_ij = torch.cat(
                                        [cls_repr_seq_i, cls_repr_seq_j], dim=-1)
                                    hl_aux_binx_prediction_curr = self.hl_binx_pred_layer(cls_repr_seq_ij)
                                    hl_aux_binx_prediction_curr = hl_aux_binx_prediction_curr.squeeze()
                                    hl_aux_binx_predictions_tmp.append(hl_aux_binx_prediction_curr)
                            hl_aux_binx_predictions_tmp = torch.stack(hl_aux_binx_predictions_tmp)

                            if self.hl_binx_sparse_prob < 1.0:
                                hl_binx_sparse_len = int(
                                    len(hl_aux_binx_predictions_tmp)
                                        * self.hl_binx_sparse_prob)
                                hl_binx_sparse_pos_tmp = np.random.choice(
                                    np.arange(hl_aux_binx_predictions_tmp.size(0)),
                                              hl_binx_sparse_len)
                                # TODO: Temporary!!!
                                hl_aux_binx_predictions_tmp = hl_aux_binx_predictions_tmp[
                                    hl_binx_sparse_pos_tmp]
                                # hl_aux_binx_predictions_tmp = hl_aux_binx_predictions_tmp[:3]
                                self.hl_binx_sparse_pos.append(hl_binx_sparse_pos_tmp)

                            hl_aux_binx_predictions.append(hl_aux_binx_predictions_tmp)

                        elif "cross_modal_dependence" == hl_aux_objective:
                            hl_aux_x_dep_predictions_tmp = []
                            for img_seq_i in range(len(cls_repr_pos_i)):
                                # Image modality.
                                if self.config.include_num_img_regional_features  is not None:
                                    img_seq_i_real = (1 + self.config.include_num_img_regional_features) * img_seq_i
                                    raise NotImplementedError("Not debugged yet!")
                                else:
                                    img_seq_i_real = img_seq_i
                                img_repr_seq_i = sequence_output_v[i][img_seq_i_real]
                                hl_aux_x_dep_predictions_tmp.append(img_repr_seq_i)
                            hl_aux_x_dep_predictions_tmp = torch.stack(hl_aux_x_dep_predictions_tmp)
                            hl_aux_x_dep_predictions_tmp = self.cross_modal_dependence_prediction(
                                hl_aux_x_dep_predictions_tmp)
                            hl_aux_x_dep_predictions.append(hl_aux_x_dep_predictions_tmp)

                        elif "variable_length_cross_modal" == hl_aux_objective:
                            img_repr_i = []
                            for img_seq_i in range(len(cls_repr_pos_i)):
                                # Image modality.
                                if self.config.include_num_img_regional_features  is not None:
                                    img_seq_i_real = (1 + self.config.include_num_img_regional_features) * img_seq_i
                                    raise NotImplementedError("Not debugged yet!")
                                else:
                                    img_seq_i_real = img_seq_i
                                img_repr_seq_i = sequence_output_v[i][img_seq_i_real]
                                img_repr_i.append(img_repr_seq_i)
                            img_repr_i = torch.stack(img_repr_i)

                pass  # End of auxiliary predictions.

            # If going through the additional transformer layers.
            if self.h_version_num > 1:
                cls_repr_i = cls_repr_i.unsqueeze(0)
                cls_repr_i = self.add_layer(hidden_states=cls_repr_i)[0][0]

            # Concatenate the heatmap representations.
            cls_heat_map_i = self.heat_map_prediction(cls_repr_i)

            if self.hl_include_objectives is not None:
                if "variable_length_lstm" in self.hl_include_objectives:
                    cls_heat_map_i = []
                    for cls_j in range(cls_repr_i.size(0)):
                        cls_repr_i_cls_j = cls_repr_i[cls_j]
                        cls_repr_i_cls_j = cls_repr_i_cls_j.unsqueeze(0).unsqueeze(0)
                        hm_outputs_cls_j = torch.zeros(cls_repr_i_cls_j.size(0),
                                                       self.config.max_story_length)
                        hm_dec_in = cls_repr_i_cls_j

                        cls_heat_map_i_cls_j = []

                        for t in range(self.config.max_story_length):
                            out_t, hidden = self.single_hm_decoder(hm_dec_in)
                            hm_dec_in = out_t
                            out_t = out_t.squeeze(0)
                            out_t = self.single_hm_projector(out_t)
                            cls_heat_map_i_cls_j.append(out_t.squeeze())
                        cls_heat_map_i_cls_j = torch.stack(cls_heat_map_i_cls_j)
                        cls_heat_map_i.append(cls_heat_map_i_cls_j.squeeze())
                    cls_heat_map_i = torch.stack(cls_heat_map_i)
                            
                elif "variable_length_transformer" in self.hl_include_objectives:
                    # raise NotImplementedError
                    cls_heat_map_k = self.heat_map_K(cls_repr_i)
                    cls_heat_map_v = self.heat_map_V(cls_repr_i)
                    cls_heat_map_v = torch.transpose(cls_heat_map_v, 0, 1)
                    cls_heat_map_i = torch.matmul(cls_heat_map_k, cls_heat_map_v)

                    if "variable_length_cross_modal" in self.hl_include_objectives:
                        img_heat_map_v = self.heat_map_v_V(img_repr_i)
                        img_heat_map_v = torch.transpose(img_heat_map_v, 0, 1)
                        cls_heat_map_xmodal_i = torch.matmul(cls_heat_map_k, img_heat_map_v)

                elif "pairwise_binary_heatmap" in self.hl_include_objectives:
                    cls_heat_map_i = torch.Tensor(self.config.max_story_length,
                        self.config.max_story_length).type_as(sequence_output_t)
                    for row in range(self.config.max_story_length):
                        for col in range(self.config.max_story_length):
                            cls_heat_map_rc = torch.cat([cls_repr_i[row], cls_repr_i[col]], dim=-1)
                            cls_heat_map_rc = cls_heat_map_rc.unsqueeze(0)
                            cls_heat_map_rc_val = self.hl_bin_hm_pred_layer(cls_heat_map_rc)
                            cls_heat_map_i[row][col] = cls_heat_map_rc_val

                            # FIXME: hl_bin_hm_class_obj
                            hl_bin_hm_class_obj_curr = self.hl_bin_hm_class_layer(cls_heat_map_rc)
                            hl_bin_hm_class_obj.append(hl_bin_hm_class_obj_curr)

            cls_heat_map.append(cls_heat_map_i)

            if self.hl_include_objectives is not None:
                if "variable_length_cross_modal" in self.hl_include_objectives:
                    cls_heat_map_xmodal.append(cls_heat_map_xmodal_i)

        cls_heat_map = torch.stack(cls_heat_map)
        self.cls_heat_map = cls_heat_map

        if self.hl_include_objectives is not None:
            if "variable_length_cross_modal" in self.hl_include_objectives:
                cls_heat_map_xmodal = torch.stack(cls_heat_map_xmodal)
        
        # Obtain auxiliary predictions.
        if self.hl_include_objectives is not None:
            if len(self.hl_include_objectives) > 0:
                for hl_aux_objective in self.hl_include_objectives:
                    aux_obj_list_idx = self.hl_include_objectives.index(hl_aux_objective)

                    if "head" == hl_aux_objective:
                        hl_aux_head_predictions = torch.stack(hl_aux_head_predictions)
                        hl_aux_predictions[aux_obj_list_idx] = hl_aux_head_predictions
                    elif ("pairwise" == hl_aux_objective
                          or "binary" == hl_aux_objective):
                        hl_aux_bin_predictions = torch.stack(hl_aux_bin_predictions)
                        hl_aux_predictions[aux_obj_list_idx] = hl_aux_bin_predictions
                    elif "binary_cross_modal" == hl_aux_objective:
                        hl_aux_binx_predictions = torch.stack(hl_aux_binx_predictions)
                        hl_aux_predictions[aux_obj_list_idx] = hl_aux_binx_predictions
                    elif "cross_modal_dependence" == hl_aux_objective:
                        hl_aux_x_dep_predictions = torch.stack(hl_aux_x_dep_predictions)
                        hl_aux_predictions[aux_obj_list_idx] = hl_aux_x_dep_predictions
            pass ####

        if not self.use_kl_div:
            logits = cls_heat_map
            if self.config.hl_include_objectives is not None:
                if "pairwise_binary_heatmap" in self.config.hl_include_objectives:
                    logits = torch.sigmoid(logits)
                    pass
                else:
                    logits = torch.sigmoid(logits)
            else:
                logits = torch.sigmoid(logits)

            # FIXME: transpose.
            if self.config.hl_include_objectives is not None:
                if "variable_length_transformer" in self.config.hl_include_objectives:
                    pass
                elif "pairwise_binary_heatmap" in self.config.hl_include_objectives:
                    pass
                else:
                    logits = torch.transpose(logits, 1, 2)
            else:
                logits = torch.transpose(logits, 1, 2)

        if "labels" in batch and batch["labels"] is not None:
            heatmap_labels = []
            for i in range(len(batch["labels"])):
                labels_i = batch["labels"][i].detach().cpu().numpy()
                heatmap_labels_i = render_order_heatmap(args=None,  # TODO: change this.
                                                        order_list=labels_i,
                                                        ranking_based=False,
                                                        soft=False)
                heatmap_labels.append(heatmap_labels_i)
            heatmap_labels = torch.stack(heatmap_labels)
            heatmap_labels = heatmap_labels.type_as(logits)

            if self.use_kl_div:
                # raise NotImplementedError("Not done yet!")
                logits = torch.sigmoid(cls_heat_map)
                logits_soft = F.log_softmax(logits)
                heatmap_labels_soft = heatmap_labels + 1e-12
                heatmap_labels_soft = F.softmax(heatmap_labels_soft)
                loss = torch.nn.KLDivLoss(size_average=True)(logits_soft, heatmap_labels_soft)

            else:
                # The main heatmap loss.
                # heatmap_labels = F.softmax(heatmap_labels)
                loss = self.heat_map_sort_loss(logits, heatmap_labels)
                # logits = torch.sigmoid(logits)

                # FIXME:
                if self.config.hl_include_objectives is not None:
                    if "pairwise_binary_heatmap" in self.config.hl_include_objectives:
                #         logits = torch.sigmoid(logits)
                        hl_bin_hm_class_target = []
                        for i in range(len(batch["labels"])):
                            heatmap_labels_i = heatmap_labels[i]
                            for row in range(self.config.max_story_length):
                                for col in range(self.config.max_story_length):
                                    if heatmap_labels_i[row][col] == 1:
                                        hl_bin_hm_class_target.append(1)
                                    else:
                                        hl_bin_hm_class_target.append(0)
                        hl_bin_hm_class_target = torch.Tensor(hl_bin_hm_class_target).type_as(batch["input_ids"])
                        hl_bin_hm_class_obj = torch.stack(hl_bin_hm_class_obj).squeeze()
                        hl_bin_hm_class_loss = self.hl_bin_hm_class_loss(hl_bin_hm_class_obj, hl_bin_hm_class_target)
                        loss += 0.4 * hl_bin_hm_class_loss

            # TODO: Deal with auxiliary objectives.
            if self.hl_include_objectives is not None:
                if len(self.hl_include_objectives) > 0:
                    for h in range(len(self.hl_include_objectives)):
                        hl_aux_objective = self.hl_include_objectives[h]

                        if "heatmap_pairwise_ranking" == hl_aux_objective:
                            hm_pw_target = []
                            hm_pw_input_1 = []
                            hm_pw_input_2 = []
                            for b in range(len(batch["labels"])):
                                label_ = list(batch["labels"][b].cpu().numpy())
                                # print(label_)
                                hm_pw_target_tmp = []
                                hm_pw_input_1_tmp = []
                                hm_pw_input_2_tmp = []
                                for seq_i in range(len(label_)):
                                    pos_i = label_[seq_i]
                                    if seq_i+1 >= len(label_):
                                        break
                                    pos_j = label_[seq_i+1]
                                    anchor = logits[b][pos_i][pos_j]
                                    for seq_j in range(len(label_)):
                                        if seq_j - seq_i == 1:  # Positive
                                            hm_pw_target_tmp.append(1)
                                        else:
                                            hm_pw_target_tmp.append(-1)
                                        pos_i = label_[seq_i]
                                        pos_j = label_[seq_j]
                                        # print(pos_i, pos_j, hm_pw_target_tmp[-1])
                                        heatmap_ij = logits[b][pos_i][pos_j]
                                        hm_pw_input_1_tmp.append(anchor)
                                        hm_pw_input_2_tmp.append(heatmap_ij)

                                hm_pw_target_tmp = torch.Tensor(
                                    hm_pw_target_tmp).type_as(batch["labels"])
                                hm_pw_input_1_tmp = torch.stack(hm_pw_input_1_tmp)
                                hm_pw_input_2_tmp = torch.stack(hm_pw_input_2_tmp)

                                hm_pw_target.append(hm_pw_target_tmp)
                                hm_pw_input_1.append(hm_pw_input_1_tmp)
                                hm_pw_input_2.append(hm_pw_input_2_tmp)
                            
                            hm_pw_target = torch.stack(hm_pw_target)
                            hm_pw_input_1 = torch.stack(hm_pw_input_1)
                            hm_pw_input_2 = torch.stack(hm_pw_input_2)
                            # print(hm_pw_target.size(), hm_pw_input_1.size(),
                            #       hm_pw_input_2.size())

                            heatmap_pairwise_ranking_loss = self.hm_pw_ranking_loss(
                                hm_pw_input_1, hm_pw_input_2, hm_pw_target)
                            # print(heatmap_pairwise_ranking_loss)

                            loss += heatmap_pairwise_ranking_loss

                        elif "mlm_wo_loss" == hl_aux_objective:
                            pass

                        elif "mlm" == hl_aux_objective:
                            masked_lm_labels = batch["masked_lm_labels"]
                            masked_lm_loss = self.mlm_loss_fct(
                                linguistic_prediction.view(-1,
                                    self.config.vocab_size),
                                masked_lm_labels.view(-1),
                            )
                            loss += 0.05 * masked_lm_loss

                        elif "itm" == hl_aux_objective:
                            assert itm_repr is not None, "No itm representation!"
                            pooled_output, itm_targets = itm_repr
                            seq_relationship_prediction = self.seq_relationship(
                                pooled_output)
                            swapping_based_nsp_loss = self.itm_loss_fct(
                                seq_relationship_prediction,
                                itm_targets,
                            )
                            loss += 0.1 * swapping_based_nsp_loss

                        elif "head" == hl_aux_objective:
                            head_labels = batch["labels"][:, 0]
                            hl_aux_head_loss = self.hl_head_pred_crit(
                                hl_aux_predictions[h], head_labels)
                            loss += hl_aux_head_loss

                        elif ("pairwise" == hl_aux_objective
                              or "binary" == hl_aux_objective):
                            hl_bin_predictions = hl_aux_predictions[h]
                            bz, seq_len = batch["labels"].size()
                            for b in range(bz):
                                label_curr = batch["labels"][b]
                                label_index = torch.argsort(label_curr)
                                hl_bin_label = torch.zeros(
                                    hl_bin_predictions.size()[1]).type_as(
                                        batch["labels"])

                                bin_idx = 0
                                if self.hl_bin_sparse_prob < 1.0:
                                    hl_bin_label = []
                                for seq_i in range(seq_len):
                                    for seq_j in range(seq_i+1, seq_len):
                                        if label_index[seq_i] < label_index[seq_j]:
                                            if self.hl_bin_sparse_prob < 1.0:
                                                hl_bin_label.append(1)
                                            else:
                                                hl_bin_label[bin_idx]  = 1
                                        else:
                                            if self.hl_bin_sparse_prob < 1.0:
                                                hl_bin_label.append(0)
                                        bin_idx += 1

                                if self.hl_bin_sparse_prob < 1.0:
                                    hl_bin_label = torch.Tensor(hl_bin_label).type_as(batch["labels"])
                                    # TODO: Change this!!!
                                    # hl_bin_label = hl_bin_label[:3]
                                    hl_bin_label = hl_bin_label[self.hl_bin_sparse_pos[b]]

                                hl_aux_bin_loss_curr = self.hl_bin_pred_crit(hl_bin_predictions[b], hl_bin_label)
                                # loss += hl_aux_bin_loss_curr
                                loss += 0.1 * hl_aux_bin_loss_curr
                            pass
                                    
                        elif "binary_cross_modal" == hl_aux_objective:
                            hl_binx_predictions = hl_aux_predictions[h]
                            bz, seq_len = batch["labels"].size()
                            for b in range(bz):
                                label_curr = batch["labels"][b]
                                label_index = torch.argsort(label_curr)
                                hl_binx_label = torch.zeros(
                                    hl_binx_predictions.size()[1]).type_as(
                                        batch["labels"])

                                binx_idx = 0
                                if self.hl_binx_sparse_prob < 1.0:
                                    hl_binx_label = []
                                for seq_i in range(seq_len):
                                    for seq_j in range(seq_i+1, seq_len):
                                        if label_index[seq_i] < label_index[seq_j]:
                                            if self.hl_binx_sparse_prob < 1.0:
                                                hl_binx_label.append(1)
                                            else:
                                                hl_binx_label[binx_idx]  = 1
                                        else:
                                            if self.hl_binx_sparse_prob < 1.0:
                                                hl_binx_label.append(0)
                                        binx_idx += 1

                                if self.hl_binx_sparse_prob < 1.0:
                                    hl_binx_label = torch.Tensor(hl_binx_label).type_as(batch["labels"])
                                    # TODO: Change this!!!
                                    # hl_bin_label = hl_bin_label[:3]
                                    hl_binx_label = hl_binx_label[self.hl_binx_sparse_pos[b]]

                                hl_aux_binx_loss_curr = self.hl_binx_pred_crit(hl_binx_predictions[b], hl_binx_label)
                                loss += hl_aux_binx_loss_curr
                                    
                            pass

                        elif "cross_modal_dependence" == hl_aux_objective:
                            hl_binx_predictions = hl_aux_predictions[h]
                            cross_modal_dependence_preds = torch.sigmoid(
                                hl_binx_predictions)
                            cross_modal_dependence_logits = torch.transpose(
                                cross_modal_dependence_preds, 1, 2)
                            cross_modal_dependence_loss = self.cross_modal_dependence_loss(
                                heatmap_labels, cross_modal_dependence_logits)
                            loss += cross_modal_dependence_loss

                        elif "variable_length_cross_modal" == hl_aux_objective:
                            cls_heat_map_xmodal = torch.sigmoid(cls_heat_map_xmodal)
                            var_xmodal_loss = self.heat_map_sort_loss(heatmap_labels, cls_heat_map_xmodal)
                            loss += var_xmodal_loss

                pass  # End of all auxiliary losses. 

            return loss, logits

        return (logits, )


class SimpleClassifier(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, dropout):
        super().__init__()
        self.logit_fc = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            GeLU(),
            BertLayerNorm(hid_dim, eps=1e-12),
            nn.Linear(hid_dim, out_dim),
        )

    def forward(self, hidden_states):
        return self.logit_fc(hidden_states)


def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


class GeLU(nn.Module):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        Also see https://arxiv.org/abs/1606.08415
    """

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return gelu(x)
