import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
import numpy as np


class GraphConv2(nn.Module):
    def __init__(self, in_feature, out_feature, bias=False, layer_norm=False, activation=None, dropout=0.0):
        super(GraphConv2, self).__init__()
        self.in_feature = in_feature
        self.out_feature = out_feature
        self.bias = bias
        self.layer_norm = layer_norm
        self.activation = activation
        self.dropout = dropout

        self.W = nn.Linear(in_feature, out_feature, bias=False)
        if self.bias:
            self.h_bias = nn.Parameter(torch.zeros(out_feature))

        if self.layer_norm:
            self.layer_norm_w = nn.LayerNorm(out_feature, elementwise_affine=True)

    def forward(self, x, adj):
        '''
        x: (B, N, D)
        adj: (B, N, N)
        '''
        # (N, N) * (N, H) -> (N, H)
        output = torch.matmul(adj, self.W(x))
        # output = torch.bmm(adj.to_sparse(), self.W(x))  # torch.__version__ >= 1.6
        if self.layer_norm:
            output = self.layer_norm_w(output)
        if self.bias:
            output = output + self.h_bias
        if self.activation:
            output = self.activation(output)
        if self.dropout > 0:
            output = F.dropout(output, p=self.dropout, training=self.training)
        return x + output

    def __repr__(self):
        return self.__class__.__name__ + '(' + str(self.in_feature) + ', ' + str(self.out_feature) + ')'


class GCNLayer(nn.Module):
    def __init__(self, in_feature, out_feature, dropout=0.5, add_skip=False):
        super(GCNLayer, self).__init__()
        self.gcn_layer1 = GraphConv2(in_feature, in_feature, activation=nn.ReLU(), dropout=dropout)
        self.gcn_layer2 = GraphConv2(in_feature, in_feature, activation=nn.ReLU(), dropout=dropout)
        self.gcn_layer3 = GraphConv2(in_feature, out_feature)
        self.dropout = dropout
        self.add_skip = add_skip
        #if self.add_skip:
        #    self.gate_fc = nn.Linear(in_feature, out_feature)
        #    nn.init.xavier_uniform_(self.gate_fc.weight)

    def forward(self, x, adj):
        xx = self.gcn_layer1(x, adj)
        xxx = self.gcn_layer2(xx, adj)
        output = self.gcn_layer3(xxx, adj)
        if self.add_skip:
            return output + x
            #gate = torch.sigmoid(self.gate_fc(x))
            #return (1 - gate) * x + gate * output
        else:
            return output
            #outs = torch.stack([xx, xxx, output], dim=0).max(dim=0)[0]    # Jump Knowledge Network
            #return outs


class MultiHeadGCNLayer(nn.Module):
    def __init__(self, in_feature, out_feature, nheads=4, dropout=0.2, add_skip=False):
        super(MultiHeadGCNLayer, self).__init__()
        self.gcn_layers = nn.ModuleList([GraphConv2(in_feature, out_feature//nheads, activation=nn.ReLU(), dropout=dropout)
                                         for _ in range(nheads)])
        self.dropout = dropout
        self.add_skip = add_skip
        if self.add_skip:
            self.gate_fc = nn.Linear(in_feature, out_feature)
            nn.init.xavier_uniform_(self.gate_fc.weight)

    def forward(self, x, adj):
        output = torch.cat(tuple([gcn(x, adj) for gcn in self.gcn_layers]), dim=-1).contiguous()
        if self.add_skip:
            # return output + x
            gate = torch.sigmoid(self.gate_fc(x))
            return (1 - gate) * x + gate * output
        else:
            return output


# Non-overlapping
class PatchEmbedding(nn.Module):
    def __init__(self, in_chans=3, embed_dim=768):
        super().__init__()
        self.proj = nn.Sequential(
            *[
            nn.Conv2d(in_chans, embed_dim//4, kernel_size=4, stride=4),   # H/4 x W/4
            #nn.Conv2d(in_chans, embed_dim//4, kernel_size=2, stride=2),   # H/2 x W/2
            nn.BatchNorm2d(embed_dim//4),   # BN or LN
            #nn.GELU(),
            nn.ReLU(),
            nn.Conv2d(embed_dim//4, embed_dim//2, kernel_size=2, stride=2), # H/8 x W/8
            nn.BatchNorm2d(embed_dim//2),
            #nn.GELU(),
            nn.ReLU(),
            nn.Conv2d(embed_dim//2, embed_dim, kernel_size=2, stride=2),    # H/16 x W/16
            #nn.BatchNorm2d(embed_dim),
        ])

        self.gmp = nn.AdaptiveMaxPool2d(1)  # added by me

    def forward(self, x): # (BTN, C, H, W)
        x = self.proj(x)
        x = self.gmp(x)
        return x.squeeze(-1).squeeze(-1)


# Overlapping
class EarlyConv(nn.Module):
    def __init__(self, in_chans=1, embed_dim=768, bias=True):
        super().__init__()
        self.conv_layers = nn.Sequential(*[
            nn.Conv2d(in_chans, embed_dim//4, kernel_size=3, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(embed_dim//4),   # BN or LN
            nn.ReLU(),
            nn.Conv2d(embed_dim//4, embed_dim//2, kernel_size=3, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(embed_dim//2),   # BN or LN
            nn.ReLU(),
            nn.Conv2d(embed_dim//2, embed_dim, kernel_size=3, stride=2, padding=1, bias=bias),
            #nn.BatchNorm2d(embed_dim),   # BN or LN
        ])
        self.gmp = nn.AdaptiveMaxPool2d(1)
        #self.gap = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.conv_layers(x)
        mx = self.gmp(x).squeeze(-1).squeeze(-1)
        #ax = self.gap(x).squeeze(-1).squeeze(-1)
        #out = torch.cat((mx, ax), dim=-1).contiguous()
        return mx


class DynGCN(nn.Module):
    def __init__(self, patch_size=16, in_chans=1, embed_dim=256, drop_ratio=0.1, patch_flatten=False):
        super(DynGCN, self).__init__()
        self.embed_dim = embed_dim
        self.patch_flatten = patch_flatten
        self.patch_embed = nn.Linear(in_chans * patch_size * patch_size, embed_dim) if patch_flatten else PatchEmbedding(in_chans, embed_dim)

        #self.fc = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, 2*embed_dim, bias=False)
        self.gcn = GCNLayer(embed_dim, embed_dim, dropout=0.2, add_skip=False)	
	#self.mhgcn = MultiHeadGCNLayer(embed_dim, embed_dim, nheads=4, dropout=0.2, add_skip=False)	

    def get_dynamic_graph(self, x):
        #hn = self.fc(x)
        hn, hn2 = self.fc(x).chunk(2, dim=-1)
        s = torch.matmul(hn, hn.transpose(-1, -2).contiguous()) / (self.embed_dim ** 0.5)
        return F.softmax(s, dim=-1)

    def forward(self, x, pos_embed=None):   # 输入为Landmark Patches：(B, T, Np, C, H, W)
        '''
          x: (B, T, N, C, H, W)
          pos_embed: (B*T, N, D)
        '''
        B, T, N = x.shape[:3]
        if self.patch_flatten:
            x = x.reshape(B*T, N, -1)
            # [BT, Np, CHW] -> [BT, Np, embed_dim]
            x = self.patch_embed(x)  # (BT, N, D)
        else:
            x = x.flatten(0, 2).contiguous()  # (BTN, CHW)
            x = self.patch_embed(x)  # (BTN, D)
            x = x.reshape(B*T, N, -1)

        if pos_embed is not None:
            x = x + pos_embed

        dyn_adj = self.get_dynamic_graph(x)
        x = self.gcn(x, dyn_adj)
        x = torch.mean(x, dim=1)  # global avg pooling
        return x.reshape(B, T, -1)   



# =======================================================================================


class GAT(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, nheads=4, nstack=3):
        super().__init__()
        self.nstack = nstack
        self.gat = nn.ModuleList([GATLayer(input_dim, output_dim, nheads)])
        for _ in range(nstack-1):
            self.gat.append(GATLayer(output_dim, output_dim, nheads))

    def forward(self, x, pos_embed=None):
        #B, T, N = x.shape[:3]  # (B, T, N, C, H, W)
        #x = x.flatten(0, 2).contiguous()  # (BTN, CHW)
        #x = self.patch_embed(x)  # (BTN, D)
        #x = x.reshape(B*T, N, -1)
        #if pos_embed is not None:
        #    x = x + pos_embed
       
        # (B, T, N, D)
        B, T = x.shape[:2]
        x = x.flatten(0, 1).contiguous()  # (BT, N, D)
        x = x.transpose(-1, -2).contiguous()
        #probs = 0
        for i in range(self.nstack):
            x, prob = self.gat[i](x)
            #probs = probs + prob.detach().cpu().numpy()
        
        #probs = probs / self.nstack
        #probs = probs.mean(axis=1)
        #for i in range(150):
        #    np.savetxt(f'{str(i)}.txt', probs[i])

        x = x.transpose(-1, -2).contiguous()
        x = torch.mean(x, dim=1)  # global avg pooling
        return x.reshape(B, T, -1)   


class GATLayer(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, num_heads=4):
        super().__init__()
        num_heads_in = num_heads
        self.reshape = None
        if input_dim != output_dim:
            for num_heads_in in range(num_heads, 0, -1):
                if input_dim % num_heads_in == 0:
                    break
            self.reshape = MLP([input_dim, output_dim])

        self.attention = MessagePassing(input_dim, num_heads_in, output_dim)

    def forward(self, features):
        message, prob = self.attention(features)
        if self.reshape:
            features = self.reshape(features)
        output = features + message
        return output, prob


def MLP(channels: list):
    n = len(channels)
    layers = []
    for i in range(1, n):
        layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
        if i < (n-1):
            layers.append(nn.BatchNorm1d(channels[i]))
            layers.append(nn.ReLU())
    return nn.Sequential(*layers)


class MessagePassing(nn.Module):
    def __init__(self, feature_dim: int, num_heads: int, out_dim=None):
        super().__init__()
        self.attn = Attention(num_heads, feature_dim)
        self.mlp = MLP([feature_dim*2, feature_dim*2, out_dim])

    def forward(self, features):
        message, prob = self.attn(features, features, features)
        return self.mlp(torch.cat([features, message], dim=1)), prob


class Attention(nn.Module):
    def __init__(self, num_heads: int, feature_dim: int):
        super().__init__()
        assert feature_dim % num_heads == 0
        self.dim = feature_dim // num_heads
        self.num_heads = num_heads
        self.merge = nn.Conv1d(feature_dim, feature_dim, kernel_size=1)
        self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
        self.dropout = nn.Dropout(0.1)

    def forward(self, query, key, value):
        batch_dim = query.size(0)
        query, key, value = [l(x).reshape(batch_dim, self.dim, self.num_heads, -1)
                             for l, x in zip(self.proj, (query, key, value))]
        x, prob = self.attention(query, key, value)
        return self.merge(x.contiguous().reshape(batch_dim, self.dim*self.num_heads, -1)), prob

    def attention(self, query, key, value):
        dim = query.shape[1]
        scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** 0.5
        prob = F.softmax(scores, dim=-1)
        return torch.einsum('bhnm,bdhm->bdhn', self.dropout(prob), value), prob


