from __future__ import print_function
from six.moves import xrange
import six.moves.cPickle as pickle

import gzip
import os

import numpy
import theano

import StringIO


def read_data_xy(readfilename):
    f = open(readfilename,"r")#.txt file

    x = []#list of list
    y = []#list

    for tmp_line in f:
        oneList = map(int,tmp_line.split(' '))
        x.append(oneList[:-1])
        y.append(oneList[-1])


    f.close()

    return x, y


def produce_data(readfilenames,savefilename):
    '''.txt files'''
    train_x, train_y = read_data_xy(readfilename[0])
    valid_x, valid_y = read_data_xy(readfilename[1])
    test_x, test_y  = read_data_xy(readfilename[2])


    data = ((train_x,train_y),(valid_x,valid_y),(test_x,test_y))#tuple

    f = open(savefilename,'wb')
    pickle.dump(data,f)
    f.close()  


    

def prepare_data(seqs, labels, maxlen=None, win_size=3):
    # x: a list of sentences
    lengths = [len(s) for s in seqs]

    if maxlen is not None:
        new_seqs = []
        new_labels = []
        new_lengths = []
        for l, s, y in zip(lengths, seqs, labels):
            #if l < maxlen:
                new_seqs.append(s)
                new_labels.append(y)
                new_lengths.append(l)
        lengths = new_lengths
        labels = new_labels
        seqs = new_seqs

        if len(lengths) < 1:
            return None, None, None

    n_samples = len(seqs)
    maxlen = numpy.max(lengths)

    x = numpy.zeros((n_samples, maxlen)).astype('int32')
    x_mask = numpy.zeros((n_samples, maxlen / win_size)).astype(theano.config.floatX)

    for idx, s in enumerate(seqs):
        x[idx , :lengths[idx]] = s
        x_mask[idx , :((lengths[idx]-1) / win_size)] = 1.

    return x, x_mask, labels, maxlen



def load_data(path, n_words=5000, valid_portion=0.2, maxlen=None,
              sort_by_len=False):

    #############
    # LOAD DATA #
    #############

   

    f = gzip.open(path, 'rb')
    train_set, valid_set, test_set = pickle.load(f)
    f.close()

    if maxlen:
        new_train_set_x = []
        new_train_set_y = []
        for x, y in zip(train_set[0], train_set[1]):
            #if len(x) < maxlen:
                new_train_set_x.append(x)
                new_train_set_y.append(y)
        train_set = (new_train_set_x, new_train_set_y)
        del new_train_set_x, new_train_set_y


    def remove_unk(x):
        return [[1 if w >= n_words else w for w in sen] for sen in x]

    test_set_x, test_set_y = test_set
    valid_set_x, valid_set_y = valid_set
    train_set_x, train_set_y = train_set

    train_set_x = remove_unk(train_set_x)
    valid_set_x = remove_unk(valid_set_x)
    test_set_x = remove_unk(test_set_x)

    def len_argsort(seq):
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))

    if sort_by_len:
        sorted_index = len_argsort(test_set_x)
        test_set_x = [test_set_x[i] for i in sorted_index]
        test_set_y = [test_set_y[i] for i in sorted_index]

        sorted_index = len_argsort(valid_set_x)
        valid_set_x = [valid_set_x[i] for i in sorted_index]
        valid_set_y = [valid_set_y[i] for i in sorted_index]

        sorted_index = len_argsort(train_set_x)
        train_set_x = [train_set_x[i] for i in sorted_index]
        train_set_y = [train_set_y[i] for i in sorted_index]

    train = (train_set_x, train_set_y)
    valid = (valid_set_x, valid_set_y)
    test = (test_set_x, test_set_y)

    return train, valid, test

def read_embedding_file_to_get_matrix(filename, savefilename):
    file_obj = open(filename,"r")
    embeddings = []
    
    for tmp_line in file_obj:
         one_embedding = numpy.loadtxt(StringIO.StringIO(tmp_line))#matrix
         embeddings.append(one_embedding)

    matrix = numpy.asarray(embeddings)

    file_obj.close()

    f = open(savefilename,'wb')
    pickle.dump(matrix,f)
    f.close()

    return matrix

def read_gz_file(filename):
    f = gzip.open(filename,'rb')
    data = pickle.load(f)
    f.close()

    return data


if __name__ == '__main__':

    readfilename = ["../train_idx.txt",
                    "../valid_idx.txt",
                    "../test_idx.txt"]
    savefilename = '../mydata.pkl'
    produce_data(readfilename,savefilename)


    '''m_arr = read_embedding_file_to_get_matrix("../word_embed.txt",
                                              "../emb.pkl")
    print(m_arr.shape)'''