#%%
import os
import copy
import json
import torch
import random
import warnings
import argparse
import numpy as np
#%%
import torch.nn as nn
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
# torch.set_num_threads(1)
from utils import setup_seed

warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser(description='Train Setting')
parser.add_argument('-c', '--config', default=None, type=str, help='config file path (default: None)')
parser.add_argument('-r', '--resume', default=None, type=str,
                    help='path to latest checkpoint (default: None)')
# parser.add_argument('-d', '--device', default='0', type=str,
                    # help='indices of GPUs to enable (default: all)')
parser.add_argument('-s', '--seed', default=None, type=int,
                    help='random seed')
args = parser.parse_args()

if args.config:
    config = json.load(open(args.config))
elif args.resume:
    config = torch.load(args.resume)['config']
    config['exp_name']=config['model']+'-resume'
else:
    raise AssertionError("Configuration file need to be specified. Add '-c config.json', for example.")

if args.seed:
    if args.seed == -1:
        seed = random.randint(0, 10000)
        config["exp_name"] = config["exp_name"] + str(seed)
        config["trainer"]["random_seed"] = seed
    else:
        config["exp_name"] = config["exp_name"] + str(args.seed)
        config["trainer"]["random_seed"] = args.seed

# if args.device:
    # os.environ["CUDA_VISIBLE_DEVICES"] = args.device

# print ("Training on Device: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))
setup_seed(config["trainer"]["random_seed"])

from trainer import Trainer, TrainerBert
from transformers import BertTokenizer, AlbertTokenizer
from data import TextPairProcessor, BertTextPairProcessor
from model import WarmupLinearSchedule
from model import DFITN, DFITNEmbedder

from transformers import BertModel, BertConfig, AlbertConfig, AlbertModel, AlbertForSequenceClassification, BertForSequenceClassification
from time import time
class BERT(nn.Module):
    def __init__(self, data, path):
        super(BERT, self).__init__()
        if "albert" in path:
            self.bert_config = AlbertConfig.from_pretrained(path)
            self.bert_model = AlbertForSequenceClassification.from_pretrained(path, num_labels=3)
        else:
            self.bert_config = BertConfig.from_pretrained(path)
            self.bert_model = BertForSequenceClassification.from_pretrained(path)
        

        self.word_emb = nn.Embedding(len(data.TEXT.vocab), 1)
        self.word_emb.weight.data.copy_(data.TEXT.vocab.vectors)
        self.word_emb.weight.requires_grad = False

        self.dropout = nn.Dropout(0.2)
    
    def forward(self, input_s, input_t, input_pair, len_s, len_t, len_pair, seq_pair):
        bert_id_pair = self.word_emb(input_pair).long().squeeze()
        # start = time()
        mask_pair = bert_id_pair.ne(0).to(dtype=torch.long)
        # print (mask_pair.shape)
        logits = self.bert_model(bert_id_pair, mask_pair, seq_pair)[0]
        outputs = (logits, [], [])
        # end = time()
        # print (end-start)
        return outputs


model_choices = {
    "DFITN": {
        "model": DFITN,
        "embed": DFITNEmbedder,
        "trainer": Trainer,
        "processor": TextPairProcessor
    }
}
tokenizer_dict = {
    "bert": BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True),
    "bert-chinese": BertTokenizer.from_pretrained("bert-base-chinese"),
    "albert": AlbertTokenizer.from_pretrained("albert-base-v2", do_lower_case=True)
}
optim_dict = {
    "Adam": Adam,
    "AdamW": AdamW
}
scheduler_dict = {
    "CosineAnnealingLR": CosineAnnealingLR,
    "WarmupLinearSchedule": WarmupLinearSchedule
}

#%%
def train(config, resume=None):
    trainer_config = copy.deepcopy(config)
    # ------ Choose Model ------
    model_name = config['model']['name']
    model_dict = model_choices[model_name]
    # ------ Load Data ------
    data_config = config['data_loader']
    if model_name == "DFITN_bert":
        data_config['tokenizer'] = tokenizer_dict[data_config['tokenizer']]
    data = model_dict['processor'](**data_config)
    # ------ Load Embedder ------
    embedder_config = config['embedder']['args']
    embedder_config['data'] = data
    embed = model_dict['embed'](**embedder_config)
    # ------ Load Model ------
    model_config = config['model']['args']
    model_config['embed'] = embed
    model = model_dict['model'](**model_config)
    # ------ Model Parameter ------
    model_params = filter(lambda p: p.requires_grad, model.parameters())
    params_num = sum([np.prod(p.size()) for p in model_params])
    print ("Trainable Model Parameters: {}".format(params_num))
    # ------ Set Optimizer ------
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optim_name = config['optimizer']['name']
    optim_config = config['optimizer']['args']
    optim_config['params'] = trainable_params
    optimizer = optim_dict[optim_name](**optim_config)
    # ------ Set LRScheduler ------
    scheduler_name = config['lr_scheduler']['name']
    scheduler_config = config['lr_scheduler']['args']
    scheduler_config['optimizer'] = optimizer
    lr_scheduler = scheduler_dict[scheduler_name](**scheduler_config)
    # ------ Set Trainer ------
    trainer = model_dict["trainer"](
        config=trainer_config, 
        data=data,
        model=model, 
        optimizer=optimizer, 
        lr_scheduler=lr_scheduler, 
        resume=resume
    )
    print (config["exp_name"])
    trainer.train()


train(config, args.resume)
