import os
import torch
from torch.utils.data import DataLoader
from accelerate.state import PartialState
from accelerate.utils import release_memory, InitProcessGroupKwargs
import datasets
from datasets import Dataset
from transformers import (
    AdamW, 
    AutoTokenizer, 
    DataCollatorWithPadding, 
    AutoModelForSequenceClassification, 
    AutoModelForCausalLM, 
    TrainingArguments, 
    Trainer,
    get_constant_schedule_with_warmup,
    DataCollatorWithPadding,
)
from reward_model.prm.prm_data import prepare_prm_data
from datetime import timedelta
from accelerate import Accelerator
import warnings
warnings.simplefilter('ignore')
from tqdm import tqdm

torch.cuda.empty_cache()
torch.set_printoptions(threshold=10_000)

def prm_classification_trainer(config):
    train_dataset = prepare_prm_data(config)
    tokenizer = AutoTokenizer.from_pretrained(config["reward_model"]["model_name"])
    tokenizer.pad_token = tokenizer.eos_token

    def preprocess_function(examples):
        return tokenizer(examples["text"], truncation=False)
    
    train_dataset = train_dataset.map(preprocess_function, batched=True)
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    training_args = TrainingArguments(
        output_dir=config["reward_model"]["output_dir"],
        learning_rate=config["reward_model"]["learning_rate"],
        per_device_train_batch_size=config["reward_model"]["per_device_train_batch_size"],
        # per_device_eval_batch_size=config["reward_model"]["per_device_eval_batch_size"],
        num_train_epochs=config["reward_model"]["num_train_epochs"],
        weight_decay=config["reward_model"]["weight_decay"],
        # evaluation_strategy=config["reward_model"]["evaluation_strategy"],
        save_strategy=config["reward_model"]["save_strategy"],
        # load_best_model_at_end=config["reward_model"]["load_best_model_at_end"],
        push_to_hub=config["reward_model"]["push_to_hub"],
    )

    model = AutoModelForSequenceClassification.from_pretrained(config["reward_model"]["model_name"], num_labels=2)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    trainer.train()
    if not os.path.exists(config["reward_model"]["output_dir"]):
        os.makedirs(config["reward_model"]["output_dir"])
    trainer.model.save_pretrained(config["reward_model"]["output_dir"])

class prm_ebm_trainer():
    def __init__(self, config):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(config["reward_model"]["tokenizer_name"], truncation_side="left")
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True

        kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=96000))
        self.accelerator = Accelerator(
            split_batches=False,
            mixed_precision='fp16',
            gradient_accumulation_steps=self.config["reward_model"]["gradient_accumulation_steps"],
            log_with='wandb',
            device_placement=True,
            kwargs_handlers=[kwargs]
        )
        if 'classification' in config['reward_model']['type']:
            self.model = AutoModelForSequenceClassification.from_pretrained(config["reward_model"]["model_name"], num_labels=1)
        elif 'generation' in config['reward_model']['type']:
            self.model = AutoModelForCausalLM.from_pretrained(config["reward_model"]["model_name"])
        else:
            raise NotImplementedError
        self.optimizer = AdamW(self.model.parameters(), lr=config["reward_model"]["learning_rate"])
        self.scheduler = get_constant_schedule_with_warmup(
            self.optimizer, num_warmup_steps=config["reward_model"]["num_warmup_steps"]
        )
        self.accelerator.print(
            f"Distributed: {self.accelerator.distributed_type}, Mixed precision: {self.accelerator.mixed_precision}"
        )

    def build_dataloader(self, batch_dataset):
        data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
        dataloader_params = {
            "batch_size": self.config["reward_model"]["per_device_train_batch_size"],
            "collate_fn": data_collator,
            "num_workers": 0,
            "pin_memory": True,
            "shuffle": True,
        }
        batch_dataloader = self.accelerator.prepare(DataLoader(batch_dataset, **dataloader_params))
        return batch_dataloader
    
    def compute_loss(self, model, inputs, return_outputs=False):
        # forward pass
        labels = inputs.pop("labels").type(torch.LongTensor).to(self.accelerator.device)
        inputs = inputs.to(self.accelerator.device)
        outputs = model(**inputs)
        output_logits = outputs.get("logits")
        input_ids = inputs["input_ids"].detach()
        attention_mask = inputs["attention_mask"].detach()
        alpha = self.config["reward_model"]["l2_reg_coef"]
        energy_temp = self.config["energy_temp"]
        l2_loss = 0
        
        energies = - output_logits.squeeze(-1)
        pos_energy = energies[labels == 1] / energy_temp
        neg_energy = energies[labels == 0] / energy_temp

        if pos_energy.shape[0] == 0:
            pos_energy = torch.zeros(1).to(self.accelerator.device)
        if neg_energy.shape[0] == 1:
            neg_energy = torch.zeros(1).to(self.accelerator.device)
        
        ml_loss = pos_energy.mean() - neg_energy.mean()
        if alpha != 0:
            l2_loss = alpha * energies.square().mean()
        
        loss = ml_loss + l2_loss
        self.accelerator.log({"total_loss": loss.item()})
        self.accelerator.log({"l2_loss": l2_loss.item() if alpha > 0 else 0.})
        self.accelerator.log({"ml_loss": ml_loss.item()})
        self.accelerator.log({"pos_energy": pos_energy.mean().item()})
        self.accelerator.log({"neg_energy": neg_energy.mean().item()})

    def train_step(self, train_loader):
        progress_bar = tqdm(range(len(train_loader)), desc="Training", disable=not self.accelerator.is_local_main_process)
        avg_loss = 0
        self.model.train()
        for _, batch in enumerate(train_loader):
            with self.accelerator.accmulate(self.model):
                loss = self.compute_loss(model=self.model, inputs=batch)
                avg_loss += loss.item()
                self.accelerator.backward(loss)
                if self.accelerator.sync_gradients:
                    grad_norm = self.accelerator.clip_grad_norm(self.model.parameters(), 1.0)
                    self.accelerator.log({"gradient_norm": grad_norm.mean()})
                    self.accelerator.log({"avg_loss": avg_loss / self.accelerator.gradient_accumulation_steps})
                    avg_loss = 0
                self.optimizer.step()
                self.lr_scheduler.step()
                self.optimizer.zero_grad()
            progress_bar.update(1)
            progress_bar.set_description(f"Loss: {loss.item():.4f}")
            self.accelerator.log({"learning_rate": self.lr_scheduler.get_last_lr()[0]})
            self.accelerator.log({"update_step": _})
        release_memory()

    def get_score_from_texts(self, input_texts, mode="sum_logits"):
        inputs = self.tokenizer(
            input_texts,
            return_tensors="pt",
            add_special_tokens=self.config["reward_model"]["add_special_tokens"],
            padding=True,
            truncation=True,
        ).to(self.accelrator.device)
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]

        self.model.eval()
        with torch.no_grad():
            outputs = self.model(**inputs)

            output_logits = outputs.get("logits")
            outputs_log_probs = torch.log_softmax(output_logits.float(), dim=-1)

            return output_logits.detach().squeeze(-1)