import torch
# from fieldlm.nn.embedding import EmbeddingLayer
from ..nn.embedding import EmbeddingLayer

postprecess = {
    'dropout' :[False, {'p':0.5, 'inplace':False}],
    'layernorm': [False, {'eps': 1e-05, "elementwise_affine":True}],
}


class Expander_Layer(torch.nn.Module):
    
    def __init__(self, expander_layer_para):
        super(Expander_Layer, self).__init__()
        layers = torch.nn.ModuleDict()
        for name, para in expander_layer_para.items():
            if name == 'grain':
                layers['grain']   = EmbeddingLayer(**para)
            elif name =='tk_psn':
                layers['tk_psn']  = EmbeddingLayer(**para)
            elif name == 'gr_psn':
                layers['gr_psn']  = EmbeddingLayer(**para)
            else:
                raise ValueError('no embedding type for:' + name + '; only grain, tk_psn, gr_psn are available')
            self.embed_size = para['embedding_size']
            
        self.layers = layers
        
        self.postprocess = []
        for method, use_config in postprecess.items():
            use, config = use_config
            if use == False: 
                continue
            
            if method == 'dropout':
                self.drop = torch.nn.Dropout(**config)
                self.postprocess.append(self.drop)
                
            elif method == 'layernorm':
                # https://pytorch.org/docs/stable/nn.html
                self.layernorm = torch.nn.LayerNorm(self.embed_size, **config)
                self.postprocess.append(self.layernorm)
                
                
    def forward(self, grain_idxes, leng_tk_mask, misc_info = None, tk_type_idex = None):
        embeddings = self.layers['grain'](grain_idxes)
        device = grain_idxes.device
        
        if 'tk_psn' in self.layers:
            a, b, c = leng_tk_mask.shape[:3]
            tk_psn_idxes = torch.arange(1, leng_tk_mask.size(1) + 1).to(device).unsqueeze(0).unsqueeze(-1).expand(a, b, c).masked_fill(leng_tk_mask, 0)
            embeddings = embeddings + self.layers['tk_psn'](tk_psn_idxes)

            if 'gr_psn' in self.layers:
                gr_psn_idxes = torch.arange(1, leng_tk_mask.size(2) + 1).to(device).unsqueeze(0).unsqueeze(1).expand(a, b, c).masked_fill(leng_tk_mask, 0)
                embeddings = embeddings + self.layers['gr_psn'](gr_psn_idxes)

            # original, grain embedding layer has a dropout and layernorm
            # if 'tk_psn' in self.layers or 'gr_psn' in self.layers:
            for layer in self.postprocess:
                embeddings = layer(embeddings)
        
        return embeddings
        