import sys
import time
import traceback
import weakref
from itertools import chain, product, zip_longest
from typing import Optional, Iterable, Union, List

import attr
import numpy as np
import torch
import torch.nn.functional as F
from allennlp.modules.attention import BilinearAttention
from dataclasses import dataclass, field
from torch import Tensor, nn

from coli.basic_tools.common_utils import cache_result, add_slots, split_to_batches
from coli.basic_tools.dataclass_argparse import argfield, OptionsBase
from coli.basic_tools.logger import logger
from coli.data_utils.dataset import TensorflowHParamsBase
from coli.hrgguru.hrg import CFGRule
from coli.hrgguru.sub_graph import SubGraph
from coli.span.const_tree import ConstTree, Lexicon
from coli.torch_extra.autobatch import AutoBatchModule
from coli.torch_extra.graph_embedding.dataset import Batch
from coli.torch_extra.graph_embedding.graph_encoder import GraphRNNEncoder
from coli.torch_extra.graph_embedding.reader import graph_embedding_types
from coli.torch_extra.utils import pad_and_stack_1d
from coli.torch_hrg.hrg_parser_base import UdefQParserBase


@add_slots
@dataclass
class BeamItem(object):
    node_info_ref: weakref
    sync_rule: CFGRule
    external_mask: List[bool]
    result_idx: int
    left: Optional["BeamItem"]
    right: Optional["BeamItem"]
    is_gold: bool

    external_embeddings: Tensor = None
    external_hidden: Tensor = None
    score: Optional[Tensor] = None

    def __lt__(self, other):
        return self.score < other.score

    def __eq__(self, other):
        return self is other

    def make_graph(self, tree_node, lexicon_to_lemma=None):
        if self.left is None and self.right is None:
            return SubGraph.create_leaf_graph(tree_node, self.sync_rule,
                                              word=tree_node.children[0].string,
                                              lexicon_to_lemma=lexicon_to_lemma)
        left_graph = right_graph = None
        if self.left is not None and self.left.sync_rule is not None:
            left_graph = self.left.make_graph(tree_node.children[0], lexicon_to_lemma)
        if self.right is not None and self.right.sync_rule is not None:
            right_graph = self.right.make_graph(tree_node.children[1], lexicon_to_lemma)
        return SubGraph.merge(tree_node, self.sync_rule, left_graph, right_graph)


@attr.s(slots=True)
class TraversalSequenceItem(object):
    __weakref__ = attr.ib(init=False, hash=False, repr=False, cmp=False)
    cfg_node = attr.ib(type=ConstTree)
    gold_rule = attr.ib(default=None)
    beam = attr.ib(type=Iterable[BeamItem], default=None)
    correspondents = attr.ib(type=List[CFGRule], default=None)
    left = attr.ib(default=None)
    right = attr.ib(default=None)
    gold_item = attr.ib(default=None)
    early_updated = attr.ib(default=False)
    is_preterminal = attr.ib(type=bool, init=False)

    def __attrs_post_init__(self):
        self.is_preterminal = isinstance(self.cfg_node.children[0], Lexicon)


@dataclass
class PrintLogger(object):
    total_count: int = sys.float_info.epsilon
    correct_count: int = 0
    total_loss: float = 0
    idx: int = 0
    start_time = time.time()

    def print(self, idx):
        end_time = time.time()
        logger.info("Sent {}, Correctness: {:.2f}, "
                    "loss: {:.2f}, speed: {:.2f}".format(
            idx + 1,
            self.correct_count / self.total_count * 100,
            self.total_loss,
            (idx - self.idx) / (end_time - self.start_time)
        ))
        self.start_time = end_time
        self.idx = idx
        self.total_count = sys.float_info.epsilon
        self.correct_count = 0
        self.total_loss = 0


empty_beam_item = BeamItem(None, None, None, 0.0, None, None, True)


class SubGraphEmbeddingNetwork(AutoBatchModule):
    @dataclass
    class HParams(OptionsBase):
        hrg_mlp_dim: int = 100
        graph_embedding_type: str = field(
            default="normal",
            metadata={"choices": graph_embedding_types})
        graph_encoder: GraphRNNEncoder.Options = field(default_factory=GraphRNNEncoder.Options)
        span_reduced_dim: int = 256
        graph_reduced_dim: int = 128
        hrg_batch_size: int = 128
        span_dropout: float = 0.33

    def __init__(self,
                 grammar,
                 d_model,
                 new_options
                 ):
        super(SubGraphEmbeddingNetwork, self).__init__()
        self.options: SubGraphEmbeddingNetwork.HParams = new_options.hparams.graph_embedding
        self.use_gpu = new_options.gpu

        @cache_result(
            new_options.output + "/simple_graphs.pkl", new_options.debug_cache)
        def get_simple_graphs_and_statistics():
            ret = {}
            statistics = graph_embedding_types[self.options.graph_embedding_type]()
            for rules_and_counts in grammar.values():
                for rule in rules_and_counts.keys():
                    if rule.hrg is not None:
                        simple_graph = statistics.read_sync_rule_and_update(rule, True)
                        ret[rule.hrg] = simple_graph
            return ret, statistics

        self.simple_graphs, self.statistics = get_simple_graphs_and_statistics()

        self.activation = nn.ReLU
        self.attention = BilinearAttention(self.options.span_reduced_dim,
                                           self.options.graph_encoder.model_hidden_size,
                                           activation=torch.tanh
                                           )

        self.graph_encoder = GraphRNNEncoder(self.options.graph_encoder, self.statistics)

        self.span_reducer = nn.Sequential(
            nn.Dropout(self.options.span_dropout),
            nn.Linear(d_model, self.options.span_reduced_dim),
            nn.LayerNorm(self.options.span_reduced_dim),
            nn.LeakyReLU(0.1)
        )

        self.graph_reducer = nn.Sequential(
            nn.Linear(self.options.graph_encoder.model_hidden_size, self.options.graph_reduced_dim),
            nn.LayerNorm(self.options.graph_reduced_dim),
            nn.LeakyReLU(0.1)
        )

        self.bilinear_layer = nn.Bilinear(self.options.span_reduced_dim,
                                          self.options.graph_reduced_dim,
                                          1)

    def forward(self, span_features_list: List[Tensor], batch_graphs: list,
                initial_states_list: List[Tensor],
                entity_embeddings_extra: List[Tensor]
                ):
        span_features_uniq_list = []
        span_features_map = {}
        ununiq_indices = []
        for idx, span_features in enumerate(span_features_list):
            new_idx = span_features_map.get(id(span_features))
            if new_idx is None:
                span_features_map[id(span_features)] = new_idx = len(span_features_uniq_list)
                span_features_uniq_list.append(span_features)
            ununiq_indices.append(new_idx)

        batch_span_features_uniq = torch.stack(span_features_uniq_list)
        span_inputs_reduced_uniq = self.span_reducer(batch_span_features_uniq)
        span_inputs_reduced = span_inputs_reduced_uniq[ununiq_indices]

        results = []
        graph_embeddings = []
        output_node_embeddings_0 = []
        sort_idx = np.argsort([len(i["entities"]) for i in batch_graphs])
        unsort_idx = np.argsort(sort_idx)
        inputs_sorted = [batch_graphs[i] for i in sort_idx]
        span_inputs_sorted = span_inputs_reduced[sort_idx]
        init_states_list_sorted = [initial_states_list[i] for i in sort_idx]
        entity_embeddings_extra_sorted = [entity_embeddings_extra[i] for i in sort_idx]
        for start_idx, _, batch_pending_inputs in split_to_batches(
                inputs_sorted, self.options.hrg_batch_size):
            span_features_r = span_inputs_sorted[start_idx:start_idx + self.options.hrg_batch_size]
            batch = Batch()
            batch.init(self.options.graph_encoder, batch_pending_inputs, device=span_features_list[0].device)
            # if self.use_gpu:
            #     batch = batch.cuda()
            batch_initial_states = pad_and_stack_1d(init_states_list_sorted[
                                                    start_idx:start_idx + self.options.hrg_batch_size],
                                                    batch.entity_labels.size(1),
                                                    device=init_states_list_sorted[0].device)
            batch_extra_embeddings = pad_and_stack_1d(entity_embeddings_extra_sorted[
                                                      start_idx:start_idx + self.options.hrg_batch_size],
                                                      batch.entity_labels.size(1),
                                                      device=entity_embeddings_extra_sorted[0].device)
            graph_outputs = self.graph_encoder(batch, batch_initial_states, batch_extra_embeddings)
            node_embeddings_n = graph_outputs[0][-1]
            node_embeddings_0 = graph_outputs[1]
            node_mask = batch.entities_mask.squeeze(-1)
            attention_weights = self.attention(span_features_r, node_embeddings_n, node_mask)
            graph_embedding_batch = torch.bmm(attention_weights.unsqueeze(1), node_embeddings_n).squeeze(1)
            results.append((span_features_r, graph_embedding_batch))
            graph_embeddings.extend(node_embeddings_n[i] for i in range(node_embeddings_n.shape[0]))
            output_node_embeddings_0.extend(node_embeddings_0[i] for i in range(node_embeddings_0.shape[0]))

        span_features_batch_all_r = torch.cat([i[0] for i in results], dim=0)
        graph_embedding_batch_all = torch.cat([i[1] for i in results], dim=0)
        graph_embedding_batch_all_r = self.graph_reducer(graph_embedding_batch_all)
        dense_out = self.bilinear_layer(span_features_batch_all_r, graph_embedding_batch_all_r)

        return [graph_embeddings[i] for i in unsort_idx], \
               [output_node_embeddings_0[i] for i in unsort_idx], \
               dense_out.squeeze(-1)[unsort_idx]


class GraphEmbeddingUdefQParser(UdefQParserBase):
    PrintLogger = PrintLogger

    @dataclass
    class HParams(UdefQParserBase.HParams):
        graph_embedding: SubGraphEmbeddingNetwork.HParams = SubGraphEmbeddingNetwork.HParams()
        loss_type: str = "hinge"

    @dataclass
    class Options(UdefQParserBase.Options):
        hparams: "GraphEmbeddingUdefQParser.HParams" = argfield(
            default_factory=lambda: GraphEmbeddingUdefQParser.HParams())

    def create_network(self):
        # delay network creation
        self.network = SubGraphEmbeddingNetwork(self.grammar,
                                                self.parser.network.contextual_unit.output_dim,
                                                self.options)

        if self.options.gpu:
            self.network.cuda()

        if not self.hparams.stop_grad:
            trainable_parameters = [param for param in
                                    chain(self.parser.network.parameters(), self.network.parameters())
                                    if param.requires_grad]
        else:
            trainable_parameters = list(self.network.parameters())

        self.optimizer = torch.optim.Adam(trainable_parameters)

    hook_1 = None

    def hook_2(self):
        self.network.calculate_results()

    def hook_3(self):
        self.network.refresh()

    def training_session(self, tree, span_features,
                         lexical_labels, attachment_bags, internal_bags,
                         print_logger, derivations=()):
        device = next(self.network.parameters()).device
        is_train = bool(derivations)

        if "DelphinSpan" not in tree.extra:
            self.populate_delphin_spans(tree, args_and_names=False)

        # each cfg tree node is assigned a beam and a list of sync rule correspondents
        traversal_sequence = []
        cfg_node_to_item = {}
        span_to_expr = {}

        for idx, node in enumerate(tree.generate_rules()):
            span_to_expr[node.span] = span_features[idx]

        tree_nodes = list(tree.generate_rules())
        total_loss = 0.0

        # generate expressions
        if derivations:
            assert len(derivations) == len(tree_nodes)

        word_idx = 0
        internal_idx = 0
        for node_idx, (gold_rule, tree_node) in enumerate(zip_longest(derivations, tree_nodes)):
            if self.tagger is None:
                lexical_labels_i, attachment_bags_i, internal_bags_i = [], [], []
            else:
                if isinstance(tree_node.children[0], Lexicon):
                    lexical_labels_i = lexical_labels[word_idx]
                    attachment_bags_i = attachment_bags[word_idx]
                    internal_bags_i = []
                    word_idx += 1
                else:
                    lexical_labels_i = []
                    attachment_bags_i = []
                    if internal_bags is not None:
                        internal_bags_i = internal_bags[internal_idx]
                    else:
                        internal_bags_i = []
                    internal_idx += 1

            if tree_node.tag.endswith("#0") or tree_node.tag.endswith("#None"):
                node_info = TraversalSequenceItem(tree_node)
                cfg_node_to_item[tree_node] = node_info
                traversal_sequence.append(node_info)
                continue
            try:
                rules_dict = self.rule_lookup(tree_node, is_train,
                                              lexical_labels_i, attachment_bags_i,
                                              internal_bags_i)
                correspondents = list(rules_dict)
                if is_train and gold_rule not in correspondents:
                    correspondents.append(gold_rule)
            except ValueError as e:
                traceback.print_exc()
                if is_train:
                    yield None
                    return 0.0 if derivations is not None else None
                else:
                    raise
            node_info = TraversalSequenceItem(tree_node, gold_rule, correspondents=correspondents)
            traversal_sequence.append(node_info)
            cfg_node_to_item[tree_node] = node_info
            if isinstance(tree_node.children[0], ConstTree):
                node_info.left = cfg_node_to_item[tree_node.children[0]]
            if len(tree_node.children) == 2 and isinstance(tree_node.children[1], ConstTree):
                node_info.right = cfg_node_to_item[tree_node.children[1]]

        # do tree beam search
        for node_idx, node_info in enumerate(traversal_sequence):
            if node_info.correspondents is None:
                # deal with semantic null
                node_info.beam = [empty_beam_item]
                node_info.gold_item = empty_beam_item if derivations else None
                continue

            sync_rules = node_info.correspondents
            if node_info.is_preterminal:
                # deal with pre-terminal nodes
                items = []

                for sync_rule in sync_rules:
                    is_gold = sync_rule == node_info.gold_rule
                    graph = self.network.simple_graphs.get(sync_rule.hrg)
                    if graph is None:
                        graph = self.network.statistics.read_sync_rule_and_update(sync_rule, is_train)
                    hidden_states = torch.zeros((len(graph["entities"]), self.network.graph_encoder.hidden_size),
                                                device=device)
                    extra_node_embeddings = torch.zeros((len(graph["entities"]),
                                                         self.network.graph_encoder.node_embedding_dim),
                                                        device=device)
                    result_idx = self.network.add_input(span_features[node_idx], graph, hidden_states,
                                                        extra_node_embeddings)
                    beam_item = BeamItem(weakref.ref(node_info), sync_rule,
                                         graph["external_indices"], result_idx, None, None,
                                         is_gold)
                    if is_gold:
                        node_info.gold_item = beam_item
                    items.append(beam_item)

                # calculate scores in batch
                yield None

                # get scores
                for item in items:
                    nodes_hidden = self.network.results[0][item.result_idx]
                    item.external_hidden = nodes_hidden[item.external_mask]

                    nodes_embeddings = self.network.results[1][item.result_idx]
                    item.external_embeddings = nodes_embeddings[item.external_mask]

                    item.score = self.network.results[2][item.result_idx]  # - (0.5 if item.is_gold else 0.0)

                items.sort(reverse=True)
                node_info.beam = items

                if is_train:
                    gold_index = -1
                    for index, item in enumerate(node_info.beam):
                        if item.is_gold:
                            gold_index = index
                            break
                    if len(items) > 1:
                        print_logger.total_count += 1
                        if gold_index == 0:
                            print_logger.correct_count += 1
                    assert gold_index >= 0

                    # if self.hparams.greedy_at_leaf or gold_index >= self.options.beam_size:
                    if True:
                        # early update
                        if self.hparams.loss_type == "hinge":
                            total_loss += node_info.beam[0].score - node_info.gold_item.score
                        else:
                            # noinspection PyCallingNonCallable
                            total_loss += F.cross_entropy(
                                torch.stack([i.score for i in node_info.beam]).unsqueeze(0),
                                torch.tensor([gold_index], device=device))
                        node_info.beam = [node_info.gold_item]
                node_info.beam = node_info.beam[:self.options.beam_size]
                # noinspection PyUnboundLocalVariable
                if is_train and gold_index >= self.options.beam_size:
                    node_info.beam[-1] = node_info.gold_item
            else:  # aka: node is not preterminal
                items = []
                # deal with non-leaf nodes
                for sync_rule in sync_rules:
                    for left_item, right_item in product(node_info.left.beam,
                                                         node_info.right.beam):
                        is_gold = left_item.is_gold and right_item.is_gold and sync_rule == node_info.gold_rule
                        left_item: BeamItem
                        right_item: BeamItem

                        # if not is_train:
                        #     edge_tuples = [(edge.nodes[0], edge.label)
                        #                    for edge in sub_graph.graph.edges
                        #                    if len(edge.nodes) == 2 and edge.span is None]
                        #     # check consistency
                        #     if len(edge_tuples) != len(set(edge_tuples)):
                        #         pending_score -= 10
                        #
                        #     # sometimes there are nodes without pred edge
                        #     pred_nodes = set(edge.nodes[0]
                        #                      for edge in sub_graph.graph.edges
                        #                      if len(edge.nodes) == 1 and edge.span is not None
                        #                      )
                        #     if sub_graph.graph.nodes - set(sub_graph.external_nodes) - pred_nodes:
                        #         pending_score -= 100
                        #
                        #     # sometimes there are multiple pred edges in a node
                        #     pred_nodes_count = defaultdict(int)
                        #     for edge in sub_graph.graph.edges:
                        #         if len(edge.nodes) == 1 and edge.span is not None:
                        #             pred_nodes_count[edge.nodes[0]] += 1
                        #     if max(pred_nodes_count.values()) > 1:
                        #         pending_score -= 100

                        graph = self.network.simple_graphs.get(sync_rule.hrg)
                        if graph is None:
                            graph = self.network.statistics.read_sync_rule_and_update(sync_rule, is_train)
                        hidden_states = torch.zeros((len(graph["entities"]), self.network.graph_encoder.hidden_size),
                                                    device=device)
                        extra_node_embeddings = torch.zeros((len(graph["entities"]),
                                                             self.network.graph_encoder.node_embedding_dim),
                                                            device=device)
                        if left_item.external_hidden is not None:
                            # noinspection PyCallingNonCallable
                            hidden_states.index_add_(
                                0, graph.left_nodes.to(device),
                                left_item.external_hidden)
                            extra_node_embeddings.index_add_(
                                0, graph.left_nodes.to(device),
                                left_item.external_embeddings)
                        if right_item.external_hidden is not None:
                            # noinspection PyCallingNonCallable
                            hidden_states.index_add_(
                                0, graph.right_nodes.to(device),
                                right_item.external_hidden)
                            extra_node_embeddings.index_add_(
                                0, graph.right_nodes.to(device),
                                right_item.external_embeddings)
                        result_idx = self.network.add_input(span_features[node_idx], graph, hidden_states,
                                                            extra_node_embeddings)
                        beam_item = BeamItem(weakref.ref(node_info), sync_rule,
                                             graph["external_indices"], result_idx, left_item, right_item,
                                             is_gold)
                        if is_gold:
                            node_info.gold_item = beam_item
                        items.append(beam_item)

                # calculate scores in batch
                yield None

                # get scores
                for item in items:
                    nodes_hidden = self.network.results[0][item.result_idx]
                    item.external_hidden = nodes_hidden[item.external_mask]

                    nodes_embeddings = self.network.results[1][item.result_idx]
                    item.external_embeddings = nodes_embeddings[item.external_mask]

                    item.score = self.network.results[2][item.result_idx]  # - (0.5 if item.is_gold else 0.0)
                    # if item.left.score is not None:
                    #     item.score += item.left.score.detach()
                    # if item.right.score is not None:
                    #     item.score += item.right.score.detach()

                items.sort(reverse=True)
                node_info.beam = items

                if is_train:
                    gold_index = -1
                    for index, item in enumerate(node_info.beam):
                        if item.is_gold:
                            gold_index = index
                            break
                    if len(items) > 1:
                        print_logger.total_count += 1
                        if gold_index == 0:
                            print_logger.correct_count += 1

                    # if node_idx == len(traversal_sequence) - 1 or gold_index >= self.options.beam_size:
                    if True:
                        # early update
                        assert gold_index >= 0
                        if self.hparams.loss_type == "hinge":
                            total_loss += node_info.beam[0].score - node_info.gold_item.score
                        else:
                            # noinspection PyCallingNonCallable
                            total_loss += F.cross_entropy(
                                torch.stack([i.score for i in node_info.beam]).unsqueeze(0),
                                torch.tensor([gold_index], device=device))
                        node_info.beam = [node_info.gold_item]
                node_info.beam = node_info.beam[:self.options.beam_size]
                # noinspection PyUnboundLocalVariable
                if is_train and gold_index >= self.options.beam_size:
                    node_info.beam[-1] = node_info.gold_item

        if not is_train:
            final_beam_item = traversal_sequence[-1].beam[0]
            return final_beam_item

        return total_loss


if __name__ == '__main__':
    GraphEmbeddingUdefQParser.main()
