import numpy as np
import torch
from typing import Optional, Tuple, List

from allennlp.modules.attention import BilinearAttention
from dataclasses import dataclass, field
from torch import nn, Tensor
from torch.nn import Module

from coli.basic_tools.common_utils import split_to_batches, cache_result
from coli.basic_tools.dataclass_argparse import DataClassArgParser
from coli.torch_extra.graph_embedding.dataset import Batch
from coli.torch_extra.graph_embedding.graph_encoder import GraphRNNEncoder
from coli.torch_extra.graph_embedding.reader import graph_embedding_types
from coli.hrgguru.hrg import CFGRule
from coli.hrg_parser.hrg_statistics import HRGStatistics


class GraphEmbeddingHRGScorer(Module):
    @dataclass
    class Options(object):
        hrg_mlp_dim: int = 100
        graph_embedding_type: str = field(
            default="normal",
            metadata={"choices": graph_embedding_types})
        graph_encoder: GraphRNNEncoder.Options = field(default_factory=GraphRNNEncoder.Options)
        span_reduced_dim: int = 256
        graph_reduced_dim: int = 128
        hrg_batch_size: int = 128
        span_dropout: float = 0.33

    @classmethod
    def add_parser_arguments(cls, arg_parser):
        """:type arg_parser: argparse.ArgumentParser"""
        group = arg_parser.add_argument_group(cls.__name__)
        DataClassArgParser("graph_embedding_scorer", group,
                            choices={"default": cls.Options()})

    def __init__(self,
                 grammar,
                 hrg_statistics,  # type: HRGStatistics
                 options,
                 new_options
                 ):
        super(GraphEmbeddingHRGScorer, self).__init__()
        self.options: GraphEmbeddingHRGScorer.Options = new_options.graph_embedding_scorer
        self.use_gpu = new_options.gpu

        @cache_result(
            new_options.output + "/simple_graphs.pkl", new_options.debug_cache)
        def get_simple_graphs_and_statistics():
            ret = {}
            statistics = graph_embedding_types[self.options.graph_embedding_type]()
            for rules_and_counts in grammar.values():
                for rule in rules_and_counts.keys():
                    if rule.hrg is not None:
                        simple_graph = statistics.read_sync_rule_and_update(rule, statistics)
                        ret[rule.hrg] = simple_graph
            return ret, statistics

        self.simple_graphs, self.statistics = get_simple_graphs_and_statistics()

        span_dim = options.hparams.d_model

        self.activation = nn.ReLU
        self.attention = BilinearAttention(self.options.span_reduced_dim,
                                           self.options.graph_encoder.model_hidden_size,
                                           activation=torch.tanh
                                           )

        self.graph_encoder = GraphRNNEncoder(self.options.graph_encoder, self.statistics)

        self.span_reducer = nn.Sequential(
            nn.Dropout(self.options.span_dropout),
            nn.Linear(options.hparams.d_model, self.options.span_reduced_dim),
            nn.LayerNorm(self.options.span_reduced_dim),
            nn.LeakyReLU(0.1)
        )

        self.graph_reducer = nn.Sequential(
            nn.Linear(self.options.graph_encoder.model_hidden_size, self.options.graph_reduced_dim),
            nn.LayerNorm(self.options.graph_reduced_dim),
            nn.LeakyReLU(0.1)
        )

        self.bilinear_layer = nn.Bilinear(self.options.span_reduced_dim,
                                          self.options.graph_reduced_dim,
                                          1)

        self.rule_inputs = []  # tuple(rule, span_idx)
        self.span_inputs = []
        self.results: Optional[Tensor] = None
        # key=(span_feature, rule_feature, rule_count)

    def calculate_results(self):
        span_inputs_stacked = torch.stack(self.span_inputs)
        span_inputs_reduced = self.span_reducer(span_inputs_stacked)

        results = []
        sort_idx = np.argsort([len(i[0]["entities"]) for i in self.rule_inputs])
        unsort_idx = np.argsort(sort_idx)
        inputs_sorted = [self.rule_inputs[i] for i in sort_idx]
        for start_idx, _, batch_pending_inputs in split_to_batches(
                inputs_sorted, self.options.hrg_batch_size):
            simple_graphs, span_features_idxs = zip(*batch_pending_inputs)
            span_idxs = torch.tensor(span_features_idxs)
            # span_features = span_inputs_stacked[span_idxs]
            span_features_r = span_inputs_reduced[span_idxs]
            batch = Batch()
            batch.init(self.options.graph_encoder, simple_graphs)
            if self.use_gpu:
                batch = batch.cuda()
            graph_outputs = self.graph_encoder(batch)
            node_embeddings = graph_outputs[0][-1]
            node_mask = batch.entities_mask.squeeze(-1)
            attention_weights = self.attention(span_features_r, node_embeddings, node_mask)
            graph_embedding_batch = torch.bmm(attention_weights.unsqueeze(1), node_embeddings).squeeze(1)
            results.append((span_features_r, graph_embedding_batch))

        span_features_batch_all_r = torch.cat([i[0] for i in results], dim=0)
        graph_embedding_batch_all = torch.cat([i[1] for i in results], dim=0)
        graph_embedding_batch_all_r = self.graph_reducer(graph_embedding_batch_all)
        dense_out = self.bilinear_layer(span_features_batch_all_r, graph_embedding_batch_all_r)

        self.results = dense_out.squeeze()[unsort_idx]

    def refresh(self):
        self.rule_inputs.clear()
        self.span_inputs.clear()
        self.results = None

    def get_best_rule(self,
                      span_feature,
                      rules_and_counts,  # type: List[Tuple[CFGRule, int]]
                      gold=None,
                      ):
        span_idx = len(self.span_inputs)
        self.span_inputs.append(span_feature)
        idx_to_result_idx = []
        if len(idx_to_result_idx) == 1:
            yield None
            yield idx_to_result_idx[0], None, idx_to_result_idx[0]

        gold_idx = None
        for idx, (rule, count) in enumerate(rules_and_counts):
            if rule == gold:
                gold_idx = idx
            result_idx = len(self.rule_inputs)
            simple_graph = self.simple_graphs.get(rule.hrg)
            if simple_graph is None:
                simple_graph = self.statistics.read_sync_rule_and_update(rule, self.statistics)
                if gold is not None:
                    self.simple_graphs[rule.hrg] = simple_graph
            self.rule_inputs.append((simple_graph, span_idx))
            idx_to_result_idx.append(result_idx)

        yield None

        scores = self.results[idx_to_result_idx]
        best_idx = torch.argmax(scores)
        best_rule = rules_and_counts[best_idx][0]

        if gold is not None:
            loss = torch.nn.functional.cross_entropy(
                scores.unsqueeze(0),
                torch.tensor([gold_idx], device=self.results.device))
            yield best_rule, loss, best_rule
        else:
            yield best_rule, None, best_rule
