import torch
import torch.nn as nn
import torch.nn.functional as f
# from layers import *
# from multiHeadAttention import *
import random
import numpy as np
from modules.transformer import *
import time

# seed = 0
# random.seed(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def fill_with_neg_inf(t):
    return t.float().fill_(float('-inf')).type_as(t)


def buffered_future_mask(tensor, tensor2=None):
    dim1 = dim2 = tensor.size(0)
    if tensor2 is not None:
        dim2 = tensor2.size(0)
    future_mask = torch.triu(fill_with_neg_inf(torch.ones(dim1, dim2)), 1 + abs(dim2 - dim1))
    if tensor.is_cuda:
        future_mask = future_mask.cuda()
    return future_mask[:dim1, :dim2]


class Transformer(nn.Module):
    def __init__(self, d_model, hidden_size, num_heads, num_blocks, activation=nn.ReLU(), dropout=0.1, mask=None):
        super(Transformer, self).__init__()
        self.num_blocks = num_blocks
        self.transformerS = [Encoder(d_model, hidden_size, num_heads, mask, dropout, activation) for _ in
                             range(num_blocks)]
        self.transformer = nn.Sequential(
            *[Encoder(d_model, hidden_size, num_heads, mask, dropout, activation)
              for _ in range(num_blocks)]
        )

    def forward(self, x):
        # for i in range(self.num_blocks):
        #     k = self.transformer[i](k, q, v)
        # return k
        return self.transformer(x)


class model1(nn.Module):
    ## todo change the model to use the pretrained weights for embeddings.
    def __init__(self,hparams):
        super(model1, self).__init__()
        self.hparams=hparams
        self.text_shape=hparams.get('text_shape')
        self.audio_shape=self.hparams.get('audio_shape')
        self.video_shape=self.hparams.get('video_shape')
        self.dataset=self.hparams.get('dataset','iemocap')
        self.batch_size=self.hparams.get('batch_size',32)
        self.model_size=self.hparams.get('model_size',80)
        self.num_heads=self.hparams.get('num_heads',10)
        self.num_blocks=self.hparams.get('num_blocks',8)
        self.device=device
        self.T=self.hparams.get('T',True)
        self.A=self.hparams.get('A',True)
        self.V=self.hparams.get('V',True)
        self.model_type=self.hparams.get('model_type','hybrid')
        if dataset == 'iemocap':
            self.output_dim = 8
        else:
            self.output_dim = 1
        
        layers = self.num_blocks
        self.dim= 3*self.model_size+self.text_shape
        self.trans_l_self = TransformerEncoder(embed_dim=self.model_size,
                                               num_heads=self.num_heads,
                                               layers=layers,
                                               attn_dropout=0.2,
                                               relu_dropout=0.1,
                                               res_dropout=0.1,
                                               embed_dropout=0.15,
                                               attn_mask=True, device=device)
        self.trans_v_self = TransformerEncoder(embed_dim=self.model_size,
                                               num_heads=self.num_heads,
                                               layers=layers,
                                               attn_dropout=0.2,
                                               relu_dropout=0.1,
                                               res_dropout=0.1,
                                               embed_dropout=0.15,
                                               attn_mask=True, device=device)
        self.trans_a_self = TransformerEncoder(embed_dim=self.model_size,
                                               num_heads=self.num_heads,
                                               layers=layers,
                                               attn_dropout=0.2,
                                               relu_dropout=0.1,
                                               res_dropout=0.1,
                                               embed_dropout=0.15,
                                               attn_mask=True, device=device)
        
        
        
        self.transMerge1 = TransformerEncoder(embed_dim=6*self.model_size,
                                               num_heads=self.num_heads,
                                               layers=layers,
                                               attn_dropout=0.2,
                                               relu_dropout=0.1,
                                               res_dropout=0.1,
                                               embed_dropout=0.15,
                                               attn_mask=True, device=device)

        self.transVT=TransformerEncoder(embed_dim=self.model_size,
                                               num_heads=self.num_heads,
                                               layers=layers,
                                               attn_dropout=0.2,
                                               relu_dropout=0.1,
                                               res_dropout=0.1,
                                               embed_dropout=0.15,
                                               attn_mask=True, device=device)
        self.proj1 = nn.Conv1d(self.text_shape, self.model_size, kernel_size=1, padding=0, bias=False).to(device)
        self.proj2 = nn.Conv1d(self.audio_shape, self.model_size, kernel_size=1, padding=0, bias=False).to(device)
        self.proj3 = nn.Conv1d(self.video_shape, self.model_size, kernel_size=1, padding=0, bias=False).to(device)
        
        self.embrace=EmbraceNet(device=device,input_size_list=[self.model_size,self.model_size,self.model_size],embracement_size=3*self.model_size)

        # self.linear1=nn.Linear(self.text_shape,self.model_size).to(device)
        self.linear2=nn.Linear(self.audio_shape,self.model_size).to(device)
        self.linear3=nn.Linear(self.video_shape,self.model_size).to(self.device)

        self.count=0
        
        if self.T:
          self.count+=1
        if self.A:
          
          self.count+=1
        if self.V:
          self.count+=1
        if not self.T and not self.V and not self.A:
          raise Exception("please set atleast one modality")
        
        multiplier=self.count*2
        multiplier1=self.count*2+1
        multiplier2=self.count+1
        
        self.transMerge = TransformerEncoder(embed_dim=self.count*self.model_size,
                                               num_heads=self.num_heads,
                                               layers=layers,
                                               attn_dropout=0.2,
                                               relu_dropout=0.1,
                                               res_dropout=0.1,
                                               embed_dropout=0.15,
                                               attn_mask=True, device=device)

        self.transMerge2 = TransformerEncoder(embed_dim=multiplier2*self.model_size,
                                               num_heads=self.num_heads,
                                               layers=layers,
                                               attn_dropout=0.2,
                                               relu_dropout=0.1,
                                               res_dropout=0.1,
                                               embed_dropout=0.15,
                                               attn_mask=True, device=device)                       
        if self.count==1:
          self.transMerge_ = TransformerEncoder(embed_dim=(self.count)*self.model_size,
                                               num_heads=self.num_heads,
                                               layers=layers,
                                               attn_dropout=0.2,
                                               relu_dropout=0.1,
                                               res_dropout=0.1,
                                               embed_dropout=0.15,
                                               attn_mask=True, device=device)
        else:
          self.transMerge_ = TransformerEncoder(embed_dim=(self.count-1)*self.model_size,
                                                num_heads=self.num_heads,
                                                layers=layers,
                                                attn_dropout=0.2,
                                                relu_dropout=0.1,
                                                res_dropout=0.1,
                                                embed_dropout=0.15,
                                                attn_mask=True, device=device)
        self.transMerge_1 = TransformerEncoder(embed_dim=self.model_size,
                                                num_heads=self.num_heads,
                                                layers=layers,
                                                attn_dropout=0.2,
                                                relu_dropout=0.1,
                                                res_dropout=0.1,
                                                embed_dropout=0.15,
                                                attn_mask=True, device=device)

        self.gru=nn.GRU(multiplier*self.model_size,multiplier*self.model_size).to(self.device) 
        self.gru2=nn.GRU(multiplier1*self.model_size,multiplier1*self.model_size).to(self.device)
        self.output = nn.Linear(multiplier*self.model_size, self.output_dim).to(device)
        self.output2 = nn.Linear(multiplier1*self.model_size, self.output_dim).to(device)
        self.output1 = nn.Linear((multiplier//2)*self.model_size, self.output_dim).to(device)
        self.linear=nn.Linear(multiplier*self.model_size,multiplier*self.model_size).to(device)
        self.linear2=nn.Linear(multiplier1*self.model_size,multiplier1*self.model_size).to(device)
        self.linear1=nn.Linear((multiplier//2)*self.model_size,(multiplier//2)*self.model_size).to(device)
        self.gru1=nn.GRU(self.model_size,self.model_size).to(self.device)
        # self.gru2=nn.GRU(3*self.model_size,3*self.model_size).to(self.device)

        self.lstm=nn.LSTM(3*self.model_size,4*self.model_size).to(self.device)

    def forward(self, xT, xA, xV):
        print(xT.shape)
        xT = xT.transpose(1,2)
        xA = xA.transpose(1, 2)
        xV = xV.transpose(1, 2)
        xT = self.proj1(xT)
        xA = self.proj2(xA)
        xV = self.proj3(xV)
        
        

        xT = xT.permute(2, 0, 1)
        xA = xA.permute(2, 0, 1)
        xV = xV.permute(2, 0, 1)
        print(xT.shape)
        #########################################   hybrid ################################################
        if self.model_type=='hybrid':
          xT1=self.trans_l_self(xT)
          xA1=self.trans_a_self(xA)
          xV1=self.trans_v_self(xV)
          

          if self.T and self.V and self.A:

            x_=torch.cat([xT,xA,xV],-1)
            x2=torch.cat([xT1,xA1,xV1],-1)
          elif self.T and self.A and not self.V:
            
            x_=torch.cat([xT,xA],-1)
            
            x2=torch.cat([xT1,xA1],-1)
          elif self.T and self.V and not self.A:
            x_=torch.cat([xT,xV],-1)
            x2=torch.cat([xT1,xV1],-1)

          elif self.V and self.A and not self.T:
            x_=torch.cat([xV,xA],-1)

            x2=torch.cat([xA1,xV1],-1)
          
          elif self.T and not self.V and not self.A:
            x_=xT
            x2=xT1
          elif self.A and not self.V and not self.T:
            x_=xA
            x2=xA1
          elif self.V and not self.A and not self.T:
            x_=xV
            x2=xV1
          # x_=torch.cat([xT,xA,xV],-1)
          
          
          x1=self.transMerge(x_)
          
          x=torch.cat([x1,x2],-1)
          # x=self.transMerge1(x)
        
          
          _,x=self.gru(x)
        
          x1=torch.squeeze(x)
          

          
          # x=f.relu(x)
        
          # x=x[-1]
          x=f.dropout(x1,0.12)
        
          # x=f.relu(self.linear(x))
          out=self.output(x)
          
          if self.dataset == 'iemocap':
              out = out.view(-1, 2)
          # print("output", out.size())
          return out

       

     
        #########################################################################################################################################

        ################################## LATE FUSION ##########################################################################################
        elif self.model_type=='late_fusion':
          xT1=self.trans_l_self(xT)
          xA1=self.trans_a_self(xA)
          xV1=self.trans_v_self(xV)
          if self.T and self.V and self.A:
            xT1=xT1[-1]
            xA1=xA1[-1]
            xV1=xV1[-1]
            x1=torch.cat([xT1,xA1,xV1],-1)

          elif self.T and self.A and not self.V:
            xT1=xT1[-1]
            xA1=xA1[-1]
            xV1=xV1[-1]
            x1=torch.cat([xT1,xA1],-1)
            
          elif self.T and self.V and not self.A:
            xT1=xT1[-1]
            xA1=xA1[-1]
            xV1=xV1[-1]
            x1=torch.cat([xT1,xV1],-1)
          elif self.V and self.A and not self.T:
            xT1=xT1[-1]
            xA1=xA1[-1]
            xV1=xV1[-1]
            x1=torch.cat([xA1,xV1],-1)
          elif self.T and not self.V and not self.A:
            xT1=xT1[-1]
            xA1=xA1[-1]
            xV1=xV1[-1]
            x1=xT1
          elif self.A and not self.V and not self.T:
            xT1=xT1[-1]
            xA1=xA1[-1]
            xV1=xV1[-1]
            x1=xA1
          elif self.V and not self.A and not self.T:
            xT1=xT1[-1]
            xA1=xA1[-1]
            xV1=xV1[-1]
            x1=xV1

          x=self.linear1(x1)
          x=f.relu(x)
          x=f.dropout(x,0.12)
          
          x=self.linear1(x)
          x+=x1

          out=self.output1(x)
          
          if self.dataset=='iemocap':
            out=out.view(-1,2)
          return out


      ###################################################################### round robin triangle 1 ########################################################
        elif self.model_type=='round_robin':
          if self.T and self.V and self.A:
            xT1=self.trans_l_self(xT,xA,xA)
            
            
            # xT_=self.transMerge_1(xT1)
            xT_=xT1
            # _,xT_=self.gru1(xT1)
            # xT_=xT_.squeeze()
            xT_=xT_[-1]
            xA1=self.trans_a_self(xA,xV,xV)
            
          
            # xA_=self.transMerge_1(xA1)
            xA_=xA1
            xA_=xA_[-1]
            # _,xA_=self.gru1(xA1)
            # xA_=xA_.squeeze()

            xV1=self.trans_v_self(xV,xT,xT)
            # _,xV_=self.gru1(xV1)
            # xV_=xV_.squeeze()
            xV_=xV1
            # xV_=self.transMerge_1(xV1)
            xV_=xV_[-1]
            x1=torch.cat([xT_,xA_,xV_],1)

          elif self.T and self.A and not self.V:
            xT1=self.trans_l_self(xT,xA,xA)
            xT_=self.transMerge_1(xT1)
            xT_=xT_[-1]
            
            xA1=self.trans_a_self(xA,xV,xV)
            xA_=self.transMerge_1(xA1)
            xA_=xA_[-1]
            
            x1=torch.cat([xT_,xA_],1)
          
          elif self.T and self.V and not self.A:
            xT1=self.trans_l_self(xT,xA,xA)
            xT_=self.transMerge_1(xT1)
            xT_=xT_[-1]
            
            xA1=self.trans_a_self(xV,xT,xT)
            xA_=self.transMerge_1(xA1)
            xA_=xA_[-1]
            
            x1=torch.cat([xT_,xA_],1)
        
          elif self.V and self.A and not self.T:
            xT1=self.trans_l_self(xA,xV,xV)
            xT_=self.transMerge_1(xT1)
            xT_=xT_[-1]
            
            xA1=self.trans_a_self(xV,xT,xT)
            xA_=self.transMerge_1(xA1)
            xA_=xA_[-1]
            
            x1=torch.cat([xT_,xA_],1)
          
          elif self.T and not self.V and not self.A:
            xT1=self.trans_l_self(xT,xA,xA)
            xT_=self.transMerge_1(xT1)
            xT_=xT_[-1]
            
            
            
            x1=xT_
          
          elif self.A and not self.V and not self.T:
            xT1=self.trans_l_self(xA,xV,xV)
            xT_=self.transMerge_1(xT1)
            xT_=xT_[-1]
            
            
            
            x1=xT_

          
          elif self.V and not self.A and not self.T:
            xT1=self.trans_l_self(xV,xT,xT)
            xT_=self.transMerge_1(xT1)
            xT_=xT_[-1]
            
            
            
            x1=xT_
          
          x=self.linear1(x1)
          x=f.relu(x)
          x=f.dropout(x,0.12)
          
          x=self.linear1(x)
          x+=x1

          out=self.output1(x)
          
          if self.dataset=='iemocap':
            out=out.view(-1,2)
          return out

        ###################################################################### round robin triangle 2 ########################################################
        elif self.model_type=='round_robin2':
          if self.T and self.V and self.A:
            xT1=self.trans_l_self(xT,xV,xV)
            
            
            # xT_=self.transMerge_1(xT1)
            xT_=xT1
            xT_=xT_[-1]
            xA1=self.trans_a_self(xA,xT,xT)
            
          
            # xA_=self.transMerge_1(xA1)
            xA_=xA1
            xA_=xA_[-1]
            
            xV1=self.trans_v_self(xV,xA,xA)
            
            # xV_=self.transMerge_1(xV1)
            xV_=xV1
            xV_=xV_[-1]
            x1=torch.cat([xT_,xA_,xV_],1)

          elif self.T and self.A and not self.V:
            xT1=self.trans_l_self(xT,xV,xV)
            # xT_=self.transMerge_1(xT1)
            xT_=xT1
            xT_=xT_[-1]
            
            xA1=self.trans_a_self(xA,xT,xT)
            # xA_=self.transMerge_1(xA1)
            xA_=xA1[-1]
            
            x1=torch.cat([xT_,xA_],1)
          
          elif self.T and self.V and not self.A:
            xT1=self.trans_l_self(xT,xV,xV)
            # xT_=self.transMerge_1(xT1)
            xT_=xT1[-1]
            
            xA1=self.trans_a_self(xV,xA,xA)
            # xA_=self.transMerge_1(xA1)
            xA_=xA1[-1]
            
            x1=torch.cat([xT_,xA_],1)
        
          elif self.V and self.A and not self.T:
            xT1=self.trans_l_self(xA,xT,xT)
            # xT_=self.transMerge_1(xT1)
            xT_=xT1[-1]
            
            xA1=self.trans_a_self(xV,xA,xA)
            # xA_=self.transMerge_1(xA1)
            xA_=xA1[-1]
            
            x1=torch.cat([xT_,xA_],1)
          
          elif self.T and not self.V and not self.A:
            xT1=self.trans_l_self(xT,xV,xV)
            # xT_=self.transMerge_1(xT1)
            xT_=xT1[-1]
            
            
            
            x1=xT_
          
          elif self.A and not self.V and not self.T:
            xT1=self.trans_l_self(xA,xT,xT)
            # xT_=self.transMerge_1(xT1)
            xT_=xT1[-1]
            
            
            
            x1=xT_

          
          elif self.V and not self.A and not self.T:
            xT1=self.trans_l_self(xV,xA,xA)
            # xT_=self.transMerge_1(xT1)
            xT_=xT1[-1]
            
            
            
            x1=xT_
          
          x=self.linear1(x1)
          x=f.relu(x)
          x=f.dropout(x,0.12)
          
          x=self.linear1(x)
          x+=x1

          out=self.output1(x)
          
          if self.dataset=='iemocap':
            out=out.view(-1,2)
          return out
        elif self.model_type=='round_robin3':
          if self.T and self.V and self.A:
            xT1=self.trans_l_self(xT,xA,xA)
            
            
            # xT_=self.transMerge_1(xT1)
            xT_=xT1
            xT_=xT_[-1]
            xA1=self.trans_a_self(xA,xV,xV)
            
          
            # xA_=self.transMerge_1(xA1)
            xA_=xA1
            xA_=xA_[-1]
            
            xV1=self.trans_v_self(xV,xT,xT)
            
            # xV_=self.transMerge_1(xV1)
            xV_=xV1
            xV_=xV_[-1]
            x1=torch.cat([xT_,xA_,xV_],1)

          elif self.T and self.A and not self.V:
            xT1=self.trans_l_self(xT,xA,xA)
            # xT_=self.transMerge_1(xT1)
            xT_=xT1
            xT_=xT_[-1]
            
            xA1=self.trans_a_self(xA,xT,xT)
            # xA_=self.transMerge_1(xA1)
            xA_=xA1[-1]
            
            x1=torch.cat([xT_,xA_],1)
          
          elif self.T and self.V and not self.A:
            xT1=self.trans_l_self(xT,xV,xV)
            # xT_=self.transMerge_1(xT1)
            xT_=xT1[-1]
            
            xA1=self.trans_a_self(xV,xT,xT)
            # xA_=self.transMerge_1(xA1)
            xA_=xA1[-1]
            
            x1=torch.cat([xT_,xA_],1)
        
          elif self.V and self.A and not self.T:
            xT1=self.trans_l_self(xA,xV,xV)
            # xT_=self.transMerge_1(xT1)
            xT_=xT1[-1]
            
            xA1=self.trans_a_self(xV,xA,xA)
            # xA_=self.transMerge_1(xA1)
            xA_=xA1[-1]
            
            x1=torch.cat([xT_,xA_],1)
          
          elif self.T and not self.V and not self.A:
            xT1=self.trans_l_self(xT,xV,xV)
            # xT_=self.transMerge_1(xT1)
            xT_=xT1[-1]
            
            
            
            x1=xT_
          
          elif self.A and not self.V and not self.T:
            xT1=self.trans_l_self(xA,xT,xT)
            # xT_=self.transMerge_1(xT1)
            xT_=xT1[-1]
            
            
            
            x1=xT_

          
          elif self.V and not self.A and not self.T:
            xT1=self.trans_l_self(xV,xA,xA)
            # xT_=self.transMerge_1(xT1)
            xT_=xT1[-1]
            
            
            
            x1=xT_
          
          x=self.linear1(x1)
          x=f.relu(x)
          x=f.dropout(x,0.12)
          
          x=self.linear1(x)
          x+=x1

          out=self.output1(x)
          
          if self.dataset=='iemocap':
            out=out.view(-1,2)
          return out
        else:
          raise ValueError("please select an appropriate model type")

