import os
import torch
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import time

import argparse
# from datasets import DATASETS
from config import *
from model import *
# from dataloader import *
from trainer import *
from datasets.kp20k import load_data

from transformers import BitsAndBytesConfig, AutoConfig, AutoTokenizer
# from pytorch_lightning import seed_everything
from utils import set_seed
from model import LlamaForCausalLM
from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    prepare_model_for_kbit_training,
)
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
from logger import FileLogger

from trainer.custom_Peft import custom_PeftModel

try:
    os.environ['WANDB_PROJECT'] = PROJECT_NAME
except:
    print('WANDB_PROJECT not available, please set it in config.py')


def main(args, export_root=None):
    # seed_everything(args.seed)
    set_seed(args.seed)

    # world_size = int(os.environ.get('WORLD_SIZE', 1))
    ngpus_per_node = torch.cuda.device_count()
    # ddp = world_size > 1
    # if ddp:
    #     args.distributed = True

    if args.distributed:
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '3478'
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        args.world_size = ngpus_per_node * args.world_size
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args)


def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    args.n_gpu = ngpus_per_node
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.rank = args.rank * ngpus_per_node + gpu

    print('rank: {}, gpu: {}, device: {}'.format(args.rank, args.gpu, args.device))
    is_rank0 = args.rank == 0
    if args.absent:
        absent_root = '_absent'
    else:
        absent_root = ''
    export_root = EXPERIMENT_ROOT + absent_root + '/' + args.llm_base_model.split('/')[-1] + '/' + "kp20k"

    device_map = 'auto'
    if args.distributed:
        # if os.environ.get('LOCAL_RANK') is not None:
        #     local_rank = int(os.environ.get('LOCAL_RANK', '0'))
        #     print(gpu, int(os.environ.get('LOCAL_RANK', '0')))
        device_map = {'': gpu}
        
        dist.init_process_group(backend=args.dist_backend, init_method='env://', world_size=args.world_size, rank=args.rank)
        torch.manual_seed(0)
        rank = dist.get_rank()
    

    is_master = (not args.distributed) or (args.rank == 0)
    global log
    log = FileLogger(export_root, is_master=is_master, is_rank0=is_rank0, log_to_file=True)
    log.console(args)


    config = AutoConfig.from_pretrained(
        args.llm_base_tokenizer, cache_dir=args.llm_cache_dir)
    tokenizer = AutoTokenizer.from_pretrained(
        args.llm_base_tokenizer, cache_dir=args.llm_cache_dir)
    
    # train_loader, val_loader, test_loader, tokenizer, test_retrieval = dataloader_factory(args)
    train_loader, train_sampler = load_data(args, config, tokenizer, split="train")
    val_loader, eval_sampler = load_data(args, config, tokenizer, split="valid")
    # test_loader, _ = load_data(args, config, tokenizer, split="test")
    test_loader = None

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    model = LlamaForCausalLM.from_pretrained(
        args.llm_base_model,
        quantization_config=bnb_config,
        device_map=device_map,
        cache_dir=args.llm_cache_dir,
    )
    model.gradient_checkpointing_enable()
    model = prepare_model_for_kbit_training(model)

    
    lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        target_modules=args.lora_target_modules,
        lora_dropout=args.lora_dropout,
        bias='none',
        task_type="CAUSAL_LM",
    )

    ckpt_model_path = osp.join(export_root, "best-f1")
    if args.resume and osp.exists(ckpt_model_path):
        log.console(f"Resuming model checkpoint from {ckpt_model_path}...")
        model = custom_PeftModel.from_pretrained(model, ckpt_model_path, is_trainable=True, config=lora_config, gpu=gpu)
        # model.load_adapter(model_id=ckpt_model_path, adapter_name="default", is_trainable=True, device_map=device_map)
    else:
        model = get_peft_model(model, lora_config)
        log.console(f"Training model from scratch")
    
    model.print_trainable_parameters()

    model.config.use_cache = False
    if args.distributed:
        # model = DistributedDataParallel(model)
        model = DistributedDataParallel(model, device_ids=[rank], )
        # output_device=rank, find_unused_parameters=True)

    trainer = LLMTrainer(args, model, train_loader, val_loader, test_loader, tokenizer, config, export_root, args.use_wandb, log)
    
    start_time = time.time()
    trainer.train()
    end_time = time.time()
    torch.distributed.barrier()
    print('Training hours: {}'.format((end_time - start_time) / 3600))
    # trainer.test(test_retrieval)

    # load best model
    # exit()

    # for dataset in ['inspec', 'krapivin', 'nus', 'semeval', 'kp20k']:
    #     args.data_dir = args.data_dir.replace('kp20k', dataset)
    #     test_loader, _ = load_data(args, config, tokenizer, split="test")



if __name__ == "__main__":
    args.model_code = 'llm'
    set_template(args)
    main(args, export_root=None)
