import sys
from ast import literal_eval
from collections import defaultdict
from typing import Any, List, Tuple

import torch
from dataclasses import dataclass, field
from torch.nn import LeakyReLU, Module

from coli.basic_tools.common_utils import AttrDict, NoPickle, cache_result, try_cache_keeper
from coli.basic_tools.dataclass_argparse import argfield
from coli.data_utils.dataset import PAD, UNKNOWN, START_OF_SENTENCE, END_OF_SENTENCE
from coli.data_utils.vocab_utils import Dictionary
from coli.span.const_tree import ConstTree, Lexicon, ConstTreeStatistics
from coli.torch_extra.dataset import lookup_list
from coli.torch_extra.layers import create_mlp, ContextualUnits
from coli.torch_extra.parser_base import SimpleParser
from coli.torch_extra.sentence import SentenceEmbeddings
from coli.torch_extra.utils import cross_entropy_nd, pad_and_stack_1d
from coli.torch_span.data_loader import SentenceFeatures
from coli.torch_span.parser import SpanParser
from coli.torch_tagger.config import TaggerHParams
from coli.torch_tagger.crf import CRF
from coli.torch_tagger.tagger import Tagger


class FScoreCalculator(object):
    all_label = "__ALL__"

    def __init__(self, name):
        self.name = name
        self.correct_count = defaultdict(int)
        self.gold_count = defaultdict(int)
        self.system_count = defaultdict(int)

    def update(self, gold_items, system_items):
        for gold_item, system_item in zip(gold_items, system_items):
            self.gold_count[self.all_label] += 1
            self.gold_count[gold_item] += 1
            self.system_count[self.all_label] += 1
            self.system_count[system_item] += 1
            if gold_item == system_item:
                self.correct_count[self.all_label] += 1
                self.correct_count[gold_item] += 1

    def update_sets(self, gold_sets, system_sets):
        for gold_set, system_set in zip(gold_sets, system_sets):
            self.gold_count[self.all_label] += len(gold_set)
            for gold_item in gold_set:
                self.gold_count[gold_item] += 1
            self.system_count[self.all_label] += len(system_set)
            for system_item in system_set:
                self.system_count[system_item] += 1

            correct_set = set(gold_set) & set(system_set)
            self.correct_count[self.all_label] += len(correct_set)
            for correct_item in correct_set:
                self.correct_count[correct_item] += 1

    def get_p_r_f(self, label):
        p = self.correct_count[label] / \
            (self.system_count[label] + sys.float_info.epsilon) * 100
        r = self.correct_count[label] / \
            (self.gold_count[label] + sys.float_info.epsilon) * 100
        f = 2 * p * r / (p + r + sys.float_info.epsilon)
        return p, r, f

    def get_report(self):
        ret = ""
        total_p, total_r, total_f = self.get_p_r_f(self.all_label)
        ret += f"===== {self.name} (total) =====\n"
        ret += f"P: {total_p:.2f}\n"
        ret += f"R: {total_r:.2f}\n"
        ret += f"F: {total_f:.2f}\n\n"
        ret += f"===== {self.name} (each label) =====\n"
        ret += f"Label\t\tGold\t\tSystem\t\tP\t\tR\t\tF\n"
        labels = [i[0] for i in sorted(self.gold_count.items(), key=lambda x: x[1])]
        labels.extend(i[0] for i in sorted(self.system_count.items(), key=lambda x: x[1])
                      if i[0] not in self.gold_count)
        for label in labels:
            label_p, label_r, label_f = self.get_p_r_f(label)
            gold_count = self.gold_count[label]
            system_count = self.system_count[label]
            ret += f"{label}\t\t{gold_count}\t\t{system_count}\t\t" \
                f"{label_p:.2f}\t\t{label_r:.2f}\t\t{label_f:.2f}\n"
        ret += "\n"
        return ret, total_f


class ConstTreeExtra(ConstTree):
    NodesInfo = List[List[Tuple[str, str]]]

    @classmethod
    def from_file(cls, file_name, use_edge=None, limit=float("inf")):
        ret = super(ConstTreeExtra, cls).from_file(file_name, use_edge, limit)
        for sent in ret:
            if "LexicalLabels" in sent.extra:
                sent.extra["LexicalLabels"] = literal_eval(sent.extra["LexicalLabels"])
                # sent.extra["LexicalLabels"] = [lexical_label + "!!!" + preterminal.tag
                #                                for preterminal, lexical_label in zip(
                #         sent.generate_preterminals(), sent.extra["LexicalLabels"])]
            else:
                # sent.extra["LexicalLabels"] = ["None!!!None"] * len(sent.words)
                sent.extra["LexicalLabels"] = ["None"] * len(sent.words)

            if "LexicalAttachments" in sent.extra:
                sent.extra["LexicalAttachments"] = literal_eval(sent.extra["LexicalAttachments"])
            else:
                sent.extra["LexicalAttachments"] = [[]] * len(sent.words)

            if "InternalAttachments" in sent.extra:
                sent.extra["InternalAttachments"] = literal_eval(sent.extra["InternalAttachments"])
            else:
                sent.extra["InternalAttachments"] = [[]] * len(sent.words)

            if "LexicalNodes" in sent.extra:
                sent.extra["LexicalNodes"] = literal_eval(sent.extra["LexicalNodes"])
            else:
                sent.extra["LexicalNodes"] = [[]] * len(sent.words)

            if "InternalNodes" in sent.extra:
                sent.extra["InternalNodes"] = literal_eval(sent.extra["InternalNodes"])
            else:
                sent.extra["InternalNodes"] = [[]] * len(sent.words)

            if "LexicalNodesPred" in sent.extra:
                sent.extra["LexicalNodesPred"] = literal_eval(sent.extra["LexicalNodesPred"])

            if "InternalNodesPred" in sent.extra:
                sent.extra["InternalNodesPred"] = literal_eval(sent.extra["InternalNodesPred"])

            if "StructualEdges" in sent.extra:
                sent.extra["StructualEdges"] = literal_eval(sent.extra["StructualEdges"])
            else:
                sent.extra["StructualEdges"] = []

            if "LeaftagsPred" in sent.extra:
                sent.extra["LeaftagsPred"] = literal_eval(sent.extra["LeaftagsPred"])
        return ret

    def __len__(self):
        return len(self.extra["LexicalLabels"])

    def replaced_labels(self, new_labels, new_attachments):
        # new_tree = ConstTree(tag=tree.tag, span=self.span, extra_info=dict(self.extra))
        # new_tree.children = tree.children
        new_tree = ConstTree(tag=self.tag, span=self.span, extra_info=dict(self.extra))
        new_tree.children = self.children
        # lexical_labels, leaftags = zip_longest(*(i.split("!!!") for i in new_labels))
        # new_tree.extra["LexicalLabels"] = lexical_labels
        # new_tree.extra["LeaftagsPred"] = leaftags
        new_tree.extra["LexicalLabels"] = new_labels
        new_tree.extra["LexicalAttachments"] = new_attachments
        # new_tree.extra["InternalAttachments"] = new_internal_labels
        # new_tree.extra["NewTree"] = tree.to_string()

        node_idx = 0
        lexical_nodes = []
        for lexical_label, attachments in zip(new_labels, new_attachments):
            this_list = []
            if lexical_label != "None":
                this_list.append((str(node_idx), lexical_label))
                node_idx += 1
            for attachment in attachments:
                this_list.append((str(node_idx), attachment))
                node_idx += 1
            lexical_nodes.append(this_list)

        new_tree.extra["LexicalNodesPred"] = lexical_nodes

        return new_tree

    @classmethod
    def internal_evaluate(cls, gold_sents,
                          system_sents,
                          log_file: str = None,
                          is_external: bool = False
                          ):
        # ConstTree.internal_evaluate(gold_sents, system_sents, log_file)
        lexical_labels_scorer = FScoreCalculator("Lexical Labels")
        # leaftags_scorer = FScoreCalculator("Leaftags")
        attachments_scorer = FScoreCalculator("Attachments")
        # internal_labels_scorer = FScoreCalculator("Internal Labels")

        for gold_sent, system_sent in zip(gold_sents, system_sents):
            # gold_lexical_labels, gold_leaftags = zip_longest(
            #     *(i.split("!!!") for i in gold_sent.extra["LexicalLabels"]))

            # if is_external:
            #     pred_lexical_labels = [
            #         i.split("!!!")[0] for i in system_sent.extra["LexicalLabels"]]
            # else:
            #     pred_lexical_labels = system_sent.extra["LexicalLabels"]
            # pred_leaftags = system_sent.extra["LeaftagsPred"]
            gold_lexical_labels = gold_sent.extra["LexicalLabels"]
            pred_lexical_labels = system_sent.extra["LexicalLabels"]
            lexical_labels_scorer.update(gold_lexical_labels, pred_lexical_labels)
            attachments_scorer.update_sets(gold_sent.extra["LexicalAttachments"],
                                           system_sent.extra["LexicalAttachments"])
            # internal_labels_scorer.update_sets(gold_sent.extra["InternalAttachments"],
            #                                    system_sent.extra["InternalAttachments"])
            # leaftags_scorer.update(gold_leaftags, pred_leaftags)

        lexical_labels_str, lexical_labels_f = lexical_labels_scorer.get_report()
        # leaftags_str, leaftags_f = leaftags_scorer.get_report()
        attachments_str, attachments_f = attachments_scorer.get_report()
        # internal_labels_str, internal_labels_f = internal_labels_scorer.get_report()

        print(f"Lexical: {lexical_labels_f:.2f}")
        # print(f"Leaftags: {leaftags_f:.2f}")
        print(f"Attachments: {attachments_f:.2f}")
        # print(f"Internal: {internal_labels_f:.2f}")

        if log_file is not None:
            with open(log_file, "a") as f:
                f.write(lexical_labels_str)
                # f.write(leaftags_str)
                f.write(attachments_str)
                # f.write(internal_labels_str)

        return lexical_labels_f + attachments_f

    @classmethod
    def evaluate_with_external_program(cls, gold_file, system_file, perf_file=None, print=True):
        if perf_file is None:
            perf_file = system_file + ".txt"
        cls.internal_evaluate(cls.from_file(gold_file),
                              cls.from_file((system_file)),
                              log_file=perf_file,
                              is_external=True)


@dataclass
class PredTagStatistics(object):
    pred_tags: Dictionary = None
    attachment_tags: Dictionary = None
    attachment_bags: Dictionary = None
    internal_bags: Dictionary = None
    # pred_tag_to_leaftag: Tensor
    structural_edges: Dictionary = None
    all_nodes: Dictionary = None

    @classmethod
    def from_sentences(cls, sentences):
        ret = cls()
        ret.pred_tags = Dictionary(initial=())
        ret.attachment_tags = Dictionary(initial=())
        ret.attachment_bags = Dictionary(initial=())
        ret.internal_bags = Dictionary(initial=())
        ret.structural_edges = Dictionary(initial=("__EMPTY__",))
        ret.all_nodes = Dictionary()
        for sentence in sentences:
            for lex_label in sentence.extra["LexicalLabels"]:
                ret.pred_tags.update_and_get_id(lex_label)
            for attachments in sentence.extra["LexicalAttachments"]:
                attachment_bag = ";".join(sorted(attachments))
                ret.attachment_bags.update_and_get_id(attachment_bag)
                for attachment in attachments:
                    ret.attachment_tags.update_and_get_id(attachment)
            for internals in sentence.extra["InternalAttachments"]:
                internal_bag = ";".join(sorted(internals))
                ret.internal_bags.update_and_get_id(internal_bag)
            for start, end, label in sentence.extra["StructualEdges"]:
                ret.structural_edges.update_and_get_id(label)
            for nodes_info in sentence.extra["LexicalNodes"]:
                for (node_var, node_label) in nodes_info:
                    ret.all_nodes.update_and_get_id(node_label)
            for nodes_info in sentence.extra["InternalNodes"]:
                for (node_var, node_label) in nodes_info:
                    ret.all_nodes.update_and_get_id(node_label)
        return ret

        # noinspection PyCallingNonCallable
        # pred_tag_to_leaftag = torch.tensor([leaftag_dict.word_to_int.get(
        #     tag_name.rsplit("!!!", 1)[-1], 0)
        #     for tag_name in pred_tags.int_to_word])

        # return cls(pred_tags, attachment_tags, attachment_bags, internal_bags,
        #            # pred_tag_to_leaftag,
        #            structural_edges, all_nodes)

    def __str__(self):
        return f"{len(self.pred_tags)} pred tags, " \
            f"{len(self.attachment_tags)} attachment tags, " \
            f"{len(self.attachment_bags)} attachment bags, " \
            f"{len(self.internal_bags)} internal bags"


@dataclass
class SentenceFeaturesWithTags(SentenceFeatures):
    words_for_tagger: Any = None
    # words_pretrained_for_tagger: Any = None
    pred_tags: Any = None
    attachment_tags: Any = None

    lexical_indices: Any = None
    internal_indices: Any = None
    internal_count: Any = None
    lexical_count: Any = None

    @classmethod
    def from_sentence_obj(cls, original_idx, sent: ConstTreeExtra,
                          statistics: PredTagStatistics,
                          padded_length=None, lower=True,
                          plugins=None,
                          *args, **kwargs
                          ):
        # noinspection PyTypeChecker
        ret = super(SentenceFeaturesWithTags, cls).from_sentence_obj(
            original_idx, sent, statistics, padded_length, *args,
            plugins=plugins, **kwargs)

        lower_func = lambda x: x.lower() if lower else lambda x: x

        ret.words_for_tagger = lookup_list(
            (lower_func(i) for i in sent.words), statistics.words.word_to_int,
            padded_length=padded_length, default=1,
            start_and_stop=False
        )

        # if "pretrained" is not None:
        #     ret.extra["external_embedding_for_tagger"] = ret.extra["external_embedding"].clone().detach()
        #     ret.extra["external_embedding_for_tagger"] = 0
        #     ret.extra["external_embedding_for_tagger"] = ret.extra["external_embedding_for_tagger"] [1:(ret.words.size(0) - 1)]

        ret.internal_indices = torch.zeros((padded_length,), dtype=torch.int64)
        ret.lexical_indices = torch.zeros((padded_length,), dtype=torch.int64)
        internal_count = 0
        lexical_count = 0
        for idx, rule in enumerate(sent.generate_rules()):
            if not isinstance(rule.children[0], Lexicon):
                ret.internal_indices[internal_count] = idx
                internal_count += 1
            else:
                ret.lexical_indices[lexical_count] = idx
                lexical_count += 1

        ret.internal_count = internal_count
        ret.lexical_count = lexical_count

        # assert len(sent.extra["LexicalLabels"]) == ret.sent_length - 2
        # assert len(sent.extra["LexicalAttachments"]) == ret.sent_length - 2

        ret.pred_tags = lookup_list(
            sent.extra["LexicalLabels"], statistics.pred_tags.word_to_int,
            padded_length=padded_length, default=1,
            start_and_stop=False,
            tensor_factory=lambda shape, *, dtype: torch.full(shape, -100, dtype=dtype)
        )

        ret.attachment_bags = lookup_list(
            (";".join(sorted(i)) for i in sent.extra.get("LexicalAttachments", ["None"] * (ret.sent_length - 2))),
            statistics.attachment_bags.word_to_int,
            padded_length=padded_length, default=1,
            start_and_stop=False,
            tensor_factory=lambda shape, *, dtype: torch.full(shape, -100, dtype=dtype)
        )

        ret.internal_bags = lookup_list(
            (";".join(sorted(i)) for i in sent.extra.get("InternalAttachments", [[]] * (ret.sent_length - 2))),
            statistics.internal_bags.word_to_int,
            padded_length=padded_length, default=1,
            tensor_factory=lambda shape, *, dtype: torch.full(shape, -100, dtype=dtype),
            start_and_stop=False
        )

        if plugins:
            for plugin in plugins.values():
                plugin.process_sentence_feature(sent, ret, padded_length)

        return ret

    @classmethod
    def get_feed_dict(cls, pls, batch_sentences, plugins=None):
        # noinspection PyCallingNonCallable
        ret = {pls.words: torch.stack([i.words for i in batch_sentences]),
               pls.sent_lengths: torch.tensor([i.sent_length for i in batch_sentences]),
               pls.pred_tags: torch.stack([i.pred_tags for i in batch_sentences]),
               pls.attachment_bags: torch.stack([i.attachment_bags for i in batch_sentences]),
               pls.internal_bags: torch.stack([i.internal_bags for i in batch_sentences]),
               pls.internal_indices: pad_and_stack_1d([i.internal_indices for i in batch_sentences])
               }

        if plugins:
            for plugin in plugins.values():
                plugin.process_batch(pls, ret, batch_sentences)

        return ret


class TaggerNetworkExtra(Module):
    def __init__(self, hparams: "PredTaggerHParams",
                 statistics, plugins,
                 target="labels"):
        super(TaggerNetworkExtra, self).__init__()
        self.hparams = hparams
        self.statistics: PredTagStatistics = statistics

        self.embeddings = SentenceEmbeddings(hparams.sentence_embedding, statistics, plugins)
        self.embeddings.reset_parameters()

        # RNN
        self.rnn = ContextualUnits.get(hparams.contextual,
                                       input_size=self.embeddings.output_dim)
        self.attachment_rnn = ContextualUnits.get(hparams.attachment_contextual,
                                                  input_size=self.embeddings.output_dim)

        self.label_count = len(getattr(statistics, target))
        self.projection = create_mlp(self.rnn.output_dim,
                                     self.label_count,
                                     hidden_dims=self.hparams.dims_hidden,
                                     layer_norm=True,
                                     last_bias=True,
                                     activation=lambda: LeakyReLU(0.1))

        self.attachment_bag_count = len(statistics.attachment_bags)
        self.attachment_projection = create_mlp(self.attachment_rnn.output_dim,
                                                self.attachment_bag_count,
                                                layer_norm=True,
                                                activation=lambda: LeakyReLU(0.1),
                                                hidden_dims=self.hparams.attachment_dims_hidden,
                                                last_bias=True)

        if self.hparams.use_crf:
            self.crf_unit = CRF(self.label_count)

    def load_bilm(self, bilm_path, gpu):
        self.embeddings.load_bilm(bilm_path, gpu)

    def forward(self, batch_sentences, inputs):
        # internal_features = self.get_span_features(batch_sentences, inputs)
        # internal_logits = self.internal_projection(internal_features)

        new_inputs = AttrDict(inputs)
        new_inputs.sent_lengths = inputs.sent_lengths - 2
        new_inputs.words = inputs.words[:, 1:-1]

        total_input_embeded = self.embeddings(new_inputs)
        contextual_output = self.rnn(total_input_embeded, new_inputs.sent_lengths)
        attachment_contextual_output = self.attachment_rnn(
            total_input_embeded, new_inputs.sent_lengths)

        lexical_logits = self.projection(contextual_output)
        attachment_logits = self.attachment_projection(attachment_contextual_output)

        batch_size = total_input_embeded.shape[0]
        ret = AttrDict(sent_count=batch_size, lexical_logits=lexical_logits,
                       attachment_logits=attachment_logits)

        if self.training:
            if self.hparams.use_crf:
                norm_scores = self.crf_unit(lexical_logits, new_inputs.sent_lengths)
                # strip "-100"
                # noinspection PyCallingNonCallable
                answer = torch.max(new_inputs.pred_tags, torch.tensor(0, device=new_inputs.pred_tags.device))
                lstm_scores = torch.gather(lexical_logits, 2, answer.unsqueeze(-1)).squeeze(-1)
                mask_2d = new_inputs.words.gt(0)
                lstm_scores_masked = (lstm_scores * mask_2d.float()).sum(-1)
                transition_scores = self.crf_unit.transition_score(answer, new_inputs.sent_lengths)
                sequence_scores = transition_scores + lstm_scores_masked
                losses = norm_scores - sequence_scores
                lexical_loss = losses.mean()
            else:
                lexical_loss = cross_entropy_nd(
                    lexical_logits, inputs.pred_tags,
                    reduction="mean")

            attachments_loss = cross_entropy_nd(
                attachment_logits, inputs.attachment_bags,
                reduction='mean')

            # internal_loss = cross_entropy_nd(
            #     internal_logits, inputs.internal_bags,
            #     reduction="mean")

            ret["loss"] = lexical_loss + attachments_loss
        else:
            ret["lexical_pred"] = lexical_pred = lexical_logits.argmax(dim=-1)
            ret["attachments_pred"] = attachment_logits.argmax(dim=-1)
            # ret["internal_pred"] = internal_logits.argmax(dim=-1)

            # leaftag_indices = self.pred_tag_to_leaftag.gather(
            #     0, lexical_pred.view(-1)).view(*lexical_pred.shape)
            # leaftag_scores = torch.zeros((leaftag_indices.size(0),
            #                               leaftag_indices.size(1),
            #                               len(self.statistics.leaftags)),
            #                              dtype=torch.float32,
            #                              device=lexical_pred.device)
            #
            # leaftag_indices_broadcast = leaftag_indices.unsqueeze(-1).expand(
            #     -1, -1, len(self.statistics.leaftags))
            # leaftag_scores.scatter_(2, leaftag_indices_broadcast, 2)
            #
            # new_trees, _ = self.parser.network(batch_sentences, inputs,
            #                                    leaftag_scores=leaftag_scores.double().detach().cpu().numpy())
            # ret["new_trees"] = new_trees

        return ret


@dataclass
class PredTaggerHParams(TaggerHParams):
    dim_postag: int = 0

    tagger_mlp_dim: int = 100
    stop_grad: bool = False
    attachment_contextual: ContextualUnits.Options = field(
        default_factory=ContextualUnits.Options)
    attachment_dims_hidden: List[int] = field(default_factory=lambda: [100])


class PredTagger(SimpleParser):
    available_data_formats = {"default": ConstTreeExtra}
    sentence_feature_class = SentenceFeaturesWithTags

    @dataclass
    class Options(Tagger.Options):
        span_model: str = argfield(None)
        hparams: PredTaggerHParams = PredTaggerHParams.get_default()

    def __init__(self, args: "PredTagger.Options", train_trees):
        super(PredTagger, self).__init__(args, train_trees)
        self.args: PredTagger.Options
        self.hparams: PredTaggerHParams

        @try_cache_keeper(args.span_model)
        def get_parser():
            parser = SpanParser.load(
                args.span_model,
                AttrDict(gpu=args.gpu,
                         bilm_path=args.hparams.pretrained_contextual.elmo_options.path
                         ))
            return parser

        @cache_result(self.options.output + "/" + "statistics.pkl",
                      enable=self.options.debug_cache)
        def get_statistics():
            statistics = PredTagStatistics.from_sentences(
                train_trees)

            if args.span_model is not None:
                parser: SpanParser = get_parser()
                old_statistics = parser.statistics
            else:
                old_statistics = ConstTreeStatistics.from_sentences(train_trees)
            # copy statistics
            for i in ("words", "postags", "characters", "labels", "leaftags"):
                setattr(statistics, i, getattr(old_statistics, i))

            return statistics

        self.statistics: PredTagStatistics = get_statistics()
        print(self.statistics)

        self.network = TaggerNetworkExtra(self.hparams, self.statistics,
                                          self.plugins,
                                          target="pred_tags"
                                          )

        self.trainable_parameters = [param for param in
                                     self.network.parameters()
                                     if param.requires_grad]

        self.optimizer, self.scheduler = self.get_optimizer_and_scheduler(self.trainable_parameters)

        if self.options.gpu:
            self.network.cuda()

    def split_batch_result(self, results):
        yield from zip(
            # results.new_trees,
            results.lexical_pred.cpu(),
            results.attachments_pred.cpu(),
            # results.internal_pred.cpu(),
        )

    def merge_answer(self, sent_feature, answer):
        sent = sent_feature.original_obj
        words = sent.words
        # new_tree, lexical_pred, attachments_pred, internal_pred = answer
        lexical_pred, attachments_pred = answer
        new_labels = list(self.statistics.pred_tags.int_to_word[i]
                          for i in lexical_pred[:len(words)])
        new_attachments = []
        for i in attachments_pred[:len(words)]:
            attachment_bag_str = self.statistics.attachment_bags.int_to_word[i]
            if attachment_bag_str:
                attachments = attachment_bag_str.split(";")
                new_attachments.append(attachments)
            else:
                new_attachments.append([])

        # new_internals = []
        # for i in internal_pred[:sent_feature.internal_count]:
        #     internal_bag_str = self.statistics.internal_bags.int_to_word[i]
        #     if internal_bag_str:
        #         attachments = internal_bag_str.split(";")
        #         new_internals.append(attachments)
        #     else:
        #         new_internals.append([])

        return sent.replaced_labels(new_labels, new_attachments)


if __name__ == '__main__':
    PredTagger.main()
