# -*- coding: utf-8 -*-

from typing import Any

import torch

from ..common.utils import DotDict
from ..torch_extra.utils import pad_and_stack_1d
from .dataset import DataFeatures
from .vocab import Tokens, lookup_characters, lookup_words


def _iter_items(container):
    yield from (container.items() if isinstance(container, dict) else container)


def collect_sentence_vocabs(samples, vocabs, sos_and_eos=False):
    initial = Tokens.PAD, Tokens.UNK
    if sos_and_eos:
        initial += Tokens.SOS, Tokens.EOS

    words_vocab = vocabs.new('word', initial=initial)

    for sample in samples:
        words_vocab.update(sample.words)

        chars = getattr(sample, 'chars', None)
        if chars is not None:
            vocabs.get_or_new('char').update(chars)

        extra_properties = getattr(sample, 'extra_properties', None)
        if extra_properties is not None:
            for name, extra_property in _iter_items(extra_properties):
                vocabs.get_or_new(name, initial=initial).update(extra_property)


class SentenceFeatures(DataFeatures):
    words: Any = None
    chars: Any = None
    extra_properties: Any = None

    sentence_length: int = -1
    word_lengths: Any = None

    @classmethod
    def create(cls, original_index, original_object, plugins, statistics,
               sos_and_eos=True, lower_case=False):
        sample = cls(original_index=original_index, original_object=original_object)

        words = [(word.lower() if lower_case else word) for word in original_object.words]

        sample.sentence_length = len(words) + (2 if sos_and_eos else 0)
        sample.words = lookup_words(words, statistics.get('word'), sos_and_eos=sos_and_eos)

        chars = statistics.get('char')
        if chars is not None:
            sample.word_lengths, sample.chars = \
                lookup_characters(words, chars, sos_and_eos=sos_and_eos, return_lengths=True)

        sample.extra_properties = {}
        extra_properties = getattr(original_object, 'extra_properties', None)
        if extra_properties is not None:
            for name, extra_property in _iter_items(extra_properties):
                sample.extra_properties[name] = \
                    lookup_words(extra_property, statistics.get(name), sos_and_eos=sos_and_eos)

        cls.run_plugins_for_sample(sample, plugins, sos_and_eos=sos_and_eos)
        return sample

    @classmethod
    def pack_to_batch(cls, batch_samples, plugins, statistics):
        inputs = DotDict()

        inputs.words = pad_and_stack_1d(
            [torch.from_numpy(sample.words) for sample in batch_samples]
        )
        inputs.encoder_lengths = \
            torch.tensor([sample.sentence_length for sample in batch_samples])

        if batch_samples[0].chars is not None:
            inputs.chars = pad_and_stack_1d(
                [torch.from_numpy(sample.chars) for sample in batch_samples]
            )
            inputs.word_lengths = pad_and_stack_1d(
                [torch.from_numpy(sample.word_lengths) for sample in batch_samples]
            )

        for name in batch_samples[0].extra_properties:
            inputs[name] = pad_and_stack_1d(
                [torch.from_numpy(sample.extra_properties[name] for sample in batch_samples)])

        cls.run_plugins_for_batch(batch_samples, inputs, plugins)
        return inputs
