"""
Compute agreement among judges.

Usage:
python compute_agreement.py --judges gpt4-pair human --votefiles human_judgments.json gpt4_pair_judgments.json
python compute_agreement.py --judges human human --votefiles human_judgments.json
"""
import argparse
from collections import Counter, OrderedDict, defaultdict
from itertools import chain, combinations
import json
import logging
import os
from pathlib import Path
from typing import Dict
import warnings
from irrCAC.raw import CAC

from nltk.metrics.agreement import AnnotationTask
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns


def get_judge_name(judge):
    if isinstance(judge, list) and judge[0] == "gpt-4" and judge[1].startswith("pair"):
        return "gpt4-pair"
    if judge.startswith("expert"):
        return "human"
    if judge.startswith("author"):
        return "author"


def revert(vote):
    if vote == "model_a":
        return "model_b"
    elif vote == "model_b":
        return "model_a"
    return vote


def get_mt_bench_votes_data(raw_votes, model_list, questions):
    data = [{}, {}]

    for judge_votes in raw_votes:
        for vote in judge_votes:
            if len(model_list) > 0:
                if not (vote["model_a"] in model_list and vote["model_b"] in model_list):
                    continue
            turn = vote["turn"] - 1
            if vote["model_a"] < vote["model_b"]:
                key = (vote["question_id"], vote["model_a"], vote["model_b"])
                winner = vote["winner"]
            else:
                key = (vote["question_id"], vote["model_b"], vote["model_a"])
                winner = revert(vote["winner"])
            judge = get_judge_name(vote["judge"])
            if key not in data[turn]:
                data[turn][key] = {}
            if judge not in data[turn][key]:
                data[turn][key][judge] = []
            data[turn][key][judge].append(convertvote(winner))
    data_with_cat = [{}, {}]
    for turn in range(len(data)):
        for key, value in data[turn].items():
            q_id = key[0]
            cat = questions.loc[questions["question_id"] == q_id]["category"].to_list()[0]
            new_key = (q_id, key[1], key[2], cat)
            data_with_cat[turn][new_key] = value
    return data_with_cat


def convertvote(vote):
    if "both bad" in vote.lower():
        return "both bad"
    elif "tie" in vote.lower():
        return "tie"
    return vote


def equalvote(vote1, vote2):
    return int(vote1 == vote2)


# data: Dict[qid -> List[vote]]
def get_mt_bench_agreement(data, judge1, judge2, ban, cat_bans):
    if judge1.startswith("gpt4") and "human" in judge2:
        stats = [0, 0]
        for (qid, model1, model2, cat), votes in data.items():
            if cat in cat_bans:
                continue
            if judge1 not in votes or "human" not in votes:
                continue
            assert len(votes[judge1]) == 1
            if convertvote(votes[judge1][0]) in ban:
                continue

            if "majority" in judge2:
                human_votes, score = get_majority_votes(votes["human"])
            else:
                human_votes = votes["human"]
                score = 1
            
            for v in human_votes:
                if convertvote(v) in ban:
                    continue
                stats[1] += 1
                stats[0] += score * equalvote(votes[judge1][0], v)
        return stats[0], stats[1]
    elif judge1 == "human" and judge2 == "human":
        stats = [0, 0]
        for (qid, model1, model2, cat), votes in data.items():
            if cat in cat_bans:
                continue
            if "human" not in votes:
                continue
            for i in range(len(votes["human"]) - 1):
                for j in range(i + 1, len(votes["human"])):
                    if convertvote(votes["human"][i]) in ban or convertvote(votes["human"][j]) in ban:
                        continue
                    stats[1] += 1
                    stats[0] += equalvote(votes["human"][i], votes["human"][j])
        return stats[0], stats[1]
    elif judge1 == "human" and judge2 == "human-majority":
        stats = [0, 0]
        for (qid, model1, model2, cat), votes in data.items():
            if cat in cat_bans:
                continue
            if judge1 not in votes:
                continue
            
            human_majority_votes, score = get_majority_votes(votes["human"])
            
            for mv in human_majority_votes:
                for v in votes["human"]:
                    if convertvote(v) in ban or convertvote(mv) in ban:
                        continue
                    stats[1] += 1
                    stats[0] += score * equalvote(mv, v)
        return stats[0], stats[1]
    elif judge1.startswith("gpt4") and judge2.startswith("gpt4"):
        stats = [0, 0]
        for (qid, model1, model2, cat), votes in data.items():
            if cat in cat_bans:
                continue
            if "gpt4-pair" not in votes:
                continue
            for i in range(len(votes["gpt4-pair"]) - 1):
                for j in range(i + 1, len(votes["gpt4-pair"])):
                    if convertvote(votes["gpt4-pair"][i]) in ban or convertvote(votes["gpt4-pair"][j]) in ban:
                        continue
                    stats[1] += 1
                    stats[0] += equalvote(votes["gpt4-pair"][i], votes["gpt4-pair"][j])
        return stats[0], stats[1]
    else:
        raise Exception("Unsupported judges.")

def get_majority_votes(votes):
    values, counts = np.unique(votes, return_counts=True)
    idx = np.where(counts == counts.max())[0]
    return values[idx].tolist(), 1/len(idx)

def filter_by_min_votes_and_trim_tail(data: Dict[tuple, Dict[str, list]], num_votes_humans: int, num_votes_machine: int, human_name="human", machine_name="gpt4-pair"):
    new_data = {}
    # key <-- (qid, model1, model2, category)
    for key, judge_vote_mapping in data.items():
        new_entry = {}
        if human_name not in judge_vote_mapping.keys() or machine_name not in judge_vote_mapping.keys():
            continue
        if len(judge_vote_mapping[human_name]) >= num_votes_humans and len(judge_vote_mapping[machine_name]) >= num_votes_machine:
            for judge, votes in judge_vote_mapping.items():
                if judge == human_name:
                    votes = votes[:num_votes_humans]
                elif machine_name in judge:
                    votes = votes[:num_votes_machine]
                else:
                    raise ValueError("Unkown judge!")
                new_entry[judge] = votes
            new_data[key] = new_entry
    return new_data


def run_mt_bench_agreement(votefiles, model_list, questions):
    # votes[i]: List of votes
    votes = []
    pos_bias = []
    run = ""
    num_runs = 0
    for filename in votefiles:
        data = []
        with open(filename, "r") as f:
            for line in f:
                if line:
                    data.append(json.loads(line))
        if "gpt-4" in filename:
            num_runs += 1
            run = Path(filename).parent.name
            data, pos_bias = transform_gpt_data_schema(data, pos_bias, run)
        votes.append(data)
    
    inspect_pos_bias(pos_bias, num_runs, questions)

    data = get_mt_bench_votes_data(votes, model_list, questions)

    for num_min_votes_human, num_min_votes_machine in [(2,1), (3,1), (4,1)]:
        if num_min_votes_human > 0 and num_min_votes_machine > 0:
            data_turn_1 = filter_by_min_votes_and_trim_tail(data[0], num_votes_humans=num_min_votes_human, num_votes_machine=num_min_votes_machine)
            data_turn_2 = filter_by_min_votes_and_trim_tail(data[1], num_votes_humans=num_min_votes_human, num_votes_machine=num_min_votes_machine)
        else:
            num_min_votes_human, num_min_votes_machine = "varying", "varying"
            data_turn_1 = data[0]
            data_turn_2 = data[1]

        categories = ["coding", "math", "reasoning", "humanities", "stem", "roleplay", "writing", "extraction"]
        combinations_len_r =  [[]] + [["humanities", "stem", "writing"]] + list(chain.from_iterable([combinations(categories, r) for r in [7]]))
        pivots = {}
        for cat_bans in combinations_len_r:
            agreement_entries = []
            for judges in [["gpt4-pair", "human"], ["human", "human"], ["gpt4-pair", "human-majority"], ["human", "human-majority"]]:       
                for turn, turn_data in enumerate([data_turn_1, data_turn_2], start=1):
                    for ban in [[], ["tie", "both bad"], ["both bad"], ["tie"], ["A", "B"]]:
                        mode = ("no_" + "_".join(ban)) if len(ban) > 0 else "all"
                        agree, total = get_mt_bench_agreement(turn_data, judges[0], judges[1], ban=ban, cat_bans=cat_bans)
                        random = f"{int(100 / (4 - len(ban)))}%"
                        if total == 0:
                            ratio = 0.0
                        else:
                            ratio = agree/total
                        agreement_entries.append(dict(ratio=ratio, agree=f"{agree:.2f}/{total}", judge1=judges[0], judge2=judges[1], mode=mode, turn=turn, random=random))
            df = pd.DataFrame.from_records(agreement_entries)
            pivot = df.pivot(columns=["judge2", "turn"], index=["mode", "judge1", "random"], values=["agree", "ratio"])
            pivots[tuple(cat_bans)] = pivot 
        
        tables = []
        for cat_bans, df in pivots.items():
            cats = set(categories).difference(cat_bans)
            cat = "all" if len(cats) == 8 else " ".join(cats)
            # print(f"\n\nWith categories: {cat}, human votes {num_min_votes_human}, machine votes {num_min_votes_machine}")
            # print(df)
            table = calc_agreement_table(df)
            table = table.droplevel(level=0, axis=1)
            table["category"] = cat
            tables.append(table)
        tables = pd.concat(tables, axis=0)
        
        tables = tables.reset_index().sort_values(["turn", "category"])
        tables.category = tables.category.apply(str.capitalize)
        tables = tables[["turn", "category", "GPT-4", "Human"]]
        tables.columns = list(map(str.capitalize, tables.columns.to_list()[:2])) + tables.columns.to_list()[2:]
        tables = pd.pivot(tables, index="Category", columns=["Turn"])
        tables = tables.swaplevel(0,1,axis=1).sort_index(axis=1)

        note = f"human votes {num_min_votes_human} machine votes {num_min_votes_machine}"
        note = "\multicolumn{5}{c}{" + note + "}\\\\"
        result = ""
        for turn, turn_data in enumerate([data_turn_1, data_turn_2], start=1):
            df = pd.DataFrame(turn_data).T.droplevel([1,2]).swaplevel(0,1).sort_index()
            human_agreement_by_fleiss = CAC(pd.DataFrame(df.human.to_list())).fleiss()
            coefficient_value = human_agreement_by_fleiss["est"]["coefficient_value"]
            p_value = human_agreement_by_fleiss["est"]["p_value"]
            coef = f"Turn {turn} (Fleiss' kappa {coefficient_value} (p-value: {p_value:g}) on {len(df)} samples)"
            result += "\multicolumn{5}{c}{" + coef + "}\\\\\n"
        
        warnings.simplefilter(action='ignore', category=FutureWarning)
        print(note)
        print(result)
        print(tables.to_latex(escape = False))

def inspect_pos_bias(pos_bias, num_runs, questions):
    num_turns = 2
    num_runs = num_runs+1
    num_qid_per_category = 10
    pos_bias = pd.DataFrame.from_records(pos_bias)
    pos_bias['category'] = pos_bias.qid.apply(lambda x: questions.loc[questions["question_id"] == x]["category"].to_list()[0])
    pos_bias_mean_across_runs_and_turns = pos_bias.groupby(["category"]).is_pos_bias.sum().to_frame().sort_values("is_pos_bias", ascending=False)
    pos_bias_mean_across_runs_and_turns = 100 * pos_bias_mean_across_runs_and_turns / (num_turns*num_qid_per_category*num_runs)

    pos_bias_count = pos_bias.groupby(["category", "qid", "turn"]).is_pos_bias.sum().to_frame().sort_values("is_pos_bias", ascending=False)
    pos_bias_after_run_merge = pos_bias_count[pos_bias_count == 3].dropna().droplevel([1,2]).reset_index().groupby("category").count().sort_values("is_pos_bias", ascending=False)
    pos_bias_after_run_merge = 100 * pos_bias_after_run_merge / (num_turns*num_qid_per_category)
    
    mitigation_percent = pos_bias_mean_across_runs_and_turns - pos_bias_after_run_merge.reindex(pos_bias_mean_across_runs_and_turns.index).fillna(0)
    pos_bias_mean_across_runs_and_turns.columns = pd.MultiIndex.from_tuples([("Pos. Bias", "Across Runs")])
    pos_bias_mean_across_runs_and_turns.index = pos_bias_mean_across_runs_and_turns.index.str.capitalize()
    print(pos_bias_mean_across_runs_and_turns)
    print(pos_bias_mean_across_runs_and_turns.style.format(precision=2).to_latex())

    print(pos_bias_after_run_merge)

def calc_agreement_table(df):
    ratio = df.loc[["all"], "ratio"]
    ratio.index = ratio.index.droplevel(0)
    ratio = (ratio * 100).astype(int)
    ratio.update("$" + ratio.loc[("gpt4-pair", slice(None)), ("human", slice(None))].astype(str) + "_{" + ratio.loc[("human", slice(None)), ("human-majority", slice(None))].astype(str).values + "}$")
    ratio = ratio.loc[slice(None), "human"]
    ratio = ratio.swaplevel(0,1,axis=0).T
    ratio.columns = pd.MultiIndex.from_tuples([(f"$R={v[:-1]}\%$", v2.replace("gpt4-pair", "GPT-4").replace("human", "Human")) for v, v2 in ratio.columns.values])
    return ratio

def print_stats(data):
    entries = []
    for d in data:
        for (qid, model1, model2, cat), value in d.items():
            for judge, winners in value.items():
                for winner in winners:
                    # entries.append(dict(qid=qid, model1=model1, model2=model2, category=cat, value=winner, judge=judge))
                    entries.append(dict(category=cat, value=convertvote(winner), judge=judge))
    df = pd.DataFrame.from_records(entries)
    stats = df.groupby(["category", "judge", "value"]).size().unstack(fill_value=0)
    
    stats_in_percent = stats.div(stats.sum(axis=1), axis=0).round(2)
    print("Stats in Percent")
    print(stats_in_percent)

    stats["total"] = stats.sum(axis=1)
    print("\nStats in Counts")
    print(stats)

    # df2 = stats_in_percent.reset_index().pivot(columns="category", index="judge", values=["A", "B", "Tie"])
    # sns.heatmap(df2)
    # plt.show()


def transform_gpt_data_schema(data, pos_bias: list, run: str):
    """
    map
    pair_wise_gpt4 = {
        "question_id": "136_EN",
         "model_1": "7B-bactrianx-FR-checkpoint-497",
         "model_2": "7B-bactrianx-EN-checkpoint-497",
         "g1_winner": "model_1",
         "g2_winner": "model_1",
         "judge": ["gpt-4", "pair-v2"],
         "g1_user_prompt": "",
         "g1_judgment": "",
         "g2_user_prompt": "",
         "g2_judgment": "",
         "turn": 1,
         "tstamp": 1702319706.8325589
    }
    to
    {
        "question_id": "156_DE",
        "model_a": "7B-mulimax-ENDEFRITES-sampled-checkpoint-12",
        "model_b": "7B-mulimax-EN-checkpoint-8",
        "turn": 2,
        "language": "DE",
        "category": "humanities",
        "judge": "expert",
        "winner": "A",
        "timestamp": 1702412378.9174,
        "user_ip": "127.0.0.1"
    }
    """
    new_data = []
    for entry in data:
        stat = {
            "run": run,
            "qid": entry["question_id"],
            "turn": entry["turn"],
        }
        if entry["g1_winner"] == entry["g2_winner"]:
            new_data.append(
                dict(
                    question_id=entry["question_id"],
                    model_a=entry["model_1"],
                    model_b=entry["model_2"],
                    turn=entry["turn"],
                    language=entry["question_id"].split("_")[-1],
                    category="",
                    judge=entry["judge"],
                    timestamp=entry["tstamp"],
                    winner=transform_winner(entry["g1_winner"]),
                )
            )
            stat["is_pos_bias"] = False
        else:
            stat["is_pos_bias"] = True
        pos_bias.append(stat)
    return new_data, pos_bias


def transform_winner(gpt_winner):
    if gpt_winner == "model_1":
        return "A"
    elif gpt_winner == "model_2":
        return "B"
    elif gpt_winner == "both bad":
        return "both bad"
    else:
        return "Tie"



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mt-bench", type=str)
    parser.add_argument(
        "--votefiles",
        nargs="+",
        type=str,
        default=["gpt4_judgments.json", "human_judgments.json"],
    )
    parser.add_argument("--model-list", type=str, nargs="+", default=[])
    args = parser.parse_args()
    questions = pd.read_json(f"data/{args.mt_bench}/question.jsonl", lines=True, orient="records")
    # inspect_votes(args.votefiles, args.model_list, questions)
    run_mt_bench_agreement(args.votefiles, args.model_list, questions)
