'''
Train Feedback model with BERT & source tasks 
python -m torch.distributed.launch --nproc_per_node=8 --use_env train_bert_multitask.py
'''
import argparse
import os
import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
import copy
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist

from models import LanguageBERTModel_MT
from transformers import BertTokenizer as Tokenizer

import utils
from dataset import create_dataset, create_sampler, create_loader
from scheduler import create_scheduler
from optim import create_optimizer
from evaluate import evaluate_task
import transformers
transformers.logging.set_verbosity_error()

# train
def train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, scheduler):
    model.train()
    # logger
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))

    header = 'Train Epoch: [{}]'.format(epoch)
    print_freq = 50
    step_size = 100
    # warmup
    warmup_iterations = warmup_steps * step_size
    # training step
    for i, (text1, text2, targets, tasks) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        # convert target and inputs 
        targets = targets.to(device, non_blocking=True)
        text_input = tokenizer(text1, text2, padding='longest', truncation=True, max_length=512)
        for i in range(len(text2)):
            if text2[i] == '':
                text_len = text_input['token_type_ids'][i].index(1) - 1
                text_input['input_ids'][i] = text_input['input_ids'][i][:text_len] + [0] * (len(text_input['input_ids'][i]) - text_len)
                text_input['token_type_ids'][i] = text_input['token_type_ids'][i][:text_len] + [0] * (len(text_input['token_type_ids'][i]) - text_len)      
        text_input['input_ids'] = torch.tensor(text_input['input_ids']).to(device)
        text_input['token_type_ids'] = torch.tensor(text_input['token_type_ids']).to(device)
        text_input['attention_mask'] = torch.tensor(text_input['attention_mask']).to(device)
        # model forward
        loss = model(text_input, targets=targets, tasks=tasks, train=True)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(loss=loss.item())

        if epoch == 0 and i % step_size == 0 and i <= warmup_iterations:
            scheduler.step(i // step_size)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}

# test
@torch.no_grad()
def evaluate(model, data_loader, tokenizer, device):
    model.eval()

    metric_logger = utils.MetricLogger(delimiter="  ")
    # record the results 
    header = 'Evaluation:'
    print_freq = 50
    global_rank = utils.get_rank()
    world_size = utils.get_world_size()
    # eval step
    result_metric = 0.0
    results = []
    for text1, text2, targets, tasks in metric_logger.log_every(data_loader, print_freq, header):
        # convert target and inputs 
        targets = targets.to(device, non_blocking=True)
        text_input = tokenizer(text1, text2, padding='longest', truncation=True, max_length=512)
        for i in range(len(text2)):
            if text2[i] == '':
                text_len = text_input['token_type_ids'][i].index(1) - 1
                text_input['input_ids'][i] = text_input['input_ids'][i][:text_len] + [0] * (len(text_input['input_ids'][i]) - text_len)
                text_input['token_type_ids'][i] = text_input['token_type_ids'][i][:text_len] + [0] * (len(text_input['token_type_ids'][i]) - text_len)      
        text_input['input_ids'] = torch.tensor(text_input['input_ids']).to(device)
        text_input['token_type_ids'] = torch.tensor(text_input['token_type_ids']).to(device)
        text_input['attention_mask'] = torch.tensor(text_input['attention_mask']).to(device)
        # getting model predict result
        prediction = model(text_input, targets=targets, tasks=tasks, train=False)

        results += prediction
    # save result for each node
    json.dump(results, open(os.path.join(args.output_dir, "result_rank{}.json".format(global_rank)), 'w', encoding='utf-8'), ensure_ascii=False, indent=4)
    
    # block until all process finish
    if args.distributed:
        dist.barrier()
    # evaluate
    if utils.is_main_process():
        predictions = []
        targets = []
        tasks = []
        for i in range(world_size):
            with open(os.path.join(args.output_dir, "result_rank{}.json".format(i)), encoding='utf-8') as f:
                result = json.load(f)
            for item in result:
                predictions.append(item['prediction'])
                targets.append(item['target'])
                tasks.append(item['task'])
        # getting results
        result_metric = evaluate_task(predictions, targets, args.task, tasks)
        result_metric = round(result_metric, 6)
        print("Result for {} : {}".format(args.task, result_metric))
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    if args.distributed:
        dist.barrier()

    stats = {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
    stats.update({"result_metric": result_metric})
    return stats


def main(args, config):
    utils.init_distributed_mode(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    #### Dataset ####
    print("Creating dataset")
    train_dataset, test_dataset = create_dataset('multitask', config)

    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()
        samplers = create_sampler([train_dataset, test_dataset], [True, False], num_tasks, global_rank)
    else:
        samplers = [None, None]

    train_loader, test_loader = create_loader([train_dataset, test_dataset], samplers,
                                              batch_size=[args.batch_size, args.batch_size],
                                              num_workers=[4, 4], is_trains=[True, False],
                                              collate_fns=[None, None])

    #### Model ####
    print("Creating model")
    model = LanguageBERTModel_MT(config, text_encoder=args.text_encoder, use_prompt=args.use_prompt, only_prompt=args.only_prompt, 
                  fix_prompt=args.fix_prompt, prompt_config=args.prompt_config, fix_prompt_pre_round=args.fix_prompt_pre_round,
                  fix_word_embeddings=args.fix_word_embeddings, 
                  multitask_train_prompt=args.multitask_train_prompt)
    
    ### save prompt2index 
    if args.use_prompt and args.multitask_train_prompt:
        task2prompt_index = model.text_encoder.task2prompt_index
        json.dump(task2prompt_index, open(os.path.join(args.output_dir, 'task2prompt_index.json'), 'w'), ensure_ascii=False, indent=4)    
    if args.use_prompt and args.multitask_train_prompt:
        task2head_index = model.task2head_index
        json.dump(task2head_index, open(os.path.join(args.output_dir, 'task2head_index.json'), 'w'), ensure_ascii=False, indent=4)    


    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint, map_location='cpu')
        state_dict = checkpoint['model']

        # for key in list(state_dict.keys()):
        #     if 'cls_head' in key:
        #         del state_dict[key]

        msg = model.load_state_dict(state_dict, strict=False)
        print('load checkpoint from %s' % args.checkpoint)
        print('[Missing_keys]', msg.missing_keys)

    model = model.to(device)

    tokenizer = Tokenizer.from_pretrained(args.text_encoder)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    # if utils.is_main_process():
    #     save_obj = {
    #         'model': model_without_ddp.state_dict(),
    #         'config': config,
    #         'epoch': -1,
    #     }
    #     torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_initial.pth'))
    # if args.distributed:
    #     dist.barrier()

    arg_opt = utils.AttrDict(config['optimizer'])
    optimizer = create_optimizer(arg_opt, model)
    arg_sche = utils.AttrDict(config['schedular'])
    lr_scheduler, _ = create_scheduler(arg_sche, optimizer)

    max_epoch = config['schedular']['epochs']
    warmup_steps = config['schedular']['warmup_epochs']
    best = 0
    best_epoch = 0

    print("Start training")
    start_time = time.time()

    for epoch in range(0, max_epoch):
        if not args.evaluate:
            if args.distributed:
                train_loader.sampler.set_epoch(epoch)
            train_start_time = time.time()
            train_stats = train(model, train_loader, optimizer, tokenizer, epoch, warmup_steps, device, lr_scheduler)
            train_time = str(datetime.timedelta(seconds=int(time.time() - train_start_time)))

        test_stats = evaluate(model, test_loader, tokenizer, device)

        if utils.is_main_process():
            if args.evaluate:
                log_stats = {**{f'test_{k}': v for k, v in test_stats.items()},
                             'epoch': epoch,
                             }

                with open(os.path.join(args.output_dir, "log.txt"), "a") as f:
                    f.write(json.dumps(log_stats) + "\n")
            else:
                log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                             **{f'test_{k}': v for k, v in test_stats.items()},
                             'epoch': epoch,
                             'train_time': train_time
                             }

                with open(os.path.join(args.output_dir, "log.txt"), "a") as f:
                    f.write(json.dumps(log_stats) + "\n")
                
                # save_obj = {
                #         'model': model_without_ddp.state_dict(),
                #         'epoch': epoch,
                #     }
                # torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_{}.pth'.format(epoch)))

                if float(test_stats['result_metric']) > best:
                    save_obj = {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'config': config,
                        'epoch': epoch,
                    }
                    torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
                    best = float(test_stats['result_metric'])
                    best_epoch = epoch

        if args.evaluate:
            break
        lr_scheduler.step(epoch + warmup_steps + 1)
        if args.distributed:
            dist.barrier()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

    if utils.is_main_process():
        with open(os.path.join(args.output_dir, "log.txt"), "a") as f:
            f.write("best epoch: %d" % best_epoch)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # task settings
    parser.add_argument('--task', default='multitask')
    parser.add_argument('--config', default='./configs/Source/multitask.yaml')
    parser.add_argument('--output_dir', default='output/bert_base/p20/seed42/multitask/finetuning')
    # model settings
    parser.add_argument('--checkpoint', default='')
    parser.add_argument('--text_encoder', default='bert-base-uncased')
    parser.add_argument('--use_prompt', action='store_true')
    parser.add_argument('--only_prompt', action='store_true')
    parser.add_argument('--fix_prompt', action='store_true')
    parser.add_argument('--fix_prompt_pre_round', action='store_true')
    parser.add_argument('--multitask_train_prompt', action='store_true')
    parser.add_argument('--fix_word_embeddings', action='store_true')
    parser.add_argument('--prompt_config', default='configs/config_prompt_p20.json', type=str)
    parser.add_argument('--evaluate', action='store_true')
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--batch_size', default=16, type=int)
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument('--distributed', default=False, type=bool)
    # dataset settings
    parser.add_argument('--pre_round_invisible', action='store_true')
    parser.add_argument('--dataset_root', default='/home/zenghang/Bi-KnowledgeTransfer/Language/data', type=str)
    parser.add_argument('--task2data_map', default='/home/zenghang/Bi-KnowledgeTransfer/Language/configs/task2data.json')
    parser.add_argument('--task_info', default='/home/zenghang/Bi-KnowledgeTransfer/Language/configs/task_info.json')
    args = parser.parse_args()
    
    # load training config 
    config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

    config['file_root'] = args.dataset_root
    with open(args.task2data_map) as f:
        task2data = json.load(f)

    config['train_file'], config['test_file'] = [], []
    tasks = copy.copy(config['tasks']['curr_round'])

    if not args.pre_round_invisible: 
        tasks += config['tasks']['pre_round']
    for t in tasks:
        config['train_file'] += task2data[t]['train']
        config['test_file'] += task2data[t]['test']

    with open(args.task_info) as f:
        config['task_info'] = json.load(f)

    # save config
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
    json.dump(vars(args), open(os.path.join(args.output_dir, 'args.json'), 'w'), ensure_ascii=False, indent=4)
    if args.use_prompt:
        json.dump(json.load(open(args.prompt_config)), open(os.path.join(args.output_dir, 'config_prompt.json'), 'w'), ensure_ascii=False, indent=4)

    main(args, config)
