import sys
import os
import logging
import argparse
import torch
import json
import shutil

from pprint import pprint
from datetime import datetime
from nlptext.base import BasicObject
from nlptext.folder import Folder
from fieldembed import FieldEmbedding
from fieldlm.utils import batch as batch_tools


# option 1
# (-) don't need to modify
from script_train.embed_config import FldEmbed
from script_train.seqrepr_config import Config
from script_train.clfier_config import Classifier_Para
from script_train.model_config import Model, train



# (+) need to modify: Data_Dir, min_token_freq, input_fields, target_field, tagScheme, MISC
from script_train.data_config import Data_Dir, min_token_freq, input_fields, target_field, tagScheme, MISC, FldEmbed_Dir, SIZE

# (+) need to modify: use_residual_structure, SeqRepr_Config_Name
from script_train.data_config import use_residual_structure, SeqRepr_Config_Name, model_type

# (+) need to modify: TRAIN, device
from script_train.train_config import TRAIN, device


print(SeqRepr_Config_Name)

if __name__ == '__main__':

    ########################################## data
    print('\n' + '+'*60)
    print('[   MODEL  ] The model type for this task  : ' + model_type )
    print('[INPUT DATA] The corpus used for this task : ' + Data_Dir)
    print('[INPUT DATA] The mininal token frequency   : ' + str(min_token_freq))
    print('+'*60 + '\n')
    BasicObject.INIT_FROM_PICKLE(Data_Dir, min_token_freq = min_token_freq)
    nlptext = BasicObject
    

    ########################################## data: train, valid, test
    data = []
    random_seed = TRAIN['random_seed'] 
    total_sent_num = BasicObject.SENT['length']
    train_sent_idx, valid_sent_idx, test_sent_idx = batch_tools.get_train_valid_test(total_sent_num, train_prop =TRAIN['train_prop'],  seed=random_seed)
    print('[SPLIT TRAIN VALID TEST]: load train valid test from RANDOM ASSIGNMENT of ' + str(TRAIN['train_prop']) )

    data = [train_sent_idx, valid_sent_idx, test_sent_idx]
    print('[SPLIT TRAIN VALID TEST] The num of train sentences:', len(train_sent_idx))
    print('[SPLIT TRAIN VALID TEST] The num of valid sentences:', len(valid_sent_idx))
    print('[SPLIT TRAIN VALID TEST] The num of test  sentences:', len(test_sent_idx))
    print('\n')



    ########################################## logging data
    if FldEmbed_Dir:
        print('[INPUT DATA] The pretrain field embedding path is: \n\t', FldEmbed_Dir)
    else:
        print('[INPUT DATA] The pretrain field embedding is NOT USED!')
    print('\n')

    print('='*20 + ' Input Fields and Output Field ' + '+'*20)
    print('[INPUT DATA] The "input fields" in this task : ', input_fields)
    print('[INPUT DATA] The "target field" in this task : ', target_field)
    print('[INPUT DATA] The tagScheme for "target field": ', tagScheme)
    print('\n')

    print('='*20 + ' Special Tokens ' + '+'*20)
    print('[INPUT DATA] All special tokens: ', MISC['idx2specialtokens'])
    print('\n')

    print('='*20 + ' Start and End Tokens ' + '+'*20)
    print('[INPUT DATA] use </start> and </end> special tokens:', MISC['useStartEnd'])
    print('[INPUT DATA] use </start> special token only:       ', MISC['useStartOnly'])
    print('\n')

    print('='*20 + ' Mask Information in Input Data ' + '+'*20)
    print('[INPUT DATA] Use Mask in Input Data or not    :     ', MISC['useMask'])
    print('[INPUT DATA] The maskProportion for Input Data:     ', MISC['maskProportion'])
    print('\n')

    print('='*20 + ' Max Sentence Length ' + '+'*20)
    print('[INPUT DATA] The maximum sentence length:           ', MISC['maxSentLeng'])
    print('\n')

    print('='*20 +  ' Position Embeddings ' + '+'*20)
    print('[INPUT DATA] Use Token Position Embeddings:         ', MISC['useTkPsn'])
    print('[INPUT DATA] Use Grain Position Embeddings:         ', MISC['useGrPsn'])
    print('\n')



    ########################################## input fields and target field
    Input_Fields, Target_Field, Field_Dir = batch_tools.get_Input_Target_Field(input_fields, target_field, tagScheme)
    pprint(Input_Fields)
    
    ########################################## pretrain field embeddings
    fldembed = FldEmbed(FldEmbed_Dir, SIZE) if FldEmbed_Dir else None
    

    ########################################## model configs
    # input fields
    INPUT_FIELDS, EMBED_FIELDS = batch_tools.get_input_fields_info(nlptext, Input_Fields, fldembed, **MISC)
    # target field
    TARGET_FIELD = batch_tools.get_target_field_info(nlptext, Target_Field, fldembed, **MISC)

    # print(EMBED_FIELDS)


    Indep_Template_Prl_Grain, Interdep_Template, SIZE = Config[SeqRepr_Config_Name]
    FldSeq_Para, FldSeq_Dir = batch_tools.generate_fldseq_para(EMBED_FIELDS, 
                                                               Indep_Template_Prl_Grain, 
                                                               Interdep_Template, 
                                                               SIZE)
    # pprint(FldSeq_Para)
    # assert SIZE == embed_size
    # FldSeq_Para, FldSeq_Dir = batch_tools.generate_fldseq_para(EMBED_FIELDS, Indep_Prl_Grain, Indep_One_Grain, Interdep, Output_Size)
    
    # (+) Classifier Configuration
    Classifier_Para['output_dim'] = SIZE
    Classifier_Para['n_class'] = TARGET_FIELD[-1]
    # if 'Matrix_Reducer_Layer' in Classifier_Para:
    #     Classifier_Para['Matrix_Reducer_Layer'][1][1]['input_size']  = SIZE
    #     Classifier_Para['Matrix_Reducer_Layer'][1][1]['output_size'] = SIZE
    ########################################## models
    # print(FldSeq_Para)
    MISC['use_residual_structure'] = use_residual_structure
    model = Model(FldSeq_Para, Classifier_Para, MISC).to(device)


    ########################################## associated file paths.
    TASK = model_type.upper() + '.'
    Cls_Dir = TASK + str(Classifier_Para['n_class']) + '.RS' + str(use_residual_structure)
    Model_Dir = Data_Dir.replace('data', 'model') 
    ModelPath = os.path.join(Model_Dir, Field_Dir, SeqRepr_Config_Name, Cls_Dir)
    TRAIN['path_to_save_current_model'] = ModelPath
    
    
    
    print('[ASSOCIATED PATH] save the model to:')
    print(ModelPath)
    
    ##########################################
    # pprint(MISC)
    # pprint(TRAIN)
    # pprint(FldSeq_Para)
    # pprint(Classifer_Para)
    print(model)
    # train the model.

    # print(INPUT_FIELDS)
    train(model, data, INPUT_FIELDS, TARGET_FIELD, TRAIN, MISC, device = device)
