"""Training Functions."""

import os

from experiment.utils.tokenizer import prepare_tokenizer
from experiment.utils.data import get_dataset
from experiment.utils.arguments import prepare_training_arguments
from experiment.utils.model import ModelPrepareDict
from experiment.utils.metrics import MetricsBuildDict

from transformers import Seq2SeqTrainer
from experiment.pointer_generator.trainer import Trainer as PgnTrainer
from experiment.pointer_generator.decode import BeamSearch


# %% huggingface-supported models
# t5, bart, bert-to-bert

def run_train(args):
    """A general training script."""

    tokenizer = prepare_tokenizer(name=args.tokenizer_name)

    trainset= get_dataset(
        experiment_name=args.experiment_name, 
        data_files=args.train_outpath, 
        tokenizer=tokenizer, 
        args=args)
    validset = get_dataset(
        experiment_name=args.experiment_name, 
        data_files=args.valid_outpath, 
        tokenizer=tokenizer, 
        args=args)

    train_args = prepare_training_arguments(args)
    model = ModelPrepareDict[args.experiment_name](
            args.model_name, args.model_path, args.device)

    # metric_fn = MetricsBuildDict[args.metrics[0]](tokenizer)
    metric_fn = MetricsBuildDict['bleu'](tokenizer)
    trainer = Seq2SeqTrainer(
        model=model, args=train_args, 
        train_dataset=trainset, eval_dataset=validset, 
        tokenizer=tokenizer, compute_metrics=metric_fn, 
    )

    trainer._max_length = args.decode_maxlen
    trainer._num_beams = args.num_beams

    trainer.train()

# %% pointer-generator network

def find_latest_model_path(model_dir):
    """Find the path/filename of the latest model within the given directory."""
    filenames = os.listdir(model_dir)
    if len(filenames) == 0: return 

    indices = []
    for fn in filenames:
        model_name = fn.split('.')[0]
        model_index = int(model_name.split('_')[-1])
        indices.append(model_index)
    max_index = indices.index( max(indices) )
    max_file = filenames[max_index]
    
    latest_model_path = os.path.join(model_dir, max_file)
    return latest_model_path



def run_train_pgn(args, verbose=True):
    trainer = PgnTrainer(args)
    if args.latest_model_path is not None:
        model_path = args.latest_model_path
    else:
        model_path = args.model_path
    print(f'run with model from [{model_path}]')

    for iepoch in range(args.start_iepoch, args.start_iepoch + args.num_train_epochs):
        print(f'\n <<< START of the #{iepoch} EPOCH >>>')
        if (iepoch + 1) % args.num_eval_epochs == 0: do_eval = True
        else: do_eval = False
        if (iepoch + 1) % args.num_save_model_epochs == 0: do_save_model = True
        else: do_save_model = False
        trainer.run_one_epoch(
            iepoch=iepoch, 
            model_path=model_path, 
            interval=args.logging_steps, 
            save_model=do_save_model, 
        )
        args.latest_model_path = find_latest_model_path(trainer.model_dir)
        if (do_eval == True) and (args.latest_model_path is not None):
            print(f'EVAL using model [{args.latest_model_path}]')
            tester = BeamSearch(args, args.latest_model_path, args.eval_data_path)
            tester.run(args.logging_steps)
        print(f' <<< END of the #{iepoch} EPOCH >>>\n')
    




# %% collection

TrainFunctionDict = {
    't5': run_train, 
    'bart': run_train, 
    'b2b': run_train, 
    'pg': run_train_pgn, 
}