import argparse
import os
from typing import List

import pandas as pd
from visualize import plot_heatmap


def overlap_analysis(
    root_dir: str, langs: List[str], save_dir: str, ratio: float = 0.3
):
    os.makedirs(save_dir, exist_ok=True)
    df_dict = {}
    for lang in langs:
        df_dict[lang] = pd.read_csv(f"{root_dir}/{lang}/hidden_states.csv")
        if lang != "en":
            to_en = lang + "-en"
            df_dict[to_en] = pd.read_csv(f"{root_dir}/{to_en}/hidden_states.csv")

    num_layers = len(df_dict[langs[0]])
    for l in range(num_layers):
        overlap_df = pd.DataFrame()
        for lang in langs + ["fr-en", "es-en", "zh-en"]:
            lang_sorted_indexs = df_dict[lang].loc[l, "dim0":].to_numpy().argsort()
            num_indexs = round(len(lang_sorted_indexs) * ratio)
            lang_highest_indexs = lang_sorted_indexs[-num_indexs:]
            for lang2 in langs + ["fr-en", "es-en", "zh-en"]:
                lang2_sorted_indexs = (
                    df_dict[lang2].loc[l, "dim0":].to_numpy().argsort()
                )
                lang2_lowest_indexs = lang2_sorted_indexs[:num_indexs]
                overlap_ratio = len(set(lang_highest_indexs) & set(lang2_lowest_indexs))
                overlap_df.loc[lang, lang2] = overlap_ratio

        plot_heatmap(
            overlap_df,
            annot=True,
            vmin=0,
            vmax=200,
            save_path=f"{save_dir}/layer{l}.overlap.{ratio}.png",
            fmt=".0f",
            xy_labels=[
                f"Bottom-{int(ratio * 100)}% magnitude features",
                f"Top-{int(ratio * 100)}% magnitude features",
            ],
            colorbar_title="Number of overlaps",
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root_dir", type=str)
    parser.add_argument("--langs", nargs="*", type=str)
    parser.add_argument("--save_dir", type=str)
    parser.add_argument("--ratio", type=float, default=0.3)
    args = parser.parse_args()

    overlap_analysis(
        root_dir=args.root_dir,
        langs=args.langs,
        save_dir=args.save_dir,
        ratio=args.ratio,
    )
