from util import *
# from allennlp.modules.elmo import batch_to_ids

class SimpleDataIterator():
    def __init__(self, df,epochs):
        self.df = df
        self.size = len(self.df)
        self.epochs = epochs
        self.shuffle()

    def shuffle(self):
        self.df = self.df.sample(frac=1).reset_index(drop=True)
        self.cursor = 0

    def next_batch(self, n):
        if self.cursor+n-1 > self.size:
            self.epochs += 1
            self.shuffle()
        res = self.df.ix[self.cursor:self.cursor+n-1]
        self.cursor += n

        return res['indexed'].tolist(), res['label'].tolist(), res['len'].tolist()


class PaddedDataIterator(SimpleDataIterator):
    def next_batch(self, n):
        if self.cursor+n > self.size:
            self.epochs += 1
            self.shuffle()
        res = self.df.ix[self.cursor:self.cursor+n-1]
        self.cursor += n

        indexed = res['indexed'].tolist()
        shape = batch_shape(indexed)
        words_len = res['len'].tolist()

        x = sub_array(indexed,shape)
        lens = hierarchy_lengths(words_len)

        return x,lens, np.array(res['label'].tolist())


class BucketedDataIterator():
    def __init__(self, df, dic, num_buckets = 5,epochs=0,char_embedding=True,char_dict = None, oov_index=None,
                 min_cnt = 5,max_len=1000):
        self.df = df
        self.dic = dic
        self.inv = inverse_dictionary(dic)
        self.df = df.sort_values('len').reset_index(drop=True)
        self.size = len(self.df) / num_buckets
        self.dfs = []
        self.max_len = max_len
        self.char_embedding = char_embedding
        self.char_dict = char_dict or self.make_char_dic(dic,min_cnt)
        self.oov_index = oov_index or self.set_oov_index(dic,min_cnt)
        for bucket in range(num_buckets):
            self.dfs.append(self.df.ix[bucket * self.size: (bucket + 1) * self.size - 1])
        '''self.size = len(self.df) // num_buckets
        self.remainder = len(self.df) % num_buckets
        self.dfs = []
        self.sizes = np.array([self.size] * num_buckets)
        self.sizes[-1]+=self.remainder
        stored = 0
        for bucket in range(num_buckets):
            self.dfs.append(self.df.ix[stored: stored + self.sizes[bucket]])
            stored += self.sizes[bucket]'''
        self.num_buckets = num_buckets

        # cursor[i] will be the cursor for the ith bucket
        self.cursor = np.array([0] * num_buckets)
        self.finished = np.array([0] * num_buckets)

        self.shuffle()

        self.epochs = epochs

    def __len__(self):
        return len(self.df)

    def shuffle(self):
        #sorts dataframe by sequence length, but keeps it random within the same length
        for i in range(self.num_buckets):
            self.dfs[i] = self.dfs[i].sample(frac=1).reset_index(drop=True)
            self.cursor[i] = 0
            self.finished[i] = 0

    def next_batch(self, n):
        if np.sum(self.finished) == self.num_buckets:
            self.epochs += 1
            self.shuffle()
        i = np.random.randint(0, self.num_buckets)
        while self.finished[i] == 1:
            i = (i + 1) % self.num_buckets
        if self.cursor[i] + n - 1 >=self.size:
            self.cursor[i] = self.size - n + 1
            self.finished[i] = 1
        res = self.dfs[i].ix[self.cursor[i]:self.cursor[i] + n - 1]
        self.cursor[i] += n
        '''if self.cursor[i] >= self.sizes[i]:
            self.finished[i] = 1'''
        '''if self.cursor[i] >= self.size:
            self.finished[i] = 1'''

        indexed = res['indexed'].tolist()
        indexed = [self.convert_oov(line,self.oov_index) for line in indexed]
        words_len = res['len'].tolist()
        for idx,l in enumerate(words_len):
            if l > self.max_len:
                indexed[idx] = indexed[idx][:self.max_len //2] + indexed[idx][-self.max_len//2:]
                words_len[idx] = self.max_len
        shape = batch_shape(indexed)
        x = sub_array(indexed, shape, self.oov_index +1)

        if self.char_embedding:
            temp = indexed
            chars = [[self.wordidx_to_charidx(index) for index in text] for text in temp]
            # chars_indexed = [[self.char_dict[indices] for indices in text] for text in word_indices]
            char_shape = batch_shape(chars)
            char_lens = [[len(word) for word in text] for text in chars]
            chars = sub_array(chars,char_shape,len(self.char_dict)+3)
            char_lens_shape = batch_shape(char_lens)
            char_lens = sub_array(char_lens,char_lens_shape,0).tolist()

        if self.char_embedding:
            temp = list(zip(x, words_len, res['label'].tolist(), chars, char_lens))
            chars = []
            char_lens = []
        else:
            temp = list(zip(x, words_len, res['label'].tolist()))
        temp.sort(key=lambda x:x[1],reverse=True)

        x = []
        lens = []
        labels = []
        for i in range(len(temp)):
            x.append(temp[i][0])
            lens.append(temp[i][1])
            labels.append(temp[i][2])
            if self.char_embedding:
                chars.append(temp[i][3])
                char_lens.append(temp[i][4])
        if self.char_embedding:
            return np.array(x), np.array(chars), lens, char_lens, np.array(labels)
        return np.array(x), lens, np.array(labels)

    def sort(self, x,lens,labels):
        """
        :param x: list
        :param lens: list
        :param labels: list
        :return:
        """

    def set_oov_index(self,dic,min_cnt):
        cnt = 0
        for word in dic:
            if dic[word]['cnt'] > min_cnt:
                cnt+=1
        return cnt

    def convert_oov(self,indices,oov_index):
        """

        :param indices: list
        :param oov_index: int
        :return:
        """
        return np.minimum(np.array(indices),oov_index).tolist()

    def make_char_dic(self,dic, min_cnt):
        cnter = collections.Counter()
        char_dic = dict()
        for word in dic:
            if dic[word]['cnt'] >0:
                cnter.update(list(word))
        for char,cnt in cnter.most_common():
            if cnt > min_cnt:
                char_dic[char] = len(char_dic)
        return char_dic

    def wordidx_to_charidx(self,wordidx):
        word = self.inv[wordidx][0]
        # print(word)
        chars = [len(self.char_dict)+1]
        for character in word:
            if character in self.char_dict:
                chars.append(self.char_dict[character])
            else:
                chars.append(len(self.char_dict))
        chars.append(len(self.char_dict)+2)
        return chars


class AllenDataIterator():
    def __init__(self, df,dic, num_buckets = 5, epochs=0, glove=False, oov_index = None, min_cnt=5,max_len=250):
        self.df = df
        self.dic = dic
        self.max_len=max_len
        self.df = df.sort_values('len').reset_index(drop=True)
        self.inv = inverse_dictionary(dic)
        self.glove = glove
        self.oov_index = oov_index or self.set_oov_index(dic, min_cnt)
        self.size = len(self.df) / num_buckets
        self.dfs = []
        for bucket in range(num_buckets):
            self.dfs.append(self.df.ix[bucket * self.size: (bucket + 1) * self.size - 1])
        '''self.size = len(self.df) // num_buckets
        self.remainder = len(self.df) % num_buckets
        self.dfs = []
        self.sizes = np.array([self.size] * num_buckets)
        self.sizes[-1]+=self.remainder
        stored = 0
        for bucket in range(num_buckets):
            self.dfs.append(self.df.ix[stored: stored + self.sizes[bucket]])
            stored += self.sizes[bucket]'''
        self.num_buckets = num_buckets

        # cursor[i] will be the cursor for the ith bucket
        self.cursor = np.array([0] * num_buckets)
        self.finished = np.array([0] * num_buckets)

        self.shuffle()

        self.epochs = epochs

    def __len__(self):
        return len(self.df)

    def shuffle(self):
        #sorts dataframe by sequence length, but keeps it random within the same length
        for i in range(self.num_buckets):
            self.dfs[i] = self.dfs[i].sample(frac=1).reset_index(drop=True)
            self.cursor[i] = 0
            self.finished[i] = 0

    def next_batch(self, n):
        if np.sum(self.finished) == self.num_buckets:
            self.epochs += 1
            self.shuffle()
        i = np.random.randint(0, self.num_buckets)
        while self.finished[i] == 1:
            i = (i + 1) % self.num_buckets
        if self.cursor[i] + n - 1 >=self.size:
            self.cursor[i] = self.size - n + 1
            self.finished[i] = 1
        res = self.dfs[i].ix[self.cursor[i]:self.cursor[i] + n - 1]
        self.cursor[i] += n
        '''if self.cursor[i] >= self.sizes[i]:
            self.finished[i] = 1'''
        '''if self.cursor[i] >= self.size:
            self.finished[i] = 1'''

        indexed = res['indexed'].tolist()
        words_len = res['len'].tolist()
        for idx,l in enumerate(words_len):
            if l > self.max_len:
                indexed[idx] = indexed[idx][:self.max_len//2] + indexed[idx][-self.max_len//2:]
                words_len[idx] = self.max_len
        if self.glove:
            converted_indexed = [self.convert_oov(line, self.oov_index) for line in indexed]
            shape = batch_shape(converted_indexed)
            word_x = sub_array(converted_indexed, shape, self.oov_index + 1)

        label = res['label'].tolist()
        index_to_words = [[self.inv[word][0] for word in batch] for batch in indexed]

        if self.glove:
            temp = list(zip(index_to_words, words_len, label,word_x))
        else:
            temp = list(zip(index_to_words, words_len, label))
        temp.sort(key=lambda x:x[1],reverse=True)
        x = []
        lens = []
        labels = []
        if self.glove:
            word_x = []
        for i in range(len(temp)):
            x.append(temp[i][0])
            lens.append(temp[i][1])
            labels.append(temp[i][2])
            if self.glove:
                word_x.append(temp[i][3])
        x = batch_to_ids(x)
        if self.glove:
            return x, np.array(word_x), lens, np.array(labels)
        else:
            return x, lens, np.array(labels)

    def sort(self, x,lens,labels):
        """
        :param x: list
        :param lens: list
        :param labels: list
        :return:
        """

    def set_oov_index(self,dic,min_cnt):
        cnt = 0
        for word in dic:
            if dic[word]['cnt'] > min_cnt:
                cnt+=1
        return cnt

    def convert_oov(self,indices,oov_index):
        """

        :param indices: list
        :param oov_index: int
        :return:
        """
        return np.minimum(np.array(indices),oov_index).tolist()


def drop_df(df,drop_ratio):
    maxlen = max(df['len'].tolist())
    df['maxlen'] = maxlen
    df = df.sort_values('maxlen').reset_index(drop=True)
    drop_index = int(len(df) * drop_ratio)
    df_new = df[drop_index:]
    del df
    return df_new


def validation_set(df,validation_ratio):
    len_df = len(df)
    validation_size = int(len_df * validation_ratio)
    df = df.sample(frac=1).reset_index(drop=True)
    validation_df = df[:validation_size-1]
    train_df = df[validation_size:]

    return train_df, validation_df


# class BucketedDataIterator():
#     def __init__(self, df,dic, num_buckets = 5,epochs=0,heirarchy=True,char_embedding=True, char_pad_num = 5):
#         self.df = df
#         self.dic = dic
#         self.maxlen = [max(list_unroll(i)) for i in self.df['len']]
#         self.df['maxlen'] = self.maxlen
#         self.df = df.sort_values('maxlen').reset_index(drop=True)
#         self.size = len(self.df) / num_buckets
#         self.dfs = []
#         self.heirarchy = heirarchy
#         self.char_embedding = char_embedding
#         self.char_dict = make_char_dic()
#         self.char_pad_num = char_pad_num
#         for bucket in range(num_buckets):
#             self.dfs.append(self.df.ix[bucket * self.size: (bucket + 1) * self.size - 1])
#         '''self.size = len(self.df) // num_buckets
#         self.remainder = len(self.df) % num_buckets
#         self.dfs = []
#         self.sizes = np.array([self.size] * num_buckets)
#         self.sizes[-1]+=self.remainder
#         stored = 0
#         for bucket in range(num_buckets):
#             self.dfs.append(self.df.ix[stored: stored + self.sizes[bucket]])
#             stored += self.sizes[bucket]'''
#         self.num_buckets = num_buckets
#
#         # cursor[i] will be the cursor for the ith bucket
#         self.cursor = np.array([0] * num_buckets)
#         self.finished = np.array([0] * num_buckets)
#
#         self.shuffle()
#
#         self.epochs = epochs
#
#     def __len__(self):
#         return len(self.df)
#
#     def shuffle(self):
#         #sorts dataframe by sequence length, but keeps it random within the same length
#         for i in range(self.num_buckets):
#             self.dfs[i] = self.dfs[i].sample(frac=1).reset_index(drop=True)
#             self.cursor[i] = 0
#             self.finished[i] = 0
#
#     def next_batch(self, n):
#         if np.sum(self.finished) == self.num_buckets:
#             self.epochs += 1
#             self.shuffle()
#         i = np.random.randint(0, self.num_buckets)
#         while self.finished[i] == 1:
#             i = (i + 1) % self.num_buckets
#         if self.cursor[i] + n - 1 >=self.size:
#             self.cursor[i] = self.size - n + 1
#             self.finished[i] = 1
#         res = self.dfs[i].ix[self.cursor[i]:self.cursor[i] + n - 1]
#         self.cursor[i] += n
#         '''if self.cursor[i] >= self.sizes[i]:
#             self.finished[i] = 1'''
#         '''if self.cursor[i] >= self.size:
#             self.finished[i] = 1'''
#
#         indexed = res['indexed'].tolist()
#         if self.char_embedding:
#             chars_indexed = res['chars_indexed'].tolist()
#             if not self.heirarchy:
#                 chars_temp = []
#                 for batch in chars_indexed:
#                     chars_temp.append(list_unroll_but_last(batch))
#                 chars_indexed = chars_temp
#             char_shape = batch_shape(chars_indexed)
#             char_shape[-1] += self.char_pad_num
#             chars = padded_sub_array(chars_indexed,char_shape,len(self.char_dict),self.char_pad_num)
#         words_len = res['len'].tolist()
#         if not self.heirarchy:
#             indexed_temp = []
#             lens = []
#             for batch in indexed:
#                 indexed_temp.append(list_unroll(batch))
#             for batch in words_len:
#                 lens.append(sum(batch))
#             indexed = indexed_temp
#         else:
#             lens = hierarchy_lengths(words_len)
#         try:
#             shape = batch_shape(indexed)
#         except:
#             print(indexed)
#         x = sub_array(indexed, shape,len(self.dic)).tolist()
#
#         temp = list(zip(x, lens, res['label'].tolist()))
#         temp.sort(key=lambda x:x[1],reverse=True)
#         x = []
#         lens = []
#         labels = []
#         for i in range(len(temp)):
#             x.append(temp[i][0])
#             lens.append(temp[i][1])
#             labels.append(temp[i][2])
#         if self.char_embedding:
#             return np.array(x), np.array(chars), lens, np.array(labels)
#         return np.array(x), lens, np.array(labels)
#
#     def sort(self, x,lens,labels):
#         """
#         :param x: list
#         :param lens: list
#         :param labels: list
#         :return:
#         """
#
#
#
# def drop_df(df,drop_ratio):
#     maxlen = [max(list_unroll(i)) for i in df['len']]
#     df['maxlen'] = maxlen
#     df = df.sort_values('maxlen').reset_index(drop=True)
#     drop_index = int(len(df) * drop_ratio)
#     df = df[drop_index:]
#     return df
#
#
# def validation_set(df,validation_ratio):
#     len_df = len(df)
#     validation_size = int(len_df * validation_ratio)
#     df = df.sample(frac=1).reset_index(drop=True)
#     validation_df = df[:validation_size-1]
#     train_df = df[validation_size:]
#
#     return train_df, validation_df


# class BucketedDataIterator():
#     def __init__(self, df,dic, num_buckets = 5,epochs=0,char_embedding=True,char_dic=None):
#         self.df = df
#         self.dic = dic
#         self.df = df.sort_values('len').reset_index(drop=True)
#         self.size = len(self.df) / num_buckets
#         self.dfs = []
#         self.char_embedding = char_embedding
#         self.char_dict = char_dic
#         for bucket in range(num_buckets):
#             self.dfs.append(self.df.ix[bucket * self.size: (bucket + 1) * self.size - 1])
#         '''self.size = len(self.df) // num_buckets
#         self.remainder = len(self.df) % num_buckets
#         self.dfs = []
#         self.sizes = np.array([self.size] * num_buckets)
#         self.sizes[-1]+=self.remainder
#         stored = 0
#         for bucket in range(num_buckets):
#             self.dfs.append(self.df.ix[stored: stored + self.sizes[bucket]])
#             stored += self.sizes[bucket]'''
#         self.num_buckets = num_buckets
#
#         # cursor[i] will be the cursor for the ith bucket
#         self.cursor = np.array([0] * num_buckets)
#         self.finished = np.array([0] * num_buckets)
#
#         self.shuffle()
#
#         self.epochs = epochs
#
#     def __len__(self):
#         return len(self.df)
#
#     def shuffle(self):
#         #sorts dataframe by sequence length, but keeps it random within the same length
#         for i in range(self.num_buckets):
#             self.dfs[i] = self.dfs[i].sample(frac=1).reset_index(drop=True)
#             self.cursor[i] = 0
#             self.finished[i] = 0
#
#     def next_batch(self, n):
#         if np.sum(self.finished) == self.num_buckets:
#             self.epochs += 1
#             self.shuffle()
#         i = np.random.randint(0, self.num_buckets)
#         while self.finished[i] == 1:
#             i = (i + 1) % self.num_buckets
#         if self.cursor[i] + n - 1 >=self.size:
#             self.cursor[i] = self.size - n + 1
#             self.finished[i] = 1
#         res = self.dfs[i].ix[self.cursor[i]:self.cursor[i] + n - 1]
#         self.cursor[i] += n
#         '''if self.cursor[i] >= self.sizes[i]:
#             self.finished[i] = 1'''
#         '''if self.cursor[i] >= self.size:
#             self.finished[i] = 1'''
#
#         indexed = res['indexed'].tolist()
#         words_len = res['len'].tolist()
#
#         if self.char_embedding:
#             word_indices = res['indexed_char'].tolist()
#             chars_indexed = [[pad_sow(self.char_dict[indices],68,69) for indices in text] for text in word_indices]
#             # chars_indexed = [[self.char_dict[indices] for indices in text] for text in word_indices]
#             char_shape = batch_shape(chars_indexed)
#             chars = padded_sub_array(chars_indexed,char_shape,70)
#             char_lens = res['char_len'].tolist()
#             char_lens_shape = batch_shape(char_lens)
#             char_lens = sub_array(char_lens,char_lens_shape,0).tolist()
#
#         try:
#             shape = batch_shape(indexed)
#         except:
#             print(indexed)
#         x = sub_array(indexed, shape,self.dic['PADDING']).tolist()
#
#         if self.char_embedding:
#             temp = list(zip(x, words_len, res['label'].tolist(), chars, char_lens))
#             chars = []
#             char_lens = []
#         else:
#             temp = list(zip(x, words_len, res['label'].tolist()))
#         temp.sort(key=lambda x:x[1],reverse=True)
#         x = []
#         lens = []
#         labels = []
#         for i in range(len(temp)):
#             x.append(temp[i][0])
#             lens.append(temp[i][1])
#             labels.append(temp[i][2])
#             if self.char_embedding:
#                 chars.append(temp[i][3])
#                 char_lens.append(temp[i][4])
#         if self.char_embedding:
#             return np.array(x), np.array(chars), lens, char_lens, np.array(labels)
#         return np.array(x), lens, np.array(labels)
#
#     def sort(self, x,lens,labels):
#         """
#         :param x: list
#         :param lens: list
#         :param labels: list
#         :return:
#         """

if __name__ =='__main__':
    tr = load_file('test.pkl')
    dic = load_file('dictionary.pkl')

    df = BucketedDataIterator(tr, dic)
    x, chars, lens, char_lens,labels = df.next_batch(3)
    # print(len(df.char_dict))
    print(x)
    print(chars)