from torch.utils.data import Dataset, DataLoader
from transformers import GPTNeoXForCausalLM, AutoTokenizer
import os
import torch
from os.path import join


class ParaRelDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, i):
        return {
            'prompt': self.df['prompt'][i].strip(),
            'true object': self.df['object'][i],
            'true object first token id': self.df['true object first token id'][i]
        }



def prepare_batch_position_ids(batch_inputs):
    position_ids = []
    batch_max_len = batch_inputs['attention_mask'].shape[1]
    for i in range(batch_inputs['attention_mask'].shape[0]):
        n_tokens_i = batch_inputs['attention_mask'][i].sum()
        n_pad_i = batch_max_len - n_tokens_i
        pos_ids_i = torch.tensor([0] * n_pad_i + list(range(n_tokens_i)))
        position_ids.append(pos_ids_i)
    return torch.stack(position_ids)


def load_model(step=0, model_size="1b", deduped=True):
    """
    Args:
        model_size: one of (70m, 160m, 410m, 1b, 1.4b, 2.8b, 6.9b, 12b)
    """
    
    cache_dir = join(os.environ['TRANSFORMERS_CACHE'], f"pythia-{model_size}-deduped/step{step}")

    model = GPTNeoXForCausalLM.from_pretrained(
        f"EleutherAI/pythia-{model_size}-deduped" if deduped else f"EleutherAI/pythia-{model_size}",
        revision=f"step{step}",
        cache_dir=cache_dir,
        torch_dtype=torch.float16,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        f"EleutherAI/pythia-{model_size}-deduped" if deduped else f"EleutherAI/pythia-{model_size}",
        revision=f"step{step}",
        cache_dir=cache_dir,
    )

    return model, tokenizer

def untuple(x):
    return x[0] if isinstance(x, tuple) else x
    