# -*- coding: utf-8 -*-

from typing import Optional

import torch
import torch.nn as nn

from ...common.dataclass_options import ExistFile, OptionsBase, argfield
from ...data.file_reader import read_embedding_file
from ...data.vocab import lookup_words
from ..input_plugin_base import InputPluginBase
from ..utils import pad_and_stack_1d


class ExternalEmbeddingPlugin(InputPluginBase):
    class Options(OptionsBase):
        filename: Optional[ExistFile] = argfield(None, active_time='both')
        project_to: Optional[int] = None
        encoding: str = 'utf-8'
        lower: bool = False

        def create(self, *args, **kwargs):
            return ExternalEmbeddingPlugin(self, *args, **kwargs)

    def __init__(self, options: Options, vocab=None):
        super().__init__()

        self.options = options
        self.initialize(vocab)

    def extra_repr(self):
        return repr(self.embedding_[0])

    def initialize(self, vocab):
        options = self.options
        words, vectors = read_embedding_file(options.filename, options.encoding,
                                             tensor_fn=torch.tensor,
                                             vocab=vocab,
                                             add_unk=True)

        self.lookup = {word: index for index, word in enumerate(words)}
        self.embedding_ = [nn.Embedding.from_pretrained(vectors, freeze=True)]
        self.output_size = self.embedding_[0].embedding_dim

        if options.project_to:
            self.projection = nn.Linear(self.output_size, options.project_to)
            self.output_size = options.project_to
        else:
            self.projection = None

    def postprocess_sample(self, sample, sos_and_eos=False, **_kwargs):
        words = sample.words
        if self.options.lower:
            words = [word.lower() for word in words]

        sample.attrs['words_pretrained'] = \
            lookup_words(words, self.lookup, len(words), sos_and_eos=sos_and_eos)

    def postprocess_batch(self, batch_samples, inputs, **_kwargs):
        inputs['words_pretrained'] = \
            pad_and_stack_1d([torch.from_numpy(sample.attrs['words_pretrained'])
                              for sample in batch_samples])

    def forward(self, inputs):
        words = inputs.words_pretrained
        ret = self.embedding_[0](words.cpu()).to(words.device)

        if self.projection is not None:
            ret = self.projection(ret)
        return ret
