import numpy as np
import random
import torch
from datasets import load_dataset,load_from_disk

# Set seed for reproducibility
def set_seed(seed):
    np.random.seed(seed)
    torch.random.manual_seed(seed)

# Wrapper for tokenized input IDs
class TokenizerWrapper:
    def __init__(self, input_ids):
        self.input_ids = input_ids

# Load and process wikitext2 dataset
def get_wikitext2(seq_len, tokenizer):
    traindata = load_dataset('parquet', data_files='data/wikitext-2-raw-v1/train-00000-of-00001.parquet')
    testdata = load_dataset('parquet',data_files='data/wikitext-2-raw-v1/test-00000-of-00001.parquet')
    return traindata, testdata

def get_ptb(seq_len, tokenizer):
    traindata = load_dataset('text',data_files='data/ptb/ptb.train.txt')
    valdata = load_dataset('text',data_files='data/ptb/ptb.valid.txt')
    #traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
    #valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
    return traindata, valdata


def get_arxiv(seq_len, tokenizer):
    dataset = load_dataset('json',data_files='data/arxiv/arxiv_001.jsonl')
    testdata = dataset['train'].select(range(10))
    return testdata

def get_c4(seq_len, tokenizer):
    testdata = load_from_disk('data/c4_validation')
    testdata = testdata.select(range(5000))  #5000
    return testdata

def process_data(samples, tokenizer, seq_len, split=None, field_name=None, choose=False):
    #print(samples['train'])
    if split=='train':
        test_ids = tokenizer("\n\n".join(samples['train'][field_name]), return_tensors='pt').input_ids[0]
    else:
        test_ids = tokenizer("\n\n".join(samples[field_name]), return_tensors='pt').input_ids[0]
    
    test_ids_batch = []
    nsamples = test_ids.numel() // seq_len

    
    idx = [496, 417, 684, 809, 717, 574, 396,  22, 223, 418, 799, 515, 583, 155, 302,  45, 808,  57,
    430, 812, 517, 521,  52, 677, 525, 437, 315, 722, 261,  68, 662, 552, 793,  10, 224, 154,
    789, 221, 788, 209, 368, 218, 619, 380, 609, 815, 608, 112, 444, 450, 262, 725, 391, 820,
    578,  15, 821,  16,  17, 586, 594, 631, 565, 208,  12,  11, 632, 222, 225, 553,   4, 656,
    661, 671, 672, 537, 676, 678, 680, 518, 742, 516, 741, 510, 740, 685, 243, 694, 810, 435,
    733, 731,  51, 710, 713, 714, 172, 715, 716, 593]
    

    for i in range(nsamples):
        if choose:
            if i in idx:
                batch = test_ids[(i * seq_len):((i + 1) * seq_len)]
                test_ids_batch.append(batch)
        else:
            batch = test_ids[(i * seq_len):((i + 1) * seq_len)]
            test_ids_batch.append(batch)
    test_ids_batch = torch.stack(test_ids_batch)
    return IndexDataset(tensors=test_ids_batch)



# Function to select the appropriate loader based on dataset name
def get_loaders(name, tokenizer, seq_len=2048, batch_size = 8):
    if 'wikitext2' in name:
        train_data, test_data = get_wikitext2(seq_len, tokenizer)
        test_dataset = process_data(test_data, tokenizer, seq_len,'train', 'text')
    if 'ptb' in name:
        train_data, test_data = get_ptb(seq_len, tokenizer)
        test_dataset = process_data(test_data, tokenizer, seq_len,'train', 'text')
    if 'wikipedia' in name:
        test_data = get_wikipedia(seq_len, tokenizer)
        test_dataset = process_data(test_data, tokenizer, seq_len,split=None,field_name='text',choose=True)
    if 'arxiv' in name:
        test_data = get_arxiv(seq_len, tokenizer)
        test_dataset = process_data(test_data, tokenizer, seq_len,split=None,field_name='text')
    if 'c4' in name:
        test_data = get_c4(seq_len, tokenizer)
        test_dataset = process_data(test_data, tokenizer, seq_len,split=None,field_name='text')

    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_loader

