from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class IntNet(nn.Module):
    def __init__(self, config):
        super(IntNet, self).__init__()
        print("[Info] build char sequence feature extractor: IntNet ...")
        self.embedding_dim = config.char_intNet_embedding_dim
        self.hidden_size = config.char_intNet_hidden_dim
        self.alphabet_size = len(config.char2idx)
        self.cnn_layer = config.intNet_cnn_layer
        self.kernel_type = config.intNet_kernel_type
        self.char_drop = nn.Dropout(config.dropout)
        self.char_embeddings = nn.Embedding(self.alphabet_size, self.embedding_dim)
        self.char_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(self.alphabet_size, self.embedding_dim)))

        self.init_char_cnn_3 = nn.Conv1d(self.embedding_dim, self.embedding_dim, kernel_size=3, padding=1)
        self.init_char_cnn_5 = nn.Conv1d(self.embedding_dim, self.embedding_dim, kernel_size=5, padding=2)
      
        self.cnn_list = nn.ModuleList() 
        self.multi_cnn_list_3 = nn.ModuleList() 
        self.multi_cnn_list_5 = nn.ModuleList() 

        self.last_dim = self.embedding_dim * self.kernel_type
        for idx in range(int((self.cnn_layer - 1) / 2)):
            self.cnn_list.append(nn.Conv1d(self.last_dim, self.hidden_size, kernel_size=1, padding=0))
            self.multi_cnn_list_3.append(nn.Conv1d(self.hidden_size, self.hidden_size, kernel_size=3, padding=1))
            self.multi_cnn_list_5.append(nn.Conv1d(self.hidden_size, self.hidden_size, kernel_size=5, padding=2))
            self.last_dim += self.hidden_size * self.kernel_type
                
    def random_embedding(self, vocab_size, embedding_dim):
        pretrain_emb = np.empty([vocab_size, embedding_dim])
        scale = np.sqrt(3.0 / embedding_dim)
        for index in range(vocab_size):
            pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedding_dim])
        return pretrain_emb


    def get_last_hiddens(self, input, seq_lengths):
        """
            input:
                input: Variable(batch_size, word_length)
                seq_lengths: numpy array (batch_size,  1)
            output:
                Variable(batch_size, char_hidden_dim)
            Note it only accepts ordered (length) variable, length size is recorded in seq_lengths
        """
        batch_size = input.size(0)
        sent_len = input.size(1)
        char_seq_tensor = input.view(batch_size * sent_len, -1)
        char_seq_len = seq_lengths.view(batch_size * sent_len)

        activate_func = F.relu
        char_embeds = self.char_drop(self.char_embeddings(char_seq_tensor))
        char_embeds = char_embeds.transpose(2,1).contiguous()
        char_cnn_out3 = activate_func(self.init_char_cnn_3(char_embeds))
        char_cnn_out5 = activate_func(self.init_char_cnn_5(char_embeds))

        last_cnn_feature = torch.cat([char_cnn_out3,  char_cnn_out5], 1)  

        for idx in range(int((self.cnn_layer - 1) / 2)):
            cnn_feature = activate_func(self.cnn_list[idx](last_cnn_feature)) 
            cnn_feature_3 = activate_func(self.multi_cnn_list_3[idx](cnn_feature))
            cnn_feature_5 = activate_func(self.multi_cnn_list_5[idx](cnn_feature)) 

            cnn_feature = torch.cat([cnn_feature_3,  cnn_feature_5], 1) 
            cnn_feature = torch.cat([cnn_feature, last_cnn_feature], 1)
            last_cnn_feature = cnn_feature

        char_cnn_out = last_cnn_feature
        char_cnn_out_max = F.max_pool1d(char_cnn_out, char_cnn_out.size(2))

        return char_cnn_out_max.view(batch_size, sent_len, -1)

    def get_all_hiddens(self, input, seq_lengths):
        """
            input:
                input: Variable(batch_size,  word_length)
                seq_lengths: numpy array (batch_size,  1)
            output:
                Variable(batch_size, word_length, char_hidden_dim)
            Note it only accepts ordered (length) variable, length size is recorded in seq_lengths
        """
        batch_size = input.size(0)
        char_embeds = self.char_drop(self.char_embeddings(input))
        char_embeds = char_embeds.transpose(2,1).contiguous()
        char_cnn_out = self.char_cnn(char_embeds).transpose(2,1).contiguous()
        return char_cnn_out

    def forward(self, input, seq_lengths):
        return self.get_all_hiddens(input, seq_lengths)
