import torch
import copy
import transformers
import datasets
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification,DataCollatorWithPadding,AutoTokenizer
from accelerate import Accelerator
from torchkeras import KerasModel
from torch.utils.data import TensorDataset,Dataset,DataLoader
from torch.optim.lr_scheduler import LambdaLR
torch_seed = 999

def load_from_pt(dir,model):
    state_dict = torch.load(dir, map_location='cpu')
    model.load_state_dict(state_dict['state'],strict = False)
    return model

def model_setting(model, grad_ckpt = True):
    if grad_ckpt:
        model.supports_gradient_checkpointing = True
        model.gradient_checkpointing_enable()
        model.enable_input_require_grads()
        model.config.use_cache = False
    model.is_parallelizable = True
    model.model_parallel = True
    
class RC_Model(torch.nn.Module):
    def __init__(self,policy_model,reward_model,policy_tokenizer,reward_tokenizer, accelerator,ref_model = None):
        super().__init__()
        self.policy = policy_model
        self.reward = reward_model
        self.policy_tokenizer = policy_tokenizer
        self.reward_tokenizer = reward_tokenizer
        self.accelerator = accelerator
        if ref_model is None:
            self.ref = copy.deepcopy(policy_model)
        else:
            self.ref = ref_model
        self.ref.requires_grad_(False)
        self.policy.requires_grad_(True)
        self.reward.requires_grad_(False)

    def get_reward(self, prompt, response):
        inputs = self.reward_tokenizer(prompt*len(response),response,padding=True,return_tensors="pt").to(self.reward.device)
        return self.reward(**inputs).logits
    
    def get_logits(self, input_ids, prompt_len, model):
        
        logits = model(input_ids).logits.log_softmax(-1)
        logits = logits * (input_ids != self.policy_tokenizer.eos_token_id).unsqueeze(2)
        
        outputs_logits = torch.gather(logits[:,prompt_len-1:-1,:], dim=2, index=input_ids[:,prompt_len:].unsqueeze(2)).squeeze(2)

        return outputs_logits.sum(-1,keepdims=True)
        
    def generate(self,batch,prompt_len,generator,**generation_kwargs):

        generator.config.use_cache = True
        generator.gradient_checkpointing_disable()
        outputs_ids = generator.generate(**batch, **generation_kwargs)
        generator.config.use_cache = False
        generator.gradient_checkpointing_enable()

        if outputs_ids.shape[-1] - prompt_len < 128:
            outputs_ids = torch.concat([outputs_ids,torch.zeros((outputs_ids.shape[0],1)).to("cuda")],dim=1).long()
        
        prompts = self.policy_tokenizer.batch_decode(batch["input_ids"],skip_special_tokens=True)
        responses = self.policy_tokenizer.batch_decode(outputs_ids[:,prompt_len:],skip_special_tokens=True)
        length = (outputs_ids[:,prompt_len:] != self.policy_tokenizer.eos_token_id).float().sum(-1,keepdims=True)

        return outputs_ids,prompts, responses,length