import torch, math
import torch.nn as nn
from utils.global_attention import GlobalAttention

class Salience(nn.Module):
    def __init__(self, model_opt, coverage_attn=False,
                 attn_type="general", attn_func="softmax",
                 eps=1e-20):
        super(Salience, self).__init__()
        # self.attn = GlobalAttention(
        #     self.hidden_size, coverage=False,
        #     attn_type="general", attn_func="softmax"
        # )
        self.in_feat = int(model_opt["enc_rnn_size"])*2
        # self.self_attn = MultiHeadAttention(self.in_feat)
        self.attention = GlobalAttention(
                self.in_feat, coverage=coverage_attn,
                attn_type=attn_type, attn_func=attn_func
            )
        # self.linear_1 = nn.Linear(self.in_feat//2, 1)
        # self.relu = nn.ReLU()
        self.gpu = model_opt["gpu"] == "True"
        self.eps = eps

    def forward(self, src_doc, tgt_sent, tgt_sent_num, sent_end_idx, src_max_len):
        '''
        :param src_doc: shape [batch, enc_size*2]
        :param tgt_sent: shape [batch, sent_num, enc_size*2]
        :param end_idx: shape [batch, sent_num]
        :return:
        '''
        batch, _ = sent_end_idx.size()
        p_attn, _ = self.attention(
            src_doc,
            tgt_sent,
            memory_lengths=tgt_sent_num) # how many sentences in each batch
        p_attn = p_attn + self.eps

        # sent_emb = self.self_attn(x, sent_end_idx) # [batch, sent_num, sent_feat]
        # x = self.relu(sent_emb)
        # x = self.linear_1(x).squeeze(2)
        # x = torch.softmax(x, dim=-1)
        if self.gpu:
            output = torch.zeros((batch, src_max_len)).cuda()
        else:
            output = torch.zeros((batch, src_max_len))
        sent_count = torch.sum((sent_end_idx != 0).int(), dim=1) # shape: [batch_size]
        for i in range(batch):
            start_idx = 0
            for j in range(sent_count[i]):
                end_idx = start_idx+sent_end_idx[i, j]
                output[i, start_idx: end_idx] = p_attn[i, j]
                start_idx = end_idx
            if start_idx < src_max_len:
                output[i, start_idx:] = torch.min(p_attn[i])
        return p_attn, output # shape: [batch, src_max_len]

class MultiHeadAttention(nn.Module):

    def __init__(self, hidden_size, num_heads=1, dropout=0.9):
        super(MultiHeadAttention, self).__init__()

        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.dropout = dropout

        self.qkv_transform = Affine(hidden_size, 3 * hidden_size)
        self.o_transform = Affine(hidden_size, hidden_size // 2)

        self.reset_parameters()

    def forward(self, query, sent_end_idx, bias=None):
        '''
        :param query: sentence embedding [batch, sent_num, sent_feat]
        :param bias:
        :return:
        '''
        qkv = self.qkv_transform(query) # shape [batch, sent_num, sent_feat*3]
        q, k, v = torch.split(qkv, self.hidden_size, dim=-1)
        del qkv
        # split heads
        qh = self.split_heads(q, self.num_heads) # [batch, nhead, sent_num, sent_feat]
        kh = self.split_heads(k, self.num_heads) # [batch, nhead, sent_num, sent_feat]
        vh = self.split_heads(v, self.num_heads) # [batch, nhead, sent_num, sent_feat]
        del q, k, v
        # scale query
        qh = qh * (self.hidden_size // self.num_heads) ** -0.5

        # dot-product attention
        kh = torch.transpose(kh, -2, -1) # [batch, nhead, sent_feat, sent_num]
        logits = torch.matmul(qh, kh) # [batch, nhead, sent_num, sent_num]
        mask = sent_end_idx == 0
        # vh.masked_fill_(mask[:,None,:,None], 0)
        if bias is not None:
            logits = logits + bias
        logits.masked_fill_(mask[:,None, None, :], 0)
        weights = torch.nn.functional.dropout(torch.softmax(logits, dim=-1),
                                              p=self.dropout,
                                              training=self.training) # [batch, nhead, sent_num, sent_num]

        x = torch.matmul(weights, vh) # [batch, nhead, sent_num, sent_feat]

        # combine heads
        output = self.o_transform(self.combine_heads(x))

        return output # [batch, sent_num, sent_feat]

    def reset_parameters(self, initializer="orthogonal"):
        if initializer == "orthogonal":
            self.qkv_transform.orthogonal_initialize()
            self.o_transform.orthogonal_initialize()
        else:
            # 6 / (4 * hidden_size) -> 6 / (2 * hidden_size)
            nn.init.xavier_uniform_(self.qkv_transform.weight)
            nn.init.xavier_uniform_(self.o_transform.weight)
            nn.init.constant_(self.qkv_transform.bias, 0.0)
            nn.init.constant_(self.o_transform.bias, 0.0)

    @staticmethod
    def split_heads(x, heads):
        batch = x.shape[0]
        length = x.shape[1]
        channels = x.shape[2]

        y = torch.reshape(x, [batch, length, heads, channels // heads])
        return torch.transpose(y, 2, 1)

    @staticmethod
    def combine_heads(x):
        batch = x.shape[0]
        heads = x.shape[1]
        length = x.shape[2]
        channels = x.shape[3]

        y = torch.transpose(x, 2, 1)

        return torch.reshape(y, [batch, length, heads * channels])

class Affine(nn.Module):

    def __init__(self, in_features, out_features, bias=True):
        super(Affine, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def orthogonal_initialize(self, gain=1.0):
        nn.init.orthogonal_(self.weight, gain)
        nn.init.zeros_(self.bias)

    def forward(self, input):
        return nn.functional.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )