import argparse, json, math, os, random
from tqdm import tqdm, trange
from time import time
import numpy as np
import torch
from torch import nn
from torch.nn.utils import clip_grad_norm_
from tensorboardX import SummaryWriter
from transformers import AdamW, get_linear_schedule_with_warmup, BartTokenizerFast
from dataloader import get_loader
from model import ClarET
from utils import *
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import argparse


def reduce_mean(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM) # sum-up as the all-reduce operation
    return rt


def Tokenizer_Mapping(model_type):
    if 'base' in model_type: 
        return BartTokenizerFast.from_pretrained('facebook/bart-base')
    elif 'large' in model_type:  
        return BartTokenizerFast.from_pretrained('facebook/bart-large')
    

def num_correct(out, labels):
    outputs = np.argmax(out, axis=1)
    return np.sum(outputs == labels)


def generate_file_dir(args):
    model_folder_name = str(args.model_type) + '_'
    model_folder_name = model_folder_name + 'abl_' + str(args.ablation) + '_'
    model_folder_name = model_folder_name + 'lr_' + str(args.lr) + '_'
    model_folder_name = model_folder_name + 'batch_size_' + str(args.batch_size) + '_'
    model_folder_name = model_folder_name + 'acc_step_' + str(args.accumulation_step) + '_'
    model_folder_name = model_folder_name + 'seed_' + str(args.seed)
    model_folder_name = model_folder_name + '_' + 'warmup_rate_' + str(args.warmup_rate)
    model_folder_name = model_folder_name + '_' + 'smooth_' + str(args.soft_label)
    model_folder_name = model_folder_name + '_prob_' + str(args.k_decay) + '_' + str(args.ct_decay) + '/'

    logs_dir = args.logs_dir + model_folder_name
    output_dir = args.output_dir + model_folder_name
    metrics_out_file = args.output_dir + model_folder_name + args.metrics_out_file
    model_file = args.output_dir + model_folder_name + 'model.ckpt'
    best_model_file = args.output_dir + model_folder_name + 'best_model.ckpt'
    para_file = args.output_dir + model_folder_name + 'para.json'
  
    return output_dir, model_file, metrics_out_file, logs_dir, best_model_file, para_file


def write_para(args, para_file):
    args = parser.parse_args()
    args = vars(args)
    with open(para_file, "a") as outfile:
        for k,v in args.items():
            json.dump({k:v}, outfile)
            outfile.write('\n')
    outfile.close()


def form_inputs(batch, ablation, global_step, soft_label=0.1):
    dic = {}
    cls_batch_size = max(int(batch[0].size(0)/2), 1)

    dic["infilling_input_ids"] = batch[0].cuda(non_blocking=True)
    dic["infilling_attention_mask"] = batch[1].cuda(non_blocking=True)
    dic["infilling_labels"] = batch[2].cuda(non_blocking=True)
    dic["infilling_mask_loc"] = batch[3].cuda(non_blocking=True)

    dic["positive_piece_input_ids"] = batch[4].cuda(non_blocking=True)
    dic["positive_piece_attention_mask"] = batch[5].cuda(non_blocking=True)
    dic["negative_piece_input_ids"] = batch[6].cuda(non_blocking=True)
    dic["negative_piece_attention_mask"] = batch[7].cuda(non_blocking=True)

    if global_step % 5 != 0:
        dic["class_input_ids"] = torch.cat([batch[8][:cls_batch_size, ...], batch[11][:cls_batch_size, ...]], 0).cuda(non_blocking=True)
        dic["class_attention_mask"] = torch.cat([batch[9][:cls_batch_size, ...], batch[12][:cls_batch_size, ...]], 0).cuda(non_blocking=True)
        dic["class_labels"] = torch.cat([batch[10][:cls_batch_size, ...], batch[13][:cls_batch_size, ...]], 0).cuda(non_blocking=True)
    else:
        dic["class_input_ids"] = torch.cat([batch[14][:cls_batch_size, ...], batch[17][:cls_batch_size, ...]], 0).cuda(non_blocking=True)
        dic["class_attention_mask"] = torch.cat([batch[15][:cls_batch_size, ...], batch[18][:cls_batch_size, ...]], 0).cuda(non_blocking=True)
        dic["class_labels"] = torch.cat([batch[16][:cls_batch_size, ...], batch[19][:cls_batch_size, ...]], 0).cuda(non_blocking=True)

    dic['ablation'] = ablation
    dic["soft_label"] = soft_label
    return dic


def form_inputs_eval(batch, ablation, dev_step):
    dic = {}

    dic["infilling_input_ids"] = batch[0].cuda(non_blocking=True)
    dic["infilling_attention_mask"] = batch[1].cuda(non_blocking=True)
    dic["infilling_labels"] = batch[2].cuda(non_blocking=True)
    dic["infilling_mask_loc"] = batch[3].cuda(non_blocking=True)

    dic["positive_piece_input_ids"] = batch[4].cuda(non_blocking=True)
    dic["positive_piece_attention_mask"] = batch[5].cuda(non_blocking=True)
    dic["negative_piece_input_ids"] = batch[6].cuda(non_blocking=True)
    dic["negative_piece_attention_mask"] = batch[7].cuda(non_blocking=True)

    if dev_step % 5 != 0:
        dic["class_input_ids"] = batch[8].cuda(non_blocking=True)
        dic["class_attention_mask"] = batch[9].cuda(non_blocking=True)
        dic["class_labels"] = batch[10].cuda(non_blocking=True)
        
        dic["tag_input_ids"] = batch[11].cuda(non_blocking=True)
        dic["tag_attention_mask"] = batch[12].cuda(non_blocking=True)
        dic["tag_labels"] = batch[13].cuda(non_blocking=True)
    else:
        dic["class_input_ids"] = batch[14].cuda(non_blocking=True)
        dic["class_attention_mask"] = batch[15].cuda(non_blocking=True)
        dic["class_labels"] = batch[16].cuda(non_blocking=True)
        
        dic["tag_input_ids"] = batch[17].cuda(non_blocking=True)
        dic["tag_attention_mask"] = batch[18].cuda(non_blocking=True)
        dic["tag_labels"] = batch[19].cuda(non_blocking=True)

    dic['ablation'] = ablation
    dic['eval_mode'] = True
    return dic


def train(output_dir, model_file, metrics_out_file, logs_dir, best_model_file, para_file, \
          model_type, ablation, batch_size, num_workers, inf_max_length, cls_max_length, event_max_length, conj_max_length, \
          dropout_prob, margin, local_rank, device, \
          lr, epochs, warmup_rate, weight_decay, soft_label, accumulation_step, grad_clip, \
          k_decay, ct_decay, \
          log_step, save_step, eval_step):
    
    curr_global_step = None

    if not os.path.exists(output_dir) and local_rank == 0:
        os.makedirs(output_dir)
        write_para(args, para_file)

    if os.path.exists(model_file): 
        print('continue training')
        checkpoint = torch.load(model_file, map_location=torch.device('cpu'))
        curr_global_step = checkpoint['global_step']

    tokenizer = Tokenizer_Mapping(model_type)
    print('Loading data')
    train_dataloader, eval_dataloader = get_loader(batch_size=batch_size, tkr=tokenizer, num_workers=num_workers, \
                                                   inf_max_length=inf_max_length, cls_max_length=cls_max_length, \
                                                   event_max_length=event_max_length, conj_max_length=conj_max_length, \
                                                   test=False)

    if local_rank == 0:
        writer = SummaryWriter(logs_dir)

    model = ClarET(model_type, dropout_prob, margin)
    model.to(device)

    num_train_steps = int(len(train_dataloader) * epochs / accumulation_step)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)   
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=math.floor(warmup_rate*num_train_steps), num_training_steps=num_train_steps)

    global_step = 0
    best_eval_loss = 10.0
    if local_rank == 0:
        log_loss = 0
        log_gloss = 0
        log_kloss = 0
        log_ctloss = 0

    if curr_global_step != None:
        try:
            model.load_state_dict(checkpoint['model'])
        except:
            model.load_state_dict(modify_state_dict(checkpoint['model']))
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        best_eval_loss = checkpoint['best_eval_loss']
    
    model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)

    for epoch_num in range(int(epochs)):
        
        train_dataloader.sampler.set_epoch(epochs)

        model.train()
        assert model.training

        tr_loss = 0
        tr_gloss = 0
        tr_kloss = 0
        tr_ctloss = 0
        accumulation_tr_steps = 0
        batch_tqdm = tqdm(train_dataloader)
        for step, batch in enumerate(batch_tqdm):
            global_step += 1

            if (curr_global_step != None) and (global_step <= curr_global_step):
                continue

            batch = tuple(t for t in batch)
            batch = form_inputs(batch, ablation, global_step, soft_label)
            g_loss, k_loss, ct_loss = model(**batch)
            
            if ablation == 'none':
                loss = g_loss + k_decay*k_loss + ct_decay*ct_loss

            elif ablation == 'wok':
                loss = g_loss + ct_decay*ct_loss

            elif ablation == 'woct':
                loss = g_loss + k_decay*k_loss

            elif ablation == 'woall':
                loss = g_loss
            
            if local_rank == 0:
                log_loss += loss.item()
                log_gloss += g_loss.item()
                log_kloss += k_loss.item()
                log_ctloss += ct_loss.item()

            if global_step % log_step == 0 and local_rank == 0:
                writer.add_scalar('ClarET/train_g_loss', log_gloss/log_step, global_step)
                writer.add_scalar('ClarET/train_k_loss', log_kloss/log_step, global_step)
                writer.add_scalar('ClarET/train_c_loss', log_ctloss/log_step, global_step)
                writer.add_scalar('ClarET/train_t_loss', log_ctloss/log_step, global_step)
                writer.add_scalar('ClarET/train_total_loss', log_loss/log_step, global_step)
                log_loss = 0
                log_gloss = 0
                log_kloss = 0
                log_ctloss = 0

            tr_loss += loss.item()/accumulation_step
            tr_gloss += g_loss.item()/accumulation_step
            tr_kloss += k_loss.item()/accumulation_step
            tr_ctloss += ct_loss.item()/accumulation_step

            loss.backward()

            if global_step % accumulation_step == 0:  # I don't know why this is here !!!
                # modify learning rate with special warm up BERT uses
                clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                accumulation_tr_steps += 1

                batch_tqdm.set_description("Epoch: {}; Loss: {}; Iteration".format(epoch_num+1, round(tr_loss / accumulation_tr_steps, 3)))

            if (global_step % save_step == 0) and local_rank == 0:
                state = {'model': model.module.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler':scheduler.state_dict(), 'global_step': global_step, 'best_eval_loss': best_eval_loss}
                torch.save(state, model_file)

            if global_step % eval_step == 0:
                model.eval()
                eval_tr_loss = 0
                eval_g_loss = 0
                eval_k_loss = 0
                eval_c_loss = 0
                eval_t_loss = 0
                eval_accumulation_tr_steps = 0
                eval_batch_tqdm = tqdm(eval_dataloader)

                for dev_step, batch in enumerate(eval_batch_tqdm):
                    with torch.no_grad():
                        batch = tuple(t for t in batch)
                        batch = form_inputs_eval(batch, ablation, dev_step)
                        g_loss, k_loss, c_loss, t_loss = model(**batch)
                        loss = g_loss + k_decay*k_loss + ct_decay*c_loss + ct_decay*t_loss
                        eval_tr_loss += float(loss.cpu().item())
                        eval_g_loss += float(g_loss.cpu().item())
                        eval_k_loss += float(k_loss.cpu().item())
                        eval_c_loss += float(c_loss.cpu().item())
                        eval_t_loss += float(t_loss.cpu().item())
                    eval_accumulation_tr_steps += 1

                eval_accumulation_tr_steps = torch.tensor(eval_accumulation_tr_steps).cuda()
                eval_tr_loss = torch.tensor(eval_tr_loss).cuda()
                eval_g_loss = torch.tensor(eval_g_loss).cuda()
                eval_k_loss = torch.tensor(eval_k_loss).cuda()
                eval_c_loss = torch.tensor(eval_c_loss).cuda()
                eval_t_loss = torch.tensor(eval_t_loss).cuda()
                torch.distributed.barrier()
                eval_accumulation_tr_steps = reduce_mean(eval_accumulation_tr_steps)
                eval_tr_loss = reduce_mean(eval_tr_loss)
                eval_g_loss = reduce_mean(eval_g_loss)
                eval_k_loss = reduce_mean(eval_k_loss)
                eval_c_loss = reduce_mean(eval_c_loss)
                eval_t_loss = reduce_mean(eval_t_loss)

                if local_rank == 0:
                    eval_tr_loss = round(float(eval_tr_loss.cpu()) / int(eval_accumulation_tr_steps.cpu()), 3)
                    eval_g_loss = round(float(eval_g_loss.cpu()) / int(eval_accumulation_tr_steps.cpu()), 3)
                    eval_k_loss = round(float(eval_k_loss.cpu()) / int(eval_accumulation_tr_steps.cpu()), 3)
                    eval_c_loss = round(float(eval_c_loss.cpu()) / int(eval_accumulation_tr_steps.cpu()), 3)
                    eval_t_loss = round(float(eval_t_loss.cpu()) / int(eval_accumulation_tr_steps.cpu()), 3)
                    writer.add_scalar('ClarET/eval_g_loss', eval_g_loss, global_step)
                    writer.add_scalar('ClarET/eval_k_loss', eval_k_loss, global_step)
                    writer.add_scalar('ClarET/eval_c_loss', eval_c_loss, global_step)
                    writer.add_scalar('ClarET/eval_t_loss', eval_t_loss, global_step)
                    writer.add_scalar('ClarET/eval_total_loss', eval_tr_loss, global_step)

                    if best_eval_loss > eval_tr_loss:
                        best_eval_loss = eval_tr_loss
                        state = {'model': model.module.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler':scheduler.state_dict(), 'global_step': global_step, 'best_eval_loss': best_eval_loss}
                        torch.save(state, best_model_file)
                model.train()
    if local_rank == 0:
        writer.close()


def main(args):
    local_rank = args.local_rank

    output_dir = args.output_dir
    metrics_out_file = args.metrics_out_file
    logs_dir = args.logs_dir

    model_type = args.model_type
    batch_size = args.batch_size
    num_workers = args.num_workers
    inf_max_length = args.inf_max_length
    cls_max_length = args.cls_max_length
    event_max_length = args.event_max_length
    conj_max_length = args.conj_max_length

    dropout_prob = args.dropout_prob
    margin = args.margin

    mode = args.mode
    ablation= args.ablation
    seed = args.seed
    lr = args.lr
    epochs = args.epochs
    warmup_rate = args.warmup_rate
    weight_decay = args.weight_decay
    soft_label = args.soft_label
    accumulation_step = args.accumulation_step
    grad_clip = args.grad_clip
    k_decay = args.k_decay
    ct_decay = args.ct_decay

    log_step = args.log_step
    save_step = args.save_step
    eval_step = args.eval_step

    dist.init_process_group(backend='nccl')
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    seed_all(seed)

    output_dir, model_file, metrics_out_file, logs_dir, best_model_file, para_file = generate_file_dir(args)

    if mode is None or mode == "train":
        train(output_dir=output_dir, 
              model_file=model_file,
              metrics_out_file=metrics_out_file,
              logs_dir=logs_dir, 
              best_model_file=best_model_file,
              para_file=para_file,
              model_type=model_type,
              ablation=ablation,
              batch_size=batch_size,
              num_workers=num_workers, 
              inf_max_length = inf_max_length,
              cls_max_length = cls_max_length,
              event_max_length = event_max_length,
              conj_max_length = conj_max_length,
              dropout_prob=dropout_prob,
              margin=margin,  
              local_rank=local_rank,
              device=device,
              lr=lr,
              epochs=epochs,
              warmup_rate=warmup_rate,
              weight_decay=weight_decay,
              soft_label=soft_label,
              accumulation_step=accumulation_step, 
              grad_clip=grad_clip,
              k_decay = k_decay,
              ct_decay = ct_decay,
              log_step=log_step, 
              save_step=save_step, 
              eval_step=eval_step)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Pretrain ClarET model and save')

    parser.add_argument("--local_rank", type=int, default=-1)

    # os parameters
    parser.add_argument('--output_dir', type=str, default='models/')
    parser.add_argument('--logs_dir', type=str, default='models/')
    parser.add_argument('--metrics_out_file', type=str, default="metrics.json")

    # data parameters
    parser.add_argument('--model_type', type=str, default='ClarET-large', help='ClarET-large, ClarET-base')
    parser.add_argument('--ablation', type=str, default='none', help='none, wok, woct, woall')
    parser.add_argument('--batch_size', type=int, default=12)    
    parser.add_argument('--num_workers', type=int, default=10)
    parser.add_argument("--inf_max_length", type=int, default=168)
    parser.add_argument("--cls_max_length", type=int, default=224)
    parser.add_argument("--event_max_length", type=int, default=30)
    parser.add_argument("--conj_max_length", type=int, default=10)

    # model parameters
    parser.add_argument('--dropout_prob', type=float, default=0.1)
    parser.add_argument('--margin', type=float, default=0.5)

    # training parameters
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--warmup_rate', type=int, default=0.03)
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--soft_label', type=float, default=0)
    parser.add_argument('--accumulation_step', type=int, default=12)
    parser.add_argument('--grad_clip', type=float, default=1.0)
    parser.add_argument('--k_decay', type=float, default=1.0)
    parser.add_argument('--ct_decay', type=float, default=1.0)

    # Other parameters
    parser.add_argument("--log_step", type=int, default=120)
    parser.add_argument("--save_step", type=int, default=480)
    parser.add_argument("--eval_step", type=int, default=7200)

    args = parser.parse_args()
    print('====Input Arguments====')
    print(json.dumps(vars(args), indent=2, sort_keys=True))
    print("=======================")
    main(args)
