from typing import Optional, Any

import torch
from datasets.arrow_dataset import Dataset
from transformers.tokenization_utils_fast import BatchEncoding, PreTrainedTokenizerFast
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset as TorchDataset


def group_texts(
    input_ids: list[list[int]],
    max_length: int,
    bos_token_id: Optional[int] = None,
    eos_token_id: Optional[int] = None,
) -> list[list[int]]:
    """
    Group `input_ids` to a batch of inputs of `max_length`.

    Note:
     - Each sequence may be truncated at `max_length`.
     - Tokens beyond `max_length` in each sequence are discarded.
     - Each instance in the output has `max_length` input_ids.
    """
    ret = []
    cur = []
    if bos_token_id is not None:
        cur.append(bos_token_id)
    for seq in input_ids:
        remaining = max_length - len(cur)
        # if space left, add as much of seq as possible
        # we cut off remainder
        if remaining > 0:
            cur.extend(seq[:remaining])
        else:
            # Since sequence is full
            # we can override final token with EOS token if desired
            if eos_token_id is not None:
                cur[-1] = eos_token_id
            ret.append(cur)
            cur = []
            if bos_token_id is not None:
                cur.append(bos_token_id)
            # extend seq potentially accounting for bos_token_id
            # we again cut of remainder
            cur.extend(seq[: max_length - len(cur)])
    # Append the last 'cur' if it has a sufficient number of tokens
    # otherwise throwaway leftover to ensure batch only comprises `max_length` instances
    if len(cur) == max_length:
        if eos_token_id is not None:
            cur[-1] = eos_token_id
        ret.append(cur)
    return ret


class DataCollatorForCausalLM:
    def __init__(
        self,
        llm_tokenizer: PreTrainedTokenizerFast,
        llm_tokenizer_kwargs: dict,
        nllb_tokenizer: PreTrainedTokenizerFast,
        nllb_tokenizer_kwargs: dict = {},
    ) -> None:
        self.llm_tokenizer = llm_tokenizer
        self.llm_tokenizer_kwargs = llm_tokenizer_kwargs

        if getattr(self.llm_tokenizer, "pad_token_id") is None:
            self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id

        self.nllb_tokenizer = nllb_tokenizer
        self.nllb_tokenizer_kwargs = nllb_tokenizer_kwargs

    def __call__(self, inputs) -> BatchEncoding:
        text: list[str] = [line["text"] for line in inputs]
        llm_batch = self.llm_tokenizer(text, **self.llm_tokenizer_kwargs)
        llm_batch["eos_token_id"] = llm_batch["attention_mask"].sum(-1) - 1
        nllb_batch = self.nllb_tokenizer(text, **self.nllb_tokenizer_kwargs)
        nllb_batch["eos_token_id"] = nllb_batch["attention_mask"].sum(-1) - 1
        for k, v in nllb_batch.items():
            llm_batch[f"nllb_{k}"] = v
        return llm_batch


class DataCollatorForSequenceClassification:
    def __init__(
        self,
        llm_tokenizer: PreTrainedTokenizerFast,
        llm_tokenizer_kwargs: dict,
        columns: dict[str, str],
        nllb_tokenizer: None | PreTrainedTokenizerFast = None,
        nllb_tokenizer_kwargs: dict = {},
        # Llama eos token, otherwise doesn't get added
        # llm_eos_token: str = "<|end_of_text|>",
        llm_eos_token: str = "",
        *args,
        **kwargs,
    ) -> None:
        self.llm_tokenizer = llm_tokenizer
        self.llm_tokenizer_kwargs = llm_tokenizer_kwargs

        if getattr(self.llm_tokenizer, "pad_token_id") is None:
            self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id

        self.nllb_tokenizer = nllb_tokenizer
        self.nllb_tokenizer_kwargs = nllb_tokenizer_kwargs
        self.columns = columns
        self.llm_eos_token = llm_eos_token

    def __call__(self, inputs) -> BatchEncoding:
        text: list[str] = [line[self.columns["text"]] for line in inputs]
        text_pair: None | list[str]
        if "text_pair" in self.columns:
            text_pair = [line[self.columns["text_pair"]] for line in inputs]
        else:
            text_pair = None
        llm_batch = self.llm_tokenizer(
            [t + self.llm_eos_token for t in text],
            [t + self.llm_eos_token for t in text_pair]
            if text_pair is not None
            else None,
            **self.llm_tokenizer_kwargs,
        )
        if self.nllb_tokenizer is not None:
            nllb_batch = self.nllb_tokenizer(
                text, text_pair, **self.nllb_tokenizer_kwargs
            )
            for k, v in nllb_batch.items():
                llm_batch[f"nllb_{k}"] = v

        if "label" in self.columns:
            labels = torch.LongTensor([line[self.columns["label"]] for line in inputs])
            llm_batch["labels"] = labels
        if "sequence_embeds" in inputs[0]:
            llm_batch["sequence_embeds"] = torch.stack(
                [input_["sequence_embeds"] for input_ in inputs], dim=0
            )
        return llm_batch


class DataCollatorForAdaptation:
    def __init__(
        self,
        llm_tokenizer: PreTrainedTokenizerFast,
        nllb_tokenizer: None | PreTrainedTokenizerFast,
        *args,
        **kwargs,
    ) -> None:
        self.llm_tokenizer = llm_tokenizer

        if getattr(self.llm_tokenizer, "pad_token_id") is None:
            self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id

        self.nllb_tokenizer = nllb_tokenizer

    def __call__(self, inputs) -> BatchEncoding:
        llama_inputs = [
            {"input_ids": line["input_ids"], "attention_mask": line["attention_mask"]}
            for line in inputs
        ]
        nllb_inputs = [
            {
                "input_ids": line["nllb_input_ids"],
                "attention_mask": line["nllb_attention_mask"],
            }
            for line in inputs
        ]
        llama_batch = self.llm_tokenizer.pad(
            llama_inputs, return_tensors="pt", padding="max_length", max_length=512
        )
        nllb_batch = self.nllb_tokenizer.pad(
            nllb_inputs, return_tensors="pt", padding="max_length", max_length=512
        )
        for k, v in nllb_batch.items():
            llama_batch[f"nllb_{k}"] = v
        return llama_batch


class DataCollatorForAdaptationGPT:
    def __init__(
        self,
        llm_tokenizer: PreTrainedTokenizerFast,
        nllb_tokenizer: PreTrainedTokenizerFast,
        tokenize_kwargs: dict = {},
        *args,
        **kwargs,
    ) -> None:
        self.llm_tokenizer = llm_tokenizer
        self.tokenize_kwargs = tokenize_kwargs

        if getattr(self.llm_tokenizer, "pad_token_id") is None:
            self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id

        self.nllb_tokenizer = nllb_tokenizer

    def __call__(self, inputs) -> BatchEncoding:
        text = [line["text"] for line in inputs]
        llama_batch = self.llm_tokenizer(text, **self.tokenize_kwargs)
        nllb_batch = self.nllb_tokenizer(text, **self.tokenize_kwargs)
        for k, v in nllb_batch.items():
            llama_batch[f"nllb_{k}"] = v
        return llama_batch


class DataCollatorForMultipleChoice:
    def __init__(
        self,
        tokenizer: PreTrainedTokenizerFast,
        columns: dict[str, str],
        pad_to_multiple_of: Optional[int] = None,
        tokenize_kwargs: dict = {
            "return_tensors": "pt",
            "truncation": True,
            "padding": True,
            "return_attention_mask": True,
            "max_length": 1024,
        },
    ) -> None:
        self.tokenizer = tokenizer
        self.tokenize_kwargs = tokenize_kwargs
        self.pad_to_multiple_of = pad_to_multiple_of
        self.columns = columns
        self.num_choices = len(self.columns["choices"])

    # def __call__(self, features: list[dict]):
    # text = [
    #     f"{line[self.columns['context']]} {line[self.columns['question']]} "
    #     for line in features
    #     for _ in range(self.num_choices)
    # ]
    # text_pair = [
    #     line[choice] for line in features for choice in self.columns["choices"]
    # ]
    # assert len(text) == len(text_pair)
    # labels = torch.LongTensor(
    #     [int(line[self.columns["label"]]) - 1 for line in features]
    # )
    # batch = self.tokenizer(text, text_pair, **self.tokenize_kwargs)
    # batch["labels"] = torch.LongTensor(labels)
    # return batch
    # def __call__(self, features: list[dict]):
    #     text = [
    #         f"{line[self.columns['context']]} {line[self.columns['question']]}\n"
    #         for line in features
    #     ]
    #     for i in range(len(text)):
    #         for j, choice in enumerate(self.columns["choices"]):
    #             text[i] += f"{j} {features[i][choice]}\n"
    #     text = [t.strip() for t in text]
    #     labels = torch.LongTensor(
    #         [int(line[self.columns["label"]]) - 1 for line in features]
    #     )
    #     batch = self.tokenizer(text, **self.tokenize_kwargs)
    #     batch["labels"] = torch.LongTensor(labels)
    #     return batch
    def __call__(self, features: list[dict]):
        text = [
            f"{line[self.columns['context']]} {line[self.columns['question']]}\n1 {line[self.columns['choices'][0]]}\n2 {line[self.columns['choices'][1]]}\n3 {line[self.columns['choices'][2]]}\n4 {line[self.columns['choices'][3]]}"
            for line in features
        ]
        choices = [
            [
                line[self.columns["choices"][0]],
                line[self.columns["choices"][1]],
                line[self.columns["choices"][2]],
                line[self.columns["choices"][3]],
            ]
            for line in features
        ]
        batch = self.tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            max_length=self.tokenize_kwargs["max_length"],
            truncation=True,
            return_offsets_mapping=True,
        )
        offsets = batch["offset_mapping"].tolist()
        N, L = batch["input_ids"].shape
        spans = []
        mask = torch.zeros((4, N, L), dtype=torch.bool)
        for idx, (t, row, row_choices) in enumerate(zip(text, offsets, choices)):
            N = len(row)
            sub_spans = []
            choice = 3
            string_ = ""
            current_end = None
            current_end_token_idx = None
            choice_stripped = row_choices[choice].strip()
            # we go backwards
            for i in reversed(range(0, N)):
                # skip RHS padding
                if row[i][0] == 0 and row[i][1] == 0:
                    continue
                # advance end pointer to left
                if current_end is None:
                    current_end = row[i][1]
                    # appropriately take inclusive range
                    current_end_token_idx = i + 1
                # get current text
                string_ = t[row[i][0] : current_end]
                # strip whitespace for clean comparison
                string_stripped = string_.strip(" ")
                if string_stripped == choice_stripped:
                    mask[choice, idx, i:current_end_token_idx] = True
                    sub_spans.append((row[i][0], current_end))
                    choice -= 1
                    current_end = None
                    string_ = ""
                    choice_stripped = row_choices[choice].strip(" ")
                # advance pointer to left since string has to end in choice
                elif not choice_stripped.endswith(string_stripped):
                    current_end = None
                    current_end_token_idx = None
                if choice == -1:
                    break
            sub_spans = list(reversed(sub_spans))
            spans.append(sub_spans)
        for s in spans:
            assert len(s) == 4
        mask = mask.permute(1, 2, 0)
        batch["mean_mask"] = mask.float()
        batch.pop("offset_mapping")
        batch["labels"] = torch.LongTensor(
            [int(line[self.columns["label"]]) - 1 for line in features]
        )
        return batch


def preprocess_adaptation(
    dataset: Dataset,
    llama_tokenizer: PreTrainedTokenizerFast,
    nllb_tokenizer: PreTrainedTokenizerFast,
    tokenize_kwargs: dict = {
        "truncation": True,
        "return_attention_mask": True,
        "max_length": 512,
    },
    text_column: str = "text",
    num_proc: int = 8,
) -> Dataset:
    """
    Preprocesses a dataset by tokenizing and grouping its text.

    Args:
    - dataset: The input dataset.
    - column_names: Dictionary containing column names.
    - tokenizer: Tokenizer to be used.
    - max_seq_length: Maximum sequence length.
    - num_proc: Number of processes.

    Returns:
    - The preprocessed dataset.
    """
    # Extract and convert column names
    remove_columns = dataset.column_names
    if isinstance(remove_columns, dict):
        remove_columns = list(remove_columns.values())

    def tokenize_dataset(
        examples: dict[str, list[str]],
        llama_tokenizer,
        nllb_tokenizer,
        tokenize_kwargs,
        text_column,
    ):
        llama_batch = llama_tokenizer(examples[text_column], **tokenize_kwargs)
        nllb_batch = nllb_tokenizer(examples[text_column], **tokenize_kwargs)
        out = {}
        for k, v in llama_batch.items():
            out[k] = v
        for k, v in nllb_batch.items():
            out[f"nllb_{k}"] = v
        return out

    dataset = dataset.map(
        function=tokenize_dataset,
        fn_kwargs={
            "llama_tokenizer": llama_tokenizer,
            "nllb_tokenizer": nllb_tokenizer,
            "tokenize_kwargs": tokenize_kwargs,
            "text_column": text_column,
        },
        remove_columns=remove_columns,
        batched=True,
        num_proc=num_proc,
        batch_size=10_000,
    )
    return dataset


def preprocess_fn(
    dataset: Dataset,
    tokenizer: PreTrainedTokenizerFast,
    tokenize_kwargs: dict = {
        "add_special_tokens": False,
        "truncation": False,
        "return_attention_mask": False,
    },
    text_column: str = "text",
    num_proc: int = 12,
) -> Dataset:
    """
    Preprocesses a dataset by tokenizing and grouping its text.

    Args:
    - dataset: The input dataset.
    - column_names: Dictionary containing column names.
    - tokenizer: Tokenizer to be used.
    - max_seq_length: Maximum sequence length.
    - num_proc: Number of processes.

    Returns:
    - The preprocessed dataset.
    """
    # Extract and convert column names
    remove_columns = dataset.column_names
    if isinstance(remove_columns, dict):
        remove_columns = list(remove_columns.values())

    def tokenize_dataset(examples: dict[str, list[str]], tokenizer, tokenize_kwargs):
        return {"input_ids": tokenizer(examples[text_column], **tokenize_kwargs)}[
            "input_ids"
        ]

    dataset = dataset.map(
        function=tokenize_dataset,
        fn_kwargs={
            "tokenizer": tokenizer,
            "tokenize_kwargs": tokenize_kwargs,
        },
        remove_columns=remove_columns,
        batched=True,
        num_proc=num_proc,
    )
    return dataset


def preprocess_bytes(
    dataset: Dataset,
    tokenizer: PreTrainedTokenizerFast,
    column: str = "text",
    num_proc: int = 12,
) -> Dataset:
    """
    Preprocesses a dataset by tokenizing and grouping its text.

    Args:
    - dataset: The input dataset.
    - column_names: Dictionary containing column names.
    - tokenizer: Tokenizer to be used.
    - max_seq_length: Maximum sequence length.
    - num_proc: Number of processes.

    Returns:
    - The preprocessed dataset.
    """
    # Extract and convert column names

    def tokenize_bytes(batch: dict[str, list[str]]):
        strings = batch[column]
        return tokenizer(
            strings,
            truncation=False,
            return_attention_mask=False,
            add_special_tokens=False,
        )

    # Tokenize and group texts
    dataset = dataset.map(
        function=tokenize_bytes,
        remove_columns=column,
        batched=True,
        num_proc=num_proc,
    )
    return dataset


class DataCollatorForMC:
    def __init__(self, tokenizer, tokenize_kwargs: dict):
        self.tokenizer = tokenizer
        self.tokenize_kwargs = tokenize_kwargs
        # simplify mean mask
        self.tokenizer.padding_side = "right"

    def __call__(self, inputs: list[dict]):
        batch = self.tokenizer.pad(
            {
                "input_ids": [b["input_ids"] for b in inputs],
                "attention_mask": [b["attention_mask"] for b in inputs],
            },
            **self.tokenize_kwargs,
        )
        mean_mask = [torch.Tensor(b["mean_mask"]) for b in inputs]

        batch["mean_mask"] = pad_sequence(mean_mask, batch_first=True, padding_value=0)
        batch["labels"] = torch.LongTensor([b["labels"] for b in inputs])
        if "choice_embeds" in inputs[0]:
            batch["choice_embeds"] = torch.stack(
                [input_["choice_embeds"] for input_ in inputs], dim=0
            )
        return batch


class EmbeddedDataset(TorchDataset):
    def __init__(self, dataset, tensors, key: str = "choice_embeds") -> None:
        super().__init__()
        self.dataset = dataset
        self.tensors = tensors
        self.key = key

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        out = self.dataset[i]
        out[self.key] = self.tensors[i]
        return out


def preprocess_multiple_choice(
    examples: dict,
    tokenizer: PreTrainedTokenizerFast,
    columns: dict[str, str],
    tokenize_kwargs: dict = {
        "max_length": 4096,
    },
    *args,
    **kwargs,
) -> dict:
    choices = list(zip(*[examples[c] for c in columns["choices"]]))
    sanitized_choices = []
    for i, choices_ in enumerate(choices):
        cleaned_choices = []
        for choice in choices_:
            if not isinstance(choice, str) or choice == "":
                cleaned_choices.append("None")
            else:
                cleaned_choices.append(choice)
        sanitized_choice = [
            s.strip()
            for s in tokenizer.batch_decode(
                tokenizer(cleaned_choices, max_length=512, add_special_tokens=False)[
                    "input_ids"
                ]
            )
        ]
        sanitized_choices.append(sanitized_choice)
    choices = sanitized_choices
    text = [
        f"{context} {question}\n1 {choice_[0]}\n2 {choice_[1]}\n3 {choice_[2]}\n4 {choice_[3]}"
        for context, question, choice_ in zip(
            examples[columns["context"]], examples[columns["question"]], choices
        )
    ]
    batch = tokenizer(
        text,
        return_tensors="pt",
        padding=True,
        # max_length=tokenize_kwargs["max_length"],
        truncation=False,
        return_offsets_mapping=True,
    )
    N, L = batch["input_ids"].shape
    choice_text = [
        f"1 {choice_[0]}\n2 {choice_[1]}\n3 {choice_[2]}\n4 {choice_[3]}"
        for choice_ in choices
    ]
    choice_length = max([len(ids) for ids in tokenizer(choice_text)["input_ids"]])
    max_context_length = tokenize_kwargs["max_length"] - choice_length
    contexts = examples[columns["context"]]
    questions = examples[columns["question"]]
    while L > tokenize_kwargs["max_length"]:
        if max_context_length > 0:
            context_questions_batch = tokenizer(
                text=contexts,
                text_pair=questions,
                max_length=max_context_length,
                truncation="longest_first",
            )
            context_questions = tokenizer.batch_decode(
                context_questions_batch["input_ids"],
            )
            # Llama tokenizer adds bos token in between text and text_pair as well
            context_questions = [
                cq.replace("<|begin_of_text|>", " ").lstrip()
                for cq in context_questions
            ]
            text = [
                f"{context_question}\n1 {choice_[0]}\n2 {choice_[1]}\n3 {choice_[2]}\n4 {choice_[3]}"
                for context_question, choice_ in zip(context_questions, choices)
            ]
        else:
            text = [
                f"1 {choice_[0]}\n2 {choice_[1]}\n3 {choice_[2]}\n4 {choice_[3]}"
                for choice_ in choices
            ]
        batch = tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            # max_length=tokenize_kwargs["max_length"],
            truncation=False,
            return_offsets_mapping=True,
        )
        N, L = batch["input_ids"].shape
        max_context_length -= 500
        max_context_length = max(0, max_context_length)
    offsets = batch["offset_mapping"].tolist()
    spans = []
    mask = torch.zeros((4, N, L), dtype=torch.bool)
    for idx, (t, row, row_choices) in enumerate(zip(text, offsets, choices)):
        N = len(row)
        sub_spans = []
        choice = 3
        string_ = ""
        current_end = None
        current_end_token_idx = None
        # INFO: have to use rstrip because newline token comes from left
        choice_stripped = row_choices[choice].strip().encode("utf-8")
        # we go backwards
        for i in reversed(range(0, N)):
            # skip RHS padding
            if row[i][0] == 0 and row[i][1] == 0:
                continue
            # advance end pointer to left
            if current_end is None:
                current_end = row[i][1]
                # appropriately take inclusive range
                current_end_token_idx = i + 1
            # get current text
            string_ = t[row[i][0] : current_end]
            # strip whitespace for clean comparison
            string_stripped = string_.rstrip().encode("utf-8")
            if (
                string_stripped == choice_stripped
                # # this accounts for string_stripped: _helium and helium
                or choice_stripped in string_stripped
            ):
                # limited lookahead for multibyte chars
                # longest multibyte char is 16
                for k in range(1, 16):
                    j = i - k
                    if j < 0:
                        break
                    string_ = t[row[j][0] : current_end]
                    # strip whitespace for clean comparison
                    string_stripped = string_.strip().encode("utf-8")
                    if choice_stripped.startswith(string_stripped):
                        i = j
                mask[choice, idx, i:current_end_token_idx] = True
                sub_spans.append((row[i][0], current_end))
                choice -= 1
                current_end = None
                string_ = ""
                choice_stripped = row_choices[choice].strip().encode("utf-8")
            # advance pointer to left since string has to end in choice
            elif not string_stripped or not choice_stripped.endswith(string_stripped):
                current_end = None
                current_end_token_idx = None
            if choice == -1:
                break
        sub_spans = list(reversed(sub_spans))
        spans.append(sub_spans)
    for s in spans:
        assert len(s) == 4
    # pivot choice and batch dimension
    mask = mask.permute(1, 2, 0)
    batch["mean_mask"] = mask.float()
    batch.pop("offset_mapping")
    out = {"input_ids": [], "attention_mask": [], "mean_mask": []}
    for row_ids, row_attn_mask, row_mean_mask in zip(
        batch["input_ids"], batch["attention_mask"], batch["mean_mask"]
    ):
        mask_ = row_attn_mask.bool()
        out["input_ids"].append(row_ids[mask_].tolist())
        out["attention_mask"].append(row_attn_mask[mask_].tolist())
        out["mean_mask"].append(row_mean_mask[mask_, :].tolist())
    out["labels"] = list(map(lambda x: int(x) - 1, examples[columns["label"]]))
    out["spans"] = spans
    return out


class CollatorForNLLBTranslation:
    def __init__(
        self, tokenizer: PreTrainedTokenizerFast, max_length: int = 1024
    ) -> None:
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, inputs, *args: Any, **kwds: Any) -> dict[str, torch.Tensor]:
        encoder_inputs = [
            {"input_ids": line["input_ids"],
             # "attention_mask": line["attention_mask"]
             }
            for line in inputs
        ]
        decoder_inputs = [
            {
                "input_ids": line["decoder_input_ids"],
                # "attention_mask": line["decoder_attention_mask"],
            }
            for line in inputs
        ]
        batch = self.tokenizer.pad(
            encoder_inputs,
            padding=True,
            return_tensors="pt",
            max_length=self.max_length,
            return_attention_mask=True,
        )
        decoder_batch = self.tokenizer.pad(
            decoder_inputs,
            padding=True,
            return_tensors="pt",
            max_length=self.max_length,
            return_attention_mask=True,
        )
        batch["decoder_input_ids"] = decoder_batch["input_ids"]
        batch["decoder_attention_mask"] = decoder_batch["attention_mask"]
        labels = batch["decoder_input_ids"].clone()
        labels = torch.hstack(
            [labels[:, 1:], torch.full(size=(len(labels), 1), fill_value=-100)]
        )
        labels = torch.where(labels == self.tokenizer.pad_token_id, -100, labels)
        batch["labels"] = labels
        return batch


def preprocess_for_translation(
    examples_: dict[str, list[str]],
    tokenizer: PreTrainedTokenizerFast,
    tokenize_kwargs: dict = {},
) -> None:
    text_batch = tokenizer(
        examples_["source_sentence"], add_special_tokens=False, **tokenize_kwargs
    )
    for line, lang in zip(text_batch["input_ids"], examples_["source_lang"]):
        lang_code = tokenizer.encode(lang + "_Latn", add_special_tokens=False)
        assert len(lang_code) == 1
        line.insert(0, lang_code[0])
        line.append(tokenizer.eos_token_id)
    # for line in text_batch["attention_mask"]:
    #     line.append(0)
    #     line.append(0)
    text_pair_batch = tokenizer(
        examples_["target_sentence"], add_special_tokens=False, **tokenize_kwargs
    )
    for line, lang in zip(text_pair_batch["input_ids"], examples_["target_lang"]):
        lang_code = tokenizer.encode("spa_Latn", add_special_tokens=False)
        assert len(lang_code) == 1
        line.insert(0, lang_code[0])
        line.append(tokenizer.eos_token_id)
    # for line in text_pair_batch["attention_mask"]:
    #     line.append(0)
    #     line.append(0)

    examples_["input_ids"] = text_batch["input_ids"]
    # examples_["attention_mask"] = text_batch["attention_mask"]
    examples_["decoder_input_ids"] = text_pair_batch["input_ids"]
    # examples_["decoder_attention_mask"] = text_pair_batch["attention_mask"]
    return examples_


def test():
    from transformers import AutoTokenizer
    from datasets import load_dataset

    dataset = load_dataset("jbrinkma/pile-10k", split="train")

    lang = "eng_Latn"
    nllb_tokenizer = AutoTokenizer.from_pretrained(
        "facebook/nllb-200-distilled-600M", src_lang=lang
    )
    example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"

    nllb_tokenizer.decode(nllb_tokenizer(example_english_phrase)["input_ids"])
    nllb_tokenizer_kwargs = {
        "max_length": 1024,
        "truncation": True,
        "padding": True,
        "return_tensors": "pt",
    }
    llm_tokenizer = AutoTokenizer.from_pretrained(
        "meta-llama/Meta-Llama-3-8B", add_eos_token=True
    )
    llm_tokenizer_kwargs = {
        "max_length": 1024,
        "truncation": True,
        "padding": True,
        "return_tensors": "pt",
    }
    llm_tokenizer(text=example_english_phrase + "<|end_of_text|>")
    text = [dataset[i] for i in range(16)]
    collator = DataCollatorForCausalLM(
        llm_tokenizer, llm_tokenizer_kwargs, nllb_tokenizer, nllb_tokenizer_kwargs
    )
    batch = collator(text)
    e = batch["eos_token_id"]
    import torch
    # N = torch.arange(len(e))
    # batch["input_ids"][N, e]

    nusax = load_dataset("indonlp/NusaX-senti", "eng")

    from transformers import AutoModelForSequenceClassification
    from peft import get_peft_model, LoraConfig, TaskType
    from peft.peft_model import PeftModelForSequenceClassification

    model = AutoModelForSequenceClassification.from_pretrained("openai-community/gpt2")
    peft_config = LoraConfig(
        task_type=TaskType.FEATURE_EXTRACTION,
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules="all-linear",
    )
    peft_model = get_peft_model(model, peft_config)
    m = PeftModelForSequenceClassification.from_pretrained(model, "./test/")

    for (n, p), (n_, p_) in zip(m.named_parameters(), peft_model.named_parameters()):
        if not torch.allclose(p.data, p_.data):
            print(n, n_)

    import torch
    from datasets import load_dataset
    import os
    from transformers import AutoTokenizer

    from omegaconf.dictconfig import DictConfig

    cwd = os.getcwd()
    fw = load_dataset(
        "HuggingFaceFW/fineweb", name="sample-10BT", split="train", num_proc=16
    )
    llama_tokenizer = AutoTokenizer.from_pretrained(
        "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp"
    )
    nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")

    x = "This is a very long test input\n1 match here\n2 match that"
    choice = "1 match here"
    # Assume I am using pytorch: How do I best get the offset of the first and the last ids in input_ids for choice?
    nllb_tokenizer.decode(nllb_tokenizer(x)["input_ids"])

    llama_tokenizer.decode(llama_tokenizer(x)["input_ids"])

    tokenize_kwargs = DictConfig(
        {"truncation": True, "max_length": 512, "return_attention_mask": True}
    )

    dataset = preprocess_adaptation(
        fw,
        llama_tokenizer=llama_tokenizer,
        nllb_tokenizer=nllb_tokenizer,
        tokenize_kwargs=tokenize_kwargs,
    )

    import torch
    from transformers import AutoTokenizer

    # Initialize the tokenizer (replace 'your-model-name' with the actual model name you're using)

    numbers = ["1", "2", "3", "4"]
    number_ids = [nllb_tokenizer(n)["input_ids"][1] for n in numbers]
    # Batch of input texts and choices
    inputs = [
        "This is a very long test input\n1 match here\n2 match that",
        "Another example input\nWith different match here\n2 match that",
    ]
    choices = [["1 match here", "2 match that"], ["different match here"]]

    # Tokenize the batch of input texts
    input_ids_batch = nllb_tokenizer(
        inputs, padding=True, truncation=True, return_tensors="pt"
    )["input_ids"]

    # Function to find start and end indices of a choice in an input
    def find_match_indices(input_ids, choice_ids):
        input_ids_tensor = torch.tensor(input_ids)
        choice_ids_tensor = torch.tensor(choice_ids)

        match_start_idx = (
            (input_ids_tensor.unfold(0, len(choice_ids_tensor), 1) == choice_ids_tensor)
            .all(dim=1)
            .nonzero(as_tuple=True)[0]
        )

        if match_start_idx.numel() > 0:
            match_start_idx = match_start_idx.item()
            match_end_idx = match_start_idx + len(choice_ids) - 1
            return match_start_idx, match_end_idx
        else:
            return None, None

    # Iterate over each input and its corresponding choices
    for i, input_text in enumerate(inputs):
        input_ids = nllb_tokenizer(input_text)["input_ids"]
        print(f"Input {i+1}:")
        for choice in choices[i]:
            choice_ids = nllb_tokenizer(choice)["input_ids"]
            start_idx, end_idx = find_match_indices(input_ids, choice_ids)
            if start_idx is not None:
                print(
                    f"  Choice '{choice}' found at: Start index: {start_idx}, End index: {end_idx}"
                )
            else:
                print(f"  Choice '{choice}' not found")
