import os
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, TrainingArguments, Trainer, pipeline
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from orm.orm_data import prepare_orm_data

def orm_ppo_trainer(config):
    ppo_config = PPOConfig(
        model_name=config['generator']['model_name'],
        learning_rate=config['generator']['learning_rate'],
        log_with=config['generator']['report_to'],
    )
    sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": config["generator"]["batch_size"]}
    train_dataset = prepare_orm_data(config)
    tokenizer = AutoTokenizer.from_pretrained(config['generator']["model_name"])
    tokenizer.pad_token = tokenizer.eos_token
    def preprocess_function(examples):
        return tokenizer(examples["text"], truncation=True)
    train_dataset = train_dataset.map(preprocess_function, batched=True)
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    model = AutoModelForCausalLMWithValueHead.from_pretrained(config['generator']["model_name"])
    ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config['generator']["model_name"])

    ppo_trainer = PPOTrainer(
        ppo_config,
        model,
        ref_model,
        tokenizer,
        dataset=train_dataset,
        data_collator=data_collator
    )
    device = ppo_trainer.accelerator.device
    if ppo_trainer.accelerator.num_processes == 1:
        device = 0 if torch.cuda.is_available() else 'cpu'
    reward_model = pipeline("text-classification", model=config['reward_model']['model_name'])

    generation_kwargs = {
        "min_length": -1,
        "top_k": 0.0,
        "top_p": 1.0,
        "do_sample": True,
        "pad_token_id": tokenizer.eos_token_id,
    }
    for epoch in tqdm(range(config['generator_trainer']['num_train_epochs']), "epoch: "):
        for batch in tqdm(ppo_trainer.data_loader):
            query_tensors = batch["input_ids"].to(device)

            ### get response from SFT Model
            response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
            batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

            ### Compute reward scores
            texts = [q + r for q, r in zip(batch["text"], batch["response"])]
            pipe_outputs = reward_model(texts)
            rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

            ### Run PPO step
            stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
            ppo_trainer.log_stats(stats, batch, rewards)
    
    ppo_trainer.save_model(f"{config['generator']['output_dir']}/model_{config['generator_trainer']['num_train_epochs']}")