# -*- coding: utf-8 -*-

from typing import List

import torch
import torch.nn as nn

from ...common.dataclass_options import BranchSelect, OptionsBase, argfield
from .dropout import FeatureDropout
from .sequential import LSTMLayer


class CharLSTMLayer(nn.Module):
    class Options(OptionsBase):
        num_layers: int = 2
        dropout: float = 0

    def __init__(self, options, input_size, output_size):
        super().__init__()

        assert output_size % 2 == 0

        lstm_options = LSTMLayer.Options()
        lstm_options.hidden_size = output_size // 2
        lstm_options.num_layers = options.num_layers
        lstm_options.input_keep_prob = 1 - options.dropout
        lstm_options.input_dropout = options.dropout

        self.char_lstm = LSTMLayer(lstm_options, input_size)

    def forward(self, char_embeded_4d, char_lengths, reuse=False):
        batch_size, max_seq_length, max_characters, embed_size = char_embeded_4d.shape

        char_lengths_1d = char_lengths.view(-1)
        char_embeded_3d = char_embeded_4d.view(-1, max_characters, embed_size)

        # shape: [batch_size * max_seq_length, 2 * hidden_size]
        output = self.char_lstm(char_embeded_3d, char_lengths_1d, return_sequence=False)
        output = output.view(batch_size, max_seq_length, -1)

        return output


# TODO cnn layers

class CharacterEmbeddingOptions(BranchSelect):
    type = 'rnn'
    branches = {'rnn': CharLSTMLayer}


class SentenceEmbeddings(nn.Module):
    class Options(OptionsBase):
        word_size: int = 100
        extra_property_names: List[str] = argfield(default_factory=list)
        extra_property_sizes: List[int] = argfield(default_factory=list)

        word_dropout: float = 0.4
        extra_property_dropout: float = 0.2

        char_size: int = 0
        char_lookup_size: int = 50
        character: CharacterEmbeddingOptions

        use_layer_norm: bool = True
        mode: str = argfield('concat', choices=['add', 'concat'])

        replace_unk_with_chars: bool = False

        def create(self, vocabs, plugins=None):
            return SentenceEmbeddings(self, vocabs, plugins)

    def __init__(self, options: Options, vocabs, plugins=None):
        super().__init__()

        self.options = options
        self.mode = options.mode
        self.replace_unk_with_chars = options.replace_unk_with_chars

        self.plugins = nn.ModuleList(plugins or [])

        input_sizes = {}
        if options.word_size != 0:  # setup word embedding
            vocab = vocabs.get('word')
            self.word_embeddings = \
                nn.Embedding(len(vocab), options.word_size, padding_idx=vocab.pad_id)

            self.word_unk_id = vocab.pad_id
            self.word_dropout = FeatureDropout(options.word_dropout)
            input_sizes['word'] = options.word_size

        if options.char_size > 0:  # setup char embedding
            char_size = options.char_size
            char_lookup_size = options.char_lookup_size

            vocab = vocabs.get('char')
            self.char_lookup = \
                nn.Embedding(len(vocab), char_lookup_size, padding_idx=vocab.pad_id)

            self.char_embeded = options.character.create(char_lookup_size, options, char_size)

            if not options.replace_unk_with_chars:
                input_sizes['char'] = options.char_size
            else:
                assert options.char_size == options.char_size
        else:
            self.char_lookup = None

        extra_embeddings = {}
        for name, size in zip(options.extra_property_names,
                              options.extra_property_sizes):  # setup extra embeddings
            vocab = vocabs.get(name)
            extra_embeddings[name] = nn.Embedding(len(vocab), size, padding_idx=vocab.pad_id)
            input_sizes[name] = size

        self.extra_property_names = options.extra_property_names
        if self.extra_property_names:
            self.extra_dropout = FeatureDropout(options.extra_property_dropout)

        for index, plugin in enumerate(self.plugins):
            input_sizes[f'plugin_{index}'] = plugin.output_size

        if options.mode == 'concat':
            self.output_size = sum(input_sizes.values())
        else:
            assert options.mode == 'add'
            uniq_input_dims = set(input_sizes.values())
            if len(uniq_input_dims) != 1:
                raise RuntimeError(f'Different input sizes {input_sizes} in mode "add"')
            self.output_size = uniq_input_dims.pop()

        self.extra_embeddings = nn.ModuleDict(extra_embeddings)
        self.layer_norm = \
            nn.LayerNorm(self.output_size, eps=1e-6) if options.use_layer_norm else None

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_normal_(self.word_embeddings.weight.data)
        if self.char_lookup is not None:
            torch.nn.init.xavier_normal_(self.char_lookup.weight.data)
        for embedding in self.extra_embeddings.values():
            torch.nn.init.xavier_normal_(embedding.weight.data)

    def forward(self, inputs):
        all_features = []

        if self.char_lookup is not None:
            # use character embedding instead
            # shape: [batch_size, seq_length, word_length, feature_count]
            char_embeded_4d = self.char_lookup(inputs.chars)
            word_embeded_by_char = self.char_embeded(char_embeded_4d, inputs.word_lengths)
            if not self.replace_unk_with_chars:
                all_features.append(word_embeded_by_char)

        if self.options.word_size != 0:
            word_embedding = self.word_dropout(self.word_embeddings(inputs.words))
            if self.char_lookup is not None and self.replace_unk_with_chars:
                unk = inputs.words.eq(self.word_unk_id)
                word_embedding[unk] = word_embeded_by_char[unk]

            all_features.append(word_embedding)

        for name in self.extra_property_names:
            all_features.append(self.extra_dropout(self.extra_embeddings[name](inputs[name])))

        for plugin in self.plugins:
            all_features.append(plugin(inputs))

        if self.mode == 'concat':
            input_embeddings = torch.cat(all_features, dim=-1)
        else:
            input_embeddings = sum(all_features)

        if self.layer_norm is not None:
            input_embeddings = self.layer_norm(input_embeddings)

        return input_embeddings
