# TODO: logging the time here in the future.
import time
import torch
from ..nn.helper import orderSeq, restoreSeq
from ..nn.op import ReduceMeanLayer, ReduceSumLayer, RecuderMaxLayer, ConcatenateLayer
from ..nn.rnn import RNNLayer
from ..nn.cnn import CNNLayer
from ..nn.tfm import TFMLayer
from ..nn.linear import LinearLayer

from ..nn.helper import reshape_as_sent, restore_as_sent_extractor, restore_as_sent_reducer_fwd, restore_as_sent_reducer_bi
from ..nn.helper import reshape_as_token, restore_as_token_extractor, restore_as_token_reducer
# from ..nn.helper import reshape_untouch, restore_untouch
from ..nn.helper import _addindent



class Matrix_Reducer_Layer(torch.nn.Module):
    def __init__(self, Meanings, NNName_NNPara):
        super(Matrix_Reducer_Layer, self).__init__()
        # always reduce b, the -2 dimension
        # (bs, a, b_input) --> (bs, a, b_output)
        self.InputMeaning    = Meanings['InputMeaning']
        self.OutputMeaning   = Meanings['OutputMeaning']
        self.Reshape_Restore = Meanings['Reshape_Restore']
        assert self.Reshape_Restore is None
        
        name, layer = NNName_NNPara
        
        if name.lower() == 'mean':
            self.layer = ReduceMeanLayer(**layer)
        elif name.lower() == 'sum':
            self.layer = ReduceSumLayer(**layer)
        elif name.lower() == 'max':
            self.layer = RecuderMaxLayer(**layer)
        elif name.lower() == 'concat':
            self.layer = ConcatenateLayer(**layer)
        elif name.lower() == 'rnn':
            self.layer = RNNLayer(**layer)
        elif name.lower() == 'tfm':
            self.layer = TFMLayer(**layer)
        else:
            raise ValueError('There is no layer ' + name)
            
    def forward(self, info, leng_st = None, *args, **kwargs):
        info = self.layer(info, leng_st, *args, **kwargs)
        return info
    
    def __repr__(self):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
        lines = extra_lines + child_lines

        main_str = self._get_name() + ' (\n'
        
        # + ' (\n \n\tOutput:' + self.OutputMeaning + '; \n\tReshape_Restore:' + str(self.Reshape_Restore) + '\n\t): ('
        main_str = main_str + '  ' + '--'*20 + '\n'
        main_str = main_str + ';\n'.join([i for i in ['  Input   ==> '  + self.InputMeaning, 
                                                      '  Output  ==> '  + self.OutputMeaning,
                                                      '  Reshape ==> '  + str(self.Reshape_Restore)]])
        main_str = main_str + ';\n'+ '  ' + '--'*20
        
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        return main_str



class Tensor_Reducer_Layer(torch.nn.Module):
    def __init__(self, Meanings, NNName_NNPara):
        super(Tensor_Reducer_Layer, self).__init__()
        # always reduce b, the -2 dimension
        # (bs, a, b, c_inp) --> (bs, a, c_outp)
        self.InputMeaning    = Meanings['InputMeaning']
        self.OutputMeaning   = Meanings['OutputMeaning']
        self.Reshape_Restore = Meanings['Reshape_Restore']
        assert self.Reshape_Restore is None or self.Reshape_Restore in ['GrainVec_SeqAs_Token', 'GrainVec_SeqAs_Sent']
        
        name, layer = NNName_NNPara
        DIRECTION = 'None'

        if name.lower() == 'mean':
            self.layer = ReduceMeanLayer(**layer)
        elif name.lower() == 'sum':
            self.layer = ReduceSumLayer(**layer)
        elif name.lower() == 'max':
            self.layer = RecuderMaxLayer(**layer)
            assert self.Reshape_Restore == 'GrainVec_SeqAs_Token'
        elif name.lower() == 'concat':
            self.layer = ConcatenateLayer(**layer)
        elif name.lower() == 'linear':
            self.layer = LinearLayer(**layer)
        elif name.lower() == 'rnn': 
            self.layer = RNNLayer(**layer)
            DIRECTION = layer['direction_type']
        elif name.lower() == 'tfm':
            self.layer = TFMLayer(**layer)
            DIRECTION = layer['direction_type']
            assert layer['direction_type'] == 'FWD'
        else:
            raise ValueError('There is no layer ' + name)
            
        # reshape: deal with the input tensor
        # restore: deal with the output tensor
        if self.Reshape_Restore == 'GrainVec_SeqAs_Token':
            self.reshape = reshape_as_token
            self.restore = restore_as_token_reducer
        elif self.Reshape_Restore == 'GrainVec_SeqAs_Sent':
            self.reshape = reshape_as_sent
            self.restore = restore_as_sent_reducer_fwd if 'BI' not in DIRECTION else restore_as_sent_reducer_bi
        elif self.Reshape_Restore == None:
            pass
        else:
            raise ValueError('No Good Reshape and Restore Function')
            
            
    def forward(self, info, leng_tk, leng_tk_mask, leng_st, misc_info):
        # info is (bs, a, b, c)
        # GrainVec_SeqAs_Token
        if self.Reshape_Restore in ['GrainVec_SeqAs_Token', 'GrainVec_SeqAs_Sent']:
            # print('[fieldlm.sublayer.reducer.Tensor_Reducer_Layer]//misc_info:', self.Reshape_Restore, misc_info)
            # if reshape is reshape_untouch, then new_leng_st is leng_st
            info, new_leng_st, reverse_id, shape = self.reshape(info, leng_tk, leng_tk_mask, leng_st, misc_info)
            # info is (bs, x, c) if reshape is as_sent or as_token
            # info = self.layer(info, new_leng_st, misc_info = misc_info)
            info = self.layer(info, new_leng_st)
            # print(info.shape)
            info = self.restore(info, leng_tk, leng_tk_mask, leng_st, misc_info, shape, reverse_id)
        else:
            # print('[fieldlm.sublayer.reducer.Tensor_Reducer_Layer]//misc_info:', self.Reshape_Restore, misc_info)
            info = self.layer(info, leng_tk = leng_tk, leng_tk_mask = leng_tk_mask, leng_st = leng_st, misc_info = misc_info)
        return info

    
    def __repr__(self):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
        lines = extra_lines + child_lines

        main_str = self._get_name() + ' (\n'
        
        # + ' (\n \n\tOutput:' + self.OutputMeaning + '; \n\tReshape_Restore:' + str(self.Reshape_Restore) + '\n\t): ('
        main_str = main_str + '  ' + '--'*20 + '\n'
        main_str = main_str + ';\n'.join([i for i in ['  Input   ==> '  + self.InputMeaning, 
                                                      '  Output  ==> '  + self.OutputMeaning,
                                                      '  Reshape ==> '  + str(self.Reshape_Restore)]])
        main_str = main_str + ';\n'+ '  ' + '--'*20
        
        
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        return main_str