import gzip
import json
import lzma
import pickle
import random
import sys
from typing import Tuple, List, Callable

import numpy as np
import torch
import transformers
from datasets import Dataset, IterableDataset
from torch.utils.data import Sampler, RandomSampler, BatchSampler, DataLoader, SequentialSampler
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, PreTrainedTokenizerFast

import datasets

import tqdm

from meta_adapters.STEP.grammar_gen import ProductionRule


class AllInOneBatch(Sampler):

    def __init__(self, source: Dataset):
        self.source = source

    def __iter__(self):
        return iter([[i for i, _ in enumerate(self.source)]])


class TaskSampler(Sampler):
    """
    Example:
    TaskSampler(dataset, SequentialSampler, lambda s: BatchSampler(SequentialSampler(s), batch_size = 2, drop_last=False))

    Returns batches of examples that consist of a single task
    """

    def __init__(self, source, across_task_sampler, within_task_sampler, within_task_sampler_test, task_key="task"):
        self.source = source
        self.task_key = task_key
        self.across_task_sampler = across_task_sampler
        self.within_task_sampler = within_task_sampler
        self.within_task_sampler_test = within_task_sampler_test

        self.task2id = {"train": dict(), "test": dict()}

        for i, el in enumerate(source):
            t2i = self.task2id["train" if el["is_train"] else "test"]
            if el[self.task_key] not in t2i:
                t2i[el[self.task_key]] = []
            t2i[el[self.task_key]].append(i)

        if self.task2id["train"].keys() != self.task2id["test"].keys():
            raise ValueError("There is a task for which there is a train but not a test split or vice versa")

        self.task_ids = list(self.task2id["train"].keys())

    def __iter__(self):
        for task_id in self.across_task_sampler(self.task_ids):
            relevant_train_ids = self.task2id["train"][self.task_ids[task_id]]
            relevant_test_ids = self.task2id["test"][self.task_ids[task_id]]
            for idx_train, idx_test in zip(self.within_task_sampler(relevant_train_ids),
                                           self.within_task_sampler_test(relevant_test_ids)):
                # TODO: we might want to handle train and test differently here...
                if not isinstance(idx_train, list) or not isinstance(idx_test, list):
                    raise ValueError("Expect batches")
                yield [relevant_train_ids[i] for i in idx_train] + [relevant_test_ids[i] for i in idx_test]

    def __len__(self):
        return len(self.task_ids)


def load_tsv(fname, expect_first_line = None, lenient: bool = False):
    with open(fname) as f:
        it = iter(f)
        if expect_first_line is not None:
            first_line = next(it).strip()
            if expect_first_line != first_line:
                if lenient:
                    line = first_line.strip("\n").strip("\r")
                    if line:
                        yield line.split("\t")
                else:
                    raise ValueError(f"First line must be: '{expect_first_line}'")
        for line in it:
            line = line.strip("\n").strip("\r")
            if line:
                yield line.split("\t")

def prepare_meta_dataset(path:str, tokenizer: AutoTokenizer, batch_size: int) -> DataLoader:
    def mapper(examples):
        d = tokenizer(examples["input"])
        if "output" in examples:
            d["labels"] = tokenizer(text_target=examples["output"])["input_ids"]
        return d

    keys = ["input", "output", "task"]
    d = {"is_train": []} | {k: [] for k in keys}
    for row in load_tsv(path, "input\toutput\ttask\tis_train"):
        for x, k in zip(row, keys):
            d[k].append(x)
        d["is_train"].append(int(row[-1]))
    dataset = Dataset.from_dict(d)

    ts = TaskSampler(dataset, RandomSampler,
                     lambda s: BatchSampler(RandomSampler(s), batch_size=batch_size, drop_last=False),
                     lambda s: AllInOneBatch(s))
    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output", "task"])
    return DataLoader(dataset, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)


def weighted_dict_choice(remaining):
    Z = sum(remaining.values())
    cutoff = random.random() * Z
    z = 0
    for key in remaining:
        z += remaining[key]
        if z >= cutoff:
            return key


class BatchSamplerWithSameLength(Sampler):
    def __init__(self, data, batch_size, key="input_ids") -> None:
        super().__init__()
        self.data = data
        self.l2indices = dict()
        self.batch_size = batch_size
        for i, dp in enumerate(data):
            length = len(dp[key])
            if length not in self.l2indices:
                self.l2indices[length] = []
            self.l2indices[length].append(i)

    def __len__(self) -> int:
        return sum((len(bin_) + self.batch_size - 1) // self.batch_size for bin_ in self.l2indices)

    def __iter__(self):
        for length, indices in self.l2indices.items():
            random.shuffle(indices)
        position = {length: 0 for length in self.l2indices}
        remaining = {length: len(self.l2indices[length]) for length in self.l2indices}

        while sum(remaining.values()) > 0:
            # print("Remaining", remaining)
            length = weighted_dict_choice(remaining)
            # print("Chosen", length)
            new_position = min(position[length] + self.batch_size, len(self.l2indices[length]))
            yield self.l2indices[length][position[length]:new_position]
            remaining[length] = remaining[length] - (new_position - position[length])


def prepare_task_dataset(path:str, tokenizer: AutoTokenizer, batch_size: int, random_order: bool = True, lenient: bool=False,
                         same_length_batches: bool = False, force_output_right_padding: bool = False, prompt:str = "") -> DataLoader:
    def mapper(examples):
        d = tokenizer(prompt+examples["input"])
        padding_side = tokenizer.padding_side
        if "output" in examples:
            if force_output_right_padding:
                tokenizer.padding_side = "right"
            d["labels"] = tokenizer(text_target=examples["output"])["input_ids"]
            tokenizer.padding_side = padding_side
            if d["labels"][-1] != tokenizer.eos_token_id:
                # gpt neo tokenizer doesn't add EOS token, so do this explicitly.
                d["labels"].append(tokenizer.eos_token_id)
        return d

    keys = ["input", "output"]
    d = {k: [] for k in keys}
    for row in load_tsv(path, "input\toutput", lenient=lenient):
        for x, k in zip(row, keys):
            d[k].append(x)
    dataset = Dataset.from_dict(d)

    if random_order:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)
    ts = BatchSampler(sampler, batch_size=batch_size, drop_last=False)
    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output"])
    if same_length_batches:
        ts = BatchSamplerWithSameLength(dataset, batch_size=batch_size)
    return DataLoader(dataset, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)



def prepare_emphasis_dataset(path:str, tokenizer: AutoTokenizer, batch_size: int, random_order: bool = True, lenient: bool=True) -> DataLoader:
    def mapper(examples):
        inp, emph = examples["input"].split(";")
        inp = inp.strip()
        emph = emph.strip()
        d = tokenizer(inp)
        emph_tokens = tokenizer(emph)["input_ids"]

        d["emph"] = [ tok in emph_tokens and tok != tokenizer.bos_token_id and tok != tokenizer.eos_token_id for tok in d["input_ids"]]

        padding_side = tokenizer.padding_side
        if "output" in examples:
            d["labels"] = tokenizer(text_target=examples["output"])["input_ids"]
            tokenizer.padding_side = padding_side
            if d["labels"][-1] != tokenizer.eos_token_id:
                # gpt neo tokenizer doesn't add EOS token, so do this explicitly.
                d["labels"].append(tokenizer.eos_token_id)
        return d

    keys = ["input", "output"]
    d = {k: [] for k in keys}
    for row in load_tsv(path, "input\toutput", lenient=lenient):
        for x, k in zip(row, keys):
            d[k].append(x)
    dataset = Dataset.from_dict(d)

    seq2seq_collator = DataCollatorForSeq2Seq(tokenizer)
    def collator_fn(features):
        emph = []
        for f in features:
            emph.append(torch.from_numpy(np.array(f.pop("emph"))))
        padded = torch.nn.utils.rnn.pad_sequence(emph, padding_value=False)
        d = seq2seq_collator(features)
        d["emph_mask"] = padded.transpose(0, 1) #shape (batch, seq_len)

        return d

    if random_order:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)
    ts = BatchSampler(sampler, batch_size=batch_size, drop_last=False)
    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output"])
    return DataLoader(dataset, collate_fn=collator_fn, batch_sampler=ts)



def load_cogs_dataset(path:str, tokenizer: AutoTokenizer, batch_size: int, random_order: bool = True, lenient: bool=False,
                      same_length_batches: bool = False, prompt: str = "") -> DataLoader:
    def mapper(examples):
        d = tokenizer(prompt+examples["input"])
        if "output" in examples:
            d["labels"] = tokenizer(text_target=examples["output"])["input_ids"]
            if d["labels"][-1] != tokenizer.eos_token_id:
                # gpt neo tokenizer doesn't add EOS token, so do this explicitly.
                d["labels"].append(tokenizer.eos_token_id)
        return d

    keys = ["input", "output", "gen_type"]
    d = {k: [] for k in keys}
    for row in load_tsv(path):
        for x, k in zip(row, keys):
            d[k].append(x)
    dataset = Dataset.from_dict(d)

    collator = DataCollatorForSeq2Seq(tokenizer)
    def collator_fn(features):
        gen_types = []
        for x in features:
            if "gen_type" in x:
                gen_types.append(x["gen_type"])
                del x["gen_type"]
        d = collator(features)
        if len(gen_types) > 0:
            assert len(gen_types) == len(features), "Either all or no sentences should have a gen_type."
            d["gen_type"] = gen_types
        return d

    if random_order:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)
    ts = BatchSampler(sampler, batch_size=batch_size, drop_last=False)
    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output"])
    if same_length_batches:
        ts = BatchSamplerWithSameLength(dataset, batch_size=batch_size)
    return DataLoader(dataset, collate_fn=collator_fn, batch_sampler=ts)



def prepare_task_dataset_jsonl(path:str, tokenizer: AutoTokenizer, batch_size: int, random_order: bool = True) -> DataLoader:
    def mapper(examples):
        if isinstance(examples["input"], str):
            d = tokenizer(examples["input"])
        elif isinstance(examples["input"][0], str):
            #pre-tokenized but not mapped to ints yet
            i = tokenizer.convert_tokens_to_ids(examples["input"])
            assert len(i) == len(examples["input"])
            i = tokenizer.build_inputs_with_special_tokens(i)
            d = {"input_ids": i,
                 "attention_mask": [1] * len(i)}
        else:
            d = {"input_ids": examples["input"], "attention_mask": [1] * len(examples["input"])}
        if "output" in examples:
            if isinstance(examples["output"], str):
                d["labels"] = tokenizer(text_target=examples["output"])["input_ids"]

            elif isinstance(examples["output"][0], str):
                #pre-tokenized
                o = tokenizer.convert_tokens_to_ids(examples["output"])
                assert len(o) == len(examples["output"])
                d["labels"] = tokenizer.build_inputs_with_special_tokens(o)
            else:
                d["labels"] = examples["output"]
        return d

    keys = ["input", "output"]
    d = {k: [] for k in keys}
    with open(path) as f:
        for row in f:
            j = json.loads(row)
            for k in keys:
                d[k].append(j[k])

    dataset = Dataset.from_dict(d)

    if random_order:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)
    ts = BatchSampler(sampler, batch_size=batch_size, drop_last=False)
    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output"])
    return DataLoader(dataset, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)


def prepare_multi_meta(path:str, tokenizer: AutoTokenizer, train_batch_size: int, test_batch_size: int) -> List[Tuple[DataLoader, DataLoader]]:
    datasets.disable_progress_bar()
    def mapper(examples):
        d = tokenizer(examples["input"], text_target=examples["output"] if "output" in examples else None)
        return d
    task2data = dict()
    for row in load_tsv(path, "input\toutput\ttask\tis_train"):
        input, output, task, is_train = row
        if task not in task2data:
            task2data[task] = {"train": {"input": [], "output": []}, "test": {"input": [], "output": []}}
        if int(is_train):
            d = task2data[task]["train"]
        else:
            d = task2data[task]["test"]

        d["input"].append(input)
        d["output"].append(output)

    dataloaders = []
    for task in tqdm.tqdm(task2data):
        train_data = Dataset.from_dict(task2data[task]["train"])
        test_data = Dataset.from_dict(task2data[task]["test"])
        if len(train_data) == 0:
            raise ValueError(f"Task {task} has no train data")
        if len(test_data) == 0:
            raise ValueError(f"Task {task} has no test data")
        dls = []
        for data, batch_size in zip([train_data, test_data], [train_batch_size, test_batch_size]):
            data = data.map(mapper, batched=True, remove_columns=["input", "output"])
            sampler = BatchSampler(RandomSampler(data), batch_size=batch_size, drop_last=False)
            dataloader = DataLoader(data, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=sampler)
            dls.append(dataloader)
        dataloaders.append(tuple(dls))

    return dataloaders


def fst_to_vector(fst_tokenizer, num_states, fst: List[Tuple[int, str, str, int]]) -> np.array:
    assert len(fst[0]) == 4 or len(fst[0]) == 5

    fst_rep = np.zeros((len(fst), len(fst[0])), dtype=np.int64)
    for j, f in enumerate(fst):
        s, i, o, sp = f[:4]
        assert s < num_states-1 #last state is reserved for padding
        assert sp < num_states-1
        fst_rep[j, 0] = s

        i_encoded = fst_tokenizer(i)["input_ids"]
        assert len(i_encoded) == 1
        fst_rep[j, 1] = i_encoded[0]

        o_encoded = fst_tokenizer(o)["input_ids"]
        assert len(o_encoded) == 1
        fst_rep[j, 2] = o_encoded[0]

        fst_rep[j, 3] = sp

        if len(f) == 5:
            # for final state indicator
            fst_rep[j, 4] = f[4]
    return fst_rep


def batch_fsts(fst_reps: List[np.array], num_states, max_len=None) -> np.array:
    if max_len is None:
        max_len = max(len(x) for x in fst_reps)
    batched_fst_reps = np.zeros((len(fst_reps), max_len, len(fst_reps[0][0])), dtype=np.int64)
    # Set states to a padding index (last state)
    batched_fst_reps[:, :, 0] = num_states - 1
    batched_fst_reps[:, :, 3] = num_states - 1
    for i, x in enumerate(fst_reps):
        for j, f in enumerate(x):
            if max_len is not None and j >= max_len:
                continue
            batched_fst_reps[i, j] = f
    return batched_fst_reps


def load_fst_jsonl(path: str, tokenizer: AutoTokenizer, fst_tokenizer_path: str, batch_size:int, num_states: int, random_order: bool = True,
                   max_len: int = None, max_n:int=None, map_f = None, filter_f = None):
    fst_tokenizer = PreTrainedTokenizerFast(tokenizer_file=fst_tokenizer_path)

    def mapper(examples):
        d = tokenizer(examples["input"])
        if "output" in examples:
            d["labels"] = tokenizer(text_target=examples["output"])["input_ids"]
        return d

    if map_f is None:
        map_f = lambda x: x

    data = {"input": [], "output": [], "fst_rep": [], "task_ids": []}
    with open(path) as f:
        i = 0
        for line in f:
            d = json.loads(line)
            if filter_f is None or filter_f(d):
                data["input"].append(d["input"])
                data["output"].append(d["output"])

                if "task_id" in d:
                    data["task_ids"].append(d["task_id"])

                data["fst_rep"].append(fst_to_vector(fst_tokenizer, num_states, map_f(d["FST"])))

                i += 1
                if max_n is not None and i > max_n:
                    break

    if len(data["task_ids"]) == 0:
        del data["task_ids"]

    dataset = Dataset.from_dict(data)
    if random_order:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)
    ts = BatchSampler(sampler, batch_size=batch_size, drop_last=False)
    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output"])

    seq2seq_collator = DataCollatorForSeq2Seq(tokenizer)
    def collator_fn(features):
        fst_reps = []
        for x in features:
            fst_reps.append(x["fst_rep"])
            del x["fst_rep"]
        d = seq2seq_collator(features)
        d["fst_rep"] = torch.from_numpy(batch_fsts(fst_reps, num_states, max_len=max_len))

        if "task_id" in features[0]:
            d["task_ids"] = torch.from_numpy(np.array([x["task_id"] for x in features]))

        return d

    return DataLoader(dataset, collate_fn=collator_fn, batch_sampler=ts)


def encode_grammar(rules: list[ProductionRule], max_rhs_length, tokenizer, epsilon_id) -> tuple[np.array, np.array]:
    d = np.zeros((len(rules), 3 + max_rhs_length), dtype=np.int64)
    nt_mask = np.zeros((len(rules), 3 + max_rhs_length), dtype=bool)
    nt_mask[:, 0] = False
    nt_mask[:, 1] = True
    nt_mask[:, 2] = False
    for i, rule in enumerate(rules):
        d[i, 0] = rule.fint
        d[i, 1] = rule.lhs
        if rule.map_term == "":
            d[i, 2] = epsilon_id
        else:
            # d[i, 2] = tokenizer(rule.map_term, add_special_tokens=False)["input_ids"][0]
            t = tokenizer.convert_tokens_to_ids(rule.map_term)
            assert isinstance(t, int)
            d[i, 2] = t

        for j, symbol in enumerate(rule.rhs):
            if isinstance(symbol, str):
                if symbol == "":
                    d[i, j + 3] = epsilon_id
                else:
                    # d[i, j + 3] = tokenizer(symbol, add_special_tokens=False)["input_ids"][0]
                    t = tokenizer.convert_tokens_to_ids(symbol)
                    assert isinstance(t, int)
                    d[i, j + 3] = t
                nt_mask[i, j + 3] = False
            else:
                d[i, j + 3] = symbol
                nt_mask[i, j + 3] = True
    return d, nt_mask


def create_filter_grammar_has_any_anchor(anchors: List[str]) -> Callable[..., bool]:
    anchors = set(anchors)
    def filter(dp):
        for rule in dp["grammar"]:
            if any(symbol in anchors for symbol in rule.rhs):
                return True
        return False

    return filter


def load_grammar_pickle(path: str, tokenizer: transformers.AutoTokenizer, batch_size:int, epsilon_str: str, random_order: bool = True,
                    max_rhs_length: int = 3,
                   max_len: int = None, max_n:int=None,
                        filter_f = None,
                        grammar_map_f = None, output_replace_eos: tuple[str, str] = None, reverse_file: bool = False):
    def mapper(examples):

        if isinstance(examples["input"], str):
            d = tokenizer(examples["input"])
            if "output" in examples:
                if output_replace_eos is not None:
                    text = examples["output"].replace(output_replace_eos[0], output_replace_eos[1])
                else:
                    text = examples["output"]
                d["labels"] = tokenizer(text_target=text)["input_ids"]
        else:
            # Already "pre-tokenized", i.e. converted into a list and every element is a string
            # that can be mapped to an id
            assert isinstance(examples["input"][0], str)
            i = tokenizer.convert_tokens_to_ids(examples["input"])
            assert len(i) == len(examples["input"])
            i = tokenizer.build_inputs_with_special_tokens(i)
            d = {"input_ids": i,
                 "attention_mask": [1] * len(i)}
            if "output" in examples:
                o = tokenizer.convert_tokens_to_ids(examples["output"])
                assert len(o) == len(examples["output"])
                d["labels"] = tokenizer.build_inputs_with_special_tokens(o)

        return d

    epsilon_id = tokenizer.convert_tokens_to_ids(epsilon_str)

    data = {"input": [], "output": [], "cfg_rep": [], "nt_mask": [], "task_ids": []}
    with lzma.open(path, "rb") as f:
        load_data = pickle.load(f)
        if reverse_file:
            load_data.reverse()
        i = 0
        for dp in load_data:
            if filter_f is None or filter_f(dp):
                rules: list[ProductionRule] = dp["grammar"]
                if grammar_map_f is not None:
                    rules = grammar_map_f(rules)

                for input, output in dp["data"]:
                    data["task_ids"].append(dp["task_id"])
                    data["input"].append(input)
                    data["output"].append(output)
                    d, nt_mask = encode_grammar(rules, max_rhs_length, tokenizer, epsilon_id)
                    data["cfg_rep"].append(d)
                    data["nt_mask"].append(nt_mask)

                    if max_n is not None and i > max_n:
                        break
                    i += 1
            if max_n is not None and i > max_n:
                break

    dataset = Dataset.from_dict(data)
    if random_order:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)
    ts = BatchSampler(sampler, batch_size=batch_size, drop_last=False)
    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output"])

    seq2seq_collator = DataCollatorForSeq2Seq(tokenizer)
    def collator_fn(features):
        cfg_reps = []
        nt_masks = []
        for x in features:
            cfg_reps.append(x["cfg_rep"])
            nt_masks.append(x["nt_mask"])
            del x["cfg_rep"]
            del x["nt_mask"]
        d = seq2seq_collator(features)

        if max_len is not None:
            max_rules = max_len
        else:
            max_rules = max(len(el) for el in cfg_reps)

        cfg_rep = np.zeros((len(cfg_reps), max_rules, 3 + max_rhs_length), dtype=np.int64)
        nt_mask = np.ones((len(cfg_reps), max_rules, 3 + max_rhs_length), dtype=bool)
        for i in range(len(cfg_reps)):
            for j in range(min(max_rules, len(cfg_reps[i]))):
                cfg_rep[i, j, :] = cfg_reps[i][j]
                nt_mask[i, j, :] = nt_masks[i][j]

        d["cfg_rep"] = torch.from_numpy(cfg_rep)
        d["nt_mask"] = torch.from_numpy(nt_mask)

        if "task_id" in features[0]:
            d["task_ids"] = torch.from_numpy(np.array([x["task_id"] for x in features]))

        return d

    return DataLoader(dataset, collate_fn=collator_fn, batch_sampler=ts)




def load_ud_grammar_pickle(path: str, tokenizer: transformers.AutoTokenizer, batch_size:int, random_order: bool = True,
                        max_n:int=None,
                        filter_f = None,
                        grammar_map_f = None, output_replace_eos: tuple[str, str] = None, reverse_file: bool = False):
    def mapper(examples):

        if isinstance(examples["input"], str):
            d = tokenizer(examples["input"])
            if "output" in examples:
                if output_replace_eos is not None:
                    text = examples["output"].replace(output_replace_eos[0], output_replace_eos[1])
                else:
                    text = examples["output"]
                d["labels"] = tokenizer(text_target=text)["input_ids"]
        else:
            # Already "pre-tokenized", i.e. converted into a list and every element is a string
            # that can be mapped to an id
            assert isinstance(examples["input"][0], str)
            i = tokenizer.convert_tokens_to_ids(examples["input"])
            assert len(i) == len(examples["input"])
            i = tokenizer.build_inputs_with_special_tokens(i)
            d = {"input_ids": i,
                 "attention_mask": [1] * len(i)}
            if "output" in examples:
                o = tokenizer.convert_tokens_to_ids(examples["output"])
                assert len(o) == len(examples["output"])
                d["labels"] = tokenizer.build_inputs_with_special_tokens(o)

        return d

    data = {"input": [], "output": [], "ud_labels": [], "function_ids": [], "task_ids": []}
    with lzma.open(path, "rb") as f:
        load_data = pickle.load(f)
        if reverse_file:
            load_data.reverse()
        i = 0
        for dp in load_data:
            if filter_f is None or filter_f(dp):
                rules: list[ProductionRule] = dp["grammar"]
                if grammar_map_f is not None:
                    rules = grammar_map_f(rules)

                for input, output in dp["data"]:
                    data["task_ids"].append(dp["task_id"])
                    data["input"].append(input)
                    data["output"].append(output)
                    data["ud_labels"].append([rule.lhs for rule in rules])
                    data["function_ids"].append([rule.fint for rule in rules])

                    if max_n is not None and i > max_n:
                        break
                    i += 1
            if max_n is not None and i > max_n:
                break

    dataset = Dataset.from_dict(data)
    if random_order:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)
    ts = BatchSampler(sampler, batch_size=batch_size, drop_last=False)
    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output"])

    seq2seq_collator = DataCollatorForSeq2Seq(tokenizer)
    def collator_fn(features):
        ud_labels = []
        function_ids = []
        for x in features:
            ud_labels.append(x.pop("ud_labels"))
            function_ids.append(x.pop("function_ids"))

        d = seq2seq_collator(features)

        max_rules = max(len(el) for el in ud_labels)

        ud_labels_tensor = np.zeros((len(ud_labels), max_rules), dtype=np.int64)
        function_ids_tensor = np.ones((len(function_ids), max_rules), dtype=np.int64)
        for i in range(len(ud_labels)):
            ud_labels_tensor[i, : len(ud_labels[i])] = ud_labels[i]
            function_ids_tensor[i, : len(function_ids[i])] = function_ids[i]

        d["ud_labels"] = torch.from_numpy(ud_labels_tensor)
        d["function_ids"] = torch.from_numpy(function_ids_tensor)

        if "task_id" in features[0]:
            d["task_ids"] = torch.from_numpy(np.array([x["task_id"] for x in features]))

        return d

    return DataLoader(dataset, collate_fn=collator_fn, batch_sampler=ts)




def load_ud_grammar_pickle_json(path: str, tokenizer: transformers.AutoTokenizer, batch_size:int, random_order: bool = True,
                        max_n:int=None, output_replace_eos: tuple[str, str] = None):
    def mapper(examples):

        if isinstance(examples["input"], str):
            d = tokenizer(examples["input"])
            if "output" in examples:
                if output_replace_eos is not None:
                    text = examples["output"].replace(output_replace_eos[0], output_replace_eos[1])
                else:
                    text = examples["output"]
                d["labels"] = tokenizer(text_target=text)["input_ids"]
        else:
            # Already "pre-tokenized", i.e. converted into a list and every element is a string
            # that can be mapped to an id
            assert isinstance(examples["input"][0], str)
            i = tokenizer.convert_tokens_to_ids(examples["input"])
            assert len(i) == len(examples["input"])
            i = tokenizer.build_inputs_with_special_tokens(i)
            d = {"input_ids": i,
                 "attention_mask": [1] * len(i)}
            if "output" in examples:
                o = tokenizer.convert_tokens_to_ids(examples["output"])
                assert len(o) == len(examples["output"])
                d["labels"] = tokenizer.build_inputs_with_special_tokens(o)

        return d
    def generator():
        with gzip.open(path, "rt") as file_obj:
            for i, line in enumerate(file_obj):
                dp = json.loads(line)
                assert len(dp["function_ids"]) == len(dp["ud_labels"])
                dp["task_ids"] = dp.pop("task_id")
                yield dp

                if max_n is not None and i > max_n:
                    break

    dataset = IterableDataset.from_generator(generator)

    if random_order:
        dataset = dataset.shuffle(buffer_size=10_000)
    dataset = dataset.map(mapper, batched=False, remove_columns=["input", "output"])

    seq2seq_collator = DataCollatorForSeq2Seq(tokenizer)
    def collator_fn(features):
        ud_labels = []
        function_ids = []
        for x in features:
            ud_labels.append(x.pop("ud_labels"))
            function_ids.append(x.pop("function_ids"))

        d = seq2seq_collator(features)

        max_rules = max(len(el) for el in ud_labels)

        ud_labels_tensor = np.zeros((len(ud_labels), max_rules), dtype=np.int64)
        function_ids_tensor = np.ones((len(function_ids), max_rules), dtype=np.int64)
        for i in range(len(ud_labels)):
            ud_labels_tensor[i, : len(ud_labels[i])] = ud_labels[i]
            function_ids_tensor[i, : len(function_ids[i])] = function_ids[i]

        d["ud_labels"] = torch.from_numpy(ud_labels_tensor)
        d["function_ids"] = torch.from_numpy(function_ids_tensor)

        if "task_id" in features[0]:
            d["task_ids"] = torch.from_numpy(np.array([x["task_id"] for x in features]))

        return d

    return DataLoader(dataset, collate_fn=collator_fn, batch_size=batch_size)


def load_grammar_pickle_with_trees(path: str, tokenizer: transformers.AutoTokenizer, batch_size:int, epsilon_str: str, random_order: bool = True,
                    max_rhs_length: int = 3,
                   max_len: int = None, max_n:int=None, replacement_sep_token: str = None):
    epsilon_id = tokenizer.convert_tokens_to_ids(epsilon_str)
    replacement_sep_id = None or tokenizer.convert_tokens_to_ids(replacement_sep_token)

    data = {"input_ids": [], "labels": [], "cfg_rep": [], "nt_mask": [], "task_ids": [], "tree": []}
    with lzma.open(path, "rb") as f:
        load_data = pickle.load(f)
        i = 0
        for dp in load_data:
            for input, output, matrix in dp["data"]:
                data["task_ids"].append(dp["task_id"])
                data["input_ids"].append(input)
                data["labels"].append(output)
                rules: list[ProductionRule] = dp["grammar"]
                d, nt_mask = encode_grammar(rules, max_rhs_length, tokenizer, epsilon_id)
                data["cfg_rep"].append(d)
                data["tree"].append(matrix)
                data["nt_mask"].append(nt_mask)

                if max_n is not None and i > max_n:
                    break
                i += 1

    dataset = Dataset.from_dict(data)
    if random_order:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)
    ts = BatchSampler(sampler, batch_size=batch_size, drop_last=False)

    seq2seq_collator = DataCollatorForSeq2Seq(tokenizer)
    def collator_fn(features):
        cfg_reps = []
        nt_masks = []
        trees = []
        for x in features:
            cfg_reps.append(x["cfg_rep"])
            nt_masks.append(x["nt_mask"])
            trees.append(x["tree"])
            del x["tree"]
            del x["cfg_rep"]
            del x["nt_mask"]
        # Batch and pad the tokens
        d = seq2seq_collator(features)

        if max_len is not None:
            max_rules = max_len
        else:
            max_rules = max(len(el) for el in cfg_reps)

        cfg_rep = np.zeros((len(cfg_reps), max_rules, 3 + max_rhs_length), dtype=np.int64)
        nt_mask = np.ones((len(cfg_reps), max_rules, 3 + max_rhs_length), dtype=bool)
        for i in range(len(cfg_reps)):
            for j in range(min(max_rules, len(cfg_reps[i]))):
                cfg_rep[i, j, :] = cfg_reps[i][j]
                nt_mask[i, j, :] = nt_masks[i][j]

        d["cfg_rep"] = torch.from_numpy(cfg_rep)
        d["nt_mask"] = torch.from_numpy(nt_mask)

        max_l_trees = max(len(x["input_ids"]) for x in features)
        tree_tensor = np.zeros((len(trees), max_l_trees, max_l_trees), dtype=np.int64)
        for i, (t, f) in enumerate(zip(trees, features)):
            if replacement_sep_id is not None:
                idx = f["input_ids"].index(replacement_sep_id)+1
                tree_tensor[i, idx:idx+len(t), idx:idx+len(t)] = t
                pass
            else:
                tree_tensor[i, :len(t), :len(t)] = t

        d["tree"] = torch.from_numpy(tree_tensor)

        if "task_id" in features[0]:
            d["task_ids"] = torch.from_numpy(np.array([x["task_id"] for x in features]))

        return d

    return DataLoader(dataset, collate_fn=collator_fn, batch_sampler=ts)




def write_tsv(fname, data):
    with open(fname, "w") as f:
        for (x,y) in data:
            f.write(x)
            f.write("\t")
            f.write(y)
            f.write("\n")


class RandomSplit:

    def __init__(self, path: str, tokenizer: AutoTokenizer, num_train:int, train_batch_size, test_batch_size = None, lenient=True):
        def mapper(examples):
            d = tokenizer(examples["input"])
            if "output" in examples:
                d["labels"] = tokenizer(text_target=examples["output"])["input_ids"]
            return d

        keys = ["input", "output"]
        data = []
        for row in load_tsv(path, "input\toutput", lenient=lenient):
            data.append(row)
        print("Random number to verify seed", random.randint(0, 100_000_000), file=sys.stderr)
        random.shuffle(data)
        self.train_data = data[:num_train]
        self.rest_data = data[num_train:]

        train_dataset = Dataset.from_list([ {k: v for k,v in zip(keys, row)} for row in self.train_data])
        rest_dataset = Dataset.from_list([ {k: v for k,v in zip(keys, row)} for row in self.rest_data])

        sampler = SequentialSampler(train_dataset)
        ts = BatchSampler(sampler, batch_size=train_batch_size, drop_last=False)
        dataset = train_dataset.map(mapper, batched=True, remove_columns=["input", "output"])
        self.train_loader = DataLoader(dataset, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)


        sampler = SequentialSampler(rest_dataset)
        ts = BatchSampler(sampler, batch_size=train_batch_size if test_batch_size is None else test_batch_size, drop_last=False)
        dataset = rest_dataset.map(mapper, batched=True, remove_columns=["input", "output"])
        self.rest_loader = DataLoader(dataset, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)


    def save_split(self, pathname):
        write_tsv(pathname+"_train.tsv", self.train_data)
        write_tsv(pathname+"_test.tsv", self.rest_data)

    def get_train_loader(self):
        return self.train_loader

    def get_rest_loader(self):
        return self.rest_loader


from datasets import load_dataset


def data_loader_from_dataset(dataset, tokenizer,
                             input_template:str, output_template: str,
                             batch_size: int, random_order: bool = True):
    def map_and_tokenize(row):
        input_text = input_template.format(**row)
        output_text = output_template.format(**row)
        d = tokenizer(input_text)
        d["labels"] = tokenizer(text_target=output_text)["input_ids"]
        return d

    dataset = dataset.map(map_and_tokenize, batched=False, remove_columns=dataset.column_names)

    if random_order:
        sampler = RandomSampler(dataset)
    else:
        sampler = SequentialSampler(dataset)

    ts = BatchSampler(sampler, batch_size=batch_size,
                      drop_last=False)

    return DataLoader(dataset, collate_fn=DataCollatorForSeq2Seq(tokenizer), batch_sampler=ts)



class ComposeDataLoaders:
    """
    The first iteration goes over the first data loader, the second iteration goes over the second data loader etc.
    If we have k data loaders, the k+i (i > 0) iteration will go over the last (i.e. the k-th) data loader.

    This is useful for curriculum learning.
    """
    def __init__(self, dataloaders, counts = None):
        self.dataloaders = dataloaders
        self.index = 0
        if counts is not None:
            assert len(counts) == len(dataloaders)
        self.counts = counts

    @staticmethod
    def create(**kwargs):
        return ComposeDataLoaders(**kwargs)

    def __len__(self):
        return len(self.dataloaders[self.index])

    def __iter__(self):
        r = iter(self.dataloaders[self.index])
        self.counts[self.index] -= 1
        if self.counts[self.index] == 0 and self.index+1 < len(self.dataloaders):
            self.index += 1
        return r

