from itertools import chain

import torch
from dataclasses import dataclass
from torch.nn import LeakyReLU, Module

from coli.basic_tools.common_utils import AttrDict, NoPickle, try_cache_keeper, NullContextManager
from coli.basic_tools.dataclass_argparse import argfield
from coli.span.const_tree import ConstTree
from coli.torch_extra.layers import create_mlp
from coli.torch_extra.parser_base import SimpleParser
from coli.torch_extra.utils import broadcast_gather, cross_entropy_nd
from coli.torch_hrg.pred_tagger import ConstTreeExtra, PredTagStatistics, SentenceFeaturesWithTags, PredTagger, \
    FScoreCalculator
from coli.torch_span.parser import SpanParser
from coli.torch_tagger.config import TaggerHParams
from coli.torch_tagger.tagger import Tagger


class ConstTreeForInternalLabels(ConstTreeExtra):
    # noinspection PyMethodOverriding
    def replaced_labels(self, internal_labels):
        new_tree = ConstTree(tag=self.tag, span=self.span, extra_info=dict(self.extra))
        new_tree.children = self.children
        new_tree.extra["InternalAttachments"] = internal_labels

        node_idx = 10000
        internal_nodes = []
        for internal_attachments in new_tree.extra["InternalAttachments"]:
            this_list = []
            for attachment in internal_attachments:
                this_list.append((str(node_idx), attachment))
                node_idx += 1
            internal_nodes.append(this_list)

        new_tree.extra["InternalNodesPred"] = internal_nodes
        return new_tree

    @classmethod
    def internal_evaluate(cls, gold_sents,
                          system_sents,
                          log_file: str = None,
                          is_external: bool = False
                          ):
        internal_labels_scorer = FScoreCalculator("Internal Labels")
        for gold_sent, system_sent in zip(gold_sents, system_sents):
            internal_labels_scorer.update_sets(gold_sent.extra["InternalAttachments"],
                                               system_sent.extra["InternalAttachments"])

        internal_labels_str, internal_labels_f = internal_labels_scorer.get_report()

        if log_file is not None:
            with open(log_file, "a") as f:
                f.write(internal_labels_str)

        return internal_labels_f


class InternalTaggerNetwork(Module):
    def __init__(self, parser: SpanParser,
                 hparams: "InternalTaggerHparams",
                 statistics):
        super(InternalTaggerNetwork, self).__init__()
        self.statistics: PredTagStatistics = statistics
        self.hparams = hparams

        self.parser_network = parser.network
        if self.hparams.stop_grad:
            self.parser_network = NoPickle(self.parser_network)

        self.internal_bag_count = len(statistics.internal_bags)
        self.internal_projection = create_mlp(self.parser_network.d_model,
                                              self.internal_bag_count,
                                              layer_norm=True,
                                              activation=lambda: LeakyReLU(0.1),
                                              hidden_dims=self.hparams.dims_hidden,
                                              last_bias=True)

    def get_span_features(self, batch_sents, inputs):
        with (torch.no_grad() if self.hparams.stop_grad else NullContextManager()):
            span_features = self.parser_network(batch_sents, inputs,
                                                return_span_features=True)
        internal_features = broadcast_gather(span_features, 1, inputs.internal_indices)
        return internal_features

    def forward(self, batch_sentences, inputs):
        internal_features = self.get_span_features(batch_sentences, inputs)
        internal_logits = self.internal_projection(internal_features)
        if self.training:
            internal_loss = cross_entropy_nd(
                internal_logits, inputs.internal_bags,
                reduction="mean")
            return AttrDict(loss=internal_loss, sent_count=inputs.words.shape[0])
        else:
            return AttrDict(internal_pred=internal_logits.argmax(dim=-1))


@dataclass
class InternalTaggerHparams(TaggerHParams):
    tagger_mlp_dim: int = 100
    stop_grad: bool = False
    dim_postag: int = 0


class InternalTagger(SimpleParser):
    available_data_formats = {"default": ConstTreeForInternalLabels}
    sentence_feature_class = SentenceFeaturesWithTags

    @dataclass
    class Options(Tagger.Options):
        span_model: str = argfield(predict_time=True)
        pred_tagger: str = argfield(predict_time=True)
        hparams: InternalTaggerHparams = InternalTaggerHparams.get_default()

    def __init__(self, args: "InternalTagger.Options", train_trees):
        super(InternalTagger, self).__init__(args, train_trees)

        self.args: InternalTagger.Options
        self.hparams: InternalTaggerHparams

        @try_cache_keeper(args.span_model)
        def get_parser():
            parser = SpanParser.load(
                args.span_model,
                AttrDict(gpu=args.gpu,
                         bilm_path=args.hparams.pretrained_contextual.elmo_options.path,
                         ))
            parser.network.eval()
            return parser

        parser: SpanParser = get_parser()

        @try_cache_keeper(args.pred_tagger)
        def get_tagger():
            tagger = PredTagger.load(
                args.pred_tagger,
                AttrDict(gpu=False,
                         bilm_path=args.hparams.pretrained_contextual.elmo_options.path,
                         span_model=parser
                         ))
            tagger.network.eval()
            return tagger

        tagger: PredTagger = get_tagger()

        self.statistics: PredTagStatistics = tagger.statistics

        self.network = InternalTaggerNetwork(parser,
                                             self.hparams, self.statistics,
                                             )

        self.trainable_parameters = [param for param in chain(self.network.parameters())
                                     if param.requires_grad]

        self.optimizer, self.scheduler = self.get_optimizer_and_scheduler(self.trainable_parameters)

        if self.options.gpu:
            self.network.cuda()

    def split_batch_result(self, results):
        yield from results.internal_pred.cpu()

    def merge_answer(self, sent_feature, internal_pred):
        sent = sent_feature.original_obj
        words = sent.words

        new_internal_labels = []
        for i in internal_pred[:len(words)]:
            internal_bags = self.statistics.internal_bags.int_to_word[i]
            if internal_bags:
                new_internal_labels.append(internal_bags.split(";"))
            else:
                new_internal_labels.append([])
        return sent_feature.original_obj.replaced_labels(new_internal_labels)

    def post_load(self, new_options):
        super(InternalTagger, self).post_load(new_options)

        assert new_options.span_model is not None
        if isinstance(new_options.span_model, SpanParser):
            parser = new_options.span_model
        else:
            parser = NoPickle(SpanParser.load(
                new_options.span_model,
                AttrDict(gpu=new_options.gpu,
                         bilm_path=new_options.bilm_path
                         )))

        self.network.parser_network = parser.network


if __name__ == '__main__':
    InternalTagger.main()
