import dataclasses
import typing

import torch

_entity_type = tuple[
    list,  # tokens
    int,  # type id
    str,  # text
    str,  # type str
]
_link_type = tuple[
    _entity_type,  # entity
    int,           # link id
    _entity_type,  # entity
]


@dataclasses.dataclass
class Metrics:
    ner_pr: float
    ner_rec: float
    ner_f1: float
    ere_pr: float
    ere_rec: float
    ere_f1: float

    macro_ner_pr: float
    macro_ner_rec: float
    macro_ner_f1: float
    macro_ere_pr: float
    macro_ere_rec: float
    macro_ere_f1: float

    ner_per_class: "dict[int, tuple[float, float, float]]" = None
    ere_per_class: "dict[int, tuple[float, float, float]]" = None

    def __repr__(self):
        return (f"NER :: pr={self.ner_pr:.6f} rec={self.ner_rec:.6f} f1={self.ner_f1:.6f} "
                f"ERE :: pr={self.ere_pr:.6f} rec={self.ere_rec:.6f} f1={self.ere_f1:.6f}")

    def to_dict(self):
        return {
            "f1": self.ere_f1,
            "pr": self.ere_pr,
            "rec": self.ere_rec,
            "ner_f1": self.ner_f1,
            "ner_pr": self.ner_pr,
            "ner_rec": self.ner_rec,
            "macro_f1": self.macro_ere_f1,
            "macro_pr": self.macro_ere_pr,
            "macro_rec": self.macro_ere_rec,
            "macro_ner_f1": self.macro_ner_f1,
            "macro_ner_pr": self.macro_ner_pr,
            "macro_ner_rec": self.macro_ner_rec,
        }


def calculate_f1(metrics: dict | list = None, average: str = "micro"):
    if average == "micro":
        return calculate_f1_micro(metrics)
    elif average == "macro":
        outputs = torch.tensor([calculate_f1_micro(metric) for metric in metrics]).mean(dim=0)  # torch.float64
        if outputs.dim() == 0:
            return 0., 0., 0.
        return outputs[0].item(), outputs[1].item(), outputs[2].item()
    else:
        raise ValueError(average)


def calculate_f1_micro(metrics: dict = None):
    tp = torch.tensor(metrics["tp"])
    fn = torch.tensor(metrics["fn"])
    fp = torch.tensor(metrics["fp"])
    zero = torch.tensor(0.)
    precision = tp / (tp + fp) if (tp + fp) > 0 else zero
    recall = tp / (tp + fn) if (tp + fn) > 0 else zero
    f1 = (2 * precision * recall) / (precision + recall) if precision + recall > 0 else zero
    return precision.item(), recall.item(), f1.item()


def metrics_ere(
        links: list[_link_type] | list[list[_entity_type]],
        gold_links: list[_link_type] | list[list[_entity_type]],
        use_entity_tag: bool = False,
        average: str = "micro",
        batched: bool = False,
        link_types: list[str | int] = None,
):
    if average == "macro":
        assert link_types is not None
        if batched:
            per_batch_macro_metrics = [
                metrics_ere(
                    links, gold_links,
                    use_entity_tag=use_entity_tag,
                    batched=False,
                    average="macro",
                    link_types=link_types,
                )
                for links, gold_links in zip(links, gold_links)
            ]

            return {
                link_type: accumulate_metrics(
                    [metrics.get(link_type, {}) for metrics in per_batch_macro_metrics],
                    average="micro",
                )
                for link_type in link_types
            }
        macro_metrics = {
            link_type: metrics_ere(
                filtered_links(links, link_type),
                filtered_links(gold_links, link_type),
                use_entity_tag=use_entity_tag,
                batched=False,
                average="micro",
                link_types=link_types
            )
            for link_type in link_types
        }
        return macro_metrics

    if batched:
        return accumulate_metrics([
            metrics_ere(links, gold_links, use_entity_tag=use_entity_tag, batched=False, average=average)
            for links, gold_links in zip(links, gold_links)
        ], average, num_classes=len(link_types))

    if not use_entity_tag:
        links = remove_entity_tag_from_links(links)
        gold_links = remove_entity_tag_from_links(gold_links)

    return {
        "tp": len(set(links) & set(gold_links)),
        "fn": len(set(gold_links) - set(links)),
        "fp": len(set(links) - set(gold_links)),
    }


def metrics_ner(
        entities: list[_entity_type] | list[list[_entity_type]],
        gold_entities: list[_entity_type] | list[list[_entity_type]],
        use_entity_tag: bool = True,
        average = "micro",
        batched: bool = False,
        entity_types: list[str | int] = None
):
    if average == "macro":
        assert entity_types is not None
        if batched:
            per_batch_macro_metrics = [
                metrics_ner(
                    entities, gold_entities,
                    batched=False,
                    average="macro",
                    entity_types=entity_types,
                )
                for entities, gold_entities in zip(entities, gold_entities)
            ]

            return {
                entity_type: accumulate_metrics(
                    [metrics.get(entity_type, {}) for metrics in per_batch_macro_metrics],
                    average="micro",
                    num_classes=len(entity_types)
                )
                for entity_type in entity_types
            }
        macro_metrics = {
            entity_type: metrics_ner(
                filtered_entities(entities, entity_type),
                filtered_entities(gold_entities, entity_type),
                batched=False,
                average="micro",
            )
            for entity_type in entity_types
        }
        return macro_metrics

    if batched:
        return accumulate_metrics([
            metrics_ner(entities, gold_entities, batched=False)
            for entities, gold_entities in zip(entities, gold_entities)
        ], average, num_classes=len(entity_types))

    if not use_entity_tag:
        entities = remove_entity_tag_from_entities(entities)
        gold_entities = remove_entity_tag_from_entities(gold_entities)

    return {
        "tp": len(set(entities) & set(gold_entities)),
        "fn": len(set(gold_entities) - set(entities)),
        "fp": len(set(entities) - set(gold_entities))
    }


def accumulate_metrics(
        metrics: list[dict] | list[dict[int, dict]],
        average: str,
        num_classes: int = None
) -> dict:
    if len(metrics) == 0:
        if average == "macro":
            return {cls: {"tp": 0, "fp": 0, "fn": 0} for cls in range(num_classes)}
        return {"tp": 0, "fp": 0, "fn": 0}
    if average == "macro":
        return {a: accumulate_metrics([item[a] for item in metrics], "micro") for a in metrics[0].keys()}
    output = {k: 0 for k in metrics[0].keys()}
    for metric_dict in metrics:
        for k, v in metric_dict.items():
            output[k] += v
    return output


def filtered_entities(entities: list[_entity_type], entity_type: int):
    return [entity for entity in entities if entity[1] == entity_type]


def filtered_links(links: list[_link_type], link_type: int):
    return [link for link in links if link[1] == link_type]


def remove_entity_tag_from_links(links: list[_link_type]) -> list[_link_type]:
    return [
        (
            remove_entity_tag_from_entity(link[0]),
            link[1],
            remove_entity_tag_from_entity(link[2])
        )
        for link in links
    ]


def remove_entity_tag_from_entities(entities: list[_entity_type]):
    return [remove_entity_tag_from_entity(ent) for ent in entities]


def remove_entity_tag_from_entity(entity: _entity_type) -> _entity_type:
    return entity[0], -1, entity[2], "<ignored>"
