import torch
from typing import Optional, Iterable, Tuple

from dataclasses import dataclass
from torch import nn, Tensor
from torch.nn import Module, Parameter, LayerNorm

from coli.basic_tools.common_utils import ensure_dir
from coli.basic_tools.dataclass_argparse import OptionsBase, argfield
from coli.hrgguru.hrg import CFGRule
from coli.hrgguru.hyper_graph import strip_category
from coli.basic_tools.logger import logger
from coli.hrg_parser.hrg_statistics import HRGStatistics, encode_nonterminal


class StructuredPeceptronHRGScorer(Module):
    def extract_features(self,
                         rule,  # type: Optional[CFGRule]
                         count
                         ):
        if not self.use_count:
            count = 1
        result = [0 for _ in range(len(self.possible_features) + 1)]
        result[-1] = count
        if rule.hrg is not None:
            if len(rule.rhs) == 2:
                left_info, right_info = rule.rhs
                if len(left_info) == 2:
                    left_label, left_edge = left_info
                    if left_edge is not None:
                        if left_edge.nodes == rule.hrg.lhs.nodes:
                            result[self.feature_index["head_left"]] = 1
                        # elif len(left_edge.nodes) == 2:
                        #     if left_edge.nodes[0] == rule.hrg.lhs.nodes[0]:
                        #         result[self.feature_index["head_left_1/2"]] = 1
                        #     elif len(rule.hrg.lhs.nodes) >= 2 and left_edge.nodes[1] == rule.hrg.lhs.nodes[1]:
                        #         result[self.feature_index["head_left_2/2"]] = 1
                if len(right_info) == 2:
                    right_label, right_edge = right_info
                    if right_edge is not None:
                        if right_edge.nodes == rule.hrg.lhs.nodes:
                            result[self.feature_index["head_right"]] = 1
                        # elif len(rule.hrg.lhs.nodes) >= 2 and len(right_edge.nodes) == 2:
                        #     if right_edge.nodes[0] == rule.hrg.lhs.nodes[0]:
                        #         result[self.feature_index["head_right_1/2"]] = 1
                        #     elif right_edge.nodes[1] == rule.hrg.lhs.nodes[1]:
                        #         result[self.feature_index["head_right_2/2"]] = 1
            for edge in rule.hrg.rhs.edges:
                if edge.is_terminal and len(edge.nodes) == 2:
                    label = edge.label
                elif edge.is_terminal and len(edge.nodes) == 1:
                    label = strip_category(edge.label)
                elif not edge.is_terminal:
                    label = encode_nonterminal(edge)
                else:
                    label = "INVALID"
                if label in self.edge_labels:
                    feature = ("Edge", label)
                    result[self.feature_index[feature]] += 1
        return tuple(result)

    @dataclass
    class Options(OptionsBase):
        hrg_mlp_dim: int = 100
        hrg_loss_margin: float = 0.1
        loss_type: str = argfield("hinge", choices=["hinge", "crossentropy"])
        conflict_output_dir: Optional[str] = None
        use_count: bool = False

    def __init__(self,
                 hrg_mlp_dim,
                 hrg_loss_margin,
                 loss_type,
                 conflict_output_dir,
                 use_count,
                 grammar,
                 statistics,  # type: HRGStatistics
                 contextual_dim,
                 ):
        super(StructuredPeceptronHRGScorer, self).__init__()
        self.hrg_loss_margin = hrg_loss_margin
        self.loss_type = loss_type
        self.conflict_output_dir = conflict_output_dir
        self.use_count = use_count
        self.activation = nn.ReLU

        self.edge_labels = list(
            word for word, count in statistics.nonterminals.most_common(300)) + \
                           list(statistics.structural_edges) + \
                           list(statistics.categories)

        self.possible_features = [("Edge", k) for k in self.edge_labels]
        logger.info("Consider {} features as graph embedding".format(
            len(self.possible_features)))
        self.possible_features.append("head_left")
        self.possible_features.append("head_right")
        # self.possible_features.append("head_left_1/2")
        # self.possible_features.append("head_left_2/2")
        # self.possible_features.append("head_right_1/2")
        # self.possible_features.append("head_right_2/2")
        self.feature_index = {i: idx for idx, i in enumerate(self.possible_features)}

        self.dense_layer = nn.Sequential(
            nn.Linear(contextual_dim + len(self.possible_features) + 1,
                      hrg_mlp_dim
                      ),
            LayerNorm(hrg_mlp_dim),
            self.activation(),
            nn.Linear(hrg_mlp_dim, 1, bias=False)
        )

        self.count_scale = Parameter(torch.randn([1]))
        self.count_scale_2 = Parameter(torch.randn([1]))

        if conflict_output_dir:
            ensure_dir(conflict_output_dir)

        # features and counts
        self.pending_inputs = []
        self.results: Optional[Tensor] = None
        # key=(span_feature, rule_feature, rule_count)

    def calculate_results(self):
        if not self.pending_inputs:
            return

        stack_features = torch.stack(self.pending_inputs)
        match_scores = self.dense_layer(stack_features).view(-1)
        self.results = match_scores

    def refresh(self):
        self.pending_inputs.clear()
        self.results = []

    def get_rules(self,
                  span_feature,
                  rules_and_counts,  # type: Iterable[Tuple[CFGRule, int]]
                  gold=None,
                  ):
        yield from self.get_best_rule(span_feature, rules_and_counts, gold, False)

    def get_best_rule(self,
                      span_feature,
                      rules_and_counts,  # type: Iterable[Tuple[CFGRule, int]]
                      gold=None,
                      get_loss=True
                      ):
        feature_map = {rule: self.extract_features(rule, count)
                       for rule, count in rules_and_counts}  # rule -> feature
        features = frozenset(feature_map.values())
        if len(features) == 1:
            best_rule = sorted(((count, rule) for rule, count in rules_and_counts),
                               key=lambda x: x[0],
                               reverse=True)[0][1]
            yield None
            if get_loss:
                yield best_rule, None, best_rule
            else:
                yield {rule: 0.0 for rule, count in rules_and_counts}
            return

        def calculate_score(feature):
            feature_input = torch.cat([
                    span_feature,
                    torch.tensor(feature, dtype=torch.float, device=span_feature.device)])
            result_idx = len(self.pending_inputs)
            self.pending_inputs.append(feature_input)

            return result_idx

        features_to_result_idx = {feature: calculate_score(feature)
                                  for feature in features}  # feature -> score

        # waiting for other batch
        yield None

        if gold is not None:
            gold_feature = feature_map[gold]
        else:
            gold_feature = None

        def get_score_count_rule(rule, count, detach=True):
            feature = feature_map[rule]
            result_idx = features_to_result_idx[feature]
            if detach:
                score = self.results[result_idx].detach().cpu().numpy()
            else:
                score = self.results[result_idx]
            if gold is not None and feature == gold_feature and self.loss_type == "hinge":
                score -= self.hrg_loss_margin
            return score, count, rule

        if not get_loss:
            score_count_rule = [get_score_count_rule(rule, count, False)
                                for rule, count in rules_and_counts]  # tuple (score, count, rule)
            yield {rule: score for score, count, rule in score_count_rule}
        else:
            score_count_rule = [get_score_count_rule(rule, count)
                                for rule, count in rules_and_counts]  # tuple (score, count, rule)
            # chooce the highest score, if not unique, choose the most frequent
            best_score, best_count, best_rule = sorted(score_count_rule,
                                                       key=lambda x: (x[0], x[1]),
                                                       reverse=True)[0]

            if gold is not None:
                best_feature = feature_map[best_rule]
                best_result_idx = features_to_result_idx[best_feature]
                best_expr = self.results[best_result_idx]
                assert gold in [i[0] for i in rules_and_counts]
                gold_result_idx = features_to_result_idx[gold_feature]
                gold_expr = self.results[gold_result_idx]

                if self.loss_type == "hinge":
                    if best_feature == gold_feature:
                        loss = 0.0
                    else:
                        loss = best_expr - gold_expr + self.hrg_loss_margin
                    _, _, real_best_rule = \
                        sorted(
                            ((score if feature_map[rule] != gold_feature else score + self.hrg_loss_margin, count, rule)
                             for score, count, rule in score_count_rule),
                            key=lambda x: (x[0], x[1]),
                            reverse=True)[0]
                else:
                    assert self.loss_type == "crossentropy"
                    real_best_rule = best_rule
                    result_idxs = list(features_to_result_idx.values())
                    gold_result_idx_idx = result_idxs.index(gold_result_idx)
                    loss = torch.nn.functional.cross_entropy(
                                    self.results[result_idxs].unsqueeze(0),
                                    torch.tensor([gold_result_idx_idx], device=self.results.device))
                yield best_rule, loss, real_best_rule

                # output conflict
                if self.conflict_output_dir:
                    if real_best_rule != gold and feature_map[real_best_rule] == gold_feature:
                        base_name = "{}_{}".format(hash(real_best_rule), hash(gold))
                        real_best_rule.hrg.draw(self.onflict_output_dir + "/" + base_name + "_real",
                                                draw_format="png")
                        gold.hrg.draw(self.conflict_output_dir + "/" + base_name + "_gold",
                                      draw_format="png")
                        # with open(self.options.conflict_output, "a") as f:
                        #     print("Conflict:\n real: {}\n gold: {}".format(
                        #         real_best_rule, gold), file=f)
            else:
                yield best_rule, None, best_rule
