import torch
from torch.utils.data import DataLoader
from transformers import T5ForConditionalGeneration, Adafactor, AdamW, T5Config, BartTokenizer, \
    BartForSequenceClassification, BartTokenizerFast, BartConfig
from transformers.optimization import AdafactorSchedule, get_linear_schedule_with_warmup

from Classifier.bart import bart_model
from Classifier.dataset import T5Dataset, BartDataset

from Classifier.utils import read_json, train_one_epoch_forBART, \
    validate_forBART


def bart_train(args,logger,trpath,depath,tagpath,metric,pretrain,mode='cot'):
    device=args.device
    model_path=args.HFmodel
    # 日志及模型保存
    won_model_file= str(f"SaveModels/{args.dataset}/vbart/won_model{args.version}.bin")
    logger.info("----"+str(model_path)+"----")
    # data
    train_data=read_json(trpath)
    valid_data=read_json(depath)
    # tokenizer
    tokenizer = BartTokenizer.from_pretrained(model_path)
    # classes_map
    classes_map=read_json(tagpath)
    train_set = BartDataset(args, True, train_data, tokenizer, classes_map) 
    train_loader = DataLoader(train_set,
                              batch_size=args.gen_batch_size,
                              shuffle=True,
                              sampler=None,
                              pin_memory=True,
                              # num_workers=args.num_workers,
                              collate_fn=train_set.collate_fn,
                              drop_last=False)
    valid_set = BartDataset(args, False, valid_data, tokenizer, classes_map)
    valid_loader = DataLoader(valid_set,
                              batch_size=1,
                              shuffle=True,
                              pin_memory=True,
                              num_workers=0,
                              collate_fn=valid_set.collate_fn,
                              drop_last=False)
    # model
    model = bart_model.BartForSequenceClassification.from_pretrained(model_path,num_labels=len(classes_map))
    model.to(device)
    if args.use_Adafactor and args.use_AdafactorSchedule:
        # https://huggingface.co/docs/transformers/v4.27.2/en/main_classes/optimizer_schedules#transformers.Adafactor
        optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
        lr_scheduler = AdafactorSchedule(optimizer)
    elif args.use_Adafactor and not args.use_AdafactorSchedule:
        optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False,
                              lr=args.gen_learning_rate)
        total_steps = len(train_loader) * args.gen_train_epochs
        gen_lr_warmup_steps = total_steps * args.gen_lr_warmup_ratio
        lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,
                                                       num_warmup_steps=gen_lr_warmup_steps,
                                                       num_training_steps=total_steps)
    else:
        optimizer = AdamW(model.parameters(), lr=args.gen_learning_rate, weight_decay=args.gen_weight_decay)
        # total_steps=(len(train_data)//args.gen_batch_size)*args.gen_train_epochs if len(train_data)%args.gen_batch_size==0 else (len(train_data)//args.gen_batch_size+1)*args.gen_train_epochs
        total_steps=len(train_loader) * args.gen_train_epochs
        gen_lr_warmup_steps=total_steps*args.gen_lr_warmup_ratio
        lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,
                                                       num_warmup_steps=gen_lr_warmup_steps,
                                                       num_training_steps=total_steps)
    won_best_f1=0
    for epoch in range(args.gen_train_epochs):
        train_result = train_one_epoch_forBART(args=args,
                                               model=model,
                                               label2id=classes_map,
                                               device=device,
                                               data_loader=train_loader,
                                               epoch=epoch,
                                               optimizer=optimizer,
                                               lr_scheduler=lr_scheduler)

        dev_result = validate_forBART(args=args,
                                      model=model,
                                      label2id=classes_map,
                                      device=device,
                                      data_loader=valid_loader,
                                      epoch=epoch)

        results = {
            'learning_rate': optimizer.param_groups[0]["lr"],
            'train_loss': train_result['loss'],
            'dev_loss': dev_result['loss'],
            'train_accuracy': train_result['accuracy'],
            'dev_accuracy': dev_result['accuracy'],
            'dev_macro_f1': dev_result['macro_f1'],
            'dev_micro_f1': dev_result['micro_f1'],
            'dev_weighted_f1': dev_result['weighted_f1'],
            'won_dev_macro_f1': dev_result['won_macro_f1'],
            'won_dev_micro_f1': dev_result['won_micro_f1'],
            'won_dev_weighted_f1': dev_result['won_weighted_f1']
        }
        if epoch % 10 == 9:
            logger.info('Training/training loss: {:.4f}'.format(train_result['loss'] / 10, epoch * len(train_loader) + epoch))
            print('Training/training loss', train_result['loss'] / 10, epoch * len(train_loader) + epoch)

        logger.info("=" * 100)
        logger.info(f"epoch: {epoch}")
        # 记录训练中各个指标的信息
        for key, value in results.items():
            logger.info(f"{key}: {value}")
        # 保存在验证集上 macro_f1 最高的模型
        won_best_metric = 'won_'+metric + '_f1'
        if dev_result[won_best_metric] > won_best_f1:
            torch.save(model.state_dict(), won_model_file)
            won_best_f1 = dev_result[won_best_metric]
            print(f'won_best:{won_best_f1}')
            logger.info(f"won_best-mi-f1:{won_best_f1}  epoch:{epoch}")

