# -*- coding: utf-8 -*-

import os
import pickle
import random
import tempfile

import torch

from framework.common.dataclass_options import argfield
from framework.common.logger import open_file
from framework.common.utils import DotDict
from framework.data.vocab import VocabularySet
from framework.torch_extra.model_base import TMP_TOKEN, ModelBase
from pyshrg_utils.parser import PySHRGOptions, pyshrg_init

from ..batch_utils import process_batch
from ..evaluate import EVALUATORS
from .network import HyperParams, SHRGSelector


def build_grammar_nonterminals(manager, nonterminal_vocab):
    grammar_nonterminals = torch.zeros(manager.hrg_size, 3, dtype=torch.long)
    for index, shrg_rule in enumerate(manager.iter_hrgs()):
        if shrg_rule.is_empty:
            continue
        labels = grammar_nonterminals[index]
        labels[0] = nonterminal_vocab.add(shrg_rule.label)
        edges = shrg_rule.nonterminal_edges
        assert len(edges) <= 2
        if len(edges) >= 1:
            labels[1] = nonterminal_vocab.add(edges[0].label)
        if len(edges) == 2:
            labels[2] = nonterminal_vocab.add(edges[1].label)
    return grammar_nonterminals


class SHRGGenerator(ModelBase):
    METRICS_MODE = 'max'
    METRICS_NAME = 'BLEU'

    NETWORK_CLASS = SHRGSelector

    class Options(ModelBase.Options):
        vocab_path: str
        dev_gold_trees_path: str

        train_mode: str = argfield('joint', choices=['joint', 'hrg', 'cfg'])

        grammar_path: str
        evaluator_mode: str = argfield('multithreads', choices=EVALUATORS.keys(),
                                       active_time='both')

        hyper_params: HyperParams
        pyshrg: PySHRGOptions

    def move_to_device(self, device=None):
        device = super().move_to_device(device)
        self.network.hrg.grammar_nonterminals = self.grammar_nonterminals.to(device)

        return device

    def state_dict(self):
        saved_state = super().state_dict()
        saved_state['nonterminals'] = self.grammar_nonterminals
        saved_state['grammar'] = self.grammar_content
        saved_state['evaluator'] = self.evaluator.state_dict()
        return saved_state

    def load_state_dict(self, saved_state):
        super().load_state_dict(saved_state)

        if self.options.training:  # only load while training
            self.evaluator.load_state_dict(saved_state['evaluator'])

    def build_network(self):
        self.plugins = []

        embedding_options = self.hyper_params.external_embedding
        if embedding_options.filename is not None:
            self.plugins.append(embedding_options.create(self.statistics.get('lemma')))

        network = self.NETWORK_CLASS(self.hyper_params, self.statistics, self.plugins,
                                     num_hrg_grammars=self.manager.hrg_size,
                                     num_cfg_grammars=self.manager.shrg_size,
                                     grammar_nonterminals=self.grammar_nonterminals)

        self.evaluator = EVALUATORS.invoke(self.options.evaluator_mode,
                                           network, self.statistics, self.device,
                                           return_output=True)

        mode = self.options.train_mode
        if mode == 'hrg':
            def filter_fn(name): return name.startswith('cfg')
        elif mode == 'cfg':
            def filter_fn(name): return not name.startswith('cfg')
        else:
            def filter_fn(name): return False

        for name, params in network.named_parameters():
            if filter_fn(name):
                params.requires_grad_(False)

        return network

    def _initialize_parser(self, saved_state):
        options = self.options
        if saved_state is None:  # init
            grammar_path = options.grammar_path
            grammar_content = open_file(options.grammar_path, 'rb').read()
        else:
            options.grammar_path = '<saved_state>'
            temp_file = tempfile.NamedTemporaryFile()

            grammar_path = temp_file.name
            grammar_content = saved_state['grammar']
            with open_file(grammar_path, 'wb') as fp:
                fp.write(grammar_content)

        self.manager = pyshrg_init(grammar_path, options=options.pyshrg)
        self.grammar_content = grammar_content

    def initialize(self, saved_state):
        training = self.options.training
        assert training or saved_state is not None

        self._initialize_parser(saved_state)

        if saved_state is None:  # init
            vocabs = VocabularySet.from_file(self.options.vocab_path)
            self.grammar_nonterminals = \
                build_grammar_nonterminals(self.manager, vocabs.get_or_new('nonterminal'))
            self.statistics = vocabs
        else:
            self.grammar_nonterminals = saved_state['nonterminals']
            self.statistics = VocabularySet()
            self.statistics.load_state_dict(saved_state['statistics'])

        self.logger.info('statistics:\n%s', self.statistics)

        self.network = self.build_network()

        if training:
            trainable_params = self.get_trainable_params()
            self.optimizer, self.scheduler = self.get_optimizer_and_scheduler(trainable_params)
            self.logger.info('optimizer: %s', self.optimizer)
            self.logger.info('scheduler: %s', self.scheduler)

            self.dev_gold_trees = pickle.load(open_file(self.options.dev_gold_trees_path, 'rb'))

        if saved_state is not None:
            self.load_state_dict(saved_state)

    def iter_batches(self, path, mode):
        assert mode == 'train'

        batches = pickle.load(open_file(path, 'rb'))
        random.shuffle(batches)

        for batch in batches:
            batch = DotDict(batch)
            yield batch.sentence_ids, process_batch(batch, self.plugins, self.device)

    def compute_stats(self, average_meter, _, inputs, outputs):
        hrg_outputs = outputs.get('hrg')
        if hrg_outputs is not None:
            average_meter.add('hrg_acc', hrg_outputs.correct, hrg_outputs.total)

        cfg_outputs = outputs.get('cfg')
        if cfg_outputs is not None:
            average_meter.add('cfg_acc', cfg_outputs.correct, cfg_outputs.total)

        return super().compute_stats(average_meter, _, inputs, outputs)

    def run_batch(self, _, inputs):
        return self.network(inputs, self.options.train_mode)

    def evaluate_entry(self, path=None, mode=None):
        if path is None:
            path = self.options.dev_path
            mode = 'dev'
        else:
            assert mode is not None

        if mode == 'predict':
            output_prefix = self.options.output_prefix
            gold_trees = None
        else:
            output_prefix = self.get_path(TMP_TOKEN)
            gold_trees = self.dev_gold_trees

        assert self.manager.load_graphs(path)

        self.network.eval()
        with torch.no_grad():
            metrics = self.evaluator(predict_cfg=(self.options.train_mode != 'hrg'),
                                     gold_trees=gold_trees)

        return metrics.get(), metrics.save(output_prefix + f'.{os.path.basename(path)}',
                                           no_score=(mode == 'predict'))
