import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import dgl.function as fn
from dgl.nn.pytorch import GATConv

'''
	GAT block
	g: input graph, a batch graph
	inputs: input for graph (g, hin)
'''

class GAT(nn.Module):
	def __init__(self, num_layers, input_size, hidden_size, heads=1, activation=F.tanh, feat_drop=0, attn_drop=0.1, negative_slope=0.2, residual=False):
		super(GAT, self).__init__()
		self.num_layers = num_layers
		self.gat_layers = nn.ModuleList()
		# input projection (no residual)
		self.gat_layers.append(GATConv(input_size, hidden_size, heads, feat_drop, attn_drop, negative_slope, False, activation))
		# hidden layers
		for l in range(1, num_layers):
			# due to multi-head, the input size = h * heads
			self.gat_layers.append(GATConv(hidden_size * heads, hidden_size, heads, feat_drop, attn_drop, negative_slope, residual, activation))
		# output projection
		self.gat_layers.append(GATConv(hidden_size * heads, hidden_size, heads, feat_drop, attn_drop, negative_slope, residual, None))

	def forward(self, g, inputs):
		h = inputs
		for l in range(self.num_layers):
			h = self.gat_layers[l](g, h).flatten(1)
		# output projection
		logits = self.gat_layers[-1](g, h).mean(1)
		return logits

'''
	input code sequence after embedding layer, through gru layer, 
	return a set of encoder output tensor and a final hidden tensor.

	x_embed: (b, h, embed) x_lengths: (b) --> x_outputs (b, l, 2h), x_hidden (1, b, 2h)
'''
class EncoderRNN(nn.Module) :
	def __init__(self, embed_size, hidden_size, layers, dropout) :
		super(EncoderRNN, self).__init__()

		self.gru = nn.GRU(embed_size, hidden_size, layers, batch_first=True, bidirectional=True, dropout=dropout)

	def forward(self, x_embed, x_lengths) :
		x_pack = pack_padded_sequence(input=x_embed, lengths=x_lengths, batch_first=True, enforce_sorted=False)
		x_outputs, x_hidden = self.gru(input=x_pack)
		x_outputs, _ = pad_packed_sequence(x_outputs, batch_first=True) # pad with zero
		x_hidden = torch.cat([x_hidden[0:x_hidden.size(0):2], x_hidden[1:x_hidden.size(0):2]], dim=2) # move layer dimension to the last dim
		return x_outputs, x_hidden

'''
	calculate the attention result of tgt input and memory set
	lengths indicates the length tensor of the memory set, it is used to caculate mask

	tgt : (b, lo, h1) memory: (b, lin, h2) lengths (b)

	attention : h1 W h2
	need to mask the memory which is not useful content
'''

class Attention(nn.Module) :
	def __init__(self, size1, size2) :
		super(Attention, self).__init__()
		self.size1 = size1
		self.size2 = size2
		self.linear_w = nn.Linear(size1, size2)
		self.linear_out = nn.Linear(size1 + size2, size1)
	def forward(self, tgt, memory, lengths) :
		weight = self.linear_w(tgt) # --> (b, lo, h2)
		weight = torch.bmm(weight, memory.transpose(1, 2)) # (b, lo, lin)

		mask = self.gen_mask(lengths)
		weight = weight.masked_fill_(~mask, -1e15)

		weight = F.softmax(weight, dim=2)
		context = torch.bmm(weight, memory) # (b, lo, lin) (b, lin, h2) --> (b, lo, h2)

		output = torch.cat([tgt, context], dim=2) # (b, lo, h1+h2)
		output = self.linear_out(output) # (b, lo, h1)
		return output, weight, context
		
	def gen_mask(self, lengths) :
		batch_size = lengths.size(0)
		maxlen = max(lengths)
		mask = torch.arange(0, maxlen).repeat(batch_size, 1).view(batch_size, -1)
		mask = mask.lt(lengths.unsqueeze(1)) # (b, lin)
		return mask.unsqueeze(1).cuda()

'''
	calculate attention between encoder output tensor set and current deocder hidden output
	calculate the pointer val utilizing the attention result

	y_embed : (b, lo, embed) x_outputs (b, l, 2h) last hidden (1, b, 2h) 
	x_idx (b, h) : index sequence of input code
'''

class DecoderRNN(nn.Module) :
	def __init__(self, en_vocab, de_vocab, embed_size, hidden_size, layers, dropout) :
		super(DecoderRNN, self).__init__()
		self.en_vocab = en_vocab
		self.de_vocab = de_vocab
		self.embed_size = embed_size
		self.hidden_size = hidden_size
		self.gru = nn.GRU(embed_size, hidden_size, batch_first=True, dropout=dropout)

		self.attn = Attention(hidden_size, hidden_size)
		self.pgen_c = nn.Linear(hidden_size, 1)
		self.pgen_s = nn.Linear(hidden_size, 1)
		self.pgen_y = nn.Linear(embed_size, 1)

		self.g_attn = Attention(hidden_size, hidden_size)

		self.linear_out = nn.Linear(hidden_size*2, de_vocab)

	def forward(self, y_embed, x_outputs, last_hidden, x_idx, lengths, split_gat_hidden, g_lenghts) :
		y_outputs, y_hidden = self.gru(y_embed, last_hidden)
		y_res, attn_weight, context = self.attn(y_outputs, x_outputs, lengths)
		g_res, _, _ = self.g_attn(y_outputs, split_gat_hidden, g_lenghts) # (b, g, 2h) (b, lo, 2h)

		y_res = torch.cat([y_res, g_res], dim=2)
		y_res = self.linear_out(y_res)
		y_res = F.softmax(y_res, dim=2)
		y_res = F.pad(y_res, (0, self.en_vocab - self.de_vocab))
		# dec_res = y_res

		pointer_res = self.pointer(attn_weight, x_idx)

		pgen_1 = self.pgen_c(context)
		pgen_2 = self.pgen_s(y_outputs)
		pgen_3 = self.pgen_y(y_embed)
		pgen = torch.sigmoid(pgen_1 + pgen_2 + pgen_3)

		dec_res = pgen * y_res + (1 - pgen) * pointer_res
		delta = 1e-15
		dec_res = (1 - delta) * dec_res + delta * torch.ones_like(dec_res).cuda()
		return dec_res, y_hidden

	'''
		input a attention distribution (b, lo, lin) and a idx tensor (b, lin) -->
		pointer attention distribution (b, lo, v)
	'''
	def pointer(self, attn, idx) :
		onehot_idx = F.one_hot(idx, num_classes=self.en_vocab).type_as(attn)
		pointer_attn = torch.bmm(attn, onehot_idx.cuda())
		pointer_attn = F.softmax(pointer_attn, dim=2)
		return pointer_attn

'''
	input code sequence through embedding layer --> encoder --> decoder

'''

class Graph2Seq(nn.Module) :
	def __init__(self, en_vocab, de_vocab, embed_size, hidden_size, layers, dropout, padding_idx) :
		super(Graph2Seq, self).__init__()
		self.en_vocab = en_vocab
		self.de_vocab = de_vocab
		self.embedding = nn.Embedding(en_vocab, embed_size, padding_idx=padding_idx)
		self.encoder = EncoderRNN(embed_size, hidden_size, layers, dropout)
		self.decoder = DecoderRNN(en_vocab, de_vocab, embed_size, hidden_size * 2, layers, dropout)
		self.gat = GAT(5, hidden_size*2, hidden_size*2)

		self.reg_size = nn.Linear(hidden_size*4, hidden_size*2)

	def forward(self, x, x_len, y, g, graph_x, graph_x_len, center, mode) :
		x_embed = self.embedding(x)
		x_outputs, x_hidden = self.encoder(x_embed, x_len) # (b, l, 2h) (1, b, 2h)

		graph_x_embed = self.embedding(graph_x)
		graph_x_outputs, graph_x_hidden = self.encoder(graph_x_embed, graph_x_len) # (bg, l, 2h) (1, bg, 2h)
		gat_outputs = self.gat(g, graph_x_hidden.squeeze(0))
		gat_hidden = gat_outputs[center] # (b, 2h)
		split_gat_hidden = self.list2batch(gat_outputs, g.batch_num_nodes) # (bg, 2h) --> (b, g, 2h) padded
		g_lenghts = torch.tensor(g.batch_num_nodes, dtype=torch.int64)

		hidden = torch.cat([x_hidden, gat_hidden.unsqueeze(0)], dim=2)
		hidden = self.reg_size(hidden)
		# hidden = gat_hidden.unsqueeze(0)

		if mode == 'train' :
			y_input = y[:, :-1]
			y_embed = self.embedding(y_input)
			dec_res, hidden = self.decoder(y_embed, x_outputs, hidden, x, x_len, split_gat_hidden, g_lenghts)
			return dec_res
		else :
			batch_size, maxlen = y.size()
			y_input = y[:, 0].unsqueeze(1)
			dec_list = torch.zeros(maxlen - 1, batch_size, self.en_vocab).cuda()
			for idx in range(maxlen - 1) :
				y_embed = self.embedding(y_input)
				dec_res, hidden = self.decoder(y_embed, x_outputs, hidden, x, x_len, split_gat_hidden, g_lenghts)
				y_input = torch.argmax(dec_res, dim=2).cuda()
				dec_list[idx] = dec_res.squeeze(1)
			return dec_list.transpose(0, 1)

	def list2batch(self, outputs, lengths) :
		x_set = torch.split(outputs, lengths, dim=0) # tuple([g, 2h]...)
		maxlen = max(lengths)
		batch_size = len(lengths)
		_, hidden_size = outputs.size()
		x_pad = torch.zeros([batch_size, maxlen, hidden_size]).cuda()
		for idx, x in enumerate(x_set) :
			x_pad[idx] = F.pad(x, (0, 0, 0, maxlen - x.size(0)))
		return x_pad

			
