import torch
from torch.utils.data import DataLoader

from dataset.language_dataset import language_dataset
from dataset.language_multitask_dataset import language_multitask_dataset
from dataset.language_text2text_dataset import language_text2text_dataset
from dataset.language_text2text_multitask_dataset import language_text2text_multitask_dataset


def create_dataset(dataset, config):
    if dataset == 'language':
        train_dataset = language_dataset(config['train_file'], config['file_root'])
        test_dataset = language_dataset(config['test_file'], config['file_root'])
        return train_dataset, test_dataset

    elif dataset == 'multitask':
        train_dataset = language_multitask_dataset(config['train_file'], config['file_root'])
        test_dataset = language_multitask_dataset(config['test_file'], config['file_root'])
        return train_dataset, test_dataset
    
    elif dataset == 'language_text2text':
        train_dataset = language_text2text_dataset(config['train_file'], config['file_root'])
        test_dataset = language_text2text_dataset(config['test_file'], config['file_root'])
        return train_dataset, test_dataset
    
    elif dataset == 'language_text2text_multitask':
        train_dataset = language_text2text_multitask_dataset(config['train_file'], config['file_root'])
        test_dataset = language_text2text_multitask_dataset(config['test_file'], config['file_root'])
        return train_dataset, test_dataset
    

def create_sampler(datasets, shuffles, num_tasks, global_rank):
    samplers = []
    for dataset,shuffle in zip(datasets,shuffles):
        sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
        samplers.append(sampler)
    return samplers     


def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
    loaders = []
    for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
        if is_train:
            shuffle = (sampler is None)
            drop_last = True
        else:
            shuffle = False
            drop_last = False
        loader = DataLoader(
            dataset,
            batch_size=bs,
            num_workers=n_worker,
            pin_memory=True,
            sampler=sampler,
            shuffle=shuffle,
            collate_fn=collate_fn,
            drop_last=drop_last,
        )              
        loaders.append(loader)
    return loaders    