import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.nn.init as init
import pickle


SMALL = 1e-08

global_num = 0

class VMASK(nn.Module):
	def __init__(self, args):
		super(VMASK, self).__init__()

		self.device = args.device
		self.mask_hidden_dim = args.mask_hidden_dim
		self.activations = {'tanh': torch.tanh, 'sigmoid': torch.sigmoid, 'relu': torch.relu, 'leaky_relu': F.leaky_relu}
		self.activation = self.activations[args.activation]
		self.embed_dim = args.embed_dim
		
		#add lstm to get the prob of mask
		self.lstm = nn.LSTM(2, 2, num_layers=args.lstm_hidden_layer)
		self.lstm_sentence = nn.LSTM(400, 400, num_layers=args.lstm_hidden_layer)
		
		self.linear_layer = nn.Linear(self.embed_dim, self.mask_hidden_dim)
		self.hidden2p = nn.Linear(self.mask_hidden_dim, 2)

		self.sentecen_mappling = nn.Linear(10000, 400)


	def forward_sent_batch(self, embeds):

		temps = self.activation(self.linear_layer(embeds))
		p = self.hidden2p(temps)  # seqlen, bsz, dim 
		return p

	def forward(self, x, p, flag):
		temp_ones_torch = torch.zeros(400, 10000)
		for i in range(10000):
			lower = i * 25
			upper = (i + 1) * 25
			for j in range(400):
				if j >= lower and j < upper: 
					temp_ones_torch[i][j] = torch.tensor(1)
		if flag == 'train':
			#lstm_out, _ = self.lstm(x)
			#generate r of sentences
			#print("x's shape:" + str(x.size()))
			#print("p's shape:" + str(p.size()))


			lstm_out_p, _ = self.lstm(p)
			#print("lstm_out_p's shape:" + str(lstm_out_p.size()))
			new_p = lstm_out_p.permute(1, 2, 0)
			#print("new_p's shape:" + str(new_p.size()))
			
			new_p_sentences = self.sentecen_mappling(new_p) #64*300*5
			new_p_sentences_lstm, _ = self.lstm_sentence(new_p_sentences)
			
			#print("new_p_sentences's shape:" + str(new_p_sentences.size()))
			#new_p_sentences = new_p_sentences * torch.ones(64, 2, 50) #64*300*50
			temp_new_p_sentences = new_p_sentences_lstm.permute(2, 0, 1) #5*64*300
			r_sentences_pre = F.gumbel_softmax(temp_new_p_sentences,hard=True,dim=2)[:,:,1:2] #5*64*1
			temp_r_sentences_pre = r_sentences_pre.permute(1, 2, 0) #64*1*5
			new_p_sentences = torch.matmul(temp_r_sentences_pre, temp_ones_torch.to(self.device)) #64*1*50
			r_sentences = new_p_sentences.permute(2, 0, 1)
			#print("r_sentences's shape:" + str(r_sentences.size()))
			x_prime = r_sentences * x
			
			
			
			return x_prime
		
		elif flag == 'train_mask_important_words':
			#lstm_out, _ = self.lstm(x)
			#generate r of sentences
			#print("x's shape:" + str(x.size()))
			#print("p's shape:" + str(p.size()))


			lstm_out_p, _ = self.lstm(p)
			#print("lstm_out_p's shape:" + str(lstm_out_p.size()))
			new_p = lstm_out_p.permute(1, 2, 0)
			#print("new_p's shape:" + str(new_p.size()))


			
			new_p_sentences = self.sentecen_mappling(new_p) #64*300*5
			new_p_sentences_lstm, _ = self.lstm_sentence(new_p_sentences)
			
			#print("new_p_sentences's shape:" + str(new_p_sentences.size()))
			#new_p_sentences = new_p_sentences * torch.ones(64, 2, 50) #64*300*50
			temp_new_p_sentences = new_p_sentences_lstm.permute(2, 0, 1) #5*64*300
			r_sentences_pre = F.gumbel_softmax(temp_new_p_sentences,hard=True,dim=2)[:,:,1:2] #5*64*1
			temp_r_sentences_pre = r_sentences_pre.permute(1, 2, 0) #64*1*5
			new_p_sentences = torch.matmul(temp_r_sentences_pre, temp_ones_torch.to(self.device)) #64*1*50
			r_sentences = new_p_sentences.permute(2, 0, 1)
			#print("r_sentences's shape:" + str(r_sentences.size()))
			r_sentences_reverse = 1 - r_sentences
			x_prime = r_sentences_reverse * x
			
			return x_prime
		
		elif flag == 'train_unmask':

			x_prime = x
			
			
			
			return x_prime
			#r = F.gumbel_softmax(p,hard=True,dim=2)[:,:,1:2]
			#print("r's shape:" + str(r.size()))
			#print("x's shape:" + str(x.size()))
			#x_prime = r * x
			#print("x_prime's shape:" + str(x_prime.size()))
			#return x_prime
		else:
			#generate r of sentences
			#print("x's shape:" + str(x.size()))
			#print("p's shape:" + str(p.size()))

			lstm_out_p, _ = self.lstm(p)
			new_p = lstm_out_p.permute(1, 2, 0)

			#print("new_p's shape:" + str(new_p.size()))
			new_p_sentences = self.sentecen_mappling(new_p) #64*300*5

			

			new_p_sentences = self.sentecen_mappling(new_p) #64*300*5
			#print("new_p_sentences's shape:" + str(new_p_sentences.size()))
			#new_p_sentences = new_p_sentences * torch.ones(64, 2, 50) #64*300*50
			temp_new_p_sentences = new_p_sentences.permute(2, 0, 1) #5*64*300
			r_sentences_pre = F.softmax(temp_new_p_sentences,dim=2)[:,:,1:2] #5*64*1
			
			
# 			global global_num
# 			print("global_num" + str(global_num))
# 			global_num += 1
# 			probs_temp = pickle.load(open('test_weight_20210806_' + str(global_num) + '.pckl', 'rb'))
# 			temp_numpy_probs = probs_temp[0]
# 			temp_numpy_probs_numpy = np.array(temp_numpy_probs, dtype=np.float32)
# 			probs_temp_tensor = torch.from_numpy(temp_numpy_probs_numpy)
# 			probs_temp_tensor_cuda = probs_temp_tensor.cuda()
# 			r_sentences_pre = probs_temp_tensor_cuda
			
			
			
			
			temp_r_sentences_pre = r_sentences_pre.permute(1, 2, 0) #64*1*5
			new_p_sentences = torch.matmul(temp_r_sentences_pre, temp_ones_torch.to(self.device)) #64*1*50
			
			
# 			# 			#print(probs)
# 			global global_num
# 			print("global_num" + str(global_num))
# 			global_num += 1
# 			with open('lstm_mask_20210806_test_' + str(global_num) + '.pckl', 'wb') as res:
# 				pickle.dump([temp_r_sentences_pre], res)

			
			
			
			r_sentences = new_p_sentences.permute(2, 0, 1)
			x_prime = r_sentences * x
			return x_prime


			#probs = F.softmax(p,dim=2)[:,:,1:2] #select the probs of being 1
			#x_prime = probs * x
			#return x_prime

	def get_statistics_batch(self, embeds):
		p = self.forward_sent_batch(embeds)
		#print("embeds's shape:" + str(embeds.size()))
		#print("p's shape:" + str(p.size()))
		return p


class LSTM(nn.Module):
	def __init__(self, args, vectors):
		super(LSTM, self).__init__()

		self.args = args

		self.embed = nn.Embedding(args.embed_num, args.embed_dim, padding_idx=1)

		# initialize word embedding with pretrained word2vec
		self.embed.weight.data.copy_(torch.from_numpy(vectors))

		# fix embedding
		if args.mode == 'static':
			self.embed.weight.requires_grad = False
		else:
			self.embed.weight.requires_grad = True

		# <unk> vectors is randomly initialized
		nn.init.uniform_(self.embed.weight.data[0], -0.05, 0.05)

		# <pad> vector is initialized as zero padding
		nn.init.constant_(self.embed.weight.data[1], 0)

		# lstm
		self.lstm = nn.LSTM(args.embed_dim, args.lstm_hidden_dim, num_layers=args.lstm_hidden_layer)
		# initial weight
		init.xavier_normal_(self.lstm.all_weights[0][0], gain=np.sqrt(6.0))
		init.xavier_normal_(self.lstm.all_weights[0][1], gain=np.sqrt(6.0))

		# linear
		self.hidden2label = nn.Linear(args.lstm_hidden_dim, args.class_num)
		# dropout
		self.dropout = nn.Dropout(args.dropout)
		self.dropout_embed = nn.Dropout(args.dropout)

	def forward(self, x):
		# lstm
		lstm_out, _ = self.lstm(x)
		#print("lstm_out's shape:" + str(lstm_out.size()))
		lstm_out = torch.transpose(lstm_out, 0, 1)
		lstm_out = torch.transpose(lstm_out, 1, 2)
		# pooling
		lstm_out = torch.tanh(lstm_out)
		lstm_out = F.max_pool1d(lstm_out, lstm_out.size(2)).squeeze(2)
		lstm_out = torch.tanh(lstm_out)
		lstm_out = F.dropout(lstm_out, p=self.args.dropout, training=self.training)
		# linear
		logit = self.hidden2label(lstm_out)
		out = F.softmax(logit, 1)
		return out


class MASK_LSTM(nn.Module):

	def __init__(self, args, vectors):
		super(MASK_LSTM, self).__init__()
		self.args = args
		self.embed_dim = args.embed_dim
		self.device = args.device
		#self.sample_size = args.sample_size
		self.max_sent_len = args.max_sent_len
		#print("self.max_sent_len:" + str(self.max_sent_len))

		self.vmask = VMASK(args)
		self.lstmmodel = LSTM(args, vectors)

	def forward(self, batch, flag):# topk):
		# embedding
		x = batch.text.t()
		embed = self.lstmmodel.embed(x)
		#print("x's shape:" + str(x.size()))
		#print("embed's shape:" + str(embed.size()))
		embed = F.dropout(embed, p=self.args.dropout, training=self.training)
		#print("embed after dropout's shape:" + str(embed.size()))
		x = embed.view(len(x), embed.size(1), -1)  # seqlen, bsz, embed-dim
		#print("x after view's shape:" + str(embed.size()))
		# MASK
		p = self.vmask.get_statistics_batch(x)
		x_prime = self.vmask(x, p, flag)
		output = self.lstmmodel(x_prime)

		# self.infor_loss = F.softmax(p,dim=2)[:,:,1:2].mean()
		probs_pos = F.softmax(p,dim=2)[:,:,1]
		probs_neg = F.softmax(p,dim=2)[:,:,0]
		self.infor_loss = torch.mean(probs_pos * torch.log(probs_pos+1e-8) + probs_neg*torch.log(probs_neg+1e-8))

		return output
