import torch
import torch.nn as nn
import torch.functional as F
from .st_gcn import STModel

# Adopted from: https://github.com/TengdaHan/DPC


class DPC_RNN_Pretrain(nn.Module):
    def __init__(
        self,
        num_seq=7,
        seq_len=10,
        pred_step=3,
        in_channels=2,
        hidden_channels=64,
        hidden_dim=256,
        num_class=60,
        dropout=0.5,
        graph_args={"layout": "mediapipe-27", "strategy": "spatial"},
        edge_importance_weighting=True,
        **kwargs
    ):
        super(DPC_RNN_Pretrain, self).__init__()

        self.num_seq = num_seq
        self.seq_len = seq_len
        self.pred_step = pred_step

        self.conv_encoder = STModel(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            hidden_dim=hidden_dim,
            num_class=num_class,
            dropout=dropout,
            graph_args=graph_args,
            edge_importance_weighting=edge_importance_weighting,
            **kwargs
        )

        feature_size = 256
        self.feature_size = 256
        self.agg = nn.GRU(feature_size, feature_size, batch_first=True)
        self.network_pred = nn.Sequential(
            nn.Linear(feature_size, feature_size),
            nn.ReLU(inplace=True),
            nn.Linear(feature_size, feature_size),
        )

        self.relu = nn.ReLU(inplace=False)
        self.mask = None
        self._initialize_weights(self.agg)
        self._initialize_weights(self.network_pred)

    def forward(self, block):
        block = block.permute(0, 1, 4, 2, 3)  # B, N, T, V, C -> B, N, C, T, V
        B, N, C, T, V = block.shape
        block = block.view(B * N, C, T, V)

        feature = self.conv_encoder(block)

        feature_inf_all = feature.view(B, N, 256)
        feature = self.relu(feature)  # [0, +inf)
        feature = feature.view(B, N, 256)
        feature_inf = feature_inf_all[:, N - self.pred_step : :, ...].contiguous()

        ### aggregate, predict future ###
        _, hidden = self.agg(feature[:, 0 : N - self.pred_step, :].contiguous())
        hidden = hidden[-1, :]

        pred = []
        for i in range(self.pred_step):
            # sequentially pred future
            p_tmp = self.network_pred(hidden)
            pred.append(p_tmp)
            _, hidden = self.agg(self.relu(p_tmp).unsqueeze(1), hidden.unsqueeze(0))
            hidden = hidden.permute(1, 0, 2)
            hidden = hidden[:, -1, :]
        pred = torch.stack(pred, 1)  # B, pred_step, xxx

        N = self.pred_step
        pred = (
            pred.permute(0, 2, 1)
            .contiguous()
            .view(B * self.pred_step, self.feature_size)
        )
        feature_inf = (
            feature_inf.permute(0, 1, 2)
            .contiguous()
            .view(B * N, self.feature_size)
            .transpose(0, 1)
        )
        score = torch.matmul(pred, feature_inf).view(B, self.pred_step, B, N)

        if self.mask is None:
            mask = torch.zeros(
                (B, self.pred_step, B, N), dtype=torch.int8, requires_grad=False
            ).detach()
            for k in range(B):
                mask[k, :, k, :] = -1  # temporal neg

            tmp = mask.contiguous().view(B, self.pred_step, B, N)
            for j in range(B):
                tmp[
                    j,
                    torch.arange(self.pred_step),
                    j,
                    torch.arange(N - self.pred_step, N),
                ] = 1  # pos
            mask = tmp.view(B, self.pred_step, B, N)
            self.mask = mask

        return [score, self.mask]

    def _initialize_weights(self, module):
        for name, param in module.named_parameters():
            if "bias" in name:
                nn.init.constant_(param, 0.0)
            elif "weight" in name:
                nn.init.orthogonal_(param, 1)


class DPC_RNN_Finetune(nn.Module):
    def __init__(
        self,
        num_seq=6,
        seq_len=10,
        pred_step=2,
        in_channels=2,
        hidden_channels=64,
        hidden_dim=256,
        num_class=60,
        dropout=0.5,
        graph_args={"layout": "mediapipe-27", "strategy": "spatial"},
        edge_importance_weighting=True,
        **kwargs
    ):
        super(DPC_RNN_Finetune, self).__init__()

        self.num_seq = num_seq
        self.seq_len = seq_len
        self.pred_step = pred_step
        self.num_class = num_class
        self.conv_encoder = STModel(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            hidden_dim=hidden_dim,
            num_class=num_class,
            dropout=dropout,
            graph_args=graph_args,
            edge_importance_weighting=edge_importance_weighting,
            **kwargs
        )

        feature_size = 256
        self.feature_size = 256
        self.agg = nn.GRU(feature_size, feature_size, batch_first=True)

        self.final_bn = nn.BatchNorm1d(self.feature_size)
        self.final_bn.weight.data.fill_(1)
        self.final_bn.bias.data.zero_()

        self.final_fc = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(self.feature_size, self.num_class)
        )

        self._initialize_weights(self.final_fc)

    def forward(self, block):
        block = block.permute(0, 1, 4, 2, 3)  # B, N, T, V, C -> B, N, C, T, V
        B, N, C, T, V = block.shape
        block = block.view(B * N, C, T, V)

        feature = self.conv_encoder(block)
        feature = F.relu(feature)  # [0, +inf)

        feature = feature.view(B, N, self.feature_size)
        ### aggregate, predict future ###
        context, hidden = self.agg(feature)
        context = context[:, -1, :]
        context = self.final_bn(context)
        output = self.final_fc(context).view(B, self.num_class)

        return output

    def _initialize_weights(self, module):
        for name, param in module.named_parameters():
            if "bias" in name:
                nn.init.constant_(param, 0.0)
            elif "weight" in name:
                nn.init.orthogonal_(param, 1)
