# fieldlm.nn.embedding
import os
import torch
import numpy as np


postprecess = {
    'dropout' :[False, {'p':0.5, 'inplace':False}],
    'layernorm': [False, {'eps': 1e-05, "elementwise_affine":True}],
}

class EmbeddingLayer(torch.nn.Module):

    def __init__(self, 
                 input_size, 
                 embedding_size, 
                 init = 'init', 
                 freeze = False, 
                 postprecess = postprecess):
        
        super(EmbeddingLayer, self).__init__()
        
        # (+) self.embedding
        if type(init) == np.ndarray:
            # 1. from given array
            weight = torch.FloatTensor(init)
            # print(freeze)
            # freeze = False
            # print(freeze)
            assert weight.shape == (input_size, embedding_size)
            self.embedding = torch.nn.Embedding.from_pretrained(weight, freeze = freeze)
            
        elif os.path.isfile(init):
            # 2. from given path
            weight = torch.FloatTensor(np.load(init))
            assert tuple(weight.shape) == (input_size, embedding_size)
            
            self.embedding = torch.nn.Embedding.from_pretrained(weight, freeze = freeze)
            
        else:
            # from random initialization
            self.embedding = torch.nn.Embedding(input_size, embedding_size, padding_idx = 0)

        self.postprocess = []
        for method, use_config in postprecess.items():
            use, config = use_config
            if use == False: continue
            if method == 'dropout':
                self.drop = torch.nn.Dropout(**config)
                self.postprocess.append(self.drop)
            elif method == 'layernorm':
                # https://pytorch.org/docs/stable/nn.html
                self.layernorm = torch.nn.LayerNorm(embedding_size, **config)
                self.postprocess.append(self.layernorm)
            
        
    def forward(self, info):
        # tensor0 to tensor1
        info = self.embedding(info)
        for post_layer in self.postprocess:
            info = post_layer(info)
        return info