import torch
import torch.nn.functional as F
from fieldlm.nn.helper import reverse_tensor, _addindent, gelu

# from .helper import reverse_tensor, _addindent, gelu


postprecess = {
    'dropout' :[True, {'p':0.5, 'inplace':False}],
    'layernorm': [True, {'eps': 1e-05, "elementwise_affine":True}],
}


####################################################################################################
class ReduceSumLayer(torch.nn.Module):
    def __init__(self, type, input_size, output_size, postprecess = postprecess):
        super(ReduceSumLayer, self).__init__()
        self.type = type
        self.input_size = input_size
        self.output_size = output_size
        self.struct_type = 'REDUCER'
        # (+) postprocess here
        self.postprocess = []
        for method, use_config in postprecess.items():
            use, config = use_config
            if use == False: continue
            if method == 'dropout':
                self.drop = torch.nn.Dropout(**config)
                self.postprocess.append(self.drop)
            elif method == 'layernorm':
                # https://pytorch.org/docs/stable/nn.html
                self.layernorm = torch.nn.LayerNorm(self.output_size, **config)
                self.postprocess.append(self.layernorm)
                
    def forward(self, info, *args, **kwargs):
        info = torch.sum(info, -2)
        for post_layer in self.postprocess:
            info = post_layer(info)
        return info
    
    def __repr__(self):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
        lines = extra_lines + child_lines

        main_str = self._get_name() + '(' + self.struct_type.upper() + '): ' + '(' + str(self.input_size) + '->' + str(self.output_size) +') '+ '('
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        return main_str
    

# NNName_NNPara = [
#                     # NN Name
#                     'SUM',
#                     # NN Para
#                     {'type': 'sum', 
#                      'input_size':  200, 
#                      'output_size': 200,
#                     }
#                 ]
# NNName_NNPara

# NNName, NNPara = NNName_NNPara
# NNPara

# reducesum_layer = ReduceSumLayer(**NNPara)
# reducesum_layer

# print(layer_input.shape)
# layer_output = reducesum_layer(layer_input, leng_tk, misc_info)
# print(layer_output.shape)





####################################################################################################
class ReduceMeanLayer(torch.nn.Module):
    '''Only For 2D, can be used after reshaping 3D to 2D'''
    def __init__(self, type, input_size, output_size, postprecess = postprecess):

        super(ReduceMeanLayer, self).__init__()
        self.type = type
        self.input_size = input_size
        self.output_size = output_size
        self.struct_type = 'REDUCER'
        # (+) postprocess here
        self.postprocess = []
        for method, use_config in postprecess.items():
            use, config = use_config

            #####
            use = False
            #####

            
            if use == False: 
                continue
            if method == 'dropout':
                self.drop = torch.nn.Dropout(**config)
                self.postprocess.append(self.drop)
            elif method == 'layernorm':
                # https://pytorch.org/docs/stable/nn.html
                self.layernorm = torch.nn.LayerNorm(self.output_size, **config)
                self.postprocess.append(self.layernorm)
                
    def forward(self, info, leng_tk = None, misc_info = None, *args, **kwargs):
        # info (bs_a*, b, c) --> (bs_a*, c)
        # print('[fieldlm.nn.op.ReduceMeanLayer]//misc_info:', misc_info)
        # print('[fieldlm.nn.op.ReduceMeanLayer]//info:', info.shape)
        # print('[fieldlm.nn.op.ReduceMeanLayer]//leng_tk:', leng_tk.shape, leng_tk)
        leng_st_mask = misc_info['leng_st_mask']
        # print('[fieldlm.nn.op.ReduceMeanLayer]//leng_st_mask:', leng_st_mask.shape, leng_st_mask.unsqueeze(-1).shape)
        leng_tk_new = leng_tk.unsqueeze(-1).float()
        leng_tk_new[leng_tk_new == 0.] = 1.0
        # print('[fieldlm.nn.op.ReduceMeanLayer]//leng_tk_new:', leng_tk_new.shape)
        info = torch.sum(info, -2) # (bs_a*, b, c) --> (bs_a*, c)
        info = info/leng_tk_new    # (bs_a*, c)    --> (bs_a*, c)
        # print('[fieldlm.nn.op.ReduceMeanLayer]//info[:, 0]', info[:, 0])
        # print('[fieldlm.nn.op.ReduceMeanLayer]//info:', info.shape)
        # info = info.masked_fill(leng_st_mask.unsqueeze(-1), 0)
        for post_layer in self.postprocess:
            info = post_layer(info)
        return info
    
    def __repr__(self):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
        lines = extra_lines + child_lines

        main_str = self._get_name() + '(' + self.struct_type.upper() + '): ' + '(' + str(self.input_size) + '->' + str(self.output_size) +') '+ '('
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        return main_str
        
# NNName_NNPara = [
#                     # NN Name
#                     'MEAN',
#                     # NN Para
#                     {'type': 'mean', 
#                      'input_size':  200, 
#                      'output_size': 200,
#                     }
#                 ]
# NNName_NNPara

# NNName, NNPara = NNName_NNPara
# NNPara

# mean_layer = ReduceMeanLayer(**NNPara)
# mean_layer

# print(layer_input.shape)
# layer_output = mean_layer(layer_input, leng_tk, misc_info)
# print(layer_output.shape)



####################################################################################################
class RecuderMaxLayer(torch.nn.Module):
    '''Only For 2D, can be used after reshaping 3D to 2D'''
    def __init__(self, type, input_size, output_size, postprecess = postprecess):
        super(RecuderMaxLayer, self).__init__()
        assert input_size == output_size
        self.type = type
        self.input_size = input_size
        self.output_size = output_size
        self.struct_type = 'REDUCER'
        # (+) postprocess here
        self.postprocess = []
        for method, use_config in postprecess.items():
            use, config = use_config
            if use == False: continue
            if method == 'dropout':
                self.drop = torch.nn.Dropout(**config)
                self.postprocess.append(self.drop)
            elif method == 'layernorm':
                # https://pytorch.org/docs/stable/nn.html
                self.layernorm = torch.nn.LayerNorm(self.output_size, **config)
                self.postprocess.append(self.layernorm)
                
    def forward(self, info, *args, **kwargs):
        # (BS, A, B) --> (BS, B, A) --> (BS, B)
        bs, a, b = info.shape
        info = torch.transpose(info, -1, 1).contiguous()
        info = F.max_pool1d(info, a).squeeze()
        for post_layer in self.postprocess:
            info = post_layer(info)
        return info
    
    def __repr__(self):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
        lines = extra_lines + child_lines

        main_str = self._get_name() + '(' + self.struct_type.upper() + '): ' + '(' + str(self.input_size) + '->' + str(self.output_size) +') '+ '('
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        return main_str
        
# NNName_NNPara = [
#                     # NN Name
#                     'ReduceMax',
#                     # NN Para
#                     {'type': 'concat', 
#                      'input_size':  200, 
#                      'output_size': 200,
#                     }
#                 ]
# NNName_NNPara

# NNName, NNPara = NNName_NNPara
# NNPara

# rdcmax_layer = RecuderMaxLayer(**NNPara)
# rdcmax_layer
        
# # Only 2D 
# print(layer_output.shape)
# layer_output_further = rdcmax_layer(layer_output, leng_tk, misc_info)
# print(layer_output_further.shape)








####################################################################################################
class ConcatenateLayer(torch.nn.Module):
    def __init__(self, type, input_size, output_size, postprecess = postprecess):
        super(ConcatenateLayer, self).__init__()
        self.type = type
        self.input_size = input_size
        self.output_size = output_size
        self.struct_type = 'REDUCER'
        # (+) postprocess here
        self.postprocess = []
        for method, use_config in postprecess.items():
            use, config = use_config
            if use == False: continue
            if method == 'dropout':
                self.drop = torch.nn.Dropout(**config)
                self.postprocess.append(self.drop)
            elif method == 'layernorm':
                # https://pytorch.org/docs/stable/nn.html
                self.layernorm = torch.nn.LayerNorm(self.output_size, **config)
                self.postprocess.append(self.layernorm)
                
    def forward(self, info, *args, **kwargs):
        y, z = info.shape[-2:]
        info = info.view(*info.shape[:-2],  y*z)
        for post_layer in self.postprocess:
            info = post_layer(info)
        return info
    
    def __repr__(self):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
        lines = extra_lines + child_lines

        main_str = self._get_name() + '(' + self.struct_type.upper() + '): ' + '(' + str(self.input_size) + '->' + str(self.output_size) +') '+ '('
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        return main_str
        
      
# NNName_NNPara = [
#                     # NN Name
#                     'Concat',
#                     # NN Para
#                     {'type': 'concat', 
#                      'input_size':  200, 
#                      'output_size': 200 * 18, # This one is important
#                     }
#                 ]
# NNName_NNPara

# NNName, NNPara = NNName_NNPara
# NNPara

# concat_layer = ConcatenateLayer(**NNPara)
# concat_layer


# print(layer_input.shape)
# layer_output = concat_layer(layer_input, leng_tk, misc_info)
# print(layer_output.shape)




####################################################################################################
# try:
#     from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
# except ImportError:
#     # logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
#     print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
#     class LayerNorm(torch.nn.Module):
#         def __init__(self, hidden_size, eps=1e-12):
#             """Construct a layernorm module in the TF style (epsilon inside the square root).
#             """
#             super(LayerNorm, self).__init__()
#             self.weight = torch.nn.Parameter(torch.ones(hidden_size))
#             self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
#             self.variance_epsilon = eps

#         def forward(self, x):
#             u = x.mean(-1, keepdim=True)
#             s = (x - u).pow(2).mean(-1, keepdim=True)
#             x = (x - u) / torch.sqrt(s + self.variance_epsilon)
#             return self.weight * x + self.bias

            