import argparse
import json
import os
import time
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

from dataloader import Recorddataset
from evaluate import eval
# from model import Model
from model import Model


def main():
    lambda_data = 10
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu", default=0, type=int)

    # parser.add_argument("--ptlm", default='microsoft/BiomedVLP-CXR-BERT-general', type=str)
    # parser.add_argument("--ptlm", default='AshtonIsNotHere/GatorTron-OG-bc-ctr-nli', type=str) //73

    parser.add_argument("--ptlm", default='cnut1648/biolinkbert-large-mnli-snli', type=str)

    # test datasets ：total_loss_per_epoch:  696.2706050872803 best_val 0.75 best_r 0.84 best_p 0.6774193548387096 best_val_epoch 18

    ####################################################################################################

    # lamda = 0;best_val 0.7364016736401673 best_r 0.88 best_p 0.6330935251798561 best_val_epoch 12
    # lamda = 0.1;best_val 0.7663551401869159 best_r 0.82 best_p 0.7192982456140351 best_val_epoch 14
    # lamda = 0.15; best_val 0.7572815533980584 best_r 0.78 best_p 0.7358490566037735 best_val_epoch 16
    # lamda = 0.2;best_val 0.7727272727272727 best_r 0.85 best_p 0.7083333333333334 best_val_epoch 21
    # lamda = 0.25; best_val 0.7699530516431924 best_r 0.82 best_p 0.7256637168141593 best_val_epoch 11
    # lamda = 0.3;best_val 0.7649769585253456 best_r 0.83 best_p 0.7094017094017094 best_val_epoch 13
    # lamda = 1; best_val 0.6782006920415224 best_r 0.98 best_p 0.5185185185185185 best_val_epoch 2

    # parser.add_argument("--ptlm", default="albert/albert-base-v1",type=str)

    # parser.add_argument("--ptlm", default='valenaparicio16/bioBERT-finetuned-financial-phrasebank', type=str)
    # parser.add_argument("--ptlm", default='pritamdeka/BioBert-PubMed200kRCT', type=str)

    parser.add_argument("--lmn", default="deberta", type=str, help="name of the language model")
    parser.add_argument('--data', type=str, help='data dir')
    parser.add_argument("--epoch", default=20, type=int)
    parser.add_argument("--eval_every", default=400, type=int)
    parser.add_argument("--prompt", default=2, choices=[0, 1, 2, 3], type=int)
    parser.add_argument("--mode", default='trn', choices=['mix', 'trn'])
    parser.add_argument("--from_check_point", default=False, type=bool)
    parser.add_argument("--tokenizer_dir", default=None, type=str, help='the tokenizer check point dir')
    parser.add_argument("--model_dir", default=None, type=str, help='the model check point dir')
    parser.add_argument("--seed", default=1024, type=int)
    args = parser.parse_args()
    # 测试不同alpha
    args.alpha_list = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
    torch.cuda.set_device(args.gpu)
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
    if args.mode == 'trn':
        trn_dataset = Recorddataset(args, args.data, "train")
    else:
        trn_dataset = Recorddataset(args, args.data, "trn&dev")
    dev_dataset = Recorddataset(args, args.data, "dev")
    tst_dataset = Recorddataset(args, args.data, "test")

    test_batch_size = 128
    trn_loader = DataLoader(trn_dataset, batch_size=4, shuffle=True, drop_last=False)
    dev_loader = DataLoader(dev_dataset, batch_size=2, shuffle=False, drop_last=False)
    tst_loader = DataLoader(tst_dataset, batch_size=test_batch_size, shuffle=False, drop_last=False)

    seed_val = args.seed
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)
    for test_alpha in args.alpha_list:
        args.alpha = test_alpha
        output_dir = "./result/lamda_{}_{}_prompt{}_mode{}_epoch{}_eval{}_alpha{}/".format(lambda_data, args.ptlm,
                                                                                           args.prompt,
                                                                                           args.mode, args.epoch,
                                                                                           args.eval_every, args.alpha)
        os.makedirs(output_dir, exist_ok=True)
        epochs = args.epoch
        num_total_steps = len(trn_loader) * epochs
        num_warmup_steps = len(trn_loader) * int(args.epoch / 8)

        model = Model(args, args.ptlm, args.from_check_point, args.tokenizer_dir, args.model_dir)
        model.to(device)

        optimizer = AdamW(model.parameters(), lr=5e-6, correct_bias=True)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps,
                                                    num_training_steps=num_total_steps)

        best_val, best_val_epoch = 0, 0
        best_recall, best_precision = 0, 0
        for epoch in range(epochs):
            total_loss = 0
            for iter, (sent, label) in enumerate(
                    tqdm(trn_loader, desc=f'epoch: {epoch + 1}/{epochs}')):  # data = (statement, trail1,trail2,label)
                label = label.to(device)
                output = model(sent, label, device)
                pred = torch.argmax(output[1], dim=-1)
                total_loss += output[0].item()
                optimizer.zero_grad()
                output[0].backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                if iter % args.eval_every == 0 and iter != 0:
                    with torch.no_grad():
                        train_l, train_f, train_p, train_r = eval(model, trn_loader, device, print_on_screen=False)
                        print(
                            f"The Train result at epoch {epoch + 1} iter {iter}: train_loss: {train_l}, train_f1: {train_f}, train_precision: {train_p}, train_recall: {train_r}")

                        l, f, p, r = eval(model, dev_loader, device, print_on_screen=False)
                    print(
                        f"The Validation result at epoch {epoch + 1} iter {iter}: val_loss: {l}, val_f1: {f}, val_precision: {p}, val_recall: {r}")
                    if f > best_val:
                        best_val_epoch = epoch + 1
                        best_val = f
                        best_precision = p
                        best_recall = r
                        model.save_model(output_dir)
                        Test_Results = {}
                        for sent, uuid in tqdm(tst_loader):
                            with torch.no_grad():
                                # outputs = model(sent,device)
                                label_test = []
                                outputs = model(sent, label_test, device)
                                # print(outputs)
                                output = outputs[0]
                                for i in range(len(output)):
                                    if torch.argmax(output[i]) == 0:
                                        Test_Results[str(uuid[i])] = {"Prediction": 'Contradiction'}
                                    else:
                                        Test_Results[str(uuid[i])] = {"Prediction": "Entailment"}

                        with open("{}/{}_best_results.json".format(output_dir, str(epoch)), 'w') as jsonFile:
                            jsonFile.write(json.dumps(Test_Results, indent=4))

            # Test_Results = {}
            # for sent, uuid in tqdm(tst_loader):
            #     with torch.no_grad():
            #         # outputs = model(sent,device)
            #         label_test = []
            #         outputs = model(sent, label_test, device)
            #         # print(outputs)
            #         output = outputs[0]
            #         for i in range(len(output)):
            #             if torch.argmax(output[i]) == 0:
            #                 Test_Results[str(uuid[i])] = {"Prediction": 'Contradiction'}
            #             else:
            #                 Test_Results[str(uuid[i])] = {"Prediction": "Entailment"}

            # with open("{}/{}_write_results.json".format(output_dir, str(epoch+1)), 'w') as jsonFile:
            #     jsonFile.write(json.dumps(Test_Results, indent=4))

            print("total_loss_per_epoch: ", total_loss, "best_val", best_val, 'best_r', best_recall, 'best_p',
                  best_precision, "best_val_epoch", best_val_epoch)
            if args.mode == 'mix':
                Test_Results = {}
                for (sent, uuid) in tqdm(tst_loader):
                    outputs = model(sent, label, device)
                    output = outputs[1]
                    for i in range(4):
                        if torch.argmax(output[i]) == 0:
                            Test_Results[str(uuid[i])] = {"Prediction": 'Contradiction'}
                        else:
                            Test_Results[str(uuid[i])] = {"Prediction": "Entailment"}

                with open("{}/epoch{}_results.json".format(output_dir, epoch + 1), 'w') as jsonFile:
                    jsonFile.write(json.dumps(Test_Results, indent=4))


if __name__ == '__main__':
    main()
