from data_loader import init_template, load_data, load_file, load_vocab, batchify, load_all_data
import numpy as np
from os.path import join, isfile
from os import listdir
from glob import glob
import logging
from module import Prober
from trainer_multiTask import GenerativeModelTrainer
import argparse
import torch

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)

device = torch.device('cuda') # 'cpu'

def main(args):
    # Extract all relations
    dataDir = "../LAMA_data/autoprompt_data"
    relation_file = "../relation_metainfo/LAMA_relations_revision.jsonl"
    vocab_filename = "../common_vocabs/common_vocab_cased_be_ro_al.txt"

    relations_meta = load_file(relation_file)
    relation_directories = glob(dataDir+"/*/")

    meta_dict = {}
    for item in relations_meta:
        name = item["relation"]
        template = item["template"]
        meta_dict[name] = template

    vocab_subset = load_vocab(vocab_filename)
    logger.info('Common vocab: %s, size: %d' % (vocab_filename, len(vocab_subset)))

    model = Prober(args)
    generator_model = model.generator_model.to(device) #.cuda()
    generator_config = model.generator_config
    generator_tokenizer = model.generator_tokenizer
    mlm_model = model.mlm_model.to(device) #.cuda()
    mlm_tokenizer = model.mlm_tokenizer

    _, generator_index_list = model.init_indices_for_filter_logprobs(vocab_subset, generator_tokenizer, logger)
    mlm_filter_indices, mlm_index_list = model.init_indices_for_filter_logprobs(vocab_subset, mlm_tokenizer, logger)
    generator_map2_mlm = model.get_vocab_map(generator_index_list, mlm_index_list)
    generator_banned_token_ids = model.get_generator_banned_token_ids(generator_index_list)

    training_samples = load_all_data(relation_directories, relations_meta, vocab_subset=vocab_subset,
                                     mask_token=generator_tokenizer.mask_token,
                                     mode="train", shuffle=True)
    train_loader = batchify(training_samples, batch_size=args.train_batch_size)

    dev_loader_dict = {}
    test_loader_dict = {}
    for dir in relation_directories:
        relation_name = dir.split("/")[-2]
        dev_samples = load_data(data_path=join(dataDir, relation_name, "dev.jsonl"),
                                template=meta_dict[relation_name], vocab_subset=vocab_subset,
                                mask_token=generator_tokenizer.mask_token)
        test_samples = load_data(data_path=join(dataDir, relation_name, "test.jsonl"),
                                template=meta_dict[relation_name], vocab_subset=vocab_subset,
                                mask_token=generator_tokenizer.mask_token)
        dev_loader = batchify(dev_samples, batch_size=args.eval_batch_size)
        test_loader = batchify(test_samples, batch_size=args.eval_batch_size)
        dev_loader_dict[relation_name] = dev_loader
        test_loader_dict[relation_name] = test_loader

    trainer = GenerativeModelTrainer(args, generator_config, generator_model, generator_tokenizer,
                                     mlm_model, mlm_tokenizer, generator_banned_token_ids, generator_map2_mlm,
                                     device, mlm_filter_indices, mlm_index_list, alpha=args.alpha)
    trainer.train(train_loader, dev_loader_dict, wiki=False)
    trainer.eval(test_loader, wiki=False, output_topk=None)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--k', type=int, default=5, help='how many predictions will be outputted')
    parser.add_argument('--seed', type=int, default=6)

    parser.add_argument('--language_model_dir', type=str, default='bert-base-cased', help='the huggingface model name')
    parser.add_argument('--generative_model_dir', type=str, default='facebook/bart-large', help='the huggingface model name')
    parser.add_argument('--train_batch_size', type=int, default=32, help='training batch size per GPU')
    parser.add_argument('--eval_batch_size', type=int, default=32)
    parser.add_argument('--num_train_epochs', type=int, default=30)
    parser.add_argument('--max_steps', type=int, default=-1)
    parser.add_argument('--max_seq_len', type=int, default=90)
    parser.add_argument('--max_input_seq_len', type=int, default=64)

    parser.add_argument('--clip', type=float, default=5.0)
    parser.add_argument('--learning_rate', type=float, default=5e-5)
    parser.add_argument('--adam_beta1', type=float, default=0.9)
    parser.add_argument('--adam_beta2', type=float, default=0.999)
    parser.add_argument('--adam_epsilon', type=float, default=1e-8)
    parser.add_argument('--weight_decay', type=float, default=1e-3)
    parser.add_argument('--warmup_steps', type=int, default=-1)
    parser.add_argument('--warmup_ratio', type=float, default=0.1)
    parser.add_argument('--save_model_dir', type=str, default="save_model_multiTask11")

    parser.add_argument('--logging_step', type=int, default=20)
    parser.add_argument('--save_step', type=int, default=1000)

    parser.add_argument('--alpha', type=float, default=0.4, help='loss weight for mutual information')

    args = parser.parse_args()

    main(args)
