import pickle
import random

import torch


class PTBLoader(object):
    '''Data path is assumed to be a directory with
       pkl files and a corpora subdirectory.
    '''

    def __init__(self, data_path):
        # make path available for nltk
        # nltk.data.path.append(data_path)
        dict_filepath = '%s-%s' % (data_path, 'dict.pkl')
        train_data_filepath = '%s-%s' % (data_path, 'train.pkl')
        valid_data_filepath = '%s-%s' % (data_path, 'dev.pkl')
        test_data_filepath = '%s-%s' % (data_path, 'test.pkl')

        print("loading dictionary ...")
        self.dictionary = pickle.load(open(dict_filepath, "rb"))

        # build tree and distance
        print("loading data ...")
        try:
            with open(train_data_filepath, 'rb') as file_data:
                self.train = pickle.load(file_data)
        except:
            self.train = None
        try:
            with open(valid_data_filepath, 'rb') as file_data:
                self.valid = pickle.load(file_data)
        except:
            self.valid = None
        try:
            with open(test_data_filepath, 'rb') as file_data:
                self.test = pickle.load(file_data)
        except:
            self.test = None

    def batchify(self, dataname, batch_size, cuda, shuffle=False):
        if dataname == 'train' and self.train is not None:
            idxs, distances = self.train
        elif dataname == 'valid' and self.valid is not None:
            idxs, distances = self.valid
        elif dataname == 'test' and self.test is not None:
            idxs, distances = self.test
        else:
            return None, None

        assert len(idxs) == len(distances)

        if shuffle:
            sentence_idx = list(range(len(idxs)))
            random.shuffle(sentence_idx)
            idxs_shuffled = [idxs[i] for i in sentence_idx]
            distances_shuffled = [distances[i] for i in sentence_idx]
        else:
            idxs_shuffled = idxs
            distances_shuffled = distances

        def flat(l):
            return [item for sublist in l for item in sublist]

        idxs_flatted = flat(idxs_shuffled)
        distances_flatted = flat(distances_shuffled)

        def _batchify(data, bsz, random_start_idx=0):
            # Work out how cleanly we can divide the dataset into bsz parts.
            nbatch = data.size(0) // bsz
            # Trim off any extra elements that wouldn't cleanly fit (remainders).
            start_idx = random_start_idx
            data = data.narrow(0, start_idx, nbatch * bsz)
            # Evenly divide the data across the bsz batches.
            data = data.view(bsz, -1).contiguous()
            if cuda:
                data = data.cuda()
            return data

        if dataname == 'train':
            random_start_idx = random.randint(0, len(idxs_flatted) % batch_size - 1)
        else:
            random_start_idx = 0

        idxs_batched = _batchify(torch.LongTensor(idxs_flatted), batch_size, random_start_idx)
        distances_batched = _batchify(torch.LongTensor(distances_flatted), batch_size, random_start_idx)

        return idxs_batched, distances_batched
