# import faulthandler
# faulthandler.enable()
import os
import torch
import torch.nn as nn
import numpy as np
import argparse
import random
import json
import time
import logging
from tqdm import tqdm, trange
import pdb
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from utils.data_utils_fst import Processor, MultiWozDataset, extract_data
from utils.eval_utils import model_evaluation, prediction_inference
from utils.label_lookup import get_label_lookup_from_first_token, combine_slot_values
from models.ModelBERT_S2full import UtteranceEncoding, BeliefTracker

from transformers import BertTokenizer, AdamW, get_cosine_schedule_with_warmup, get_linear_schedule_with_warmup


logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)

def main(args):
    def worker_init_fn(worker_id):
        np.random.seed(args.random_seed + worker_id)
        
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
        
    # logger
    logger_file_name = args.save_dir.split('/')[1] + '_' + str(args.n_epochs) + '_init50unlabel' + str(args.random_seed)
    fileHandler = logging.FileHandler(os.path.join(args.save_dir, "%s.txt"%(logger_file_name)))
    logger.addHandler(fileHandler)
    logger.info(args)
    assert args.aug_type in ("self_aug", "cl_aug", "mix_aug")
    
    # cuda setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info("device: {}".format(device))
    
    # set random seed
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    if device == "cuda":
        torch.cuda.manual_seed(args.random_seed)
        torch.cuda.manual_seed_all(args.random_seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

        
        
    #******************************************************
    # extract data ratio
    #******************************************************    
    extract_data(args.data_dir, args.data_ratio, args.random_seed)    
        
    #******************************************************
    # load data
    #******************************************************
    processor = Processor(args)
    slot_meta = processor.slot_meta
    label_list = processor.label_list
    num_labels = [len(labels) for labels in label_list]
    logger.info(slot_meta)
   
    tokenizer = BertTokenizer.from_pretrained(args.pretrained_model)
    
    train_data_raw = processor.get_train_instances(args.data_dir, str(args.data_ratio)+"-"+str(args.random_seed)+"-", tokenizer)
    print("# train examples %d" % len(train_data_raw))
    
    dev_data_raw = processor.get_dev_instances(args.data_dir, tokenizer)
    print("# dev examples %d" % len(dev_data_raw))
    
    test_data_raw = processor.get_test_instances(args.data_dir, tokenizer)
    print("# test examples %d" % len(test_data_raw))
    
    complement_train_data_raw = processor.get_train_instances(args.data_dir, "sup-"+str(args.data_ratio)+"-"+str(args.random_seed)+"-", tokenizer)
    print("# unlabel complement train examples %d" % len(complement_train_data_raw))
    logger.info("Data loaded!")
    

    #******************************************************
    # build teacher model
    #******************************************************
    ## Initialize slot and value embeddings
    sv_encoder = UtteranceEncoding.from_pretrained(args.pretrained_model)
    for p in sv_encoder.bert.parameters():
        p.requires_grad = False
    
    new_label_list, slot_value_pos = combine_slot_values(slot_meta, label_list) # without slot head
    logger.info(slot_value_pos)
    slot_lookup = get_label_lookup_from_first_token(slot_meta, tokenizer, sv_encoder, device)
    value_lookup = get_label_lookup_from_first_token(new_label_list, tokenizer, sv_encoder, device)
    
#     load state_dict
    ckpt_path = os.path.join(args.teacher_dir, 'model_best_acc.bin')

    
    teacher_model = BeliefTracker(args, slot_lookup, value_lookup, num_labels, slot_value_pos, device)
    ckpt = torch.load(ckpt_path, map_location='cpu')
    ckpt['slot_lookup.weight'] = slot_lookup.weight
    ckpt['value_lookup.weight'] = value_lookup.weight
    
    teacher_model.load_state_dict(ckpt)
    teacher_model.to(device)
    
    epoch = 0
    
    #******************************************************
    #  student override loop
    #******************************************************
    for override in range(args.override_loop):
        print("************ override loop [%d/%d] ************" % (override+1, args.override_loop))
         # derive pesudo labels  
        prediction_pseudo_labels = prediction_inference(teacher_model, complement_train_data_raw, tokenizer, slot_meta, label_list, epoch+1)
    #         prediction_pseudo_labels = prediction_inference(teacher_model, train_data_raw, tokenizer, slot_meta, label_list, epoch+1)
        pseudo_train_data_raw = processor.get_pseudo_train_instances(args.data_dir, "sup-"+str(args.data_ratio)+"-"+str(args.random_seed)+"-", tokenizer, prediction_pseudo_labels)
        
#     if train_concat == True:
        pseudo_train_data_raw = pseudo_train_data_raw + train_data_raw
    
        noised_train_data = MultiWozDataset(pseudo_train_data_raw, tokenizer, word_dropout=args.word_dropout)
        num_train_steps = int(len(pseudo_train_data_raw) / args.train_batch_size * args.n_epochs)
    
        logger.info("***** Run student training *****")
        logger.info(" Num examples = %d", len(pseudo_train_data_raw))
        logger.info(" Batch size = %d", args.train_batch_size)
        logger.info(" Num steps = %d", num_train_steps)

        noised_train_sampler = RandomSampler(noised_train_data)
        noised_train_dataloader = DataLoader(noised_train_data,
                                  sampler=noised_train_sampler,
                            batch_size=args.train_batch_size,
                            collate_fn=noised_train_data.collate_fn,
                                num_workers=args.num_workers,
                                worker_init_fn=worker_init_fn)
    
        if override >= 0:
            print("Student Model Initialization...")
            student_model = BeliefTracker(args, slot_lookup, value_lookup, num_labels, slot_value_pos, device)
            student_model.to(device)
   
            ## prepare optimizer
            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            enc_param_optimizer = list(student_model.encoder.named_parameters())
            enc_optimizer_grouped_parameters = [
                {'params': [p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
                {'params': [p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
                ]
    
            enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr)
            enc_scheduler = get_linear_schedule_with_warmup(enc_optimizer, int(num_train_steps * args.enc_warmup), num_train_steps)

            dec_param_optimizer = list(student_model.decoder.parameters())
            dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr)
            dec_scheduler = get_linear_schedule_with_warmup(dec_optimizer, int(num_train_steps * args.dec_warmup), num_train_steps)

            logger.info(enc_optimizer)
            logger.info(dec_optimizer)

        #******************************************************
        # training
        #******************************************************
        logger.info("Student Training...")    
    
        best_loss = None
        best_acc = None
        last_update = None
        for epoch in trange(int(args.n_epochs), desc="Epoch"):
            batch_loss = []
            batch_acc = []

            for step, batch in enumerate(tqdm(noised_train_dataloader)):
                student_model.train()

                batch = [b.to(device) if b is not None else b for b in batch]
                input_ids, segment_ids, input_mask, label_ids = batch

                # forward
                loss, _, acc, _, _ = student_model(input_ids=input_ids, attention_mask=input_mask,
                                           token_type_ids=segment_ids, labels=label_ids, aug_type=args.aug_type)

                loss.backward()
                enc_optimizer.step()
                enc_scheduler.step()
                dec_optimizer.step()
                dec_scheduler.step()
                student_model.zero_grad()

                batch_loss.append(loss.item())
                batch_acc.append(acc)
                if step % 300 == 0:
                    print("[%d/%d] [%d/%d] mean_loss: %.6f, mean_joint_acc: %.6f" % \
                          (epoch+1, args.n_epochs, step, len(noised_train_dataloader), np.mean(batch_loss), np.mean(batch_acc)))
                    batch_loss = []
                    batch_acc = []


            if (epoch+1) % args.eval_epoch == 0:
                eval_res = model_evaluation(student_model, dev_data_raw, tokenizer, slot_meta, label_list, epoch+1)
                if last_update is None or best_loss > eval_res['loss']:
                    best_loss = eval_res['loss']
                    print("Best Loss : ", best_loss)
                    print("\n")
                if last_update is None or best_acc < eval_res['joint_acc']:
                    best_acc = eval_res['joint_acc']
                    save_path = os.path.join(args.save_dir, 'student_model_best_acc_iter'+str(override+1)+'.bin')
                    torch.save(student_model.state_dict(), save_path)
                    last_update = epoch
                    print("Best Acc : ", best_acc)
                    print("\n")

                logger.info("*** Student Evaluation: Epoch=%d, Last Update=%d, Dev Loss=%.6f, Dev Acc=%.6f, Dev Turn Acc=%.6f, Best Loss=%.6f, Best Acc=%.6f ***" % (epoch, last_update, eval_res['loss'], eval_res['joint_acc'], eval_res['joint_turn_acc'], best_loss, best_acc))


            if last_update + args.patience <= epoch:
                break
        #override
        test_res = model_evaluation(student_model, test_data_raw, tokenizer, slot_meta, label_list, epoch+1)
        logger.info("*** Student Test Evaluation: Epoch=%d, Last Update=%d, Tes Loss=%.6f, Tes Acc=%.6f, Tes Turn Acc=%.6f, Best Loss=%.6f, Best Acc=%.6f ***" % (epoch, last_update, test_res['loss'], test_res['joint_acc'], test_res['joint_turn_acc'], best_loss, best_acc))
        teacher_model = student_model
    #----------------------------------------------------------------------
    print("Train Over")
    
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--data_dir", default='data/mwz2.0', type=str)
    parser.add_argument("--pretrained_model", default='bert-base-uncased', type=str)
    parser.add_argument("--save_dir", default='out-bert/exp20/5-shot/css_aug', type=str)
    parser.add_argument("--teacher_dir", default='../teacher/out-bert/exp20/5-shot', type=str)
    parser.add_argument("--data_ratio", default=5, type=int)
    parser.add_argument("--attn_type", default='softmax', type=str,
                        help="softmax or tanh")
    parser.add_argument("--aug_type", default='cl_aug', type=str,
                        help="self_aug or cl_aug or mix_aug")
    parser.add_argument("--cutoff_ratio", default=0.1, type=float)
    
    parser.add_argument("--random_seed", default=80, type=int)
    parser.add_argument("--num_workers", default=4, type=int)
    parser.add_argument("--train_batch_size", default=8, type=int)
    parser.add_argument("--enc_warmup", default=0.1, type=float)
    parser.add_argument("--dec_warmup", default=0.1, type=float)
    parser.add_argument("--enc_lr", default=4e-5, type=float)
    parser.add_argument("--dec_lr", default=1e-4, type=float)
    parser.add_argument("--n_epochs", default=10, type=int)
    parser.add_argument("--override_loop", default=3, type=int)
    parser.add_argument("--eval_epoch", default=1, type=int)
    parser.add_argument("--eval_step", default=10000, type=int,
                        help="Within each epoch, do evaluation as well at every eval_step")

    parser.add_argument("--dropout_prob", default=0.1, type=float)
    parser.add_argument("--word_dropout", default=0.1, type=float)
    parser.add_argument("--temperature", default=0.1, type=float) 
    parser.add_argument("--weight_decay", default=0.01, type=float)
    parser.add_argument("--max_seq_length", default=512, type=int)
    parser.add_argument("--patience", default=80, type=int)
    parser.add_argument("--attn_head", default=4, type=int)
    parser.add_argument("--num_history", default=20, type=int)
    parser.add_argument("--distance_metric", default="euclidean", type=str,
                        help="euclidean or cosine")
    
    parser.add_argument("--num_self_attention_layer", default=6, type=int)
    
    args = parser.parse_args()
    
    print('pytorch version: ', torch.__version__)
#     print(args)
    main(args)
