import argparse
import sys
import os
import numpy as np
np.seterr(divide='ignore', invalid='ignore')  

import multiprocessing as mp

from itertools import chain

from makePolarEmbedding import has_edge



def eval(tuples):
    angles, rs, vx1s, vx2s, vy, c = tuples
    pred = has_edge(angles, rs, vx1s, vx2s, c)
    sum_pred = np.sum(pred)
    precision = np.sum(pred[vy == 1]) / sum_pred

    sum_vy = np.sum(vy)
    recall = np.sum(pred[(vy == 1)]) / sum_vy
    f = 2 * (precision * recall) / (precision + recall)
    if np.isnan(f):
        f = 0.0
    return (c, f)


if __name__ == "__main__":
    args = argparse.ArgumentParser()

    args.add_argument('--model', required=True)
    args.add_argument('--rfile')
    opt = args.parse_args()

    C_GRID = [i/10 for i in range(1, 21,1)]  

   # assert len(sys.argv) == 1, "Usage: python testPolarEmbedding_DB.py MODEL"

    # model load
    embpath = opt.model
    npz = np.load(embpath, allow_pickle=True)
    rs = npz['arr_0']
    angles = npz['arr_1']

    if opt.rfile:
        values = []
        with open(opt.rfile) as f:
            values = [float(r) for r in f]
        assert len(values) == 82114
        rs = np.array(values)

    basedir ="data/noun"

    validfile_pos = os.path.join(basedir, "noun_closure.tsv.valid")
    validfile_neg = validfile_pos + "_neg"

    # grid search
    validation_samples = []
    with open(validfile_pos) as pf, open(validfile_neg) as nf:
        for l in pf:
            validation_samples.append(list(map(int, l.strip().split())) + [1])
        for l in nf:
            validation_samples.append(list(map(int, l.strip().split())) + [0])

    vx1s, vx2s, vy = zip(*validation_samples)
    vx1s = np.array(vx1s)
    vx2s = np.array(vx2s)
    vy = np.array(vy)

    N = len(C_GRID) 
    i = 0
    scores = {}


    tuples = [(angles, rs, vx1s, vx2s, vy, c) for c in C_GRID] 
    with mp.Pool(mp.cpu_count()) as p:
        results = p.map(eval, tuples)
    scores = {r[0]:r[1] for r in results}

    # single core
    # for a_th in A_GRID:
    #     for r_th in R_GRID:
    #         i += 1
    #         print("Searching ... {}/{}".format(i, N), end="\r")
    #         pred = has_edge(angles, rs, vx1s, vx2s, a_th, r_th)
    #         precision = np.sum(pred[vy == 1]) / np.sum(pred)
    #         recall = np.sum(pred[(vy == 1)]) / np.sum(vy)
    #         f = 2 * (precision * recall) / (precision + recall)
    #         if np.isnan(f):
    #             f = 0.0
    #         scores[(a_th, r_th)] = f

    param, f = [(k, v) for (k, v) in sorted(scores.items(), key=lambda x: x[1], reverse=True)][0]
    print("Best parameters:")
    print("c : ", param)
    print("F1: {:.3f}".format(f))
    print()

    # run test
    testfile_pos = os.path.join(basedir, "noun_closure.tsv.test")  
    testfile_neg = testfile_pos + "_neg"

    test_samples = []
    with open(testfile_pos) as pf, open(testfile_neg) as nf:
        for l in pf:
            test_samples.append(list(map(int, l.strip().split())) + [1])
        for l in nf:
            test_samples.append(list(map(int, l.strip().split())) + [0])

    tx1s, tx2s, ty = zip(*test_samples)
    tx1s = np.array(tx1s)
    tx2s = np.array(tx2s)
    ty = np.array(ty)


    pred = has_edge(angles, rs, tx1s, tx2s, param)
    precision = np.sum(pred[ty == 1]) / np.sum(pred)
    recall = np.sum(pred[(ty == 1)]) / np.sum(ty)
    f = 2 * (precision * recall) / (precision + recall)

    print("Test score")
    print("Precision: {:.3f}".format(precision))
    print("Recall   : {:.3f}".format(recall))
    print("F1       : {:.3f}".format(f))
