import shutil
import sys
import time
from typing import Any, List

import numpy as np
import torch
from dataclasses import dataclass

from coli.torch_extra.utils import to_cuda
from coli.torch_span.config import SpanParserOptions
from coli.torch_span.data_loader import SentenceFeatures
from coli.torch_span.network import AttentiveSpanNetwork
from coli.basic_tools.common_utils import AttrDict, Progbar, NoPickle, cache_result
from coli.basic_tools.dataclass_argparse import argfield
from coli.span.const_tree import ConstTree, ConstTreeStatistics
from coli.torch_extra.parser_base import SimpleParser, PyTorchParserBase


class SpanParser(SimpleParser):
    """"""
    sentence_feature_class = SentenceFeatures
    available_data_formats = {"default": ConstTree}

    @dataclass
    class Options(SimpleParser.Options):
        hparams: SpanParserOptions = argfield(default_factory=SpanParserOptions.get_default,
                                              choices={"default": SpanParserOptions.get_default(),
                                                       "lstm": SpanParserOptions.get_lstm_options()})
        concurrent_count: int = argfield(6, predict_time=True)
        use_rules: bool = argfield(False, predict_time=True)
        restrict_root_rule: bool = argfield(False, predict_time=True)

    def sentence_convert_func(self, sent_idx: int, sentence: ConstTree,
                              padded_length: int):
        return SentenceFeatures.from_sentence_obj(
            sent_idx, sentence, self.statistics,
            padded_length,
            include_postags=self.hparams.predict_postags,
            strip_top=self.hparams.strip_top,
            plugins=self.plugins
        )

    def __init__(self, args: Any, data_train):
        super(SpanParser, self).__init__(args, data_train)
        self.hparams: SpanParserOptions

        @cache_result(args.output + "/" + "statistics.pkl",
                      enable=args.debug_cache)
        def load_statistics():
            return ConstTreeStatistics.from_sentences(
                data_train, include_postags=self.hparams.predict_postags,
                strip_top=self.hparams.strip_top)

        self.statistics = load_statistics()

        self.network = AttentiveSpanNetwork(self.args,
                                            self.hparams, self.statistics,
                                            self.plugins
                                            )

        self.trainable_parameters = [param for param in self.network.parameters()
                                     if param.requires_grad]

        self.optimizer, self.scheduler = self.get_optimizer_and_scheduler(
            self.trainable_parameters, lr=1., betas=(0.9, 0.98), eps=1e-9)

        if self.options.gpu:
            self.network.cuda()

    def split_batch_result(self, batch_result):
        raise NotImplementedError

    def merge_answer(self, sent, answer):
        raise NotImplementedError

    def get_parsed(self, bucket,
                   return_original=True,
                   return_span_features=False,
                   return_chart=False,
                   train_mode=False):
        self.network.train(train_mode)
        with torch.no_grad():
            for subbatches_itr in bucket.generate_batches(
                    self.hparams.test_batch_size,
                    original=True, use_sub_batch=True,
                    plugins=self.plugins
                    # sort_key_func=lambda x: x.sent_length
            ):
                for batch_sent, feed_dict in subbatches_itr:
                    if self.args.gpu:
                        to_cuda(feed_dict)
                    results = self.network(batch_sent, AttrDict(feed_dict),
                                           return_span_features=return_span_features,
                                           return_chart=return_chart
                                           )
                    if not return_span_features and not return_chart:
                        results = results["decoded_list"]
                    for original, parsed in zip(batch_sent, results):
                        if return_original:
                            yield original, parsed
                        else:
                            yield parsed

    def evaluate(self, gold, outputs, log_file):
        scores = super(SpanParser, self).evaluate(gold, outputs, log_file)
        return scores["Bracketing FMeasure"]

    def predict_bucket(self, bucket,
                       return_span_features=False,
                       return_chart=False
                       ):
        outputs: List[Any] = [None for _ in range(len(bucket))]
        for sent_feature, result in self.get_parsed(
                bucket, return_span_features=return_span_features,
                return_chart=return_chart):
            if isinstance(result, ConstTree):
                result.extra = dict(sent_feature.original_obj.extra)
            outputs[sent_feature.original_idx] = result
        return outputs

    def post_load(self, new_options):
        super(SpanParser, self).post_load(new_options)
        self.network.create_pool()
