import torch
from ..sublayer.expander import Expander_Layer
from .indep import Indep_Layer
from .interdep import Interdep_Layer


class SeqRepr(torch.nn.Module):

    def __init__(self, expander_layer_para, indep_layer_para, interdep_layer_para, use_residual_structure=True):
        super(SeqRepr, self).__init__()
        # produce Field Settings:
        self.Expander_Layers = torch.nn.ModuleDict()
        for fld, para in expander_layer_para.items():
            self.Expander_Layers[fld.upper()] = Expander_Layer(para)

        self.Indep_Layers = torch.nn.ModuleDict()
        for fld, para in indep_layer_para:
            self.Indep_Layers[fld.upper()] = Indep_Layer(para, use_residual_structure)
            
        self.Interdep_Layer = Interdep_Layer(interdep_layer_para, use_residual_structure)

    def permute(self, single_results):
        return torch.cat([tensor.unsqueeze(2) for fld, tensor in single_results.items()], 2)  # .shape

    def forward(self, info_dict, leng_st, misc_info = None):
        device = leng_st.device
        indep_results = dict()
        # in multiple grains and one field each time.
        for fld, info in info_dict.items():
            # the whole input of a field
            # print(fld)
            fld = fld.upper()
            info, leng_tk, leng_tk_mask = info
            # expander_layers include embeddings of grains, positions, and token types.
            info = self.Expander_Layers[fld](info, leng_tk_mask = leng_tk_mask, misc_info = misc_info)
            # currently, self.Indep_Layers still need more development.
            # print('[fieldlm.module.seqrepr.forward]//misc_info:', fld, misc_info)
            # print('[fieldlm.module.seqrepr.forward]//leng_tk:', fld, leng_tk)
            # print('[fieldlm.module.seqrepr.forward]//leng_st:', fld, leng_st)
            info = self.Indep_Layers[fld](info, leng_tk, leng_tk_mask, leng_st, misc_info)
            indep_results[fld] = info
        
        # print('mix')
        # merge each field-tensor as a fields-tensor
        info = self.permute(indep_results)

        # in on grains and multiple fields
        # build new leng_tk and leng_tk_mask
        # don't forget to padding zeros
        fld_num = info.size(2); leng_st_mask = misc_info['leng_st_mask']
        leng_tk = (torch.ones_like(leng_tk).to(device) * fld_num).masked_fill(leng_st_mask, 0)
        batch_size, max_leng_st = leng_tk.shape
        leng_tk_mask = torch.ones_like(leng_tk).to(device).unsqueeze(-1).expand(batch_size, max_leng_st, fld_num).masked_fill(leng_st_mask.unsqueeze(-1), 0)
        leng_tk_mask = leng_tk_mask == 0
        
        # this info the representation of sentence with all fld information.
        info = self.Interdep_Layer(info, leng_tk, leng_tk_mask, leng_st, misc_info)
        
        return info