import json
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go


CATEGORIES = ["Writing", "Roleplay", "Reasoning", "Math", "Coding", "Extraction", "STEM", "Humanities"]


def get_model_df():
    cnt = 0
    q2result = []
    fin = open("./data/mt_bench/model_judgment/gpt-4_single.jsonl", "r")
    for line in fin:
        obj = json.loads(line)
        obj["category"] = CATEGORIES[(obj["question_id"]-81)//10]
        q2result.append(obj)
    df = pd.DataFrame(q2result)
    return df

df = get_model_df()

all_models = df["model"].unique()
print(all_models)
scores_all = []
for model in all_models:
    for cat in CATEGORIES:
        # filter category/model, and score format error (<1% case)
        res = df[(df["category"]==cat) & (df["model"]==model) & (df["score"] >= 0)]
        score = res["score"].mean()

        # # pairwise result
        # res_pair = df_pair[(df_pair["category"]==cat) & (df_pair["model"]==model)]["result"].value_counts()
        # wincnt = res_pair["win"] if "win" in res_pair.index else 0
        # tiecnt = res_pair["tie"] if "tie" in res_pair.index else 0
        # winrate = wincnt/res_pair.sum()
        # winrate_adjusted = (wincnt + tiecnt)/res_pair.sum()
        # # print(winrate_adjusted)

        # scores_all.append({"model": model, "category": cat, "score": score, "winrate": winrate, "wtrate": winrate_adjusted})
        scores_all.append({"model": model, "category": cat, "score": score})

target_models = ["llama-2-chat", "llama-2-chat-dpo", "llama-2-chat-sft", "llama-2-chat-sft-dpo"]#,  "Llama-2-13b-chat", "Llama-2-70b-chat"]#, "gpt-3.5-turbo", "claude-v1", "gpt-4"]
target_models = ["llama-2", "llama-2-sft", "llama-2-sft-dpo"]
target_models = ["mistral", "mistral-sft", "mistral-sft-dpo"]
target_models = ["zephyr", "zephyr-dpo", "zephyr-sft", "zephyr-sft-dpo"]
target_models = ["llama-2-chat", "llama-2-chat-dpo", "Llama-2-13b-chat", "Llama-2-70b-chat"]
target_models = ["llama-2-chat", "llama-2-chat-dpo", "llama-2-chat-sft", "llama-2-chat-sft-dpo", "zephyr", "zephyr-dpo", "zephyr-sft", "zephyr-sft-dpo"]
target_models = ["llama-2-chat", "llama-2-chat-dpo", "zephyr", "zephyr-dpo"]
target_models = ["llama-2", "llama-2-sft", "llama-2-sft-dpo", "mistral", "mistral-sft", "mistral-sft-dpo"]

# target_models = all_models.tolist()

scores_target = [scores_all[i] for i in range(len(scores_all)) if scores_all[i]["model"] in target_models]

# sort by target_models
scores_target = sorted(scores_target, key=lambda x: target_models.index(x["model"]), reverse=True)

df_score = pd.DataFrame(scores_target)
df_score = df_score[df_score["model"].isin(target_models)]

rename_map = {"llama-13b": "LLaMA-13B",
              "alpaca-13b": "Alpaca-13B",
              "vicuna-33b-v1.3": "Vicuna-33B",
              "vicuna-13b-v1.3": "Vicuna-13B",
              "gpt-3.5-turbo": "GPT-3.5-turbo",
              "claude-v1": "Claude-v1",
              "gpt-4": "GPT-4"}

for k, v in rename_map.items():
    df_score.replace(k, v, inplace=True)
# df_score.to_csv("all_models.csv", sep=";", decimal=",")
fig = px.line_polar(df_score, r = 'score', theta = 'category', line_close = True, category_orders = {"category": CATEGORIES},
                    color = 'model', markers=True, color_discrete_sequence=px.colors.qualitative.Pastel)

fig.write_image(file="llama-vs-mistral_good.png", format="png")