import torch
import json
from transformers import AutoTokenizer
import random


def generate_tokenized_data():
    with open("./data/complete_qa_final_filtered_preprocessed_train.json", "r") as input_file:
        dataset = json.load(input_file)
    infinite = False
    dataset_text_field = "text"
    seq_length = 512
    chars_per_token = 3.76
    num_of_sequences = 2
    max_buffer_size = seq_length * chars_per_token * num_of_sequences
    formatting_func = lambda x: x[dataset_text_field]
    shuffle = False
    current_size = 0
    HF_TOKEN = "token"
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=HF_TOKEN, use_fast=True)
    concat_token_id = tokenizer.eos_token_id
    append_concat_token = True


    iterator = iter(dataset)
    more_examples = True
    result = []
    while more_examples:
        buffer, buffer_len = [], 0
        while True:
            if buffer_len >= max_buffer_size:
                break
            try:
                buffer.append(formatting_func(next(iterator)))
                buffer_len += len(buffer[-1])
            except StopIteration:
                if infinite:
                    iterator = iter(dataset)
                    # warnings.warn("The dataset reached end and the iterator is reset to the start.")
                else:
                    more_examples = False
                    break
        tokenized_inputs = tokenizer(buffer, add_special_tokens=True, truncation=False)[
            "input_ids"
        ]
        all_token_ids = []
        for tokenized_input in tokenized_inputs:
            if append_concat_token:
                tokenized_input = tokenized_input + [concat_token_id]
            all_token_ids.extend(tokenized_input)
        examples = []
        for i in range(0, len(all_token_ids), seq_length):
            input_ids = all_token_ids[i : i + seq_length]
            if len(input_ids) == seq_length:
                examples.append(input_ids)
            else:
                x = 1
        if shuffle:
            random.shuffle(examples)
        for example in examples:
            current_size += 1
            result.append({
                "input_ids": torch.LongTensor(example),
                "labels": torch.LongTensor(example),
            })
result = generate_tokenized_data()