import sys
import torch
import copy
import os
import wandb
import pickle
import transformers
import datasets
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from accelerate import Accelerator
from torchkeras import KerasModel
from torch.utils.data import TensorDataset,Dataset,DataLoader
from torch.optim.lr_scheduler import LambdaLR
from accelerate import Accelerator
from models import RC_Model,load_from_pt,model_setting
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification,DataCollatorWithPadding,AutoTokenizer



torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch_seed = 777
##--------------------------------------------init accelerator and wandb--------------------------------------------------------

file_path, arg_loss, arg_beta, arg_lambda = sys.argv[1:]
print(arg_loss)
arg_beta,arg_lambda = float(arg_beta), float(arg_lambda)

accelerator = Accelerator(mixed_precision="bf16",cpu=False,gradient_accumulation_steps=16)

if accelerator.is_local_main_process:
    wandb.init(project = arg_loss + "_tldr")

##--------------------------------------------check if exist checkpoints--------------------------------------------------------

folders = os.listdir(file_path)
checkpoint_path = file_path + "state_checkpoint"

if len(folders) == 0:
    init_batch = 0
else:
    init_batch = max([int(folder) if folder != "state_checkpoint" else -1 for folder in folders])

training_args = {
    "beta" : float(arg_beta),
    "batch_count" : init_batch
}

##-------------------------------------------load policy model and tokenizer----------------------------------------------------

print("init_batch: ", init_batch)
if init_batch == 0:
    policy_path = "arxiv_models/sft_model.pt/"
else:
    policy_path = file_path + str(init_batch)
    

policy_model = transformers.AutoModelForCausalLM.from_pretrained(policy_path,trust_remote_code=True, torch_dtype=torch.bfloat16,
                                                                 attn_implementation= "flash_attention_2")
model_setting(policy_model)


ref_model = transformers.AutoModelForCausalLM.from_pretrained("arxiv_models/sft_model.pt/",trust_remote_code=True, torch_dtype=torch.bfloat16,
                                                                attn_implementation= "flash_attention_2")

policy_tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8b",add_eos_token=False)
policy_tokenizer.pad_token_id = policy_tokenizer.eos_token_id
policy_tokenizer.truncation_side = "left"

##------------------------------------------load reward model and tokenizer---------------------------------------------------
reward_path = "arxiv_models/reward_model.pt/"
reward_model = transformers.AutoModelForSequenceClassification.from_pretrained(reward_path,trust_remote_code=True)
model_setting(reward_model)
reward_tokenizer = AutoTokenizer.from_pretrained(reward_path)

##--------------------------------------------------init rc model--------------------------------------------------------------
    
rc_model = RC_Model(policy_model,reward_model, policy_tokenizer, reward_tokenizer,accelerator,ref_model=ref_model)
if accelerator.is_local_main_process:
    wandb.watch(rc_model)

##--------------------------------------------------load dataset--------------------------------------------------------------

with open("train_set.pickle","rb") as f:
    train_set = pickle.load(f)
   
def my_collator(batch):
    batch = batch[0]
    prompt_inputs = policy_tokenizer(batch["prompt"],padding=True,return_tensors="pt",max_length=1024,truncation=True)
    response_inputs = policy_tokenizer([batch["prompt"] + response for response in batch["responses"]],padding=True,return_tensors="pt",max_length=1024,truncation=True)
    return batch["prompt"],prompt_inputs["input_ids"], batch["responses"], response_inputs["input_ids"]

torch.manual_seed(seed=12345)
train_loader = DataLoader(train_set, batch_size=1, collate_fn=my_collator,shuffle=True)

##--------------------------------------------------init my step runner---------------------------------------------------------

class StepRunner:

    def __init__(self, net, loss_fn, batch_count = 0, beta = 0.1, accelerator=None, stage="train", metrics_dict=None,
                 optimizer=None, lr_scheduler=None
                 ):
        
        self.net, self.loss_fn, self.metrics_dict, self.stage = net, loss_fn, metrics_dict, stage
        self.optimizer, self.lr_scheduler = optimizer, lr_scheduler
        self.accelerator = accelerator if accelerator is not None else Accelerator()
        self.beta = beta
        self.batch_count = batch_count

        if self.stage == 'train':
            self.net.train()
        else:
            self.net.eval()

    def __call__(self, batch):
        
        with accelerator.accumulate(self.net):
   
            prompt, prompt_ids, responses, response_ids = batch
            prompt_len = prompt_ids.shape[-1]-1

            rewards =  self.net.get_reward([prompt], responses)
            reward_std = rewards.std()
            policy_logits = self.net.get_logits(response_ids, prompt_len, self.net.policy) * self.beta
            ref_logits = self.net.get_logits(response_ids, prompt_len, self.net.ref) * self.beta
            
            if self.accelerator.is_local_main_process:
                    print(prompt)
                    print(responses)
                    print(rewards)

            if arg_loss == "rc":
                scores = (policy_logits - ref_logits - rewards / reward_std )
                loss = arg_lambda * torch.logsumexp((scores - scores.T)**2/2 / arg_lambda)
                loss = loss.mean()
            
            if arg_loss == "dpo":
                scores = (policy_logits - ref_logits)
                scores = (scores - scores.T)
                result_mask = (rewards - rewards.T) > 0
                loss = -F.logsigmoid(scores) * result_mask.float()
                if self.accelerator.is_local_main_process:
                    print(loss)
                loss = loss.mean()
                
            # backward()
            if self.optimizer is not None and self.stage == "train":
                self.accelerator.backward(loss)
                if self.accelerator.sync_gradients:
                    self.accelerator.clip_grad_norm_(self.net.parameters(), 1.0)
                self.optimizer.step()
                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()
                self.optimizer.zero_grad()
        
        # losses (or plain metrics that can be averaged)
        _loss = self.accelerator.gather(loss).mean()   

        step_metrics = {
                          "loss": _loss.item()
                       }
        
        step_losses = {self.stage + "_loss": _loss.item()}

        if self.batch_count % 10 == 9:
            if self.accelerator.is_local_main_process:
                wandb.log(step_metrics)

        if self.batch_count % 2000 == 1999:
            if self.accelerator.is_local_main_process:
                self.save_ckpt(file_path+str(self.batch_count+1))

        self.batch_count += 1
        return step_losses, step_metrics
        
    def save_ckpt(self, path):
        self.accelerator.unwrap_model(self.net.policy).save_pretrained(path,safe_serialization=False)

    def load_ckpt(self, path):
        self.net = self.net.from_pretrained(self.net, path)
        self.from_scratch = False

KerasModel.StepRunner = StepRunner
KerasModel.save_ckpt = StepRunner.save_ckpt
KerasModel.load_ckpt = StepRunner.load_ckpt

##-----------------------------------------load hyper parameters and keras model---------------------------------------------

if init_batch != 0:
    warmup_steps = 50
else:
    warmup_steps = 200

optimizer = torch.optim.RMSprop(rc_model.parameters(), lr=5e-7)
lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda step: min(1.0, (step + 1) / (warmup_steps + 1)))

keras_model = KerasModel(rc_model, loss_fn=None, accelerator = accelerator,
                         optimizer=optimizer, lr_scheduler=lr_scheduler,
                         **training_args)

##-----------------------------------------------------start training---------------------------------------------------------

keras_model.fit(train_data = train_loader,
                checkpoint_path = checkpoint_path,
                init_batch = init_batch,
                epochs=1, patience=100,
                monitor='train_loss', mode='min',
                mixed_precision="bf16",
                )
