import argparse
import glob
import json
import os

import pandas as pd
import numpy as np

thisdir = os.path.dirname(os.path.realpath(__file__))
parser = argparse.ArgumentParser(description="Export SEAT results.")
parser.add_argument(
    "--persistent_dir",
    action="store",
    default=os.path.realpath(os.path.join(thisdir, "..")),
    type=str,
    help="Directory where all persistent data will be stored.",
)
parser.add_argument(
    "--alpha",
    action="store",
    type=float,
    default=0.01,
    help="Alpha value for reporting significant results.",
)
parser.add_argument(
    "--bias_type",
    action="store",
    type=str,
    choices=["gender", "race", "religion"],
    default="gender",
    help="Determines which results are exported.",
)


GENDER_TESTS = [
    "sent-weat6",
    "sent-weat6b",
    "sent-weat7",
    "sent-weat7b",
    "sent-weat8",
    "sent-weat8b",
]


RACE_TESTS = [
    "sent-angry_black_woman_stereotype",
    "sent-angry_black_woman_stereotype_b",
    "sent-weat3",
    "sent-weat3b",
    "sent-weat4",
    "sent-weat5",
    "sent-weat5b",
]


RELIGION_TESTS = [
    "sent-religion1",
    "sent-religion1b",
    "sent-religion2",
    "sent-religion2b",
]


if __name__ == "__main__":
    args = parser.parse_args()

    print("Exporting SEAT results:")
    print(f" - persistent_dir: {args.persistent_dir}")
    print(f" - alpha: {args.alpha}")
    print(f" - bias_type: {args.bias_type}")

    # Filter to only a subset of the tests.
    if args.bias_type == "gender":
        tests = GENDER_TESTS
    elif args.bias_type == "race":
        tests = RACE_TESTS
    else:
        tests = RELIGION_TESTS

    # Load the results.
    results = []
    for file_path in glob.glob(f"{args.persistent_dir}/results/seat/*.json"):
        with open(file_path, "r") as f:
            results.extend(json.load(f))
    df = pd.DataFrame.from_records(results)

    # Find the significant results.
    print("Significant results:")
    df_significant_results = df.copy()
    df_significant_results["significant"] = df["p_value"] < args.alpha

    # with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    df_significant_results = df_significant_results[
        df_significant_results["experiment_id"]
        == "seat_m-INLPGPT2Model_c-gpt2_t-gender"
    ]
    print(df_significant_results[df_significant_results["significant"]])

    # Reformat results table so that for a given model, the effect size for each
    # SEAT test is a column.
    df = df[["experiment_id", "test", "effect_size"]]
    df = df[df["test"].isin(tests)]

    df = df.pivot_table(df, index="experiment_id", columns="test")
    df.columns = df.columns.droplevel(0)
    df.columns.name = None
    df = df.reset_index()

    # Compute the average absolute effect size.
    df_avg = df.copy()
    df_avg = df_avg.apply(lambda x: x.abs() if np.issubdtype(x.dtype, np.number) else x)
    df_avg["average_absolute_effect_size"] = df_avg.iloc[:, 1:].mean(axis=1)

    df = pd.merge(
        df,
        df_avg[["experiment_id", "average_absolute_effect_size"]],
        on="experiment_id",
        how="left",
    )

    # Parse the experiment ID.
    df["model"] = df["experiment_id"].str.extract(
        r"seat_m-([A-Za-z0-9-]+)_c-[A-Za-z0-9-]+_t-[A-Za-z-]+"
    )

    df["bias_type"] = df["experiment_id"].str.extract(
        r"seat_m-[A-Za-z-]+_c-[A-Za-z0-9-]+_t-([A-Za-z-]+)"
    )
    df["model_name_or_path"] = df["experiment_id"].str.extract(
        r"seat_m-[A-Za-z-]+_c-([A_Za-z0-9-]+)_t-[A-Za-z-]+"
    )

    with pd.option_context("max_colwidth", 1000):
        print(
            df.to_latex(
                float_format="%.3f",
                # columns=["experiment_id"] + tests + ["average_absolute_effect_size"],
                columns=["experiment_id", "average_absolute_effect_size"],
                index=False,
            )
        )

    with pd.option_context("max_colwidth", 1000):
        with open(f"{args.persistent_dir}/table.tex", "w") as f:
            f.write(
                df.to_latex(
                    float_format="%.3f",
                    columns=["experiment_id"]
                    + tests
                    + ["average_absolute_effect_size"],
                    index=False,
                )
            )
