import argparse
import os
import torch
import logging
import random
import numpy as np
from torch.utils.data.distributed import DistributedSampler
from peft import LoraConfig
from accelerate import Accelerator
from peft import PeftModelForCausalLM
import time
import torch.nn.functional as F
try:
    from torch.utils.tensorboard import SummaryWriter
except:
    from tensorboardX import SummaryWriter
from transformers import AdamW
from transformers import LlamaForCausalLM,LlamaTokenizer,LlamaConfig

from torch.utils.data import DataLoader, SequentialSampler, RandomSampler

from tqdm import tqdm, trange

from transformers import get_linear_schedule_with_warmup

import transformers

import data_utils
import utils

from typing import Dict, Optional, Sequence
logger = logging.getLogger(__name__)
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
def find_all_linear_names(model):
    cls=torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:
        lora_module_names.remove('lm_head')
    return list(lora_module_names)
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

def train(args, train_dataset, model,data_collator):
    tb_writer=SummaryWriter(args.tensorbord_dir)
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size,collate_fn=data_collator)
    t_total = args.max_steps
    args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.grad_steps) + 1
   
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_training_steps=t_total,num_warmup_steps=0)
    if args.bf16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.bf16_opt_level)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.batch_size)
    logger.info("  Gradient Accumulation steps = %d", args.grad_steps)
    logger.info("  Total optimization steps = %d",t_total)
    best_loss=100
    global_step = 0
    tr_loss=0
    logging_loss=0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            inputs=batch['input_ids']
            labels=batch['labels']
            masks=batch['attention_mask']
            model.train()
            outputs_pcot = model(inputs.to(args.device), attention_mask=masks.to(args.device),labels=labels.to(args.device))
           # print(inputs[0],"###",inputs[1],"###",labels[0],"###",labels[1])
            # pdb.set_trace()
            loss = outputs_pcot[0] # model outputs are always tuple in transformers (see doc)
            if args.grad_steps > 1:
                loss = loss / args.grad_steps

            if args.bf16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.grad_steps == 0:
                if args.bf16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                #model.zero_grad()
                optimizer.zero_grad()
                global_step += 1

                if  global_step % args.eval_steps == 0:
                    # tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.eval_steps, global_step)
                    # 
                    print(f"{global_step}:{(tr_loss - logging_loss)/args.eval_steps}")
                    logging_loss = tr_loss
                if loss.item()<best_loss:
                    best_loss=loss.item()
                    checkpoint_prefix = 'best_loss'
                    print("save_loss",best_loss)
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, -1))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                    unwrap_net = args.accelerator.unwrap_model(model_to_save)
                    unwrap_net.save_pretrained(output_dir)    
                    logger.info("Saving model checkpoint to %s", output_dir)

            if global_step > args.max_steps:
                epoch_iterator.close()
                break
            torch.cuda.empty_cache()
        if  global_step > args.max_steps:
            train_iterator.close()
            break
    tb_writer.close()
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    print('save to:',args.output_dir)
    logger.info("Saving model checkpoint to %s", args.output_dir)
    model_to_save = model.module if hasattr(model, 'module') else model  
    unwrap_net = args.accelerator.unwrap_model(model_to_save)
    unwrap_net.save_pretrained(args.output_dir)
        #tokenizer.save_pretrained(args.output_dir)
    return global_step, tr_loss / global_step

def main(args):
    print(args.train_data_file)
    set_seed(args)
    tokenizer =LlamaTokenizer.from_pretrained(args.model_name_or_path,model_max_length=args.max_input_length,
        padding_side="right")
    model = LlamaForCausalLM.from_pretrained(args.model_name_or_path,device_map=args.device)
    # More info: https://github.com/huggingface/transformers/pull/24906
    modules=find_all_linear_names(model)
    peft_config = LoraConfig(
            lora_alpha=16,
             lora_dropout=0.05,
             target_modules=modules,
             r=64,
             bias="none",
            task_type="CAUSAL_LM",
                )
    model = PeftModelForCausalLM(model, peft_config)
    model.print_trainable_parameters()

    model.supports_gradient_checkpointing = True  
    special_tokens_dict = dict()
    if tokenizer.pad_token is None:
        special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
    if tokenizer.eos_token is None:
        special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
    if tokenizer.bos_token is None:
        special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
    if tokenizer.unk_token is None:
        special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
    
    smart_tokenizer_and_embedding_resize(
        special_tokens_dict=special_tokens_dict,
        tokenizer=tokenizer,
        model=model,
    )
    logger.info("Training/evaluation parameters %s", args)
    # Training
    train_dataset = data_utils.load_and_cache_examples_single(args, tokenizer)
    data_collator = data_utils.DataCollatorForSupervisedDataset_single(tokenizer=tokenizer)
    print('len',len(train_dataset))     
    global_step, tr_loss = train(args, train_dataset, model,data_collator)
    logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    


    


    

if __name__ == "__main__":
    start_time=time.time()
    parser = argparse.ArgumentParser()
    parser.add_argument('--subsample', type=float, default=1.0)
    parser.add_argument('--alpha', type=float, default=0.5)
    parser.add_argument('--max_steps', type=int, default=16000) 
    parser.add_argument('--eval_steps', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--optimizer_name', type=str, default='AdamW')
    parser.add_argument('--lr', type=float, default=5e-5)
    parser.add_argument('--run', type=int, default=0)
    #parser.add_argument('--from_pretrained', type=str, default='google/t5-v1_1-base')
    parser.add_argument('--label_type', type=str, default='gt')
    parser.add_argument('--llm', type=str, default='palm')
    parser.add_argument('--max_input_length', type=int, default=750)
    parser.add_argument('--grad_steps', type=int, default=8)
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--gen_max_len', type=int, default=512)
    parser.add_argument('--a', type=float, default=0.1)
    parser.add_argument('--weight_decay', type=float, default=0.0)
    parser.add_argument('--seed', type=float, default=1)
    parser.add_argument('--parallelize', action='store_true')
    parser.add_argument('--model_type', type=str, default='task_prefix')
    #parser.add_argument('--model_type', type=str, default='standard')
    parser.add_argument('--bf16', action='store_true')
    parser.add_argument('--no_log', action='store_true')
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
   # parser.add_argument("--model_name_or_path", default="meta-llama/Llama-2-7b-hf", type=str,
                       # help="The model checkpoint for weights initialization.")
    parser.add_argument("--model_name_or_path", default="meta-llama/Llama-2-7b-hf", type=str,
                        help="The model checkpoint for weights initialization.")
    parser.add_argument('--output_rationale', action='store_true') 
    parser.add_argument('--overwrite_cache', action='store_true',
                        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--n_cpot', type=int, default=10)
    parser.add_argument("--tensorbord_dir", default=None, type=str, required=True)
    parser.add_argument('--device', type=int, default=6, help='device ID for CUDA')
    parser.add_argument("--train_data_file", default=None, type=str, required=True,
                        help="The input training data file (a text file).")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    args = parser.parse_args()
    args.device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
    args.accelerator=Accelerator()
    main(args)
    print(f'run time: {time.time() - start_time:.1f} s.')
    
