import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()

        V = args.num_vocab
        D = args.dim_embd
        self.padding_idx = 0

        if args.embedding_pretrained:
            pretrain_embed = torch.load('{}glove_pretrain_embed.pth'.format(args.data_file))['pretrain']
            self.embedding = nn.Embedding(V, D).from_pretrained(pretrain_embed, args.freeze)
        else:
            self.embedding = nn.Embedding(V, D, padding_idx=self.padding_idx)

        self.emb_convert = None
        if not args.dim_model == args.dim_embd:
            self.emb_convert = nn.Linear(args.dim_embd, args.dim_model)

        self.postion_embedding = Positional_Encoding(args.dim_embd, args.max_length, args.dropout, args.device)

        self.encoders = nn.ModuleList([
            Encoder(args.dim_model, args.num_head, args.dim_inner, args.dropout, args.n_centroid, args.commitment1)
            for _ in range(args.num_encoder)])

        self.num_head = args.num_head
        self.fc1 = nn.Linear(args.dim_model, args.num_class)

    def forward(self, x):
        out = self.embedding(x)
        out = self.postion_embedding(out)

        if not self.emb_convert is None:
            out = self.emb_convert(out)

        mask = x.ne(self.padding_idx)

        losses = []
        for encoder in self.encoders:
            out, loss = encoder(out, mask)
            losses.append(loss)

        out = torch.mean(out, 1)
        # out = F.max_pool1d(out, 1).squeeze(2)
        out = self.fc1(out)
        return (out, losses)


class Encoder(nn.Module):
    def __init__(self, dim_model, num_head, hidden, dropout, n_centroid, commitment):
        super(Encoder, self).__init__()
        self.attention = RoutingAttention(dim_model, num_head, dropout, n_centroid, commitment)
        self.feed_forward = Position_wise_Feed_Forward(dim_model, hidden, dropout)

    def forward(self, x, mask=None):
        out, aux_loss = self.attention(x, mask=mask)
        out = self.feed_forward(out)
        return out, aux_loss


class Positional_Encoding(nn.Module):
    def __init__(self, embed, pad_size, dropout, device):
        super(Positional_Encoding, self).__init__()
        self.device = device
        self.pe = torch.Tensor(
            [[pos / (10000.0 ** (i // 2 * 2.0 / embed)) for i in range(embed)] for pos in range(pad_size)])
        self.pe[:, 0::2] = np.sin(self.pe[:, 0::2])
        self.pe[:, 1::2] = np.cos(self.pe[:, 1::2])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = x + nn.Parameter(self.pe, requires_grad=False).to(self.device)
        out = self.dropout(out)
        return out


class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):
        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature

        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn


# 路由注意力辅助函数
# constants  kmean_init_iters
TOKEN_SELF_ATTN_VALUE = -5e4
KMEAN_INIT_ITERS = 10


def update_kmeans_on_backwards(module):
    module.kmean_modules = find_modules(module, Kmeans)

    def hook(_, grad_in, grad_out):
        for m in module.kmean_modules:
            m.update()

    return module.register_backward_hook(hook)


def find_modules(nn_module, type):
    return [module for module in nn_module.modules() if isinstance(module, type)]


def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max


def batched_index_select(values, indices):
    '''返回与q中每个词相应的质心向量矩阵'''

    # values[h,nc,d]  indices[b,h,len]
    last_dim = values.shape[-1]
    # gather(input, dim, index)：根据index，在dim维度上选取数据，输出的size与index 一致
    return values.gather(2, expand_dim(indices, -1, last_dim))


def expand_dim(t, dim, k):
    ''' 扩维'''

    t = t.unsqueeze(dim)
    expand_shape = [-1] * len(t.shape)
    expand_shape[dim] = k
    return t.expand(*expand_shape)


def scatter_mean(src, t, index, dim, eps=1e-5):
    '''第一步：相同索引号相同的词进行求和 第二步：取平均值'''

    # indices[b,h,nc*ws,d] t[b,h,nc*ws,d]
    numer = src.scatter_add(dim, index, t)
    denom = src.scatter_add(dim, index, torch.ones_like(t))
    return numer / (denom + eps)


def split_at_index(dim, index, t):
    pre_slices = (slice(None),) * dim
    l = (*pre_slices, slice(None, index))
    r = (*pre_slices, slice(index, None))
    return t[l], t[r]


def dists_and_buckets(q, means):
    # q[b,h,len,d] means[h,num_clusters,d]
    # 相似度dists[b,h,len,num_clusters]  buckets[b,h,len]：len为当前词距离最大的质心向量序号矩阵
    # dists = torch.einsum('bhld,hcd->bhlc', q, means)
    dists = einsum_to_matmul(q, means)
    _, buckets = torch.max(dists, dim=-1)
    return dists, buckets


def einsum_to_matmul(a, b, Trans=True):
    # a_size = a.size()
    # b_size = b.size()
    # a = a.view(-1, a_size[-2], a_size[-1])
    # b = b.view(-1, b_size[-2], b_size[-1])
    # if Trans:
    #     b = b.transpose(1, 2)
    # c = torch.matmul(a, b)
    # c_size = c.size()
    # o_size = a_size[:-2] + c_size[-2:]
    # c = c.view(*o_size)
    if Trans:
        b = b.transpose(-1, -2)
    c = torch.matmul(a, b)
    return c


def batched_bincount(index, num_clusters, dim=-1):
    '''# 统计分配给每个聚类簇中的成员个数'''

    # index[b,h,len]  out[b,h,num_clusters]
    out = index.new_zeros(index.size()[0], index.size()[1], num_clusters)
    out.scatter_add_(dim, index, torch.ones_like(index, dtype=index.dtype))
    return out


def kmeans_iter(q, means, buckets=None):
    '''迭代更新means'''

    b, h, l, d = q.size()
    dtype, num_clusters = q.dtype, means.shape[1]
    if buckets is None:
        # q[b,h,len,d] means[h,num_clusters,d] buckets[b,h,len] 返回的是每个词属于哪个簇的簇号
        _, buckets = dists_and_buckets(q, means)

    bins = batched_bincount(buckets, num_clusters).sum(0, keepdim=True)
    zero_mask = bins.long() == 0
    # means_[b,h,num_clusters,d]
    means_ = buckets.new_zeros(b, h, num_clusters, d, dtype=dtype)
    means_.scatter_add_(-2, expand_dim(buckets, -1, d), q)
    means_ = F.normalize(means_.sum(0, keepdim=True), dim=-1).type(dtype)
    # means[h,nc,d]:初始化聚类中心  means_[b=1,h,nc,d]:更新后的聚类中心
    means = torch.where(zero_mask.unsqueeze(-1), means, means_)
    means = means.squeeze(0)
    return means


def distribution(dists, window_size):
    '''# 对q进行聚类  每个簇中topk个词的索引'''

    # dists[b,h,len,nc]  topk_indices[b,h,window_size,nc]
    _, topk_indices = dists.topk(k=window_size, dim=-2)
    indices = topk_indices.transpose(-2, -1)
    return indices.reshape(*indices.size()[:2], -1)


class Kmeans(nn.Module):
    def __init__(self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4):
        super().__init__()
        self.commitment = commitment
        self.ema_decay = ema_decay

        self.register_buffer('means', torch.randn(num_heads, num_clusters, head_dim))
        self.register_buffer('initted', torch.tensor(False))
        self.num_new_means = 0
        self.new_means = None

    @torch.no_grad()
    def init(self, q):
        if self.initted:
            return
        # _, h, _, d, device, dtype = q.size(), q.device, q.dtype
        _, h, num_samples, d = q.size()
        device, dtype = q.device, q.dtype

        num_clusters = self.means.shape[1]
        if num_samples >= num_clusters:
            # torch.randperm:随机排列序列号
            indices = torch.randperm(num_samples, device=device)[:num_clusters]
        else:
            indices = torch.randint(0, num_samples, (num_clusters,), device=device)

        means = q[0, :, indices]

        for _ in range(KMEAN_INIT_ITERS):
            # q[b,h,len,d] means[h,nc,d]
            means = kmeans_iter(q, means)

        self.num_new_means = 0
        self.means.data.copy_(means)
        self.initted.data.copy_(torch.tensor(True))

    def forward(self, q, update_means=False):
        # 初始化聚类中心
        self.init(q)

        b, h, _, d = q.size()
        means = self.means.type_as(q)
        q = F.normalize(q, 2, dim=-1).type_as(q)

        with torch.no_grad():
            dists, buckets = dists_and_buckets(q, means)

        # q,routed_means[b,h,len,d]
        routed_means = batched_index_select(expand_dim(means, 0, b), buckets)
        loss = F.mse_loss(q, routed_means) * self.commitment

        def ema(old, new, decay):
            if old is None:
                return new
            return old * decay + new * (1 - decay)

        if update_means:
            with torch.no_grad():
                means = kmeans_iter(q, means, buckets)
            self.new_means = ema(self.new_means, means, self.num_new_means / (self.num_new_means + 1))
            self.num_new_means += 1

        return dists, loss


class RoutingAttention(nn.Module):
    ''' '''

    def __init__(self, dim_model, num_head, dropout=0.0, n_centroid=4, commitment=0.1):
        super().__init__()

        self.n_head = num_head
        self.d_k = dim_model // num_head
        self.d_v = self.d_k
        d_k = d_v = self.d_k
        d_model = dim_model
        n_head = num_head

        self.n_centroid = n_centroid
        self.w_qs = nn.Linear(d_model, n_head * d_k)
        self.w_ks = nn.Linear(d_model, n_head * d_k)
        self.w_vs = nn.Linear(d_model, n_head * d_v)

        self.kmeans = Kmeans(n_head, d_k, n_centroid, ema_decay=0.999, commitment=commitment)
        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
        # self.attention = SingleCoreAttention(np.power(d_k,0.5), self.d_v, n_head, dropout)
        self.layer_norm = nn.LayerNorm(d_model)

        self.fc = nn.Linear(n_head * d_v, d_model)
        nn.init.xavier_normal_(self.fc.weight)

        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(0.1)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        ''' 注意：共享QK的情况 '''

        d_k, d_v, n_head, n_centroid = self.d_k, self.d_v, self.n_head, self.n_centroid
        q = k = v = x
        bz, seq_len, _ = q.size()

        num_cluster = seq_len // n_centroid
        split_heads = lambda x: x.reshape(bz, -1, n_head, d_k).transpose(1, 2).contiguous()

        # shared qk
        qw = self.w_qs(q).view(bz, seq_len, n_head * d_k)
        kw = qw
        vw = self.w_vs(v).view(bz, seq_len, n_head * d_v)
        # print("q2.size::", qw.size())
        qw, kw, vw = map(split_heads, (qw, kw, vw))
        dists, aux_loss = self.kmeans(qw)
        indices = distribution(dists, num_cluster)
        kv_indices = indices

        qw = batched_index_select(qw, indices)
        kw = batched_index_select(kw, kv_indices)
        vw = batched_index_select(vw, kv_indices)

        reshape_with_window = lambda x: x.reshape(bz, n_head, n_centroid, -1, d_k)
        qw, kw, vw = map(reshape_with_window, (qw, kw, vw))

        # attention操作
        output = torch.zeros(bz, n_head, seq_len, d_k).type_as(qw)
        # attn = torch.einsum('bhnid,bhnjd->bhnij', qw, kw) * (d_k ** -0.5)
        attn = einsum_to_matmul(qw, kw, Trans=True) * (d_k ** -0.5)

        mask_value = max_neg_value(attn)

        '''# padding mask
        if mask is not None :
            mask_sort = expand_dim(mask, 1, n_head).gather(2, indices).contiguous().view(bz, n_head, n_centroid, num_cluster)     
            mask_sort = mask_sort.unsqueeze(3).repeat(1, 1, 1, num_cluster, 1) 
            attn = attn.masked_fill(mask_sort, -1e10)'''

        if mask is not None:
            q_mask = expand_dim(mask, 1, n_head).gather(2, indices)
            kv_mask = expand_dim(mask, 1, n_head).gather(2, kv_indices)
            q_mask, kv_mask = map(lambda t: t.reshape(bz, n_head, n_centroid, -1), (q_mask, kv_mask))
            mask = q_mask[:, :, :, :, None] * kv_mask[:, :, :, None, :]
            # print("q_mask[:, :, :, :, None]:/n", q_mask[:, :, :, :, None])
            # mask = F.pad(mask, (self.num_mem_kv, 0), value=True)
            attn.masked_fill_(~mask, mask_value)
            del mask

        '''if self.causal:
            q_mask, kv_mask = map(lambda t: t.reshape(bz, n_head, n_centroid, -1), (indices, kv_indices))
            mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :]
            #mask = F.pad(mask, (self.num_mem_kv, 0), value=True)
            attn.masked_fill_(~mask, mask_value)
            del mask   '''

        # shared_qk:
        '''q_mask, kv_mask = map(lambda t: t.reshape(bz, n_head, n_centroid, -1), (indices, kv_indices))
        mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :]
        attn.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE)
        del mask '''

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        # out = torch.einsum('bhcij,bhcjd->bhcid', attn, vw)
        out = einsum_to_matmul(attn, vw, Trans=False)
        out = torch.reshape(out, (bz, n_head, -1, d_k))
        output = scatter_mean(output, out, indices.unsqueeze(-1).expand_as(out), -2)
        output = output.transpose(1, 2).contiguous().view(bz, seq_len, -1)
        output = self.dropout(self.fc(output))

        return output, aux_loss


class Position_wise_Feed_Forward(nn.Module):
    def __init__(self, dim_model, hidden, dropout=0.0):
        super(Position_wise_Feed_Forward, self).__init__()
        self.fc1 = nn.Linear(dim_model, hidden)
        self.fc2 = nn.Linear(hidden, dim_model)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(dim_model)

    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.fc2(out)
        out = self.dropout(out)
        out = out + x  # 残差连接
        out = self.layer_norm(out)
        return out


