from functools import partial

import torch.cuda
from torch.utils.data import DataLoader
from tqdm import tqdm

from iter import ITER
from iter.datasets import CoNLL, sparse_collate_fn
from iter.datasets.util import move_to_device_and_de_sparsify_collate_fn
from iter.misc.config import with_argparse
from iter.misc.metrics import accumulate_metrics, metrics_ner, calculate_f1, metrics_ere, Metrics
from iter.misc.seeding import setup_seed


@with_argparse
def evaluate(
        model: str,
        dataset: str = None,
        split: str = "test",
        seed: int = 42,
):
    setup_seed(seed)

    device = "cpu" if not torch.cuda.is_available() else "cuda:0"
    model: ITER = ITER.from_pretrained(model).to(device)
    tokenizer = model.tokenizer

    dataset = dataset or model.config.dataset
    dataset = CoNLL.from_name(dataset, tokenizer=tokenizer)
    dataset.setup_dataset()

    metrics = evaluate0(model, dataset, split)
    print(metrics)


def evaluate0(
        model: ITER,
        dataset: CoNLL,
        split: str,
        batch_size: int = 8,
) -> Metrics:
    device = model.device
    # you can either think of this in the first place or be happy you came up with it after two hours of debugging
    # in entirely different places
    model.eval()  # my god, isn't it beautiful

    def call_model_decode(bt: dict):
        return model.decode_actions_and_pairings(
            bt["input_ids"], bt["actions"], bt["lr_pair_flag"], bt.get("rr_pair_flag", None),
            entity_types=dataset.entity_types, link_types=dataset.link_types
        )

    collate_fn = partial(sparse_collate_fn, tokenizer=dataset.tokenizer)
    collate_fn = move_to_device_and_de_sparsify_collate_fn(collate_fn, device)
    metrics = accumulate_metrics([], "macro", dataset.num_types - int(model.is_feature_extra_lr_class))
    ere_metrics = accumulate_metrics([], "macro", dataset.num_links - int(model.is_feature_extra_rr_class))
    with torch.no_grad():
        dataloader = DataLoader(dataset[split], batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
        batch: dict
        for batch_idx, batch in enumerate(
                tqdm(dataloader, disable=False, leave=False, desc=f"Evaluating {dataset.name}")):
            actions, pairings, decoded_pairings, decoded_links = model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                entity_types=dataset.entity_types,
                link_types=dataset.link_types,
            )
            gold_decoded_pairings, gold_decoded_links = call_model_decode(batch)

            batch_metrics = metrics_ner(
                decoded_pairings,
                gold_decoded_pairings,
                use_entity_tag=True,
                average="macro",
                batched=True,
                entity_types=list(range(model.num_types - int(model.is_feature_extra_lr_class)))
            )
            metrics = accumulate_metrics([metrics, batch_metrics], "macro", dataset.num_types)
            if not model.is_feature_ner_only:
                ere_batch_metrics = metrics_ere(
                    decoded_links,
                    gold_decoded_links,
                    use_entity_tag=True,
                    average="macro",
                    batched=True,
                    link_types=list(range(model.num_links - int(model.is_feature_extra_rr_class)))
                )
                ere_metrics = accumulate_metrics([ere_metrics, ere_batch_metrics], "macro", dataset.num_links)

    per_class = {k: calculate_f1(v, average="micro") for k, v in metrics.items()}
    m_pr, m_rec, m_f1 = calculate_f1(list(metrics.values()), average="macro")
    pr, rec, f1 = calculate_f1(accumulate_metrics(list(metrics.values()), average="micro"), average="micro")

    ere_per_class = {k: calculate_f1(v, average="micro") for k, v in ere_metrics.items()}
    ere_m_pr, ere_m_rec, ere_m_f1 = calculate_f1(list(ere_metrics.values()), average="macro")
    ere_pr, ere_rec, ere_f1 = calculate_f1(accumulate_metrics(list(ere_metrics.values()), average="micro"), average="micro")

    return Metrics(
        pr, rec, f1,
        ere_pr, ere_rec, ere_f1,
        m_pr, m_rec, m_f1,
        ere_m_pr, ere_m_rec, ere_m_f1,
        per_class,
        ere_per_class
    )


if __name__ == "__main__":
    evaluate()
