import json
from torch.utils.data import Dataset
import random
import torch
from torch.utils.data import DataLoader
import os

class NLG_Dataset_Manager():
    def __init__(self, path, max_seq_length, batch_size, seed, N):
        self._max_seq_length = max_seq_length
        self._batch_size = batch_size
        self._rng = random.Random(seed)
        self._N = N
        self._train_samples = self._read_NLG_file(os.path.join(path, "train.jsonl"))
        self._valid_samples = self._read_NLG_file(os.path.join(path, "valid.jsonl"))
        self._test_samples = self._read_NLG_file(os.path.join(path, "test.jsonl"))
        self._rng.shuffle(self._train_samples)

        self._num_train = len(self._train_samples)
        self._num_server = int(self._num_train / (self._N+1))
        self._num_client = int((self._num_train - self._num_server)/(self._N))
        self._train_samples_clients = []
        self._train_dataset_clients = []
        self._train_loader_clients = []
        samples = self._train_samples[:self._num_server]
        self._train_loader_clients.append(
            DataLoader(
                dataset = NLG_Dataset(samples, self._max_seq_length), 
                batch_size = self._batch_size,
                shuffle=True
            )
        )        
        for client_id in range(0, N):
            samples = self._train_samples[
                self._num_server + client_id*self._num_client:
                self._num_server + (client_id+1)*self._num_client
            ]
            self._train_loader_clients.append(
                DataLoader(
                    dataset = NLG_Dataset(samples, self._max_seq_length), 
                    batch_size = self._batch_size,
                    shuffle=True
                )
            )

        valid_dataset = NLG_Dataset(self._valid_samples, self._max_seq_length)
        test_dataset = NLG_Dataset(self._test_samples, self._max_seq_length, is_test = True)
        self._valid_loader = DataLoader(
            dataset = valid_dataset, 
            batch_size = self._batch_size,
        )
        self._test_loader = DataLoader(
            dataset = test_dataset, 
            batch_size = 1,
        )
    def _get_train_loader(self, client_id):
        return self._train_loader_clients[client_id]
    def _get_valid_loader(self):
        return self._valid_loader
    def _get_test_loader(self):
        return self._test_loader
        
    def _read_NLG_file(self, path):
        samples = []
        with open(path, 'r') as reader:
            for line in reader:
                items = json.loads(line.strip())
                context = items['context']
                completion = items['completion']
                samples.append([context, completion])
        return samples

class NLG_Dataset(Dataset):
    def __init__(self, data, max_seq_length, is_test = False):
        self._data = data
        self._max_seq_length = max_seq_length
        self._num_sample = len(self._data)
        self._is_test = is_test
    def _padding_tokens(self, tokens, pad_token, direct, max_context_length = 0):
        if max_context_length == 0:
            max_context_length = self._max_seq_length
        if len(tokens) > max_context_length:
            if direct > 0:
                pad_tokens = tokens[:max_context_length]
            else:
                pad_tokens = tokens[-max_context_length:]
        else:
            pad_tokens = tokens
        token_len = len(pad_tokens)
        pad_tokens = pad_tokens + [pad_token for _ in range(self._max_seq_length - token_len)]
        return pad_tokens, token_len
    def __len__(self):
        return len(self._data)
        
    def __getitem__(self, item):
        assert item < self._num_sample

        sample = self._data[item]
        conditions = sample[0]
        completion = sample[1]

        _input, _input_len = self._padding_tokens(conditions + completion, 0, 1)
        _target, _ = self._padding_tokens((conditions + completion)[1:], 0, 1)
        _msk = [0.0] * (len(conditions) - 1) + [1.0] * (_input_len - len(conditions))
        _msk, _ = self._padding_tokens(_msk, 0, 1)
        
        output = {}
        output["id"] = torch.tensor(item, dtype=torch.long)
        if(self._is_test):
            if len(conditions) > 512 - 64 -1:
                conditions = conditions[:512 - 64 -1]
            output["query"] = torch.tensor(conditions, dtype=torch.long)
            output["query_len"] = torch.tensor(len(conditions), dtype=torch.long)
        output["input"] = torch.tensor(_input, dtype=torch.long) 
        output["target"] = torch.tensor(_target, dtype=torch.long) 
        output["mask"] = torch.tensor(_msk, dtype=torch.float)
        return output
    



