import argparse
import torch
from transformers import *
from transformers import LineByLineTextDataset

import argparse
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_type', help='', required=False, default="bert", type=str)
    parser.add_argument('--model_name_or_path', help='', required=False, default="bert", type=str)
    parser.add_argument("--cache_dir", default=None, type=str, help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)")
    parser.add_argument("--tokenizer_name", default="", type=str, help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
    parser.add_argument("--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.")
    parser.add_argument("--block_size", default=-1, type=int, help="Optional input sequence length after tokenization. The training dataset will be truncated in block of this size for training. Default to the model max input length for single sentence inputs (take into account special tokens).")
    parser.add_argument("--domain_txt_file", default=None, type=str, required=True, help="input domain txt file for training MLM model")
    return parser.parse_args()

    
if __name__ == '__main__':
    MODEL_CLASSES = {
    "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
    "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
    "bert-seq": (BertConfig, BertModel, BertTokenizer),
    "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
    "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
    "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
    "xlmr": (XLMRobertaConfig, XLMRobertaForMaskedLM, XLMRobertaTokenizer),
    }
    
    torch.cuda.set_device("cuda:1")
    args = parse_args()
   
    #args["model_class"] = model_class
    #args["tokenizer"] = tokenizer
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    if args["model_name_or_path"]:
        config = config_class.from_pretrained(args["model_name_or_path"], cache_dir=args["cache_dir"]) 
    else:
        config = config_class()
        
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    #model = model_class.from_pretrained(args.model_name_or_path,  from_tf=bool(".ckpt" in args.model_name_or_path), config=config, cache_dir=args.cache_dir if args.cache_dir else None,)
    model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=args.cache_dir if args.cache_dir else None,)
    model.to(args.device)
    print('Number of model parameters: {}'.format(model.num_parameters()))
    if args.block_size <= 0:
        args.block_size = (
            tokenizer.max_len_single_sentence
        )  # Our input block size will be the max possible for the model
    args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
    
    ## Load dataset
    dataset = LineByLineTextDataset(tokenizer=tokenizer, file_path="../../comcrawl/bus_en_all.txt", block_size=args.block_size)
    model.train()
    #tokenizer.add_tokens([args.sys_token, args.usr_token])
    #model.resize_token_embeddings(len(tokenizer))    
    
    
    
    