import csv
import os.path

import pandas as pd
from tqdm import tqdm

from data.load_eval_df import load_train_df, load_dev_df, load_test_df
from evaluation.evaluate import save_corr
from project_root import join_with_root
from os import listdir, path

if __name__ == '__main__':
    datasets = {
        "train": load_train_df(),
        "dev": load_dev_df(),
        "test": load_test_df()
    }

    baseline_path_head = join_with_root("outputs/raw_baselines")
    baseline_paths = listdir(baseline_path_head)
    baseline_paths_full = [path.join(baseline_path_head, file) for file in baseline_paths]


    correlations_list = []
    for dataset, df in datasets.items():
        for task in df["task"].unique().tolist():
            sub_df = df[df["task"] == task]
            scores = sub_df["GT_Score"]
            for i, baseline in tqdm(enumerate(baseline_paths)):
                if f"{dataset}_" in baseline and task in baseline and not "generated_text" in baseline:
                    baseline_df = pd.read_json(baseline_paths_full[i])
                    baseline_name = baseline_df.columns[0]
                    baseline_scores = baseline_df[baseline_name].tolist()
                    if dataset == "train":
                        # Train samples are limited to 500
                        scores = scores[:500]

                    assert len(scores) == len(baseline_scores)
                    correlations = save_corr(scores, baseline_scores)
                    correlations["approach"] = baseline_name.split("___")[1]
                    if "DSBA" in correlations["approach"] or "MQM" in correlations["approach"]:
                        correlations["model"] = baseline_name.split("___")[2].split(".")[0]
                    else:
                        correlations["model"] = ""
                    correlations["dataset"] = dataset
                    correlations["task"] = task
                    correlations_list.append(correlations)

    baseline_correlations = pd.DataFrame(correlations_list)
    baseline_correlations.to_json(join_with_root("outputs/baseline_correlations/baseline_correlations.json"))
