import gc
import pickle
from abc import abstractmethod, abstractproperty, ABCMeta
from itertools import zip_longest
from typing import Mapping, Optional, Any

import torch
from dataclasses import dataclass
from nltk.stem import WordNetLemmatizer
from torch import Tensor

from coli.basic_tools.common_utils import AttrDict, NoPickle, try_cache_keeper, cache_result, Progbar, \
    NullContextManager
from coli.basic_tools.dataclass_argparse import argfield, ExistFile, BranchSelect
from coli.basic_tools.logger import logger
from coli.data_utils.dataset import HParamsBase
from coli.hrg_parser.count_based_scorer import CountBasedHRGScorer
from coli.hrg_parser.hrg_statistics import HRGStatistics
from coli.hrg_parser.parser_mixin import HRGParserMixin
from coli.hrgguru.hrg import CFGRule
from coli.hrgguru.hyper_graph import GraphNode, HyperEdge
from coli.torch_extra.parser_base import PyTorchParserBase
from coli.torch_extra.utils import to_cuda
from coli.torch_hrg.feature_based_scorer import StructuredPeceptronHRGScorer
from coli.torch_hrg.graph_embedding_based_scorer import GraphEmbeddingHRGScorer
from coli.torch_hrg.internal_tagger import InternalTagger
from coli.torch_hrg.pred_tagger import PredTagger, ConstTreeExtra, SentenceFeaturesWithTags
from coli.torch_span.data_loader import SentenceFeatures
from coli.torch_span.parser import SpanParser


class AutoBatchMixin(object):
    parameters: Any
    __call__: Any

    def __init__(self, *args, **kwargs):
        super(AutoBatchMixin, self).__init__(*args, **kwargs)
        self.pending_inputs = []
        self.has_processed = 0
        self.results: Optional[Tensor] = None

    def add_input(self, *args):
        result_idx = len(self.pending_inputs)
        self.pending_inputs.append(args)
        return result_idx

    def calculate_results(self):
        device = next(self.parameters()).device
        if self.has_processed == len(self.pending_inputs):
            return

        def smart_stack(maybe_tensor_list):
            if isinstance(maybe_tensor_list[0], Tensor):
                return torch.stack(maybe_tensor_list)
            if isinstance(maybe_tensor_list[0], (int, float)):
                # noinspection PyCallingNonCallable
                return torch.tensor(maybe_tensor_list, device=device)
            raise Exception("Unknown input type {}".format(maybe_tensor_list[0]))

        batch_data = [smart_stack(i)
                      for i in zip(*self.pending_inputs[self.has_processed:])]
        outputs = self(*batch_data)
        if self.results is None:
            self.results = outputs
        else:
            self.results = torch.cat([self.results, outputs], dim=0)
        self.has_processed = len(self.pending_inputs)

    def refresh(self):
        self.has_processed = 0
        self.pending_inputs = []
        self.results = None


class AutoBatchModule(AutoBatchMixin, torch.nn.Module, metaclass=ABCMeta):
    pass


class AutoBatchSequential(AutoBatchMixin, torch.nn.Sequential):
    pass


scorers = {"feature": StructuredPeceptronHRGScorer,
           "count": CountBasedHRGScorer,
           "graph": GraphEmbeddingHRGScorer}


class RuleScorers(BranchSelect):
    branches = scorers

    @dataclass
    class Options(BranchSelect.Options):
        type: str = argfield("feature", choices=scorers.keys())
        feature_options: StructuredPeceptronHRGScorer.Options = argfield(
            default_factory=StructuredPeceptronHRGScorer.Options)
        count_options: CountBasedHRGScorer.Options = argfield(CountBasedHRGScorer.Options)
        graph_options: GraphEmbeddingHRGScorer.Options = argfield(GraphEmbeddingHRGScorer.Options)


class UdefQParserBase(HRGParserMixin, PyTorchParserBase):
    available_data_formats = {"default": ConstTreeExtra}
    sentence_feature_class = SentenceFeaturesWithTags
    PrintLogger = abstractproperty()

    @dataclass
    class HParams(HParamsBase):
        stop_grad: bool = argfield(False)
        disable_span_dropout: bool = argfield(False)
        greedy_at_leaf: bool = argfield(False)

    @dataclass
    class Options(PyTorchParserBase.Options):
        hparams: "UdefQParserBase.HParams" = argfield(
            default_factory=lambda: UdefQParserBase.HParams())
        derivations: ExistFile = argfield()
        grammar: ExistFile = argfield()
        span_cache: ExistFile = argfield(None)
        embed_file: ExistFile = argfield(None)
        pred_tagger: ExistFile = argfield(None)
        internal_tagger: ExistFile = argfield(None)
        span_model: str = argfield(predict_time=True)
        deepbank_dir: str = argfield(predict_time=True)
        gpu: bool = argfield(False, predict_time=True)
        graph_type: str = argfield("eds", choices=["eds", "dmrs", "lf"], predict_time=True)
        beam_size: int = argfield(16, predict_time=True)
        word_threshold: int = argfield(0, predict_time=True)

    def __init__(self, options, train_trees):
        super(UdefQParserBase, self).__init__(options, train_trees)

        @try_cache_keeper(options.span_model)
        def get_parser():
            parser = SpanParser.load(
                options.span_model,
                AttrDict(gpu=options.gpu and not self.options.span_cache,
                         bilm_path=self.options.bilm_path
                         ))
            if self.options.hparams.stop_grad:
                parser: SpanParser = NoPickle(parser)
            return parser

        self.parser: SpanParser = get_parser()

        @try_cache_keeper(options.pred_tagger)
        def get_tagger():
            tagger = PredTagger.load(
                options.pred_tagger,
                AttrDict(gpu=options.gpu,
                         bilm_path=self.options.bilm_path,
                         span_model=self.parser
                         ))
            if self.options.hparams.stop_grad:
                tagger = NoPickle(tagger)
            tagger.network.eval()
            return tagger

        if options.pred_tagger is not None:
            self.tagger = get_tagger()
            self.statistics = self.tagger.statistics
        else:
            self.tagger = None
            self.statistics = self.parser.statistics
            self.sentence_feature_class = SentenceFeatures

        @try_cache_keeper(options.internal_tagger)
        def get_internal_tagger():
            tagger = InternalTagger.load(
                options.pred_tagger,
                AttrDict(gpu=options.gpu,
                         bilm_path=self.options.bilm_path,
                         span_model=self.parser
                         ))
            if self.options.hparams.stop_grad:
                tagger = NoPickle(tagger)
            tagger.network.eval()
            return tagger

        if self.options.internal_tagger is not None:
            self.internal_tagger = get_internal_tagger()
        else:
            self.internal_tagger = None

        self.plugins = self.parser.plugins
        self.old_options = self.parser.options
        self.hparams = options.hparams

        @try_cache_keeper(options.derivations)
        def load_derivations():
            with open(options.derivations, "rb") as f:
                return pickle.load(f)

        self.logger.info("Loading derivations...")
        self.derivations = NoPickle(load_derivations())

        @cache_result(
            self.options.output + "/hrg_statistics.pkl", self.options.debug_cache)
        def load_hrg_statistics():
            return HRGStatistics.from_derivations(self.derivations)

        self.logger.info("Statistics HRG...")
        self.hrg_statistics = load_hrg_statistics()

        @try_cache_keeper(options.grammar)
        def load_grammar():
            with open(options.grammar, "rb") as f:
                return pickle.load(f, encoding="latin1")

        self.grammar, self.lexicon_to_graph, self.lexicon_to_lemma, \
        self.postag_mapping, self.lexical_label_mapping = load_grammar()  # type: Mapping[str, Mapping[CFGRule, int]]

        self.lemmatizer = WordNetLemmatizer()

        # delay network creation
        self.network = self.optimizer = None

        self.file_logger = NoPickle(self.get_logger(self.options, False))
        self.progbar = NoPickle(Progbar(self.hparams.train_iters, log_func=self.file_logger.info))

        print(self.hrg_statistics)

        self.span_cache = None
        self.gold_graphs = NoPickle({})
        self.best_score = 0

    def post_load(self, new_options: Optional[AttrDict] = None):
        super(UdefQParserBase, self).post_load(new_options)
        if self.parser is None:
            self.parser: SpanParser = SpanParser.load(
                new_options.span_model_prefix,
                AttrDict(gpu=new_options.gpu and not self.options.span_cache,
                         bilm_path=new_options.bilm_path)
            )
        else:
            self.parser.post_load(
                AttrDict(gpu=new_options.gpu and not self.options.span_cache,
                         bilm_path=new_options.bilm_path, feature_only=True
                         ))

        if self.tagger is None and new_options.pred_tagger is not None:
            self.tagger = PredTagger.load(
                new_options.pred_tagger,
                AttrDict(gpu=new_options.gpu,
                         bilm_path=new_options.bilm_path,
                         span_model=self.parser
                         ))
            if self.options.hparams.stop_grad:
                self.tagger = NoPickle(self.tagger)
            self.tagger.network.eval()
        else:
            pass

        self.gold_graphs = NoPickle({})

    def get_span_features(self, batch_sents, feed_dict, device):
        if not self.span_cache:
            with (torch.no_grad() if self.options.hparams.stop_grad else NullContextManager()):
                span_features_list = self.parser.network(batch_sents, AttrDict(feed_dict),
                                                         return_span_features=True)
        else:
            # noinspection PyCallingNonCallable
            span_features_list = [
                torch.tensor(self.span_cache.get(i.original_obj.extra["ID"]), device=device)
                for i in batch_sents]
        return span_features_list

    @abstractmethod
    def create_network(self):
        raise NotImplementedError

    @abstractmethod
    def hook_1(self):
        raise NotImplementedError

    def hook_2(self):
        raise NotImplementedError

    def hook_3(self):
        raise NotImplementedError

    @abstractmethod
    def training_session(self, tree, span_features,
                         lexical_labels, attachment_bags, internal_bags,
                         print_logger, derivations=()):
        raise NotImplementedError

    def train(self, train_buckets, dev_args_list):
        if self.network is None:
            self.create_network()

        print_logger = self.PrintLogger()
        self.network.train()
        device = next(self.network.parameters()).device

        batch_itr = train_buckets.generate_batches(
            self.hparams.train_batch_size,
            shuffle=True, original=True, use_sub_batch=True,
            plugins=self.plugins
            # sort_key_func=lambda x: x.sent_length
        )

        sent_idx = 0
        for batch_data in batch_itr:
            # self.schedule_lr(self.global_step)
            for batch_sents, feed_dict in batch_data:
                if self.options.gpu:
                    to_cuda(feed_dict)
                sent_idx += len(batch_sents)
                span_features_list = self.get_span_features(batch_sents, feed_dict, device)
                lexical_labels_list, attachment_bags_list, internal_bags_list = self.get_tagger_result(
                    batch_sents, feed_dict)
                self.optimizer.zero_grad()
                assert len(batch_sents) == len(span_features_list)
                sessions = [
                    self.training_session(sent_feature.original_obj, span_features,
                                          lexical_labels, attachment_bags, internal_bags, print_logger,
                                          self.derivations[sent_feature.original_obj.extra["ID"]])
                    for sent_feature, span_features, lexical_labels,
                        attachment_bags, internal_bags in zip_longest(
                        batch_sents, span_features_list, lexical_labels_list, attachment_bags_list, internal_bags_list)
                ]
                if self.hook_1 is not None:
                    for session in sessions:
                        next(session)
                    self.hook_1()
                losses = [None] * len(sessions)
                finish_count = 0
                while finish_count < len(sessions):
                    for idx, session in enumerate(sessions):
                        try:
                            next(session)
                        except StopIteration as e:
                            if losses[idx] is None:
                                losses[idx] = e.value
                                finish_count += 1
                        self.hook_2()
                loss: Tensor = sum(losses) / len(sessions)
                self.hook_3()
                del losses, sessions, span_features_list
                if loss == 0.0:
                    continue
                gc.collect()
                print_logger.total_loss += loss.detach().cpu().numpy()
                backed = False
                while not backed:
                    try:
                        loss.backward()
                        backed = True
                    except RuntimeError as e:
                        if "out of memory" in str(e):
                            self.logger.info("OOM. Clear cache and try again.")
                            torch.cuda.empty_cache()
                            continue
                        elif "but the buffers have already been freed" in str(e):
                            self.logger.info("Backward error")
                            break
                        else:
                            raise
                if backed:
                    self.optimizer.step()
                del loss

                self.global_step += 1
                if self.global_step % self.hparams.print_every == 0:
                    print_logger.print(sent_idx)
                    if self.global_step % self.hparams.evaluate_every == 0:
                        if dev_args_list is not None:
                            for filename, data, buckets in dev_args_list:
                                results = self.predict_bucket(buckets)
                                output_file = self.get_output_name(
                                    self.options.output, filename, self.global_step)
                                new_score = self.evaluate_hg(results, output_file)
                                if new_score > self.best_score:
                                    self.logger.info("New best score: %.2f > %.2f", new_score, self.best_score)
                                    self.save(self.options.output + "/model")
                                    self.best_score = new_score
                                else:
                                    self.logger.info("Not best score, %.2f <= %.2f", new_score, self.best_score)
                        self.network.train()

    def get_tagger_result(self, batch_sents, feed_dict):
        if self.tagger is None:
            return [], [], []

        tagger_results = self.tagger.network(batch_sents, AttrDict(feed_dict))
        logits = tagger_results.lexical_logits
        attachment_logits = tagger_results.attachment_logits
        _, lexical_labels_list = (-logits).sort(dim=-1)
        _, attachment_bags_list = (-attachment_logits).sort(dim=-1)
        lexical_labels_list = lexical_labels_list.cpu()
        attachment_bags_list = attachment_bags_list.cpu()

        if self.internal_tagger is not None:
            internal_logits = tagger_results.internal_logits
            _, internal_bags_list = (-internal_logits).sort(dim=-1)
            internal_bags_list = internal_bags_list.cpu()
        else:
            internal_bags_list = []

        return lexical_labels_list, attachment_bags_list, internal_bags_list

    def predict_bucket(self, buckets, return_derivation=False):
        print_logger = self.PrintLogger()
        results = [None for _ in range(len(buckets))]
        with torch.no_grad():
            batch_itr = buckets.generate_batches(
                self.hparams.train_batch_size,
                original=True, use_sub_batch=True,
                plugins=self.plugins
            )
            device = next(self.network.parameters()).device
            sent_idx = 0
            total_sents = 0
            for batch_data in batch_itr:
                for batch_sents, feed_dict in batch_data:
                    if self.options.gpu:
                        to_cuda(feed_dict)
                    total_sents += len(batch_sents)
                    logger.info("{} predicted".format(total_sents))
                    sent_idx += len(batch_sents)
                    span_features_list = self.get_span_features(batch_sents, feed_dict, device)
                    lexical_labels_list, attachment_bags_list, internal_bags_list = self.get_tagger_result(
                        batch_sents, feed_dict)
                    assert len(batch_sents) == len(span_features_list)
                    sessions = [self.training_session(sent_feature.original_obj, span_features,
                                                      lexical_labels, attachment_bags, internal_bags,
                                                      print_logger)
                                for sent_feature, span_features, lexical_labels,
                                    attachment_bags, internal_bags in zip_longest(
                            batch_sents, span_features_list, lexical_labels_list, attachment_bags_list,
                            internal_bags_list)
                                ]
                    if self.hook_1 is not None:
                        for session in sessions:
                            next(session)
                        self.hook_1()
                    finish_count = 0
                    while finish_count < len(sessions):
                        for sent_feature, session in zip(batch_sents, sessions):
                            try:
                                next(session)
                            except StopIteration as e:
                                if results[sent_feature.original_idx] is None:
                                    final_beam_item = e.value
                                    tree = sent_feature.original_obj
                                    graph = final_beam_item.make_graph(tree,
                                                                       lexicon_to_lemma=self.lexicon_to_lemma).graph
                                    if not return_derivation:
                                        results[sent_feature.original_idx] = (tree.extra["ID"], graph)
                                    else:
                                        results[sent_feature.original_idx] = (
                                            tree.extra["ID"], graph, list(self.construct_derivation(final_beam_item)))
                                    finish_count += 1
                            self.hook_2()
                    self.hook_3()
        return results

    @staticmethod
    def get_real_edges(hg):
        node_mapping = {}  # node -> pred edge
        real_edges = []
        ret_edges = []
        for edge in hg.edges:  # type: HyperEdge
            if len(edge.nodes) == 1:
                main_node = edge.nodes[0]  # type: GraphNode
                if node_mapping.get(main_node) is not None:
                    continue
                if not edge.is_terminal:
                    raise Exception("Non-terminal edge should not exist there {}".format(edge))
                node_mapping[main_node] = edge
            elif len(edge.nodes) == 2:
                real_edges.append(edge)
            else:
                raise Exception("Hyperedge should not exist there")

        for edge in real_edges:
            pred_edges = [node_mapping.get(i) for i in edge.nodes]
            if pred_edges[0] is not None and pred_edges[1] is not None:
                ret_edges.append((pred_edges[0], edge.label, pred_edges[1]))

        return ret_edges, list(node_mapping.values())

    def construct_derivation(self, beam_item):
        if beam_item.left is not None:
            yield from self.construct_derivation(beam_item.left)
        if beam_item.right is not None:
            yield from self.construct_derivation(beam_item.right)
        # ignore empty beam item
        if beam_item.sync_rule is not None:
            # correspondents = [i[0] for i in sorted(
            #     beam_item.node_info_ref().correspondents.items(),
            #     key=lambda x: x[1], reverse=True)]
            yield beam_item.sub_graph, beam_item.sync_rule, []
        else:
            yield None, None, []
