"""Run the Hierarchical-Table-to-Text task.
1. train: train a model, using optimizer loss
2. valid & test (w/o logging): do generation (optionally write to file), calc the metrics.
"""

import os
import argparse

from utils import get_dataset_path, get_run_dir
from experiment.train import TrainFunctionDict
from experiment.test import TestFunctionDict


# argument suites
ArgumentSuite = {  # 'experiment_name': {'tokenizer_name', 'model_name'}
    't5': {
        'tokenizer_name': 't5-base', 
        'model_name': 't5-base', 
        'per_device_train_batch_size': 8, 
        'per_device_eval_batch_size': 8, 
        'learning_rate': 1e-4, 
        'num_train_epochs': 20, 
    }, 
    'bart': {
        'tokenizer_name': 'facebook/bart-base', 
        'model_name': 'facebook/bart-base', 
        'per_device_train_batch_size': 8, 
        'per_device_eval_batch_size': 8, 
        'learning_rate': 1e-4, 
        'num_train_epochs': 20, 
    }, 
    'b2b': {
        'tokenizer_name': 'bert-base-uncased', 
        'model_name': 'bert-base-uncased', 
        'per_device_train_batch_size': 2, 
        'per_device_eval_batch_size': 2, 
        'learning_rate': 3e-5, 
        'num_train_epochs': 50, 
    }, 
    'pg': {
        'tokenizer_name': None, 
        'model_name': None, 
        'per_device_train_batch_size': 2, 
        'per_device_eval_batch_size': 2, 
        'learning_rate': 0.05, 
        'num_train_epochs': 100, 
    }, 
}


def main():
    parser = argparse.ArgumentParser()

    # i/o
    parser.add_argument('--dataset_subdir', type=str, default='../data')
    parser.add_argument('--split_method', type=str, default='table', 
        choices=['subsent', 'table', 'page', 'domain', 'website'])
    parser.add_argument('--link_method', type=str, default='formula', choices=['raw', 'formula'])
    parser.add_argument('--serial_method', type=str, default='concat', choices=['concat', 'pair'])
    parser.add_argument('--dataset_name', type=str, default='both0513')
    parser.add_argument('--vocab_path', type=str, default='./experiment/pointer_generator/vocab')
    parser.add_argument('--vocab_size', type=int, default=30000)
    
    parser.add_argument('--train_filename', type=str, default='train.json')
    parser.add_argument('--valid_filename', type=str, default='valid.json')
    parser.add_argument('--test_filename', type=str, default='test.json')

    parser.add_argument('--run_subdir', type=str, default='runs')
    parser.add_argument('--log_subdir', type=str, default='logs', 
        help='Logging file (multiple predictions and single references) of the testset.')

    # model
    parser.add_argument('--experiment_name', type=str, default='b2b', choices=['t5', 'bart', 'b2b', 'pg'])
    parser.add_argument('--tokenizer_name', type=str, default='t5-base', 
        choices=['t5-base', 'facebook/bart-base', 'bert-base-uncased']) 
    parser.add_argument('--model_name', type=str, default='t5-base', 
        choices=['t5-small', 'facebook/bart-base', 'bert-base-uncased'])
    parser.add_argument('--model_path', type=str, default=None)
    
    # hyper params
    parser.add_argument('--per_device_train_batch_size', type=int, default=8)
    parser.add_argument('--per_device_eval_batch_size', type=int, default=8)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
    parser.add_argument('--warmup_steps', type=int, default=100)
    parser.add_argument('--learning_rate', type=float, default=1e-3)

    parser.add_argument('--start_iepoch', type=int, default=0, help='Index of the starting epoch.')
    parser.add_argument('--num_train_epochs', type=int, default=5, help='Number of epochs for continual tuning.')
    parser.add_argument('--num_eval_epochs', type=int, default=1, help='Number of epochs per validation.')
    parser.add_argument('--num_save_model_epochs', type=int, default=1, help='Number of epochs to save model ckpt.')

    parser.add_argument('--num_beams', type=int, default=5, help='Number of the searching beam size for sequence generation.')
    parser.add_argument('--input_maxlen', type=int, default=512, help='Max number of tokens of input sequences.')
    parser.add_argument('--decode_maxlen', type=int, default=60, help='Max number of tokens of generated sequnces.')
    parser.add_argument('--num_return_sequences', type=int, default=5, help='Number of generated sentences for comparison.')

    parser.add_argument('--logging_steps', type=int, default=100)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--seed', type=int, default=47, help='Random seed for seq2seq training arguments.')

    # command
    parser.add_argument('--do_train', action='store_true')
    parser.add_argument('--do_eval', action='store_true')
    parser.add_argument('--do_test', action='store_true')
    parser.add_argument('--do_decode', action='store_true')
    parser.add_argument('--test_decode_name', type=str, default='decoded_test.log')
    parser.add_argument('--rewrite_with_suite', action='store_true')

    parser.add_argument('--metrics', type=str, default='bleu')   # 'bleu,parent'

    args = parser.parse_args()
    
    args.dataset_dir = get_dataset_path(args)
    args.train_outpath = os.path.join(args.dataset_dir, args.train_filename)
    args.valid_outpath = os.path.join(args.dataset_dir, args.valid_filename)
    args.test_outpath = os.path.join(args.dataset_dir, args.test_filename)
    if args.experiment_name == 'pg':
        args.train_data_path = args.train_outpath
        args.eval_data_path = args.valid_outpath
        args.decode_data_path = args.test_outpath
        args.latest_model_path = args.model_path
        args.train_sleep_time = 15


    print(f'experiment with name [{args.experiment_name}]')
    if args.rewrite_with_suite:
        exp_suite = ArgumentSuite[args.experiment_name]
        if args.tokenizer_name != exp_suite['tokenizer_name']:
            print(f"align tokenizer-name from {args.tokenizer_name} to {exp_suite['tokenizer_name']}")
            args.tokenizer_name = exp_suite['tokenizer_name']
        if args.model_name != exp_suite['model_name']:
            print(f"align model-name from {args.model_name} to {exp_suite['model_name']}")
            args.model_name = exp_suite['model_name']
        args.per_device_train_batch_size = exp_suite['per_device_train_batch_size']
        args.per_device_eval_batch_size = exp_suite['per_device_eval_batch_size']
        args.learning_rate = exp_suite['learning_rate']
        args.num_train_epochs = exp_suite['num_train_epochs']

    args.run_dir = get_run_dir(args)
    print(f'out to running directory: {args.run_dir}')

    args.metrics = [m.strip() for m in args.metrics.split(',')]
    print(f'do evaluation with metrics: {args.metrics}')


    if args.do_train or args.do_eval:
        TrainFunctionDict[args.experiment_name](args)
    
    if args.do_test or args.do_decode:
        TestFunctionDict[args.experiment_name](args)


if __name__ == "__main__":
    main()
