# -*- coding: utf-8 -*-

import math
from typing import Optional

import torch
import torch.jit as jit
import torch.nn as nn
import torch.nn.init as init

from ...common.dataclass_options import OptionsBase, argfield
from ..utils import sequence_mask
from .dropout import FeatureDropout


class ScaledDotProductAttention(jit.ScriptModule):
    __constants__ = ['temper']

    def __init__(self, d_model, attention_dropout=0.1):
        super().__init__()

        self.temper = d_model ** 0.5
        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim=1)

    @jit.script_method
    def forward(self, q, k, v, attn_mask):
        # q: [batch, slot, feat]
        # k: [batch, slot, feat]
        # v: [batch, slot, feat]

        attn = torch.bmm(q, k.transpose(1, 2)) / self.temper

        assert attn_mask.size() == attn.size()

        attn.masked_fill_(attn_mask, -float('inf'))

        # Transposes to avoid https://github.com/pytorch/pytorch/issues/4893
        attn = self.softmax(attn.transpose(1, 2)).transpose(1, 2)
        # Note that this makes the distribution not sum to 1. At some point it
        # may be worth researching whether this is the right way to apply
        # dropout to the attention.
        # Note that the t2t code also applies dropout in this manner
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn


class MultiHeadAttention(jit.ScriptModule):
    """
    Multi-head attention module
    """
    __constants__ = ['n_head', 'partitioned', 'd_v', 'w_qs', 'w_ks', 'w_vs', 'd_content']

    def __init__(self, n_head, d_model, d_k, d_v,
                 residual_dropout=0.1, attention_dropout=0.1, d_positional=None):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        if d_positional is None:
            self.partitioned = False
        else:
            self.partitioned = True

        if self.partitioned:
            self.d_content = d_model - d_positional
            self.d_positional = d_positional

            self.w_qs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))
            self.w_ks1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_k // 2))
            self.w_vs1 = nn.Parameter(torch.FloatTensor(n_head, self.d_content, d_v // 2))

            self.w_qs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))
            self.w_ks2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_k // 2))
            self.w_vs2 = nn.Parameter(torch.FloatTensor(n_head, self.d_positional, d_v // 2))

            init.xavier_normal_(self.w_qs1)
            init.xavier_normal_(self.w_ks1)
            init.xavier_normal_(self.w_vs1)

            init.xavier_normal_(self.w_qs2)
            init.xavier_normal_(self.w_ks2)
            init.xavier_normal_(self.w_vs2)
        else:
            self.d_content = d_model
            self.w_qs1 = self.w_qs2 = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
            self.w_ks1 = self.w_ks2 = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
            self.w_vs1 = self.w_vs2 = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v))

            init.xavier_normal_(self.w_qs1)
            init.xavier_normal_(self.w_ks1)
            init.xavier_normal_(self.w_vs1)

        self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

        if not self.partitioned:
            # The lack of a bias term here is consistent with the t2t code, though
            # in my experiments I have never observed this making a difference.
            self.proj1 = self.proj2 = nn.Linear(n_head * d_v, d_model, bias=False)
        else:
            self.proj1 = nn.Linear(n_head * (d_v // 2), self.d_content, bias=False)
            self.proj2 = nn.Linear(n_head * (d_v // 2), self.d_positional, bias=False)

        self.residual_dropout = FeatureDropout(residual_dropout)

    @jit.script_method
    def split_qkv_packed(self, inp):
        batch_size, max_length, feature_count = inp.shape
        input_2d = inp.contiguous().view(-1, feature_count)
        v_inp_repeated = input_2d.unsqueeze(0).repeat(self.n_head, 1, 1)
        qk_inp_repeated = v_inp_repeated

        if not self.partitioned:
            q_s = torch.bmm(qk_inp_repeated, self.w_qs1)  # n_head x len_inp x d_k
            k_s = torch.bmm(qk_inp_repeated, self.w_ks1)  # n_head x len_inp x d_k
            v_s = torch.bmm(v_inp_repeated, self.w_vs1)  # n_head x len_inp x d_v
        else:
            q_s = torch.cat([
                torch.bmm(qk_inp_repeated[:, :, :self.d_content], self.w_qs1),
                torch.bmm(qk_inp_repeated[:, :, self.d_content:], self.w_qs2),
            ], -1)
            k_s = torch.cat([
                torch.bmm(qk_inp_repeated[:, :, :self.d_content], self.w_ks1),
                torch.bmm(qk_inp_repeated[:, :, self.d_content:], self.w_ks2),
            ], -1)
            v_s = torch.cat([
                torch.bmm(v_inp_repeated[:, :, :self.d_content], self.w_vs1),
                torch.bmm(v_inp_repeated[:, :, self.d_content:], self.w_vs2),
            ], -1)
        return q_s.view(self.n_head * batch_size, max_length, -1), \
            k_s.view(self.n_head * batch_size, max_length, -1), \
            v_s.view(self.n_head * batch_size, max_length, -1)

    @jit.script_method
    def combine_v(self, outputs):
        # Combine attention information from the different heads
        n_head = self.n_head

        if not self.partitioned:
            # Switch from n_head x len_inp x d_v to len_inp x (n_head * d_v)
            outputs = torch.transpose(outputs, 0, 1).contiguous().view(-1, n_head * self.d_v)

            # Project back to residual size
            outputs = self.proj1(outputs)
        else:
            d_v1 = self.d_v // 2
            outputs1 = outputs[:, :, :d_v1]
            outputs2 = outputs[:, :, d_v1:]
            outputs1 = torch.transpose(outputs1, 0, 1).contiguous().view(-1, n_head * d_v1)
            outputs2 = torch.transpose(outputs2, 0, 1).contiguous().view(-1, n_head * d_v1)
            outputs = torch.cat([
                self.proj1(outputs1),
                self.proj2(outputs2),
            ], -1)

        return outputs

    @jit.script_method
    def forward(self, inp, sentence_mask):
        batch_size, max_length, feature_count = inp.shape
        residual = inp

        # While still using a packed representation, project to obtain the
        # query/key/value for each head
        q_padded, k_padded, v_padded = self.split_qkv_packed(inp)

        invalid_mask = 1 - sentence_mask
        attn_mask = invalid_mask.unsqueeze(1).expand(
            batch_size, max_length, max_length).repeat(self.n_head, 1, 1)

        outputs_padded, attns_padded = self.attention(
            q_padded, k_padded, v_padded,
            attn_mask=attn_mask
        )

        outputs = self.combine_v(outputs_padded.view(
            self.n_head, batch_size * max_length, self.d_v))

        outputs = outputs.view(batch_size, max_length, -1)

        outputs = self.residual_dropout(outputs)

        return self.layer_norm(outputs + residual)


class PositionwiseFeedForward(jit.ScriptModule):
    """
    A position-wise feed forward module.

    Projects to a higher-dimensional space before applying ReLU, then projects
    back.
    """

    def __init__(self, d_hid, d_ff, relu_dropout=0.1, residual_dropout=0.1,
                 leaky_relu_slope=0.1):
        super().__init__()

        self.w_1 = nn.Linear(d_hid, d_ff)
        self.w_2 = nn.Linear(d_ff, d_hid)

        self.layer_norm = nn.LayerNorm(d_hid, eps=1e-6)
        # The t2t code on github uses relu dropout, even though the transformer
        # paper describes residual dropout only. We implement relu dropout
        # because we always have the option to set it to zero.
        self.relu_dropout = FeatureDropout(relu_dropout)
        self.residual_dropout = FeatureDropout(residual_dropout)
        self.relu = nn.ReLU() if leaky_relu_slope == 0.0 else nn.LeakyReLU(leaky_relu_slope)

    @jit.script_method
    def forward(self, x, sentence_mask):
        residual = x

        output = self.w_1(x)
        output = self.relu_dropout(self.relu(output))
        output = self.w_2(output)

        output = self.residual_dropout(output)
        return self.layer_norm(output + residual)


class PartitionedPositionwiseFeedForward(jit.ScriptModule):
    __constants__ = ['d_content']

    def __init__(self, d_hid, d_ff, d_positional, relu_dropout=0.1, residual_dropout=0.1,
                 leaky_relu_slope=0.1):
        super().__init__()
        self.d_content = d_hid - d_positional
        self.w_1c = nn.Linear(self.d_content, d_ff // 2)
        self.w_1p = nn.Linear(d_positional, d_ff // 2)
        self.w_2c = nn.Linear(d_ff // 2, self.d_content)
        self.w_2p = nn.Linear(d_ff // 2, d_positional)
        self.layer_norm = nn.LayerNorm(d_hid, eps=1e-6)
        # The t2t code on github uses relu dropout, even though the transformer
        # paper describes residual dropout only. We implement relu dropout
        # because we always have the option to set it to zero.
        self.relu_dropout = FeatureDropout(relu_dropout)
        self.residual_dropout = FeatureDropout(residual_dropout)
        self.relu = nn.ReLU() if leaky_relu_slope == 0.0 else nn.LeakyReLU(leaky_relu_slope)

    @jit.script_method
    def forward(self, x, sentence_mask):
        residual = x
        xc = x[:, :, :self.d_content]
        xp = x[:, :, self.d_content:]

        outputc = self.w_1c(xc)
        outputc = self.relu_dropout(self.relu(outputc))
        outputc = self.w_2c(outputc)

        outputp = self.w_1p(xp)
        outputp = self.relu_dropout(self.relu(outputp))
        outputp = self.w_2p(outputp)

        output = torch.cat([outputc, outputp], -1)

        output = self.residual_dropout(output)
        return self.layer_norm(output + residual)


class TransformerEncoder(jit.ScriptModule):
    __constants__ = ['partitioned', 'num_layers', 'layers', 'use_timing_layer_norm']

    class Options(OptionsBase):
        num_layers: int = 2
        num_heads: int = 8
        d_kv: int = 32
        d_ff: int = 1024
        d_positional: Optional[int] = \
            argfield(None, help='Use partitioned transformer if it is not None')
        relu_dropout: float = 0.1
        residual_dropout: float = 0.1
        attention_dropout: float = 0.1
        timing_dropout: float = 0.0
        timing_method: str = argfield('embedding', choices=['embedding', 'sinusoidal'])
        use_timing_layer_norm: bool = False
        max_length: int = 512
        leaky_relu_slope: float = 0.1

    def __init__(self, options: Options, input_size):
        super().__init__()

        d_positional = options.d_positional
        d_model = input_size + (d_positional or 0)
        d_k = d_v = options.d_kv
        d_ff = options.d_ff

        self.output_size = d_model
        self.use_timing_layer_norm = options.use_timing_layer_norm
        self.partitioned = d_positional is not None

        attention_dropout = options.attention_dropout
        residual_dropout = options.residual_dropout
        relu_dropout = options.relu_dropout
        timing_dropout = options.timing_dropout
        leaky_relu_slope = options.leaky_relu_slope

        layers = []
        for i in range(options.num_layers):
            attn = MultiHeadAttention(options.num_heads, d_model, d_k, d_v,
                                      residual_dropout=residual_dropout,
                                      attention_dropout=attention_dropout,
                                      d_positional=d_positional)
            if d_positional is None:
                ff = PositionwiseFeedForward(d_model, d_ff,
                                             relu_dropout=relu_dropout,
                                             residual_dropout=residual_dropout,
                                             leaky_relu_slope=leaky_relu_slope)
            else:
                ff = PartitionedPositionwiseFeedForward(d_model, d_ff, d_positional,
                                                        relu_dropout=relu_dropout,
                                                        residual_dropout=residual_dropout,
                                                        leaky_relu_slope=leaky_relu_slope)

            layers.append(attn)
            layers.append(ff)

        self.layers = nn.ModuleList(layers)
        self.timing_dropout = FeatureDropout(timing_dropout)

        position_size = d_positional if d_positional is not None else input_size
        timing_method = options.timing_method
        max_length = options.max_length
        if timing_method == 'embedding':
            # Learned embeddings
            self.position_table = \
                nn.Parameter(torch.FloatTensor(max_length, position_size))
            init.normal_(self.position_table)
        else:
            assert timing_method == 'sinusoidal'
            self.position_table = \
                nn.Parameter(torch.zeros(max_length, position_size), requires_grad=False)

            position = torch.arange(0, max_length).view(-1).unsqueeze(-1).float()
            div_term = torch.exp((torch.arange(0, input_size, 2, dtype=torch.float) *
                                  -(math.log(10000.0) / input_size)))
            self.position_table[:, 0::2] = torch.sin(position * div_term)
            self.position_table[:, 1::2] = torch.cos(position * div_term)

        self.timing_layer_norm = nn.LayerNorm(position_size, eps=1e-6)

    @jit.script_method
    def forward(self, inputs, lengths_or_mask, use_mask: bool = False):
        if not use_mask:
            sentence_mask = sequence_mask(lengths_or_mask, inputs.shape[1])
        else:
            sentence_mask = lengths_or_mask

        timing_signal = self.position_table[:inputs.shape[1], :]
        timing_signal = timing_signal.unsqueeze(0).repeat(inputs.shape[0], 1, 1)
        timing_signal = self.timing_dropout(timing_signal)

        if self.use_timing_layer_norm:
            timing_signal = self.timing_layer_norm(timing_signal)

        if self.partitioned:
            inputs = torch.cat([inputs, timing_signal], dim=-1)
        else:
            inputs += timing_signal

        for layer in self.layers:
            inputs = layer(inputs, sentence_mask)

        return inputs, None
