import experiments.config
import ast, json
import os
import polars as pl
import numpy as np
import pandas as pd

# data_folder = "data"
data_folder = "final_data"

def get_top2_accuracy(y_trues, ranks):
    l = np.array([str(y_true) in ranks[i][:2] for i, y_true in enumerate(y_trues)])
    return l

def get_top5_accuracy(y_trues, ranks):
    l = np.array([y_true in ranks[i][:5] for i, y_true in enumerate(y_trues)])
    return l
    
def analyze_1_time(config, print_results=True):
    save_path = os.path.join(
        os.path.dirname(__file__), "..", data_folder, config["name"], "data.json"
    )
    data = json.load(open(save_path))
    if print_results:
        print(f"Config: {config['name']}")
        print(f"Total Time Elapsed: {data['total_time_elapsed']:.2f} seconds")
        print(f"LLM Time Elapsed: {data['llm_time_elapsed']:.2f} seconds")
    return data["total_time_elapsed"], data["llm_time_elapsed"]

def analyze_1_accuracy(config, print_results=True):
    save_path = os.path.join(
        os.path.dirname(__file__), "..", data_folder, config["name"], "data.csv"
    )
    out_ds = pl.read_csv(
        save_path,
        schema_overrides={
            "authorId": pl.String,
            "pred_authorId": pl.String,
            "textId": pl.String,
        },
    )

    y_true = out_ds["authorId"].to_numpy()
    y_pred = out_ds["pred_authorId"].to_numpy()
    has_rank = "rank" in out_ds.columns
    if has_rank:
        ranks = (
            out_ds["rank"]
            .map_elements(ast.literal_eval, return_dtype=pl.List(pl.String))
            .to_numpy()
        )
        l2 = get_top2_accuracy(y_true, ranks)
        l5 = get_top5_accuracy(y_true, ranks)
    l = np.array(y_true == y_pred)
    if print_results:
        print(f"Config: {config['name']}")
        print(f"The number of samples is {len(l)}")
        print(f"Top 1 Accuracy: {np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}")
        if has_rank:
            print(
                f"Top 2 Accuracy: {np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}"
            )
            print(
                f"Top 5 Accuracy: {np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}"
            )
    if has_rank:
        return l, l2, l5
    else:
        return l, None, None


def analyze_bias_gender():
    config = experiments.config.large_exp1[1]

    save_path = os.path.join(
        os.path.dirname(__file__), "..", data_folder, config["name"], "data.csv"
    )
    out_ds = pl.read_csv(
        save_path,
        schema_overrides={
            "authorId": pl.String,
            "pred_authorId": pl.String,
            "textId": pl.String,
        },
    )

    print("Bias: gender")

    records = []
    for gender in [None, "male", "female"]:
        if gender is None:
            lds = out_ds
        else:
            lds = out_ds.filter(pl.col("gender") == gender)
        y_true = lds["authorId"].to_numpy()
        y_pred = lds["pred_authorId"].to_numpy()
        ranks = (
            lds["rank"]
            .apply(ast.literal_eval, return_dtype=pl.List(pl.String))
            .to_numpy()
        )
        l = np.array(y_true == y_pred)
        l2 = get_top2_accuracy(y_true, ranks)
        l5 = get_top5_accuracy(y_true, ranks)
        records.append(
            {
                "gender": gender or "both",
                "num_samples": len(l),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
                "top 2 accuracy": f"{np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}",
                "top 5 accuracy": f"{np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}",
            }
        )

    ds = pl.DataFrame(records)
    print(ds)

def analyze_1_accuracy_by_gender(config, print_results=True):
    save_path = os.path.join(
        os.path.dirname(__file__), "..", data_folder, config["name"], "data.csv"
    )
    out_ds = pl.read_csv(
        save_path,
        schema_overrides={
            "authorId": pl.String,
            "pred_authorId": pl.String,
            "textId": pl.String,
        },
    )
    
    gender_data = out_ds["gender"]

    for gender in ["male", "female"]:
        gender_filter = (gender_data == gender)
        y_true = out_ds.filter(gender_filter)["authorId"].to_numpy()
        y_pred = out_ds.filter(gender_filter)["pred_authorId"].to_numpy()
        ranks = out_ds.filter(gender_filter)["rank"].apply(ast.literal_eval).to_numpy()

        l2 = get_top2_accuracy(y_true, ranks)
        l5 = get_top5_accuracy(y_true, ranks)
        l = np.array(y_true == y_pred)

        if print_results:
            print(f"{gender.capitalize()} - Top 1 Accuracy: {np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}")
            print(f"{gender.capitalize()} - Top 2 Accuracy: {np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}")
            print(f"{gender.capitalize()} - Top 5 Accuracy: {np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}")
    return l, l2, l5


def analyze_1_accuracy_by_rate(config, print_results=True):
    save_path = os.path.join(
        os.path.dirname(__file__), "..", data_folder, config["name"], "data.csv"
    )
    out_ds = pl.read_csv(
        save_path,
        schema_overrides={
            "authorId": pl.String,
            "pred_authorId": pl.String,
            "textId": pl.String,
        },
    )
    
    y_true = out_ds["authorId"].to_numpy()
    y_pred = out_ds["pred_authorId"].to_numpy()
    ranks = out_ds["rank"].apply(ast.literal_eval).to_numpy()
    l2 = get_top2_accuracy(y_true, ranks)
    l5 = get_top5_accuracy(y_true, ranks)
    l = np.array(y_true == y_pred)
    if print_results:
        print(f"Top 1 Accuracy: {np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}")
        print(f"Top 2 Accuracy: {np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}")
        print(f"Top 5 Accuracy: {np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}")

    
    for i in range(5):
        
        subset = out_ds.filter((out_ds["rating"] >= 1+i*2) & (out_ds["rating"] <= 2+i*2))
        subset_y_true = subset["authorId"].to_numpy()
        subset_y_pred = subset["pred_authorId"].to_numpy()
        subset_ranks = subset["rank"].apply(ast.literal_eval).to_numpy()

        l2 = get_top2_accuracy(subset_y_true, subset_ranks)
        l5 = get_top5_accuracy(subset_y_true, subset_ranks)
        l = np.array(subset_y_true == subset_y_pred)
        if print_results:
            print(f"Subset {i+1} Top 1 Accuracy: {np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}")
            print(f"Subset {i+1} Top 2 Accuracy: {np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}")
            print(f"Subset {i+1} Top 5 Accuracy: {np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}")
    
    return l, l2, l5

def analyze_1_accuracy_by_length(config, print_results=True):
    save_path = os.path.join(
        os.path.dirname(__file__), "..", data_folder, config["name"], "data.csv"
    )
    out_ds = pl.read_csv(
        save_path,
        schema_overrides={
            "authorId": pl.String,
            "pred_authorId": pl.String,
            "textId": pl.String,
        },
    )
    
    y_true = out_ds["authorId"].to_numpy()
    y_pred = out_ds["pred_authorId"].to_numpy()
    ranks = out_ds["rank"].apply(ast.literal_eval).to_numpy()
    l2 = get_top2_accuracy(y_true, ranks)
    l5 = get_top5_accuracy(y_true, ranks)
    l = np.array(y_true == y_pred)
    if print_results:
        print(f"Top 1 Accuracy: {np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}")
        print(f"Top 2 Accuracy: {np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}")
        print(f"Top 5 Accuracy: {np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}")

    max_length = out_ds['word_length'].max()
    min_length = out_ds['word_length'].min()
    bin_edges = np.linspace(min_length, max_length, 6)  
    for i in range(len(bin_edges)-1):
        subset = out_ds.filter((pl.col("word_length") > bin_edges[i]) & (pl.col("word_length") < bin_edges[i+1]))
        subset_y_true = subset["authorId"].to_numpy()
        subset_y_pred = subset["pred_authorId"].to_numpy()
        subset_ranks = subset["rank"].apply(ast.literal_eval).to_numpy()

        l2 = get_top2_accuracy(subset_y_true, subset_ranks)
        l5 = get_top5_accuracy(subset_y_true, subset_ranks)
        l = np.array(subset_y_true == subset_y_pred)
        if print_results:
            print(f"Subset {i+1} ({bin_edges[i]}-{bin_edges[i+1]}) Top 1 Accuracy: {np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}")
            print(f"Subset {i+1} ({bin_edges[i]}-{bin_edges[i+1]}) Top 2 Accuracy: {np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l)):.4f}")
            print(f"Subset {i+1} ({bin_edges[i]}-{bin_edges[i+1]}) Top 5 Accuracy: {np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l)):.4f}")
    
    
    return l, l2, l5


def analyze_ablation_exp1():
    config = experiments.config.ablation_exp1
    records = []
    for c in config:
        l, l2, l5 = analyze_1_accuracy(c, print_results=False)
        records.append(
            {
                "num_authors": c["num_authors"],
                #  "accuracy": np.mean(l),
                #  "std": np.std(l) / np.sqrt(len(l)),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
                "top 2 accuracy": f"{np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}",
                "top 5 accuracy": f"{np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}",
            }
        )
    ds = pl.DataFrame(records)
    

def analyze_ablation_exp2():
    config = experiments.config.ablation_exp2
    records = []
    for c in config:
        l, l2, l5 = analyze_1_accuracy(c, print_results=False)
        records.append(
            {
                "num_authors": c["num_authors"],
                #  "accuracy": np.mean(l),
                #  "std": np.std(l) / np.sqrt(len(l)),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
                "top 2 accuracy": f"{np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}",
                "top 5 accuracy": f"{np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}",
            }
        )
    ds = pl.DataFrame(records)
    print(ds)

def analyze_ablation_exp3():
    config = experiments.config.ablation_exp3
    records = []
    print(config)
    for c in config:
        print(c)
        l, l2, l5 = analyze_1_accuracy(c, print_results=False)
        records.append(
            {
                "num_authors": c["num_authors"],
                #  "accuracy": np.mean(l),
                #  "std": np.std(l) / np.sqrt(len(l)),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
                "top 2 accuracy": f"{np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}",
                "top 5 accuracy": f"{np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}",
            }
        )
    ds = pl.DataFrame(records)
    print(ds)


def analyze_ablation_exp4():
    config = experiments.config.ablation_exp4
    records = []
    print(config)
    for c in config:
        print(c)
        l, l2, l5 = analyze_1_accuracy(c, print_results=False)
        records.append(
            {
                "num_authors": c["num_authors"],
                #  "accuracy": np.mean(l),
                #  "std": np.std(l) / np.sqrt(len(l)),
                "num_samples": len(l),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
                "top 2 accuracy": f"{np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}",
                "top 5 accuracy": f"{np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}",
            }
        )
    ds = pl.DataFrame(records)
    print(ds)



def analyze_ablation_exp1():
    config = experiments.config.ablation_exp1
    records = []
    print(f"Config: {config[0]['name']}")

    import inspect

    for c in config:
        try:
            l, l2, l5 = analyze_1_accuracy(c, print_results=False)
        except Exception as e:
            continue
        records.append(
            {
                "num_authors": f"{c['num_authors']}",
                "num_samples": len(l),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
                "top 2 accuracy": f"{np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}",
                "top 5 accuracy": f"{np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}",
            }
        )
    ds = pl.DataFrame(records)
    print(ds)

def analyze_ablation_exp2():
    config = experiments.config.ablation_exp2
    records = []
    print(f"Config: {config[0]['name']}")

    import inspect

    for c in config:
        try:
            l, l2, l5 = analyze_1_accuracy(c, print_results=False)
        except Exception as e:
            continue
        records.append(
            {
                "gender": f"{c['filter_config']['gender']}",
                "num_samples": len(l),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
                "top 2 accuracy": f"{np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}",
                "top 5 accuracy": f"{np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}",
            }
        )
    ds = pl.DataFrame(records)
    print(ds)


def analyze_ablation_exp4():
    config = experiments.config.ablation_exp4
    records = []
    print(f"Config: {config[0]['name']}")

    import inspect

    for c in config:
        try:
            l, l2, l5 = analyze_1_accuracy(c, print_results=False)
        except Exception as e:
            continue
        records.append(
            {
                "age_range": f"{c['filter_config']['min_age']}-{c['filter_config']['max_age']}",
                "num_samples": len(l),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
                "top 2 accuracy": f"{np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}",
                "top 5 accuracy": f"{np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}",
            }
        )
    ds = pl.DataFrame(records)
    print(ds)


def analyze_ablation_exp5():
    config = experiments.config.ablation_exp5
    records = []
    print(f"Config: {config[0]['name']}")

    import inspect

    for c in config:
        try:
            l, l2, l5 = analyze_1_accuracy(c, print_results=False)
        except Exception as e:
            continue
        records.append(
            {
                "code": inspect.getsource(c["prompt_composer"]).strip(),
                "num_samples": len(l),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
                "top 2 accuracy": f"{np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}",
                "top 5 accuracy": f"{np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}",
            }
        )
    ds = pl.DataFrame(records)
    print(ds)


def analyze_ablation_exp6():
    config = experiments.config.ablation_exp6
    records = []
    print(f"Config: {config[0]['name']}")

    import inspect

    for c in config:
        try:
            l, l2, l5 = analyze_1_accuracy(c, print_results=False)
        except Exception as e:
            continue
        records.append(
            {
                "rating_range": f"{c['filter_config']['min_rating']}-{c['filter_config']['max_rating']}",
                "num_samples": len(l),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
                "top 2 accuracy": f"{np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}",
                "top 5 accuracy": f"{np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}",
            }
        )
    ds = pl.DataFrame(records)
    print(ds)


def analyze_compare_exp1():
    config = experiments.config.compare_exp1
    records = []
    print(f"Config: {config[0]['name']}")

    import inspect

    for c in config:
        try:
            l, l2, l5 = analyze_1_accuracy(c, print_results=False)
        except Exception as e:
            continue
        records.append(
            {
                "model_name": f"{c['completion_kwargs']['model']}",
                "ds_name": f"{c['ds_name']}",
                "num_samples": len(l),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
                "top 2 accuracy": f"{np.mean(l2):.4f} ± {np.std(l2)/np.sqrt(len(l2)):.4f}",
                "top 5 accuracy": f"{np.mean(l5):.4f} ± {np.std(l5)/np.sqrt(len(l5)):.4f}",
            }
        )
    ds = pl.DataFrame(records)
    
    with pl.Config(fmt_str_lengths=1000, tbl_rows=ds.height):
        print(ds)


def analyze_compare_exp2():
    config = experiments.config.compare_exp2
    records = []
    print(f"Config: {config[0]['name']}")

    import inspect

    for c in config:
        try:
            l, l2, l5 = analyze_1_accuracy(c, print_results=False)
        except Exception as e:
            continue
        records.append(
            {
                "model_name": f"{c['completion_kwargs']['model']}",
                "ds_name": f"{c['ds_name']}",
                "num_samples": len(l),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
            }
        )
    ds = pl.DataFrame(records)
    print(ds)


def analyze_compare_exp3():
    config = experiments.config.compare_exp3
    records = []
    print(f"Config: {config[0]['name']}")

    import inspect

    for c in config:
        try:
            l, l2, l5 = analyze_1_accuracy(c, print_results=False)
            time_elapsed, llm_time_elapsed = analyze_1_time(c, print_results=False)
        except Exception as e:
            raise e
            continue
        records.append(
            {
                "name": f"{c['name']}",
                "model_name": f"{c['completion_kwargs']['model']}",
                "ds_name": f"{c['ds_name']}",
                "num_samples": len(l),
                "top 1 accuracy": f"{np.mean(l):.4f} ± {np.std(l)/np.sqrt(len(l)):.4f}",
                "llm_time_elapsed": f"{llm_time_elapsed:.2f}",
            }
        )
    ds = pl.DataFrame(records)
    print(ds)


def analyze_all():
    # data_folder = "data"
    # analyze_1_accuracy(experiments.config.debug_exp1)
    # analyze_1_accuracy(experiments.config.debug_exp2)
    # analyze_1_accuracy(experiments.config.debug_exp3)
    # analyze_1_accuracy(experiments.config.debug_exp4)
    data_folder = "final_data"
    analyze_ablation_exp1()
    analyze_ablation_exp2()
    analyze_ablation_exp4()
    analyze_ablation_exp5()
    analyze_ablation_exp6()
    analyze_bias_gender()
    analyze_compare_exp1()
    analyze_compare_exp2()
    analyze_compare_exp3()


if __name__ == "__main__":
    analyze_all()
