import json
from collections import defaultdict
from glob import glob
from pathlib import Path

import numpy as np
import pandas as pd
from cycler import cycler
from matplotlib import pyplot as plt

DATA_STRATEGIES = ["random", "uncertainty"]
ANNOTATOR_STRATEGIES = [
    "active",
    "lmin",
    "semdiv",
    # "agreement",
    "representation",
    "oracle",
]
RESULT_COLUMNS = [
    "eval_loss",
    "eval_jsdiv",
    "eval_accuracy",
    "eval_macro avg",  # F1
    "eval_mean_f1_macro_per_annotator",
    "eval_worst_f1_macro_per_annotator",
    "eval_worst_jsdiv_per_annotator",
    "eval_mean_jsdiv_per_annotator",
    "eval_worst_jsdiv_per_item",
]
GRB = [False, False, True, True, True, True, False, False, False]
XS_COLUMNS = ["total_unique_annotations_seen"]


STRAT2LABEL = {
    "random_active": f"$\\mathcal{{S}}_R$$\\mathcal{{T}}_R$",
    "random_lmin": f"$\\mathcal{{S}}_R$$\\mathcal{{T}}_L$",
    "random_semdiv": f"$\\mathcal{{S}}_R$$\\mathcal{{T}}_S$",
    "random_agreement": f"$\\mathcal{{S}}_R$$\\mathcal{{T}}_A$",
    "random_representation": f"$\\mathcal{{S}}_R$$\\mathcal{{T}}_D$",
    "random_oracle": f"$\\mathcal{{S}}_R$$\\mathcal{{O}}$",
    "uncertainty_active": f"$\\mathcal{{S}}_U$$\\mathcal{{T}}_R$",
    "uncertainty_lmin": f"$\\mathcal{{S}}_U$$\\mathcal{{T}}_L$",
    "uncertainty_semdiv": f"$\\mathcal{{S}}_U$$\\mathcal{{T}}_S$",
    "uncertainty_agreement": f"$\\mathcal{{S}}_U$$\\mathcal{{T}}_A$",
    "uncertainty_representation": f"$\\mathcal{{S}}_U$$\\mathcal{{T}}_D$",
    "uncertainty_oracle": f"$\\mathcal{{S}}_U$$\\mathcal{{O}}$",
    "passive_passive": "passive",
}


def read_data(files):
    """
    Load in JSONL data from a list of files.
    """
    file_data = {}
    for file in files:
        with open(file) as f:
            lines = f.readlines()

        rows = []
        for line in lines:
            try:
                data = json.loads(line)
            except json.JSONDecodeError:
                print(f"Error reading {file}")
            if "counters" not in data or "results" not in data:
                continue
            counters = data["counters"]
            results = data["results"]
            if "best_model" in data:
                best_model = data["best_model"]
            else:
                best_model = {}
            rows.append({**counters, **results, **best_model})
        file_data[file] = pd.DataFrame(rows)
    return file_data


def get_best_score(df, metric, is_passive):
    """
    Get the test set score and the round at which it was achieved.

    For passive learning, simply extract the data directly instead of
    accessing the "best_model" entry.
    """
    test_ys_counter = metric.replace("eval_", "test_")
    if test_ys_counter == "test_macro avg":
        test_ys = (
            df[test_ys_counter]
            .apply(lambda x: x["f1-score"] if isinstance(x, dict) else x)
            .dropna()
        )
    else:
        test_ys = df[test_ys_counter].dropna()

    assert len(test_ys) == 1
    test_score = test_ys.item()

    if is_passive:
        test_round = -1  # no round for passive
        test_x = df["total_unique_annotations_seen"].iloc[-1]
    else:
        test_round = df["best_model_metric_round"].iloc[-1]
        test_x = df["best_model_metric_total_unqiue_samples"].iloc[-1]  # typo entry

    return test_score, test_x, test_round


def dump_plot_data(dt, plot_dict, metric, gib):
    dfs = []
    for approach in plot_dict:
        x = plot_dict[approach]["x"]
        y = plot_dict[approach]["y"]
        if "passive" in approach:
            # If passive, we only have one value that we repeat
            if gib:
                y = [max(y)] * len(x)
            else:
                y = [min(y)] * len(x)
        dfs.append(
            pd.DataFrame.from_dict(
                {
                    f"{approach}_x": x,
                    f"{approach}_y": y,
                    f"{approach}_label": STRAT2LABEL[approach],
                }
            )
        )

    df = pd.concat(dfs, axis=1)
    if not Path("output/images/data").exists():
        Path("output/images/data").mkdir(parents=True, exist_ok=True)
    df.to_csv(f"output/images/data/{dt}_{metric}.dat", sep="\t", index=False)


def get_metric_values(df, metric):
    if metric not in df:
        raise ValueError(f"Metric {metric} not found in dataframe")

    if metric == "eval_macro avg":
        values = df[metric].apply(lambda x: x["f1-score"] if isinstance(x, dict) else x)
    else:
        values = df[metric]

    return values


def get_passive_plot_data(xs, ys, greater_is_better):
    """
    For passive learning, we only plot the best value as a horizontal line.
    """
    if greater_is_better:
        y = np.array([max(ys)] * len(xs))
    else:
        y = np.array([min(ys)] * len(xs))

    xs = np.linspace(0, xs.iloc[-1], len(xs))
    return xs, y


def plot_validation_during_training(
    dataset_task, glob_data, xs_counter, ys_counter, greater_is_better=True
):
    plt.figure(figsize=(20, 6))
    plot_file_dict = {}

    cycl = (
        cycler(linestyle=["-", "--", "-.", ":"] * 5)
        + cycler(
            marker=[
                "o",
                "s",
                "x",
                "D",
                "v",
                "^",
                "<",
                ">",
                "1",
                "2",
                "3",
                "4",
                "8",
                "p",
                "P",
                "*",
                "h",
                "H",
                "+",
                "X",
            ]
        )
        + cycler(color=["b", "g", "c", "m"] * 5)
    )
    plt.rc("axes", prop_cycle=cycl)
    did_plot = False
    for a, glob_plot_data in glob_data.items():
        data_strategy = glob_plot_data["data_strategy"]
        annotator_strategy = glob_plot_data["annotator_strategy"]

        # Iterate over all seeds
        agg_xs = []
        agg_ys = []
        for filename, data in glob_plot_data["file_dataframes"].items():
            try:
                xs = get_metric_values(data, xs_counter)
                ys = get_metric_values(data, ys_counter)
            except ValueError as e:
                print(f"{filename}: {e}")
                continue

            agg_xs.append(xs)
            agg_ys.append(ys)

        if len(agg_ys) == 0:
            continue

        # Aggregate result data
        min_shared_length = min([len(x) for x in agg_xs])
        agg_xs = [x[:min_shared_length] for x in agg_xs]
        agg_ys = [y[:min_shared_length] for y in agg_ys]

        # assume all xs are the same
        xs = agg_xs[0]
        # average ys
        ys = pd.DataFrame(agg_ys).mean(axis=0)

        label = f"{data_strategy}_{annotator_strategy}"

        best_score_y, best_score_x, best_score_round = get_best_score(
            data, ys_counter, "passive" in label
        )

        if "passive" in label:
            xs, ys = get_passive_plot_data(xs, ys, greater_is_better)

        plt.plot(xs, ys, label=label)
        plt.xlabel(xs_counter)
        plt.ylabel(ys_counter)
        did_plot = True

        plot_file_dict[label] = {
            "x": xs.tolist(),
            "y": ys.tolist(),
            "best_score": {
                "round": best_score_round,
                "unique_annotations_seen": best_score_x,
                "test_value": best_score_y,
            },
        }

    if did_plot:
        plt.legend(
            loc="upper center",
            bbox_to_anchor=(0.5, -0.20),
            ncol=1,
            fancybox=True,
            shadow=True,
        )
        plt.savefig(
            f"output/images/{dataset_task}_{ys_counter}.png",
            bbox_inches="tight",
        )
        plt.close()
        dump_plot_data(dataset_task, plot_file_dict, ys_counter, greater_is_better)

    return plot_file_dict


def plot_test_vs_data_usage(dataset_task, plot_data, metric):
    plt.figure(figsize=(20, 6))

    prints = defaultdict(list)

    for strategy, data in plot_data.items():
        best_score = data["best_score"]
        x = best_score["unique_annotations_seen"]
        test_y = best_score["test_value"]
        plt.plot(x, test_y, label=strategy)
        if metric == "eval_jsdiv" or metric == "eval_macro avg":
            prints[metric].append(f"{x}\t{test_y:.5f}\t{STRAT2LABEL[strategy]}")

        plt.annotate(strategy, (x, test_y))
    if "DICES" in dataset_task:
        plt.xlim([0, 60000])
    plt.ylabel(f"Test {metric}")
    plt.xlabel("Unique annotations seen")
    plt.savefig(
        f"output/images/{dataset_task}_2D_{metric}_test.png",
        bbox_inches="tight",
    )
    plt.close()

    for metric in prints:
        print(f"========= {metric} =========")
        for line in prints[metric]:
            print(line)


def read_training_log(dataset_task, file_pattern):
    # Figure out which files to read
    file_group = [
        (file_pattern.format(dataset_task, dataset_task, ans, ds), ds, ans)
        for ds in DATA_STRATEGIES
        for ans in ANNOTATOR_STRATEGIES
    ]

    # Load also passive learning data
    passive_path = file_pattern.format(dataset_task, dataset_task, "passive", "passive")
    pp = Path(passive_path).parent / f"slurm_passive_{dataset_task}_s*.jsonl"
    file_group.append((str(pp), "passive", "passive"))

    # Read in results
    glob_results = {}
    for file_glob, data_strategy, annotator_strategy in file_group:
        files = glob(file_glob)
        files = [f for f in files if "data_logging" not in f]
        plot_data = read_data(files)
        glob_results[file_glob] = {
            "file_dataframes": plot_data,
            "data_strategy": data_strategy,
            "annotator_strategy": annotator_strategy,
        }

    return glob_results


def print_annotator_centric_results(test_set_metrics, best_passive_x):
    annotator_centric_metrics = [
        "eval_mean_f1_macro_per_annotator",
        "eval_mean_jsdiv_per_annotator",
        "eval_worst_f1_macro_per_annotator",
        "eval_worst_jsdiv_per_annotator",
    ]
    x_counters = list(test_set_metrics[annotator_centric_metrics[0]].keys())
    strategies = test_set_metrics[annotator_centric_metrics[0]][x_counters[0]].keys()

    for strategy in strategies:
        print(f"{STRAT2LABEL[strategy]:36}", end="&")
        x_values = []
        for metric in annotator_centric_metrics:
            results = test_set_metrics[metric][x_counters[0]][strategy]
            best_score = results["best_score"]
            this_unique = best_score["unique_annotations_seen"]
            x_values.append(this_unique)

            print(f" {best_score['test_value']:.3f} ", end="&")

        # sanity check the x values for the best round for all annotator-centric metrics
        best_x_val = set(x_values)
        assert len(set(x_values)) == 1
        decrease = ((list(best_x_val)[0] / best_passive_x) - 1) * 100
        print(f" {decrease:2.1f} \\\\")


def plot_file_group(dataset_task, results):
    test_set_metrics = {}

    for metric, greater_is_better in zip(RESULT_COLUMNS, GRB):
        test_set_metrics[metric] = {}
        for counter in XS_COLUMNS:
            plot_data = plot_validation_during_training(
                dataset_task, results, counter, metric, greater_is_better
            )
            test_set_metrics[metric][counter] = plot_data

            plot_test_vs_data_usage(dataset_task, plot_data, metric)

    print("========= Annotator-centric tables =========")
    best_passive_x = test_set_metrics["eval_jsdiv"]["total_unique_annotations_seen"][
        "passive_passive"
    ]["best_score"]["unique_annotations_seen"]
    print_annotator_centric_results(test_set_metrics, best_passive_x)


if __name__ == "__main__":
    DATASET_TASKS = {
        "DICES_overall": "output/runs/prajjwal1_bert-tiny/{}_r50_e20_lr1e-05/slurm_{}_{}_{}_s*.jsonl",
        # "DICES_quality": "output/runs/prajjwal1_bert-tiny/{}_r70_e20_lr1e-05/slurm_{}_{}_{}_s*.jsonl",
        "MFTC_care": "output/runs/prajjwal1_bert-tiny/{}_r20_e30_lr1e-05/slurm_{}_{}_{}_s*.jsonl",
        "MFTC_loyalty": "output/runs/prajjwal1_bert-tiny/{}_r20_e30_lr1e-05/slurm_{}_{}_{}_s*.jsonl",
        "MFTC_betrayal": "output/runs/prajjwal1_bert-tiny/{}_r20_e30_lr1e-05/slurm_{}_{}_{}_s*.jsonl",
        "MHS_dehumanize": "output/runs/prajjwal1_bert-tiny/{}_r20_e20_lr1e-05/slurm_{}_{}_{}_s*.jsonl",
        "MHS_respect": "output/runs/prajjwal1_bert-tiny/{}_r20_e20_lr1e-05/slurm_{}_{}_{}_s*.jsonl",
        "MHS_genocide": "output/runs/prajjwal1_bert-tiny/{}_r20_e20_lr1e-05/slurm_{}_{}_{}_s*.jsonl",
    }
    for dt, pattern in DATASET_TASKS.items():
        print(f"Processing {dt}")
        results = read_training_log(dt, pattern)

        plot_file_group(dt, results)
