import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

import random
#random.seed(666)
import logging
logger = logging.getLogger(__name__)

class Attention(nn.Module):  
    def __init__(self, in_dim, decoder_dim, att_dim, num_heads=1, mode='dot'):
        super(Attention,self).__init__()
        
        self.mode = mode
        self.num_heads = num_heads
        self.att_dim = att_dim
        self.softmax = nn.Softmax(dim=-1)

        self.comp_listener_feature = None

        self.psi = nn.Linear(in_dim,att_dim)
        self.phi = nn.Linear(decoder_dim,att_dim*num_heads,bias=False)

        if num_heads > 1:
            self.merge_head = nn.Linear(in_dim*num_heads,in_dim)

        
        if self.mode == 'loc':
            assert self.num_heads==1
            # TODO : Move this to config
            C = 10
            K = 100
            self.prev_att  = None
            self.loc_conv = nn.Conv1d(1, C, kernel_size=2*K+1, padding=K, bias=False)
            self.loc_proj = nn.Linear(C,att_dim,bias=False)
            self.gen_energy = nn.Linear(att_dim, 1)
    
    def reset_enc_mem(self):
        self.comp_listener_feature = None
        self.state_mask = None

        if self.mode == 'loc':
            self.prev_att = None


    def forward(self, decoder_state, listener_feature, state_len, scale=2.0):
        # Store enc state to save time
        if self.comp_listener_feature is None:
            # Maskout attention score for padded states
            # NOTE: mask MUST have all input > 0 
            self.state_mask = np.zeros((listener_feature.shape[0],listener_feature.shape[1]))
            for idx,sl in enumerate(state_len):
                self.state_mask[idx,sl:] = 1
            self.state_mask = torch.from_numpy(self.state_mask).type(torch.bool).to(decoder_state.device)
            self.comp_listener_feature =  torch.tanh(self.psi(listener_feature))

        comp_decoder_state = torch.tanh(self.phi(decoder_state))

        if self.mode == 'dot':
            if self.num_heads == 1:
                # print(self.comp_listener_feature.shape)
                # print(comp_decoder_state.unsqueeze(2).shape)
                energy = torch.bmm(self.comp_listener_feature,comp_decoder_state.unsqueeze(2)).squeeze(dim=2)
                # print(energy.shape)
                energy.masked_fill_(self.state_mask,-float("Inf"))
                attention_score = [self.softmax(energy*scale)]
                # print(attention_score[0].shape)
                context = torch.bmm(attention_score[0].unsqueeze(1),listener_feature).squeeze(1)
                # print(context.shape)
            else:
                attention_score = [self.softmax(torch.bmm(self.comp_listener_feature,att_querry.unsqueeze(2)).squeeze(dim=2))\
                                        for att_querry in torch.split(comp_decoder_state, self.att_dim, dim=-1)]
                for idx in range(self.num_heads):
                    attention_score[idx].masked_fill_(self.state_mask,-float("inf"))
                    attention_score[idx] = self.softmax(attention_score[idx])
                projected_src = [torch.bmm(att_s.unsqueeze(1),listener_feature).squeeze(1) \
                                for att_s in attention_score]
                context = self.merge_head(torch.cat(projected_src,dim=-1))
        elif self.mode == 'loc':
            if self.prev_att is None:
                # Uniformly init attention,平均初始化
                bs,ts,_ = self.comp_listener_feature.shape
                self.prev_att = torch.zeros((bs,1,ts)).to(self.comp_listener_feature.device).half()
                for idx,sl in enumerate(state_len):
                    self.prev_att[idx,:,:sl] = 1.0/sl

            comp_decoder_state = comp_decoder_state.unsqueeze(1)
            comp_location_info = torch.tanh(self.loc_proj(self.loc_conv(self.prev_att).transpose(1,2)))
            energy = self.gen_energy(torch.tanh(self.comp_listener_feature+ comp_decoder_state+comp_location_info)).squeeze(2)
            energy.masked_fill_(self.state_mask,-float("inf"))
            attention_score = [self.softmax(energy*scale)]
            self.prev_att = attention_score[0].unsqueeze(1)
            context = torch.bmm(attention_score[0].unsqueeze(1),listener_feature).squeeze(1)
        
        return attention_score,context

class SimpleRNNDecoder(nn.Module):
    def __init__(self, input_dim, dim, layer, rnn_cell, dropout):
        super(SimpleRNNDecoder, self).__init__()
        assert "Cell" in rnn_cell,'Please use Recurrent Cell instead of layer in decoder'
        # Manually forward through Cells if using RNNCell family
        self.layer = layer
        self.dim = dim
        self.dropout = nn.Dropout(p=dropout)
        
        self.layer0 = getattr(nn,rnn_cell)(input_dim,dim)
        for i in range(1,layer):
            setattr(self,'layer'+str(i), getattr(nn,rnn_cell)(dim,dim))
        
        self.state_list = []
        self.cell_list = []
        
    def init_rnn(self,context):
        self.state_list = [torch.zeros(context.shape[0],self.dim).to(context.device).half()]*self.layer
        self.cell_list = [torch.zeros(context.shape[0],self.dim).to(context.device).half()]*self.layer
        # self.state_list = [torch.zeros(context.shape[0],self.dim).to(context.device)]*self.layer
        # self.cell_list = [torch.zeros(context.shape[0],self.dim).to(context.device)]*self.layer

    @property
    def hidden_state(self):
        return [s.clone().detach().cpu() for s in self.state_list], [c.clone().detach().cpu() for c in self.cell_list]

    @hidden_state.setter
    def hidden_state(self, state): # state is a tuple of two list
        device = self.state_list[0].device
        self.state_list = [s.to(device) for s in state[0]]
        self.cell_list = [c.to(device) for c in state[1]]
    
    def forward(self, input_context):
        # print(self.state_list[0].shape, self.cell_list[0].shape)
        self.state_list[0],self.cell_list[0] = self.layer0(self.dropout(input_context),(self.state_list[0],self.cell_list[0]))
        for l in range(1,self.layer):
            self.state_list[l],self.cell_list[l] = getattr(self,'layer'+str(l))(self.state_list[l-1],(self.dropout(self.state_list[l]),self.cell_list[l]))
        
        return self.state_list[-1]

class HybridDecoderWOAtt(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, vocab_size, decoder_layer, decoder_cell_type='LSTMCell', dropout=0):
        super(HybridDecoderWOAtt, self).__init__()
        self.embed = nn.Embedding(vocab_size, decoder_dim)
        self.decoder = SimpleRNNDecoder(encoder_dim+decoder_dim, decoder_dim, decoder_layer, decoder_cell_type, dropout)
        self.char_trans = nn.Linear(decoder_dim, vocab_size)
        self.ctc_layer = nn.Linear(encoder_dim, vocab_size)
    
    def forward(self, encoder_feature, padding_mask, tf_rate, teacher):
        encoder_feature = encoder_feature.transpose(0,1)
        bs = encoder_feature.shape[0]

        ctc_out = self.ctc_layer(encoder_feature)

        if self.training:
            decoder_step = int(torch.max(torch.sum(teacher!=0,dim=-1)))
        else:
            decoder_step = int(torch.max(torch.sum(teacher!=0,dim=-1))) * 2



        last_char = self.embed(torch.zeros((bs),dtype=torch.long).to(encoder_feature.device)).half()
        self.decoder.init_rnn(encoder_feature)
        output_char_seq = []

        teacher = self.embed(teacher)


        for t in range(decoder_step):
            context = torch.mean(encoder_feature, dim=1)
            decoder_input = torch.cat([last_char,context],dim=-1)
            dec_out = self.decoder(decoder_input)

            # print(dec_out.shape)

            cur_char = self.char_trans(dec_out)

            # print(cur_char.shape)

            # Teacher forcing
            if self.training and (teacher is not None):
                if random.random() <= tf_rate:
                    last_char = teacher[:, t, :]
                else:
                    sampled_char = Categorical(F.softmax(cur_char,dim=-1)).sample()
                    last_char = self.embed(sampled_char)
            else:
                last_char = self.embed(torch.argmax(cur_char,dim=-1))                    

            output_char_seq.append(cur_char)

        att_output = torch.stack(output_char_seq,dim=1)

        ctc_out = ctc_out.transpose(0,1)
        return ctc_out, att_output, padding_mask

class HybridDecoder(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, att_dim, att_head, att_mode, vocab_size, decoder_layer, decoder_cell_type='LSTMCell', dropout=0.1):
        super(HybridDecoder, self).__init__()
        self.attention_module = Attention(encoder_dim, decoder_dim, att_dim, att_head, att_mode)
        self.embed = nn.Embedding(vocab_size, decoder_dim)
        self.decoder = SimpleRNNDecoder(encoder_dim+decoder_dim, decoder_dim, decoder_layer, decoder_cell_type, dropout)
        self.char_trans = nn.Linear(decoder_dim, vocab_size)
        self.ctc_layer = nn.Linear(encoder_dim, vocab_size)
    
    def forward(self, encoder_feature, padding_mask, tf_rate, teacher):
        # logger.info("decode")
        # logger.info(teacher.shape)
        encoder_feature = encoder_feature.transpose(0,1)
        bs = encoder_feature.shape[0]

        ctc_out = self.ctc_layer(encoder_feature)

        # 这里其实max就是最长那个，因为没有bos token
        if self.training:
            decoder_step = int(torch.max(torch.sum(teacher!=0,dim=-1)))
        else:
            decoder_step = int(torch.max(torch.sum(teacher!=0,dim=-1))) * 2


        non_padding_mask = ~padding_mask
        encode_len = non_padding_mask.long().sum(-1)

        self.attention_module.reset_enc_mem()
        last_char = self.embed(torch.zeros((bs),dtype=torch.long).to(encoder_feature.device)).half()
        # logger.info(last_char.shape)
        self.decoder.init_rnn(encoder_feature)
        output_char_seq = []

        origin_teacher = teacher
        teacher = self.embed(teacher)

        attention_scores_seq = []

        for t in range(decoder_step):

            attention_score, context = self.attention_module(self.decoder.state_list[0],encoder_feature,encode_len)
            decoder_input = torch.cat([last_char,context],dim=-1)
            dec_out = self.decoder(decoder_input)

            # print(dec_out.shape)

            cur_char = self.char_trans(dec_out)

            # print(cur_char.shape)

            # Teacher forcing
            if self.training and (teacher is not None):
                if random.random() <= tf_rate:
                    last_char = teacher[:, t, :]
                else:
                    sampled_char = Categorical(F.softmax(cur_char,dim=-1)).sample()
                    last_char = self.embed(sampled_char)
            else:
                last_char = self.embed(torch.argmax(cur_char,dim=-1))                    

            output_char_seq.append(cur_char)
            #logger.info(attention_score[0].shape)
            attention_scores_seq.append(attention_score[0])

        att_output = torch.stack(output_char_seq,dim=1)
        att_scores_out = torch.stack(attention_scores_seq,dim=1) # bs, tg_len, src
        #logger.info(att_scores_out.shape)

        # logger.info("last")
        # logger.info(encoder_feature.shape)
        # logger.info(ctc_out.shape)
        # logger.info(att_output.shape)
        # logger.info(padding_mask.shape)

        # if padding_mask is not None:
        #     extra = padding_mask.size(1) % x.size(0) # padding_mask.T  % features.T'
        #     if extra > 0:
        #         padding_mask = padding_mask[:, :-extra] # remove mod part
        #     padding_mask = padding_mask.view(padding_mask.size(0), x.size(0), -1) # padding_mask first 2 dimention to features dimention B*T'*C'
        #     padding_mask = padding_mask.all(-1)
        ctc_out = ctc_out.transpose(0,1)
        return ctc_out, att_output, padding_mask, att_scores_out

if __name__ == "__main__":
    listener_feature_dim = 768
    listener_feature_len = 200
    decoder_dim = 512
    att_dim = 1024
    num_heads = 1
    out_dim = 29 # vocab size
    bs = 3
    decoder_step = 10
    

    listener_feature = torch.rand((bs, listener_feature_len, listener_feature_dim))
    encode_len = torch.ones((bs, 1),dtype=torch.int) * listener_feature_len
    

    attention_module = Attention(listener_feature_dim, decoder_dim, att_dim, mode="loc")
    embed = nn.Embedding(out_dim, decoder_dim)
    decoder = SimpleRNNDecoder(listener_feature_dim+decoder_dim, decoder_dim, 1, 'LSTMCell', 0.1)
    char_trans = nn.Linear(decoder_dim, out_dim)

    attention_module.reset_enc_mem()
    last_char = embed(torch.zeros((bs),dtype=torch.long))

    decoder.init_rnn(listener_feature)
    output_char_seq = []

    for t in range(decoder_step):

        attention_score, context = attention_module(decoder.state_list[0],listener_feature,encode_len)
        # context = torch.mean(listener_feature, dim=1)
        print(context.shape)
        print("decoding")
        print(context.shape)
        print(last_char.shape)
        decoder_input = torch.cat([last_char,context],dim=-1)
        dec_out = decoder(decoder_input)

        print(dec_out.shape)

        cur_char = char_trans(dec_out)

        print(cur_char.shape)


        # # Teacher forcing
        # if (teacher is not None):
        #     if random.random() <= tf_rate:
        #         last_char = teacher[:,t+1,:]
        #     else:
        #         sampled_char = Categorical(F.softmax(cur_char,dim=-1)).sample()
        #         last_char = self.embed(sampled_char)
        # else:
        #     last_char = self.embed(torch.argmax(cur_char,dim=-1))


        output_char_seq.append(cur_char)

    att_output = torch.stack(output_char_seq,dim=1)

    print("last")
    print(att_output.shape)




    ############ debug cross entropy loss
    # import torch
    # import torch.nn as nn
    # x_input=torch.randn(3,30000)#随机生成输入 
    # x_input[0][1] = 10000
    # x_input[1][2] = 10000
    # x_input[2][0] = 10000
    # print('x_input:\n',x_input) 
    # y_target=torch.tensor([1,2,0])#设置输出具体值 print('y_target\n',y_target)
    # print('y_target:\n',y_target) 

    # #计算输入softmax，此时可以看到每一行加到一起结果都是1
    # softmax_func=nn.Softmax(dim=1)
    # soft_output=softmax_func(x_input)
    # print('soft_output:\n',soft_output)

    # #在softmax的基础上取log
    # log_output=torch.log(soft_output)
    # print('log_output:\n',log_output)

    # #对比softmax与log的结合与nn.LogSoftmaxloss(负对数似然损失)的输出结果，发现两者是一致的。
    # logsoftmax_func=nn.LogSoftmax(dim=1)
    # logsoftmax_output=logsoftmax_func(x_input)
    # print('logsoftmax_output:\n',logsoftmax_output)

    # #pytorch中关于NLLLoss的默认参数配置为：reducetion=True、size_average=True
    # nllloss_func=nn.NLLLoss()
    # nlloss_output=nllloss_func(logsoftmax_output,y_target)
    # print('nlloss_output:\n',nlloss_output)

    # #直接使用pytorch中的loss_func=nn.CrossEntropyLoss()看与经过NLLLoss的计算是不是一样
    # crossentropyloss=nn.CrossEntropyLoss()
    # crossentropyloss_output=crossentropyloss(x_input,y_target)
    # print('crossentropyloss_output:\n',crossentropyloss_output)