import os
import pickle
import time
import traceback
from itertools import chain, zip_longest

import torch
import typing
from argparse import ArgumentParser
from pprint import pformat
from typing import List, Mapping

import sys

from dataclasses import dataclass
from nltk import WordNetLemmatizer
from torch import Tensor

from coli.basic_tools.dataclass_argparse import DataClassArgParser, argfield, ExistFile
from coli.torch_extra.utils import to_cuda
from coli.torch_hrg.hrg_parser_base import RuleScorers
from coli.torch_hrg.internal_tagger import InternalTagger
from coli.torch_hrg.pred_tagger import PredTagger, ConstTreeExtra, SentenceFeaturesWithTags
from coli.torch_span.data_loader import SentenceFeatures
from coli.torch_span.parser import SpanParser
from coli.basic_tools.common_utils import AttrDict, cache_result, NoPickle, Progbar, \
    try_cache_keeper
from coli.data_utils.dataset import TensorflowHParamsBase, HParamsBase
from coli.hrg_parser.parser_mixin import HRGParserMixin
from coli.hrgguru.hrg import CFGRule
from coli.hrgguru.sub_graph import SubGraph
from coli.span.const_tree import ConstTree, Lexicon

from coli.hrg_parser.count_based_scorer import CountBasedHRGScorer
from coli.hrg_parser.hrg_statistics import HRGStatistics
from coli.torch_hrg.cache_manager import SpanCacheManager
from coli.torch_hrg.feature_based_scorer import StructuredPeceptronHRGScorer
from coli.torch_hrg.graph_embedding_based_scorer import GraphEmbeddingHRGScorer
from coli.torch_extra.parser_base import PyTorchParserBase


class UdefQParserGreedy(HRGParserMixin, PyTorchParserBase):
    available_data_formats = {"default": ConstTreeExtra}
    scorers = {"feature": StructuredPeceptronHRGScorer,
               "count": CountBasedHRGScorer,
               "graph": GraphEmbeddingHRGScorer}
    sentence_feature_class = SentenceFeaturesWithTags

    @dataclass
    class HParams(HParamsBase):
        stop_grad: bool = argfield(False)
        scorer: RuleScorers.Options = argfield(default_factory=RuleScorers.Options)
        span_cache: ExistFile = argfield(None)
        disable_span_dropout: bool = argfield(False)

    @dataclass
    class Options(PyTorchParserBase.Options):
        hparams: "UdefQParserGreedy.HParams" = argfield(
            default_factory=lambda: UdefQParserGreedy.HParams())
        pred_tagger: ExistFile = argfield(None)
        internal_tagger: ExistFile = argfield(None)
        span_model: str = argfield(predict_time=True)
        derivations: ExistFile = argfield()
        grammar: ExistFile = argfield()
        embed_file: ExistFile = argfield(None)
        graph_type: str = argfield("eds", choices=["eds", "dmrs", "lf"], predict_time=True)
        deepbank_dir: str = argfield(predict_time=True)
        gpu: bool = argfield(False, predict_time=True)
        word_threshold: int = argfield(0, predict_time=True)

    def __init__(self, options: "UdefQParserGreedy.Options", train_trees):
        super(UdefQParserGreedy, self).__init__(options, train_trees)

        self.args: UdefQParserGreedy.Options
        self.hparams: UdefQParserGreedy.Options

        @try_cache_keeper(options.span_model)
        def get_parser():
            parser = SpanParser.load(
                options.span_model,
                AttrDict(gpu=options.gpu and not options.hparams.span_cache,
                         bilm_path=self.options.bilm_path
                         ))
            if self.options.hparams.stop_grad:
                parser: SpanParser = NoPickle(parser)
            return parser

        self.parser: SpanParser = get_parser()

        @try_cache_keeper(options.pred_tagger)
        def get_tagger():
            tagger = PredTagger.load(
                options.pred_tagger,
                AttrDict(gpu=options.gpu,
                         bilm_path=self.options.bilm_path,
                         span_model=self.parser
                         ))
            if self.options.hparams.stop_grad:
                tagger = NoPickle(tagger)
            tagger.network.eval()
            return tagger

        if options.pred_tagger is not None:
            self.tagger = get_tagger()
            self.statistics = self.tagger.statistics
        else:
            self.tagger = None
            self.statistics = self.parser.statistics
            self.sentence_feature_class = SentenceFeatures

        @try_cache_keeper(options.internal_tagger)
        def get_internal_tagger():
            tagger = InternalTagger.load(
                options.pred_tagger,
                AttrDict(gpu=options.gpu,
                         bilm_path=self.options.bilm_path,
                         span_model=self.parser
                         ))
            if self.options.hparams.stop_grad:
                tagger = NoPickle(tagger)
            tagger.network.eval()
            return tagger

        if options.internal_tagger is not None:
            self.internal_tagger = get_internal_tagger()
        else:
            self.internal_tagger = None

        self.plugins = self.parser.plugins
        self.old_options = self.parser.options
        self.args = options
        self.hparams = options.hparams
        self.logger.info(self.statistics)
        self.options = options
        self.logger.info(pformat(self.options.__dict__))
        self.global_step = 0
        self.global_epoch = 0
        self.best_score = 0.0
        self.gold_graphs = NoPickle({})
        self.span_cache = None

        @try_cache_keeper(options.derivations)
        def load_derivations():
            with open(options.derivations, "rb") as f:
                return pickle.load(f)

        self.logger.info("Loading derivations...")
        self.derivations = NoPickle(load_derivations())

        @cache_result(
            self.options.output + "/hrg_statistics.pkl", self.options.debug_cache)
        def load_hrg_statistics():
            return HRGStatistics.from_derivations(self.derivations)

        self.logger.info("Statistics HRG...")
        self.hrg_statistics = load_hrg_statistics()

        @try_cache_keeper(options.grammar)
        def load_grammar():
            with open(options.grammar, "rb") as f:
                return pickle.load(f)

        self.grammar, self.lexicon_to_graph, self.lexicon_to_lemma, \
        self.postag_mapping, self.lexical_label_mapping = load_grammar()  # type: Mapping[str, Mapping[CFGRule, int]]

        self.lemmatizer = WordNetLemmatizer()

        # delay network creation
        self.network = self.scorer_network = self.optimizer = None

        self.file_logger = NoPickle(self.get_logger(self.options, False))
        self.progbar = NoPickle(Progbar(self.hparams.train_iters, log_func=self.file_logger.info))

    def create_network(self):
        self.network = self.scorer_network = RuleScorers.get(
            self.options.hparams.scorer,
            grammar=self.grammar,
            statistics=self.hrg_statistics,
            contextual_dim=self.parser.network.contextual_unit.output_dim,
        )

        if self.options.gpu:
            self.network.cuda()

        if not self.options.hparams.stop_grad:
            trainable_parameters = [param for param in
                                    chain(self.parser.network.parameters(), self.network.parameters())
                                    if param.requires_grad]
        else:
            trainable_parameters = list(self.network.parameters())

        # self.grad_clip_threshold = np.inf if self.hparams.clip_grad_norm == 0 else self.hparams.clip_grad_norm
        # self.clippable_parameters = trainable_parameters

        if trainable_parameters:
            self.optimizer = torch.optim.Adam(trainable_parameters)
        else:
            self.optimizer = None

    def get_span_features(self, batch_sents, feed_dict, device):
        if isinstance(self.network, CountBasedHRGScorer):
            ret = []
            for i in batch_sents:
                ret.append([None] * len(i.span_starts))
            return ret

        if not self.span_cache:
            if self.options.hparams.stop_grad:
                with torch.no_grad():
                    span_features_list = self.parser.network(batch_sents, AttrDict(feed_dict),
                                                             return_span_features=True)
                # span_features_list = [i.detach() for i in span_features_list]
            else:
                span_features_list = self.parser.network(batch_sents, AttrDict(feed_dict),
                                                         return_span_features=True)
        else:
            # noinspection PyCallingNonCallable
            span_features_list = [
                torch.tensor(self.span_cache.get(i.original_obj.extra["ID"]), device=device)
                for i in batch_sents]
        return span_features_list

    def get_tagger_result(self, batch_sents, feed_dict):
        if self.tagger is None:
            return [], [], []

        tagger_results = self.tagger.network(batch_sents, AttrDict(feed_dict))
        logits = tagger_results.lexical_logits
        attachment_logits = tagger_results.attachment_logits

        _, lexical_labels_list = (-logits).sort(dim=-1)
        _, attachment_bags_list = (-attachment_logits).sort(dim=-1)
        lexical_labels_list = lexical_labels_list.cpu()
        attachment_bags_list = attachment_bags_list.cpu()

        if self.internal_tagger is not None:
            internal_tagger_results = self.internal_tagger.network(batch_sents, AttrDict(feed_dict))
            internal_logits = internal_tagger_results["internal_logits"]
            _, internal_bags_list = (-internal_logits).sort(dim=-1)
            internal_bags_list = internal_bags_list.cpu()
        else:
            internal_bags_list = []
        return lexical_labels_list, attachment_bags_list, internal_bags_list

    def train(self, train_buckets, dev_args_list):
        if self.options.hparams.span_cache is not None and self.span_cache is None:
            if not os.path.exists(self.options.span_cache):
                SpanCacheManager.generate_cache_file(
                    self.parser, self.options.span_cache,
                    train_buckets, [i[2] for i in dev_args_list],
                    training_features=not self.options.disable_span_dropout
                )
            self.span_cache = NoPickle(SpanCacheManager(
                self.options.span_cache, mode="r",
                span_dim=self.parser.network.contextual_unit.output_dim))
            del self.parser

        if self.network is None:
            self.create_network()

        if isinstance(self.scorer_network, CountBasedHRGScorer):
            self.logger.info("No need to train a count-based scorer.")
            if dev_args_list is not None:
                for filename, data, buckets in dev_args_list:
                    results = self.predict_bucket(buckets)
                    output_file = self.get_output_name(
                        self.options.output, filename, self.global_step)
                    new_score = self.evaluate_hg(results, output_file)
                    if new_score > self.best_score:
                        self.logger.info("New best score: %.2f > %.2f", new_score, self.best_score)
                        self.save(self.args.output + "/model")
                        self.best_score = new_score
                    else:
                        self.logger.info("Not best score, %.2f <= %.2f", new_score, self.best_score)
            return

        device = next(self.network.parameters()).device

        total_count = sys.float_info.epsilon
        correct_count = 0
        total_loss = 0.0
        sent_count = 0
        start_time = time.time()
        self.global_epoch += 1
        if not self.span_cache:
            self.parser.network.train(not self.options.hparams.disable_span_dropout)
        self.network.train()

        batch_itr = train_buckets.generate_batches(
            self.hparams.train_batch_size,
            shuffle=True, original=True, use_sub_batch=True,
            plugins=self.plugins
            # sort_key_func=lambda x: x.sent_length
        )

        for batch_data in batch_itr:
            # self.schedule_lr(self.global_step)
            for batch_sents, feed_dict in batch_data:
                if self.options.gpu:
                    to_cuda(feed_dict)
                span_features_list = self.get_span_features(batch_sents, feed_dict, device)
                lexical_labels_list, attachment_bags_list, internal_bags_list = self.get_tagger_result(
                    batch_sents, feed_dict)
                pending = []
                self.optimizer.zero_grad()
                assert len(batch_sents) == len(span_features_list)
                for sent_feature, span_features, lexical_labels, attachment_bags, internal_bags in zip_longest(
                        batch_sents, span_features_list, lexical_labels_list, attachment_bags_list, internal_bags_list):
                    tree = sent_feature.original_obj
                    sent_id = tree.extra["ID"]
                    derivations = self.derivations[sent_id]  # type: List[CFGRule]
                    cfg_nodes = list(tree.generate_rules())  # type: List[ConstTree]
                    span_features = span_features[:len(derivations)]
                    assert len(derivations) == len(cfg_nodes)

                    word_idx = 0
                    internal_idx = 0
                    for gold_rule, tree_node, span_feature in zip(
                            derivations, cfg_nodes, span_features):
                        if isinstance(tree_node.children[0], Lexicon):
                            lexical_labels_i = lexical_labels[word_idx]
                            attachment_bags_i = attachment_bags[word_idx]
                            internal_bags_i = []
                            word_idx += 1
                        else:
                            lexical_labels_i = []
                            attachment_bags_i = []
                            if self.internal_tagger is not None:
                                internal_bags_i = internal_bags[internal_idx]
                            else:
                                internal_bags_i = []
                        if tree_node.tag.endswith("#0") or tree_node.tag.endswith("#None"):
                            continue
                        try:
                            correspondents = set(
                                self.rule_lookup(tree_node, True, lexical_labels_i,
                                                 attachment_bags_i, internal_bags_i).items())
                        except ValueError as e:
                            print(e)
                            traceback.print_exc()
                            continue

                        correspondents_list = [(rule, count) for rule, count in correspondents
                                               if rule.hrg is not None]
                        if gold_rule not in correspondents:
                            correspondents_list.append((gold_rule, 1))

                        # print(span_features[tree_node.span].npvalue().shape)
                        loss_calculator = self.scorer_network.get_best_rule(
                            span_feature,
                            correspondents_list,
                            gold_rule)
                        next(loss_calculator)
                        pending.append((gold_rule, loss_calculator))

                self.scorer_network.calculate_results()
                # noinspection PyCallingNonCallable
                loss = 0.0
                for gold_rule, loss_calculator in pending:
                    best_rule, this_loss, real_best_rule = next(loss_calculator)
                    if this_loss is not None:
                        total_count += 1
                        loss += this_loss
                        if real_best_rule == gold_rule:
                            correct_count += 1
                # loss /= total_sentences
                sent_count += len(batch_sents)
                if not isinstance(loss, Tensor):
                    # loss == 0.0
                    continue
                backed = False
                while not backed:
                    try:
                        loss.backward()
                        backed = True
                    except RuntimeError as e:
                        if "out of memory" in str(e):
                            self.logger.info("OOM. Clear cache and try again.")
                            torch.cuda.empty_cache()
                            continue
                        else:
                            raise
                self.scorer_network.refresh()
                total_loss += loss.detach().cpu().numpy()
                del loss

            # update parameters
            # grad_norm = torch.nn.utils.clip_grad_norm_(self.clippable_parameters, self.grad_clip_threshold)
            self.optimizer.step()
            self.global_step += 1

            # test set
            if self.global_step % self.hparams.print_every == 0:
                end_time = time.time()
                speed = sent_count / (end_time - start_time)
                start_time = end_time
                self.progbar.update(
                    self.global_step,
                    exact=[("Epoch", self.global_epoch),
                           ("Loss", total_loss),
                           ("Speed", speed),
                           ("Corr.", correct_count / total_count * 100),
                           ]
                )
                sent_count = 0
                total_loss = 0.0
                total_count = 0
                correct_count = 0
                if self.global_step % self.hparams.evaluate_every == 0:
                    if dev_args_list is not None:
                        for filename, data, buckets in dev_args_list:
                            results = self.predict_bucket(buckets)
                            output_file = self.get_output_name(
                                self.options.output, filename, self.global_step)
                            new_score = self.evaluate_hg(results, output_file)
                            if new_score > self.best_score:
                                self.logger.info("New best score: %.2f > %.2f", new_score, self.best_score)
                                self.save(self.args.output + "/model")
                                self.best_score = new_score
                            else:
                                self.logger.info("Not best score, %.2f <= %.2f", new_score, self.best_score)
                    self.network.train()
                    if not self.span_cache:
                        self.parser.network.train(not self.options.hparams.disable_span_dropout)

    def predict_bucket(self, buckets, return_derivation=False):
        if not self.span_cache:
            self.parser.network.eval()
        self.network.eval()
        results = [None for _ in range(len(buckets))]
        batch_itr = buckets.generate_batches(
            self.hparams.train_batch_size,
            original=True, use_sub_batch=True, plugins=self.plugins
        )
        device = next(self.parser.network.parameters()).device

        for batch_data in batch_itr:
            for batch_sents, feed_dict in batch_data:
                if self.options.gpu:
                    to_cuda(feed_dict)
                span_features_list = self.get_span_features(batch_sents, feed_dict, device)
                lexical_labels_list, attachment_bags_list, internal_bags_list = self.get_tagger_result(
                    batch_sents, feed_dict)
                pending = []
                for sent_feature, span_features, lexical_labels, attachment_bags, internal_bags in zip_longest(
                        batch_sents, span_features_list, lexical_labels_list, attachment_bags_list, internal_bags_list):
                    tree = sent_feature.original_obj
                    self.populate_delphin_spans(tree)
                    cfg_nodes = list(tree.generate_rules())  # type: List[ConstTree]
                    pending_i = []
                    pending.append(pending_i)

                    span_features = span_features[:len(cfg_nodes)]
                    word_idx = 0
                    internal_idx = 0
                    for tree_node, span_feature in zip(cfg_nodes, span_features):
                        if isinstance(tree_node.children[0], Lexicon):
                            if lexical_labels is not None:
                                lexical_labels_i = lexical_labels[word_idx]
                            else:
                                lexical_labels_i = []
                            if attachment_bags is not None:
                                attachment_bags_i = attachment_bags[word_idx]
                            else:
                                attachment_bags_i = []
                            internal_bags_i = []
                            word_idx += 1
                        else:
                            lexical_labels_i = []
                            attachment_bags_i = []
                            if self.internal_tagger is not None:
                                internal_bags_i = internal_bags[internal_idx]
                            else:
                                internal_bags_i = []
                            internal_idx += 1
                        if tree_node.tag.endswith("#0") or tree_node.tag.endswith("#None"):
                            pending_i.append((tree_node, None, None))
                            continue
                        correspondents = set(
                            self.rule_lookup(tree_node, False, lexical_labels_i, attachment_bags_i,
                                             internal_bags_i).items())
                        # if isinstance(tree_node.children[0], Lexicon):
                        # print(f"{len(correspondents)} of {tree_node}")
                        # print(correspondents)
                        correspondents_list = [(rule, count) for rule, count in correspondents
                                               if rule.hrg is not None]
                        rule_getter = self.scorer_network.get_best_rule(
                            span_feature,
                            correspondents_list,
                            None)
                        next(rule_getter)
                        pending_i.append((tree_node, rule_getter, correspondents))

                self.scorer_network.calculate_results()

                for sent_feature, pending_i in zip(batch_sents, pending):
                    tree = sent_feature.original_obj
                    sub_graphs = {}
                    correspondents_map = {}
                    sync_rule_map = {}
                    for tree_node, rule_getter, correspondents in pending_i:
                        if rule_getter is None:
                            continue
                        best_rule, this_loss, real_best_rule = next(rule_getter)
                        sync_rule_map[tree_node] = best_rule
                        correspondents_map[tree_node] = correspondents
                        if isinstance(tree_node.children[0], Lexicon):
                            # preterminal
                            sub_graphs[tree_node] = SubGraph.create_leaf_graph(
                                tree_node, best_rule, tree_node.children[0].string, self.lexicon_to_lemma)
                        else:
                            left_sub_graph = sub_graphs.get(tree_node.children[0])
                            right_sub_graph = sub_graphs.get(tree_node.children[1]) if len(
                                tree_node.children) >= 2 else None
                            sub_graphs[tree_node] = SubGraph.merge(tree_node, best_rule, left_sub_graph,
                                                                   right_sub_graph)

                    if return_derivation:
                        results[sent_feature.original_idx] = tree.extra["ID"], sub_graphs[tree].graph, list(
                            self.construct_derivation(sub_graphs, correspondents_map, sync_rule_map, tree))
                    else:
                        results[sent_feature.original_idx] = tree.extra["ID"], sub_graphs[tree].graph

                self.scorer_network.refresh()
        return results

    def post_load(self, new_options: typing.Optional[AttrDict] = None):
        super(UdefQParserGreedy, self).post_load(new_options)
        if self.parser is None:
            self.parser: SpanParser = SpanParser.load(
                new_options.span_model_prefix,
                AttrDict(gpu=new_options.gpu and not self.options.span_cache,
                         bilm_path=new_options.bilm_path)
            )
        else:
            self.parser.post_load(
                AttrDict(gpu=new_options.gpu and not self.options.span_cache,
                         bilm_path=new_options.bilm_path, feature_only=True
                         ))

        if self.tagger is None and new_options.pred_tagger is not None:
            self.tagger = PredTagger.load(
                new_options.pred_tagger,
                AttrDict(gpu=new_options.gpu,
                         bilm_path=new_options.bilm_path,
                         span_model=self.parser
                         ))
            if self.options.hparams.stop_grad:
                self.tagger = NoPickle(self.tagger)
            self.tagger.network.eval()
        else:
            pass

        self.gold_graphs = NoPickle({})
