import argparse
import logging
import os
import json
import torch
import random
import numpy as np
import time

from torch.nn.parallel import DataParallel
from torch.utils.data import DataLoader

from dataset import VAEDataset, WikiDataset, CVAEDataset, Seq2seqDataset, WPDataset
from train import train, valid, generate, interpolating, visual_attention

# GPT2 Model
from model_add_middle import GPT2VAEAAddMiddle
from model_lmf_layerly import GPT2VAELMFLayerly
from model_lmf_last_layer import GPT2VAELMFLastLayerOnly
from model_lmf_layerly_independent import GPT2VAELMFLayerlyIndependent
from model_lmf_layerly_lastloss import GPT2VAELMFLayerlyLastLoss
from model_lmf_layerly_firstloss import GPT2VAELMFLayerlyFirstLoss
from model_embedding import GPT2VAEEmbed
from model_memory_layerly import GPT2VAEMemoryLayerly
from model_memory_last_layer import GPT2VAEMemoryLastLayerOnly
from model_softmax import GPT2VAESoftmax
from model_gpt2 import GPT2LMHeadModel

# T5 Model
from model_t5_mem import T5VAEMem

# Bart Model
from model_bart import BartForConditionalGeneration
from model_bart_embedding import BartVAEEmbed
from model_bart_memory_last import BartVAEMemoryLast
from model_bart_softmax import BartVAESoftmax
from model_bart_lmf import BartVAELMFLayerly
from model_bart_lmf_independent import BartVAELMFLayerlyIndependent
from model_bart_lmf_firstloss import BartVAELMFLayerlyFirstLoss
from model_bart_lmf_lastloss import BartVAELMFLayerlyLastLoss

from transformers import AutoConfig, AutoModel, AutoTokenizer

logger = logging.getLogger(__name__)

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--train_file", default=None, type=str,
                        help="train data (json format) for training.")
    parser.add_argument("--valid_file", default=None, type=str,
                        help="valid data (json format) for training.")
    parser.add_argument("--test_file", default=None, type=str,
                        help="test data (json format) for training.")
    parser.add_argument("--model_type", type=str, default='vae', 
                        choices=['bartemb', 'bartbase', 'bartsoftmax', 'bartmemorylast', 'bartlmf', 
                                 'bartlmfindependent', 'bartlmflastloss', 'bartlmffirstloss',
                                 'addmiddle', 'embed', 'softmax', 'memorylayerly', 'memorylastlayer', 
                                 'lmflayerly', 'lmflastlayer', 'lmfindependent', 'lmflayerlastloss', 'lmflayerfirstloss',
                                 't5-small', 'gpt2'], 
                        help="model type to use for training")
    parser.add_argument("--pretrained_model", type=str, default='gpt2', 
                        help="model type to use for training")
    parser.add_argument("--dataset_type", type=str, default='vae', choices=['vae', 'wiki', 'cvae', 'seq2seq', 'wp'], 
                        help="Dataset to use for training")
    parser.add_argument("--vocab_file", default=None, type=str,
                        help="vocab for training.")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model checkpoints and predictions will be written.")
    parser.add_argument("--model_name", default=None, type=str, required=True,
                        help="The output directory where the model checkpoints and predictions will be written.")
    parser.add_argument("--config_path", default='gpt2', type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_path", default='gpt2', type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--pretrain_model_path", default='gpt2', type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--logdir", default='./train_logger', type=str,
                        help="The output directory where the log will be written.")
    parser.add_argument("--generation_output_dir", default='./generation_output', type=str,
                        help="The output directory where the log will be written.")
    # Other parameters\
    parser.add_argument("--begin_epoch", default=None, type=int, help="load epochs")
    parser.add_argument("--epochs", default=20, type=int,
                        help="total epochs")
    parser.add_argument("--per_gpu_train_batch_size", default=16, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--learning_rate", default=1e-4, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--weight_decay", default=0.01, type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--label_smoothing", default=0.1, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--kl_threshold", default=0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--latent_size", default=32, type=int,
                        help="latent size")
    parser.add_argument("--latent_lmf_rank", default=4, type=int,
                        help="latent size")
    parser.add_argument("--hidden_size", default=512, type=int,
                        help="latent size")
    parser.add_argument("--max_length", default=200, type=int,
                        help="max length")
    parser.add_argument("--begin_layer", default=None, type=int,
                        help="max length")
    parser.add_argument("--end_layer", default=None, type=int,
                        help="max length")
    parser.add_argument("--add_layer", default=None, type=int,
                        help="max length")
    parser.add_argument("--num_training_steps", default=-1, type=int,
                        help="set total number of training steps to perform")
    parser.add_argument('--num_classifier_epochs_per_iters', type=int, default=10,
                        help="Training epochs when initlizing.")
    parser.add_argument('--num_generator_epochs_per_iters', type=int, default=10,
                        help="Training epochs when initlizing.")
    parser.add_argument('--num_iters', type=int, default=10,
                        help="Training epochs when initlizing.")
    parser.add_argument("--warmup_steps", default=3000, type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--load_pretrain", action='store_true',
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--random_prob", default=0.1, type=float,
                        help="prob to random replace a masked token")
    parser.add_argument("--keep_prob", default=0.1, type=float,
                        help="prob to keep no change for a masked token")

    parser.add_argument('--logging_steps', type=int, default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=1500,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")
    parser.add_argument('--log_step', type=int, default=100,
                        help="random seed for initialization")
    parser.add_argument('--num_beams', type=int, default=10,
                        help="Beam size for searching")
    parser.add_argument('--top_k', type=int, default=-1,
                        help="Beam size for searching")
    parser.add_argument('--top_p', type=float, default=0.9,
                        help="Beam size for searching")
    parser.add_argument('--repetition_penalty', type=float, default=1.2)
    parser.add_argument('--model_parallel', action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--world_size", type=int, default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--eval', action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--generation', action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--interpolating', action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--visual_attention', action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--eval_metrics', action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--use_scheduler', action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--gradient_checkpointing', action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument('--latent_lmf_on_value', action='store_true',
                        help="latent variables as memory")
    parser.add_argument('--latent_lmf_split', action='store_true',
                        help="latent variables use lmf model")
    parser.add_argument('--cycle_annealing', action='store_true',
                        help="latent variables as memory")
    parser.add_argument('--cycle_iters', type=int, default=2,
                        help="latent variables as memory")
    parser.add_argument('--sample_times', type=int, default=30,
                        help="latent variables as memory")
    parser.add_argument('--linear_annealing', action='store_true',
                        help="latent variables as memory")
    parser.add_argument('--loss_reduce', action='store_true',
                        help="latent variables as memory")
    args = parser.parse_args()
    return args

def prepare(args):
    torch.set_num_threads(3)
    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        logging.info("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    if not args.eval and not args.generation:
        os.makedirs(os.path.join(args.output_dir, args.model_name), exist_ok=True)
        json.dump(args.__dict__, open(os.path.join(
            args.output_dir, args.model_name, 'train_opt.json'), 'w'), sort_keys=True, indent=2)

    if args.no_cuda:
        args.n_gpu = 1
    else:
        args.n_gpu = torch.cuda.device_count()
    args.batch_size = args.per_gpu_train_batch_size * args.n_gpu
    # Setup logging
    
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    os.makedirs(os.path.join(args.logdir, args.model_name), exist_ok=True)
    logfile=os.path.join(args.logdir, args.model_name, time.strftime("%Y_%m_%d_%H_%M", time.localtime())+'.log')
    fh = logging.FileHandler(logfile, mode='w')
    fh.setLevel(logging.DEBUG)  
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    # Set seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    logger.info("Training/evaluation parameters %s", args)

    if args.no_cuda:
        args.device = torch.device('cpu')
    else:
        args.device = torch.device('cuda:0')

def init_para_frompretrained(model, pm, pretrained_model):
    if pretrained_model == 'gpt2':
        logger.info('load gpt2 pretrained model parameters')
        gpt2 = pm
        model = model.encoder
        model.wte.weight = gpt2.wte.weight
        model.wpe.weight = gpt2.wpe.weight

        for i in range(len(gpt2.h)):
            model.h[i].ln_1.weight = gpt2.h[i].ln_1.weight
            model.h[i].ln_1.bias = gpt2.h[i].ln_1.bias
            model.h[i].attn.c_attn.weight = gpt2.h[i].attn.c_attn.weight
            model.h[i].attn.c_attn.bias = gpt2.h[i].attn.c_attn.bias
            model.h[i].attn.c_proj.weight = gpt2.h[i].attn.c_proj.weight
            model.h[i].attn.c_proj.bias = gpt2.h[i].attn.c_proj.bias
            model.h[i].ln_2.weight = gpt2.h[i].ln_2.weight
            model.h[i].ln_2.bias = gpt2.h[i].ln_2.bias
            model.h[i].mlp.c_fc.weight = gpt2.h[i].mlp.c_fc.weight
            model.h[i].mlp.c_fc.bias = gpt2.h[i].mlp.c_fc.bias
            model.h[i].mlp.c_proj.weight = gpt2.h[i].mlp.c_proj.weight
            model.h[i].mlp.c_proj.bias = gpt2.h[i].mlp.c_proj.bias

        model.ln_f.weight = gpt2.ln_f.weight
        model.ln_f.bias = gpt2.ln_f.bias
    elif pretrained_model == 'facebook/bart-base':
        logger.info('load bart-base pretrained model parameters')
        bart = pm
        model.shared.weight = bart.shared.weight
        model.encoder.embed_positions.weight = bart.encoder.embed_positions.weight
        model.decoder.embed_positions.weight = bart.decoder.embed_positions.weight
        model.encoder.layernorm_embedding.weight = bart.encoder.layernorm_embedding.weight
        model.encoder.layernorm_embedding.bias = bart.encoder.layernorm_embedding.bias
        model.decoder.layernorm_embedding.weight = bart.decoder.layernorm_embedding.weight
        model.decoder.layernorm_embedding.bias = bart.decoder.layernorm_embedding.bias

        for i in range(len(model.encoder.layers)):
            model.encoder.layers[i].self_attn.k_proj.weight = bart.encoder.layers[i].self_attn.k_proj.weight
            model.encoder.layers[i].self_attn.k_proj.bias = bart.encoder.layers[i].self_attn.k_proj.bias
            model.encoder.layers[i].self_attn.v_proj.weight = bart.encoder.layers[i].self_attn.v_proj.weight
            model.encoder.layers[i].self_attn.v_proj.bias = bart.encoder.layers[i].self_attn.v_proj.bias
            model.encoder.layers[i].self_attn.q_proj.weight = bart.encoder.layers[i].self_attn.q_proj.weight
            model.encoder.layers[i].self_attn.q_proj.bias = bart.encoder.layers[i].self_attn.q_proj.bias
            model.encoder.layers[i].self_attn.out_proj.weight = bart.encoder.layers[i].self_attn.out_proj.weight
            model.encoder.layers[i].self_attn.out_proj.bias = bart.encoder.layers[i].self_attn.out_proj.bias
            model.encoder.layers[i].self_attn_layer_norm.weight = bart.encoder.layers[i].self_attn_layer_norm.weight
            model.encoder.layers[i].self_attn_layer_norm.bias = bart.encoder.layers[i].self_attn_layer_norm.bias

            model.encoder.layers[i].fc1.weight = bart.encoder.layers[i].fc1.weight
            model.encoder.layers[i].fc1.bias = bart.encoder.layers[i].fc1.bias
            model.encoder.layers[i].fc2.weight = bart.encoder.layers[i].fc2.weight
            model.encoder.layers[i].fc2.bias = bart.encoder.layers[i].fc2.bias
            model.encoder.layers[i].final_layer_norm.weight = bart.encoder.layers[i].final_layer_norm.weight
            model.encoder.layers[i].final_layer_norm.bias = bart.encoder.layers[i].final_layer_norm.bias

        for i in range(len(model.decoder.layers)):
            model.decoder.layers[i].self_attn.k_proj.weight = bart.decoder.layers[i].self_attn.k_proj.weight
            model.decoder.layers[i].self_attn.k_proj.bias = bart.decoder.layers[i].self_attn.k_proj.bias
            model.decoder.layers[i].self_attn.v_proj.weight = bart.decoder.layers[i].self_attn.v_proj.weight
            model.decoder.layers[i].self_attn.v_proj.bias = bart.decoder.layers[i].self_attn.v_proj.bias
            model.decoder.layers[i].self_attn.q_proj.weight = bart.decoder.layers[i].self_attn.q_proj.weight
            model.decoder.layers[i].self_attn.q_proj.bias = bart.decoder.layers[i].self_attn.q_proj.bias
            model.decoder.layers[i].self_attn.out_proj.weight = bart.decoder.layers[i].self_attn.out_proj.weight
            model.decoder.layers[i].self_attn.out_proj.bias = bart.decoder.layers[i].self_attn.out_proj.bias
            model.decoder.layers[i].self_attn_layer_norm.weight = bart.decoder.layers[i].self_attn_layer_norm.weight
            model.decoder.layers[i].self_attn_layer_norm.bias = bart.decoder.layers[i].self_attn_layer_norm.bias
            
            model.decoder.layers[i].encoder_attn.k_proj.weight = bart.decoder.layers[i].encoder_attn.k_proj.weight
            model.decoder.layers[i].encoder_attn.k_proj.bias = bart.decoder.layers[i].encoder_attn.k_proj.bias
            model.decoder.layers[i].encoder_attn.v_proj.weight = bart.decoder.layers[i].encoder_attn.v_proj.weight
            model.decoder.layers[i].encoder_attn.v_proj.bias = bart.decoder.layers[i].encoder_attn.v_proj.bias
            model.decoder.layers[i].encoder_attn.q_proj.weight = bart.decoder.layers[i].encoder_attn.q_proj.weight
            model.decoder.layers[i].encoder_attn.q_proj.bias = bart.decoder.layers[i].encoder_attn.q_proj.bias
            model.decoder.layers[i].encoder_attn.out_proj.weight = bart.decoder.layers[i].encoder_attn.out_proj.weight
            model.decoder.layers[i].encoder_attn.out_proj.bias = bart.decoder.layers[i].encoder_attn.out_proj.bias
            model.decoder.layers[i].encoder_attn_layer_norm.weight = bart.decoder.layers[i].encoder_attn_layer_norm.weight
            model.decoder.layers[i].encoder_attn_layer_norm.bias = bart.decoder.layers[i].encoder_attn_layer_norm.bias

            model.decoder.layers[i].fc1.weight = bart.decoder.layers[i].fc1.weight
            model.decoder.layers[i].fc1.bias = bart.decoder.layers[i].fc1.bias
            model.decoder.layers[i].fc2.weight = bart.decoder.layers[i].fc2.weight
            model.decoder.layers[i].fc2.bias = bart.decoder.layers[i].fc2.bias
            model.decoder.layers[i].final_layer_norm.weight = bart.decoder.layers[i].final_layer_norm.weight
            model.decoder.layers[i].final_layer_norm.bias = bart.decoder.layers[i].final_layer_norm.bias

def prepare_model(args):
    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model)
    if '<s>' not in tokenizer.vocab:
        tokenizer._add_tokens(['<s>'])
    if '</s>' not in tokenizer.vocab:
        tokenizer._add_tokens(['</s>'])
    if args.pretrained_model == 'gpt2':
        tokenizer.pad_id = 50256
        if args.model_type == 'gpt2':
            tokenizer.pad_id_label = -100
        else:
            tokenizer.pad_id_label = tokenizer.pad_id
    else:
        tokenizer.pad_id = tokenizer.pad_token_id
        tokenizer.pad_id_label = tokenizer.pad_id
    tokenizer.pad_id_label = tokenizer.pad_id
    tokenizer.bos_id = tokenizer.convert_tokens_to_ids('<s>')
    tokenizer.eos_id = tokenizer.convert_tokens_to_ids('</s>')

    model_config = AutoConfig.from_pretrained(args.config_path)
    model_config.model_type = args.model_type
    model_config.max_length = 1024
    model_config.vocab_size = len(tokenizer)
    model_config.pad_token_id = tokenizer.pad_id
    model_config.kl_threshold = args.kl_threshold
    model_config.is_cvae = (args.dataset_type == 'cvae' or args.dataset_type == 'wp')
    model_config.begin_layer = args.begin_layer
    model_config.end_layer = args.end_layer
    model_config.add_layer = args.add_layer
    model_config.loss_reduce = args.loss_reduce

    if args.gradient_checkpointing:
        model_config.gradient_checkpointing = True
        model_config.use_cache = False
    for arg in vars(args):
        if arg.startswith('latent'):
            setattr(model_config, arg, getattr(args, arg))
    
    model_class = {
        'gpt2': GPT2LMHeadModel,
        't5-small': T5VAEMem,
        'bartbase': BartForConditionalGeneration,
        'bartemb': BartVAEEmbed,
        'bartmemorylast': BartVAEMemoryLast,
        'bartsoftmax': BartVAESoftmax,
        'bartlmf': BartVAELMFLayerly,
        'bartlmfindependent': BartVAELMFLayerlyIndependent, 
        'bartlmflastloss': BartVAELMFLayerlyLastLoss, 
        'bartlmffirstloss': BartVAELMFLayerlyFirstLoss,
        'lmflastlayer': GPT2VAELMFLastLayerOnly,
        'lmflayerly': GPT2VAELMFLayerly,
        'lmfindependent': GPT2VAELMFLayerlyIndependent,
        'lmflayerfirstloss': GPT2VAELMFLayerlyFirstLoss,
        'lmflayerlastloss': GPT2VAELMFLayerlyLastLoss,
        'addmiddle': GPT2VAEAAddMiddle,
        'embed': GPT2VAEEmbed,
        'softmax': GPT2VAESoftmax,
        'memorylayerly': GPT2VAEMemoryLayerly,
        'memorylastlayer': GPT2VAEMemoryLastLayerOnly,
    }
    
    model = model_class[args.model_type](model_config)
    if args.load_pretrain:
        if args.model_type in ['t5-small', 'gpt2', 'bartbase']:
            model = model_class[args.model_type].from_pretrained(args.pretrained_model)
            model.resize_token_embeddings(len(tokenizer))
            if args.model_type == 'gpt2':
                model.loss_reduce = args.loss_reduce
                model.pad_token_id = model_config.pad_token_id
        else:
            pretrained_model = AutoModel.from_pretrained(args.pretrained_model)
            logging.info('loading pretrained model parameters...')
            init_para_frompretrained(model, pretrained_model, args.pretrained_model)
            if args.pretrained_model == 'gpt2':
                model.encoder.resize_token_embeddings(len(tokenizer))
                model.decoder.wte = model.encoder.wte
    if args.begin_epoch is not None:
        model_path = os.path.join(args.output_dir, args.model_name, 'model_epoch_{}.pt'.format(args.begin_epoch))
        model_state_dict = torch.load(model_path, map_location=args.device)
        model.load_state_dict(model_state_dict)
        logging.info('load model_epoch_{}.pt finish'.format(args.begin_epoch))
    else:
        args.begin_epoch = -1

    if args.model_parallel and torch.cuda.device_count() > 1:  
        logging.info('model paralleize...')
        model.parallelize()
    else:
        model = model.to(args.device)
        if torch.cuda.device_count() > 1:
            model = DataParallel(model)
    return model, tokenizer

def prepare_data(tokenizer, args):
    dataset_class = {
        'vae': VAEDataset,
        'wiki': WikiDataset,
        'cvae': CVAEDataset,
        'wp': WPDataset,
        'seq2seq': Seq2seqDataset,
    }
    if args.eval or args.generation:
        logging.info("eval model: the epoch {} of {}".format(args.begin_epoch, args.model_name))
        test_dataset = dataset_class[args.dataset_type](args.test_file, tokenizer, args.device)
        test_iter = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=test_dataset.collate_fn)
        return test_iter
    else:
        train_dataset = dataset_class[args.dataset_type](args.train_file, tokenizer, args.device)
        valid_dataset = dataset_class[args.dataset_type](args.valid_file, tokenizer, args.device)
        train_iter = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=train_dataset.collate_fn)
        valid_iter = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=valid_dataset.collate_fn)
        logging.info('training with {} samples...'.format(len(train_dataset)))
        return train_iter, valid_iter

def main():
    args = get_args()
    prepare(args)
    model, tokenizer = prepare_model(args)
    total_params = sum(p.numel() for p in model.parameters())
    logging.info('total parameters: {}'.format(total_params))
    if args.eval or args.generation:
        test_iter = prepare_data(tokenizer, args)
        if args.eval:
            valid(model, test_iter, args.begin_epoch, args)
        if args.generation:
            generate(model, test_iter, tokenizer, args)
    elif args.interpolating:
        sen1 = ""
        sen2 = ""
        interpolating(model, tokenizer, sen1, sen2, args)
    elif args.visual_attention:
        sen = ""
        visual_attention(model, tokenizer, sen, args)
    else:
        train_iter, valid_iter = prepare_data(tokenizer, args)
        train(model, train_iter, valid_iter, args)

if __name__ == "__main__":
    main()
