# -*- coding: utf-8 -*-

import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from ...common.dataclass_options import OptionsBase, argfield
from ..activations import get_activation
from ..utils import sequence_mask
from .dropout import FeatureDropout


class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model, attention_dropout=0.1):
        super().__init__()

        self.temper = d_model ** 0.5
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, q, k, v, attention_mask):
        # q: [batch, length_q, *]
        # k: [batch, length_kv, *]
        # v: [batch, length_kv, *]

        attn = torch.bmm(q, k.transpose(1, 2)) / self.temper

        assert attention_mask.size() == attn.size()

        attn = attn.masked_fill(attention_mask, -float('inf'))

        attn = F.softmax(attn, dim=2)
        # Note that this makes the distribution not sum to 1. At some point it
        # may be worth researching whether this is the right way to apply
        # dropout to the attention.
        # Note that the t2t code also applies dropout in this manner
        attn = self.dropout(attn)
        outputs = torch.bmm(attn, v)

        return outputs, attn


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, d_model, d_k, d_v, attention_dropout=0.1, d_positional=None):
        super().__init__()

        self.num_heads = num_heads
        self.d_k = d_k
        self.d_v = d_v

        self.partitioned = (d_positional is not None)

        if self.partitioned:
            assert d_k % 2 == 0 and d_v % 2 == 0
            self.d_content = d_model - d_positional
            self.d_positional = d_positional

            self.w_qs1 = nn.Parameter(torch.FloatTensor(num_heads, self.d_content, d_k // 2))
            self.w_ks1 = nn.Parameter(torch.FloatTensor(num_heads, self.d_content, d_k // 2))
            self.w_vs1 = nn.Parameter(torch.FloatTensor(num_heads, self.d_content, d_v // 2))

            self.w_qs2 = nn.Parameter(torch.FloatTensor(num_heads, self.d_positional, d_k // 2))
            self.w_ks2 = nn.Parameter(torch.FloatTensor(num_heads, self.d_positional, d_k // 2))
            self.w_vs2 = nn.Parameter(torch.FloatTensor(num_heads, self.d_positional, d_v // 2))

            self.proj1 = nn.Linear(num_heads * (d_v // 2), self.d_content, bias=False)
            self.proj2 = nn.Linear(num_heads * (d_v // 2), self.d_positional, bias=False)
        else:
            self.d_content = d_model
            self.w_qs2 = self.w_qs1 = nn.Parameter(torch.FloatTensor(num_heads, d_model, d_k))
            self.w_ks2 = self.w_ks1 = nn.Parameter(torch.FloatTensor(num_heads, d_model, d_k))
            self.w_vs2 = self.w_vs1 = nn.Parameter(torch.FloatTensor(num_heads, d_model, d_v))

            # The lack of a bias term here is consistent with the t2t code, though
            # in my experiments I have never observed this making a difference.
            self.proj2 = self.proj1 = nn.Linear(num_heads * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(d_model, attention_dropout=attention_dropout)

        self.reset_parameters()

    def reset_parameters(self):
        init.xavier_normal_(self.w_qs1)
        init.xavier_normal_(self.w_ks1)
        init.xavier_normal_(self.w_vs1)

        if self.partitioned:
            init.xavier_normal_(self.w_qs2)
            init.xavier_normal_(self.w_ks2)
            init.xavier_normal_(self.w_vs2)

    def split_qkv_packed(self, query, key=None, value=None):
        num_heads = self.num_heads
        batch_size, length_q, d_model = query.size()

        query = query.view(1, -1, d_model).repeat(num_heads, 1, 1)
        if key is None:
            key = query
            length_kv = length_q
        else:
            key = key.view(1, -1, d_model).repeat(num_heads, 1, 1)
            length_kv = key.size(1)

        if value is None:
            value = key
        else:
            value = value.view(1, -1, d_model).repeat(num_heads, 1, 1)

        if not self.partitioned:
            # num_heads inputs (batch_size inputs length_q) inputs d_k
            q_s = torch.bmm(query, self.w_qs1)
            # num_heads inputs (batch_size inputs length_kv) inputs d_k
            k_s = torch.bmm(key, self.w_ks1)
            # num_heads inputs (batch_size inputs length_kv) inputs d_v
            v_s = torch.bmm(value, self.w_vs1)
        else:
            d_content = self.d_content
            q_s = torch.cat([
                torch.bmm(query[:, :, :d_content], self.w_qs1),
                torch.bmm(query[:, :, d_content:], self.w_qs2),
            ], -1)
            k_s = torch.cat([
                torch.bmm(key[:, :, :d_content], self.w_ks1),
                torch.bmm(key[:, :, d_content:], self.w_ks2),
            ], -1)
            v_s = torch.cat([
                torch.bmm(value[:, :, :d_content], self.w_vs1),
                torch.bmm(value[:, :, d_content:], self.w_vs2),
            ], -1)

        return (q_s.view(num_heads * batch_size, length_q, -1),
                k_s.view(num_heads * batch_size, length_kv, -1),
                v_s.view(num_heads * batch_size, length_kv, -1))

    def combine_v(self, outputs):  # [num_heads, batch_size * length, d_kv]
        # Combine attention information from the different heads
        num_heads = self.num_heads

        if not self.partitioned:
            outputs = outputs.transpose(0, 1).contiguous().view(-1, num_heads * self.d_v)
            # Project back to residual size
            outputs = self.proj1(outputs)
        else:
            d_v1 = self.d_v // 2
            outputs = torch.cat([
                self.proj1(outputs[:, :, :d_v1]
                           .transpose(0, 1)
                           .contiguous().view(-1, num_heads * d_v1)),
                self.proj2(outputs[:, :, d_v1:]
                           .transpose(0, 1)
                           .contiguous().view(-1, num_heads * d_v1)),
            ], -1)

        return outputs

    def forward(self, key, value, query, mask):
        """
        Compute the context vector and the attention vectors.
        Args:
           key (FloatTensor): set of `length_kv`
               key vectors ``(batch, length_kv, d_model)``
           value (FloatTensor): set of `length_kv`
               value vectors ``(batch, length_kv, d_model)``
           query (FloatTensor): set of `length_q`
               query vectors  ``(batch, length_q, d_model)``
           mask: binary mask 1/0 indicating which keys have
               zero / non-zero attention ``(batch, length_q, length_kv)``

        Returns:
           (FloatTensor, FloatTensor):
           * outputs context vectors ``(batch, length_q, d_model)``
           * one of the attention vectors ``(batch, length_q, length_kv)``
        """
        num_heads = self.num_heads

        # query/key/value for each head
        q_padded, k_padded, v_padded = self.split_qkv_packed(query, key, value)

        outputs_padded, attn = self.attention(q_padded, k_padded, v_padded, attention_mask=mask)

        outputs = self.combine_v(outputs_padded.view(num_heads, -1, self.d_v))

        return outputs.view_as(query), attn


class SelfAttention(MultiHeadAttention):
    def __init__(self, num_heads, d_model, d_k, d_v,
                 residual_dropout=0.1, attention_dropout=0.1, d_positional=None,
                 use_norm_after_input=False):
        super().__init__(num_heads, d_model, d_k, d_v,
                         attention_dropout=attention_dropout, d_positional=d_positional)

        self.use_norm_after_input = use_norm_after_input
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

        self.residual_dropout = FeatureDropout(residual_dropout)

    def forward(self, inputs, input_mask):
        residual = inputs

        if self.use_norm_after_input:
            inputs = self.layer_norm(inputs)

        invalid_mask = ~input_mask
        attention_mask = invalid_mask.unsqueeze(1).repeat(self.num_heads, inputs.size(1), 1)

        outputs, _ = super().forward(None, None, inputs, attention_mask)

        outputs = self.residual_dropout(outputs) + residual

        if self.use_norm_after_input:
            return outputs

        return self.layer_norm(outputs)


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, activation,
                 output_dropout=0.1, residual_dropout=0.1,
                 use_norm_after_input=False):
        super().__init__()

        self.proj1 = nn.Linear(d_model, d_ff)
        self.proj2 = nn.Linear(d_ff, d_model)

        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        # The t2t code on github uses relu dropout, even though the transformer
        # paper describes residual dropout only. We implement relu dropout
        # because we always have the option to set it to zero.
        self.output_dropout = FeatureDropout(output_dropout)
        self.residual_dropout = FeatureDropout(residual_dropout)

        self.activation = activation
        self.use_norm_after_input = use_norm_after_input

    def forward(self, inputs, input_mask):
        residual = inputs

        if self.use_norm_after_input:
            inputs = self.layer_norm(inputs)

        outputs = self.proj1(inputs)
        outputs = self.output_dropout(self.activation(outputs))
        outputs = self.proj2(outputs)

        outputs = self.residual_dropout(outputs) + residual

        if self.use_norm_after_input:
            return outputs

        return self.layer_norm(outputs)


class PartitionedPositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, d_ff, activation, d_positional,
                 output_dropout=0.1, residual_dropout=0.1,
                 use_norm_after_input=False):
        super().__init__()

        self.d_content = d_model - d_positional

        self.proj1c = nn.Linear(self.d_content, d_ff // 2)
        self.proj1p = nn.Linear(d_positional, d_ff // 2)
        self.proj2c = nn.Linear(d_ff // 2, self.d_content)
        self.proj2p = nn.Linear(d_ff // 2, d_positional)

        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

        # The t2t code on github uses relu dropout, even though the transformer
        # paper describes residual dropout only. We implement relu dropout
        # because we always have the option to set it to zero.
        self.output_dropout = FeatureDropout(output_dropout)
        self.residual_dropout = FeatureDropout(residual_dropout)

        self.activation = activation
        self.use_norm_after_input = use_norm_after_input

    def forward(self, inputs, input_mask):
        residual = inputs

        if self.use_norm_after_input:
            inputs = self.layer_norm(inputs)

        xc = inputs[:, :, :self.d_content]
        xp = inputs[:, :, self.d_content:]

        outputc = self.proj1c(xc)
        outputc = self.output_dropout(self.activation(outputc))
        outputc = self.proj2c(outputc)

        outputp = self.proj1p(xp)
        outputp = self.output_dropout(self.activation(outputp))
        outputp = self.proj2p(outputp)

        outputs = self.residual_dropout(torch.cat([outputc, outputp], -1)) + residual

        if self.use_norm_after_input:
            return outputs

        return self.layer_norm(outputs)


class TransformerEncoder(nn.Module):
    class Options(OptionsBase):
        num_layers: int = 2
        num_heads: int = 8
        d_kv: int = 32
        d_ff: int = 1024
        d_positional: Optional[int] = \
            argfield(None, help='Use partitioned transformer if it is not None')

        output_dropout: float = 0.1
        residual_dropout: float = 0.1
        attention_dropout: float = 0.1

        timing_dropout: float = 0.0
        timing_method: str = argfield('embedding', choices=['embedding', 'sinusoidal'])
        use_timing_layer_norm: bool = False

        max_length: int = 512

        activation: str = 'leaky_relu/0.1'
        use_norm_after_input: bool = False

    def __init__(self, options: Options, input_size):
        super().__init__()

        d_positional = options.d_positional
        d_model = input_size + (d_positional or 0)
        d_k = d_v = options.d_kv
        d_ff = options.d_ff

        self.output_size = d_model
        self.use_timing_layer_norm = options.use_timing_layer_norm
        self.partitioned = d_positional is not None

        if d_positional is None:
            ff_class = PositionwiseFeedForward
            extra_args = {}
        else:
            ff_class = PartitionedPositionwiseFeedForward
            extra_args = {'d_positional': d_positional}

        use_norm_after_input = options.use_norm_after_input

        layers = []
        for _ in range(options.num_layers):
            activation = get_activation(options.activation)

            layers.append(SelfAttention(options.num_heads, d_model, d_k, d_v,
                                        residual_dropout=options.residual_dropout,
                                        attention_dropout=options.attention_dropout,
                                        d_positional=d_positional,
                                        use_norm_after_input=use_norm_after_input))
            layers.append(ff_class(d_model, d_ff, activation,
                                   output_dropout=options.output_dropout,
                                   residual_dropout=options.residual_dropout,
                                   use_norm_after_input=use_norm_after_input,
                                   **extra_args))

        self.layers = nn.ModuleList(layers)

        self.timing_dropout = FeatureDropout(options.timing_dropout)

        position_size = d_positional if d_positional is not None else input_size
        max_length = options.max_length

        timing_method = options.timing_method
        if timing_method == 'embedding':  # Learned embeddings
            self.position_table = nn.Parameter(torch.FloatTensor(max_length, position_size))
            init.normal_(self.position_table)
        else:
            assert timing_method == 'sinusoidal'

            position_table = torch.zeros(max_length, position_size)

            position = torch.arange(0, max_length, dtype=torch.float).view(-1, 1)
            div_term = torch.exp((torch.arange(0, input_size, 2, dtype=torch.float) *
                                  -(math.log(10000.0) / input_size)))
            position_table[:, 0::2] = torch.sin(position * div_term)
            position_table[:, 1::2] = torch.cos(position * div_term)

            self.position_table = nn.Parameter(position_table, requires_grad=False)

        self.timing_layer_norm = nn.LayerNorm(position_size, eps=1e-6)

    def forward(self, inputs, lengths_or_mask, use_mask=False):
        if not use_mask:
            input_mask = sequence_mask(lengths_or_mask, inputs.size(1))
        else:
            input_mask = lengths_or_mask

        timing_signal = self.position_table[:inputs.size(1), :]
        timing_signal = timing_signal.unsqueeze(0).repeat(inputs.size(0), 1, 1)
        timing_signal = self.timing_dropout(timing_signal)
        if self.use_timing_layer_norm:
            timing_signal = self.timing_layer_norm(timing_signal)

        if self.partitioned:
            inputs = torch.cat([inputs, timing_signal], -1)
        else:
            # NOTE: do not use +=, it will change original inputs
            inputs = inputs + timing_signal

        for layer in self.layers:
            inputs = layer(inputs, input_mask)

        return inputs, None
