import os
import json
import glob
import argparse
import collections

import numpy as np


def get_results(perf, keys):
    if len(keys) == 0:
        return perf
    return get_results(perf[keys[0]], keys[1:])


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", default=None, type=str, required=True)
    parser.add_argument("--merge", action='store_true', help="Whether to run training.")
    parser.add_argument("--max", action='store_true', help="Only use the max score for each group. ")
    parser.add_argument("--x", type=float, default=1.0)
    parser.add_argument("--ban", type=str, default=None)
    parser.add_argument("--key", default="best_f1", type=str)
    parser.add_argument("--results_json", default=None, type=str)
    args = parser.parse_args()

    result_files = glob.glob(args.input)

    print("Result files:")
    for r in result_files:
        print(r)
    
    if args.ban is not None:
        ban_list = args.ban.split(";")
    else:
        ban_list = []
    result = collections.defaultdict(list)

    for result_file in result_files:
        with open(result_file, mode="r", encoding="utf-8") as reader:
            base_name = os.path.dirname(result_file).split('/')[-1]
            ckpt = base_name
            if args.merge:
                ckpt = "_".join(ckpt.split('_')[:-1])
            else:
                ckpt = "all"
            need_skip = False
            for ban in ban_list:
                if ban in ckpt:
                    need_skip = True
            if not need_skip:
                perf = json.loads(reader.read())
                if isinstance(perf, list):
                    best_perf = None
                    for perf_per_epoch in perf:
                        this_perf = args.x * get_results(perf_per_epoch, args.key.split(';'))
                        if best_perf is None or this_perf > best_perf:
                            best_perf = this_perf
                    result[ckpt].append(best_perf)
                else:
                    result[ckpt].append(args.x * get_results(perf, args.key.split(';')))

    best_ckpt = None
    best_mean = None

    for ckpt in sorted(result.keys()):
        # assert len(result[ckpt]) == 4
        perf = result[ckpt]
        if args.max:
            perf = [max(perf)]
        perfs = []
        for p in perf:
            perfs.append("%.3f" % p)
        mean = sum(perf) / len(perf)
        print("CKTP = %s Ave = %.3f, Std = %.3f, Max = %.3f, Min = %.3f, all = %s" %
                    (ckpt, sum(perf) / len(perf), np.std(perf), max(perf), min(perf), " ".join(perfs)))

        if best_mean is None or mean > best_mean:
            best_mean = mean
            best_ckpt = ckpt

    print("Best ckpt = {} and perf = {}".format(best_ckpt, best_mean))

    if args.results_json:
        with open(args.results_json, mode="w", encoding="utf-8") as writer:
            writer.write(json.dumps(result, indent=2))
