import os
import torch

# from ..nn.op import ReduceMeanLayer, ReduceSumLayer, ConcatenateLayer
# each linear is capable to prodcue a number of same nn layers
from ..nn.rnn import RNNLayer
from ..nn.cnn import CNNLayer
from ..nn.tfm import TFMLayer
from ..nn.linear import LinearLayer

from ..nn.helper import reshape_as_sent, restore_as_sent_extractor, restore_as_sent_reducer_fwd, restore_as_sent_reducer_bi
from ..nn.helper import reshape_as_token, restore_as_token_extractor, restore_as_token_reducer
# from ..nn.helper import reshape_untouch, restore_untouch
from ..nn.helper import _addindent


class Vector_Extractor_Layer(torch.nn.Module):
    '''This is full connected layer'''
    def __init__(self, Meanings, NNName_NNPara, **kwargs):
        super(Vector_Extractor_Layer, self).__init__()
        self.InputMeaning   = Meanings['InputMeaning']
        self.OutputMeaning = Meanings['OutputMeaning']
        self.Reshape_Restore = Meanings['Reshape_Restore']
        
        # self.layers = torch.nn.ModuleList()
        # for idx, name_layer in enumerate(Layers):
        name, layer = NNName_NNPara
        if name.lower() == 'linear':
            self.layer = LinearLayer(**layer)
        else:
            print('There is no layer', name)

    def forward(self, info):
        # for layer in self.layers:
        info = self.layer(info)
        return info
    
    def __repr__(self):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
        lines = extra_lines + child_lines

        # main_str = self._get_name() + ' (\nInput:' + self.InputMeaning + '; \nOutput:' + self.OutputMeaning + '; \nReshape_Restore:' + str(self.Reshape_Restore) + '): ('
        main_str = self._get_name() + ' (\n'
        
        # + ' (\n \n\tOutput:' + self.OutputMeaning + '; \n\tReshape_Restore:' + str(self.Reshape_Restore) + '\n\t): ('
        main_str = main_str + '  ' + '--'*20 + '\n'
        main_str = main_str + ';\n'.join([i for i in ['  Input   ==> '  + self.InputMeaning, 
                                                      '  Output  ==> '  + self.OutputMeaning,
                                                      '  Reshape ==> '  + str(self.Reshape_Restore)]])
        main_str = main_str + ';\n'+ '  ' + '--'*20
        
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        return main_str



class Matrix_Extractor_Layer(torch.nn.Module):
    def __init__(self, Meanings, NNName_NNPara):
        super(Matrix_Extractor_Layer, self).__init__()
        
        self.InputMeaning   = Meanings['InputMeaning']
        self.OutputMeaning = Meanings['OutputMeaning']
        self.Reshape_Restore = Meanings['Reshape_Restore']
        
        name, layer = NNName_NNPara
        if name.lower() == 'cnn':
            self.layer = CNNLayer(**layer)
        elif name.lower() == 'rnn':
            self.layer = RNNLayer(**layer)
        elif name.lower() == 'tfm':
            self.layer = TFMLayer(**layer)
        elif name.lower() == 'linear':
            self.layer = LinearLayer(**layer)
        else:
            raise ValueError('There is no layer ' + name)
            
    def forward(self, info, leng_st = None, *args, **kwargs):
        # print(info.device)
        info = self.layer(info, leng_st)
        return info
    
    def __repr__(self):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
        lines = extra_lines + child_lines

        main_str = self._get_name() + ' (\n'
        
        # + ' (\n \n\tOutput:' + self.OutputMeaning + '; \n\tReshape_Restore:' + str(self.Reshape_Restore) + '\n\t): ('
        main_str = main_str + '  ' + '--'*20 + '\n'
        main_str = main_str + ';\n'.join([i for i in ['  Input   ==> '  + self.InputMeaning, 
                                                      '  Output  ==> '  + self.OutputMeaning,
                                                      '  Reshape ==> '  + str(self.Reshape_Restore)]])
        main_str = main_str + ';\n'+ '  ' + '--'*20
        
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        return main_str




class Tensor_Extractor_Layer(torch.nn.Module):
    def __init__(self, Meanings, NNName_NNPara):
        super(Tensor_Extractor_Layer, self).__init__()
        
        self.InputMeaning   = Meanings['InputMeaning']
        self.OutputMeaning  = Meanings['OutputMeaning']
        self.Reshape_Restore = Meanings['Reshape_Restore']
        assert self.Reshape_Restore is None or self.Reshape_Restore in ['GrainVec_SeqAs_Token', 'GrainVec_SeqAs_Sent']
        
        name, layer = NNName_NNPara
        if name.lower() == 'cnn':
            self.layer = CNNLayer(**layer)
        elif name.lower() == 'rnn':
            self.layer = RNNLayer(**layer)
        elif name.lower() == 'tfm':
            self.layer = TFMLayer(**layer)
        elif name.lower() == 'linear':
            self.layer = LinearLayer(**layer)
        else:
            raise ValueError('There is no layer ' + name)
            
        # reshape: deal with the input tensor
        # restore: deal with the output tensor
        if self.Reshape_Restore == 'GrainVec_SeqAs_Token':
            self.reshape = reshape_as_token
            self.restore = restore_as_token_extractor
        elif self.Reshape_Restore == 'GrainVec_SeqAs_Sent':
            self.reshape = reshape_as_sent
            self.restore = restore_as_sent_extractor
        # else:
        #     self.reshape = reshape_untouch
        #     self.restore = restore_untouch
            
    def forward(self, info, leng_tk, leng_tk_mask, leng_st, misc_info):
        # info is (bs, a, b, c)
        if self.Reshape_Restore in ['GrainVec_SeqAs_Token', 'GrainVec_SeqAs_Sent']: #  None:
            # if reshape is reshape_untouch, then new_leng_st is leng_st
            info, new_leng_st, reverse_id, shape = self.reshape(info, leng_tk, leng_tk_mask, leng_st, misc_info)
            # info is (bs, x, c) if reshape is as_sent or as_token
            info = self.layer(info, new_leng_st)
            info = self.restore(info, leng_tk, leng_tk_mask, leng_st, misc_info, shape, reverse_id)
        else:
            info = self.layer(info, leng_tk = leng_tk, leng_tk_mask = leng_tk_mask, leng_st = leng_st,  misc_info = misc_info)
        return info
    
    def __repr__(self):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = self.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        for key, module in self._modules.items():
            mod_str = repr(module)
            mod_str = _addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
        lines = extra_lines + child_lines

        main_str = self._get_name() + ' (\n'
        
        # + ' (\n \n\tOutput:' + self.OutputMeaning + '; \n\tReshape_Restore:' + str(self.Reshape_Restore) + '\n\t): ('
        main_str = main_str + '  ' + '--'*20 + '\n'
        main_str = main_str + ';\n'.join([i for i in ['  Input   ==> '  + self.InputMeaning, 
                                                      '  Output  ==> '  + self.OutputMeaning,
                                                      '  Reshape ==> '  + str(self.Reshape_Restore)]])
        main_str = main_str + ';\n'+ '  ' + '--'*20
        
        
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        return main_str