import json
import os
from collections import Counter, defaultdict
# import numpy as np
import argparse
import pprint
import random

PATH = "/Users/dxm/Project/MultiPoT/bbh/"
# PATH = "./"


def softmax(logits):
    logits = np.array(logits)
    logits = np.exp(logits-np.max(logits))
    return logits / logits.sum()


def get_file_name(dataset, lang, args):
    path1 = PATH + f"{args.middle_dir}/{dataset}_{lang}{args.output_suffix}1.json"
    if args.old and os.path.exists(path1):
        print(f"{path1} exist!")
        return path1
    return PATH + f"{args.middle_dir}/{dataset}_{lang}{args.output_suffix}.json"


def evaluate(dataset, langs, args):
    ds = json.load(open(get_file_name(dataset, langs[0], args), "r"))
    for lang in langs[1:]:
        cur_ds = json.load(open(get_file_name(dataset, lang, args), "r"))
        for d, cd in zip(ds, cur_ds):
            d["exec_ans"].extend(cd["exec_ans"])
            d["code"].extend(cd["code"])
            d["RE"].extend(cd["RE"])
            if "loss" in d:
                d["loss"].extend(cd["loss"])
            if "verfier_loss" in d:
                d["verfier_loss"].extend(cd["verfier_loss"])
    if len(langs) > 1:
        for d in ds:
            if "loss" in d:
                d["langs"] = langs
                d["loss"], d["exec_ans"], d["code"], d["RE"], d["langs"] = zip(*sorted(zip(d["loss"], d["exec_ans"], d["code"], d["RE"], d["langs"]), reverse=True))
            else:
                tmp = list(zip(d["exec_ans"], d["code"], d["RE"]))
                random.shuffle(tmp)
                d["exec_ans"], d["code"], d["RE"] = zip(*tmp)
    # json.dump(ds, open(f"{dataset}_debug.json", "w"), indent=4)
    ans, ds = calcu(ds, args)
    # if len(langs) == 1:
    #     json.dump(ds, open(get_file_name(dataset, langs[0], args), "w"), indent=4)
    print(ans)


def print_can_correct(can_correct):
    writer = open("can_correct.md", "w")
    writer.write(f"SUM: {len(can_correct)}\n\n")
    writer.write(f"| idx | target | {' | '.join(can_correct[0]['langs'])} |\n")
    writer.write(f"| --- | --- | {'--- | ' * len(can_correct[0]['langs'])}\n")
    for d in can_correct:
        writer.write(f"| {d['idx']} | {d['target']} | ")
        for e in d["exec_ans"]:
            if e == str(d["target"]):
                writer.write("✅ | ")
            else:
                writer.write(f"{'RE' if isinstance(e, str) else e } | ")
        writer.write("\n")


def calcu(ds, args):
    total = len(ds)
    correct = 0
    can_correct = []
    for d in ds:
        target = d['target']
        re_ans = []
        for a in d['exec_ans']:
            if isinstance(target, int):
                try :
                    a2 = float(a)
                    if abs(a2 - target) < 1e-3:
                        re_ans.append(str(target))
                    else:
                        re_ans.append(str(round(a2,3)))
                except:
                    re_ans.append(a)
            else:
                if target in ["yes", "no"]:
                    extract_ans = d["input"].split(" ")[-1][:-1]
                    if a.lower() == "true" or a.lower() == "yes" or a == "1" or a == extract_ans:
                        re_ans.append("yes")
                    else:
                        re_ans.append("no")
                    # print(target, extract_ans, a, re_ans[-1])
                else:
                    re_ans.append(a)
        if args.opt == "best":
            c = Counter([a for a, b in zip(re_ans, d["RE"]) if not b])
            if str(target) in c:
                correct += 1
                d["passed"] = True
            else:
                d["passed"] = False
        else:
            # ws = defaultdict(float)
            if args.opt == "vote":
                c = Counter([a for a, b in zip(re_ans, d["RE"]) if not b])
                if len(c) == 0:
                    # print(f"{d['idx']} has no common ans!")
                    continue
                if c.most_common(1)[0][0] == str(target):
                    correct += 1
                # else:
                #     if str(target) in c and c[str(target)] == c.most_common(1)[0][1]:
                #         d["exec_ans"] = re_ans
                #         d["vote"] = c
                #         can_correct.append(d)
                #         print(f"idx: {d['idx']}, target: {target}, c:{c}, ans: {['RE' if b else a for a,b in zip(re_ans, d['RE'])]}")
            #     weight = [0 if _ else 1 for _ in d["RE"]]
            # else:
            #     v_loss = [v[0] for v in d['verfier_loss']]
            #     if args.opt == "verfier_loss":
            #         weight = v_loss
            #     if args.opt == "lever":
            #         loss = softmax(d['loss'])
            #         weight = [l * v_l for l, v_l in zip(loss, v_loss)]
            # for a, w in zip(re_ans, weight):
            #     ws[a] += w
            # max_value = max(ws.values())
            # if max_value != 0 and max_value == ws[str(target)]:
            #     correct += 1
            # else:
            #     print(f"idx: {d['idx']}, target: {target}, weight: {weight}, ans: {['RE' if b else a for a,b in zip(re_ans, d['RE'])]}")
    # print_can_correct(can_correct)
    # json.dump(can_correct, open(PATH + f"{args.middle_dir}/color_can_correct.json", "w"), indent=4)
    # print((correct + len(can_correct))/total)
    return correct/total, ds


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--datasets", type=str, default="color")
    parser.add_argument("--langs", type=str, default="Python")
    parser.add_argument("--output_suffix", type=str, default="_ans")
    parser.add_argument('--middle_dir', type=str, default='../outputs')
    parser.add_argument("--multi", action="store_true", default=False)
    parser.add_argument("--opt", type=str, default="best")
    parser.add_argument("--old", action="store_true", default=False)
    args = parser.parse_args()

    argsdict = vars(args)
    print(pprint.pformat(argsdict))
    datasets = args.datasets.strip().split(",")
    langs = args.langs.strip().split(",")
    
    for dataset in datasets:
        if args.multi:
            evaluate(dataset, langs, args)
        else:
            for lang in langs:
                evaluate(dataset, [lang], args)
