import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from collections import OrderedDict
# from .beam_search import beam_decode
from .batch_beam_search import beam_decode
#from torchaudio.models import Conformer
from .conformer import Conformer
#from .conformer2 import Conformer
from .ViT import VisionTransformer 
from .GNN import DynGCN, GAT
from .TCN import TemporalConvNet
from .scale_mix import ScalarMix


# Gradient Reversal Layer
class GRLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lmbd=0.01):
        ctx.lmbd = torch.tensor(lmbd)
        return x.reshape_as(x)

    @staticmethod
    # 输入为forward输出的梯度
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return ctx.lmbd * grad_input.neg(), None


class MINE(nn.Module):
    def __init__(self, hid_size1, hid_size2, mode='jsd', norm=False):
        super(MINE, self).__init__()
        assert mode in ['mine', 'jsd']
        self.mode = mode
        self.norm = norm
        hid_size = hid_size1 + hid_size2
        #self.layers = nn.Sequential(nn.Linear(hid_size, hid_size//2),
        #                            nn.ReLU(),
        #                            nn.Linear(hid_size//2, 1))
        self.fc = nn.Linear(hid_size, 1)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.uniform_(m.weight, -0.1, 0.1)
                nn.init.zeros_(m.bias)

    def forward(self, x, y):
        batch_size = x.size(0)
        tiled_x = torch.cat((x, x), dim=0)
        #idx = torch.randperm(batch_size)
        #shuffled_y = y[idx]
        shuffled_y = torch.cat((y[1:], y[0].unsqueeze(0)), dim=0)
        concat_y = torch.cat((y, shuffled_y), dim=0)
        inputs = torch.cat((tiled_x, concat_y), dim=1)
        #logits = self.layers(inputs)
        logits = self.fc(F.relu(inputs))
        if self.norm:
            logits = F.normalize(logits, p=2, dim=1)
        pred_xy = logits[:batch_size]
        pred_x_y = logits[batch_size:]
        if self.mode == 'mine':
            mi_loss = - np.log2(np.exp(1)) * (torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y))))  # max mine
        else:
            mi_loss = -1 * (-F.softplus(-pred_xy).mean() - F.softplus(pred_x_y).mean())  # max jsd
        return mi_loss



class PositionEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super(PositionEmbedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.weight = nn.Embedding(num_embeddings, embedding_dim)
        torch.nn.init.xavier_normal_(self.weight.weight)

    def forward(self, x):
        embeddings = self.weight(x)
        return embeddings


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=200):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)   # (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)   # (max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)  # PE(pos, 2i)
        pe[:, 1::2] = torch.cos(position * div_term)  # PE(pos, 2i+1)
        pe = pe.unsqueeze(0)   # (1, max_len, d_model)
        self.register_buffer("pe", pe)

    def forward(self, L):
        return self.pe[:, :L].detach()


class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, bias=True, dropout=0.1):
        super().__init__()
        assert hid_dim % n_heads == 0
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.scale = self.head_dim ** -0.5

        self.fc_q = nn.Linear(hid_dim, hid_dim, bias)
        self.fc_k = nn.Linear(hid_dim, hid_dim, bias)
        self.fc_v = nn.Linear(hid_dim, hid_dim, bias)
        self.fc_o = nn.Linear(hid_dim, hid_dim, bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        bs = q.shape[0]
        # query = [batch size, query len, hid dim]
        # key = [batch size, key len, hid dim]
        # value = [batch size, value len, hid dim]
        
        Q, K, V = self.fc_q(q), self.fc_k(k), self.fc_v(v)
        # Q = [batch size, query len, hid dim]
        # K = [batch size, key len, hid dim]
        # V = [batch size, value len, hid dim]

        Q = Q.reshape(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
        K = K.reshape(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
        V = V.reshape(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
        # Q = [batch size, n heads, query len, head dim]
        # K = [batch size, n heads, key len, head dim]
        # V = [batch size, n heads, value len, head dim]

        # energy = torch.einsum('bnqh,bnkh->bnqk', Q, K) * scale
        energy = torch.matmul(Q, K.transpose(-1, -2).contiguous()) * self.scale
        # energy = [batch size, n heads, query len, key len]

        if mask is not None:   # [batch size, 1, 1, key len]
            energy = energy.masked_fill(mask == 0, -1e9)

        attention = torch.softmax(energy, dim=-1)
        # attention = [batch size, n heads, query len, key len]

        x = torch.matmul(self.dropout(attention), V)
        # x = [batch size, n heads, query len, head dim]

        x = x.transpose(1, 2).contiguous().reshape(bs, -1, self.hid_dim)
        # x = [batch size, query len, n heads, head dim] -> [batch size, query len, hid dim]

        x = self.fc_o(x)
        # x = [batch size, query len, hid dim]
        return x, attention


class FeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, ffn_dim, dropout):
        super().__init__()
        self.fc_1 = nn.Linear(hid_dim, ffn_dim)
        self.fc_2 = nn.Linear(ffn_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x = [batch size, seq len, hid dim]

        x = self.dropout(torch.relu(self.fc_1(x)))
        # x = [batch size, seq len, ffn dim]

        x = self.fc_2(x)
        # x = [batch size, seq len, hid dim]
        return x


class EncoderLayer(nn.Module):
    def __init__(self,
                 hid_dim,
                 n_heads,
                 ffn_dim,
                 dropout):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim, eps=1e-6)
        self.ff_layer_norm = nn.LayerNorm(hid_dim, eps=1e-6)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.feedforward = FeedforwardLayer(hid_dim, ffn_dim, dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_mask):
        # src = [batch size, src len, hid dim]
        # src_mask = [batch size, 1, 1, src len]
        _src, _ = self.self_attention(src, src, src, src_mask)

        src = self.self_attn_layer_norm(src + self.dropout(_src))
        # src = [batch size, src len, hid dim]

        _src = self.feedforward(src)

        src = self.ff_layer_norm(src + self.dropout(_src))
        # src = [batch size, src len, hid dim]
        return src


class CausalConv(nn.Module):
    def __init__(self, in_channel, hid_dim, out_channel, kernel_size=3, dilation=1, bid=False, dropout=0.1):
        super(CausalConv, self).__init__()
        self.padding = (kernel_size - 1) * dilation
        self.bid = bid
        self.fwd_caus_conv = nn.Sequential(
            nn.ConstantPad1d((self.padding, 0), 0),   # F.pad(x, (self.padding, 0))
            nn.Conv1d(in_channel, hid_dim, kernel_size, padding=0, dilation=dilation),
            nn.BatchNorm1d(hid_dim),
            nn.ReLU(),
            nn.ConstantPad1d((self.padding, 0), 0),
            nn.Conv1d(hid_dim, out_channel, kernel_size, padding=0, dilation=dilation),
            nn.Dropout(dropout),
        )

        if self.bid:
            self.bwd_caus_conv = nn.Sequential(
                nn.ConstantPad1d((self.padding, 0), 0),   # F.pad(x, (self.padding, 0))
                nn.Conv1d(in_channel, hid_dim, kernel_size, padding=0, dilation=dilation),
                nn.BatchNorm1d(hid_dim),
                nn.ReLU(),
                nn.ConstantPad1d((self.padding, 0), 0),
                nn.Conv1d(hid_dim, out_channel, kernel_size, padding=0, dilation=dilation),
                nn.Dropout(dropout),
            )
        
        self.layer_norm = nn.LayerNorm(out_channel)

    def forward(self, x):  # (B, T, D)
        out = self.fwd_caus_conv(x.transpose(1, 2)).transpose(1, 2)  # BTD to BDT to BTD
        return self.layer_norm(x + out)
        if self.bid:
            rev_x = torch.flip(x, dims=[1])  # 沿着第二个dim进行反转
            rev_out = self.bwd_caus_conv(rev_x.transpose(1, 2)).transpose(1, 2)  # BTD to BDT to BTD
            rev_out = torch.flip(rev_out, dims=[1])
            #out = torch.cat((out, rev_out), dim=-1)
            out = out + rev_out + x
            return out


class DecoderLayer(nn.Module):
    def __init__(self,
                 hid_dim,
                 n_heads,
                 ffn_dim,
                 norm_before=True,
                 dropout=0.1):
        super().__init__()
        self.norm_before = norm_before
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim, eps=1e-6)
        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim, eps=1e-6)
        self.ff_layer_norm = nn.LayerNorm(hid_dim, eps=1e-6)
        self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout)
        self.feedforward = FeedforwardLayer(hid_dim, ffn_dim, dropout)
        self.dropout = nn.Dropout(dropout)
        #self.conv_layer = CausalConv(hid_dim, hid_dim*2, hid_dim, 3, dropout=dropout)

    def forward(self, tgt, enc_src, tgt_mask, src_mask):
        # tgt = [batch size, tgt len, hid dim]
        # enc_src = [batch size, src len, hid dim]
        # tgt_mask = [batch size, 1, tgt len, tgt len]
        # src_mask = [batch size, 1, 1, src len]
        residual = tgt
        if self.norm_before:
            tgt = self.self_attn_layer_norm(tgt)
        _tgt, _ = self.self_attention(tgt, tgt, tgt, tgt_mask)     # Masked Multi-Head Self-Attention
        # tgt = [batch size, tgt len, hid dim]
        x = residual + self.dropout(_tgt)
        if not self.norm_before:
            x = self.self_attn_layer_norm(x)
       
        residual = x
        if self.norm_before:
            x = self.enc_attn_layer_norm(x)
        _tgt, attention = self.encoder_attention(x, enc_src, enc_src, src_mask)
        # tgt = [batch size, tgt len, hid dim]
        x = residual + self.dropout(_tgt)
        if not self.norm_before:
            x = self.enc_attn_layer_norm(x)
       
        residual = x
        if self.norm_before:
            x = self.ff_layer_norm(x)
        _tgt = self.feedforward(x)
        # tgt = [batch size, tgt len, hid dim]
        # attention = [batch size, n heads, tgt len, src len]
        x = residual + self.dropout(_tgt)
        if not self.norm_before:
            x = self.ff_layer_norm(x)
        
        # add conv layer
        # x = self.conv_layer(x)
        return x, attention


# 编码Landmark相对位置信息
class MLP(nn.Module):
    def __init__(self, in_channel=19*2, hid_dim=64, out_channel=256, bias=True):
        super(MLP, self).__init__()
        self.conv = nn.Sequential(
                nn.Conv1d(in_channel, hid_dim, kernel_size=1, bias=bias),
                nn.BatchNorm1d(hid_dim),
                #nn.ReLU(),
                nn.SiLU(),
                nn.Conv1d(hid_dim, out_channel, kernel_size=1, bias=bias),
                #nn.BatchNorm1d(out_channel), # add
            )

    def forward(self, x):   # (B, L, D)
        x = x.transpose(-1, -2)  # (B, D, L)
        x = self.conv(x)
        return x.transpose(-1, -2)


# landmark motion dynamics
class MotionConv(nn.Module):
    def __init__(self, in_channel=1, hid_dim=64, out_channel=256, bias=True):
        super(MotionConv, self).__init__()
        self.conv = nn.Sequential(
                nn.Conv1d(in_channel, hid_dim, kernel_size=1, padding=0, bias=bias),
                nn.BatchNorm1d(hid_dim),
                #nn.ReLU(),
                nn.SiLU(),
                nn.Conv1d(hid_dim, out_channel, kernel_size=3, padding=1, bias=bias),
                #nn.BatchNorm1d(out_channel)
            )

    def forward(self, x):   # (B, L, D)
        x = x.transpose(-1, -2)  # (B, D, L)
        x = self.conv(x)
        return x.transpose(-1, -2)


# 取代原始ViT用16x16 conv进行token化
class hMLP(nn.Module):
    def __init__(self, in_chans=3, embed_dim=768):
        super().__init__()
        self.proj = torch.nn.Sequential(
            *[
            nn.Conv2d(in_chans, embed_dim//4, kernel_size=4, stride=4),   # H/4 x W/4
            nn.BatchNorm2d(embed_dim//4),   # BN or LN
            nn.GELU(),
            nn.Conv2d(embed_dim//4, embed_dim//4, kernel_size=2, stride=2), # H/8 x W/8
            nn.BatchNorm2d(embed_dim//4),
            nn.GELU(),
            nn.Conv2d(embed_dim//4, embed_dim, kernel_size=2, stride=2),    # H/16 x W/16
            nn.BatchNorm2d(embed_dim),
        ])

    def forward(self, x): # (B, C, H, W)
        x = self.proj(x)
        return x

# Non-overlapping
class PatchEmbedding(nn.Module):
    def __init__(self, in_chans=3, embed_dim=768):
        super().__init__()
        self.proj = nn.Sequential(
            *[
            nn.Conv2d(in_chans, embed_dim//4, kernel_size=4, stride=4, bias=False),   # H/4 x W/4
            #nn.Conv2d(in_chans, embed_dim//4, kernel_size=2, stride=2),   # H/2 x W/2
            nn.BatchNorm2d(embed_dim//4),   # BN or LN
            #nn.GELU(),
            nn.ReLU(),
            nn.Conv2d(embed_dim//4, embed_dim//2, kernel_size=2, stride=2, bias=False), # H/8 x W/8
            nn.BatchNorm2d(embed_dim//2),
            #nn.GELU(),
            nn.ReLU(),
            nn.Conv2d(embed_dim//2, embed_dim, kernel_size=2, stride=2),    # H/16 x W/16
            #nn.BatchNorm2d(embed_dim),
        ])

    def forward(self, x): # (BT, C, H, W)
        x = self.proj(x)
        return x.flatten(2).transpose(-1, -2).squeeze(-2)   # (BT, D, N) -> (BT, N, D)

'''
class PatchEmbedding(nn.Module):
    def __init__(self, in_chans=3, embed_dim=768):
        super().__init__()
        self.proj = nn.Sequential(
            *[
            nn.Conv2d(in_chans, embed_dim//8, kernel_size=3, stride=2, padding=1, bias=False),   # H/2 x W/2
            nn.BatchNorm2d(embed_dim//8),   # BN or LN
            #nn.GELU(),
            nn.ReLU(),
            nn.Conv2d(embed_dim//8, embed_dim//4, kernel_size=3, stride=2, padding=1, bias=False), # H/4 x W/4
            nn.BatchNorm2d(embed_dim//4),
            #nn.GELU(),
            nn.ReLU(),
            nn.Conv2d(embed_dim//4, embed_dim//2, kernel_size=3, stride=2, padding=1, bias=False), # H/8 x W/8
            nn.BatchNorm2d(embed_dim//2),
            #nn.GELU(),
            nn.ReLU(),
            nn.Conv2d(embed_dim//2, embed_dim, kernel_size=3, stride=2, padding=1),    # H/16 x W/16
            #nn.BatchNorm2d(embed_dim),
        ])

    def forward(self, x): # (BT, C, H, W)
        x = self.proj(x)
        return x.flatten(2).transpose(-1, -2).squeeze(-2)  # (BT, D, N) -> (BT, N, D)
'''


# Overlapping
class EarlyConv(nn.Module):
    def __init__(self, in_chans=1, embed_dim=768, bias=False):
        super().__init__()
        self.conv_layers = nn.Sequential(*[
            nn.Conv2d(in_chans, embed_dim//4, kernel_size=3, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(embed_dim//4),   # BN or LN
            nn.ReLU(),
            nn.Conv2d(embed_dim//4, embed_dim//2, kernel_size=3, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(embed_dim//2),   # BN or LN
            nn.ReLU(),
            nn.Conv2d(embed_dim//2, embed_dim, kernel_size=3, stride=2, padding=1, bias=bias),
            #nn.BatchNorm2d(embed_dim),   # BN or LN
            #nn.ReLU(),
        ])
        self.gmp = nn.AdaptiveMaxPool2d(1)
        #self.gap = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.conv_layers(x)
        mx = self.gmp(x).squeeze(-1).squeeze(-1)
        #ax = self.gap(x).squeeze(-1).squeeze(-1)
        #out = torch.cat((mx, ax), dim=-1).contiguous()
        return mx


'''
class TubeletEarlyConv(nn.Module):
    def __init__(self, in_chans=1, embed_dim=768, bias=False):
        super().__init__()
        self.conv_layers = nn.Sequential(*[
            # (B, C, T, H, W)
            nn.Conv3d(in_chans, embed_dim//2, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), bias=bias),
            nn.BatchNorm3d(embed_dim//2),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
            #nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0, 0, 0)),
            nn.Conv3d(embed_dim//2, embed_dim, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), bias=bias),
            #nn.Conv3d(embed_dim//2, embed_dim, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=bias),
            nn.BatchNorm3d(embed_dim),
            nn.ReLU(),
        ])
        
        self.gmp = nn.AdaptiveMaxPool3d((None, 1, 1))
        #self.gmp = nn.AdaptiveAvgPool3d((None, 1, 1))

    def forward(self, x):  # (B, T, N, C, H, W)
        B, T, N = x.shape[:3]
        x = x.transpose(1, 2).flatten(0, 1).transpose(1, 2)  # (BN, T, C, H, W) -> (BN, C, T, H, W)
        x = self.conv_layers(x)
        x = self.gmp(x.transpose(1, 2)).squeeze(-1).squeeze(-1)   # (BN, C, T, H, W) -> (BN, T, C, 1, 1) -> (BN, T, C)
        x = x.reshape(B, N, T, -1).transpose(1, 2)
        return x  # (B, T, N, C)
'''


class TubeletEarlyConv(nn.Module):
    def __init__(self, in_chans=1, embed_dim=768, bias=False):
        super().__init__()
        self.conv_layers = nn.Sequential(*[
            # (B, C, T, H, W)
            #nn.Conv3d(in_chans, embed_dim//4, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), bias=bias),
            #nn.Conv3d(in_chans, embed_dim//4, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), bias=bias),
            #nn.Conv3d(in_chans, embed_dim//4, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=bias),
            nn.Conv3d(in_chans, embed_dim//4, kernel_size=(5, 3, 3), stride=(1, 2, 2), padding=(2, 1, 1), bias=bias),
            nn.BatchNorm3d(embed_dim//4),
            #nn.ReLU(),
            nn.SiLU(),
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.Conv3d(embed_dim//4, embed_dim//2, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=bias),
            nn.BatchNorm3d(embed_dim//2),
            #nn.ReLU(),
            nn.SiLU(),
            nn.Conv3d(embed_dim//2, embed_dim, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=bias),
            nn.BatchNorm3d(embed_dim),
            #nn.ReLU(),
            nn.SiLU()
        ])

       
        '''
        self.front3d = nn.Sequential(*[
            # (B, C, T, H, W)
            #nn.Conv3d(in_chans, embed_dim//4, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1), bias=bias),
            #nn.Conv3d(in_chans, embed_dim//4, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), bias=bias),
            #nn.Conv3d(in_chans, embed_dim//4, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=bias),
            nn.Conv3d(in_chans, embed_dim//4, kernel_size=(5, 3, 3), stride=(1, 2, 2), padding=(2, 1, 1), bias=bias),
            nn.BatchNorm3d(embed_dim//4),
            #nn.ReLU(),
            nn.SiLU(),
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
            #nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0, 0, 0)),
        ])
        self.front2d = nn.Sequential(*[
            nn.Conv2d(embed_dim//4, embed_dim//2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=bias),
            nn.BatchNorm2d(embed_dim//2),
            #nn.ReLU(),
            nn.SiLU(),
            nn.Conv2d(embed_dim//2, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=bias),
            nn.BatchNorm2d(embed_dim),
            #nn.ReLU(),
            nn.SiLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        ])
        '''

        #self.gmp = nn.AdaptiveMaxPool3d((None, 1, 1))
        #self.gmp = nn.AdaptiveAvgPool3d((None, 1, 1))
        self.gmp = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):  # (B, T, N, C, H, W)
        B, T, N = x.shape[:3]
        x = x.transpose(1, 2).flatten(0, 1).transpose(1, 2)  # (BN, T, C, H, W) -> (BN, C, T, H, W)
        x = self.conv_layers(x)
        #x = self.gmp(x.transpose(1, 2)).squeeze(-1).squeeze(-1)   # (BN, C, T, H, W) -> (BN, T, C, 1, 1) -> (BN, T, C)
        #x = x.reshape(B, N, T, -1).transpose(1, 2)
        #x = self.gmp(x).squeeze(-1).squeeze(-1)   # (BN, C, T, H, W) -> (BN, C, T, 1, 1) -> (BN, C, T)
        #x = x.reshape(B, N, -1, T).permute(0, 3, 1, 2)
        ## 2D pooling
        x = self.gmp(x.transpose(1, 2).flatten(0, 1)).squeeze(-1).squeeze(-1)  # (BN, C, T, H, W) -> (BN, T, C, H, W) -> (BNT, C, 1, 1) -> (BNT, C)
        x = x.reshape(B, N, T, -1).transpose(1, 2)
        #x = self.front3d(x)
        #x = self.front2d(x.transpose(1, 2).flatten(0, 1)).squeeze(-1).squeeze(-1)
        #x = x.reshape(B, N, T, -1).transpose(1, 2)
        return x  # (B, T, N, C)


# 提取Patch图像信息
class VisualFrontEnd2D(nn.Module):
    def __init__(self, in_channel=3, embed_dim=256):
        super(VisualFrontEnd2D, self).__init__()
        self.conv = nn.Sequential(
                nn.Conv2d(in_channel, embed_dim//2, kernel_size=3, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(embed_dim//2),
                nn.ReLU(),
                nn.Conv2d(embed_dim//2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(embed_dim),
                nn.AdaptiveMaxPool2d(1)
            )
        
    def forward(self, x):  # B x C x H x W
        out = self.conv(x)   
        out = out.squeeze()  # (B, D)
        return out


class VisualFrontEnd3D(nn.Module):
    def __init__(self, in_channel=1, hidden_dim=256):
        super(VisualFrontEnd3D, self).__init__()

        self.stcnn = nn.Sequential(
            #nn.Conv3d(in_channel, hidden_dim//2, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False),
            nn.Conv3d(in_channel, hidden_dim//4, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), bias=False),
            nn.BatchNorm3d(hidden_dim//4),
            nn.ReLU(),
            #nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            nn.Conv3d(hidden_dim//4, hidden_dim//2, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2), bias=False),
            nn.BatchNorm3d(hidden_dim//2),
            nn.ReLU(),
            #nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            nn.Conv3d(hidden_dim//2, hidden_dim, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(),
            #nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
        )

        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(4608, hidden_dim)   # full face

    def forward(self, x):   # (B, C, T, H, W)
        cnn = self.stcnn(x)  # (B, D, T, H, W)
        '''
        if self.training:  # time masking
            tau = cnn.shape[2]
            t = int(10 * np.random.rand())  # t < T
            t0 = int((tau - t) * np.random.rand())
            cnn[:, :, t0: t0 + t, :, :] = torch.mean(cnn.clone(), dim=2, keepdim=True)
        '''
        cnn = cnn.transpose(1, 2).contiguous()  # (B, T, D, H, W)
        b, t, d, h, w = cnn.size()
        cnn = cnn.reshape(b, t, -1)  # (B, T, D)
        return self.fc(self.dropout(cnn))



# 前端预训练
class PretrainLipModel(nn.Module):
    def __init__(self, opt):
        super().__init__()

        self.lipcnn = VisualFrontEnd3D(opt.in_channel, opt.hidden_dim)

        self.conformer = Conformer(
                input_dim=opt.hidden_dim,
                num_heads=opt.head_num,
                ffn_dim=opt.ffn_ratio * opt.hidden_dim,
                num_layers=opt.enc_layers,
                depthwise_conv_kernel_size=31,
                dropout=opt.drop_attn)
      
        self.decoder = Decoder(opt.tgt_vocab_size,
                               opt.hidden_dim,
                               opt.dec_layers,
                               opt.head_num,
                               opt.ffn_ratio,
                               opt.drop_attn) 
        
        self.ctc_dec = nn.Linear(opt.hidden_dim, opt.tgt_vocab_size-1)  # including blank and space, excluding bos 

    def get_repr(self, lips, src_lens):  
        ## vid: (b, t, c, h, w)
        lips = lips.transpose(1, 2).contiguous()  # (b, c, t, h, w)
        lip_feat = self.lipcnn(lips)
        #kl1 = nn.functional.kl_div(large_feat.log_softmax(-1), org_feat.softmax(-1), reduction='batchmean')
        #motion_loss = nn.functional.smooth_l1_loss(vid_feat, motion_feat)
        src, _ = self.conformer(lip_feat, src_lens)
        return src

    def get_mask_from_lengths(self, lengths, max_len=None):
        '''
         param:   lengths - [batch_size]
         return:  mask - [batch_size, max_len]
        '''
        batch_size = lengths.shape[0]
        if max_len is None:
            max_len = torch.max(lengths).item()
        ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device)
        mask = ids < lengths.unsqueeze(1).expand(-1, max_len)    ## True or False
        return mask 

    def forward(self, src, tgt, src_lens, tgt_lens):
        enc_src = self.get_repr(src, src_lens)
        src_mask = self.get_mask_from_lengths(src_lens, src.shape[1]).unsqueeze(1).unsqueeze(2)
        
        log_probs = self.ctc_dec(enc_src).transpose(0, 1).log_softmax(dim=-1)
        ctc_loss = F.ctc_loss(log_probs, tgt[:, 1:], src_lens.reshape(-1), tgt_lens.reshape(-1), zero_infinity=True)
        
        logits, _ = self.decoder(tgt[:, :-1], enc_src, src_mask)
        attn_loss = F.cross_entropy(logits.transpose(-1, -2).contiguous(), tgt[:, 1:].long(), ignore_index=0)
        
        loss = 0.9 * attn_loss + 0.1 * ctc_loss
        return loss

    def decode(self, src, src_lens, bos_id, eos_id, max_dec_len=80, pad_id=0):
        enc_src = self.get_repr(src, src_lens)
        src_mask = self.get_mask_from_lengths(src_lens, src.shape[1]).unsqueeze(1).unsqueeze(2)
        #res = self.ctc_dec(enc_src).argmax(dim=-1)
        res = beam_decode(self.decoder, enc_src, src_mask, bos_id, eos_id, max_output_length=max_dec_len, beam_size=6)
        return res.detach().cpu().numpy()

    def save(self, path):
        torch.save({'model': self.state_dict()}, path)
        print('saved !!!')

    def load(self, path):
        self.load_state_dict(torch.load(path, map_location='cpu')['model'])
        print('loaded !!!')


def diff_loss(hx, hy):
    if hx.ndim == 3:
        hx = hx.flatten(0, 1)
        #hx = hx.mean(dim=1)
    if hy.ndim == 3:
        hy = hy.flatten(0, 1)
        #hy = hy.mean(dim=1)
    #dot = torch.matmul(hx, hy.transpose(0, 1))
    #norm_dot = F.normalize(dot - torch.mean(dot, 0))
    #return torch.sum(norm_dot ** 2)
    hx = F.normalize(hx - torch.mean(hx, 0))
    hy = F.normalize(hy - torch.mean(hy, 0))
    dot_mat = torch.matmul(hx, hy.transpose(0, 1))
    return torch.sum(dot_mat ** 2).mean()
    # return torch.norm(dot_mat, p='fro') ** 2


class ConformerEncoder(nn.Module):
    def __init__(self,
                 in_channel,
                 hid_dim,
                 n_layers,
                 n_heads,
                 ffn_ratio,
                 dropout,
                 max_length=100):
        super().__init__()
        self.hid_dim = hid_dim
	# lip or face ROI
        #self.lipcnn = VisualFrontEnd3D(in_channel, hid_dim)
        
        # face + landmark + motion 
        #self.visual_front = VisualFrontEnd2D(in_channel, hid_dim)

        #self.patch_embed = nn.Conv2d(in_channel, hid_dim, kernel_size=16, stride=16)
        #self.patch_embed = PatchEmbedding(in_channel, hid_dim)
        #num_patches = (48 // 16) * (80 // 16)
        #self.pos_embed = nn.Parameter(torch.randn(1, 1, num_patches, hid_dim))  # (B, T, N, D)
        #nn.init.trunc_normal_(self.pos_embed, std=0.02)
        #self.patch_conv = EarlyConv(in_channel, hid_dim)
        self.tubelet_conv = TubeletEarlyConv(in_channel, hid_dim)

        # ViT
        #self.vit = VisionTransformer(patch_size=20, in_chans=in_channel, embed_dim=hid_dim, depth=3, num_heads=n_heads, mlp_ratio=ffn_ratio, patch_flatten=False)
        # GCN
        #self.gnn = DynGCN(patch_size=16, in_chans=in_channel, embed_dim=hid_dim, patch_flatten=False)
        self.gnn = GAT(hid_dim, hid_dim)
        
        self.pts_mlp = MLP((20-1)*2, hid_dim//2, hid_dim)
        
        #self.motion_mlp = MotionConv(44, hid_dim//2, hid_dim)   
        #self.motion_mlp = nn.Sequential(nn.Linear(44, hid_dim//2), nn.ReLU(), nn.Linear(hid_dim//2, hid_dim))
        ##self.motion_mlp = nn.GRU(44, hid_dim//2, num_layers=2, batch_first=True, bidirectional=True, dropout=0.2)
        ##self.motion_mlp = TemporalConvNet(num_inputs=44, num_channels=[hid_dim//2]*3, kernel_size=3, dropout=0.2)
       
        #self.vl_fusion = nn.Linear(hid_dim*2, hid_dim)
        
        '''
        fusion_dim = hid_dim + 50
        self.vl_fusion = nn.Sequential(
                    nn.Linear(fusion_dim, 2*fusion_dim),
                    nn.SiLU(),
                    #nn.BatchNorm1d(2*fusion_dim),  # (B*L, C)
                    #nn.ReLU(),
                    nn.Linear(2*fusion_dim, hid_dim),)
        '''
       
        #self.pos_embed = PositionalEncoding(hid_dim, 100)
        self.conformer = Conformer(
                input_dim=hid_dim,
                num_heads=n_heads,
                ffn_dim=ffn_ratio * hid_dim,
                num_layers=n_layers,
                depthwise_conv_kernel_size=31,
                dropout=dropout)

        self.dropout = nn.Dropout(dropout)

    # def get_mask_from_idxs(self, src):  # src = [batch size, src len]
    #     src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
    #     # src_mask = [batch size, 1, 1, src len]
    #     return src_mask

    def get_mask_from_lengths(self, lengths, max_len=None):
        '''
         param:   lengths - [batch_size]
         return:  mask - [batch_size, max_len]
        '''
        batch_size = lengths.shape[0]
        if max_len is None:
            max_len = torch.max(lengths).item()
        ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device)
        mask = ids < lengths.unsqueeze(1).expand(-1, max_len)    ## True or False
        return mask  

    def forward(self, lips, vid, pts, motion, src_lens):  
        # lips: (B, T, C, H, W)
        # vid: (B, T, N, C, H, W)
        # pts: (B, T, N, D)
        # motion: (B, T, D)
        
        ## vid: (b, t, c, h, w)
        #lips = lips.transpose(1, 2).contiguous()  # (b, c, t, h, w)
        #lip_feat = self.lipcnn(lips)
        
        ## vid: (b, t, 20, c, h, w)
        #vid = vid.flatten(0, 2).contiguous()  # (bxtx20, c, h, w)
        #vid_feat = self.visual_front(vid)   # (bxtx20, d)
        #vid_feat = vid_feat.reshape(B, T, -1, self.hid_dim).mean(dim=2)
        #pts_feat = self.pts_mlp(pts.flatten(0, 1))
        #pts_feat = pts_feat.reshape(B, T, -1, self.hid_dim).mean(dim=2)
        #motion_feat = self.motion_mlp(motion)
        #feat = vid_feat + pts_feat + motion_feat

        #lips = self.patch_embed(lips.flatten(0, 1)).flatten(2).transpose(1, 2)  # (BT, C, HW) -> (BT, N, C) 
        #lips = self.patch_embed(lips.flatten(0, 1)).transpose(1, 2)  # (BT, C, HW) -> (BT, N, C) 
        #lips = self.patch_embed(lips.flatten(0, 1)).reshape(B, T, -1, self.hid_dim) + self.pos_embed
        '''
        if vid is not None:
            B, T = lips.shape[:2]
            lips = self.patch_embed(lips.flatten(0, 1)).reshape(B, T, -1, self.hid_dim) 
            vid_feat = self.gnn(lips)
            #lip_feat = self.vit(lips)
        else:
            B, T = vid.shape[:2]
            pts_feat = self.pts_mlp(pts.flatten(0, 1)).reshape(B, T, -1, self.hid_dim)
            #vid = self.patch_conv(vid.flatten(0, 2)).reshape(B, T, -1, self.hid_dim) + pts_feat
            vid = self.patch_embed(vid.flatten(0, 2)).reshape(B, T, -1, self.hid_dim) + pts_feat
            vid_feat = self.gnn(vid) 
            #patch_feat = self.vit(vid)
        '''
        
        B, T = vid.shape[:2]
        pts_feat = self.pts_mlp(pts.flatten(0, 1)).reshape(B, T, -1, self.hid_dim)
        #pat_vid = self.patch_conv(vid.flatten(0, 2)).reshape(B, T, -1, self.hid_dim) + pts_feat
        #pat_feat = self.gnn(pat_vid) 
        tub_vid = self.tubelet_conv(vid) + pts_feat
        #tub_vid = self.tubelet_conv(vid)
        tub_feat = self.gnn(tub_vid) 
        #tub_feat = self.vit(tub_vid)
        
        #motion_feat = self.motion_mlp(motion)
        #motion_feat = self.motion_mlp(motion)[0]  # GRU
        #pat_feat = torch.cat((pat_feat, motion_feat), dim=-1)
        #feat = self.vl_fusion(pat_feat) + tub_feat
        #feat0 = torch.cat((tub_feat, motion_feat), dim=-1)
        #feat = self.vl_fusion(feat0)
        feat = tub_feat

        #motion_feat = self.motion_mlp(motion.transpose(-1, -2)).transpose(-1, -2)   # TCN
        #motion_loss = nn.functional.smooth_l1_loss(self.vid_proj(feat), motion)
        #motion_loss = nn.functional.kl_div(feat.log_softmax(-1), motion.softmax(-1))

        # content, spk_id 
        #spk_logit = self.spk_fc(GRLayer.apply(feat.mean(dim=1), lmbd))
        #spk_feat = self.spk_fc1(F.relu(feat.mean(dim=1)).detach())
        #spk_feat = self.spk_fc1(F.relu(feat.mean(dim=1)))
        #spk_logit = self.spk_fc2(spk_feat)

        src, _ = self.conformer(self.dropout(feat), src_lens)
        #src, _ = self.conformer(self.dropout(feat + self.pos_embed(T)), src_lens)
        #mi_loss = self.max_mi(feat.mean(1), src.mean(1)) - self.min_mi(spk_feat, src.mean(1)) 
        #mi_loss = - self.min_mi(spk_feat, feat.mean(1)) - self.min_mi(spk_feat, src.mean(1)) 
        #mi_loss = diff_loss(spk_feat, feat.mean(dim=1)) + diff_loss(spk_feat, src.mean(dim=1)) 
        src_mask = self.get_mask_from_lengths(src_lens, T).unsqueeze(1).unsqueeze(2) 
        return src, src_mask, feat

    def load_weights(self, path):
        pass


class Encoder(nn.Module):
    def __init__(self,
                 in_channel,
                 hid_dim,
                 n_layers,
                 n_heads,
                 ffn_ratio,
                 dropout,
                 max_length=100):
        super().__init__()
        self.hid_dim = hid_dim

        self.visual_front = VisualFrontEnd3D(in_channel, out_channel=hid_dim)

        self.pos_embedding = PositionalEncoding(hid_dim, max_length)
        self.layers = nn.ModuleList([EncoderLayer(hid_dim,
                                                  n_heads,
                                                  ffn_ratio * hid_dim,
                                                  dropout)
                                     for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)

    def get_mask_from_idxs(self, src):  # src = [batch size, src len]
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # src_mask = [batch size, 1, 1, src len]
        return src_mask

    def get_mask_from_lengths(self, lengths, max_len=None):
        '''
         param:   lengths --- [Batch_size]
         return:  mask --- [Batch_size, max_len]
        '''
        batch_size = lengths.shape[0]
        if max_len is None:
            max_len = torch.max(lengths).item()
        ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device)
        mask = ids < lengths.unsqueeze(1).expand(-1, max_len)    ## True or False
        return mask  

    def forward(self, x, src_lens):  # (b, t, c, h, w)
        B, T = x.shape[0], x.shape[1]
        src_mask = self.get_mask_from_lengths(src_lens, T).unsqueeze(1).unsqueeze(2)  # (b, 1, 1, t)

        x = x.transpose(1, 2).contiguous()  # (b, c, t, h, w)
        feat = self.visual_front(x)
        pos_embed = self.pos_embedding(T)
        src = self.dropout(feat + pos_embed)
        # src = [batch size, src len, hid dim]

        for layer in self.layers:
            src = layer(src, src_mask)
        # src = [batch size, src len, hid dim]
        return src, src_mask


class Decoder(nn.Module):
    def __init__(self,
                 num_cls,
                 hid_dim,
                 n_layers,
                 n_heads,
                 ffn_ratio,
                 dropout,
                 norm_before=False,
                 pad_idx=0,
                 max_length=100):
        super().__init__()
        self.hid_dim = hid_dim
        self.tgt_pad_idx = pad_idx  # pad token index
        self.norm_before = norm_before
        self.tok_embedding = nn.Embedding(num_cls, hid_dim)
        self.pos_embedding = PositionalEncoding(hid_dim, max_length)
        self.layers = nn.ModuleList([DecoderLayer(hid_dim,
                                                  n_heads,
                                                  ffn_ratio * hid_dim,
                                                  norm_before,
                                                  dropout)
                                     for _ in range(n_layers)])
        if self.norm_before:        
            self.post_layer_norm = nn.LayerNorm(hid_dim, eps=1e-6)
        self.fc_out = nn.Linear(hid_dim, num_cls - 1)   # excluding bos
        self.dropout = nn.Dropout(dropout)

    def make_tgt_mask(self, tgt):  # [batch size, tgt len]
        tgt_pad_mask = (tgt != self.tgt_pad_idx).unsqueeze(1).unsqueeze(2)
        # tgt_pad_mask = [batch size, 1, 1, tgt len]
        tgt_len = tgt.shape[1]
        # 下三角(包括对角线)
        tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
        # tgt_sub_mask = [tgt len, tgt len]
        tgt_mask = tgt_pad_mask & tgt_sub_mask
        # tgt_mask = [batch size, 1, tgt len, tgt len]
        return tgt_mask

    def forward(self, tgt, enc_src, src_mask):
        # tgt = [batch size, tgt len]
        # enc_src = [batch size, src len, hid dim]
        # tgt_mask = [batch size, 1, tgt len, tgt len]
        # src_mask = [batch size, 1, 1, src len]
        bs, tgt_len = tgt.shape[0], tgt.shape[1]

        tgt_mask = self.make_tgt_mask(tgt)  # [batch size, 1, tgt len, tgt len]

        pos_embed = self.pos_embedding(tgt_len)
        tgt = self.dropout((self.tok_embedding(tgt) * self.hid_dim**0.5) + pos_embed)
        # tgt = [batch size, tgt len, hid dim]

        for layer in self.layers:
            tgt, attention = layer(tgt, enc_src, tgt_mask, src_mask)
        # tgt = [batch size, tgt len, hid dim]
        # attention = [batch size, n heads, tgt len, src len]
        if self.norm_before:
            tgt = self.post_layer_norm(tgt)  # for layer norm before

        output = self.fc_out(tgt)
        # output = [batch size, tgt len, output dim]
        return output, attention



# class PointNet(nn.Module):
#     def __init__(self, in_channel=3, out_channel=256):
#         super(PointNet, self).__init__()
#         self.in_feature_dim = in_channel
#         # 1x1卷积核作用：改变特征的通道数(不改变尺寸大小)；计算过程中相当于全连接
#         self.conv1 = torch.nn.Conv1d(self.in_feature_dim, 64, 1)
#         self.conv2 = torch.nn.Conv1d(64, 64, 1)
#         self.conv3 = torch.nn.Conv1d(64, 64, 1)
#         self.conv4 = torch.nn.Conv1d(64, 128, 1)
#         self.conv5 = torch.nn.Conv1d(128, 1024, 1)

#         self.bn1 = nn.BatchNorm1d(64)
#         self.bn2 = nn.BatchNorm1d(64)
#         self.bn3 = nn.BatchNorm1d(64)
#         self.bn4 = nn.BatchNorm1d(128)
#         self.bn5 = nn.BatchNorm1d(1024)

#         self.fc1 = nn.Linear(1024, 512)
#         self.fc2 = nn.Linear(512, 256)
#         self.fc3 = nn.Linear(256, out_channel)

#     def forward(self, pc):   # (B, N, C)
#         pc = pc.transpose(1, 2)  # (B, C, N)
#         x = F.relu(self.bn1(self.conv1(pc)))
#         x = F.relu(self.bn2(self.conv2(x)))
#         x = F.relu(self.bn3(self.conv3(x)))
#         x = F.relu(self.bn4(self.conv4(x)))
#         x = F.relu(self.bn5(self.conv5(x)))
#         x = x.transpose(1, 2).contiguous()   # (B, N, D)
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = self.fc3(x)
#         return x


class PointNet(nn.Module):
    def __init__(self, in_channel=3, out_channel=256):
        super(PointNet, self).__init__()
        self.in_feature_dim = in_channel
        self.conv1 = torch.nn.Conv1d(self.in_feature_dim, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 64, 1)
        self.conv3 = torch.nn.Conv1d(64, 64, 1)
        self.conv4 = torch.nn.Conv1d(64, 128, 1)
        self.conv5 = torch.nn.Conv1d(128, 1024, 1)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(1024)

        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, out_channel)

    def forward(self, pc):   # (B, T, N, 3)
        B, T = pc.shape[0], pc.shape[1]
        pc = pc.reshape(-1, *pc.shape[-2:])    # (BT, N, 3)
        pc = pc.transpose(-1, -2).contiguous()  # (BT, 3, N)
        x = nn.functional.relu(self.bn1(self.conv1(pc)))
        x = nn.functional.relu(self.bn2(self.conv2(x)))
        x = nn.functional.relu(self.bn3(self.conv3(x)))
        x = nn.functional.relu(self.bn4(self.conv4(x)))
        x = nn.functional.relu(self.bn5(self.conv5(x)))
        x = torch.max(x, dim=2, keepdim=True)[0]
        x = x.reshape(B, T, -1)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        
        
        
class Seq2Seq(nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        '''        
        self.encoder = Encoder(opt.in_channel,
                               opt.hidden_dim,
                               opt.enc_layers,
                               opt.head_num,
                               opt.ffn_ratio,
                               opt.drop_attn)
        '''
        
        self.encoder = ConformerEncoder(opt.in_channel,
                               opt.hidden_dim,
                               opt.enc_layers,
                               opt.head_num,
                               opt.ffn_ratio,
                               opt.drop_attn)

        self.decoder = Decoder(opt.tgt_vocab_size,
                               opt.hidden_dim,
                               opt.dec_layers,
                               opt.head_num,
                               opt.ffn_ratio,
                               opt.drop_attn)
      
        self.dropout = nn.Dropout(opt.drop_attn)
        self.ctc_dec = nn.Linear(opt.hidden_dim, opt.tgt_vocab_size - 1)  # excluding bos; blank_label == pad_idx == 0
        
        self.max_mi = MINE(opt.hidden_dim, opt.hidden_dim)
        self.min_mi = MINE(opt.hidden_dim//2, opt.hidden_dim, norm=True)  # need
        
        #self.spk_fc1 = nn.Conv1d(opt.hidden_dim, opt.hidden_dim//2, kernel_size=1)
        #self.gap = nn.AdaptiveAvgPool1d(1)
        self.spk_fc1 = nn.Linear(opt.hidden_dim, opt.hidden_dim//2)
        self.spk_fc2 = nn.Linear(opt.hidden_dim//2, opt.num_spk)  # unseen-29, overlap-33
  
        '''
        self.hid_proj = nn.Linear(opt.hidden_dim, opt.hidden_dim)
        self.lip_enc = PretrainLipModel(opt)
        #self.lip_enc.load('checkpoints/grid/ckpt30-pid1552585-loss0.0006.pt')
        #self.lip_enc.load('checkpoints/grid/ckpt29-pid2068164-loss0.0040.pt')
        #self.lip_enc.load('checkpoints/grid/ckpt29-pid1248851-loss0.0014.pt')
        self.lip_enc.load('checkpoints/grid/ckpt30-pid3632310-loss0.0014.pt')
        self.lip_enc.requires_grad_(False)
        '''

    def forward(self, stage, spk_ids, lips, src_vid, src_pts, src_motion, tgt, src_lens=None, tgt_lens=None):   
        enc_src, src_mask, vid_feat = self.encoder(lips, src_vid, src_pts, src_motion, src_lens)   # [batch size, src len, hid dim]
        vid_feat = self.dropout(vid_feat)
        enc_src = self.dropout(enc_src)
        logits, attention = self.decoder(tgt[:, :-1], enc_src, src_mask)
        attn_loss = F.cross_entropy(logits.transpose(-1, -2).contiguous(), tgt[:, 1:].long(), ignore_index=self.opt.tgt_pad_idx)
        log_probs = self.ctc_dec(enc_src).transpose(0, 1).log_softmax(dim=-1)
        # (T, B, C), (B, S), (B,), (B,)
        ctc_loss = F.ctc_loss(log_probs, tgt[:, 1:], src_lens.reshape(-1), tgt_lens.reshape(-1), zero_infinity=True)
       
        '''
        with torch.no_grad():
            lip_repr = self.lip_enc.get_repr(lips, src_lens)
        aux_loss = F.l1_loss(self.hid_proj(enc_src), lip_repr)
        #aux_loss = F.kl_div(self.hid_proj(enc_src).flatten(0, 1).log_softmax(-1), lip_repr.flatten(0, 1).softmax(-1), reduction='batchmean')
        '''

        vsr_loss = 0.9 * attn_loss + 0.1 * ctc_loss 
        if stage == 1:
            spk_feat = self.spk_fc1(vid_feat.mean(dim=1))   # fc
            #spk_logit = self.spk_fc2(self.dropout(F.relu(spk_feat)))
            #spk_feat = self.gap(self.spk_fc1(vid_feat.transpose(1, 2))).squeeze(-1)   # 1x1 cnn
            spk_logit = self.spk_fc2(self.dropout(spk_feat))
            aux_loss = F.cross_entropy(spk_logit, spk_ids)
            
            #spk_feat = self.spk_fc1(vid_feat.transpose(1, 2)).transpose(1, 2)
            #spk_logit = self.spk_fc2(self.dropout(spk_feat))
            #loss = F.cross_entropy(spk_logit.transpose(1, 2), spk_ids[:, None].expand(spk_logit.shape[:2]), reduction='none')
            #mask = (torch.arange(loss.shape[1]).unsqueeze(0).to(src_lens.device) < src_lens.unsqueeze(1)).float()   # (B, N)
            #aux_loss = (loss * mask).sum() / mask.sum()
            
            #spk_loss = F.cross_entropy(spk_logit, spk_ids)
            #loss = vsr_loss + 0.2 * spk_loss
            #loss = vsr_loss + 0.2 * spk_loss + 0.2 * aux_loss 
        else:  # fix speaker weights
            #with torch.no_grad():
            spk_feat = self.spk_fc1(vid_feat.mean(dim=1))
            #spk_feat = self.gap(self.spk_fc1(vid_feat.transpose(1, 2))).squeeze(-1)
            #spk_logit = self.spk_fc2(self.dropout(spk_feat))
            #aux_loss1 = F.cross_entropy(spk_logit, spk_ids)
            
            #mi_loss = self.max_mi(vid_feat.mean(1), enc_src.mean(1)) - self.min_mi(spk_feat, enc_src.mean(1)) 
            #mi_loss = self.max_mi(vid_feat.mean(1), enc_src.mean(1)) - self.min_mi(spk_feat, enc_src.mean(1)) - self.min_mi(spk_feat, vid_feat.mean(1)) 
            #loss = vsr_loss + 0.2 * mi_loss  
            
            #aux_loss = self.max_mi(vid_feat.mean(1), enc_src.mean(1)) - self.min_mi(spk_feat, enc_src.mean(1)) 
            aux_loss = F.relu(self.max_mi(vid_feat.mean(1), enc_src.mean(1)) - self.min_mi(spk_feat, enc_src.mean(1)) + 0.2)
            #aux_loss = aux_loss1 + aux_loss2

        ''' 
        if stage == 1:
            loss = 0.9 * loss + 0.1 * ctc_loss   # 固定spk
        elif stage == 2:
            spk_feat = self.spk_fc1(vid_feat.mean(dim=1))
            spk_logit = self.spk_fc2(spk_feat)
            spk_loss = F.cross_entropy(spk_logit, spk_ids)
            loss = spk_loss    # 固定vsr encoder
        else:
            spk_feat = self.spk_fc1(vid_feat.mean(dim=1))
            mi_loss = 0.2 * self.max_mi(vid_feat.mean(1), enc_src.mean(1)) - 0.3 * self.min_mi(spk_feat, enc_src.mean(1)) 
            #mi_loss = self.max_mi(vid_feat.mean(1), enc_src.mean(1)) - self.min_mi(spk_feat, enc_src.mean(1)) - self.min_mi(spk_feat, vid_feat.mean(1)) 
            loss = 0.9 * loss + 0.1 * ctc_loss + mi_loss   # 固定spk
        '''

        #spk_feat = self.spk_fc1(vid_feat.mean(dim=1))
        #spk_logit = self.spk_fc2(spk_feat)
        #spk_loss = F.cross_entropy(spk_logit, spk_ids)
        #mi_loss = self.max_mi(vid_feat.mean(1), enc_src.mean(1)) - self.min_mi(spk_feat, enc_src.mean(1)) 
        #loss = 0.9 * loss + 0.1 * ctc_loss
        #loss = 0.9 * loss + 0.1 * ctc_loss + 0.2 * spk_loss + 0.2 * mi_loss
        return vsr_loss, aux_loss


    def greedy_decoding(self, src_vid, src_pts, src_motion, src_lens, bos_id, eos_id, pad_id=0):  # (bs, src len)
        print('greedy decoding ......')
        tgt_len = self.opt.max_dec_len
        bs = src_vid.shape[0]
        tgt_align = torch.tensor([[bos_id]] * bs).to(src_vid.device)  # (bs, 1)
        with torch.no_grad():
            enc_src, src_mask = self.encoder(src_vid, src_pts, src_motion, src_lens)
            ctc_score = self.ctc_dec(enc_src)
           
            for t in range(tgt_len):
                attn_score, _ = self.decoder(tgt_align, enc_src, src_mask)
                pred = attn_score.argmax(dim=-1)[:, -1]  # (bs, tgt_len) -> (bs, )   # greedy decoding
                # pred = self.topp_decoding(output[:, -1], top_p=0.96)
                # pred = self.topk_decoding(output[:, -1])
                if pred.cpu().tolist() == [eos_id] * bs or pred.cpu().tolist() == [pad_id] * bs:
                    break
                tgt_align = torch.cat((tgt_align, pred.unsqueeze(1)), dim=1).contiguous()
            
            L = min(ctc_score.shape[1], attn_score.shape[1])
            score = 0.1 * ctc_score[:, :L] + 0.9 * attn_score[:, :L]
        dec_pred = score.argmax(dim=-1)[1:].detach().cpu().numpy()
        #dec_pred = tgt_align[:, 1:].detach().cpu().numpy()  # (bs, tgt_len)
        return dec_pred

    # def greedy_decoding(self, src_inp, bos_id, eos_id, pad_id=0):  # (bs, src len)
    #     tgt_len = self.opt.max_dec_len
    #     bs = src_inp.shape[0]
    #     res = torch.zeros((bs, tgt_len)).to(src_inp.device)
    #     tgt_align = torch.zeros((bs, tgt_len), dtype=torch.long).to(src_inp.device)
    #     tgt_align[:, 0] = bos_id
    #     with torch.no_grad():
    #         enc_src, src_mask = self.encoder(src_inp)
    #         for t in range(tgt_len):
    #             output, attn = self.decoder(tgt_align, enc_src, src_mask)
    #             pred = output.argmax(dim=-1)
    #             if pred[:, t].cpu().tolist() == [eos_id] * bs or pred[:, t].cpu().tolist() == [pad_id] * bs:
    #                 break
    #             res[:, t] = pred[:, t]
    #             if t < tgt_len - 1:
    #                 tgt_align[:, t + 1] = pred[:, t]
    
    #     dec_pred = res.detach().cpu().numpy()  # (bs, tgt_len)
    #     return dec_pred

    
    def topk_decoding(self, next_token_logits, k=3, T=0.7):
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = next_token_logits < torch.topk(next_token_logits, k)[0][..., -1, None]
        next_token_logits[indices_to_remove] = -float("Inf")
        probs = nn.functional.softmax(next_token_logits/T, dim=-1)
        # multinominal方法可以根据给定权重对数组进行多次采样，返回采样后的元素下标
        next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        return next_token
        

    def topp_decoding(self, scores, top_p=0.95, filter_value=-float("Inf")):
        # sorted_indices[b][v] is the location of sorted_logits[b][v] in scores[b]
        sorted_logits, sorted_indices = torch.sort(scores, descending=True)
    
        # compute prefix sum: cumulative_probs is of shape (batch, vocab_size)
        cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
        # cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
        # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if sorted_indices_to_remove.all().item() == True:
            print('run ....')
            return scores.argmax(dim=-1)
        print(sorted_indices_to_remove)
        # scatter sorted tensors to original indexing
        # indices_to_remove[b][sorted_indices[b][v]] = sorted_indices_to_remove[b][v]
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    
        # fill in -inf in the positions we don't want to sample
        # so that they have probability 0 after the softmax
        scores = scores.masked_fill(indices_to_remove, filter_value)
        probs = torch.nn.functional.softmax(scores, dim=-1)   # (B, V)
        # sampling
        # print(scores.data, '|||', probs.data)
        tokens = torch.multinomial(probs, num_samples=1)   # sample 1 for each from batch
        return tokens.squeeze(1)
    
     
    def beam_search_decoding(self, lips, src_vid, src_pts, src_motion, src_lens, bos_id, eos_id, max_dec_len=80, pad_id=0, mode='attn'):
        assert mode in ['ctc', 'attn']
        if mode == 'ctc':
            with torch.no_grad():
                enc_src, src_mask = self.encoder(lips, src_vid, src_pts, src_motion, src_lens)[:2]   # [batch size, src len, hid dim]
                output = self.ctc_dec(enc_src)
                res = output.argmax(dim=-1)
        else:
            with torch.no_grad():
                enc_src, src_mask, _ = self.encoder(lips, src_vid, src_pts, src_motion, src_lens)
                #spk_feat = self.spk_fc2(self.spk_fc1(vid_feat.mean(dim=1)))  # (B, D)
                #np.savetxt('spk0.txt', spk_feat.detach().cpu().numpy(), fmt='%.4f')
                res = beam_decode(self.decoder, enc_src, src_mask, bos_id, eos_id, max_output_length=max_dec_len, beam_size=6)
        return res.detach().cpu().numpy()
    
    
    ''' 
    # rescoring
    def beam_search_decoding(self, src_vid, src_pts, src_motion, src_lens, bos_id, eos_id, max_dec_len=100, pad_id=0):
        res = []
        with torch.no_grad():
            enc_src, src_mask, _ = self.encoder(src_vid, src_pts, src_motion, src_lens)
            # 用CTC对beam search解码结果进行打分
            ctc_scores = self.ctc_dec(enc_src)
            beam_outs, beam_scores = beam_decode(self.decoder, enc_src, src_mask, bos_id, eos_id, beam_size=10, n_best=10)
            print(beam_outs[0], beam_scores[0])
            # bug: ctc_score中包含blank标签的得分
            beam_scores = 0.1 * ctc_scores + 0.9 * beam_scores
            best_pred_idx = beam_scores.argmax(dim=-1).cpu().numpy()
            res = beam_outs[best_pred_idx].cpu().numpy().tolist()
        return res
    '''
