import json
import argparse
import numpy as np
import matplotlib.pyplot as plt


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--results_json", default="../../pull_results.json", type=str)
    parser.add_argument("--pattern", default=None, type=str)
    parser.add_argument("--x_key", required=True, type=str)
    parser.add_argument("--y_key", required=True, type=str)
    parser.add_argument("--3d", action='store_true')
    parser.add_argument("--old_parser", action='store_true')
    parser.add_argument("--point", action='store_true')
    parser.add_argument("--pass_threshold", default=None, type=float)
    return parser.parse_args()


def visualization_3d(x, y, z, args):
    ax = plt.axes(projection='3d')
    ax.set_xlabel(args.x_key)
    ax.set_ylabel(args.y_key)
    ax.set_zlabel("perf")

    if args.point:
        ax.scatter(x, y, z, c=z, cmap='viridis', linewidth=0.5)
    else:
        ax.plot_trisurf(x, y, z, cmap='viridis', edgecolor='none')
    plt.show()


def parse(job_name, args):
    all_keys = {}
    if args.old_parser:
        a, b = job_name.split("_wmu")
        _, a = a.split("_bsz")
        bsz = a.split('-')[0]
        all_keys["batch_size"] = int(bsz)
        lr_epochs = a[len(bsz) + 1:]
        all_keys["lr"], all_keys["epoch"] = lr_epochs.split('x')
        all_keys["warmup"], all_keys["weight_decay"] = b.split('_')[:2]
        all_keys["weight_decay"] = all_keys["weight_decay"][2:]
        if "ld" in all_keys["weight_decay"]:
            all_keys["weight_decay"], all_keys["layer_decay"] = all_keys["weight_decay"].split("ld")
    else:
        # example: run_mgn1_bsz32_lr3e-6_epoch3_wmu0.1_wd0.01_ld1_d0.1_c0.1
        parts = job_name.split('_')
        for part in parts:
            name = ""
            for c in part:
                if str.isdigit(c):
                    break
                else:
                    name += c

            if len(name) < len(part):
                value = float(part[len(name):])
                assert name not in all_keys
                all_keys[name] = value

    for key in all_keys:
        if isinstance(all_keys[key], str):
            all_keys[key] = float(all_keys[key])

    return all_keys[args.x_key], all_keys[args.y_key]


def main():
    args = get_args()
    with open(args.results_json, mode="r", encoding="utf-8") as reader:
        results = json.loads(reader.read())

    x, y, z = [], [], []
    for job_name in results:
        _x, _y = parse(job_name, args)
        perf = results[job_name]
        if args.pass_threshold is not None:
            perf = [p for p in perf if p > args.pass_threshold]
        if len(perf) == 0:
            continue
        mean = sum(perf) / len(perf)
        z.append(mean)
        x.append(_x)
        y.append(_y)

    visualization_3d(x, y, z, args)


if __name__ == '__main__':
    main()
