import argparse
import json
from pathlib import Path
from typing import Dict, List
import numpy as np

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.stats import ttest_ind

CATEGORIES = ["Writing", "Roleplay", "Reasoning", "Math", "Coding", "Extraction", "STEM", "Humanities"]


def get_model_df(path):
    cnt = 0
    q2result = []
    fin = open(path, "r")
    for line in fin:
        obj = json.loads(line)
        numeric_question_id = (
            int(obj["question_id"].split("_")[0]) if isinstance(obj["question_id"], str) else obj["question_id"]
        )
        obj["category"] = CATEGORIES[(numeric_question_id - 81) // 10]
        q2result.append(obj)
    df = pd.DataFrame(q2result)
    return df


def toggle(res_str):
    if res_str == "win":
        return "loss"
    elif res_str == "loss":
        return "win"
    return "tie"


def get_model_df_pair(path):
    fin = open(path, "r")
    cnt = 0
    q2result = []
    for line in fin:
        obj = json.loads(line.strip("\n"))

        result = {}
        result["qid"] = str(obj["question_id"])
        result["turn"] = str(obj["turn"])
        if obj["g1_winner"] == "model_1" and obj["g2_winner"] == "model_1":
            result["result"] = "win"
        elif obj["g1_winner"] == "model_2" and obj["g2_winner"] == "model_2":
            result["result"] = "loss"
        elif obj["g1_winner"] == "both bad" and obj["g2_winner"] == "both bad":
            result["result"] = "both_bad"
        else:
            result["result"] = "tie"
        numeric_question_id = (
            int(obj["question_id"].split("_")[0]) if isinstance(obj["question_id"], str) else obj["question_id"]
        )
        result["category"] = CATEGORIES[(numeric_question_id - 81) // 10]
        result["model"] = obj["model_1"]
        result["competitor"] = obj["model_2"]
        q2result.append(result)

    df = pd.DataFrame(q2result)

    return df


def collect_scores(df_pair: pd.DataFrame):
    scores_all = []
    for competitor in df_pair["competitor"].unique():
        for model in df_pair["model"].unique():
            for cat in CATEGORIES:
                # pairwise result
                res_pair = df_pair[
                    (df_pair["category"] == cat) & (df_pair["model"] == model) & (df_pair["competitor"] == competitor)
                ]["result"].value_counts()
                wincnt = res_pair["win"] if "win" in res_pair.index else 0
                tiecnt = res_pair["tie"] if "tie" in res_pair.index else 0
                bbcnt = res_pair["both_bad"] if "both_bad" in res_pair.index else 0
                losscnt = res_pair["loss"] if "loss" in res_pair.index else 0

                winrate = wincnt / res_pair.sum()
                lossrate = losscnt / res_pair.sum()
                tierate = tiecnt / res_pair.sum()
                bbrate = bbcnt / res_pair.sum()
                winrate_adjusted = (wincnt + tiecnt + bbcnt) / res_pair.sum()
                # print(winrate_adjusted)

                scores_all.append(
                    {
                        "model": model,
                        "competitor": competitor,
                        "category": cat,
                        "Loss rate": lossrate,
                        "Win rate": winrate,
                        "Tie rate": tierate,
                        "Both Bad rate": bbrate,
                        "wtrate": winrate_adjusted,
                    }
                )

    return scores_all


def get_rename_map():
    return {
        "7B-24EU-1T-pre-train-iter_0236250_trfs": "24EU-1T-pre-train",
        "7B-24EU-1T-bactrianx-DE-checkpoint-497": "24EU-1T-Bactrian-X-DE",
        "7B-24EU-1T-limax-ENDEFRITES-checkpoint-20": "24EU-1T-Mulima-ENDEFRITES",
        "7B-24EU-1T-limax-sampled-checkpoint-16": "24EU-1T-Mulima-sampled",
        "7B-24EU-1T-limax-EN-checkpoint-16": "24EU-1T-Mulima-EN",
        "7B-24EU-1T-limax-DE-checkpoint-16": "24EU-1T-Mulima-DE",
        "7B-24EU-1T-bactrianx-EN-checkpoint-497": "24EU-1T-Bactrian-X-EN",
        "7B-24EU-1T-bactrianx-ENDEFRITES-checkpoint-1421": "24EU-1T-Bactrian-X-ENDEFRITES",
        "7B-24EU-1T-bactrianx-ES-24EU-1T-bactrianx-ES": "24EU-1T-Bactrian-X-ES",
        "7B-24EU-1T-bactrianx-FR-checkpoint-497": "24EU-1T-Bactrian-X-FR",
        "7B-24EU-1T-bactrianx-IT-checkpoint-284": "24EU-1T-Bactrian-X-IT",
        "7B-24EU-1T-bactrianx-sampled-checkpoint-248": "24EU-1T-Bactrian-X-sampled",
        "7B-24EU-1T-limax-ES-checkpoint-16": "24EU-1T-Mulima-ES",
        "7B-24EU-1T-limax-FR-checkpoint-16": "24EU-1T-Mulima-FR",
        "7B-24EU-1T-limax-IT-checkpoint-16": "24EU-1T-Mulima-IT",
        "7B-ENDEFRITES-sampled-checkpoint-12": "Mulima-sampled",
        "7B-ENDEFRITES-checkpoint-20": "Mulima-ENDEFRITES",
        "7B-DE-checkpoint-16": "Mulima-DE",
        "7B-ES-checkpoint-20": "Mulima-ES",
        "7B-IT-checkpoint-20": "Mulima-IT",
        "7B-FR-checkpoint-16": "Mulima-FR",
        "7B-EN-checkpoint-12": "Mulima-EN",
        "7B-DE-token-1024-pre-train": "pre-train",
        "7B-bactrianx-ENDEFRITES-sampled-checkpoint-497": "Bactrian-X-sampled",
        "7B-pre-train-7B_ENDEFRITES_iter_0047683_trfs": "pre-train",
        "7B-bactrianx-DE-checkpoint-497": "Bactrian-X-DE",
        "7B-bactrianx-EN-checkpoint-497": "Bactrian-X-EN",
        "7B-bactrianx-ENDEFRITES-checkpoint-1243": "Bactrian-X-ENDEFRITES",
        "7B-bactrianx-FR-checkpoint-497": "Bactrian-X-FR",
        "7B-bactrianx-ES-checkpoint-497": "Bactrian-X-ES",
        "7B-bactrianx-IT-checkpoint-497": "Bactrian-X-IT",
        "7B-mulimax-ENDEFRITES-sampled-checkpoint-12": "Mulima-X-sampled",
        "7B-mulimax-ENDEFRITES-checkpoint-20": "Mulima-X-ENDEFRITES",
        "7B-mulimax-DE-checkpoint-12": "Mulima-X-DE",
        "7B-mulimax-ES-checkpoint-12": "Mulima-X-ES",
        "7B-mulimax-IT-checkpoint-12": "Mulima-X-IT",
        "7B-mulimax-FR-checkpoint-12": "Mulima-X-FR",
        "7B-mulimax-EN-checkpoint-8": "Mulima-X-EN",
        "7B-mistral-DE-checkpoint-8": "Mistral-DE",
        "7B-mistral-EN-checkpoint-4": "Mistral-EN",
        "7B-mistral-ENDEFRITES-sampled-checkpoint-8": "Mistral-sampled",
        "7B-mistral-pre-trained-Mistral-7B-v0.1": "Mitral-pre-train",
    }


def render(df_score, save_path: str = None, is_v2: bool = False):
    df_score = df_score.replace({"24EU-1T-": "", "Bactrian-X-": "Bactrian-"}, regex=True)
    mean_rates = df_score.groupby(["model", "competitor", "category"]).mean().mean()

    # this is a hack to not show "Loading [MathJax]/extensions/MathMenu.js" in the corner of the PDF
    import plotly.io as pio

    pio.kaleido.scope.mathjax = None

    for group in df_score.groupby(["model", "competitor"]):
        model, competitor = group[0]

        data = group[1]
        title = f"{Path(save_path).name}".replace("_", "-").replace("mt", "MT").replace("bench", "Bench")
        title += f": {model} vs. {competitor}"

        bbrate = ["Both Bad rate"] if is_v2 else []
        fig = px.bar(
            data,
            x="category",
            y=["Loss rate", "Tie rate"] + bbrate + ["Win rate"],
            labels={"value": "", "variable": model, "category": ""},
            color_discrete_map={
                "Win rate": "#32a852",
                "Tie rate": "#3281a8",
                "Loss rate": "#9e2020",
                "Both Bad rate": "#5f5b5b"
            },
            text_auto=True,
            category_orders={"category": CATEGORIES}
        )
        fig.update_traces(texttemplate="%{value:.2f}",  textposition = "inside")
        if save_path is not None:
            Path(save_path).mkdir(parents=True, exist_ok=True)
            if is_v2:
                fig.update_layout(
                    legend=dict(
                        orientation="h",  # Set the orientation to horizontal
                        x=-0.05,  # Position the legend at the center horizontally
                        y=-0.05,  # Position the legend below the plot (adjust this value as needed)
                        bgcolor="rgba(255, 255, 255, 0)",  # Make the legend background transparent
                        bordercolor="rgba(255, 255, 255, 0)",  # Make the legend border transparent
                    ),
                    xaxis_tickangle=0,
                )
            else:
                fig.update_layout(
                    legend=dict(
                        orientation="h",  # Set the orientation to horizontal
                        x=0,  # Position the legend at the center horizontally
                        y=-0.05,  # Position the legend below the plot (adjust this value as needed)
                        bgcolor="rgba(255, 255, 255, 0)",  # Make the legend background transparent
                        bordercolor="rgba(255, 255, 255, 0)",  # Make the legend border transparent
                    ),
                    xaxis_tickangle=0,
                )
            fig.update_traces(showlegend=True)
            fig.write_image(Path(save_path) / f"{Path(save_path).name}_bar_plot_{model}_vs_{competitor}{'_v2' if is_v2 else ''}.pdf", scale=2)
        return mean_rates


def as_table(scores_all):
    df = pd.DataFrame.from_records(scores_all)

    rename_map = get_rename_map()
    for k, v in rename_map.items():
        df.replace(k, v, inplace=True)

    df = df.drop(columns=["wtrate"])
    piv = df.pivot(
        index="category", columns=["model", "competitor"], values=["Win rate", "Loss rate", "Tie rate", "Both Bad rate"]
    ).dropna(axis=1)
    piv.columns = piv.columns.reorder_levels(order=[1, 2, 0])
    # piv.columns = piv.columns.sort_index(["model", "competitor"])
    piv

def scores_to_pandas(scores_all: List[Dict], target_models: List[str]) -> pd.DataFrame:
    scores_target = [
        scores_all[i]
        for i in range(len(scores_all))
        if scores_all[i]["model"] in target_models and scores_all[i]["competitor"] in target_models
    ]

    # sort by target_models
    scores_target = sorted(scores_target, key=lambda x: target_models.index(x["model"]), reverse=True)

    df_score = pd.DataFrame(scores_target)
    df_score = df_score[df_score["model"].isin(target_models)]

    rename_map = get_rename_map()

    for k, v in rename_map.items():
        df_score.replace(k, v, inplace=True)
    
    return df_score

def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

def human_scores():
    human_judgements_path = "human_eval/logs/24EU_bactrianx_pair_wise_copy.jsonl"
    df = pd.read_json(human_judgements_path, lines=True, orient="records")
    df = df.rename(columns={"question_id": "qid", "winner": "result", "model_a": "model", "model_b": "competitor"})
    df = df[["qid", "turn", "result", "model", "competitor", "category"]]
    df.result = df.result.replace({"B": "loss", "A": "win", "Tie (both bad)": "both_bad", "Tie": "tie"})
    
    model = df.iloc[0]["model"]
    competitor = df.iloc[0]["competitor"]
    to_model_switch = df.loc[df["model"] == competitor]
    model_switched = to_model_switch.rename({"model":"competitor", "competitor": "model"}, axis=1)
    winner_switched = model_switched.replace({"win": "loss", "loss": "win"})
    df.update(winner_switched)

    df.category = df.category.replace(dict(zip(map(str.lower, CATEGORIES), CATEGORIES)))
    
    decisive_judgments = df[(df.result == "loss") | (df.result == "win")].result
    human_preference = decisive_judgments.value_counts()
    human_preference.name = df.model.iloc[0]
    human_preference = 100 * human_preference / human_preference.sum()

    scores_all = collect_scores(df_pair=df)
    df_score = scores_to_pandas(scores_all=scores_all, target_models=args.model_list)
    return render(
            df_score=df_score,
            save_path=f"data/{args.bench_name}/human_eval_{args.bench_name}/",
            is_v2=True
        )
    

def majority_vote(df: pd.DataFrame):
    votes = df.result
    values, counts = np.unique(votes, return_counts=True)
    idx = np.where(counts == counts.max())[0]
    return df.iloc[idx]

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--bench-name", type=str, required=True)
    parser.add_argument("--model-list", type=str, nargs="+", required=True)
    args = parser.parse_args()
    mean_rates = []
    mean_rates.append(human_scores())

    base_path= Path(f"data/{args.bench_name}/model_judgment/")
    for version_postfix in ["", "_v2"]:
        all_runs = []
        for run_dir in base_path.glob("run*"):
        
            path = run_dir / f"gpt-4_pair{version_postfix}.jsonl"
    
            df = get_model_df_pair(path)
            scores_all = collect_scores(df_pair=df)
            
            as_table(scores_all)
            df_score = scores_to_pandas(scores_all=scores_all, target_models=args.model_list)
            all_runs.append(df_score)

        merge = pd.concat(all_runs).groupby(['model', 'competitor', 'category'], as_index=False).mean()#.aggregate(['mean', 'std'])
        mean_rate = render(
            df_score=merge,
            save_path=f"data/{args.bench_name}",
            is_v2=version_postfix != ""
        )
        mean_rates.append(mean_rate)
    df = pd.concat(mean_rates, axis=1).rename(columns={0:"human", 1:"gpt4", 2:"gpt4bb"})
    print(df)
    