from dataclasses import dataclass
import os
import argparse
from typing import Optional
from collections.abc import Mapping
import numpy as np
import torch
import wandb
from omegaconf import OmegaConf
from transformers import TrainerCallback, TrainingArguments, PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin, _torch_collate_batch, pad_without_fast_tokenizer_warning


def get_compute_metrics(tokenizer):

    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        
        # shift labels to the right and remove last token
        labels = labels[:, 1:]
        preds = preds[:, :-1]

        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        padding_mask = (labels != tokenizer.pad_token_id)

        preds = np.where(padding_mask, preds, tokenizer.pad_token_id)

        # accuracy masked by pad token
        accuracies = 1 - ((preds != labels).sum(axis=-1) / padding_mask.sum(axis=-1))

        return {"accuracy": accuracies.mean().item(), "exact_match": (accuracies == 1.0).mean().item()}

    return compute_metrics


def preprocess_logits_for_metrics(preds, labels):
    return preds.argmax(dim=-1)


def get_latest_checkpoint(ckpt_dir: str):
    checkpoints = os.listdir(ckpt_dir)
    checkpoints = [os.path.join(ckpt_dir, c) for c in checkpoints if c.startswith("checkpoint-")]
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))
    return checkpoints[-1]


class LogGenerationsCallback(TrainerCallback):

    def __init__(self, tokenized_dataset, tokenizer, model, num_beams=1, prompt_length=32,
                 max_new_tokens=64):
        self.tokenized_dataset = tokenized_dataset
        self.tokenizer = tokenizer
        self.model = model
        self.num_beams = num_beams
        self.prompt_length = prompt_length
        self.max_new_tokens = max_new_tokens

    """ Callback to log examples of translations to wandb. """
    def on_evaluate(self, args: TrainingArguments, state, control, **kwargs):

        if not wandb.run or not state.is_world_process_zero:
            return

        table = wandb.Table(columns=["prompt", "generated", "ground truth"])
        dataset = self.tokenized_dataset["validation"]
        n_samples = len(dataset)

        for idx in [0, n_samples // 3, n_samples * 2 // 3]:
            examples = dataset[idx:idx+1]

            inputs_ids = torch.tensor(examples["input_ids"], device=self.model.device)
            inputs = inputs_ids[:, :self.prompt_length]
            targets = inputs_ids[:, self.prompt_length: self.prompt_length + self.max_new_tokens]

            predictions = self.model.generate(inputs, max_new_tokens=self.max_new_tokens, num_beams=self.num_beams)
            predictions = predictions[:, self.prompt_length:]

            decoded_inputs = self.tokenizer.batch_decode(inputs, skip_special_tokens=False)
            decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=False)
            decoded_labels = self.tokenizer.batch_decode(targets, skip_special_tokens=False)
            table.add_data(decoded_inputs[0], decoded_preds[0], decoded_labels[0])

            
            # inputs_ids = torch.tensor(examples["input_ids"], device=self.model.device)
            # decoded_inputs = self.tokenizer.batch_decode(inputs_ids, skip_special_tokens=False)
            # sentences = decoded_inputs[0].split('.')
            # last_sentence_words = sentences[-2].split()
            # decoded_inputs = ['.'.join(sentences[:-2]) + '. ' + ' '.join(last_sentence_words[:2])]  # first sentences + first two words of last sentence
            # decoded_labels = [' '.join(last_sentence_words[2:]) + '.']  # next words of last sentence

            # inputs = self.tokenizer(decoded_inputs, return_tensors="pt").input_ids.to(self.model.device)

            # predictions = self.model.generate(inputs, max_new_tokens=self.max_new_tokens, num_beams=self.num_beams)
            # predictions = predictions[:, inputs.shape[1]:]

            # decoded_preds = self.tokenizer.batch_decode(predictions, skip_special_tokens=False)

            # table.add_data(decoded_inputs[0], decoded_preds[0], decoded_labels[0])


        wandb.log({"generated": table, "global_step": state.global_step})


class EvaluateOnFirstStepCallback(TrainerCallback):
    def on_step_begin(self, args, state, control, **kwargs):
        if state.global_step == 0:
            control.should_evaluate = True



class LogClassificationsCallback(TrainerCallback):

    def __init__(self, tokenized_dataset, tokenizer, model):
        self.tokenized_dataset = tokenized_dataset
        self.tokenizer = tokenizer
        self.model = model

    """ Callback to log examples of predictions to wandb. """
    def on_evaluate(self, args: TrainingArguments, state, control, **kwargs):

        if not wandb.run or not state.is_world_process_zero:
            return

        table = wandb.Table(columns=["input", "prediction", "ground truth"])

        for idx in [0, 100, 200]:
            example = self.tokenized_dataset["validation"][idx:idx+1]

            inputs_ids = torch.tensor(example["input_ids"], device=self.model.device)
            prediction = self.model(input_ids=inputs_ids).logits.argmax(dim=-1).item()
            target = example["labels"][0]

            table.add_data(self.tokenizer.decode(inputs_ids[0]), prediction, target)

        wandb.log({"predictions": table, "global_step": state.global_step})


class EvaluateOnFirstStepCallback(TrainerCallback):
    def on_step_begin(self, args, state, control, **kwargs):
        if state.global_step == 0:
            control.should_evaluate = True



def get_config():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="config.yaml")
    args = parser.parse_args()

    config_path = args.config
    config = OmegaConf.load(config_path)
    return config



@dataclass
class DataCollatorForLanguageModeling(DataCollatorMixin):

    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def torch_call(self, examples):
        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], Mapping):
            examples_no_label = [{k: v for k, v in example.items() if k != "labels"} for example in examples]
            batch = pad_without_fast_tokenizer_warning(
                self.tokenizer, examples_no_label, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
            )
        else:
            batch = {
                "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
            }

        if "labels" in examples[0]:
            labels = _torch_collate_batch([example["labels"] for example in examples], self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
        else:
            labels = batch["input_ids"].clone()

        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        return batch