import torch
import numpy as np
import torch.nn.functional as F
# from fieldlm.nn.helper import change_feat_location, reverse_tensor, gelu, _addindent, get_leng_mask
# from fieldlm.nn.transformer import Transformer
from .helper import change_feat_location, reverse_tensor, gelu, _addindent, get_leng_mask
from .transformer import Transformer


postprecess = {
#     'dropout' :[True, {'p':0.5, 'inplace':False}],
#     'activiator': [True, 'relu'],
#     'layernorm': [True, {'eps': 1e-05, "elementwise_affine":True}],
}


class TFMLayer(torch.nn.Module):

    def __init__(self, 
                 type = 'tfm',
                 input_type = 'INPUT-NML',  # ['INPUT-NML', 'INPUT-SEP']
                 direction_type = 'FWD',    # ['FWD', 'BI-MIX', 'BI-SEP']
                 struct_type = 'EXTRACTOR', # ['EXTRACTOR', 'REDUCER']
                 
                 input_size = 200, output_size = 200, # d_model
                 nhead = 8,
                 num_encoder_layers=6,
                 num_decoder_layers=0, 
                 dim_feedforward=1024, 
                 # src_mask_flag = False,
                 # tgt_mask_flag = False, # these will be given according to direction_type

                 tfm_dropout = 0.1,
                 tfm_activation = 'relu',
                 
                 batch_first = True,
                 postprecess = postprecess):
        
        
        super(TFMLayer,self).__init__()
        
        assert input_type in ['INPUT-NML', 'INPUT-SEP']
        assert direction_type in ['FWD', 'MIX', 'BI-SEP'] # for mu
        assert struct_type in ['EXTRACTOR']
        
        
        self.type = type 
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers

        
        self.input_size = input_size
        self.output_size = output_size
        
        # (+) input_type: ['INPUT-NML', 'INPUT-SEP']
        self.input_type = input_type
        self.input_size = input_size
        if self.input_type == 'INPUT-SEP':
            assert input_size % 2 == 0 
            self.tfm_input_size = int( input_size / 2)
        elif self.input_type == 'INPUT-NML':
            self.tfm_input_size = input_size
            
        # (+) direction_type:  ['FWD', 'BI-MIX', 'BI-SEP']
        self.direction_type = direction_type
        self.n_directions = 2 if direction_type == 'BI-SEP' else 1
        
        # (+) struct_type: ['EXTRACTOR', 'REDUCER']
        self.struct_type = struct_type
        
        # (+) output size   
        assert output_size % self.n_directions == 0 
        self.output_size = output_size
        self.hidden_size = int(output_size / self.n_directions)
        assert self.hidden_size == self.tfm_input_size
        

        # 1 initialize self.rnn here:
        if direction_type  == 'FWD':
            self.transformer  = Transformer(d_model = self.hidden_size, 
                                            nhead = nhead,
                                            num_encoder_layers = self.num_encoder_layers,
                                            num_decoder_layers = self.num_decoder_layers,
                                            dim_feedforward = dim_feedforward, 
                                            dropout = tfm_dropout,
                                            activation = tfm_activation,
                                            src_mask_flag = True, # see forward tokens in a sentence only.
                                           )
            self.transformer_op = self.op_allnormal
            
        elif direction_type == 'MIX':
            self.transformer  = Transformer(d_model = self.hidden_size, 
                                            nhead = nhead,
                                            num_encoder_layers = self.num_encoder_layers,
                                            num_decoder_layers = self.num_decoder_layers,
                                            dim_feedforward = dim_feedforward, 
                                            dropout = tfm_dropout,
                                            activation = tfm_activation,
                                            src_mask_flag = False, # see all tokens in a sentence
                                           ) 
            self.transformer_op = self.op_allnormal
            
        elif direction_type == 'BI-SEP':
            
            self.transformer_fwd  = Transformer(d_model = self.hidden_size, 
                                                nhead = nhead,
                                                num_encoder_layers = self.num_encoder_layers,
                                                num_decoder_layers = self.num_decoder_layers,
                                                dim_feedforward = dim_feedforward, 
                                                dropout = tfm_dropout,
                                                activation = tfm_activation,
                                                src_mask_flag = True, # see forward tokens in a sentence only.
                                               )
            self.transformer_bwd  = Transformer(d_model = self.hidden_size, 
                                                nhead = nhead,
                                                num_encoder_layers = self.num_encoder_layers,
                                                num_decoder_layers = self.num_decoder_layers,
                                                dim_feedforward = dim_feedforward, 
                                                dropout = tfm_dropout,
                                                activation = tfm_activation,
                                                src_mask_flag = True, # see forward tokens in a sentence only. (actually the backward)
                                               )

            if  self.input_type == 'INPUT-SEP':
                self.transformer_op = self.op_bisep_inputsep 
                # output can be extractor or reducer
            elif self.input_type == 'INPUT-NML':
                self.transformer_op = self.op_bisep_inputnml
                # output can be extractor or reducer
            else:
                raise ValueError('Not a valid lstm input type, must in [INPUT-SEP, INPUT-NML]!')

        else:
            raise ValueError("Not a valid lstm direction type, must in ['FWD', 'BI-MIX', 'BI-SEP']!")


       # (+) RUDUCER OR EXTRACTOR
        if self.struct_type == 'EXTRACTOR':
            self.output = self.transformer_extractor
        # elif self.struct_type == 'REDUCER':
        #     self.output = self.transformer_reducer
        else:
            raise ValueError('Not a valid struct type, must be either extactor or reducer!')

        # (+) postprocess here
        self.postprocess = []
        for method, use_config in postprecess.items():
            use, config = use_config
            if use == False: continue
            if method == 'activator':
                activator = config
                if activator.lower() == 'relu': 
                    self.activator = F.relu
                elif activator.lower() == 'tanh': 
                    self.activator = F.tanh
                elif activator.lower() == 'gelu':
                    # TODO: adding gelu here.
                    self.activator = gelu
                else:
                    self.activator = lambda x: x
                self.postprocess.append(self.activator)
            
            if method == 'dropout':
                self.drop = torch.nn.Dropout(**config)
                self.postprocess.append(self.drop)
                
            elif method == 'layernorm':
                # https://pytorch.org/docs/stable/nn.html
                self.layernorm = torch.nn.LayerNorm(self.output_size, **config)
                self.postprocess.append(self.layernorm)
                
    def op_allnormal(self, info, leng_st, leng_st_mask):
        
        info = self.transformer(info, leng_st_mask)
        
        # according to leng_st 1-dim tensor
        # hn = torch.zeros(info.shape[0], 1, info.shape[-1])
        # for i, j in enumerate(info):
        #     hn[i] = j[leng_st[i]-1]

        # hn = hn.transpose(0,1)
        # hn : [n_directions, bz, hidden_size]
        return info, None

    
    def op_bisep_inputsep(self, info, leng_st, leng_st_mask):
        # prepare data
        info_fwd, info_bwd = info.chunk(2, -1)
        info_bwd = reverse_tensor(info_bwd, leng_st) # info = helper_tools.reverse_tensor(info, leng_st)

        # get fwd output
        info_fwd = self.transformer_fwd(info_fwd, leng_st_mask)
        # hn_fwd = torch.zeros(info_fwd.shape[0], 1, info_fwd.shape[-1])
        # for i, j in enumerate(info_fwd):
        #     hn_fwd[i] = j[leng_st[i] - 1]

        # get bwd output
        info_bwd = self.transformer_bwd(info_bwd, leng_st_mask)
        # hn_bwd = torch.zeros(info_bwd.shape[0], 1, info_bwd.shape[-1])
        # for i, j in enumerate(info_bwd):
        #     hn_bwd[i] = j[leng_st[i] - 1]

        # reorder bwd
        info_bwd = reverse_tensor(info_bwd, leng_st)  # info_bwd = helper_tools.reverse_tensor(info_bwd, leng_st)

        # concat output
        info = torch.cat([info_fwd, info_bwd], -1)
        # concat hidden output
        # hn = torch.cat([hn_fwd, hn_bwd], 1)
        # hn = hn.transpose(0, 1)
        # hn : [n_directions, bz, hidden_size]
        return info, None

    def op_bisep_inputnml(self, info, leng_st, leng_st_mask):

        info_fwd = info
        info_bwd = reverse_tensor(info, leng_st)

        # get fwd output
        info_fwd = self.transformer_fwd(info_fwd, leng_st_mask)
        # hn_fwd = torch.zeros(info_fwd.shape[0], 1, info_fwd.shape[-1])
        # for i, j in enumerate(info_fwd):
        #     hn_fwd[i] = j[leng_st[i] - 1]

        # get bwd output
        info_bwd = self.transformer_bwd(info_bwd, leng_st_mask)
        # hn_bwd = torch.zeros(info_bwd.shape[0], 1, info_bwd.shape[-1])
        # for i, j in enumerate(info_bwd):
        #     hn_bwd[i] = j[leng_st[i] - 1]

        # reorder bwd
        info_bwd = reverse_tensor(info_bwd, leng_st)  # info_bwd = helper_tools.reverse_tensor(info_bwd, leng_st)

        # concat output
        info = torch.cat([info_fwd, info_bwd], -1)
        # concat hidden output
        # hn = torch.cat([hn_fwd, hn_bwd], 1)
        # hn = hn.transpose(0,1)
        # hn : [n_directions, bz, hidden_size]

        return info, None

    def transformer_extractor(self, info, hidden):
        # do nothing
        return info

    # def transformer_reducer(self, info, hidden):
    #     hn = hidden
    #     hn = torch.cat([hn[i] for i in range(self.n_directions)], dim=-1)
    #     return  hn


    def forward(self, info, leng_st):
        leng_st_mask = get_leng_mask(leng_st)
        # print(leng_st_mask.device)
        # print(leng_st.device)
        info, hidden = self.transformer_op(info, leng_st, leng_st_mask)
        info = self.output(info, hidden)
        for layer in self.postprocess:
            info = layer(info)
        return info
    
    def __repr__(self):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
        lines = extra_lines + child_lines

        main_str = self._get_name() + '(' + self.struct_type.upper() + '): ' + '(' + str(self.input_size) + '->' + str(self.output_size) +') ' + '[INPUT] ' + self.input_type.upper() +'; ' + '[DIRECTION] ' + self.direction_type.upper() + '('
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        return main_str
        