import torch
import math
import torch.nn.functional as F
# from fieldlm.nn.helper import change_feat_location, _addindent, gelu
from .helper import change_feat_location, _addindent, gelu



postprecess = {
    'dropout' :[True, {'p':0.5, 'inplace':False}],
    'activiator': [True, 'relu'],
    'layernorm': [True, {'eps': 1e-05, "elementwise_affine":True}],
}


class LinearLayer(torch.nn.Module):

    def __init__(self, 
                 type = 'linear', 
                 n_layers = 1, 
                 input_type  = 'INPUT-NML',
                 struct_type = 'EXTRACTOR',
                 input_size  = 200, 
                 output_size =  200.
                 ):

        super(LinearLayer, self).__init__()
        
        assert input_type in ['INPUT-NML', 'INPUT-SEP']
        self.input_type = input_type
        # assert direction_type in ['FWD', 'BI-MIX', 'BI-SEP'] # for mu
        assert struct_type in ['EXTRACTOR', 'REDUCER']
        self.struct_type = struct_type
        
        assert n_layers == 1
        self.n_layers = n_layers
        
        self.input_size  = input_size
        self.output_size = output_size
        
        if self.struct_type == 'EXTRACTOR':
            if self.input_type == 'INPUT-NML':
                self.lnr_input_size  = input_size
                self.lnr_output_size = output_size
                self.linear  = torch.nn.Linear(self.lnr_input_size, self.output_size)
                self.forward = self.forward_extractor_inputnml
                
            elif self.input_type == 'INPUT-SEP':
                # seems to be useless
                assert input_size  % 2 == 0
                assert output_size % 2 == 0
                self.lnr_input_size  = int(input_size  / 2)
                self.lnr_output_size = int(output_size / 2)
                self.linear  = torch.nn.Linear(self.lnr_input_size, self.lnr_output_size)
                self.forward = self.forward_extractor_inputsep

        # elif self.struct_type == 'REDUCER':
        #     assert output_size == 1
        #     assert self.input_type == 'INPUT-NML'
        #     self.lnr_input_size  = input_size
        #     self.lnr_output_size = 1
        #     self.linear  = torch.nn.Linear(self.lnr_input_size, self.lnr_output_size)
        #     self.forward = self.forward_reducer
        

        self.init_weights()
        
        self.postprocess = []
        for method, use_config in postprecess.items():
            use, config = use_config
            if use == False: continue
            if method == 'activator':
                activator = config
                if activator.lower() == 'relu': 
                    self.activator = F.relu
                elif activator.lower() == 'tanh': 
                    self.activator = F.tanh
                elif activator.lower() == 'gelu':
                    # TODO: adding gelu here.
                    self.activator = gelu
                else:
                    self.activator = lambda x: x
                self.postprocess.append(self.activator)
            
            if method == 'dropout':
                self.drop = torch.nn.Dropout(**config)
                self.postprocess.append(self.drop)
                
            elif method == 'layernorm':
                # https://pytorch.org/docs/stable/nn.html
                self.layernorm = torch.nn.LayerNorm(self.output_size, **config)
                self.postprocess.append(self.layernorm)
            
            
    def init_weights(self):
        initrange = 0.1
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward_extractor_inputnml(self, info, *args, **kwargs):
        info = self.linear(info)
        for post_layer in self.postprocess:
            info = post_layer(info)
        return info
    
    def forward_extractor_inputsep(self, info, *args, **kwargs):
        info_fwd, info_bwd = info.chunk(2, -1)
        # print(info_fwd.shape)
        # print(info_bwd.shape)
        info_fwd = self.linear(info_fwd)
        info_bwd = self.linear(info_bwd)
        info = torch.cat([info_fwd, info_bwd], -1)
        for post_layer in self.postprocess:
            info = post_layer(info)
        return info
        
    # def forward_reducer(self, info, *args, **kwargs):
    #     info = self.linear(info.transpose(-2, -1)).transpose(-2, -1).squeeze(-2)
    #     for post_layer in self.postprocess:
    #         info = post_layer(info)
    #     return info
    
    def __repr__(self):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
        lines = extra_lines + child_lines

        main_str = self._get_name() + '(' + self.struct_type.upper() + '): ' + '(' + str(self.input_size) + '->' + str(self.output_size) +') ' + '[INPUT] ' + self.input_type.upper() +'; ' + '('
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        return main_str