from transformers.modeling_outputs import CausalLMOutputWithPast, CausalLMOutputWithCrossAttentions
from transformers import GPT2LMHeadModel, AdamW
from datetime import datetime
import json
import torch
from mydataset import MyDataset
from typing import *
import random
import numpy as np
import argparse
import transformers
from transformers import AutoTokenizer, CodeGenForCausalLM, AutoModelForCausalLM, XGLMForCausalLM, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
import os
from data_collator import MyDataCollatorWithPadding
from dataclasses import asdict, dataclass, field
from peft import get_peft_config, get_peft_model, TaskType, LoraConfig
from peft import PeftModel, PeftConfig

# nohup python -u train.py --task-type steps-genetate --gradient-checkpointing > logs/gpt2_large_luogu_added_steps-genetate.log &

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '1234'
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
# os.environ["ACCELERATE_USE_DEEPSPEED"] = 'true'


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def main(args):
    set_seed(args.seed)
    model_path = args.model_path
    print("Loading tokenizer...")

    #######################################

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # tokenizer = GPT2Tokenizer.from_pretrained(model_path)  # gpt2
    tokenizer.pad_token = tokenizer.eos_token

    print("Finish loading tokenizer!\nLoading model ...")
    if args.resume:
        print(f"Loaded model from {args.resume}")
        # model = CodeGenForCausalLM.from_pretrained(args.resume) # codegen
        model = AutoModelForCausalLM.from_pretrained(
            args.resume)  # codegen2 | starcoderbase
        # model = GPT2LMHeadModel.from_pretrained(model_path)  # gpt2

    else:
        # model = CodeGenForCausalLM.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(model_path)
        # model = GPT2LMHeadModel.from_pretrained(model_path)  # gpt2

    # Remove this section during testing or when fine-tuning is not needed
    # lora_config = LoraConfig(
    #     task_type=TaskType.CAUSAL_LM,
    #     r=8,
    #     lora_alpha=32,
    #     # target_modules=["qkv_proj", "out_proj", "fc_in", "fc_out"], # codegen
    #     # target_modules=["c_attn", "c_proj", "c_fc", "c_proj"], # starcoder
    #     target_modules=["q_proj", "v_proj",
    #                     "k_proj", "out_proj", "fc1", "fc2"],  # incoder
    #     lora_dropout=0.1,
    #     # codegen: lora[27,32) | [15,20);
    #     # incoder: [21, 24)
    #     # codegen2: [6,16)
    #     # starcoderbase: [9,24)
    #     layers_to_transform=[i for i in range(21, 24)],
    #     bias="none")

    # model = get_peft_model(model, lora_config)  # use during fine-tuning
    # model.print_trainable_parameters()
    # print(f"model with LoRA:\n{model}")
    print(model)
    for name, param in model.named_parameters():
        # print(f'Parameter: {name}, Requires grad: {param.requires_grad}\n Param: {param}')
        print(f'Parameter: {name}, Requires grad: {param.requires_grad}')

    print("Finish loading model!")

    train_data, test_data = get_dataset(args, tokenizer)

    # Save command to file
    argsdict = vars(args)
    # print(argsdict)
    if not os.path.exists(args.save_dir_root):
        os.mkdir(args.save_dir_root)
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)
    with open(os.path.join(args.save_dir, "command.json"), 'w', encoding='utf-8') as f:
        json.dump(argsdict, f, indent=2)

    run_training(args, train_data, test_data, tokenizer, model)

    print("Saving model ...")
    model.save_pretrained(os.path.join(args.save_dir, "final_checkpoint"))
    tokenizer.save_pretrained(os.path.join(args.save_dir, "final_checkpoint"))
    print("Finish save model!")


def get_dataset(args, tokenizer: AutoTokenizer):
    train_data = MyDataset(data_path=args.apps_train_files,
                           max_tokens=args.max_tokens, task_type=args.task_type, tokenizer=tokenizer)
    test_data = MyDataset(data_path=args.apps_test_files,
                          max_tokens=args.max_tokens, task_type=args.task_type, tokenizer=tokenizer)
    return train_data, test_data


def run_training(args, train_data: MyDataset, test_data: MyDataset, tokenizer: AutoTokenizer, model: CodeGenForCausalLM):
    start_iteration = 0
    if args.resume:
        start_iteration = int(args.resume.split("-")[-1])
        print("start_iteration = ", start_iteration)
    else:
        start_iteration = 0
    ## Dataloading ########################################################
    train_data.start_iteration = start_iteration

    ## Start Loop ########################################################
    print("Setting up trainer ...")

    training_args = transformers.TrainingArguments(
        output_dir=args.save_dir,

        # Use this to continue training if output_dir points to a checkpoint directory.
        overwrite_output_dir=False,
        resume_from_checkpoint=args.resume,
        evaluation_strategy='steps',
        eval_steps=args.eval_steps,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_acc_steps,

        learning_rate=args.lr,
        lr_scheduler_type=args.lr_scheduler,
        weight_decay=args.weight_decay,
        warmup_ratio=args.warmup_ratio,
        # warmup_steps=args.lr_warmup_steps,

        logging_dir=args.save_dir,
        logging_steps=args.log_freq,
        save_steps=args.save_freq,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",

        save_total_limit=1,

        no_cuda=args.no_cuda,
        seed=args.seed,
        data_seed=args.seed,
        # local_rank=args.local_rank,
        dataloader_drop_last=True,
        dataloader_num_workers=2,
        gradient_checkpointing=args.gradient_checkpointing,
        # deepspeed=args.deepspeed,
        fp16=args.fp16,
        # bf16=args.bf16,
    )

    data_collator = MyDataCollatorWithPadding(tokenizer, padding="longest")

    trainer = transformers.Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset=test_data,
        data_collator=data_collator,

    )
    print("Finish set up trainer!")

    print(f"Starting training...")
    ##########################
    # print(trainer.evaluate())
    if args.resume:
        trainer.train(resume_from_checkpoint=args.resume)
    else:
        trainer.train()
    print("Finish training!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Language Modelling on Code")

    parser.add_argument('--model-path',
                        default=r"Salesforce/codegen-350M-multi",
                        type=str)
    parser.add_argument('--adapter-path',
                        default=r"Salesforce/codegen-350M-multi",
                        type=str)
    parser.add_argument('--task-type', default=None, required=True, type=str)

    parser.add_argument(
        '--resume', default=None, type=str)
    parser.add_argument('--seed', default=2333, type=int)
    parser.add_argument('--no-cuda', default=False, type=bool)
    parser.add_argument('--gradient-checkpointing', action="store_true")
    # Dataloading
    parser.add_argument('--apps-dataroot', default='resources', type=str)
    parser.add_argument('--apps-train-files',
                        default='resources/train_luogu_added.json', type=str)
    parser.add_argument('--apps-test-files',
                        default='resources/test_luogu_added.json', type=str)
    parser.add_argument('--max-tokens', default=1024, type=int)

    # Training
    parser.add_argument('--epochs', default=8, type=int)
    parser.add_argument('--lr', default=2e-5, type=float)
    parser.add_argument(
        '--lr-scheduler', default="constant_with_warmup", type=str)
    parser.add_argument('--weight-decay', default=0.01, type=float)
    parser.add_argument('--warmup_ratio', default=0.15, type=float)
    parser.add_argument('--batch-size', default=2, type=int)
    parser.add_argument('--grad-acc-steps', default=8, type=int)
    # parser.add_argument('--local-rank', default=-1, type=int)
    # parser.add_argument(
    #     '--deepspeed', default="deepspeed_config.json", type=str)
    parser.add_argument('--fp16', default=True)
    parser.add_argument('--bf16', default=True)

    parser.add_argument('--save-dir-root', default="checkpoints", type=str)
    parser.add_argument('--save-dir', default="checkpoints/TEMP", type=str)
    parser.add_argument('--log-freq', default=5, type=int)  # 10 | 5
    parser.add_argument('--eval-steps', default=75, type=float)  # 150 | 75
    parser.add_argument('--save-freq', default=150, type=int)  # 300 | 150

    args = parser.parse_args()

    args.save_dir = os.path.join(
        args.save_dir_root, datetime.now().strftime("%m-%d-%H-%M"))

    main(args)
