import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torch.cuda.amp import autocast
import random
from typing import Union, List
from .rnns import GRU
from .attn import AttentiveAttention, MLPAttention


# https://github.com/prajwalkr/vtp/blob/master/modules.py
class Conv3d(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, bias=True, residual=False):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv3d(cin, cout, kernel_size, stride, padding, bias=bias),
            nn.BatchNorm3d(cout)
        )
        self.act = nn.ReLU()
        self.residual = residual

    def forward(self, x):
        out = self.conv_block(x)
        if self.residual:
            out += x
        return self.act(out)


class CNN_3d(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.encoder = nn.Sequential(
            Conv3d(3, 64, kernel_size=5, stride=(1, 2, 2), padding=2),  # 48, 48

            Conv3d(64, 128, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),  # 24, 24
            Conv3d(128, 128, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), residual=True),

            Conv3d(128, 256, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),  # 12, 12
            Conv3d(256, 256, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), residual=True),
            Conv3d(256, 256, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), residual=True),

            Conv3d(256, 512, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),  # 6, 6
            Conv3d(512, 512, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), residual=True),
            Conv3d(512, 512, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), residual=True),

            Conv3d(512, 512, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),  # 3, 3
            Conv3d(512, d_model, kernel_size=(1, 3, 3), stride=1, padding=(0, 0, 0)), )

    @autocast()
    def forward(self, faces, mask):
        assert faces.size(3) == 96
        assert faces.size(4) == 96
        face_embeddings = self.encoder(faces)  # (B, C, T, 1, 1)
        return face_embeddings.squeeze(3).squeeze(3).transpose(1, 2)  # (B, T, C)


class CNN_3d_featextractor(nn.Module):
    def __init__(self, d_model, till):
        super().__init__()
        layers = [Conv3d(3, 64, kernel_size=5, stride=(1, 2, 2), padding=2)]  # 48, 48
        if till <= 24:
            layers.extend([Conv3d(64, 128, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),  # 24, 24
                           Conv3d(128, 128, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), residual=True), ])
        if till <= 12:
            layers.extend([Conv3d(128, 256, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),  # 12, 12
                           Conv3d(256, 256, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), residual=True),
                           Conv3d(256, 256, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), residual=True), ])
        if till == 6:
            layers.extend([Conv3d(256, 512, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),  # 6, 6
                           Conv3d(512, 512, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), residual=True),
                           Conv3d(512, 512, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), residual=True), ])

        self.encoder = nn.Sequential(*layers)

    @autocast()
    def forward(self, faces, mask):
        assert faces.size(3) == 96
        assert faces.size(-1) == 96
        face_embeddings = self.encoder(faces)  # (B, C, T, H, W)
        return face_embeddings


# ==========================================================================================

class TextEncoder(nn.Module):
    def __init__(self, num_embeddings, emb_dim, hidden_dim, drop_rate=0.):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, emb_dim)
        self.gru = GRU(emb_dim, hidden_dim, bidirectional=True)
        self.drop_rate = drop_rate

    def forward(self, src, src_mask=None):
        embedded = F.dropout(self.embedding(src), p=self.drop_rate, training=self.training)
        enc_outputs, _ = self.gru(embedded, non_pad_mask=src_mask)
        return enc_outputs


class VisualEncoder(nn.Module):
    def __init__(self, in_channel=1, init_dim=64, out_dim=256, num_layers=1, drop_rate=0.2):
        super(VisualEncoder, self).__init__()
        self.drop_rate = drop_rate
        self.num_layers = num_layers
        # 3D-CNN
        self.stcnn = nn.Sequential(OrderedDict([
            ('conv', nn.Conv3d(in_channel, init_dim, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3),
                       bias=False)),
            ('norm', nn.BatchNorm3d(init_dim)),
            ('relu', nn.ReLU(inplace=True)),
            ('pool', nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)))]))

        # omitting ResNet layers

        self.gru = GRU(32768, out_dim, num_layers=num_layers, bidirectional=True, batch_first=True)

    def forward(self, x):   # (b, t, c, h, w)
        mask = torch.abs(torch.sum(x, dim=(-1, -2, -3))) > 0  # (b, t)
        x = x.transpose(1, 2).contiguous()  # (b, c, t, h, w)
        cnn = self.stcnn(x)  # (N, Cout, Dout, Hout, Wout)
        cnn = cnn.permute(0, 2, 1, 3, 4).contiguous()  # (N, Dout, Cout, Hout, Wout)
        batch, seq, channel, high, width = cnn.size()
        cnn = cnn.reshape(batch, seq, -1)  # (B, N, D)
        out, hn = self.gru(cnn, non_pad_mask=mask)  # hn: (n_layer * n_direct, B, D)
        hn = hn.reshape(self.num_layers, hn.size(1), -1)
        return out, hn, mask


class TextDecoder(nn.Module):
    def __init__(self, vocab_size, dec_embed_size, hidden_size, num_layers=1, drop_rate=0.2):
        super(TextDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, dec_embed_size)
        self.attn_layer = MLPAttention(hidden_size)
        self.gru = GRU(dec_embed_size + hidden_size, hidden_size, num_layers=num_layers, dropout=drop_rate, batch_first=True)
        self.fc_out = nn.Linear(hidden_size, vocab_size)

    def forward(self, cur_input, state, enc_states, enc_mask=None):
        """
         cur_input shape: (batch, )
         state shape: (num_layers*num_directs, batch, num_hiddens)
         """
        embedded = self.embedding(cur_input)
        ctx_vec = self.attn_layer(state[-1], enc_states, enc_mask)
        inp_and_ctx = torch.cat((embedded, ctx_vec), dim=1).contiguous()  # (B, D)
        output, state = self.gru(inp_and_ctx.unsqueeze(1), state)  # (B, 1, D)
        output = self.fc_out(output.squeeze(dim=1))  # (B, V)
        # output = self.fc_out(torch.cat((output.squeeze(dim=1), inp_and_ctx), dim=-1))  # (B, V)
        return output, state


# class TextDecoder(nn.Module):
#     def __init__(self, vocab_size, dec_embed_size, hidden_size, attn_size, num_layers=1, drop_rate=0.):
#         super(TextDecoder, self).__init__()
#         self.embedding = nn.Embedding(vocab_size, dec_embed_size)
#         self.attn_model = self.attn_layer(2 * hidden_size, attn_size)
#         # 输⼊包含attention输出的c和实际输⼊
#         self.gru = GRU(dec_embed_size + hidden_size, hidden_size, num_layers, dropout=drop_rate, batch_first=True)
#         self.fc_out = nn.Linear(hidden_size, vocab_size)
#
#     def attn_layer(self, input_size, attn_size):
#         return nn.Sequential(nn.Linear(input_size, attn_size, bias=False),
#                              nn.Tanh(),
#                              nn.Linear(attn_size, 1, bias=False))
#
#     def attn_forward(self, enc_states, dec_state, src_mask=None):
#         """
#         model:函数attention_model返回的模型
#         enc_states: 编码端的输出，shape是(batch_size, seq_len, hidden_dim)
#         dec_state: 解码端一个时间步的输出，shape是(batch_size, hidden_dim)
#         """
#         # (batch_size, 1, hidden_dim) -> (batch_size, seq_len, hidden_dim)
#         enc_dec_states = torch.cat((enc_states, dec_state.unsqueeze(dim=1).expand_as(enc_states)), dim=2)  # (batch_size, seq_len, 2*hidden_dim)
#         e = self.attn_model(enc_dec_states).squeeze(-1)  # (batch_size, seq_len, 1) -> (batch_size, seq_len)
#         if src_mask is not None:
#             e = e.masked_fill(src_mask == 0, -1e9)
#         alpha = F.softmax(e, dim=1)
#         return (alpha.unsqueeze(-1) * enc_states).sum(dim=1)  # context vector  (batch_size, hidden_dim)
#
#     def forward(self, cur_input, state, enc_states, enc_mask=None):
#         """
#          cur_input shape: (batch, )
#          state shape: (num_layers, batch, num_hiddens)
#          """
#         embedded = self.embedding(cur_input)
#         ctx = self.attn_forward(enc_states, state[-1], enc_mask)
#         input_and_ctx = torch.cat((embedded, ctx), dim=1)  # (batch_size, 2*embed_size)
#         output, state = self.gru(input_and_ctx.unsqueeze(1), state)  # (batch_size, 1, 2*embed_size)
#         output = self.fc_out(output.squeeze(dim=1))  # (batch_size, vocab_size)
#         # output = self.fc_out(torch.cat((output.squeeze(dim=1), input_and_ctx), dim=-1))  # (batch_size, vocab_size)
#         return output, state



class Seq2Seq(nn.Module):
    def __init__(self, opt):
        super(Seq2Seq, self).__init__()
        self.encoder = VisualEncoder(in_channel=opt.in_channel,
                                     num_layers=opt.enc_layers,
                                     out_dim=opt.hidden_dim,
                                     drop_rate=opt.dropout)

        self.decoder = TextDecoder(vocab_size=opt.tgt_vocab_size,
                                   dec_embed_size=opt.hidden_dim,
                                   hidden_size=2*opt.hidden_dim,
                                   num_layers=opt.dec_layers,
                                   drop_rate=opt.dropout)
        self.opt = opt

    def forward(self, src, tgt=None):
        if tgt is None:
            tgt_len = self.opt.max_dec_len
        else:
            tgt_len = tgt.size(1)

        enc_outputs, enc_state, enc_mask = self.encoder(src)
        dec_state = enc_state
        dec_outputs = []
        dec_input = tgt[:, 0]
        for t in range(1, tgt_len):
            dec_output_t, dec_state = self.decoder(dec_input, dec_state, enc_outputs, enc_mask)
            dec_outputs.append(dec_output_t.unsqueeze(1))   # (B, V) -> (B, 1, V)
            if random.random() < self.opt.teacher_forcing_ratio:
                dec_input = tgt[:, t]
            else:
                dec_input = dec_output_t.argmax(dim=-1)

        dec_outs = torch.cat(dec_outputs, dim=1).contiguous()   # (B, N, V)
        loss = F.cross_entropy(dec_outs.transpose(-1, -2), tgt[:, 1:].long(), ignore_index=self.opt.tgt_pad_idx)
        return loss, dec_outs

    def greedy_decoding(self, src_inp, bos_id, eos_id, pad_id=0):
        tgt_len = self.opt.max_dec_len
        bs = src_inp.size(0)
        dec_preds = []
        dec_input = torch.tensor([bos_id] * bs).to(src_inp.device)
        with torch.no_grad():
            enc_outputs, enc_state, enc_mask = self.encoder(src_inp)
            dec_state = enc_state
            for t in range(tgt_len):
                dec_output_t, dec_state = self.decoder(dec_input, dec_state, enc_outputs, enc_mask)
                pred = dec_output_t.argmax(dim=-1)  # (B, )
                if pred.cpu().tolist() == [eos_id] * bs or pred.cpu().tolist() == [pad_id] * bs:
                    break
                dec_preds.append(pred.unsqueeze(1))
                dec_input = pred
        dec_pred = torch.cat(dec_preds, dim=1)  # (B, N)
        return dec_pred.detach().cpu().numpy()


# def train(train_loader, opt=None):
#     model = Seq2Seq(opt)
#     param_nums = sum(p.numel() for p in model.parameters() if p.requires_grad)
#     print(f'The model has {param_nums} trainable parameters')
#     optimizer_parameters = [
#         {'params': [p for p in model.encoder.parameters() if p.requires_grad],
#          'weight_decay': opt.weight_decay, 'lr': opt.enc_lr},
#         {'params': [p for p in model.decoder.parameters() if p.requires_grad],
#          'weight_decay': opt.weight_decay, 'lr': opt.dec_lr}]
#     optimizer = torch.optim.AdamW(optimizer_parameters, lr=opt.enc_lr, eps=opt.eps)
#     model.train()
#     ep_loss = 0.
#     for i, batch in enumerate(train_loader):
#         optimizer.zero_grad()
#         r = model(batch.src, batch.trg)
#         r.loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), opt.clip)
#         optimizer.step()
#         ep_loss += r.loss_value
#     return ep_loss / len(train_loader)
# 
# 
# def evaluate(test_loader, ckpt_path='save.pt', opt=None):
#     model = Seq2Seq(opt)
#     model.load_state_dict(torch.load(ckpt_path))
#     model.eval()
#     outputs = []
#     with torch.no_grad():
#         for i, batch in enumerate(test_loader):
#             decoder_predict = model.greedy_search_decode(batch.src)
#             outputs.extend(decoder_predict)
#     return outputs