import csv
import numpy as np
import torch

class DataHelper():
    def __init__(self, vocab_dict, sequence_max_length=1024):
        self.vocab_dict = vocab_dict
        self.sequence_max_length = sequence_max_length

    def char2vec(self, text):
        data = np.zeros(self.sequence_max_length)
        text = text.split()
        if len(text) > self.sequence_max_length :
            leng = self.sequence_max_length
        else:
            leng = len(text)
        #for i in range(0, leng):
            #data[i] = self.vocab_dict[text[i]]
        i = 0
        for v in text:
        	if v in self.vocab_dict:
        		data[i] = self.vocab_dict[v]
        		i += 1
        		if i == leng:
        			break
        return np.array(data)
        
        
    def load_csv_file(self, filename, num_classes, s1, train=True, one_hot=False):

        all_data =np.zeros(shape=(s1, self.sequence_max_length), dtype=np.int)
        labels =np.zeros(shape=(s1, 1), dtype=np.int)
        with open(filename) as f:
            reader = csv.DictReader(f, fieldnames=['class'], restkey='fields')
            for i,row in enumerate(reader):
                if one_hot:
                    one_hot = np.zeros(num_classes)
                    one_hot[int(row['class']) - 1] = 1
                    labels[i] = one_hot
                else:
                    labels[i] = int(row['class']) - 1
                text = row['fields'][-1]
                all_data[i] = self.char2vec(text)
        f.close()
        return all_data, labels
    def init_embeddings(self,embeddings):
        bias = np.sqrt(3.0 / embeddings.size(1))
        torch.nn.init.uniform_(embeddings, -bias, bias)


    def load_dataset(self, dataset_path,train_len,text_len):
        with open(dataset_path+"classes.txt") as f:
            classes = []
            for line in f:
                classes.append(line.strip())
        f.close()
        num_classes = len(classes)
        train_data, train_label = self.load_csv_file(dataset_path + 'train.csv', num_classes, train_len)
        test_data, test_label = self.load_csv_file(dataset_path + 'test.csv', num_classes, text_len, train=False)

        return train_data, train_label, test_data, test_label
    def batch_iter(self, data, batch_size, num_epochs, shuffle=True):
        data_size = len(data)
        num_batches_per_epoch = int((len(data)-1)/batch_size) + 1
        if shuffle:
            shuffle_indices = np.random.permutation(np.arange(data_size))
            shuffled_data = data[shuffle_indices]
        else:
            shuffled_data = data
        for batch_num in range(num_batches_per_epoch):
            start_index = batch_num * batch_size
            end_index = min((batch_num + 1) * batch_size, data_size)
            batch = shuffled_data[start_index:end_index]
            batch_data, label = np.split(batch, [self.sequence_max_length],axis=1)
            yield np.array(batch_data, dtype=np.int), label



if __name__ == '__main__':
    print('start')
