import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()

        V = args.num_vocab
        D = args.dim_embd
        self.padding_idx = 0

        if args.embedding_pretrained:
            pretrain_embed = torch.load('{}glove_pretrain_embed.pth'.format(args.data_file))['pretrain']
            self.embedding = nn.Embedding(V, D).from_pretrained(pretrain_embed, args.freeze)
        else:
            self.embedding = nn.Embedding(V, D, padding_idx=self.padding_idx)

        self.emb_convert = None
        if not args.dim_model == args.dim_embd:
            self.emb_convert = nn.Linear(args.dim_embd, args.dim_model)

        self.postion_embedding = Positional_Encoding(args.dim_embd, args.max_length, args.dropout, args.device)

        self.encoders = nn.ModuleList([
            Encoder(args.dim_model, args.num_head, args.dim_inner, args.dropout)
            for _ in range(args.num_encoder)])

        self.num_head = args.num_head
        self.fc1 = nn.Linear(args.dim_model, args.num_class)

        self.clusterCenter = nn.Parameter(torch.FloatTensor(1, args.n_centroid, args.dim_model), requires_grad=True)
        nn.init.xavier_uniform_(self.clusterCenter)

    def forward(self, x):
        out = self.embedding(x)
        out = self.postion_embedding(out)

        if not self.emb_convert is None:
            out = self.emb_convert(out)

        mask = x.eq(self.padding_idx)

        # clusterCenter = self.clusterCenter.repeat(x.size(0), 1, 1)
        clusterCenter = self.clusterCenter.expand(x.size(0), -1, -1)

        losses = []
        for encoder in self.encoders:
            out, clusterCenter, loss, loss_ = encoder(out, clusterCenter, mask)
            losses.append((loss, loss_))

        out = torch.mean(out, 1)
        # out = F.max_pool1d(out, 1).squeeze(2)
        out = self.fc1(out)

        return (out, losses)


class Encoder(nn.Module):
    def __init__(self, dim_model, num_head, hidden, dropout):
        super(Encoder, self).__init__()
        self.attention = ClusterAttention(dim_model, num_head, dropout)
        self.feed_forward = Position_wise_Feed_Forward(dim_model, hidden, dropout)

    def forward(self, x, c, mask=None):
        out, clusterCenter, loss, loss_ = self.attention(x, c, mask=mask)
        out = self.feed_forward(out)
        return out, clusterCenter, loss, loss_


def get_index(similarity_matrix):
    _, buckets_index = torch.max(similarity_matrix.transpose(1, 2), dim=-1)
    return buckets_index

def expand_dim(t, dim, k):
    t = t.unsqueeze(dim)
    expand_shape = [-1] * len(t.shape)
    expand_shape[dim] = k
    return t.expand(*expand_shape)

class ClusterAttention(nn.Module):

    def __init__(self, dim_model, num_head, dropout=0.0):
        super().__init__()

        self.n_head = num_head
        assert dim_model % num_head == 0
        self.d_v = self.d_k = dim_model // num_head

        self.w_qkc = nn.Linear(dim_model, dim_model)
        self.w_qk = nn.Linear(dim_model, self.n_head * self.d_k)
        self.w_vs = nn.Linear(dim_model, self.n_head * self.d_v)

        self.layer_norm = nn.LayerNorm(dim_model)
        self.layer_norm_c = nn.LayerNorm(self.n_head * self.d_k)

        self.fc = nn.Linear(self.n_head * self.d_v, dim_model)

        self.dropout = nn.Dropout(dropout)

        self.scale_cc = dim_model ** -0.5
        self.scale = self.d_k ** -0.5

    def perform_attention(self, Q, K, V, mask=None, need_attn=False):
        attention = torch.matmul(Q, K.permute(0, 2, 1))

        if mask is not None:
            attention = attention.masked_fill(mask, -65504)

        attention = F.softmax(attention, dim=-1)
        context = torch.matmul(attention, V)

        if need_attn:
            return context, attention

        return context

    def forward(self, x, c, mask=None):

        residual = q = k = v = x
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        n_clusters = c.size()[-2]
        sz_b, len_q, dim = q.size()

        qkc, qk, v = self.w_qkc(q), self.w_qk(q), self.w_vs(v)
        cluster_size = int(len_q // n_clusters)

        q = F.normalize(q, 2, dim=-1)

        if mask is None:
            clusterCenter, similarity_matrix = self.perform_attention(c*self.scale_cc, qkc, qkc, mask=None, need_attn=True)
        else:
            cluster_mask = mask.unsqueeze(1).expand(-1, n_clusters, -1)
            clusterCenter, similarity_matrix = self.perform_attention(c*self.scale_cc, qkc, qkc, mask=cluster_mask, need_attn=True)

        buckets_index = get_index(similarity_matrix)
        clusterCenter = F.normalize(clusterCenter + c, 2, dim=-1)

        expand_cluster = expand_dim(buckets_index, -1, c.shape[-1])  # [64,30,128]
        x_means = clusterCenter.gather(1, expand_cluster)

        # Three ways for clustering loss
        # loss1 = F.mse_loss(q, x_means)
        # loss2 = F.mse_loss(clusterCenter, torch.cat([clusterCenter[:, -1:, :], clusterCenter[:, :-1, :]], dim=1))       
        # loss1 = torch.sum((q - x_means) ** 2) / (q.shape[0] * q.shape[1])
        # loss2 = torch.sum((c - torch.cat([c[:, -1:, :], c[:, :-1, :]], dim=1)) ** 2) / (c.shape[0] * c.shape[1])       
        loss1 = -torch.mul(q, x_means).sum() / (q.shape[0] * q.shape[1])
        loss2 = -torch.mul(c, torch.cat([c[:, -1:, :], c[:, :-1, :]], dim=1)).sum() / (c.shape[0] * c.shape[1])

        _, idx1 = torch.sort(buckets_index, -1)
        index11 = idx1.unsqueeze(-1).expand_as(qk).type(q.dtype).long()
        qk_sort = qk.gather(1, index11).contiguous().view(sz_b, n_clusters, cluster_size, n_head, d_k)
        v_sort = v.gather(1, index11).contiguous().view(sz_b, n_clusters, cluster_size, n_head, d_k)

        # make K,V contain 2 blocks
        def look_one_back(x):
            x_extra = torch.cat([x[:, -1:, :], x[:, :-1, :]], dim=1)
            return torch.cat([x, x_extra], dim=2)

        k_sort = look_one_back(qk_sort)
        v_sort = look_one_back(v_sort)

        qk_sort = qk_sort.permute(3, 0, 1, 2, 4).contiguous().view(-1, cluster_size, d_k)
        k_sort = k_sort.permute(3, 0, 1, 2, 4).contiguous().view(-1, 2 * cluster_size, d_k)
        v_sort = v_sort.permute(3, 0, 1, 2, 4).contiguous().view(-1, 2 * cluster_size, d_v)

        if mask is None:
            output = self.perform_attention(qk_sort*self.scale, k_sort, v_sort, mask=None)
        else:
            mask_sort = mask.gather(1, idx1).contiguous().view(sz_b, n_clusters, cluster_size)
            mask_sort = look_one_back(mask_sort)
            mask_sort = mask_sort.unsqueeze(0).unsqueeze(3).expand(n_head, -1, -1, cluster_size, -1)
            mask_sort = mask_sort.contiguous().view(-1, cluster_size, 2 * cluster_size)
            output = self.perform_attention(qk_sort*self.scale, k_sort, v_sort, mask=mask_sort)

        _, idx2 = torch.sort(idx1)
        index22 = idx2.unsqueeze(-1).expand_as(qk).type(q.dtype).long()
        output = output.view(n_head, sz_b, n_clusters, cluster_size, d_k)
        output = output.permute(1, 2, 3, 0, 4).contiguous().view(sz_b, len_q, -1).gather(1, index22)

        output = self.dropout(self.fc(output))
        # output = self.layer_norm(output + residual + x_means)
        output = self.layer_norm(output + residual)

        return output, clusterCenter, loss, loss_

class Positional_Encoding(nn.Module):
    def __init__(self, embed, pad_size, dropout, device):
        super(Positional_Encoding, self).__init__()
        self.device = device
        self.pe = torch.Tensor([[pos / (10000.0 ** (i // 2 * 2.0 / embed)) for i in range(embed)] for pos in range(pad_size)])
        self.pe[:, 0::2] = np.sin(self.pe[:, 0::2])
        self.pe[:, 1::2] = np.cos(self.pe[:, 1::2])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = x + nn.Parameter(self.pe, requires_grad=False).to(self.device)
        out = self.dropout(out)
        return out

class Position_wise_Feed_Forward(nn.Module):
    def __init__(self, dim_model, hidden, dropout=0.0):
        super(Position_wise_Feed_Forward, self).__init__()
        self.fc1 = nn.Linear(dim_model, hidden)
        self.fc2 = nn.Linear(hidden, dim_model)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(dim_model)

    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.fc2(out)
        out = self.dropout(out)
        out = out + x  # 残差连接
        out = self.layer_norm(out)
        return out
