# -*- coding: utf-8 -*-

import functools
import os
from abc import abstractmethod

import torch.nn as nn

from ..common.dataclass_options import LOGGER, ExistFile, argfield
from ..data.bucket import BucketOptions
from ..data.vocab import VocabularySet, smart_remove_low_frequency_words
from .layers.sentence import SentenceEmbeddings
from .layers.sequential import ContextualOptions
from .model_base import ModelBase, RestoreToBestSignal
from .pretrained import ExternalEmbeddingPlugin, PretainedPluginOptions


def filter_vocab_by_threshold(vocab, threshold):
    if threshold < 1:
        return smart_remove_low_frequency_words(vocab, threshold)
    elif threshold > 1:
        return vocab.copy_without_low_frequency(threshold, name=vocab.name)
    return vocab


class ParserNetworkBase(nn.Module):
    def __init__(self, hyper_params, vocabs, plugins):
        super().__init__()

        self.input_embeddings = hyper_params.sentence_embedding.create(vocabs, plugins)
        self.encoder = hyper_params.encoder.create(self.input_embeddings.output_size)


class ParserBase(ModelBase):
    class HyperParams(ModelBase.HyperParams):
        bucket: BucketOptions

        sentence_embedding: SentenceEmbeddings.Options
        external_embedding: ExternalEmbeddingPlugin.Options
        pretrained_encoder: PretainedPluginOptions

        encoder: ContextualOptions

        word_threshold: float = \
            argfield(1, help='If it is less than 1, '
                             'remove low-frequency words by min coverage, '
                             'otherwise by min frequency')

    class Options(ModelBase.Options):
        vocab_path: ExistFile

        input_format: str = argfield('standard', active_time='predict')
        evaluate: bool = argfield(False, active_time='predict')

    def __init__(self, options: Options, training_session=None):
        super().__init__(options, training_session)

        self._restore_count = 5

    def build_vocabs(self):
        raise Exception('`vocab_path` should be a valid vocabulary file')

    def build_network(self):
        self.plugins = []

        embedding_options = self.hyper_params.external_embedding
        if embedding_options.filename is not None:
            self.plugins.append(embedding_options.create(self.statistics.get('word')))

        pretrained_encoder = self.hyper_params.pretrained_encoder
        if pretrained_encoder.type != 'none':
            self.plugins.append(pretrained_encoder.create())

        return self.NETWORK_CLASS(self.hyper_params, self.statistics, self.plugins)

    def load_or_build_vocabs(self):
        options = self.options
        logger = self.logger

        if os.path.exists(options.vocab_path):
            logger.info('Loading built vocabs ...')
            vocabs = VocabularySet.from_file(options.vocab_path)
        else:
            logger.info('Building vocabs ...')
            vocabs = self.build_vocabs()
            word_threshold = options.hyper_params.word_threshold
            vocabs.set(filter_vocab_by_threshold(vocabs.get('word'), word_threshold))
            vocabs.to_file(options.vocab_path)

        logger.info('statistics:\n%s', vocabs)
        return vocabs

    def initialize(self, saved_state):
        training = self.options.training
        assert training or saved_state is not None

        if saved_state is None:  # init
            self.statistics = self.load_or_build_vocabs()
        else:
            self.statistics = VocabularySet()
            self.statistics.load_state_dict(saved_state['statistics'])

        self.network = self.build_network()

        if training:
            trainable_params = self.get_trainable_params()
            self.optimizer, self.scheduler = self.get_optimizer_and_scheduler(trainable_params)
            self.logger.info('optimizer: %s', self.optimizer)
            self.logger.info('scheduler: %s', self.scheduler)

        if saved_state is not None:
            self.load_state_dict(saved_state)

    @functools.lru_cache()
    def create_buckets(self, path, mode):
        original_objects = self.SAMPLE_CLASS.from_file(path, self.options.input_format)
        buckets = self.hyper_params.bucket.create(original_objects,
                                                  preprocess_fn=self.FEATURES_CLASS.create,
                                                  logger=self.logger)
        return buckets

    def iter_batches(self, path, mode):
        buckets = self.create_buckets(path, mode)
        yield from buckets.generate_batches(shuffle=(mode != 'test'),
                                            return_original=True,
                                            plugins=self.plugins,
                                            statistics=self.statistics)

    @abstractmethod
    def write_outputs(self, output_path, sample, output):
        pass

    def predict(self, data_path, samples, outputs, output_prefix):
        self.evaluate(data_path, samples, outputs, output_prefix)

    def run_evaluator(self, data_path, samples, outputs, output_files):
        cls = self.SAMPLE_CLASS
        try:
            original_objects = [sample.original_object for sample in samples]
            score = cls.internal_evaluate(original_objects, outputs, output_files[1])
        except NotImplementedError:
            score = cls.external_evaluate(data_path, output_files[0], output_files[1])
        return score

    def evaluate(self, data_path, samples, outputs, output_prefix):
        LOGGER.info('Running on %s', data_path)

        name = os.path.basename(data_path)
        output_files = [output_prefix + f'.{name}', output_prefix + f'.{name}.score']

        extra_files = self.write_outputs(output_files[0], samples, outputs)
        if extra_files is not None:
            output_files.extend(extra_files)

        options = self.options
        if options.training or options.evaluate:
            score = self.run_evaluator(data_path, samples, outputs, output_files)
            return score, output_files

    def after_eval_hook(self, _step, _max_steps, _epoch, metric_value, logger=None):
        if self.metrics.is_far_worse(metric_value):  # performance suddenly drops
            if self._restore_count == 0:
                logger.critical('Restart a lot of times, but it did not work.')
                return True  # quit training session

            self._restore_count -= 1
            raise RestoreToBestSignal()
