# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from ...common.dataclass_options import BranchSelect, OptionsBase
from ...common.logger import LOGGER
from ..activations import get_activation
from ..utils import sort_and_pack_sequences, unpack_and_unsort_sequences
from .transformer_encoder import TransformerEncoder


def make_mlp_layers(input_size, output_size, hidden_sizes=(),
                    dropout=0, activation='relu',
                    use_layer_norm=False,
                    use_last_bias=True,
                    use_last_dropout=False,
                    use_last_layer_norm=False,
                    use_last_activation=False):
    sizes = [input_size, *hidden_sizes, output_size]
    module_list = []
    for index, (in_size, out_size) in enumerate(zip(sizes[:-1], sizes[1:])):
        not_last_layer = (index != len(sizes) - 2)
        use_bias = not_last_layer or use_last_bias

        linear = nn.Linear(in_size, out_size, use_bias)

        init.xavier_normal_(linear.weight)
        if use_bias:
            init.zeros_(linear.bias)

        module_list.append(linear)
        if use_layer_norm and (not_last_layer or use_last_layer_norm):
            module_list.append(nn.LayerNorm(out_size))

        if dropout != 0 and (not_last_layer or use_last_dropout):
            module_list.append(nn.Dropout(dropout))

        if not_last_layer or use_last_activation:
            module_list.append(get_activation(activation))

    return nn.Sequential(*module_list)


class ModuleRef(nn.Module):
    def __init__(self, name, model):
        super().__init__()

        self.name = name
        self._ref = (model, )

    def extra_repr(self):
        return f'&{self.name}'

    def forward(self, *args, **kwargs):
        return self._ref[0](*args, **kwargs)

    def __getattr__(self, name):
        return getattr(self._ref[0], name)


class LSTMLayer(nn.Module):
    DEFAULT_CELL = torch.nn.LSTM

    class Options(OptionsBase):
        hidden_size: int = 500  # LSTM dimension
        num_layers: int = 2
        input_keep_prob: float = 0.5
        recurrent_keep_prob: float = 1
        use_layer_norm: bool = False
        input_dropout: float = 0
        bidirectional: bool = True

    def __init__(self, options: Options, input_size):
        super().__init__()

        if options.recurrent_keep_prob != 1.0:
            raise NotImplementedError('Pytorch RNN does not support recurrent dropout.')

        hidden_size = options.hidden_size
        self.rnn = self.DEFAULT_CELL(input_size=input_size,
                                     hidden_size=hidden_size,
                                     num_layers=options.num_layers,
                                     batch_first=True,
                                     dropout=1 - options.input_keep_prob,
                                     bidirectional=options.bidirectional)

        if options.use_layer_norm:
            LOGGER.warning('Pytorch RNN only support layer norm at last layer.')
            self.layer_norm = nn.LayerNorm(hidden_size * 2)
        else:
            self.layer_norm = None

        input_dropout = options.input_dropout
        if input_dropout != 0:
            self.input_dropout = nn.Dropout(input_dropout)
        else:
            self.input_dropout = None

        self.reset_parameters()

        self.output_size = hidden_size * (2 if options.bidirectional else 1)

    def reset_parameters(self):
        for name, param in self.rnn.named_parameters():
            if 'bias' in name:
                # set bias to 0
                param.data.fill_(0)
            else:
                init.orthogonal_(param.data)

    def forward(self, seqs, lengths, return_sequence=True, is_sorted=False):
        if self.input_dropout is not None:
            seqs = self.input_dropout(seqs)

        packed_seqs, unsort_indices = sort_and_pack_sequences(seqs, lengths, is_sorted=is_sorted)

        outputs_pack, state_n = self.rnn(packed_seqs)

        if isinstance(state_n, tuple) and len(state_n) == 2:
            # LSTM
            h_n, c_n = state_n
        else:
            # GRU
            h_n = state_n

        num_layers = self.rnn.num_layers
        _, batch_size, hidden_size = h_n.shape
        # shape: [num_layers, num_directions, batch_size, hidden_size]
        h_n = h_n.view(num_layers, -1, *h_n.shape[1:])
        # only keep output of last_layer
        h_n = h_n[-1].transpose(0, 1).contiguous().view(batch_size, -1)

        if seqs.shape[0] is not None and batch_size < seqs.shape[0]:
            h_n = F.pad(h_n, [0, 0, 0, seqs.size(0) - batch_size])
        if not is_sorted:
            h_n = h_n.index_select(0, unsort_indices)

        if return_sequence:  # return seqs
            outputs = unpack_and_unsort_sequences(outputs_pack, unsort_indices, seqs.shape)

            if self.layer_norm is not None:
                outputs = self.layer_norm(outputs)

            # Another equal method to compute h_n
            # if self.rnn.bidirectional:
            #     indices = torch.arange(0, seqs.size(0), device=seqs.device)
            #     h_n_fw = outputs[indices, torch.max(torch.zeros_like(lengths), lengths - 1)]
            #     h_n_bw = outputs[:, 0]
            #     h_n = torch.cat([h_n_fw[:, :hidden_size], h_n_bw[:, :hidden_size]], dim=1)

            return outputs, h_n

        return h_n


class GRULayer(LSTMLayer):
    DEFAULT_CELL = nn.GRU


class ConvolutionLayer(nn.Module):

    class Options(OptionsBase):
        num_layers: int = 5
        kernel_size: int = 3
        channels: int = 300
        input_dropout: float = 0.1
        hidden_dropout: float = 0.1

    def __init__(self, options, input_size):
        super().__init__()

        kernel_size = options.kernel_size
        channels = options.channels
        # get rid of annoying padding problems
        assert kernel_size % 2 == 1

        self.conv = nn.ModuleList([
            nn.Conv1d(input_size if i == 0 else channels, channels,
                      kernel_size,
                      padding=(kernel_size - 1) // 2)
            for i in range(options.num_layers)])

        input_dropout = options.input_dropout
        layer_dropout = options.layer_dropout
        if input_dropout > 0:
            self.input_dropout = nn.Dropout(input_dropout, inplace=True)
        else:
            self.input_dropout = None

        if layer_dropout > 0:
            self.layer_dropout = nn.Dropout(layer_dropout, inplace=True)
        else:
            self.layer_dropout = None

        self.output_size = channels

    def forward(self, seqs, lengths, _return_sequence=True, _is_sorted=False):
        device = seqs.device
        batch_size, _, feature_count = seqs.shape

        seqs_NCL = seqs.transpose(1, 2)  # N x Channel x Length
        if self.input_dropout is not None:
            noise = torch.ones((batch_size, feature_count, 1), device=device)
            seqs_NCL *= self.input_dropout(noise)

        outputs_NCL = seqs_NCL
        for layer in self.conv:
            outputs_NCL = layer(outputs_NCL)
            if self.layer_dropout is not None:
                noise = torch.ones((batch_size, outputs_NCL.shape[1], 1), device=device)
                outputs_NCL *= self.layer_dropout(noise)
            outputs_NCL = F.relu(outputs_NCL)

        return outputs_NCL.transpose(1, 2).contiguous()


class ContextualOptions(BranchSelect):
    type = 'lstm'
    branches = {
        'lstm': LSTMLayer,
        'gru': GRULayer,
        'conv': ConvolutionLayer,
        'transformer': TransformerEncoder
    }
