from dataclasses import dataclass
import numpy as np
import torch
import torch.utils.data
from tqdm import tqdm

@dataclass
class DatasetInfo:
    max_seq_length: int
    n_tokens: int
    n_classes: int
    multiclass: bool


def get_synthetic_dataset(num_samples: int, sequence_length: int, num_jumps: int):
    inputs, final_indices = _generate_data(num_samples, sequence_length, num_jumps)
    inputs = torch.from_numpy(inputs).long()
    final_indices = torch.from_numpy(final_indices).long()
    dataset = torch.utils.data.TensorDataset(inputs, final_indices)
    info = DatasetInfo(max_seq_length=sequence_length, n_tokens=sequence_length, n_classes=sequence_length + num_jumps + 1, multiclass=True)
    return dataset, info


def _generate_data(num_samples: int, sequence_length: int, num_jumps: int):
    assert sequence_length % (num_jumps + 1) == 0

    block_size = sequence_length // (num_jumps + 1)

    rng = np.random.default_rng(np.random.randint(0, 2**32))

    # generate A and place the root tokens at the end
    indices = np.arange(0, sequence_length).reshape(num_jumps+1, block_size, 1).repeat(num_samples, axis=2)
    jumps = rng.permuted(indices, axis=1).reshape(sequence_length, num_samples).transpose()
    # print("jumps:", jumps[0:2])
    
    # find root index
    # blocks = [jumps[:, :block_size]]
    # for i in range(block_size, sequence_length, block_size):
    #     b = np.take_along_axis(blocks[-1], jumps[:, i:i+block_size] % block_size, axis=1)
    #     blocks.append(b)
    # root_indices = np.concatenate(blocks, axis=1)

    # blocks = []
    # for i in tqdm(range(0, sequence_length, block_size)):
    #     idx = jumps[:, i:i+block_size]
    #     if i == 0:
    #         b = np.zeros((num_samples, block_size, sequence_length), dtype=int)
    #     else:
    #         b = np.take_along_axis(blocks[-1], idx[:, :, None] % block_size, axis=1)
    #     np.put_along_axis(b, idx[:, :, None], 1, axis=2)
    #     blocks.append(b)
    # trajectories = np.concatenate(blocks, axis=1)

    # none_token = block_size
    none_token = -100
    blocks = []
    for i in range(num_jumps+1):
        indices = jumps[:, i*block_size:(i+1)*block_size]
        if i == 0:
            b = np.concatenate((jumps[:, :block_size, None], np.full((num_samples, block_size, num_jumps), none_token, dtype=int)), axis=2)
        else:
            b = np.take_along_axis(blocks[-1], indices[:, :, None] % block_size, axis=1)
        b[:, :, i] = indices % block_size
        blocks.append(b)
    trajectories = np.concatenate(blocks, axis=1)

    # trajectories = []
    # idx = np.full(num_samples, block_size+sequence_length-1, dtype=int)
    # values = jumps[np.arange(num_samples), idx - block_size]
    # for _ in range(num_jumps):
    #     idx, values = values, jumps[np.arange(num_samples), values - block_size]
    #     trajectories.append(values % block_size)
    # trajectories = np.stack(trajectories, axis=1)

    # for i in range(num_samples):
    #     print(f"Sample {i}:")
    #     print("      ", np.arange(block_size, block_size + sequence_length))
    #     print("jumps:", jumps[i])
    #     # print("Final index:", root_indices[i])
    #     print("Final index:", trajectories[i])
    #     print()
    
    return jumps, trajectories


if __name__ == "__main__":
    np.random.seed(42)
    data, info = get_synthetic_dataset(num_samples=30, sequence_length=12, num_jumps=3)
    print(info)