import argparse
import json
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from matplotlib import pyplot as plt
from plotly.subplots import make_subplots
from scipy.stats import ttest_ind
from fastchat.llm_judge.visualisation.mt_bench_single_score_catplot import get_colors

from fastchat.llm_judge.visualisation.render_win_tie_bar_chart import get_rename_map

CATEGORIES = ["Writing", "Roleplay", "Reasoning", "Math", "Coding", "Extraction", "STEM", "Humanities"]


def get_model_df(path):
    cnt = 0
    q2result = []
    fin = open(path, "r")
    for line in fin:
        obj = json.loads(line)
        numeric_question_id = (
            int(obj["question_id"].split("_")[0]) if isinstance(obj["question_id"], str) else obj["question_id"]
        )
        obj["category"] = CATEGORIES[(numeric_question_id - 81) // 10]
        q2result.append(obj)
    df = pd.DataFrame(q2result)
    return df


def toggle(res_str):
    if res_str == "win":
        return "loss"
    elif res_str == "loss":
        return "win"
    return "tie"


def get_model_df_pair():
    fin = open("gpt-4_pair.jsonl", "r")
    cnt = 0
    q2result = []
    for line in fin:
        obj = json.loads(line)

        result = {}
        result["qid"] = str(obj["question_id"])
        result["turn"] = str(obj["turn"])
        if obj["g1_winner"] == "model_1" and obj["g2_winner"] == "model_1":
            result["result"] = "win"
        elif obj["g1_winner"] == "model_2" and obj["g2_winner"] == "model_2":
            result["result"] = "loss"
        else:
            result["result"] = "tie"
        result["category"] = CATEGORIES[(obj["question_id"] - 81) // 10]
        result["model"] = obj["model_1"]
        q2result.append(result)

    df = pd.DataFrame(q2result)

    return df


def collect_scores(df: pd.DataFrame):
    all_models = df["model"].unique()
    scores_all = []
    for model in all_models:
        for cat in CATEGORIES:
            # filter category/model, and score format error (<1% case)
            res = df[(df["category"] == cat) & (df["model"] == model) & (df["score"] >= 0)]
            score = res["score"].mean()
            std = res["score"].std()

            # # pairwise result
            # res_pair = df_pair[(df_pair["category"]==cat) & (df_pair["model"]==model)]["result"].value_counts()
            # wincnt = res_pair["win"] if "win" in res_pair.index else 0
            # tiecnt = res_pair["tie"] if "tie" in res_pair.index else 0
            # winrate = wincnt/res_pair.sum()
            # winrate_adjusted = (wincnt + tiecnt)/res_pair.sum()
            # # print(winrate_adjusted)

            # scores_all.append({"model": model, "category": cat, "score": score, "winrate": winrate, "wtrate": winrate_adjusted})
            scores_all.append({"model": model, "category": cat, "score": score, "std": std})
    target_models = df["model"].unique().tolist()
    scores_target = [scores_all[i] for i in range(len(scores_all)) if scores_all[i]["model"] in target_models]

    # sort by target_models
    scores_target = sorted(scores_target, key=lambda x: target_models.index(x["model"]), reverse=True)

    df_score = pd.DataFrame(scores_target)
    df_score = df_score[df_score["model"].isin(target_models)]

    df_score["val"] = df_score.score

    return df_score

def convert_color(color: tuple):
    return f"rgb{tuple([int(x*255) for x in color])}"

def render(df: pd.DataFrame, name: str, save_path: str | None = None):
    df.model = df.model.str.replace("24EU-1T-", "")
    df.model = df.model.str.replace("pre-train", "24EU-1T-pre-train")
    df.model = df.model.str.replace("Mulima", "Lima")
    df.model = df.model.str.replace("Bactrian-X", "Bactrian")
    target_colors = get_colors()
    model_names = ["Lima", "Bactrian"]
    
    colors = {}
    for model_name in model_names:
        for lang_code in "EN DE FR IT ES sampled ENDEFRITES".split():
            colors[f"{model_name}-{lang_code}"] = convert_color(target_colors[lang_code])
    colors["24EU-1T-pre-train"] = convert_color(target_colors["pre-train"])
    fig = px.line_polar(
        df,
        r="score",
        theta="category",
        line_close=True,
        category_orders={"category": CATEGORIES},
        color="model",
        markers=True,
        color_discrete_map=colors,
        # title=name,
    )

    if save_path is not None:
        # this is a hack to not show "Loading [MathJax]/extensions/MathMenu.js" in the corner of the PDF
        import plotly.io as pio

        pio.kaleido.scope.mathjax = None
        fig.update_layout(
            height=460,
            polar=dict(
                radialaxis=dict(tickvals=list(range(1, 7)), range=[0, 7]),  # should go to 10 for fair visual comparison
            ),
            font=dict(
                size=16.5,
            ),
            autosize=False,
            margin=dict(
                l=50,
                r=30,
                b=0,
                t=0,
            ),
            legend=dict(
                y=0.11,
                bgcolor="rgba(255, 255, 255, 0)",
                bordercolor="rgba(255, 255, 255, 0)",
                title="",  # Set legend title to an empty string
            ),
        )
        fig.update_traces(
            showlegend=True,
            line=dict(width=2.5),
            marker=dict(size=6, symbol="diamond"),
        )

        # Define a list of models containing the string "Mulima"
        mulima_models = df[df["model"].str.contains("Mulima", case=False)]["model"].unique().tolist()
        bactrian_models = df[df["model"].str.contains("Bactrian-X", case=False)]["model"].unique().tolist()

        # Update the traces with dotted lines for models containing "Mulima"
        for i in range(len(fig.data)):
            if fig.data[i].name in bactrian_models:
                fig.update_traces(line=dict(dash="solid"), selector=dict(name=fig.data[i].name))
            elif fig.data[i].name in mulima_models:
                fig.update_traces(line=dict(dash="dash"), selector=dict(name=fig.data[i].name))
            else:
                fig.update_traces(line=dict(dash="dot"), selector=dict(name=fig.data[i].name))

        # Create a dictionary to map line types to their respective traces
        line_type_map = {"solid": [], "dash": [], "dot": [], "dashdot": []}

        # Categorize traces based on their line type
        for trace in fig.data:
            line_type_map[trace.line.dash].append(trace)

        # Flatten the dictionary values to get the sorted list of traces by line type
        sorted_traces = [trace for line_type in line_type_map.values() for trace in line_type]

        # Update the figure with the sorted traces
        fig.data = sorted_traces
        save_file_path = Path(save_path) / f"{Path(save_path).name}_radar_plot.pdf"
        print(f"Saved radar plot to {save_file_path}")
        fig.write_image(save_file_path, scale=2)
    # fig.show()


def create_table_with_std(df_score, name):
    columns = []
    for cat in df_score.category.unique():
        columns.append((cat, "val"))
        columns.append((cat, "std"))
    df = pd.DataFrame(index=df_score.model.unique(), columns=pd.MultiIndex.from_tuples(columns))

    for row in df.index.tolist():
        for category, value in df.columns.tolist():
            val = df_score.where((df_score.model == row) & (df_score.category == category)).dropna().iloc[0][value]
            df.at[row, (category, value)] = val
    df.index.name = name
    df["avg."] = df.loc[:, [v == "val" for cat, v in df.columns]].mean(axis=1)
    df = df.astype(float).round(2)
    df = df.sort_index()

    for row in df.index.tolist():
        for category, value in df.columns.tolist():
            if value == "std":
                df.at[row, (category, value)] = f"±{df.loc[row, (category, value)]}"

    # df = df.rename(
    #     columns={
    #         "Writing": "Wr.",
    #         "Roleplay": "Role.",
    #         "Reasoning": "Reas.",
    #         "Math": "Math",
    #         "Coding": "Code",
    #         "Extraction": "Extr.",
    #         "STEM": "STEM",
    #         "Humanities": "Hum.",
    #     }
    # )
    print()
    df.index = list(map("{: <30}".format, df.index))
    df.columns = ["" if v == "std" else cat for (cat, v) in df.columns]
    df.columns = df.columns.set_names(names=name)
    print(df.iloc[:, :8].to_latex(formatters=["{: <5}".format] * len(df.columns)))
    print(df.iloc[:, 8:].to_latex(formatters=["{: <5}".format] * len(df.columns)))


def create_table(df_score, name):
    df = df_score.drop(columns=["std", "score"])
    df = df.pivot(index="model", columns="category")

    df.index.name = name
    df.columns = df.columns.droplevel(level=0)
    df["Avg."] = df.mean(axis=1)
    # mulimax_mean = df[df.index.str.contains("Mulima-X")].mean(axis=0)
    bactrianx_mean = df[df.index.str.contains("Bactrian-X")].mean(axis=0)
    mulima_mean = df[df.index.str.contains("Mulima") & ~df.index.str.contains("Mulima-X")].mean(axis=0)
    # mistral_mean = df[df.index.str.contains("Mistral")].mean(axis=0)
    # df = pd.concat([df, pd.DataFrame([mulimax_mean], index=["Mulima-X-Avg."], columns=df.columns)], axis=0)
    df = pd.concat([df, pd.DataFrame([bactrianx_mean], index=["Bactrian-X-Avg."], columns=df.columns)], axis=0)
    df = pd.concat([df, pd.DataFrame([mulima_mean], index=["Mulima-Avg."], columns=df.columns)], axis=0)
    # df = pd.concat([df, pd.DataFrame([mistral_mean], index=["Mistral-Avg."], columns=df.columns)], axis=0)
    # resolve hack for calculating correct Mistral mean
    df.rename(columns={"Mitral-pre-trained": "Mistral-pre-trained"}, inplace=True)

    df = df.astype(float).round(2)
    df = df[["Writing", "Roleplay", "Reasoning", "Math", "Coding", "Extraction", "STEM", "Humanities", "Avg."]]

    df = df.rename(
        columns={
            "Writing": "Wr.",
            "Roleplay": "Role.",
            "Reasoning": "Reas.",
            "Math": "Math",
            "Coding": "Code",
            "Extraction": "Extr.",
            "STEM": "STEM",
            "Humanities": "Hum.",
        }
    )

    df.index = list(map("{: <30}".format, df.index))
    df.columns = df.columns.set_names(names=name)
    print(df.to_latex(formatters=["{: <5}".format] * len(df.columns)))


def render_error_bar_bars(df):
    fig2 = px.bar(
        df,
        x="category",
        y="score",
        color="model",
        error_y="std",
        barmode="group",
        color_discrete_sequence=px.colors.qualitative.Light24,
    )

    fig2.update_layout(xaxis={"categoryorder": "total descending"})  # To order the x-axis by total descending
    fig2.show()


def filter_and_rename(df: pd.DataFrame, target_models: List[str]) -> pd.DataFrame:
    if len(target_models) > 0:
        df = df.loc[df["model"].isin(target_models)]
    rename_map = get_rename_map()

    for k, v in rename_map.items():
        df.replace(k, v, inplace=True)
    return df


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--bench-name", type=str, required=True)
    parser.add_argument("--model-list", type=str, nargs="+", default=[])
    args = parser.parse_args()
    path = f"data/{args.bench_name}/model_judgment/gpt-4_single.jsonl"
    df = get_model_df(path)
    df = collect_scores(df=df)
    df = filter_and_rename(df=df, target_models=args.model_list)
    bench_name = args.bench_name.replace("_", "-").replace("mt", "MT").replace("bench", "Bench")
    render(
        df,
        save_path=f"data/{args.bench_name}",
        name=bench_name,
    )
    create_table(df_score=df, name=bench_name)
