import argparse
import pytorch_lightning as pl
import os
import sys
sys.path.append('src')
from models.nli_model import NLIModel
from models.multnat_model import MultNatModel
from train.utils import add_dataset_specific_args, load_custom_data

str2model = {
    "NLI": NLIModel,
    "MultNat": MultNatModel
}

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model', choices=str2model.keys())
    parser.add_argument('checkpoint')
    parser.add_argument('--classification_threshold', type=float, default=None)
    parser.add_argument('--gpus', type=int, nargs='+', default=[])
    parser.add_argument('--range', type=str, default='all', help='[all / en / zh / ant / dirlbl / symlbl / dirsym / dirsym_posi]')
    parser.add_argument('--is_directional', action='store_true')
    add_dataset_specific_args(parser)
    args = parser.parse_args()
    print(args)

    if len(args.range) == 0:
        args.range = 'all'
    assert args.range in ['all', 'en', 'zh', 'zhraw', 'dirsym_posi', 'dirsym_negi', 'dirsym_cross', 'hypo_only', 'zh_hypoonly', 'hypo_only_dirsym',
                          'hypo_only_dirsym_cross', 'hypo_only_oneposi_allnegi', 'qaeval_hyponly_namearg_freqmap',
                          'qaeval_hyponly_typearg_freqmap', 'clue_qaeval_hyponly_typearg_freqmap', 'clue_qaeval_hyponly_namearg_freqmap', 'nc_qaeval_all',
                          'clue_qaeval_all']

    cls = str2model[args.model]
    model = cls.load_from_checkpoint(args.checkpoint)
    print(f"Model loaded.")

    if args.classification_threshold is not None:
        model.set_classification_threshold(args.classification_threshold)

    ckpt_root = '/'.join(args.checkpoint.split('/')[:-1])

    if args.range in ['all', 'en']:
        print(f"Testing on: '../datasets/data_en_levyholt/full/test.txt'")
        model.set_minimum_precision(0.21910)  # 2831 / 12921
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_levyholt_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_levyholt_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/full/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['all', 'en']:
        print(f"Testing on: '../datasets/data_en_sherliic/levy_holt/test.txt'")
        model.set_minimum_precision(0.33255)  # 994 / 2989
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_sherliic_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_sherliic_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_sherliic/levy_holt/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['all', 'zh', 'zhraw']:
        print(f"Testing on: '../datasets/data_zh_levyholt/full/test.txt'")
        model.set_minimum_precision(0.21910)  # 2831 / 12921
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'zh_levyholt_raw_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'zh_levyholt_raw_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_zh_levyholt/full/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['all', 'zh', 'zhraw']:
        print(f"Testing on: '../datasets/data_sherliic_raw_raw/levy_holt/test.txt'")
        model.set_minimum_precision(0.33255)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'zh_sherliic_raw_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'zh_sherliic_raw_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_sherliic_raw_raw/levy_holt/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['all', 'en']:
        print(f"Testing on: '../datasets/data_en_levyholt/directional/test.txt'")
        model.set_minimum_precision(0.5)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_levyholt_dir_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_levyholt_dir_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/directional/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['all', 'zh', 'zhraw']:
        print(f"Testing on: '../datasets/data_zh_levyholt/directional/test.txt'")
        model.set_minimum_precision(0.5)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'zh_levyholt_raw_dir_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'zh_levyholt_raw_dir_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_zh_levyholt/directional/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['all', 'en']:
        print(f"Testing on: '../datasets/data_en_levyholt/symmetric/test.txt'")
        model.set_minimum_precision(0.1741)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_levyholt_sym_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_levyholt_sym_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/symmetric/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['all', 'zh', 'zhraw']:
        print(f"Testing on: '../datasets/data_zh_levyholt/symmetric/test.txt'")
        model.set_minimum_precision(0.1741)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'zh_levyholt_raw_sym_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'zh_levyholt_raw_sym_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_zh_levyholt/symmetric/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['dirsym_posi']:
        print(f"Testing on: '../datasets/data_en_levyholt/dirsym_posi/test.txt'")
        model.set_minimum_precision(0.3500)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_levyholt_dirsym_posi_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_levyholt_dirsym_posi_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/dirsym_posi/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['dirsym_negi']:
        print(f"Testing on: '../datasets/data_en_levyholt/dirsym_negi/test.txt'")
        model.set_minimum_precision(0.091)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_levyholt_dirsym_negi_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_levyholt_dirsym_negi_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/dirsym_negi/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['dirsym_cross']:
        print(f"Testing on: '../datasets/data_en_levyholt/dirposi_symnegi/test.txt'")
        model.set_minimum_precision(0.0884)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_levyholt_dirposi_symnegi_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_levyholt_dirposi_symnegi_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/dirposi_symnegi/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['dirsym_cross']:
        print(f"Testing on: '../datasets/data_en_levyholt/dirnegi_symposi/test.txt'")
        model.set_minimum_precision(0.6849)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_levyholt_dirnegi_symposi_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_levyholt_dirnegi_symposi_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/dirnegi_symposi/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['hypo_only']:
        print(f"Testing on: '../datasets/data_en_levyholt/directional_hypo_only/test.txt'")
        model.set_minimum_precision(0.5)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'endir_levyholt_hypoonly_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'endir_levyholt_hypoonly_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/directional_hypo_only/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['hypo_only']:
        print(f"Testing on: '../datasets/data_en_levyholt/full_hypo_only/test.txt'")
        model.set_minimum_precision(0.2191)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_levyholt_hypoonly_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_levyholt_hypoonly_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/full_hypo_only/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['hypo_only']:
        print(f"Testing on: '../datasets/data_en_levyholt/symmetric_hypo_only/test.txt'")
        model.set_minimum_precision(0.1741)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'ensym_levyholt_hypoonly_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'ensym_levyholt_hypoonly_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/symmetric_hypo_only/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['zh_hypoonly']:
        print(f"Testing on: '../datasets/data_zh_levyholt/directional_hypo_only/test.txt'")
        model.set_minimum_precision(0.5)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'zhdir_levyholt_hypoonly_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'zhdir_levyholt_hypoonly_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_zh_levyholt/directional_hypo_only/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['zh_hypoonly']:
        print(f"Testing on: '../datasets/data_zh_levyholt/full_hypo_only/test.txt'")
        model.set_minimum_precision(0.21910)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'zh_levyholt_hypoonly_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'zh_levyholt_hypoonly_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_zh_levyholt/full_hypo_only/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['zh_hypoonly']:
        print(f"Testing on: '../datasets/data_zh_levyholt/symmetric_hypo_only/test.txt'")
        model.set_minimum_precision(0.1741)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'zhsym_levyholt_hypoonly_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'zhsym_levyholt_hypoonly_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_zh_levyholt/symmetric_hypo_only/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['hypo_only_dirsym']:
        print(f"Testing on: '../datasets/data_en_levyholt/dirsym_posi_hypo_only/test.txt'")
        model.set_minimum_precision(0.3500)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_dirsym_posi_hypoonly_levyholt_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_dirsym_posi_hypoonly_levyholt_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/dirsym_posi_hypo_only/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['hypo_only_dirsym']:
        print(f"Testing on: '../datasets/data_en_levyholt/dirsym_negi_hypo_only/test.txt'")
        model.set_minimum_precision(0.091)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_dirsym_negi_hypoonly_levyholt_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_dirsym_negi_hypoonly_levyholt_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/dirsym_negi_hypo_only/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['hypo_only_dirsym_cross']:
        print(f"Testing on: '../datasets/data_en_levyholt/dirposi_symnegi_hypo_only/test.txt'")
        model.set_minimum_precision(0.0884)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_dirposi_symnegi_hypoonly_levyholt_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_dirposi_symnegi_hypoonly_levyholt_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/dirposi_symnegi_hypo_only/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['hypo_only_dirsym_cross']:
        print(f"Testing on: '../datasets/data_en_levyholt/dirnegi_symposi_hypo_only/test.txt'")
        model.set_minimum_precision(0.6849)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'en_dirnegi_symposi_hypoonly_levyholt_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_dirnegi_symposi_hypoonly_levyholt_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/data_en_levyholt/dirnegi_symposi_hypo_only/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['qaeval_hyponly_mckenna', 'nc_qaeval_all']:
        print(f"Testing on: '../datasets/mckenna_dataset/hypoonly_typearg_lhsize/test.txt'")
        model.set_minimum_precision(0.5425)
        model.set_pr_rec_curve_path(
            os.path.join(ckpt_root, 'en_qaeval_mckenna_hyponly_typearg_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'en_qaeval_mckenna_hyponly_typearg_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/mckenna_dataset/hypoonly_typearg_lhsize/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['qaeval_hyponly_namearg_freqmap', 'nc_qaeval_all']:
        print(f"Testing on: '../datasets/booqa_en/hypoonly_namearg_lhsize/test.txt'")
        model.set_minimum_precision(0.3627)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'booqa_en_hyponly_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'booqa_en_hyponly_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model,
                                      '../datasets/booqa_en/hypoonly_namearg_lhsize/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['qaeval_hyponly_typearg_freqmap', 'nc_qaeval_all']:
        print(f"Testing on: '../datasets/booqa_en/hypoonly_typearg_lhsize/test.txt'")
        model.set_minimum_precision(0.3581)
        model.set_pr_rec_curve_path(os.path.join(ckpt_root, 'booqa_en_hyponly_typearg_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'booqa_en_hyponly_typearg_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model, '../datasets/booqa_en/hypoonly_typearg_lhsize/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['clue_qaeval_hyponly_namearg_freqmap', 'clue_qaeval_all']:
        print(f"Testing on: '../datasets/booqa_zh/hypoonly_namearg_lhsize/test.txt'")
        model.set_minimum_precision(0.3737)
        model.set_pr_rec_curve_path(
            os.path.join(ckpt_root, 'booqa_zh_namearg_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'booqa_zh_hyponly_namearg_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model,
                                      '../datasets/booqa_zh/hypoonly_namearg_lhsize/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer

    if args.range in ['clue_qaeval_hyponly_typearg_freqmap', 'clue_qaeval_all']:
        print(f"Testing on: '../datasets/booqa_zh/hypoonly_typearg_lhsize/test.txt'")
        model.set_minimum_precision(0.3777)
        model.set_pr_rec_curve_path(
            os.path.join(ckpt_root, 'booqa_zh_hyponly_typearg_test_pr_rec.txt'))
        model.set_score_outfile(os.path.join(ckpt_root, 'booqa_zh_hyponly_typearg_test_Y.txt'))
        dataloader = load_custom_data(args, args.model, model,
                                      '../datasets/booqa_zh/hypoonly_typearg_lhsize/test.txt')
        trainer = pl.Trainer(gpus=args.gpus, logger=False)
        trainer.test(model, test_dataloaders=dataloader)

        del dataloader
        del trainer
