import torch
import numpy as np
from ..nn.helper import orderSeq
from nlptext.folder import Folder
from nlptext.sentence import Sentence
from datetime import datetime
from pprint import pprint


HyperFields = ['annoE', 'pos', 'pos_en', 'medpos']
one_grain_fields = ['token', 'pos', 'medpos']

############################################################################ get train valid test from the whole corpus
def get_train_valid_test_from_nlptext(nlptext):
    train_sent_idx, valid_sent_idx, test_sent_idx = [], [], []
    for i in range(nlptext.GROUP['length']):
        f = Folder(i)
        if 'train' in f.name:
            print('[fieldlm.utils.get_train_valid_test_from_nlptext]', 'Train:', f.name)
            train_sent_idx.extend(list(range(*f.IdxSentStartEnd)))
        elif 'valid' in f.name or 'dev' in f.name:
            print('[fieldlm.utils.get_train_valid_test_from_nlptext]', 'Valid:', f.name)
            valid_sent_idx.extend(list(range(*f.IdxSentStartEnd)))
        elif 'test' in f.name:
            print('[fieldlm.utils.get_train_valid_test_from_nlptext]', 'Test :', f.name)
            test_sent_idx.extend(list(range(*f.IdxSentStartEnd)))

    if len(valid_sent_idx) == 0: 
        valid_sent_idx = test_sent_idx
    elif len(test_sent_idx) == 0:
        test_sent_idx = valid_sent_idx

    data = train_sent_idx, valid_sent_idx, test_sent_idx 
    return data


def get_train_valid_test(total_sent_num, train_prop = 0.95, seed = 10):
    np.random.seed(seed)
    train_num = int(total_sent_num*train_prop)
    test_num  = int((total_sent_num - train_num) /2)
    valid_num = total_sent_num - train_num - test_num
    
    total_sent_idx = list(range(total_sent_num))
    np.random.shuffle(total_sent_idx)

    train_sent_idx = total_sent_idx[:train_num]
    test_sent_idx = total_sent_idx[train_num:train_num + test_num]
    valid_sent_idx = total_sent_idx[train_num + test_num:]

    return train_sent_idx, valid_sent_idx, test_sent_idx
##################################################



############################################################################ get information for input fields and target field 
def get_Input_Target_Field(field_combination, target_field, tagScheme):
    Input_Fields = {}
    
    for fld in field_combination:
        if fld in HyperFields:
            Input_Fields[fld] = {'tagScheme': tagScheme}
        else:
            Input_Fields[fld] = {}
    
    Target_Field = [
        target_field, {'tagScheme': tagScheme}
    ]

    Field_Dir = '_'.join(field_combination)
    return Input_Fields, Target_Field, Field_Dir


def get_input_fields_info(nlptext, Input_Fields, fldembed = None, **MISC):
    # together with Weigths to make the final decision.
    # IMPORTANT CHANGE: transform from fldembed oriented to task corpus oriented?

    # (a): get the Weights from pretrained embeddings. 
    Weights = fldembed.weights if fldembed else {}
    if len(Weights):
        print('[fieldlm.utils.batch.get_input_fields_info]//pretrained embed:', {fld: len(wv.index2word) for fld, wv in Weights.items()})

    # (b): field settings.
    Field_Settings = fldembed.Field_Settings if fldembed else {}
    INPUT_FIELDS = {} 
    INPUT_EMBEDS = {}
    
    # MISC must contain specical_tokens
    idx2specialtokens = MISC['idx2specialtokens']
    specialtokens2idx = MISC['specialtokens2idx']
    special_tokens_num = len(idx2specialtokens)
    unk_id = specialtokens2idx['</unk>']
    tk_unk_id = unk_id
    maxSentLeng = MISC['maxSentLeng']
    
    useTkPsn = MISC['useTkPsn']
    useGrPsn = MISC['useGrPsn']
    embed_size = None

    # pprint(Input_Fields)
    for fld, fld_settings in Input_Fields.items():
        # print('For field:', fld)
        print('[fieldlm.utils.batch.get_input_fields_info]//field:', fld)
        # s = datetime.now()
        para  = {}
        
        if fld in Weights:
            embed = {}
            max_gr = 1
            wv = Weights[fld]
            GU = wv.GU
            pure_grain_num = len(GU[0])
            idx2grain = idx2specialtokens + GU[0]
            grain_num = len(idx2grain)
            grain2idx = dict(zip(idx2grain, list(range(grain_num)) ))
            GU = idx2grain, grain2idx
            para['GU']  = GU 
            para['unk_id'] = unk_id
            para['tk_unk_id'] = tk_unk_id
            
            pure_grain_num, embed_size = wv.vectors.shape
            vectors = np.vstack([np.zeros([special_tokens_num, embed_size]), wv.vectors])
            embed['init'] = vectors
            embed['input_size'], embed['embedding_size'] = vectors.shape
            assert embed['input_size'] == pure_grain_num + special_tokens_num
            
            if fld in nlptext.CONTEXT_DEP_CHANNELS  and fld != 'token':
                # (a): hyperfield GU.
                # print(fld, fld_settings)
                TRANS = nlptext.getTrans(fld, **fld_settings)
                TRANS = {k: v + special_tokens_num for k, v in TRANS.items()}
                para['TRANS'] = TRANS
                
            elif fld in nlptext.CONTEXT_IND_CHANNELS and fld != 'token':
                # (b): subfield GU, prepare TU & LKP.
                # here is to deal with the sub fields.
                TU =  wv.TU
                idx2token = idx2specialtokens + TU[0]
                token_num = len(idx2token)
                token2idx = dict(zip(idx2token, list(range(token_num)) ))
                TU = idx2token, token2idx
                para['TU']  = TU 
                
                LKP = wv.LKP
                ##################################################
                for idx, grains_seq in enumerate(LKP):
                    if len(grains_seq) == 0:
                        LKP[idx] = [1]
                    else:
                        LKP[idx] = [i+special_tokens_num for i in grains_seq]
                    if max_gr < len(LKP[idx]): 
                        max_gr =len(LKP[idx])
                ##################################################
                
                LKP = [[i] for i in range(special_tokens_num)] + LKP
                para['LKP'] = LKP  
                assert len(LKP) == len(idx2token)
        
            INPUT_FIELDS[fld] = para
            fld_settings = Field_Settings[fld]
            # print(fld_settings)
            for k, v in fld_settings.items():
                INPUT_FIELDS[fld][k] = v
                
            # TODO, adding more information here
            INPUT_EMBEDS[fld] = {}
            INPUT_EMBEDS[fld]['grain'] = embed
            if useTkPsn:
                INPUT_EMBEDS[fld]['tk_psn'] = {'input_size':maxSentLeng + 1 , 
                                               'embedding_size':embed_size, 
                                               'init': 'random'}
            if useGrPsn and fld not in one_grain_fields:
                INPUT_EMBEDS[fld]['gr_psn'] = {'input_size':max_gr + 1 , 
                                               'embedding_size':embed_size, 
                                               'init': 'random'}

            
    if not embed_size: embed_size = 200
    for fld, fld_settings in Input_Fields.items():
        # print(fld_settings)
        max_gr = 1
        para = {}
        if fld not in Weights:
            embed = {}
            # idx2specialtokens = ['</pad>']
            # special_tokens_num = len(idx2specialtokens ) # 1
            GU = nlptext.getGrainVocab(fld, **fld_settings)
            idx2grain = idx2specialtokens + GU[0]
            grain_num = len(idx2grain)
            grain2idx = dict(zip(idx2grain, list(range(grain_num))))
            GU = idx2grain, grain2idx
            para['GU']  = GU
            para['unk_id'] = unk_id
            para['tk_unk_id'] = tk_unk_id
            embed['init'] = 'random'
            embed['input_size'], embed['embedding_size'] = grain_num, embed_size # notice that this embed_size is from last part
            
            if fld in nlptext.CONTEXT_DEP_CHANNELS and fld != 'token':
                TRANS = nlptext.getTrans(fld, **fld_settings)
                TRANS = {k: v+special_tokens_num for k, v in TRANS.items()}
                para['TRANS'] = TRANS
            
            elif fld in nlptext.CONTEXT_IND_CHANNELS and fld != 'token':
                # here is to deal with the sub fields.
                LKP, TU = nlptext.getLookUp(fld, **fld_settings)
                idx2token = idx2specialtokens + TU[0]
                token_num = len(idx2token)
                token2idx = dict(zip(idx2token, list(range(token_num)) ))
                TU = idx2token, token2idx
                para['TU']  = TU 
                
                for idx, grains_seq in enumerate(LKP):
                    LKP[idx] = [i+special_tokens_num for i in grains_seq]
                    if max_gr < len(LKP[idx]): max_gr =len(LKP[idx])
                LKP = [[i] for i in range(special_tokens_num)] + LKP
                para['LKP'] = LKP  
                assert len(LKP) == len(idx2token)
                
            INPUT_FIELDS[fld] = para
            for k, v in fld_settings.items():
                INPUT_FIELDS[fld][k] = v
                
            # TODO, adding more information here
            INPUT_EMBEDS[fld] = {}
            INPUT_EMBEDS[fld]['grain'] = embed
            if useTkPsn:
                INPUT_EMBEDS[fld]['tk_psn'] = {'input_size':maxSentLeng + 1 , 
                                               'embedding_size':embed_size, 
                                               'init': 'random'}
            if useGrPsn and fld not in one_grain_fields:
                INPUT_EMBEDS[fld]['gr_psn'] = {'input_size':max_gr + 1 , 
                                               'embedding_size':embed_size, 
                                               'init': 'random'}
            
    return INPUT_FIELDS, INPUT_EMBEDS
                  
      
def get_target_field_info(nlptext, Target_Field, fldembed, **MISC):
    field, para = Target_Field
    # nlptext is important here.
    # channel_anno = 'annoE'
    target_para = {}
    
    specialtokens2idx = MISC['specialtokens2idx']
    unk_id = specialtokens2idx['</unk>']
    
    if field == 'token':
        idx2specialtokens = MISC['idx2specialtokens']
        
        Weights = fldembed.weights
        assert 'token' in Weights
        wv = Weights['token']
        GU = wv.GU
        pure_grain_num = len(GU[0])
        idx2grain = idx2specialtokens + GU[0]
        grain_num = len(idx2grain)
        grain2idx = dict(zip(idx2grain, list(range(grain_num)) ))
        GU = idx2grain, grain2idx
        target_para['GU']  = GU
        target_para['unk_id'] = unk_id
        labels = [] 
        tag_size = len(idx2grain)

    elif field == 'category':
        labels = []
        for i in range(nlptext.GROUP['length']):
            f = Folder(i)
            category = f.name.split('/')[-1].split('.')[0]
            if category not in labels:
                labels.append(category)
        labels.sort()
        target_para['GU'] = labels, {label:idx for idx, label in enumerate(labels)}
        tag_size = len(labels)

    else: 
        # generally, we won't adding start and end into CRF model.
        # and any other special tokens, except '</pad>'
        # even the features are with start and end, it will should be removed before feed into crf
        # it can also be proved that the start and end vectors are useless in sequential labeling.
        GU = nlptext.getGrainVocab(field, **para)
        idx2tag = ['</pad>'] + GU[0]
        tag2idx = dict(zip(range(len(idx2tag)), idx2tag))
        GU = idx2tag, tag2idx
        # </pad> is in GU, no </unk> or other special tags in GU
        tag_size = len(idx2tag)
        target_para['GU'] = GU
        
        TRANS = nlptext.getTrans(field, **para)
        TRANS = {k: v+1 for k, v in TRANS.items()}
        labels = list(set([i.split('-')[0] for i in idx2tag if '-' in i]))
        labels.sort()
        target_para['TRANS'] = TRANS
        target_para['unk_id'] = unk_id
        

    TARGET_FIELD = [field, target_para, labels, tag_size]
    return TARGET_FIELD
##################################################




############################################################################ masked language model
def generate_maskedtokens_idxes_prob(leng_st, proportion = 0.2, device = None):

    proportion = 0.2
    leng_mask_num = (leng_st.float() * proportion).long()
    
    batch_masked_token_idxes = []
    batch_masked_token_prob = []
    for sent_idx, mask_num in enumerate(leng_mask_num):
        # print(torch.randint(0, mask_num)
        masked_token_idxes = torch.sort(torch.randperm(leng_st[sent_idx]).to(device)[:leng_mask_num[sent_idx]])[0]
        batch_masked_token_idxes.append(masked_token_idxes)
        masked_token_prob = torch.rand(masked_token_idxes.size(0)).to(device)
        batch_masked_token_prob.append(masked_token_prob)

    return batch_masked_token_idxes, batch_masked_token_prob


def corrupt_batchsents(batch, batch_masked_tokens_idxes, batch_masked_tokens_probs, INPUT_FIELDS, **MISC):
    # questions: how to deal with the hyper field information. 
    # current solution is not masking them
    # 1) one hyper grain can link to many words
    # 2) this can also alleviate the discrepency between pretran and downstream.
    idx2token = INPUT_FIELDS['token']['GU'][0]
    special_tokens_num = len(MISC['specialtokens2idx'])
    vocab_size = int(len(idx2token) * 0.8) # ignore some low-frequent words
    mask_id = MISC['specialtokens2idx']['</mask>']
    
    device = batch_masked_tokens_idxes[0].device
    cpt_batch = []
    batch_target_masked_tokens_idxes = []
    for sent_idx in range(len(batch)):
        sent = batch[sent_idx]
        masked_tokens_idxes = batch_masked_tokens_idxes[sent_idx].cpu().numpy()   # torch to np
        masked_tokens_probs = batch_masked_tokens_probs[sent_idx].cpu().numpy()   # torch to np
        
        # how to deal with the unk token?
        # the tokens will converted to unk token and no subfields for them.
        sent_tokens = np.array(sent.get_grain_idx('token', **INPUT_FIELDS['token'])[0]).squeeze(-1)
        selected_tokens = sent_tokens[masked_tokens_idxes]
        
        # idx to predict
        target_masked_tokens_idxes = torch.LongTensor(selected_tokens.copy()).to(device)
        selected_tokens[masked_tokens_probs > 0.2]  = mask_id
        
        random_tokens_num = len(selected_tokens[masked_tokens_probs < 0.1])
        selected_tokens[masked_tokens_probs < 0.1]  = np.random.randint(special_tokens_num, vocab_size, size=(random_tokens_num,))
        sent_tokens[masked_tokens_idxes] = selected_tokens
        
        crpt_sent = Sentence(sentence = ' '.join(idx2token[i] for i in sent_tokens))
        cpt_batch.append(crpt_sent)
        batch_target_masked_tokens_idxes.append(target_masked_tokens_idxes)
        # print(crpt_sent.sentence)
        
    return cpt_batch, batch_target_masked_tokens_idxes
##################################################



############################################################################ get batch information for one field
def get_field_4_batchsents(batch, field, MISC, **kwargs):
    # batch can be corrupted batch or normal batch
    # kwargs is for st.get_grain_idx
    useStartEnd = MISC['useStartEnd']
    
    extra_leng = 2 if useStartEnd else 0
    info__leng_st__leng_tk__max_gr =[st.get_grain_idx(field, **kwargs) for st in batch]

    mGr_sents = [i[3] for i in info__leng_st__leng_tk__max_gr]
    mGr  = np.max(mGr_sents)
    
    leng_st = [i[1] + extra_leng for i in info__leng_st__leng_tk__max_gr]
    mTk  = np.max(leng_st)

    # print(leng_st)
    # padding info and leng
    if not useStartEnd or field == 'annoE':
        # NOTICE: annoE seqeuences always don't have start and end.
        info = np.array([np.pad(i[0], 
                                ((0, mTk-leng_st[idx]), (0, mGr-mGr_sents[idx])),  
                                'constant', 
                                constant_values=0) 
                         for idx, i in enumerate(info__leng_st__leng_tk__max_gr)])

        leng_tk = np.array([np.pad(i[2], (0, mTk - leng_st[idx]),  'constant', constant_values=0) 
                         for idx, i in enumerate(info__leng_st__leng_tk__max_gr)])
    elif useStartEnd:
        # kwargs['GU'][0] doesn't contain special tokens: </unk>, </start>, </end>
        # their idx are len(GU[0]) + 0, + 1, + 2.
        start_idx = MISC['specialtokens2idx']['</start>']
        end_idx   = MISC['specialtokens2idx']['</end>']

        info = np.array([np.pad([[start_idx]+[0] * (mGr_sents[idx]-1) ]+ i[0] + [[end_idx]+[0] * (mGr_sents[idx]-1)], 
                                ((0, mTk - leng_st[idx]), (0, mGr-mGr_sents[idx])),  
                                'constant', 
                                constant_values=0) 
                         for idx, i in enumerate(info__leng_st__leng_tk__max_gr)])

        leng_tk = np.array([np.pad([1] + i[2] + [1], 
                                   (0, mTk - leng_st[idx]),  
                                   'constant', 
                                   constant_values=0) 
                         for idx, i in enumerate(info__leng_st__leng_tk__max_gr)])
    else:
        raise ValueError('useStartEnd is not Correct')
    
    return info, leng_tk, leng_st
##################################################

def get_category_4_batchsents(batch, field, MISC, **para):
    # batch can be corrupted batch or normal batch
    # para label GU
    category2idx = para['GU'][1]
    tags = np.array([category2idx[sent.Folder.name.split('/')[-1].split('.')[0]] for sent in batch])
    return tags, None, None
    
############################################################################ get batch information for all fields
def get_fieldsinfo_4_batchsents(batch, INPUT_FIELDS, TARGET_FIELD, MISC, device = None):
    # Hyper-Parameters
    useStartEnd  = MISC['useStartEnd']
    useMask      = MISC['useMask'] 
    useTokenType = MISC['useTokenType']
    Hyper_Fields = batch[0].CONTEXT_DEP_CHANNELS
    
    # Lengths of Sentence
    leng_st = torch.LongTensor([st.length for st in batch]).to(device)
    if len(batch) > 1: batch, leng_st, reverse_index = orderSeq(np.array(batch), leng_st)
    else: leng_st = torch.LongTensor([batch[0].length]).to(device)
        
    # If useStartEnd
    if useStartEnd:leng_st = leng_st + 2
        
    # If useMask
    if useMask:
        maskProportion = MISC['maskProportion']
        batch_masked_tokens_idxes, batch_masked_tokens_probs = generate_maskedtokens_idxes_prob(leng_st, proportion = maskProportion)
        cpt_batch, batch_target_masked_tokens_idxes = corrupt_batchsents(batch, batch_masked_tokens_idxes, batch_masked_tokens_probs, INPUT_FIELDS, **MISC) 
        misc_info = {'batch_masked_tokens_idxes': batch_masked_tokens_idxes, 
                     'batch_target_masked_tokens_idxes': batch_target_masked_tokens_idxes}
    else:
        misc_info = {}
    
    # INPUT FIELDS
    info_dict = {}
    leng_tk_mask_dict = {}
    for fld, para in INPUT_FIELDS.items():
        if useMask and fld not in Hyper_Fields:
            data = get_field_4_batchsents(cpt_batch, fld, MISC, **para)
        else:
            data = get_field_4_batchsents(batch,     fld, MISC, **para)
        info_dict[fld] = [torch.LongTensor(i).to(device) for i in data]
        info_dict[fld][2] = info_dict[fld][0] == 0
        # info_dict  
        # {fld: [info, leng_tk, leng_tk_mask]}
                         
    # TARGET FIELD
    if TARGET_FIELD[0] in ['token', 'annoE']:
        tags, _, _ = get_field_4_batchsents(batch, TARGET_FIELD[0], MISC,  **TARGET_FIELD[1])
    elif TARGET_FIELD[0] == 'category':
        tags, _, _ = get_category_4_batchsents(batch, TARGET_FIELD[0], MISC,  **TARGET_FIELD[1])
    
    tags = torch.LongTensor(tags).to(device).squeeze(-1) 
    leng_st_mask = info_dict[fld][1] == 0
    misc_info['leng_st_mask'] = leng_st_mask
    
    if useTokenType: pass # TODO: to update this information in the future
        
    return info_dict, tags, leng_st, misc_info
##################################################

####################################################################################################
def generate_fldseq_para(embed_layer_para, 
                         Indep_Template, 
                         Interdep_Template,
                         Output_Size= 200):
    
    
    # generate Indep_Template_One_Grain

    # (+) DON'T CHANGE THIS PART =======================================
    # This is equal to using embedding directly.
    # No post processing
    # this is for field token, pos
    for idx, Sublayer in enumerate(Indep_Template):
        if Sublayer[0] == 'Tensor_Reducer':
            # print(idx)
            started_idx = idx + 1
            break
    
    
    OneGrain_Tensor_Extractor = [ 
        # STRUCTURE:
            'Tensor_Reducer',    
        # Meanings:
            {
                'InputMeaning':    'GrainVec_SeqAS_Token_SeqAS_Sent',
                'OutputMeaning':   'TokenVec_SeqAS_Sent',
                'Reshape_Restore':  None,
            },
        # NNName_NNPara
            [
                # NN Name
                'Mean',
                # NN Para
                {'type': 'mean', 
                 'input_size':  Output_Size, 
                 'output_size': Output_Size,
                 'postprecess' :{ # don't use dropout here.
                    #  'dropout' :[True, {'p':0.5, 'inplace':False}],
                    #  'layernorm': [True, {'eps': 1e-05, "elementwise_affine":True}],
                    }

                }
            ]
    ]
    
    Indep_Template_One_Grain = [OneGrain_Tensor_Extractor] + Indep_Template[started_idx: ]

    
    # (1) Indep_Layer_Template
    # One-Grain field
    one_grain_flds = ['token', 'pos']
    Indep_Layer_Template = []
    for fld in embed_layer_para:
        if fld in one_grain_flds:
            Indep_Layer_Template.append([fld, Indep_Template_One_Grain])
        else:
            Indep_Layer_Template.append([fld, Indep_Template])

    # (2) Interdep_Layer_Template Configuration
    Interdep_Layer_Template = Interdep_Template
    # for SUBLayer in Interdep_Layer_Template:
    #     if SUBLayer[0] == 'Tensor_Reducer' and SUBLayer[2][0].lower() == 'linear' and SUBLayer[2][1]['input_size'] is None:
    #         SUBLayer[2][1]['input_size'] = len(embed_layer_para)

    # (3) Total Field Sequence Configuration
    FieldSequence_Para = {
        'expander_layer_para': embed_layer_para,
        'indep_layer_para': Indep_Layer_Template,
        'interdep_layer_para':Interdep_Layer_Template,
    }
    
    
    # How to deal with the name
    
    # (1) Embed
    name = 'embed_layer_para'.split('_')[0].upper()
    sublayer = '.'.join([i.split('_')[0] for i in embed_layer_para])
    # sublayer

    fld = [i for i in embed_layer_para][0]
    embed_size = str(embed_layer_para[fld]['grain']['embedding_size'])
    # embed_size

    expander_name = '.'.join([name, sublayer, embed_size])
    
    
    # (2) Independent
    name = 'indep_layer_para'.split('_')[0].upper()
    L = []
    for i in Indep_Template:
        # Get TYPE
        # (+) type is in ['Matrix_Extractor', 'Tensor_Reducer', 'Tensor_Extractor']
        TYPE = i[0]
        if TYPE == 'Matrix_Extractor':
            t = 'ME'
        elif TYPE == 'Tensor_Reducer':
            t = 'TR'
        elif TYPE == 'Tensor_Extractor':
            t = 'TE'
        else:
            raise ValueError('must be TE, TR, ME')

        Meaning = i[1]
        Reshape_Restore  = Meaning['Reshape_Restore']
        if Reshape_Restore == 'GrainVec_SeqAs_Token':
            struct_type = 'at_'
        elif Reshape_Restore == 'GrainVec_SeqAs_Sent':
            struct_type = 'as_'
        else:
            struct_type = '_'
            
        #print(struct_type)
        NNName, NNPara = i[2]
        #print(layer)
        layer_type = NNPara['type'].lower()


        try:
            if layer_type in ['mean', 'sum', 'max']:
                string = ''.join([t + struct_type, layer_type])
            # elif layer_type in ['linear']:
            #     string = ''.join([t + struct_type, layer['type'], str(layer['n_layers'])])
            elif layer_type in ['tfm']:
                string = ''.join([t + struct_type, layer_type,  '-' + NNPara['direction_type'].lower() + '-', str(NNPara['num_encoder_layers'])])
            else:
                string = ''.join([t + struct_type, layer_type,  '-' + NNPara['direction_type'].lower() + '-', str(NNPara['n_layers'])])
        except:
            pprint(i)



        L.append(string) # print(string)

    indep_name = name + '.' + '.'.join(L)

     # (3) Interdependent
    name = 'inter_layer_para'.split('_')[0].upper()
    # name

    L = []
    for i in Interdep_Template:
        # Get TYPE
        # (+) type is in ['Matrix_Extractor', 'Tensor_Reducer', 'Tensor_Extractor']
        TYPE = i[0]
        if TYPE == 'Matrix_Extractor':
            t = 'ME'
        elif TYPE == 'Tensor_Reducer':
            t = 'TR'
        elif TYPE == 'Tensor_Extractor':
            t = 'TE'
        else:
            raise ValueError('must be TE, TR, ME')

        Meaning = i[1]
        Reshape_Restore  = Meaning['Reshape_Restore']
        if Reshape_Restore == 'GrainVec_SeqAs_Token':
            struct_type = 'at_'
        elif Reshape_Restore == 'GrainVec_SeqAs_Sent':
            struct_type = 'as_'
        else:
            struct_type = '_'
            
        #print(struct_type)
        try:
            NNName, NNPara = i[2]
            #print(layer)
            layer_type = NNPara['type'].lower()
            try:
                if layer_type in ['mean', 'sum', 'max']:
                    string = ''.join([t + struct_type, layer_type])
                # elif layer_type in ['linear']:
                #     string = ''.join([t + struct_type, layer['type'], str(layer['n_layers'])])
                elif layer_type in ['tfm']:
                    string = ''.join([t + struct_type, layer_type,  '-' + NNPara['direction_type'].lower() + '-', str(NNPara['num_encoder_layers'])])
                else:
                    string = ''.join([t + struct_type, layer_type,  '-' + NNPara['direction_type'].lower() + '-', str(NNPara['n_layers'])])
            except:
                pprint(i)
        except:
            pass
            
        L.append(string) # print(string)

    interdep_name = name + '.' + '.'.join(L) + '.' + str(Output_Size)
    # interdep_name
    # FldSeq_Dir = '-'.join([expander_name, indep_name,  interdep_name])
    FldSeq_Dir = '-'.join([indep_name,  interdep_name])
    return FieldSequence_Para, FldSeq_Dir
    
##################################################