import numpy as np
import torch
import os
import re
import json
import collections
import random
from collections import Counter
import scipy.sparse as sp
import dgl

class DataLoader(object) :
	def __init__(self, data_path, vocab_path, batch_size, truncate, mode='train') :
		self.batch_size = batch_size
		self.mode = mode
		self.truncate = truncate
		self.word2idx, self.idx2word, self.en_vocab, self.de_vocab = self.load_vocab(vocab_path, data_path)
		self.dataset = self.load_dataset(data_path)

	def load_vocab(self, vocab_path, data_path) :
		if os.path.exists(vocab_path) :
			with open(vocab_path, 'r') as f :
				v = json.load(f)
			return v["word2idx"], v["idx2word"], v["en_vocab"], v["de_vocab"]
		elif self.mode == 'train' :
			with open(data_path) as f :
				dataset = json.load(f)
			src, tgt = [], []
			for data in dataset :
				src += data["all_body"]
				src += data["all_name"]
				tgt.append(data["comment"])
			
			text = ' '.join(tgt)
			words = text.split()
			word2cnt = Counter(words)
			word2idx = {word: idx+4 for idx, word in enumerate(word2cnt)}
			idx2word = {idx+4: word for idx, word in enumerate(word2cnt)}

			for idx, token in enumerate(["<PAD>", "<UNK>", "<S>", "</S>"]) :
				word2idx[token] = idx
				idx2word[idx] = token
			decoder_vocab_size = len(word2idx)

			text = ' '.join(src)
			words = text.split()
			word2cnt = Counter(words)
			for word, num in word2cnt.items() :
				if word not in word2idx :
					word2idx[word] = len(word2idx)
					idx2word[len(idx2word)] = word
			encoder_vocab_size = len(word2idx)

			tuple = {}
			tuple["en_vocab"] = encoder_vocab_size
			tuple["de_vocab"] = decoder_vocab_size
			tuple["word2idx"] = word2idx
			tuple["idx2word"] = idx2word
			with open(vocab_path, "w") as f :
				json.dump(tuple, f, indent=4)

			return word2idx, idx2word, encoder_vocab_size, decoder_vocab_size

	def load_dataset(self, data_path) :
		with open(data_path) as f :
			dataset = json.load(f)
		return dataset

	def gen_batch_data(self) :

		random.seed(1997)
		if self.mode == 'train' :
			random.shuffle(self.dataset)

		batch = collections.namedtuple('batch', ['src', 'src_len', 'tgt', 'graph', 'graph_x', 'graph_x_len', 'center'])

		dataset_size = len(self.dataset)
		current_idx = 0
		while True :
			if current_idx < dataset_size :

				src, tgt, all_body = [], [], []
				graph_list, center = [], []
				for data in self.dataset[current_idx: current_idx + self.batch_size] :
					src.append(data["body"])
					tgt.append("<S> " + data["comment"] + " </S>")

					g = dgl.DGLGraph()
					g.add_nodes(len(data["body_len"]))
					g.add_edges(data["src"][0], data["tgt"][0])
					graph_list.append(g)

					# must count length before adding current all_body to the array
					center.append(data["center"] + len(all_body))
					all_body += data["all_body"]

				x, x_len = self.get_idx(src, True)
				y, y_len = self.get_idx(tgt, False)
				graph_x, graph_x_len = self.get_idx(all_body, True)

				g = dgl.batch(graph_list)

				yield batch(src=x, src_len=x_len, tgt=y, graph=g, graph_x=graph_x, graph_x_len=graph_x_len, center=center)
				current_idx += self.batch_size
			else :
				break

	def get_idx(self, inputs, is_src) :
		len_x = []
		unpad_x = []
		if is_src :
			for s in inputs :
				sen_x = [self.word2idx[word] if word in self.word2idx else self.word2idx["<UNK>"] for word in s.split()]
				unpad_x.append(np.array(sen_x[:self.truncate]))
				len_x.append(min(len(sen_x), self.truncate))
		else :
			for s in inputs :
				sen_x = [self.word2idx[word] if word in self.word2idx else self.word2idx["<UNK>"] for word in s.split()]
				unpad_x.append(np.array(sen_x[:self.truncate]))
				len_x.append(min(len(sen_x), self.truncate))
				# len_x.append(len(sen_x))

		maxlen = max(len_x)
		pad_x = np.zeros([len(inputs), maxlen])
		padding_idx = self.word2idx["<PAD>"]
		for idx, x in enumerate(unpad_x) :
			pad_x[idx] = np.pad(x, [0, maxlen - len(x)], 'constant', constant_values=(padding_idx, padding_idx))
		return pad_x, len_x


# def in_place(new_data, max_graph_node, max_in_len, max_out_len):
# 	for key in ['body', 'name']:
# 		ans_l = []
# 		for data in new_data['all_{}'.format(key)]:
# 			d_ind = []
# 			for tgt_w in new_data['comment'].split()[1:]:
# 				w_ind = []
# 				for src_w in data.split():
# 					if src_w == tgt_w:
# 						w_ind.append(1)
# 					else:
# 						w_ind.append(0)
# 				w_ind = w_ind + [0] * (max_in_len - len(w_ind)) if len(w_ind) < max_in_len else w_ind[:max_in_len]
# 				d_ind.append(w_ind)
# 			d_ind = d_ind + [[0] * max_in_len] * (max_out_len - len(d_ind)) if len(d_ind) < max_out_len else d_ind[:max_out_len]
# 			ans_l.append(d_ind)
# 		ans_l = ans_l + [[[0] * max_in_len] * max_out_len] * (max_graph_node - len(ans_l)) if len(ans_l) < max_graph_node else ans_l[:max_graph_node]
# 		new_data['{}_ind'.format(key)] = ans_l

# 	return new_data["body_ind"], new_data["name_ind"]


if __name__ == '__main__':
	test_Dataset = DataLoader("./simple/train.json", "./simple/vocab.json", batch_size=5, mode='train')
	for batch in test_Dataset.gen_batch_data() :
		print("generate batch {}".format(batch.center))







