import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt
from project_root import join_with_root


def gen_box_plots(df, main_axis="kendall", sub1="regex", sub2="task", sub3="model"):
    # Create subplots for each combination of model and task
    if type(df[sub1].tolist()[0]) == dict:
        df[sub1] = [u["name"] for u in df[sub1].tolist()]
    u1 = df[sub1].unique()
    u2 = df[sub2].unique()

    fig, axes = plt.subplots(len(u1), len(u2), figsize=(10, 10))

    # Set common x-axis limits
    ymin = df[main_axis].min()
    ymax = df[main_axis].max()

    # Iterate through each combination of model and task
    for i, s1 in enumerate(u1):
        for j, s2 in enumerate(u2):
            # Filter data for the current model and task
            subset = df[(df[sub1] == s1) & (df[sub2] == s2)]

            s3_data = {}
            # Plot boxplot for kendall values
            for idx, s3 in enumerate(subset[sub3].unique()):
                s3_data[s3] = subset[subset[sub3] == s3]["kendall"]

                # Plot boxplot for kendall values with different colors
            sns.boxplot(s3_data, ax=axes[i, j]).set_title(
                    f'{s1} - {s2}')

            axes[i, j].set_ylim(ymin, ymax+0.02)

    #axes[len(u1)-1, len(u2)-1].legend()

    plt.tight_layout()
    plt.show()

    return df

if __name__ == '__main__':
    train_files = [join_with_root("outputs/evaluation/corr_zero_shot_train_avg.json"),
                   ]#join_with_root("outputs/evaluation/corr_few_shot_train_avg.json")]
    df = pd.concat([pd.read_json(t) for t in train_files]).reset_index()
    gen_box_plots(df)