import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
###################################################################################
'''
This script contains the shallow adaptation layer and the CNN encoder model.
Output from the shallow adaptation layer is passed as input to the CNN encoder.
'''

class modCNN(nn.Module):
	def __init__(self,**kwargs):
		super(modCNN,self).__init__()

		self.model = kwargs["model"]
		self.batch_size = kwargs["batch_size"]
		self.max_sent_len = kwargs["max_sent_len"]
		self.word_dim = kwargs["word_dim"]
		self.vocab_size = kwargs["vocab_size"]
		self.class_size= kwargs["class_size"]
		self.filters = kwargs["filters"]
		self.filter_num = kwargs["filter_num"]
		self.dropout_prob = kwargs["dropout_prob"]
		self.in_channel = 1
		self.wv_mat = kwargs["wv_mat"]

		assert (len(self.filters) == len(self.filter_num))

		self.conv1w = torch.nn.Conv1d(self.in_channel,self.in_channel, 
			                                    kernel_size = 2,stride = 2,bias = True)
#		self.conv1w.weight = torch.nn.Parameter(torch.Tensor([0.5,0.5]).view(1,1,2))
		self.conv1w.weight.requires_grad = True

		for i in range(len(self.filters)):
			conv = torch.nn.Conv1d(self.in_channel, self.filter_num[i], self.word_dim*self.filters[i],
				                    stride = self.word_dim)
			setattr(self, f'conv_{i}', conv)

		self.fc = torch.nn.Linear(sum(self.filter_num), self.class_size)

	def get_conv(self,i):
		return getattr(self, f'conv_{i}')
 

	def forward(self, inp):
		x = self.conv1w(self.wv_mat).view(self.vocab_size+2,self.word_dim)
		x = F.embedding(inp, x)
		x = x.view(-1,1, self.word_dim*self.max_sent_len)
		conv_results = [
		F.max_pool1d(F.relu(self.get_conv(i)(x)), self.max_sent_len-self.filters[i]+1).
			             view(-1, self.filter_num[i])
			             for i in range(len(self.filters))]

		x = torch.cat(conv_results,1)
		x = F.dropout(x, p=self.dropout_prob, training=self.training)
		x = self.fc(x)

		return x