import argparse
import time
import sys
import os

sys.path.append("..")
from tqdm import tqdm
from MassConfig.ConfigBase import *
from Trainer import product_trainer

def config_parser():
    parser = argparse.ArgumentParser()
    # hyper params
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--n_epochs", type=int, default=3)
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--lr_bert", type=float, default=5e-6)
    parser.add_argument("--lr_crf", type=float, default=1e-2)
    parser.add_argument("--loss_alpha", type=float, default=0.5)
    parser.add_argument("--LEN_ALL_TAGS", type=int, default=len(ALL_TAGS))
    parser.add_argument("--HIDDEN_DIM", type=int, default=HIDDEN_DIM)
    parser.add_argument("--weight_loss", dest="weight_loss", action="store_true")

    # model config
    parser.add_argument("--trainer_name", type=str, default="sim_trainer")
    parser.add_argument("--model_name", type=str, default="MyCrossLingualMain")
    parser.add_argument("--criterion_name", type=str, default="BCELoss")
    parser.add_argument("--optim_name", type=str, default="Adam")
    parser.add_argument("--top_rnn", dest="top_rnn", action="store_true")
    parser.add_argument("--rnn_name", type=str, default="RNN")
    parser.add_argument("--top_crf", dest="top_crf", action="store_true")
    parser.add_argument("--left_isner", dest="left_isner", action="store_true")
    parser.add_argument("--low_resource",  dest="low_resource", action="store_true")
    parser.add_argument("--n_clusters", type=int, default=9)
    parser.add_argument("--only_siamese", dest="only_siamese", action="store_true")
    parser.add_argument("--self_learn", dest="self_learn", action="store_true")
    parser.add_argument("--add_sim", dest="add_sim", action="store_true")
    parser.add_argument("--tgt_lang", type=str, default="es")
    parser.add_argument("--add_cluster", dest="add_cluster", action="store_true")
    parser.add_argument("--add_visual", dest="add_visual", action="store_true")
    parser.add_argument("--knn_k", type=int, default=9)
    parser.add_argument("--add_spc", dest="add_spc", action="store_true")
    parser.add_argument("--spc_tem", type=float, default=0.07)
    parser.add_argument("--fine_tune", dest="fine_tune", action="store_true")
    parser.add_argument("--new_num", type=str, default="")


    # data config
    parser.add_argument("--dataloader_name", type=str, default="SiameseLoader")
    parser.add_argument("--trainset", type=str, default="../data/CoNLL2003/train.txt")
    parser.add_argument("--validset", type=str, default="../data/CoNLL2003/valid.txt")
    parser.add_argument("--testset", type=str, default="../data/CoNLL2003/test.txt")
    parser.add_argument("--trainset_params", type=str, default="200_T_T_F")
    parser.add_argument("--tgt_trainset_params", type=str, default="200_T_T_F")
    parser.add_argument("--validset_params", type=str, default="200_T_F_F")
    parser.add_argument("--testset_params", type=str, default="200_T_F_T")
    parser.add_argument("--validset_sim_params", type=str, default="200_T_F_F")
    parser.add_argument("--testset_sim_params", type=str, default="200_T_F_T")
    parser.add_argument("--target_trainset", type=str, default="../data/CoNLL2003/train.txt")


    return parser

def main():
    # build trainer, reload potential checkpoints / build evaluator
    trainer = product_trainer(hp)

    # training
    for i in tqdm(range(1, hp.n_epochs + 1, 1)):
        # train set : labeled data: english
        trainer.train_epoch(i)

    # end of epoch
    trainer.record_result()

if __name__ == '__main__':
    # parse parameters
    parser = config_parser()
    hp = parser.parse_args()

    # check parameters

    # run experiment
    main()


