import sys
sys.path.append('../')

from emb2emb.train import get_train_parser, train
import csv, os

if __name__ == "__main__":
    parser = get_train_parser()
    parser.add_argument("--csv_file", type=str, default="out.csv")
    parser.add_argument("--lambdas", nargs = "+", type=float, default=[0.5])
    parser.add_argument("--ts", nargs = "+", type=float, default=[])
    parser.add_argument("--root", type=str, help="root folder of data directory", default = "../data/")
    parser.add_argument("--tasks", nargs="+", type=str, default = ["yelp"])
    parser.add_argument("--epochs", nargs="+", type=int, default = [5])
    parser.add_argument("--val_freqs", nargs="+",type=int, default = [-1])
    params, unknown = parser.parse_known_args()
    
    if len(unknown) > 0:
        raise ValueError("Got unknown parameters " + str(unknown))
    
    LAMBDAS = params.lambdas
    THRESHOLDS = params.ts
    TASKS = params.tasks
    EPOCHS = params.epochs
    VAL_FREQUENCIES = params.val_freqs
    ROOT_FOLDER = params.root
    
    # execute each task
    final_results = {}
    for i, t in enumerate(TASKS):
        params.dataset_path = os.path.join(ROOT_FOLDER, t)
        params.n_epochs = EPOCHS[i]
        
        params.validation_frequency = VAL_FREQUENCIES[i]
        params.real_data_path = os.path.join(ROOT_FOLDER, t, "all_train")
        
        results_list = []
        for i, l in enumerate(LAMBDAS):
            params.lambda_clfloss = l
            if len(THRESHOLDS) > 0:
                params.fast_gradient_iterative_modification = True
                params.fgim_threshold = THRESHOLDS[i]
            results = train(params)
            print("Task {}; lr {}; dev-score {}".format(t, l, results["dev"]))
            results_list.append(results)
        
        cur_dev = -float('inf')
        cur_test = -float('inf')
        for r in results_list:
            if r["dev"] > cur_dev:
                cur_dev, cur_test = r["dev"], r["test"]
        
        final_results[t+"_dev"] = cur_dev
        final_results[t+"_test"] = cur_test
    
    
    fields=['model'] + [t + "_dev" for t in TASKS] + [t + "_test" for t in TASKS]
    if os.path.exists(params.csv_file):
        mode = "a"
    else:
        mode = "w"
    with open(params.csv_file, mode) as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fields)
    
        if mode == "w":
            writer.writeheader()
        
        final_results["model"] = os.path.basename(params.modeldir)
        writer.writerow(final_results)
        
