import json
import os
from collections import Counter, defaultdict
# import numpy as np
import argparse
import pprint
import random

PATH = "../"
# PATH = "/Users/dxm/Project/MultiPoT/"


def detailed_analyse(dataset, lang1, lang2):
    # r = json.load(open(PATH + f"/outputs/CodeLLM/{dataset}_{lang1}_Classcial_8_ans.json", "r"))
    # python = json.load(open(PATH + f"/outputs/CodeLLM/{dataset}_{lang2}_Classcial_8_ans.json", "r"))
    r = json.load(open(PATH + f"/analyse/same_language_overlap/gsm_Python_Classcial_8_sc_1.1.json", "r"))
    python = json.load(open(PATH + f"/analyse/same_language_overlap/gsm_Python_Classcial_8_sc_1.json", "r"))
    assert len(r) == len(python)

    data_dict_r = {item['input']: item for item in r}
    data_dict_python = {item['input']: item for item in python}
    
    total = len(r)
    correct = 0
    mismatches = []
    for input_text, item_r in data_dict_r.items():
        item_python = data_dict_python[input_text]
        if item_python['passed'] or item_r['passed']:
            correct += 1
        if item_python['passed'] != item_r['passed']:
            mismatches.append((item_r, item_python))
    
    # {lang1} wrong against {lang2}
    r_wrong = [item for item in mismatches if item[0]['passed'] is False]
    # {lang2} wrong against {lang1}
    python_wrong = [item for item in mismatches if item[1]['passed'] is False]

    output_path = PATH + f"analyse/same_language_overlap/"
    os.makedirs(output_path, exist_ok=True)
    json.dump(r_wrong, open(output_path + "r_wrong.json", "w"), indent=4)
    json.dump(python_wrong, open(output_path + "python_wrong.json", "w"), indent=4)
    with open(output_path + "README.md", "w") as writer:
        writer.write(f"Total in {dataset}: {total}\n")
        writer.write(f"Merge Accuracy: {correct/total}\n")
        writer.write(f"---\n\n")
        writer.write(f"{lang2} wrong against {lang1}: {len(python_wrong)}\n")
        for i,item in enumerate(python_wrong):
            writer.write(f"Q {i}: {item[0]['input']}\n")
            writer.write(f"Target: {item[0]['target']}\n")
            writer.write(f"{lang2} output: {item[1]['exec_ans']}\n\n")
            writer.write(f"```{lang2}\n{item[1]['code'][0]}\n```\n\n")
            writer.write(f"```{lang1}\n{item[0]['code'][0]}\n```\n\n")
        writer.write(f"---\n\n")
        writer.write(f"{lang1} wrong against {lang2} right: {len(r_wrong)}\n")
        for i,item in enumerate(r_wrong):
            writer.write(f"Q {i}: {item[0]['input']}\n")
            writer.write(f"Target: {item[0]['target']}\n")
            writer.write(f"{lang1} output: {item[0]['exec_ans']}\n\n")
            writer.write(f"```{lang1}\n{item[0]['code'][0]}\n```\n\n")
            writer.write(f"```{lang2}\n{item[1]['code'][0]}\n```\n\n")


def choose_best(dataset, langs, middle_dir):
    ds = []
    for lang in langs:
        ds.append(json.load(open(PATH + f"{args.middle_dir}/{dataset}_{lang}_Classcial_{args.num_examples}{args.output_suffix}.json", "r")))
    total = len(ds[0])
    correct = 0
    for i in range(total):
        passed = False
        for j in range(len(langs)):
            if ds[j][i]['input'] != ds[0][i]['input']:
                print("ERROR")
                exit()
            passed = passed or ds[j][i]['passed']
        # if passed:
        #     print(i)
        correct += int(passed)
    return correct/total
    # print(f"{dataset} Best Accuracy: {correct/total}")


def vote(dataset, langs, args):
    ds = []
    for lang in langs:
        ds.append(json.load(open(PATH + f"{args.middle_dir}/{dataset}_{lang}_Classcial_{args.num_examples}{args.output_suffix}.json", "r")))
    total = len(ds[0])
    correct = 0
    for i in range(total):
        target = ds[0][i]['target']
        ans = []
        for j in range(len(langs)):
            exec_ans = ds[j][i]['exec_ans'][0]
            if isinstance(exec_ans, float):
                ans.append(str(exec_ans))
        if len(ans) == 0:
            continue
        ans = Counter(ans)
        exec_ans = ans.most_common(1)[0][0]
        if abs(float(target) - float(exec_ans)) < 1e-3:
            correct += 1
    # print(f"{dataset} Vote Accuracy: {correct/total}")
    return correct/total


def vote_sc(dataset, lang, args):
    ds = json.load(open(PATH + f"{args.middle_dir}/{dataset}_{lang}_Classcial_{args.num_examples}{args.output_suffix}.json", "r"))
    total = len(ds)
    correct = 0
    for d in ds:
        target = d['target']
        ans = []
        for a in d['exec_ans']:
            if isinstance(a, float):
                ans.append(str(a))
        if len(ans) == 0:
            continue
        ans = Counter(ans)
        exec_ans = ans.most_common(1)[0][0]
        if abs(float(target) - float(exec_ans)) < 1e-3:
            correct += 1
    # print(f"{dataset}_{lang}_sc_5 Vote Accuracy: {correct/total}")
    return correct/total


def test_float2string():
    for dataset in ["gsm", "svamp", "asdiv"]:
        for lang in ["Python", "R", "C++", "Java", "Javascript"]:
            ds = json.load(open(PATH + f"/outputs/CodeLLM/34B/ans/{dataset}_{lang}_Classcial_8_sc_5_ans.json", "r"))
            for d in ds:
                for a in d['exec_ans']:
                    if isinstance(a, float):
                        if len(str(a)) > 15 and abs(d["target"]-a)<1e-3:
                            print(a)
                            print("target:", d["target"], "dataset:", dataset, "lang:", lang)


def softmax(logits):
    logits = np.array(logits)
    logits = np.exp(logits-np.max(logits))
    return logits / logits.sum()


def lever(dataset, langs, args):
    opt = args.opt
    ds = []
    for lang in langs:
        ds.append(json.load(open(PATH + f"{args.middle_dir}/{dataset}_{lang}_Classcial_{args.num_examples}{args.output_suffix}.json", "r")))
    total = len(ds[0])
    correct = 0
    for i in range(total):
        target = ds[0][i]['target']
        ans = []
        loss = []
        v_loss = []
        for j in range(len(langs)):
            exec_ans = ds[j][i]['exec_ans'][0]
            ans.append(str(exec_ans))
            loss.append(ds[j][i]['loss'][0])
            v_loss.append(ds[j][i]['verfier_loss'][0][0])
        if opt == "lever":
            loss = softmax(loss)
            loss = [l * v_l for l, v_l in zip(loss, v_loss)]
        elif opt == "verfier_loss":
            loss = v_loss
        elif opt == "lm_loss":
            loss = list(np.exp(np.array(loss)))
            cnt = Counter(ans)
        else:
            raise NotImplementedError
        d = defaultdict(float)
        for a, l in zip(ans, loss):
            d[a] += l
        # cnt = Counter(ans)
        # for a in d.keys():
        #     d[a] /= cnt[a]
        exec_ans = max(d, key=d.get)
        # if opt != "lm_loss":
        #     exec_ans = max(d, key=d.get)
        # else:
        #     lm_loss = -99999999
        #     exec_ans = -9999999
        #     for k, v in d.items():
        #         if v/cnt[k] > lm_loss:
        #             lm_loss = v/cnt[k]
        #             exec_ans = k
        try:
            if abs(float(target) - float(exec_ans)) < 1e-3:
                correct += 1
        except Exception as e:
            # print(json.dumps(d, indent=4))
            pass
    # print(f"{dataset} Lever Accuracy: {correct/total}")
    return correct/total


def lever_sc(dataset, lang, args):
    opt = args.opt
    ds = json.load(open(PATH + f"{args.middle_dir}/{dataset}_{lang}_Classcial_{args.num_examples}{args.output_suffix}.json", "r"))
    total = len(ds)
    correct = 0
    for d in ds:
        target = d['target']
        ans = []
        loss = []
        v_loss = []
        for i, a in enumerate(d['exec_ans']):
            ans.append(str(a))
            loss.append(d['loss'][i])
            v_loss.append(d['verfier_loss'][i][0])
        if opt == "lever":
            loss = softmax(loss)
            loss = [l * v_l for l, v_l in zip(loss, v_loss)]
        elif opt == "verfier_loss":
            loss = v_loss
        elif opt == "lm_loss":
            loss = list(np.exp(np.array(loss)))
            cnt = Counter(ans)
        else:
            raise NotImplementedError
        d = defaultdict(float)
        for a, l in zip(ans, loss):
            d[a] += l
        exec_ans = max(d, key=d.get)
        # if opt != "lm_loss":
        #     exec_ans = max(d, key=d.get)
        # else:
        #     lm_loss = -99999999
        #     exec_ans = -9999999
        #     for k, v in d.items():
        #         if v/cnt[k] > lm_loss:
        #             lm_loss = v/cnt[k]
        #             exec_ans = k
        try:
            # 有的题目会导致大家都RE，这时候就没有exec_ans
            if abs(float(target) - float(exec_ans)) < 1e-3:
                correct += 1
        except Exception as e:
            # print(json.dumps(d, indent=4))
            pass
    # print(f"{dataset}_{lang}_sc_5 Lever Accuracy: {correct/total}")
    return correct/total


def calculate(dataset, ds, label="merge"):
    total = len(ds)
    if label == "single":
        correct = 0
        for d in ds:
            ans, target = d[0][0], d[1]
            if isinstance(ans, str):
                continue
            if abs(float(target) - float(ans)) < 1e-3:
                correct += 1
        print(f"{dataset} Single Accuracy: {correct}/{total} : {correct/total}")
    elif label == "merge":
        vote_correct = 0
        upper_bound_correct = 0
        verifier_correct = 0
        for d in ds:
            vote_d = defaultdict(float)
            verifier_d = defaultdict(float)
            ans, target, v_loss = d[0], d[1], d[2]
            exist = False
            for a, l in zip(ans, v_loss):
                if isinstance(a, float):
                    if abs(float(target) - float(a)) < 1e-3:
                        exist = True
                    a = str(round(a, 6))
                    vote_d[a] += 1
                    verifier_d[a] += l
            if len(vote_d) == 0:
                continue
            vote_ans = max(vote_d, key=vote_d.get)
            verifier_ans = max(verifier_d, key=verifier_d.get)
            
            upper_bound_correct += int(exist)
            if abs(float(target) - float(vote_ans)) < 1e-3:
                vote_correct += 1
            if abs(float(target) - float(verifier_ans)) < 1e-3:
                verifier_correct += 1
        # print(f"{dataset} Vote Accuracy: {vote_correct/total}")
        # print(f"{dataset} Upper Bound Accuracy: {upper_bound_correct/total}")
        print(f"{dataset} Verifier Accuracy: {verifier_correct}/{total} : {verifier_correct/total}")
            

def math(langs, middle_dir, label):
    dss = {}
    for lang in langs:
        ds = json.load(open(PATH + f"{middle_dir}/math_{lang}_Classcial_8_sc_5_ans_verifier.json", "r"))
        dss[lang] = ds
    extracts = defaultdict(list)
    for i in range(len(ds)):
        ans = []
        v_loss = []
        for lang in langs:
            ans.extend(dss[lang][i]["exec_ans"])
            v_loss.extend([v[0] for v in dss[lang][i]["verfier_loss"]])
            # v_loss.extend([1,1,1,1,1])
        target = ds[i]["target"]
        type_ = ds[i]["type"]
        extracts[type_].append((ans, target, v_loss))
    for type_, items in extracts.items():
        calculate(type_, items, label=label)


def get_file_name(dataset, lang, args):
    path = PATH + f"{args.middle_dir}/{dataset}_{lang}_{args.examples}_{args.num_examples}{args.output_suffix}.json"
    return path + ".bak" if args.bak and os.path.exists(path + ".bak") else path


def evaluate(dataset, langs, args):
    ds = json.load(open(get_file_name(dataset, langs[0], args), "r"))
    for lang in langs[1:]:
        cur_ds = json.load(open(get_file_name(dataset, lang, args), "r"))
        for d, cd in zip(ds, cur_ds):
            d["exec_ans"].extend(cd["exec_ans"])
            d["code"].extend(cd["code"])
            if "loss" in d:
                d["loss"].extend(cd["loss"])
            if "verfier_loss" in d:
                d["verfier_loss"].extend(cd["verfier_loss"])
    if len(langs) > 1:
        for d in ds:
            if "loss" in d:
                d["loss"], d["exec_ans"], d["code"] = zip(*sorted(zip(d["loss"], d["exec_ans"], d["code"]), reverse=True))
            else:
                random.shuffle(d["exec_ans"])
    ans, ds = calcu(ds, args)
    if len(langs) == 1:
        json.dump(ds, open(get_file_name(dataset, langs[0], args), "w"), indent=4)
    print(ans)
    # if dataset != "math":
    #     print(calcu(ds, args))
    # else:
    #     extracts = defaultdict(list)
    #     for d in ds:
    #         extracts[d['type']].append(d)
    #     # ans = []
    #     for type_, items in extracts.items():
    #         print(calcu(items, args))
    #     print(calcu(ds, args))
    #     print("\n")
            # ans.append([type_, len(items), calcu(items, args)])
        # for a in ans:
        #     print(a[0], a[1])
        # for a in ans:
        #     print(a[2])


def calcu(ds, args):
    total = len(ds)
    correct = 0
    for d in ds:
        target = d['target']
        re_ans = []
        for a in d['exec_ans']:
            if isinstance(a, str):
                re_ans.append(a)
            else:
                if "percent" in d["input"] and a < 1:
                    a = a * 100
                if abs(a - target) < 1e-3:
                    re_ans.append(round(target,3))
                else:
                    re_ans.append(round(a, 3))
        if args.opt == "best":
            if round(target, 3) in re_ans:
                correct += 1
                d["passed"] = True
            else:
                d["passed"] = False
        else:
            # ws = defaultdict(float)
            if args.opt == "vote":
                c = Counter([a for a in re_ans if not isinstance(a, str)])
                if len(c) == 0:
                    continue
                if c.most_common(1)[0][0] == round(target, 3):
                    correct += 1
            #     weight = [int(not isinstance(a, str)) for a in re_ans]
            # else:
            #     v_loss = [v[0] for v in d['verfier_loss']]
            #     if args.opt == "verfier_loss":
            #         weight = v_loss
            #     if args.opt == "lever":
            #         loss = softmax(d['loss'])
            #         weight = [l * v_l for l, v_l in zip(loss, v_loss)]
            # for a, w in zip(re_ans, weight):
            #     ws[str(a)] += w
            # exec_ans = max(ws, key=ws.get)
            # if ws[exec_ans] != 0 and str(round(target, 3)) == exec_ans:
            #     correct += 1
    # print(correct/total)
    
    return correct/total, ds


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--datasets", type=str, default="math")
    parser.add_argument("--langs", type=str, default="Python")
    parser.add_argument("--examples", type=str, default="Classcial")
    parser.add_argument("--output_suffix", type=str, default="sc_40_ans")
    parser.add_argument("--num_examples", type=int, default=3)
    parser.add_argument('--middle_dir', type=str, default='outputs_deepseek')
    parser.add_argument("--multi", action="store_true", default=False)
    parser.add_argument("--bak", action="store_true", default=False)
    parser.add_argument("--opt", type=str, default="vote")
    args = parser.parse_args()

    argsdict = vars(args)
    print(pprint.pformat(argsdict))
    datasets = args.datasets.strip().split(",")
    langs = args.langs.strip().split(",")
    
    for dataset in datasets:
        if args.multi:
            evaluate(dataset, langs, args)
        else:
            for lang in langs:
                evaluate(dataset, [lang], args)

    # dataset = "asdiv"
    # langs = ["Python", "R", "C++", "Java", "Javascript"]
    # langs = ["C++"]
    # choose_best("asdiv", langs)
    # choose_best("gsm", langs)
    # choose_best("svamp", langs)
    # vote("asdiv", langs)
    # vote("gsm", langs)
    # vote("svamp", langs)
    # datasets = ["gsm", "svamp", "asdiv"]
    # datasets = ["gsm"]
    # detailed_analyse("gsm", "python1", "python2")
    # test_float2string()
    # print("MultiPoT")
    # math(langs, "outputs/CodeLLM/34B/with_loss", "merge")
    # for lang in langs:
    #     print(lang)
    #     math([lang], "outputs/CodeLLM/34B/with_loss", "merge")
        
    # for dataset in datasets:
    #     lever(dataset, langs, "outputs/CodeLLM/34B/with_loss", opt="verfier_loss")
        # vote(dataset, langs, "outputs/CodeLLM/34B/with_loss")
        # choose_best(dataset, langs, "outputs/CodeLLM/34B/with_loss")
        # print("MultiPoT")
        # math(langs, "outputs/CodeLLM/34B/with_loss", "merge")
        # for lang in langs:
        #     print(lang)
        #     math([lang], "outputs/CodeLLM/34B/with_loss", "single")
            # vote_sc(dataset, lang, "outputs/CodeLLM/34B/with_loss")
            # lever_sc(dataset, lang, "outputs/CodeLLM/34B/with_loss", opt="verfier_loss")
            