import torch
from torch import nn
from src.models.encoders.charCNN import CharCNN
from src.models.encoders.charRNN import CharRNN

class CharWordEmbedding(nn.Module):
    def __init__(self, layerUtil):
        super(CharWordEmbedding, self).__init__()
        self.embedding = layerUtil.getEmbeddingParameter()
        # self.charEmbedding = CharCNN(layerUtil)
        self.charEmbedding = CharRNN(layerUtil)
        self.dropout = layerUtil.getDropOut()
        self.wordEmbedding = None
        self.eit = None

    def forward(self, wordSeqTensors, charSeqTensors, charSeqLengths):
        self.wordEmbedding = self.embedding(wordSeqTensors)
        if self.eit is not None:
            self.wordEmbedding += self.eit
        self.wordEmbedding.retain_grad()
        charCNNEmbedding = self.charEmbedding(charSeqTensors, charSeqLengths)
        mergeEmbedding = torch.cat([self.wordEmbedding, charCNNEmbedding], 2)
        wordEmbeddingDropOut = self.dropout(mergeEmbedding)
        return wordEmbeddingDropOut