import copy
from typing import Any, Dict, Optional, Tuple

from dataclasses import dataclass


import torch.nn.functional as F

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.file_utils import ModelOutput
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model, GPT2Block, GPT2Attention, GPT2MLP


from model_utils import AverageSelfAttention, Conv1D, entropy, joint_entropy

EPS = 1e-12

class SharedEncoder(GPT2Model):
    def __init__(self, config): # for posterior/prior distribution of context sensitive/independent variables
        super().__init__(config)

        self.embed_dim = config.hidden_size
        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
        self.drop = nn.Dropout(config.embd_pdrop)

        n = 3 # 3
        self.h = nn.ModuleList([GPT2Block(config) for _ in range(n)])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

        self.init_weights()


class PosteriorContextSensitiveEncoder(nn.Module):
    def __init__(self, max_utterance_num:int, max_utterance_len:int, hidden_size:int):
        super().__init__()

        self.max_utterance_num = max_utterance_num
        self.max_utterance_len = max_utterance_len

        self.embed_dim = hidden_size
        self.local_self_attention = AverageSelfAttention(self.embed_dim)
        self.global_self_attention = AverageSelfAttention(self.embed_dim)
        nx = hidden_size
        nz = hidden_size
        self.mean = Conv1D(nz, nx)
        self.logvar = Conv1D(nz, nx)

    def forward(self, hidden_states, attention_mask):
        # reshape to [batch * max_num, max_len, hidden_size]
        batch_size, sequence_length = hidden_states.shape[:2]
        assert sequence_length % self.max_utterance_len == 0
        utterance_num = sequence_length // self.max_utterance_len

        local_states = hidden_states.view(batch_size, utterance_num, self.max_utterance_len, self.embed_dim)
        local_states = local_states.view(batch_size * utterance_num, self.max_utterance_len, self.embed_dim)
        local_mask = attention_mask.view(batch_size, utterance_num, self.max_utterance_len)
        local_mask = local_mask.reshape(batch_size * utterance_num, self.max_utterance_len)

        global_states, _ = self.local_self_attention(local_states, local_mask) # [batch * max_num, hidden_size]
        global_states = global_states.view(batch_size, utterance_num, self.embed_dim)
        global_mask = local_mask.view(batch_size, utterance_num, self.max_utterance_len).sum(dim=-1)
        global_mask = global_mask.to(torch.bool).to(attention_mask.dtype)
        representations, _ = self.global_self_attention(global_states, global_mask)
        mean = self.mean(representations)
        logvar = self.logvar(representations)

        return (mean, logvar)


class PosteriorContextIndependentEncoder(nn.Module):
    def __init__(self, max_utterance_num: int, max_utterance_len: int, hidden_size: int):
        super().__init__()

        self.max_utterance_num = max_utterance_num
        self.max_utterance_len = max_utterance_len

        self.embed_dim = hidden_size
        self.local_self_attention = AverageSelfAttention(self.embed_dim)
        nx = hidden_size
        nz = hidden_size
        self.mean = Conv1D(nz, nx)
        self.logvar = Conv1D(nz, nx)

    def forward(self, hidden_states, attention_mask):
        # reshape to [batch * max_num, max_len, hidden_size]
        batch_size, sequence_length = hidden_states.shape[:2]
        assert sequence_length % self.max_utterance_len == 0
        utterance_num = sequence_length // self.max_utterance_len

        local_states = hidden_states.view(batch_size, utterance_num, self.max_utterance_len, self.embed_dim)
        local_states = local_states.view(batch_size * utterance_num, self.max_utterance_len, self.embed_dim)
        local_mask = attention_mask.view(batch_size, utterance_num, self.max_utterance_len)
        local_mask = local_mask.reshape(batch_size * utterance_num, self.max_utterance_len)

        representations, _ = self.local_self_attention(local_states, local_mask)  # [batch * max_num, hidden_size]
        mean = self.mean(representations) # [batch * max_num, embed_dim]
        logvar = self.logvar(representations) # [batch * max_num, embed_dim]
        mean = mean.view(batch_size, utterance_num, self.embed_dim)
        logvar = logvar.view(batch_size, utterance_num, self.embed_dim)

        return (mean, logvar)


class PosteriorCategoryEncoder(nn.Module):
    def __init__(self, max_utterance_num: int, max_utterance_len: int, hidden_size: int, num_category: int):
        super().__init__()

        self.max_utterance_num = max_utterance_num
        self.max_utterance_len = max_utterance_len
        self.num_category = num_category

        self.embed_dim = hidden_size
        self.local_self_attention = AverageSelfAttention(self.embed_dim)
        self.fc_alpha = Conv1D(num_category, hidden_size)
        # nx = hidden_size
        # nz = hidden_size
        # self.mean = Conv1D(nz, nx)
        # self.logvar = Conv1D(nz, nx)

    def forward(self, hidden_states, attention_mask):
        # reshape to [batch * max_num, max_len, hidden_size]
        batch_size, sequence_length = hidden_states.shape[:2]
        assert sequence_length % self.max_utterance_len == 0
        utterance_num = sequence_length // self.max_utterance_len

        local_states = hidden_states.view(batch_size, utterance_num, self.max_utterance_len, self.embed_dim)
        local_states = local_states.view(batch_size * utterance_num, self.max_utterance_len, self.embed_dim)
        local_mask = attention_mask.view(batch_size, utterance_num, self.max_utterance_len)
        local_mask = local_mask.reshape(batch_size * utterance_num, self.max_utterance_len)

        representations, _ = self.local_self_attention(local_states, local_mask)  # [batch * max_num, hidden_size]
        alpha = F.softmax(self.fc_alpha(representations.reshape(batch_size, utterance_num, self.embed_dim)), dim=-1)
        # mean = self.mean(representations) # [batch * max_num, embed_dim]
        # logvar = self.logvar(representations) # [batch * max_num, embed_dim]
        # mean = mean.view(batch_size, utterance_num, self.embed_dim)
        # logvar = logvar.view(batch_size, utterance_num, self.embed_dim)

        return alpha


class PriorContextSensitiveEncoder(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, hidden_states, attention_mask=None):
        prior_mean = torch.zeros([hidden_states.size(0), hidden_states.size(-1)], device=hidden_states.device)
        prior_logvar = torch.zeros([hidden_states.size(0), hidden_states.size(-1)], device=hidden_states.device) # log(1)
        prior_mean = prior_mean.to(hidden_states.dtype)
        prior_logvar = prior_logvar.to(hidden_states.dtype)

        return (prior_mean, prior_logvar)


class PriorContextIndependentEncoder(nn.Module):
    def __init__(self, max_utterance_num: int, max_utterance_len: int, hidden_size: int, num_category: int):
        super().__init__()

        self.max_utterance_num = max_utterance_num
        self.max_utterance_len = max_utterance_len

        self.embed_dim = hidden_size
        self.local_self_attention = AverageSelfAttention(self.embed_dim)
        self.category_embed = nn.Linear(num_category, hidden_size)
        nx = hidden_size * 2
        nz = hidden_size
        self.mean = Conv1D(nz, nx)
        self.logvar = Conv1D(nz, nx)
        self.init_mean = Conv1D(nz, nx // 2)
        self.init_logvar = Conv1D(nz, nx // 2)

    def forward(self, hidden_states, attention_mask, categories):
        # reshape to [batch * max_num, max_len, hidden_size]
        batch_size, sequence_length = hidden_states.shape[:2]
        assert sequence_length % self.max_utterance_len == 0
        utterance_num = sequence_length // self.max_utterance_len

        local_states = hidden_states.view(batch_size, utterance_num, self.max_utterance_len, self.embed_dim)
        local_states = local_states.reshape(batch_size * utterance_num, self.max_utterance_len, self.embed_dim)
        local_mask = attention_mask.view(batch_size, utterance_num * self.max_utterance_len)
        local_mask = local_mask.reshape(batch_size * utterance_num, self.max_utterance_len)
        representations, _ = self.local_self_attention(local_states, local_mask)  # [batch * max_num, hidden_size]

        categories_representations = self.category_embed(categories)
        categories_representations_first = categories_representations[:, :1].reshape(batch_size, self.embed_dim) # [batch, 1, embed_dim]
        categories_representations_last = categories_representations[:, 1:].reshape(batch_size * utterance_num, self.embed_dim) # [batch, utterance_num, embed_dim]

        mean = self.mean(torch.cat([representations, categories_representations_last], dim=-1))  # [batch * max_num, embed_dim]
        logvar = self.logvar(torch.cat([representations, categories_representations_last], dim=-1))  # [batch * max_num, embed_dim]
        mean = mean.view(batch_size, utterance_num, self.embed_dim) # [2 ~ n]
        logvar = logvar.view(batch_size, utterance_num, self.embed_dim) # [2 ~ n]

        init_mean = self.init_mean(categories_representations_first).unsqueeze(1)
        init_logvar = self.init_logvar(categories_representations_first).unsqueeze(1)
        mean = torch.cat([init_mean, mean], dim=1) # [batch, max_num, embed_dim]
        logvar = torch.cat([init_logvar, logvar], dim=1) # [batch, max_num, embed_dim]

        return (mean, logvar)


class PriorCategoryEncoder(nn.Module):
    def __init__(self, config, num_category): # for posterior/prior distribution of context sensitive/independent variables
        super().__init__()
        self.num_category = num_category
        config_new = copy.copy(config)
        config_new.vocab_size = num_category
        self.encoder = SharedEncoder(config_new)

    def forward(self, categories, attention_mask):
        categories = torch.cat([torch.ones_like(categories[:,:1]) / self.num_category, categories], dim=1)
        category_embeds = torch.einsum("bnc,cd->bnd", categories, self.encoder.wte.weight)
        category_mask = torch.cat([torch.ones_like(attention_mask[:,:1]), attention_mask], dim=1)
        category_hidden_states = self.encoder(
            attention_mask=category_mask,
            inputs_embeds=category_embeds,
            use_cache=False,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        )["last_hidden_state"] # [batch, max_num, hidden_size]
        prior_category_alpha = F.softmax(torch.einsum("bnd,dc->bnc", category_hidden_states, self.encoder.wte.weight.transpose(0, 1)), dim=-1)
        return prior_category_alpha


class GPT2CondAttention(GPT2Attention):
    def __init__(self, config, max_utterance_num, max_utterance_len):
        super().__init__(config)

        max_positions = config.max_position_embeddings
        self.register_buffer(
            "bias",
            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
                1, 1, max_positions, max_positions
            ),
        )
        self.register_buffer("masked_bias", torch.tensor(-1e4))

        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
            )

        self.scale_attn_weights = config.scale_attn_weights
        self.is_cross_attention = False

        self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
        self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

        self.pruned_heads = set()

        # add code here
        self.max_utterance_num = max_utterance_num
        self.max_utterance_len = max_utterance_len
        self.context_sensitive_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
        self.context_independent_attn = Conv1D(2 * self.embed_dim, self.embed_dim)

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        if self.scale_attn_weights:
            attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)

        # if only "normal" attention layer implements causal mask
        query_length, key_length = query.size(-2), key.size(-2)
        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
        attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))

        if attention_mask is not None:
            # add code here
            assert attention_mask.size()[-1] == attn_weights.size()[-1] - 2 * self.max_utterance_num
            # print('the size of attention_mask is {}'.format(attention_mask.size()))
            # input(">>>")
            batch, seq_length = attention_mask.size()[0], attention_mask.size()[-1]
            zeros = torch.zeros((batch, self.max_utterance_num, 2), device=attention_mask.device, dtype=attention_mask.dtype)
            if self.training:
                attention_mask = torch.cat([
                    zeros.unsqueeze(1).unsqueeze(2),
                    attention_mask.reshape(batch, 1, 1, self.max_utterance_num, self.max_utterance_len)
                ], dim=-1).reshape(batch, 1, 1, self.max_utterance_num * (2 + self.max_utterance_len))
            else:
                context_length = (self.max_utterance_num - 1) * self.max_utterance_len
                attention_mask_first = torch.cat([
                    zeros[:,:-1].unsqueeze(1).unsqueeze(2), # batch, 1, 1, max_num - 1, 2
                    attention_mask[:,:,:,:context_length].reshape(batch, 1, 1, self.max_utterance_num - 1, self.max_utterance_len)
                ], dim=-1).reshape(batch, 1, 1, (self.max_utterance_num - 1) * (self.max_utterance_len + 2))
                attention_mask = torch.cat([
                    attention_mask_first,
                    zeros[:,-1].unsqueeze(1).unsqueeze(2), # batch, 1, 1, 2
                    attention_mask[:, :, :, context_length:]
                ], dim=-1).reshape(batch, 1, 1, seq_length + 2 * self.max_utterance_num)

            # Apply the attention mask
            # print("the size of attn_weights is {}".format(attn_weights.size()))
            # print("the size of attention_mask is {}".format(attention_mask.size()))
            # input(">>>")
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.Softmax(dim=-1)(attn_weights)
        attn_weights = self.attn_dropout(attn_weights)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights

    def forward(
        self,
        hidden_states,
        layer_past=None,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=False,
        output_attentions=False,
        **kwargs,
    ):
        query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

        query = self._split_heads(query, self.num_heads, self.head_dim)
        key = self._split_heads(key, self.num_heads, self.head_dim)
        value = self._split_heads(value, self.num_heads, self.head_dim)

        if layer_past is not None:
            past_key, past_value = layer_past
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

        if use_cache is True:
            present = (key, value)
        else:
            present = None

        # add code here
        context_sensitive = kwargs.get("context_sensitive", None)
        context_independent = kwargs.get("context_independent", None)
        assert context_sensitive is not None and context_independent is not None
        # print('the size of context_sensitive is {}'.format(context_sensitive.size()))
        # print('the size of context_independent is {}'.format(context_independent.size()))
        # input('>>>')
        assert len(context_sensitive.size()) == 2 and len(context_independent.size()) == 3

        context_sensitive = context_sensitive.unsqueeze(1).expand_as(context_independent)
        key_context_sensitive, value_context_sensitive = self.context_sensitive_attn(context_sensitive).split(self.split_size, dim=2)
        key_context_sensitive = self._split_heads(key_context_sensitive, self.num_heads, self.head_dim)
        value_context_sensitive = self._split_heads(value_context_sensitive, self.num_heads, self.head_dim)
        key_context_independent, value_context_independent = self.context_independent_attn(context_independent).split(self.split_size, dim=2)
        key_context_independent = self._split_heads(key_context_independent, self.num_heads, self.head_dim)
        value_context_independent = self._split_heads(value_context_independent, self.num_heads, self.head_dim)

        batch, head, seq_length, head_features = key.size()
        if self.training: # how to calculate ppl?
            assert seq_length == self.max_utterance_num * self.max_utterance_len
            key = torch.cat([
                key_context_sensitive.unsqueeze(3), # batch, head, max_num, 1, head_features
                key_context_independent.unsqueeze(3),
                key.view(batch, head, self.max_utterance_num, self.max_utterance_len, head_features),
            ], dim=3).view(batch, head, self.max_utterance_num * (2 + self.max_utterance_len), head_features)
            value = torch.cat([
                value_context_sensitive.unsqueeze(3),
                value_context_independent.unsqueeze(3),
                value.view(batch, head, self.max_utterance_num, self.max_utterance_len, head_features),
            ], dim=3).view(batch, head, self.max_utterance_num * (2 + self.max_utterance_len), head_features)
        else:
            context_length = (self.max_utterance_num - 1) * self.max_utterance_len
            key_first = torch.cat([
                key_context_sensitive[:,:,:-1].unsqueeze(3), # batch, head, max_num - 1, 1, head_features
                key_context_independent[:,:,:-1].unsqueeze(3),
                key[:,:,:context_length].view(batch, head, self.max_utterance_num - 1, self.max_utterance_len, head_features),
            ], dim=3).view(batch, head, context_length + 2 * (self.max_utterance_num - 1), head_features)
            key = torch.cat([
                key_first,
                key_context_sensitive[:,:,-1].unsqueeze(2), # batch, head, 1, head_features
                key_context_independent[:,:,-1].unsqueeze(2),
                key[:,:,context_length:]
            ], dim=2).view(batch, head, seq_length + 2 * (self.max_utterance_num), head_features)
            value_first = torch.cat([
                value_context_sensitive[:, :, :-1].unsqueeze(3),  # batch, head, max_num - 1, 1, head_features
                value_context_independent[:, :, :-1].unsqueeze(3),
                value[:, :, :context_length].view(batch, head, self.max_utterance_num - 1, self.max_utterance_len, head_features),
            ], dim=3).view(batch, head, context_length + 2 * (self.max_utterance_num - 1), head_features)
            value = torch.cat([
                value_first,
                value_context_sensitive[:, :, -1].unsqueeze(2),  # batch, head, 1, head_features
                value_context_independent[:, :, -1].unsqueeze(2),
                value[:, :, context_length:]
            ], dim=2).view(batch, head, seq_length + 2 * (self.max_utterance_num), head_features)

        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs  # a, present, (attentions)


class GPT2CondBlock(GPT2Block):
    def __init__(self, config, max_utterance_num, max_utterance_len):
        super().__init__(config)
        hidden_size = config.hidden_size
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = GPT2CondAttention(config, max_utterance_num, max_utterance_len)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = GPT2MLP(inner_dim, config)

    def forward(
        self,
        hidden_states,
        layer_past=None,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=False,
        output_attentions=False,
        **kwargs
    ):
        context_sensitive = kwargs.get("context_sensitive", None)
        context_independent = kwargs.get("context_independent", None)
        assert context_sensitive is not None and context_independent is not None
        # print('(1) the size of context_sensitive is {}'.format(context_sensitive.size()))
        # input(">>>")

        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            context_sensitive=context_sensitive,
            context_independent=context_independent,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs  # hidden_states, present, (attentions, cross_attentions)


class Decoder(GPT2Model):
    def __init__(self, config, max_utterance_num, max_utterance_len, add_input=False, add_attn=False, attn_proj_vary=False):
        super().__init__(config)

        self.add_input = add_input
        self.add_attn = add_attn
        self.attn_proj_vary = attn_proj_vary

        self.embed_dim = config.hidden_size
        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
        self.drop = nn.Dropout(config.embd_pdrop)

        self.max_utterance_num = max_utterance_num
        self.max_utterance_len = max_utterance_len

        if self.add_input:
            nz = self.embed_dim
            nx = self.embed_dim
            self.input_proj_context_sensitive = nn.Linear(nz, nx, bias=False)
            self.input_proj_context_independent = nn.Linear(nz, nx, bias=False)

        if self.add_attn:
            nz = self.embed_dim
            nx = self.embed_dim
            n = config.num_hidden_layers

            if self.attn_proj_vary:
                self.attn_proj_context_sensitive = nn.Linear(nz, nx * n, bias=False)
                self.attn_proj_context_independent = nn.Linear(nz, nx * n, bias=False)
            else:
                self.attn_proj_context_sensitive = nn.Linear(nz, nx, bias=False)
                self.attn_proj_context_independent = nn.Linear(nz, nx, bias=False)

            self.h = nn.ModuleList([GPT2CondBlock(config, max_utterance_num, max_utterance_len) for _ in range(config.num_hidden_layers)])
        else:
            self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        context_sensitive = kwargs.get("context_sensitive", None)
        context_independent = kwargs.get("context_independent", None)

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
            batch_size = input_ids.shape[0]
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size = inputs_embeds.shape[0]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])
        if position_ids is not None:
            position_ids = position_ids.view(-1, input_shape[-1])

        if past_key_values is None:
            past_length = 0
            past_key_values = tuple([None] * len(self.h))
        else:
            past_length = past_key_values[0][0].size(-2)
        if position_ids is None:
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

        # GPT2Attention mask.
        if attention_mask is not None:
            assert batch_size > 0, "batch_size has to be defined and > 0"
            attention_mask = attention_mask.view(batch_size, -1)
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, None, None, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * -10000.0

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # head_mask has shape n_layer x batch x n_heads x N x N
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)

        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        if token_type_ids is not None:
            token_type_embeds = self.wte(token_type_ids)
            hidden_states = hidden_states + token_type_embeds

        # add code here
        if self.add_input:
            assert context_sensitive is not None and context_independent is not None
            # context_sensitive = context_sensitive.unsqueeze(1).expand_as(context_independent)
            input_proj_context_sensitive = self.input_proj_context_sensitive(context_sensitive.unsqueeze(1).expand_as(context_independent))
            input_proj_context_independent = self.input_proj_context_independent(context_independent)
            batch, seq_length, embed_dim = hidden_states.size()
            if self.training:
                assert seq_length == self.max_utterance_num * self.max_utterance_len
                input_proj_context_sensitive = input_proj_context_sensitive.unsqueeze(2).expand(-1, -1, self.max_utterance_len, -1).reshape(batch, seq_length, embed_dim)
                input_proj_context_independent = input_proj_context_independent.unsqueeze(2).expand(-1, -1, self.max_utterance_len, -1).reshape(batch, seq_length, embed_dim)
                hidden_states = hidden_states + input_proj_context_sensitive + input_proj_context_independent
            elif seq_length == 1:
                hidden_states = hidden_states + input_proj_context_sensitive[:,-1].unsqueeze(1) + input_proj_context_independent[:,-1].unsqueeze(1) # [batch, 1, embed_dim]
            else:
                context_length = (self.max_utterance_num - 1) * self.max_utterance_len
                assert seq_length == context_length + 1
                # print("context_length: ", context_length)
                # print('the size of context_sensitive is {}'.format(context_sensitive.size()))
                # print("the size of input_proj_context_sensitive is {}".format(input_proj_context_sensitive.size()))
                # input(">>>")
                input_proj_context_sensitive_first = input_proj_context_sensitive[:,:-1].unsqueeze(2).expand(-1, -1, self.max_utterance_len, -1).reshape(batch, context_length, embed_dim)
                input_proj_context_independent_first = input_proj_context_independent[:,:-1].unsqueeze(2).expand(-1, -1, self.max_utterance_len, -1).reshape(batch, context_length, embed_dim)
                input_proj_first = input_proj_context_sensitive_first + input_proj_context_independent_first # [batch, context_length, embed_dim]
                input_proj_context_sensitive_last = input_proj_context_sensitive[:, -1].unsqueeze(1).expand(batch, seq_length - context_length, embed_dim)
                input_proj_context_independent_last = input_proj_context_independent[:, -1].unsqueeze(1).expand(batch, seq_length - context_length, embed_dim)
                input_proj_last = input_proj_context_sensitive_last + input_proj_context_independent_last
                input_proj = torch.cat([input_proj_first, input_proj_last], dim=1) # [batch, seq_length, embed_dim]
                hidden_states = hidden_states + input_proj

        hidden_states = self.drop(hidden_states)

        output_shape = input_shape + (hidden_states.size(-1),)

        # add code here
        if self.add_attn:
            assert context_sensitive is not None and context_independent is not None
            attn_proj_context_sensitive = self.attn_proj_context_sensitive(context_sensitive) # [batch, embed_dim]
            attn_proj_context_independent = self.attn_proj_context_independent(context_independent) # [batch, max_num, embed_dim]
            if self.attn_proj_vary:
                attn_proj_context_sensitive = attn_proj_context_sensitive.split(hidden_states.size(-1), dim=-1)
                attn_proj_context_independent = attn_proj_context_independent.split(hidden_states.size(-1), dim=-1)
                assert len(attn_proj_context_sensitive) == len(self.h) and len(attn_proj_context_independent) == len(self.h)

        presents = () if use_cache else None
        all_self_attentions = () if output_attentions else None
        # all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
        all_hidden_states = () if output_hidden_states else None
        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):

            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if self.add_attn:
                if self.attn_proj_vary:
                    latent_context_sensitive = attn_proj_context_sensitive[i]
                    latent_context_independent = attn_proj_context_independent[i]
                else:
                    latent_context_sensitive = attn_proj_context_sensitive
                    latent_context_independent = attn_proj_context_independent
                outputs = block(
                    hidden_states,
                    context_sensitive=latent_context_sensitive,
                    context_independent=latent_context_independent,
                    layer_past=layer_past,
                    attention_mask=attention_mask,
                    head_mask=head_mask[i],
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )
            else:
                outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    attention_mask=attention_mask,
                    head_mask=head_mask[i],
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

            hidden_states = outputs[0]
            if use_cache is True:
                presents = presents + (outputs[1],)

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)

        hidden_states = self.ln_f(hidden_states)

        hidden_states = hidden_states.view(*output_shape)
        # Add last hidden state
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )


@dataclass
class CausalLMOutputWithCrossAttentions(ModelOutput):
    # VAE
    lm_loss: Optional[torch.FloatTensor] = None
    context_sensitive_kl_loss: Optional[torch.FloatTensor] = None
    context_independent_kl_loss: Optional[torch.FloatTensor] = None
    category_kl_loss: Optional[torch.FloatTensor] = None

    # Disentangle
    SCC_loss: Optional[torch.FloatTensor] = None
    DFP_loss: Optional[torch.FloatTensor] = None
    MI_loss: Optional[torch.FloatTensor] = None

    # Other
    context_sensitive: Optional[torch.FloatTensor] = None
    context_independent: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None


class VAEModel(GPT2LMHeadModel):
    def __init__(self, config, max_utterance_num, max_utterance_len, num_category, add_input=False, add_attn=False, add_softmax=False, attn_proj_vary=False, learn_prior=False):
        super().__init__(config)

        # add code here
        self.max_utterance_num = max_utterance_num
        self.max_utterance_len = max_utterance_len
        self.num_category = num_category
        self.add_input = add_input
        self.add_attn = add_attn
        self.add_softmax = add_softmax
        self.attn_proj_vary = attn_proj_vary
        self.learn_prior = learn_prior
        self.bow_proj = nn.Linear(config.n_embd, 64, bias=True)
        self.bow_head = nn.Linear(64, config.vocab_size, bias=False)

        self.transformer = Decoder(config, max_utterance_num=max_utterance_num, max_utterance_len=max_utterance_len, add_input=add_input, add_attn=add_attn, attn_proj_vary=attn_proj_vary)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.shared_encoder = SharedEncoder(config)
        self.posterior_context_sensitive_encoder = PosteriorContextSensitiveEncoder(max_utterance_num=self.max_utterance_num, max_utterance_len=self.max_utterance_len, hidden_size=config.hidden_size)
        self.posterior_context_independent_encoder = PosteriorContextIndependentEncoder(max_utterance_num=self.max_utterance_num, max_utterance_len=self.max_utterance_len, hidden_size=config.hidden_size)
        self.posterior_category_encoder = PosteriorCategoryEncoder(max_utterance_num=self.max_utterance_num, max_utterance_len=self.max_utterance_len, hidden_size=config.hidden_size, num_category=self.num_category)

        if self.learn_prior:
            self.prior_context_sensitive_encoder = PriorContextSensitiveEncoder()
            self.prior_context_independent_encoder = PriorContextIndependentEncoder(max_utterance_num=self.max_utterance_num, max_utterance_len=self.max_utterance_len, hidden_size=config.hidden_size, num_category=self.num_category)
            self.prior_category_encoder = PriorCategoryEncoder(config, num_category=self.num_category)

        if self.add_softmax:
            self.head_proj_context_sensitive = Conv1D(config.n_embd, config.n_embd)
            self.head_proj_context_independent = Conv1D(config.n_embd, config.n_embd)

        self.init_weights()

    def sample_gumbel_softmax(self, alpha):
        if self.training:
            # Sample from gumbel distribution
            unif = torch.rand(alpha.size()).type_as(alpha)
            gumbel = -torch.log(-torch.log(unif + EPS) + EPS)
            # Reparameterize to create gumbel softmax sample
            log_alpha = torch.log(alpha + EPS)
            logit = (log_alpha + gumbel) / .67
            return F.softmax(logit, dim=1)
        else:
            # In reconstruction mode, pick most likely sample
            _, max_alpha = torch.max(alpha, dim=1)
            one_hot_samples = torch.zeros(alpha.size())
            # On axis 1 of one_hot_samples, scatter the value 1 at indices
            # max_alpha. Note the view is because scatter_ only accepts 2D
            # tensors.
            one_hot_samples.scatter_(1, max_alpha.view(-1, 1).data.cpu(), 1)
            return one_hot_samples.to(alpha.device)

    def sample_normal(self, mean, logvar):
        """
        Samples from a normal distribution using the reparameterization trick.
        Parameters
        ----------
        mean : torch.Tensor
            Mean of the normal distribution. Shape (N, D) where D is dimension
            of distribution.
        logvar : torch.Tensor
            Diagonal log variance of the normal distribution. Shape (N, D)
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.zeros(std.size()).normal_().type_as(mean)
            return mean + std * eps
        else:
            # Reconstruction mode
            return mean

    # def reparameterize(self, mean, logvar, z=None):
    #     std = logvar.mul(0.5).exp()
    #     if z is None:
    #         z = torch.randn(std.size(), device=mean.device, dtype=mean.dtype)
    #     return z.mul(std) + mean

    def kl_loss(self, mean1, logvar1, mean2, logvar2):
        exponential = logvar1 - logvar2 - torch.pow(mean1 - mean2, 2) / logvar2.exp() - torch.exp(logvar1 - logvar2) + 1
        result = -0.5 * torch.sum(exponential, tuple(range(1, len(exponential.shape))))
        return result # [batch]

    def kl_discrete_loss(self, alpha1, alpha2):
        result = torch.sum(alpha1 * (torch.log(alpha1 + EPS) - torch.log(alpha2 + EPS)), dim=1)
        return result # [batch]

    def get_utterance_mask(self, attention_mask):
        seq_length = attention_mask.size()[1]
        utterance_num = seq_length // self.max_utterance_len
        utterance_mask = attention_mask.reshape(-1, utterance_num, self.max_utterance_len).sum(dim=-1).to(torch.bool).to(attention_mask.dtype)
        return utterance_mask

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        # from_prior=False,
        from_mean=False,
        **kwargs,
    ):
        if self.training:
            # utterance_mask = attention_mask.reshape(-1, self.max_utterance_num, self.max_utterance_len).sum(dim=-1).to(torch.bool).to(attention_mask.dtype)
            utterance_mask = self.get_utterance_mask(attention_mask)
            shared_hidden_states = self.shared_encoder(
                input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                use_cache=False,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )["last_hidden_state"]
            posterior_context_sensitive_mean, posterior_context_sensitive_logvar = self.posterior_context_sensitive_encoder(shared_hidden_states, attention_mask)
            posterior_context_independent_mean, posterior_context_independent_logvar = self.posterior_context_independent_encoder(shared_hidden_states, attention_mask)
            posterior_category_alpha = self.posterior_category_encoder(shared_hidden_states, attention_mask)
            category = self.sample_gumbel_softmax(posterior_category_alpha)

            if self.learn_prior:
                prior_context_sensitive_mean, prior_context_sensitive_logvar = self.prior_context_sensitive_encoder(shared_hidden_states, attention_mask)
                prior_context_independent_mean, prior_context_independent_logvar = self.prior_context_independent_encoder(shared_hidden_states[:, :-self.max_utterance_len], attention_mask[:, :-self.max_utterance_len], category)
                prior_category_alpha = self.prior_category_encoder(categories=category[:,:-1], attention_mask=utterance_mask[:,:-1])

            else:
                prior_context_sensitive_mean = prior_context_sensitive_logvar = torch.zeros([input_ids.size(0), self.config.hidden_size], device=input_ids.device)
                prior_context_independent_mean = prior_context_independent_logvar = torch.zeros([input_ids.size(0), self.max_utterance_num, self.config.hidden_size], device=input_ids.device)
                # prior_context_sensitive_mean, prior_context_sensitive_logvar = prior_context_sensitive_mean.to(posterior_context_sensitive_mean.device), prior_context_sensitive_logvar.to(posterior_context_sensitive_logvar.device)
                # prior_context_independent_mean, prior_context_independent_logvar = prior_context_independent_mean.to(posterior_context_independent_mean.device), prior_context_independent_logvar.to(posterior_context_independent_logvar.device)
                prior_category_alpha = torch.ones([input_ids.size(0), self.max_utterance_num, self.num_category], device=input_ids.device) / self.num_category


            latent_context_sensitive_mean, latent_context_sensitive_logvar = posterior_context_sensitive_mean, posterior_context_sensitive_logvar
            latent_context_independent_mean, latent_context_independent_logvar = posterior_context_independent_mean, posterior_context_independent_logvar

            if from_mean:
                context_sensitive = latent_context_sensitive_mean
                context_independent = latent_context_independent_mean
            else:
                context_sensitive = self.sample_normal(latent_context_sensitive_mean, latent_context_sensitive_logvar)
                context_independent = self.sample_normal(latent_context_independent_mean, latent_context_independent_logvar)
            assert not torch.isnan(context_sensitive).any(), "training get nan context sensitive variable"
            assert not torch.isnan(context_independent).any(), "training get nan context independent variable"

            context_sensitive_kl_loss = self.kl_loss(posterior_context_sensitive_mean, posterior_context_sensitive_logvar, prior_context_sensitive_mean, prior_context_sensitive_logvar).mean() # [batch]
            context_independent_kl_loss = self.kl_loss(
                posterior_context_independent_mean.reshape(-1, self.config.hidden_size),
                posterior_context_independent_logvar.reshape(-1, self.config.hidden_size),
                prior_context_independent_mean.reshape(-1, self.config.hidden_size),
                prior_context_independent_logvar.reshape(-1, self.config.hidden_size)
            )
            context_independent_kl_loss = context_independent_kl_loss.reshape(-1, self.max_utterance_num).mul(utterance_mask)
            context_independent_kl_loss = context_independent_kl_loss.sum(dim=-1).mean()

            category_kl_loss = self.kl_discrete_loss(posterior_category_alpha.reshape(-1, self.num_category), prior_category_alpha.reshape(-1, self.num_category))
            category_kl_loss = category_kl_loss.reshape(-1, self.max_utterance_num).mul(utterance_mask)
            category_kl_loss = category_kl_loss.sum(dim=-1).mean()

            # add SCC loss here
            shuffled_input_ids = kwargs.get("shuffled_input_ids", None)
            shuffled_attention_mask = kwargs.get("shuffled_attention_mask", None)
            assert shuffled_input_ids is not None and shuffled_attention_mask is not None
            shuffled_hidden_states = self.shared_encoder(
                shuffled_input_ids,
                attention_mask=shuffled_attention_mask,
                position_ids=position_ids,
                use_cache=False,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )["last_hidden_state"]
            shuffled_context_sensitive_mean, shuffled_context_sensitive_logvar = self.posterior_context_sensitive_encoder(shuffled_hidden_states, shuffled_attention_mask)
            shuffled_context_sensitive = self.sample_normal(shuffled_context_sensitive_mean, shuffled_context_sensitive_logvar)
            similarity_mat = torch.matmul(shuffled_context_sensitive, context_sensitive.transpose(-1, -2)) / (self.config.n_embd ** 0.5) # [batch, batch]
            SCC_loss = -torch.log(torch.diagonal(nn.Softmax(dim=-1)(similarity_mat))).mean()

            # add DFP loss here
            bow_hidden_states = context_independent.unsqueeze(2).expand(-1, -1, self.max_utterance_len, -1).reshape(-1, self.max_utterance_num * self.max_utterance_len, self.config.n_embd)
            bow_logits = self.bow_head(ACT2FN["gelu"](self.bow_proj(bow_hidden_states)))
            DFP_loss = nn.CrossEntropyLoss(ignore_index=-1)(bow_logits.reshape(-1, self.config.vocab_size), labels.reshape(-1))

            # add MI loss here
            MI_loss = 0
            for t in range(self.max_utterance_num):
                H_1 = entropy(x=context_sensitive, mu=latent_context_sensitive_mean, logvar=latent_context_sensitive_logvar, batch_mask=utterance_mask[:,t])
                H_2 = entropy(x=context_independent[:,t], mu=latent_context_independent_mean[:,t], logvar=latent_context_independent_logvar[:,t], batch_mask=utterance_mask[:,t])
                H_3 = joint_entropy(x_1=context_sensitive, mu_1=latent_context_sensitive_mean, logvar_1=latent_context_sensitive_logvar, x_2=context_independent[:,t], mu_2=latent_context_independent_mean[:,t], logvar_2=latent_context_independent_logvar[:,t], batch_mask=utterance_mask[:,t])
                MI_loss += (H_1 + H_2 - H_3) / self.max_utterance_num

        else:
            context_sensitive_kl_loss = None
            context_independent_kl_loss = None
            category_kl_loss = None

            SCC_loss = None
            DFP_loss = None
            MI_loss = None

            context_sensitive = kwargs.get("context_sensitive", None)
            context_independent = kwargs.get("context_independent", None)
            # category = kwargs.get("category", None)
            assert not torch.isnan(context_sensitive).any(), "training get nan context sensitive variable"
            assert not torch.isnan(context_independent).any(), "training get nan context independent variable"
            # assert not torch.isnan(category).any(), "training get nan category variable"

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            context_sensitive=context_sensitive,
            context_independent=context_independent,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]
        # lm_logits = self.lm_head(hidden_states)

        # add code here
        if self.add_softmax:
            # context_sensitive = context_sensitive.unsqueeze(1).expand_as(context_independent)
            # lm_logits_context_sensitive = self.lm_head(self.head_proj_context_sensitive(context_sensitive.unsqueeze(1).expand_as(context_independent)))
            # lm_logits_context_independent = self.lm_head(self.head_proj_context_independent(context_independent))
            # lm_logits = self.lm_head(
            #     (
            #         hidden_states +
            #         self.head_proj_context_sensitive(context_sensitive.unsqueeze(1).expand_as(context_independent)) +
            #         self.head_proj_context_independent(context_independent)
            #     )
            # )
            # batch, seq_length, vocab_size = lm_logits.size()
            context_sensitive_hidden_states = self.head_proj_context_sensitive(context_sensitive.unsqueeze(1).expand_as(context_independent))
            context_independent_hidden_states = self.head_proj_context_independent(context_independent)
            other_hidden_states = context_sensitive_hidden_states + context_independent_hidden_states
            batch, seq_length, hidden_size = hidden_states.size()
            if self.training:
                assert seq_length == self.max_utterance_num * self.max_utterance_len
                # other_hidden_states = context_sensitive_hidden_states + context_independent_hidden_states
                other_hidden_states = other_hidden_states.unsqueeze(2).expand(-1, -1, self.max_utterance_len, -1).reshape(batch, seq_length, hidden_size)
                lm_logits = self.lm_head(hidden_states + other_hidden_states)
                # lm_logits_other = lm_logits_other.unsqueeze(2).expand(-1, -1, self.max_utterance_len, -1).reshape(batch, seq_length, vocab_size)
                # lm_logits_context_independent = lm_logits_context_independent.unsqueeze(2).expand(-1, -1, self.max_utterance_len, -1).reshape(batch, seq_length, vocab_size)
                # lm_logits = lm_logits + lm_logits_other
            elif seq_length == 1:
                lm_logits = self.lm_head(hidden_states + other_hidden_states[:,-1:])
                # lm_logits = lm_logits + lm_logits_other[:,-1].unsqueeze(1)
            else:
                context_length = (self.max_utterance_num - 1) * self.max_utterance_len
                assert context_length + 1 == seq_length
                other_hidden_states_first = other_hidden_states[:,:-1].unsqueeze(2).expand(-1, -1, self.max_utterance_len, -1).reshape(batch, context_length, hidden_size)
                other_hidden_states_last = other_hidden_states[:,-1:].expand(batch, seq_length - context_length, hidden_size)
                lm_logits = self.lm_head(hidden_states + torch.cat([other_hidden_states_first, other_hidden_states_last], dim=1))

                # lm_logits_other_first = lm_logits_other[:,:-1].unsqueeze(2).expand(-1, -1, self.max_utterance_len, -1).reshape(batch, context_length, vocab_size)
                # lm_logits_context_independent_first = lm_logits_context_independent[:,:-1].unsqueeze(2).expand(-1, -1, self.max_utterance_len, -1).reshape(batch, context_length, vocab_size)
                # lm_logits_first = lm_logits_context_sensitive_first + lm_logits_context_independent_first
                # lm_logits_other_last = lm_logits_other[:, -1].unsqueeze(1).expand(batch, seq_length - context_length, vocab_size)
                # lm_logits__last = lm_logits_context_independent[:, -1].unsqueeze(1).expand(batch, seq_length - context_length, vocab_size)
                # lm_logits_last = lm_logits_context_sensitive_last + lm_logits_context_independent_last
                # lm_logits = lm_logits + torch.cat([lm_logits_other_first, lm_logits_other_last], dim=1)

        if self.training:
            lm_loss = CrossEntropyLoss(ignore_index=-1)(lm_logits.view(-1, lm_logits.size(-1)), labels.reshape(-1))
        else:
            lm_loss = None

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((lm_loss, context_sensitive_kl_loss, context_independent_kl_loss, category_kl_loss) + output) if self.training else output
        else:
            return CausalLMOutputWithCrossAttentions(
                lm_loss=lm_loss,
                context_sensitive_kl_loss=context_sensitive_kl_loss,
                context_independent_kl_loss=context_independent_kl_loss,
                category_kl_loss=category_kl_loss,
                SCC_loss=SCC_loss,
                DFP_loss=DFP_loss,
                MI_loss=MI_loss,
                context_sensitive=context_sensitive,
                context_independent=context_independent,
                logits=lm_logits,
                past_key_values=transformer_outputs.past_key_values,
                hidden_states=transformer_outputs.hidden_states,
                attentions=transformer_outputs.attentions,
                cross_attentions=transformer_outputs.cross_attentions,
            )

    def get_sequential_states(self, input_ids, attention_mask, position_ids):
        shared_hidden_states = self.shared_encoder(
            input_ids[:, :-1],
            attention_mask=attention_mask[:, :-1],
            position_ids=position_ids[:, :-1],
            use_cache=False,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True
        )["last_hidden_state"]
        posterior_category = self.posterior_category_encoder(shared_hidden_states, attention_mask[:, :-1])
        posterior_category = self.sample_gumbel_softmax(posterior_category.reshape(-1, self.num_category)).reshape(-1, self.max_utterance_num - 1, self.num_category)
        # utterance_mask = attention_mask[:,:-1].reshape(-1, self.max_utterance_num-1, self.max_utterance_len).sum(dim=-1).to(torch.bool).to(attention_mask.dtype)
        utterance_mask = self.get_utterance_mask(attention_mask[:, :-1])
        prior_category = self.prior_category_encoder(categories=posterior_category, attention_mask=utterance_mask)
        prior_category = self.sample_gumbel_softmax(prior_category.reshape(-1, self.num_category)).reshape(-1, self.max_utterance_num, self.num_category)

        _, context_state = torch.max(posterior_category, dim=-1) # [batch, max_utterance_num - 1]
        context_mask = utterance_mask
        _, predict_state = torch.max(prior_category[:,-1], dim=-1) # [batch]
        return context_state, context_mask, predict_state


    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
        attention_mask = kwargs.get("attention_mask", None)
        # token_type_ids = kwargs.get("token_type_ids", None)
        position_ids = kwargs.get("position_ids", None)
        # only last token for inputs_ids if past is defined in kwargs

        if past is None:
            # add code here
            shared_hidden_states = self.shared_encoder(
                input_ids[:,:-1],
                attention_mask=attention_mask[:,:-1],
                position_ids=position_ids[:,:-1],
                use_cache=False,
                output_attentions=False,
                output_hidden_states=False,
                return_dict=True
            )["last_hidden_state"]
            posterior_category = self.posterior_category_encoder(shared_hidden_states, attention_mask[:,:-1])
            posterior_category = self.sample_gumbel_softmax(posterior_category.reshape(-1, self.num_category)).reshape(-1, self.max_utterance_num-1, self.num_category)
            # utterance_mask = attention_mask[:,:-1].reshape(-1, self.max_utterance_num-1, self.max_utterance_len).sum(dim=-1).to(torch.bool).to(attention_mask.dtype)
            utterance_mask = self.get_utterance_mask(attention_mask[:,:-1])
            prior_category = self.prior_category_encoder(categories=posterior_category, attention_mask=utterance_mask)
            prior_category = self.sample_gumbel_softmax(prior_category.reshape(-1, self.num_category)).reshape(-1, self.max_utterance_num, self.num_category)
            category = torch.cat([posterior_category, prior_category[:,-1:]], dim=1) # [batch, max_num, c]

            context_sensitive_mean, context_sensitive_logvar = self.posterior_context_sensitive_encoder(shared_hidden_states, attention_mask[:,:-1])
            posterior_context_independent_mean, posterior_context_independent_logvar = self.posterior_context_independent_encoder(shared_hidden_states, attention_mask[:,:-1])
            if self.learn_prior:
                prior_context_independent_mean, prior_context_independent_logvar = self.prior_context_independent_encoder(shared_hidden_states, attention_mask[:,:-1], category)
            else:
                prior_context_independent_mean = prior_context_independent_logvar = torch.zeros([input_ids.size()[0], self.max_utterance_num, self.config.hidden_size], device=input_ids.device)
                prior_context_independent_mean, prior_context_independent_logvar = prior_context_independent_mean.to(posterior_context_independent_mean.device), prior_context_independent_logvar.to(posterior_context_independent_logvar.device)

            context_independent_mean = torch.cat([posterior_context_independent_mean, prior_context_independent_mean[:,-1:]], dim=1)
            context_independent_logvar = torch.cat([posterior_context_independent_logvar, prior_context_independent_logvar[:,-1:]], dim=1)
            context_sensitive = self.sample_normal(context_sensitive_mean, context_sensitive_logvar)
            context_independent = self.sample_normal(context_independent_mean, context_independent_logvar)
            # print("(init) the size of input_ids is {}; \nthe size of context_sensitive is {}; \nthe size of context_independent is {}; \nthe size of attention_mask is {}; \nthe size of position_ids is {}".format(input_ids.size(), context_sensitive.size(), context_independent.size(), attention_mask.size(), position_ids.size()))
            # input(">>>")
            return {
                "input_ids": input_ids,
                "context_sensitive": context_sensitive,
                "context_independent": context_independent,
                "past_key_values": None,
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "use_cache": True,
            }

        else:
            context_sensitive = kwargs.get("context_sensitive", None)
            context_independent = kwargs.get("context_independent", None)
            assert context_sensitive is not None and context_independent is not None
            input_ids = input_ids[:, -1].unsqueeze(-1)
            # print("(continue) the size of input_ids is {}; \nthe size of context_sensitive is {}; \nthe size of context_independent is {}; \nthe size of attention_mask is {}; \nthe size of position_ids is {}".format(input_ids.size(), context_sensitive.size(), context_independent.size(), attention_mask.size(), position_ids.size()))
            # print("(continue) the size of past_key is {}; \nthe size of past_values is {}".format(past[-1][0].size(), past[-1][1].size()))
            # input(">>>")
            return {
                "input_ids": input_ids,
                "context_sensitive": context_sensitive,
                "context_independent": context_independent,
                "past_key_values": past,
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "use_cache": True,
            }

    @staticmethod
    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
        """
        This function is used to re-order the :obj:`past_key_values` cache if
        :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
        called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
        """
        return tuple(
            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
            for layer_past in past
        )

    @staticmethod
    def _update_model_kwargs_for_generation(
        outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
    ) -> Dict[str, Any]:
        # update past
        if "past_key_values" in outputs:
            model_kwargs["past"] = outputs.past_key_values

        if "context_sensitive" in outputs:
            model_kwargs["context_sensitive"] = outputs.context_sensitive

        if "context_independent" in outputs:
            model_kwargs["context_independent"] = outputs.context_independent

        # update attention mask
        if "attention_mask" in model_kwargs:
            attention_mask = model_kwargs["attention_mask"]
            model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)

        # update position_ids
        if "position_ids" in model_kwargs:
            position_ids = model_kwargs["position_ids"]
            model_kwargs["position_ids"] = torch.add(position_ids[:, -1:], position_ids.new_ones((position_ids.shape[0], 1)))
            # print("update position_ids")

        # update token_type_ids with last value
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = token_type_ids[:, -1].unsqueeze(-1)

        return model_kwargs

    @staticmethod
    def _expand_inputs_for_generation(
        input_ids: torch.LongTensor,
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
        attention_mask: torch.LongTensor = None,
        encoder_outputs=None,
        **model_kwargs,
    ):
        expanded_return_idx = (
            torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
        )
        input_ids = input_ids.index_select(0, expanded_return_idx)

        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)

        if "position_ids" in model_kwargs:
            position_ids = model_kwargs["position_ids"]
            model_kwargs["position_ids"] = position_ids.index_select(0, expanded_return_idx)

        if attention_mask is not None:
            model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)

        if is_encoder_decoder:
            assert encoder_outputs is not None
            encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
                0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
            )
            model_kwargs["encoder_outputs"] = encoder_outputs
        return input_ids, model_kwargs

    def fix_pretrained_parameters(self):
        for name, parameter in self.named_parameters():
            new_pars = [
                "context_sensitive_attn", "context_independent_attn", "attn_proj_context_sensitive", "attn_proj_context_independent", # add_attn
                "input_proj_context_sensitive", "input_proj_context_independent", # add_input
                "head_proj_context_sensitive", "head_proj_context_independent", # add_softmax
                "posterior_context_sensitive_encoder", "posterior_context_independent_encoder", "posterior_category_encoder",
                "prior_context_sensitive_encoder", "prior_context_independent_encoder", "prior_category_encoder",
                "bow_proj", "bow_head"
            ]
            if not any([True if n in name else False for n in new_pars]):
                parameter.requires_grad = False

    def tune_all_parameters(self):
        for name, parameter in self.named_parameters():
            parameter.requires_grad = True
