import torch
# from fieldlm.sublayer.extractor import Tensor_Extractor_Layer, Matrix_Extractor_Layer
# from fieldlm.sublayer.reducer import Tensor_Reducer_Layer


from ..sublayer.extractor import Tensor_Extractor_Layer, Matrix_Extractor_Layer
from ..sublayer.reducer import Tensor_Reducer_Layer

# Tensor_Reducer_Layer_Para = [ 
#     # STRUCTURE:
#         'Tensor_Reducer',    
#     # Meanings:
#         {
#             'InputMeaning':    'GrainVec_SeqAS_Token_SeqAS_Sent',
#             'OutputMeaning':   'TokenVec_SeqAS_Sent',
#             'Reshape_Restore':  None,
#         },
#     # NNName_NNPara
#         [
#             # NN Name
#             'Mean',
#             # NN Para
#             {'type': 'mean', 
#              'input_size':  None, 
#              'output_size': None,
#              'postprecess' :{
#                  'dropout' :[True, {'p':0.5, 'inplace':False}],
#                  # 'activiator': [True, 'relu'],
#                  'layernorm': [True, {'eps': 1e-05, "elementwise_affine":True}],
#                 }
             
#             }
#         ]
# ]


# (+) Don't change this for CNN and TFM Tensor Extractor =======================================
Tensor_Reducer_Layer_Para = [ 
    # STRUCTURE:
        'Tensor_Reducer',    
    # Meanings:
        {
            'InputMeaning':    'GrainVec_SeqAS_Token_SeqAS_Sent',
            'OutputMeaning':   'TokenVec_SeqAS_Sent',
            'Reshape_Restore': 'GrainVec_SeqAs_Token',
        },
    # NNName_NNPara
        [
            # NN Name
            'Max',
            # NN Para
            {'type': 'Max', 
                'input_size':  None, 
                'output_size': None,
                'postprecess' :{
                    # 'dropout' :[True, {'p':0.5, 'inplace':False}],
                    # 'activiator': [True, 'relu'],
                    # 'layernorm': [True, {'eps': 1e-05, "elementwise_affine":True}],
                }
            }
        ]
]
    
    

# Structure_Layer_Name, Meanings, NNName_NNPara = Tensor_Reducer_Layer_Para
# tns_reducer_layer = Tensor_Reducer_Layer(Meanings, NNName_NNPara)
# tns_reducer_layer

postprecess = {
    'dropout' :[True, {'p':0.5, 'inplace':False}],
    'activiator': [True, 'relu'],
    'layernorm': [True, {'eps': 1e-05, "elementwise_affine":True}],
}


class Indep_Layer(torch.nn.Module):

    def __init__(self, indep_layer_para, use_residual_structure = True):
        super(Indep_Layer, self).__init__()
        
        self.Stack_NN_Layers = torch.nn.ModuleList()

        for layer_para in indep_layer_para:
            # print(layer_para)
            Structure_Layer_Name, Meanings, NNName_NNPara = layer_para
            assert Structure_Layer_Name in ['Tensor_Extractor', 'Tensor_Reducer', 'Matrix_Extractor']
            if 'Tensor_Extractor' == Structure_Layer_Name:
                self.Stack_NN_Layers.append(Tensor_Extractor_Layer(Meanings, NNName_NNPara))
            elif 'Tensor_Reducer' == Structure_Layer_Name:
                self.Stack_NN_Layers.append(Tensor_Reducer_Layer(Meanings, NNName_NNPara))
            elif 'Matrix_Extractor' == Structure_Layer_Name:
                self.Stack_NN_Layers.append(Matrix_Extractor_Layer(Meanings, NNName_NNPara))
                
        self.output_size = NNName_NNPara[1]['output_size']
                
        self.use_residual_structure = use_residual_structure
        if self.use_residual_structure == True:
            # (+) 
            Structure_Layer_Name, Meanings, NNName_NNPara = Tensor_Reducer_Layer_Para
            NNName_NNPara[1]['input_size']  = self.output_size
            NNName_NNPara[1]['output_size'] = self.output_size
            self.Residual_Structure = Tensor_Reducer_Layer(Meanings, NNName_NNPara)
            # (+) postprocess here
            self.postprocess = []
            for method, use_config in postprecess.items():
                use, config = use_config
                if use == False: continue
                if method == 'dropout':
                    self.drop = torch.nn.Dropout(**config)
                    self.postprocess.append(self.drop)
                elif method == 'layernorm':
                    # https://pytorch.org/docs/stable/nn.html
                    self.layernorm = torch.nn.LayerNorm(self.output_size, **config)
                    self.postprocess.append(self.layernorm)
                    

    def forward(self, info, leng_tk, leng_tk_mask, leng_st, misc_info):
        # print('[fieldm.module.indep.Indep_Layer.forward]//misc_info:', misc_info)
        if self.use_residual_structure:
            info_orig = self.Residual_Structure(info, 
                                        leng_tk = leng_tk, 
                                        leng_tk_mask = leng_tk_mask, 
                                        leng_st = leng_st, 
                                        misc_info = misc_info)
        
        for layer in self.Stack_NN_Layers:
            info = layer(info, 
                         leng_tk = leng_tk, 
                         leng_tk_mask = leng_tk_mask, 
                         leng_st = leng_st, 
                         misc_info = misc_info)
                         
        if self.use_residual_structure:
            info = info + info_orig
            for layer in self.postprocess:
                info = layer(info)
            
        return info
    
