import pandas as pd
import os
import time
import sys
import torch
import numpy as np
import pickle
from tabulate import tabulate
from six.moves import cPickle
from nlptext.sentence import Sentence

from .utils import evals as eval_tools
from .utils import batch as batch_tools
from .utils.train import add_summary_value, lr_decay, build_optimizer, lr_decay_scale

from .module.seqrepr import SeqRepr

from .nn import softmax as sfm
from .nn import crf
from .nn.linear import LinearLayer
from .nn.helper import orderSeq

# import tensorflow as tf

try:
    import tensorflow as tf
except ImportError:
    print("Tensorflow not installed; No tensorboard logging.")
    tf = None

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% MODEL

############################################# CLFIER
class SeqLabelCRF(torch.nn.Module):
    def __init__(self, Classifier_Para, MISC):
        super(SeqLabelCRF, self).__init__()
        # current treat useStartEnd as False
        self.useStartEnd = MISC['useStartEnd']
        
        self.output_dim, self.n_class = Classifier_Para['output_dim'], Classifier_Para['n_class']
        
        # highway or beam or something else
        # self.transform  = torch.nn.Linear(self.output_dim, self.output_dim)
        self.hidden2tag = torch.nn.Linear(self.output_dim, self.n_class)
        
        # classifier: crf (trainable) or softmax (not trainable).
        self.classifier = crf.CRF(self.n_class)
        
    def loss_function(self, seqrepr, targets, leng_st, misc_info):
        
        emission = self.hidden2tag(seqrepr)

        # if self.useStartEnd: # don't use this.
        #     device = emission.device
        #     # get rid of start in sent_repr.
        #     emission = emission[:, 1:-1, :]
        #     # get rid of start and end
        #     leng_st = leng_st - 2
        #     # generate a new mask according to new leng_st
        #     leng_st_mask = mask_padding(leng_st).to(device)
        #     # get rid of end in sent_repr.
        #     emission.masked_fill_(leng_st_mask.unsqueeze(-1), value = 0)
        # else:
        #     leng_st_mask = misc_info['leng_st_mask']
        #     # print(leng_st_mask[:, 0])
        #     emission.masked_fill_(leng_st_mask.unsqueeze(-1), value = 0)
        leng_st_mask = misc_info['leng_st_mask']
        # print(leng_st_mask[:, 0])
        emission.masked_fill_(leng_st_mask.unsqueeze(-1), value = 0)
        
        # emission is without start and end tokens, what is the phyiscal meaning of the emission
        loss = self.classifier.loss_function(emission, targets, leng_st_mask==0)
        return loss
    
    def decode(self, seqrepr, leng_st, misc_info):
        
        # seqrepr  = self.transform(seqrepr) # LESSION
        emission =self.hidden2tag(seqrepr)

        # if self.useStartEnd:
        #     device = emission.device
        #     # get rid of start in sent_repr.
        #     emission = emission[:, 1:-1, :]
        #     # get rid of start and end
        #     leng_st = leng_st - 2
        #     # generate a new mask according to new leng_st
        #     leng_st_mask = mask_padding(leng_st).to(device)
        #     # get rid of end in sent_repr.
        #     emission.masked_fill_(leng_st_mask.unsqueeze(-1), value=0)
        # else:
        #     leng_st_mask = misc_info['leng_st_mask']
        #     emission.masked_fill_(leng_st_mask.unsqueeze(-1), value=0)

        leng_st_mask = misc_info['leng_st_mask']
        emission.masked_fill_(leng_st_mask.unsqueeze(-1), value=0)
        
        # emission is without start and end tokens
        pred_tags = self.classifier.decode(emission, leng_st_mask == 0)
        return pred_tags
    

############################################# WHOLE MODEL
class SeqLabel(torch.nn.Module):
    def __init__(self, SeqRepr_Para, Classifier_Para, MISC):
        super(SeqLabel, self).__init__() 
        # seqrepr layer
        self.use_residual_structure = MISC['use_residual_structure']
        # self.seqrepr = SeqRepr(**SeqRepr_Para)
        self.seqrepr = SeqRepr(**SeqRepr_Para, use_residual_structure=self.use_residual_structure)
        # classifier layer
        self.clfier = SeqLabelCRF(Classifier_Para, MISC)

    def forward(self, info_dict, leng_st, misc_info):
        # print('[fieldlm.seqlabel.SeqLabel.forward]//misc_info:', misc_info)
        seqrepr = self.seqrepr(info_dict, leng_st, misc_info)
        return seqrepr
    
    def loss_function(self, seqrepr, targets, leng_st, misc_info):
        loss = self.clfier.loss_function(seqrepr, targets, leng_st, misc_info)
        return loss

    def decode(self, seqrepr, leng_st, misc_info):
        pred_tags = self.clfier.decode(seqrepr, leng_st, misc_info)
        return pred_tags

    def load_pretrain_seqrepr(self, root_path, name = '', map_location = 'gpu'):
        name = '_' + name if name != '' else ''
        path = os.path.join(root_path, 'seqrepr' + name +'.pth')
        # self.seqrepr.load_state_dict(torch.load(path, map_location))
        self.seqrepr.load_state_dict(torch.load(path))
        print("[LOAD PRE-TRAINED SEQREQR] seqrepr loaded from {}".format(path))
    
    def load_model(self, root_path, name = '', map_location = 'gpu'):
        name = '_' + name if name != '' else ''
        path = os.path.join(root_path, 'seqrepr' + name +'.pth')
        self.seqrepr.load_state_dict(torch.load(path, map_location))
        print("[LOAD PRE-TRAINED SEQREQR]    seqrepr model    saved to {}".format(path))
        
        path = os.path.join(root_path, 'clfier' + name +'.pth')
        self.clfier.load_state_dict(torch.load(path, map_location))
        print("[LOAD PRE-TRAINED CLFIER]     clfier  model    saved to {}".format(path))

    def save_model(self, root_path, name = ''):
        name = '_' + name if name != '' else ''
        path = os.path.join(root_path, 'seqrepr' + name +'.pth')
        torch.save(self.seqrepr.state_dict(), path)
        print("[SAVE PRE-TRAINED SEQREQR]    seqrepr model    saved to {}".format(path))
        
        path = os.path.join(root_path, 'clfier' + name +'.pth')
        torch.save(self.clfier.state_dict(), path)
        print("[SAVE PRE-TRAINED CLFIER]     clfier  model    saved to {}".format(path))


#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% TRAIN
# what should we pay attention to when training models?
# (1) input and target
# (2) model construction and randomization
# (3) the proportion of trainable parameters.
# (4) optimization methods
# (5) the learning rate evolution.
# (6) when to terminate the traing phrase.

############################################# model eval
def seqlabel_eval(model, sents, INPUT_FIELDS, TARGET_FIELD, TRAIN, MISC, device = None):
    # (+) MISC 
    useStartEnd = MISC['useStartEnd']; start_end_length = 2 if useStartEnd else 0
    maxSentLeng = MISC['maxSentLeng']


    # (+) Dictionary
    labels = TARGET_FIELD[-2]
    idx2tag= TARGET_FIELD[1]['GU'][0]
    
    # (+) Check Point 
    checkpoint_size_of_bp = TRAIN['checkpoint_size_of_bp_in_eval'] 
    

    # Model Point
    loss_in_mp = 0
    data_point_num_in_ep = len(sents)
    
    # Check Point
    overall_results = 0

    # Batch Point 
    batchpoint_size_of_dp = TRAIN['batchpoint_size_of_dp']
    batchpoint_num_in_ep  = int(data_point_num_in_ep/batchpoint_size_of_dp) + 1 # train batchunit number


    loss_in_cp = 0
    data_point_num_in_cp = 0
    batchpoint_num_in_cp = 0

    loss_in_ep = 0
    batchpoint_num_in_cp = 0

    # Token Point
    tokenpoint_right_num_in_cp = 0
    tokenpoint_total_num_in_cp = 0

    tokenpoint_right_num_in_ep = 0
    tokenpoint_total_num_in_ep = 0


    # time records
    start_time = time.time()
    temp_start = start_time


    log_list = []
    
    # set model in eval model, and turn off drop
    model.eval()

    
    
    with torch.no_grad():
        # prepare information about sents which are used for evaluation.

        for batchpoint_idx_in_ep in range(batchpoint_num_in_ep):
            
            # NOW WE ARE IN A BATCH of the `epochpoint_idx_in_tp` th epoch

            # (+) batch's datapoint start and end idx.
            # get batch_idx, each element is a sentence local index
            data_point_startidx_in_bs = batchpoint_idx_in_ep * batchpoint_size_of_dp
            tmp   = data_point_startidx_in_bs + batchpoint_size_of_dp
            data_point_endidx_in_bs   = tmp if tmp <= data_point_num_in_ep else data_point_num_in_ep
            # print(data_point_startidx_in_bs, data_point_endidx_in_bs)
            # (+) batch input 
            # convert batch_idx to batch of nlptext.sentence.Sentence
            batch = sents[data_point_startidx_in_bs: data_point_endidx_in_bs]
            if type(batch) == int: 
                batch = [batch]
            batch = [Sentence(i) for i in batch]

            # (+) check batch's validation
            ##########################################################
            batch = [i for i in batch if i.length + start_end_length <= maxSentLeng]
            # batch = [i for i in batch if i.length + start_end_length <= maxSentLeng and i.length >= 3]
            datapoint_num_in_bp = len(batch)
            if datapoint_num_in_bp == 0: 
                print('zero batch')
                continue
            ##########################################################

            # (+) prepare the input data and target data
            leng_st = torch.LongTensor([st.length for st in batch]).to(device)
            if len(batch) > 1: 
                batch, leng_st, reverse_index = orderSeq(np.array(batch), leng_st)

            # (+) load input information
            # info_dict, targets, leng_st, misc_info = batch_tools.get_fieldsinfo_4_batchsents(batch, INPUT_FIELDS, TARGET_FIELD, MISC, device=device)
            try:
                info_dict, targets, leng_st, misc_info = batch_tools.get_fieldsinfo_4_batchsents(batch, INPUT_FIELDS, TARGET_FIELD, MISC, device=device)
            except:
                print(batch)
                # TODO: also report the error information
                continue

            # (+) seqrepr
            seqrepr = model(info_dict, leng_st, misc_info)
            # (+) loss_in_bp
            loss_in_bp = model.loss_function(seqrepr, targets, leng_st, misc_info)
            # (+) preds
            preds = model.decode(seqrepr, leng_st, misc_info)
            # print(targets.shape)
            # print(preds.shape)

            loss_in_cp  += loss_in_bp.item()
            loss_in_ep  += loss_in_bp.item()

            data_point_num_in_cp += datapoint_num_in_bp
            batchpoint_num_in_cp += 1

            # (+) Accuracy Rate
            right, whole = eval_tools.predict_check(targets, preds, leng_st - start_end_length)
            tokenpoint_right_num_in_cp += right
            tokenpoint_total_num_in_cp += whole

            tokenpoint_right_num_in_ep += right
            tokenpoint_total_num_in_ep += whole

            # (+) Extracted Entities
            # f1-score in this mini batch
            anno_entities = eval_tools.extract_entities(targets, leng_st, idx2tag)
            pred_entities = eval_tools.extract_entities(preds, leng_st, idx2tag) 
            

            result = eval_tools.match_anno_pred_result(anno_entities, pred_entities, labels)
            overall_results = overall_results + result.sum()
            
            assert len(anno_entities) == len(pred_entities) and len(anno_entities) == len(batch) 

            for idx, sent in enumerate(batch):
                # sent = batch[idx]
                pred_SET = pred_entities[idx]
                anno_SET = anno_entities[idx]
                error = eval_tools.logErrors(sent, pred_SET, anno_SET)
                log_list.append(error)


            # (Enter Check Point)
            if (batchpoint_idx_in_ep + 1) % checkpoint_size_of_bp == 0 or (batchpoint_idx_in_ep + 1) == batchpoint_num_in_ep:
                # start logging the loss only.
                temp_time  = time.time()
                temp_cost  = temp_time - temp_start
                temp_start = temp_time

                speed_dp_per_sec_in_cp = round(data_point_num_in_cp/temp_cost, 3)
                data_point_num_in_cp = 0
                
                loss_per_bp_in_cp  = round(loss_in_cp/batchpoint_num_in_cp, 3)
                loss_in_cp = 0
                batchpoint_num_in_cp = 0

                acc_in_cp = round(tokenpoint_right_num_in_cp / tokenpoint_total_num_in_cp, 3)
                tokenpoint_right_num_in_cp = 0
                tokenpoint_total_num_in_cp = 0

                # (+) log information for this batchpoint
                print('\t'.join(['[MODEL POINT: eval]', 
                                 'BatchPoint: {}/{}'.format(batchpoint_idx_in_ep+1, batchpoint_num_in_ep),
                                 'DataPoint: {}/{};'.format((batchpoint_idx_in_ep) * batchpoint_size_of_dp + len(batch), data_point_num_in_ep),
                                 'Speed: {:.3f} dp/s;'.format(speed_dp_per_sec_in_cp),
                                 'Accuary: {};'.format(acc_in_cp),
                                 'Loss: {:.3f} /bp'.format(loss_per_bp_in_cp)]))

                sys.stdout.flush()

            # (Enter Last Batch Point, i.e., Model Point, Epoch Point)
            if (batchpoint_idx_in_ep + 1) == batchpoint_num_in_ep:
                loss_per_bp_in_ep = round(loss_in_ep/batchpoint_num_in_ep, 3)
                # loss_in_ep = 0
                # batchpoint_num_in_ep = 0

                acc_in_ep = round(tokenpoint_right_num_in_ep / tokenpoint_total_num_in_ep, 3)
                # tokenpoint_right_num_in_ep = 0
                # tokenpoint_total_num_in_ep = 0

                # write these in train function
                f1_all = eval_tools.calculate_F1_Score(overall_results, labels).fillna(0)
                f1 = f1_all.loc['E']['F1']
                # print(preds)
                # print(pred_entities)


                LogError = pd.concat(log_list).reset_index(drop = True)

            # -------- jump out of current batch point 
        # -------- jump out of the eval epoch
    # -------- jump out of torch.no_grad()
    

    return loss_per_bp_in_ep, acc_in_ep, f1_all, f1, LogError


############################################# model train
def seqlabel_train(model, data, INPUT_FIELDS, TARGET_FIELD, TRAIN, MISC, device=None):
    """Train the Sequence Labeling Tasks with this function.

    Parameters:
        model (SeqLabel): the initialized model which contains a seqrepr and clfier.
                          the trainable parameters in them may be randomly initialized or pretrained.
        data (nlptext based sentence idx): There must be a NLPText BasicObject outside to make sure 
                                           that these sentences can be initialized.
        INPUT_FIELDS (dictionary): a dict whose keys are selected fields.
                                   for each field, its keys contain all the field information of the selected corpus.
        TARGET_FIELD (list): the first element is field annoE (annotation entity), the rest elements are annoE's information.
        TRAIN (dictionary): the training settings used to train the model. 
                            the important ones are optimizier and learning rate.
        MISC (dictioanry): additional information. The ones used inside this function are useStartEnd, maxSentLeng.
                           MISC is also utilized in function `get_fieldsinfo_4_batchsents` to get field information.
        device (string):  cpu or cuda.
    """
    # (+) Random Seed Setting
    seed = TRAIN['random_seed']
    torch.manual_seed(seed); np.random.seed(seed)
    
    # (+) MISC 
    useStartEnd = MISC['useStartEnd']; start_end_length = 2 if useStartEnd else 0
    maxSentLeng = MISC['maxSentLeng']

    # (+) Path to record information of current model
    path_to_save_current_model = TRAIN['path_to_save_current_model']


    best_log_test_path  = os.path.join(path_to_save_current_model, 'best_log_test_path.csv')
    best_pfm_test_path  = os.path.join(path_to_save_current_model, 'best_pfm_test_path.csv')
    best_log_valid_path = os.path.join(path_to_save_current_model, 'best_log_valid_path.csv')
    best_pfm_valid_path = os.path.join(path_to_save_current_model, 'best_pfm_valid_path.csv')

    current_log_test_path  = os.path.join(path_to_save_current_model, 'current_log_test_path.csv')
    current_pfm_test_path  = os.path.join(path_to_save_current_model, 'current_pfm_test_path.csv')
    current_log_valid_path = os.path.join(path_to_save_current_model, 'current_log_valid_path.csv')
    current_pfm_valid_path = os.path.join(path_to_save_current_model, 'current_pfm_valid_path.csv')
   

    if not os.path.exists(path_to_save_current_model): 
        os.makedirs(path_to_save_current_model)
    print('[TRAINING :: INFO] Logging to {}'.format(path_to_save_current_model))

    # (+) Record Information in tf summary writer
    tf_summary_writer = tf and tf.summary.FileWriter(path_to_save_current_model)
    # print(tf_summary_writer)

    # (+) Ways to Initialize Model. Use Pretrained Seqrepr 
    # our model contains seqrepr and clfier. 
    # before training, two ways to initialize seqrepr.
    # one is randomly initializing.
    # another is using pretrained seqrepr.
    pretrained_seqrepr_path         = TRAIN['pretrained_seqrepr_path']
    continue_previous_trained_model = TRAIN['continue_previous_trained_model']

    if pretrained_seqrepr_path:
        # (+) model initial parameters
        model.load_pretrain_seqrepr(pretrained_seqrepr_path)

        # (+) which part to frozen
        for para in model.seqrepr.parameters():
            para.require_grad = False
        # (+) optimizer
        # lr_decay_num_per_epoch = TRAIN['lr_decay_num_per_epoch']
        optim_method, current_lr, lr_decay_rate, optimizer = build_optimizer(model, TRAIN)
        # current_lr = lr
        # current_lr = 0.001
        # clip_grad = TRAIN['clip_grad']
        # (+) history
        history = {}
        # (+) Get Current Point for Where to Start
        start_epochpoint_idx_in_tp = 0
        start_batchpoint_idx_in_ep = 0
        # (+) logging
        print('\n'+ '+'*60)
        print('[MODEL :: INITIALIZATION] load pretrained seqrepr from:', pretrained_seqrepr_path)
        print('+'*60, '\n')
    elif continue_previous_trained_model:
        # (+) model initial parameters
        model.load_model(path_to_save_current_model)
        # (+) optimizer
        optimizer.load_state_dict(torch.load(os.path.join(path_to_save_current_model, 'optimizer.pth')))
        # (+) history
        with open(os.path.join(path_to_save_current_model, 'history' + '.pkl'), 'rb') as f: 
            history = cPickle.load(f)
        # (+) Get Current Point for Where to Start
        # pay attention to handle this 
        start_epochpoint_idx_in_tp = history['history_epochpoint_idx_in_tp']
        start_batchpoint_idx_in_ep = history['history_batchpoint_idx_in_ep']
        # (+) logging
        print('\n'+ '+'*60)
        print('[MODEL :: INITIALIZATION] load whole model from:', path_to_save_current_model)
        print('+'*60, '\n')
    else:

        # (+) model initial parameters
        # model
        # (+) from part to Frozen
        # position embedding, is also important
        for fld, expander_layer in model.seqrepr.Expander_Layers.items():
            print(fld)
            if fld.lower() == 'medpos':
                print('pass medpos')
                continue
            for para in expander_layer.layers.grain.parameters():
                para.require_grad = False

        # (+) optimizer
        # lr_decay_num_per_epoch = TRAIN['lr_decay_num_per_epoch']
        optim_method, current_lr, lr_decay_rate, optimizer = build_optimizer(model, TRAIN)
        print('[MODEL :: Learning Rate] Current Learning Rate is:', current_lr)
        # (+) history
        history = {}
        # (+) Get Current Point for Where to Start
        start_epochpoint_idx_in_tp = 0
        start_batchpoint_idx_in_ep = 0
        # (+) logging
        print('\n'+ '+'*60)
        print('[MODEL :: INITIALIZATION] frozen field grain embeddings, and initialize all the parameters in indep, interdep, and clfier')
        print('+'*60, '\n')



    # (+) Data
    train_sent_idx, valid_sent_idx, test_sent_idx = data
    
    # (+) Logging and Saving for Train Data
    # dp, bp, cp, mp, ep, tp
    batchpoint_size_of_dp = TRAIN['batchpoint_size_of_dp']

    data_point_num_in_ep  = len(train_sent_idx) # train datapoint number
    batchpoint_num_in_ep  = int(data_point_num_in_ep/batchpoint_size_of_dp) + 1 # train batchunit number
    epochpoint_num_in_tp  = TRAIN['epochpoint_num_in_tp'] # or trainpoint_size_of_epc
    
    checkpoint_size_of_bp = TRAIN['checkpoint_size_of_bp'] 
    modelpoint_size_of_bp = 'Epoch End'
    assert modelpoint_size_of_bp == 'Epoch End'
    # for language model
    # modelpoint_size_of_bp = TRAIN['modelpoint_size_of_bp'] # we don't use this one.
    # assert modelpoint_size % checkpoint_size == 0 


    # at the beginning of the train procession
    best_record = None
    best_f1_all_val  = None
    best_f1_all_test = None
    best_f1_val  = None
    best_f1_test = None

    Fine_Tune_Flag_Embed = False
    Fine_Tune_Flag_INTER = False
    Fine_Tune_Flag_INDEP = False


    num_of_no_improvements_para = 0


    # history_batchpoint_idx
    step_num = 0
    last_valid_loss = 1000000
    for epochpoint_idx_in_tp in range(start_epochpoint_idx_in_tp, epochpoint_num_in_tp):
        # (+) start_batchpoint_idx_in_ep and train_sent_idx_for_ep
        if epochpoint_idx_in_tp == start_epochpoint_idx_in_tp and epochpoint_idx_in_tp != 0:
            start_batchpoint_idx_in_ep = start_batchpoint_idx_in_ep
            # TODO
            train_sent_idx_for_ep = pickle.load() 
        else:
            start_batchpoint_idx_in_ep = 0
            train_sent_idx_for_ep = np.random.shuffle(train_sent_idx)

        # np.random.shuffle(train_sent_idx)
        print('\n' +  '=='*30)
        print('[TRAINING :: EPOCH POINT] START AT EPOCH', epochpoint_idx_in_tp+1, '/', epochpoint_num_in_tp)

        # (+) time records
        epoch_start = time.time()
        temp_start  = epoch_start

        
        loss_in_cp = 0
        data_point_num_in_cp = 0
        batchpoint_num_in_cp = 0


        # Model Point
        loss_in_mp = 0 
        batchpoint_num_in_mp = 0


        # (+) TRAINING STAGE FOR AN EPOCH
        # 1 to last n-1 checkout points
        for batchpoint_idx_in_ep in range(start_batchpoint_idx_in_ep, batchpoint_num_in_ep):
            # NOW WE ARE IN A BATCH of the `epochpoint_idx_in_tp` th epoch

            # (+) batch's datapoint start and end idx.
            # get batch_idx, each element is a sentence local index
            data_point_startidx_in_bs = batchpoint_idx_in_ep * batchpoint_size_of_dp
            tmp   = data_point_startidx_in_bs + batchpoint_size_of_dp
            data_point_endidx_in_bs   = tmp if tmp <= data_point_num_in_ep else data_point_num_in_ep
            # print(data_point_startidx_in_bs, data_point_endidx_in_bs)
            # (+) batch input 
            # convert batch_idx to batch of nlptext.sentence.Sentence
            batch = train_sent_idx[data_point_startidx_in_bs: data_point_endidx_in_bs]
            if type(batch) == int: 
                batch = [batch]
            batch = [Sentence(i) for i in batch]

            # (+) check batch's validation
            ##########################################################
            batch = [i for i in batch if i.length + start_end_length <= maxSentLeng]
            # batch = [i for i in batch if i.length + start_end_length <= maxSentLeng and i.length >= 3]
            datapoint_num_in_bp = len(batch)
            if datapoint_num_in_bp == 0: 
                print('zero batch')
                continue
            ##########################################################

            # (+) prepare the input data and target data
            leng_st = torch.LongTensor([st.length for st in batch]).to(device)
            if len(batch) > 1: 
                batch, leng_st, reverse_index = orderSeq(np.array(batch), leng_st)



            # sent = batch[0]
            # print(sent.get_grain_str('token'))
            # print(sent.get_grain_str('medpos'))


            # (+) load input information
            info_dict, targets, leng_st, misc_info = batch_tools.get_fieldsinfo_4_batchsents(batch, INPUT_FIELDS, TARGET_FIELD, MISC, device=device)
            # try:
            #     info_dict, targets, leng_st, misc_info = batch_tools.get_fieldsinfo_4_batchsents(batch, INPUT_FIELDS, TARGET_FIELD, MISC, device=device)
            # except Exception as e:
            #     print('loss calculate')
            #     print(e, file = sys.stderr)
            #     print(batch)
            #     continue


            # (+) learning rate warming up
            if TRAIN['lr_warm_up'] and step_num <= TRAIN['lr_warm_up_steps']:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = param_group['lr'] + TRAIN['peak_lr'] / TRAIN['lr_warm_up_steps']
                step_num = step_num + 1

            # (+) model preparation
            model.train()
            model.zero_grad()


            # print('[fieldlm.seqlabel.seqlabel_train]//misc_info', misc_info)
            # (+) calculate seqrepr
            seqrepr = model(info_dict, leng_st, misc_info)

            # (+) calculate loss
            loss_in_bp = model.loss_function(seqrepr, targets, leng_st, misc_info)
            loss_in_bp.backward()

            # (+) optimizer update parameters
            # torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad) # this is a problem.
            optimizer.step()


            
            # try:
            #     # (+) calculate seqrepr
            #     seqrepr = model(info_dict, leng_st, misc_info)

            #     # (+) calculate loss
            #     loss_in_bp = model.loss_function(seqrepr, targets, leng_st, misc_info)
            #     loss_in_bp.backward()

            #     # (+) optimizer update parameters
            #     # torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad) # this is a problem.
            #     optimizer.step()

            # except Exception as e:
            #     print('loss calculate')
            #     print(e, file = sys.stderr)
            #     print(batch)
            #     continue


            # (+) record
            # bp in tp
            batchpoint_idx_in_tp = batchpoint_idx_in_ep + 1 + batchpoint_num_in_ep *  epochpoint_idx_in_tp

            # the observer for checkout_point
            loss_in_cp += loss_in_bp.item()
            data_point_num_in_cp  += datapoint_num_in_bp # datapoint_num in current checkpoint
            batchpoint_num_in_cp  += 1

            # the observer for checkout_point
            loss_in_mp += loss_in_bp.item()
            batchpoint_num_in_mp  += 1

            # CheckPoint (Enter a Checkpoint if Batchpoint is Special)
            if (batchpoint_idx_in_ep + 1) % checkpoint_size_of_bp == 0 or (batchpoint_idx_in_ep + 1) == batchpoint_num_in_ep:
                # start logging the loss only.
                temp_time  = time.time()
                temp_cost  = temp_time - temp_start
                temp_start = temp_time

                speed_dp_per_sec_in_cp = round(data_point_num_in_cp/temp_cost, 3)
                data_point_num_in_cp = 0
                
                loss_per_bp_in_cp  = round(loss_in_cp/batchpoint_num_in_cp, 3)
                loss_in_cp = 0
                batchpoint_num_in_cp = 0


                for param_group in optimizer.param_groups:
                    current_lr = param_group['lr']

                # (+) log information for this batchpoint
                print('\t'.join(['[CHECK POINT]', 
                                 'Epoch: {}'.format(epochpoint_idx_in_tp+1),
                                 'BatchPoint: {}/{}'.format(batchpoint_idx_in_ep+1, batchpoint_num_in_ep),
                                 'DataPoint: {}/{};'.format((batchpoint_idx_in_ep) * batchpoint_size_of_dp + len(batch), data_point_num_in_ep),
                                 'Speed: {:.3f} dp/s;'.format(speed_dp_per_sec_in_cp),
                                 'Loss: {:.3f} /bp'.format(loss_per_bp_in_cp)]))

                # (+) check whether the loss explores or not.
                if loss_per_bp_in_cp > 1e8 or str(loss_per_bp_in_cp) == "nan":
                    raise ValueError('The Loss is Explored!')

                # (+) save to tf_summary_writer
                if tf is not None:
                    add_summary_value(tf_summary_writer, 'learning_rate',  current_lr,  batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'loss_per_bp_in_cp',  loss_per_bp_in_cp,  batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'speed_dp_per_sec_in_cp', speed_dp_per_sec_in_cp, batchpoint_idx_in_tp)
                    tf_summary_writer.flush()

                sys.stdout.flush()
                # ------jump out of one checkpoint

            # ModelPoint (Enter Valid (and perhaps Test) and Model Saving if Checkpoint is Special)
            if (batchpoint_idx_in_ep + 1) == batchpoint_num_in_ep:

                # (+) Save current model to model_path + 'current_model'
                print('\n[MODEL POINT :: Save Current Model] Save Current Model to:')
                print('\n' +  'pretrained_seqrepr_path is:', pretrained_seqrepr_path)
                print('TRAIN', TRAIN)
                model.save_model(path_to_save_current_model, name = 'current')
                optimizer_path = os.path.join(path_to_save_current_model, 'optimizer_current.pth')
                torch.save(optimizer.state_dict(), optimizer_path)


                # (+) Write current situation to history
                history['history_epochpoint_idx_in_tp'] = epochpoint_idx_in_tp
                history['history_batchpoint_idx_in_ep'] = batchpoint_idx_in_ep

                with open(os.path.join(path_to_save_current_model, 'history' + '.pkl'), 'wb') as f: 
                    cPickle.dump(history, f)


                # (+) Valid: Current Model in Validation
                print('\n[MODEL POINT :: Valid] EVAL validation data')
                loss_val, acc_val, f1_all_val, f1_val, LogError_val = seqlabel_eval(model, valid_sent_idx, INPUT_FIELDS, TARGET_FIELD, TRAIN,  MISC, device=device)
                print(tabulate(f1_all_val, headers=f1_all_val.columns, tablefmt='psql', numalign='left'))

                # Write the validation f1 score
                if tf is not None:
                    for entity_type in range(len(f1_all_val)):
                        add_summary_value(tf_summary_writer, 
                                          f1_all_val.index[entity_type] + '-' + f1_all_val.columns[-1] + '_valid', 
                                          f1_all_val.iloc [entity_type, -1], 
                                          batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'loss_val',  loss_val,  batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'f1_val',    f1_val,    batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'acc_val',   acc_val,   batchpoint_idx_in_tp)
                    tf_summary_writer.flush()


                # (+) Test: Current Model in Validation
                # in some case, it may be later.
                print('\n[MODEL POINT :: Test] EVAL test data')
                loss_test, acc_test, f1_all_test, f1_test, LogError_test = seqlabel_eval(model, test_sent_idx, INPUT_FIELDS, TARGET_FIELD, TRAIN,  MISC, device=device)
                print(tabulate(f1_all_test, headers = f1_all_test.columns, tablefmt = 'psql', numalign = 'left'))
                # Write the validation f1 score
                if tf is not None:
                    for entity_type in range(len(f1_all_test)):
                        add_summary_value(tf_summary_writer, 
                                          f1_all_test.index[entity_type] + '-' + f1_all_test.columns[-1] + '_test', 
                                          f1_all_test.iloc [entity_type, -1], 
                                          batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'loss_test',  loss_test,  batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'f1_test',    f1_test,    batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'acc_test',   acc_test,   batchpoint_idx_in_tp)
                    tf_summary_writer.flush()
                
                # (+) Is Current Model also the BEST MODEL so far?
                print('\n[MODEL POINT :: TRAIN VALID TEST INFO] Record Overall Information')

                loss_per_bp_in_mp = round(loss_in_mp/batchpoint_num_in_mp, 3)
                loss_in_mp = 0
                batchpoint_num_in_mp = 0
                
                if tf is not None:
                    add_summary_value(tf_summary_writer, 'loss_per_bp_in_mp',       loss_per_bp_in_mp,       batchpoint_idx_in_tp)

                # print(path_to_save_current_model, '\n\n')
                print('[MODEL POINT :: Loss] Current Average    Loss Train :         {:.4f} '.format(loss_per_bp_in_mp) + 'loss_per_bp_in_mp')
                print('[MODEL POINT :: Loss] Current Average    Loss Valid :         {:.4f} '.format(loss_val))
                print('[MODEL POINT :: Loss] Current Average    Loss  Test :         {:.4f} '.format(loss_test))
                print('[MODEL POINT ::  F1 ] Current Average    F1   Valid :         {:.4f} '.format(f1_val))
                print('[MODEL POINT ::  F1 ] Current Average    F1    Test :         {:.4f} '.format(f1_test))
                

                LogError_test.to_csv(current_log_test_path, index = False, sep = '\t')
                f1_all_test.  to_csv(current_pfm_test_path, index = True,  sep = '\t')

                LogError_val.to_csv(current_log_valid_path, index = False, sep = '\t')
                f1_all_val.  to_csv(current_pfm_valid_path, index = True,  sep = '\t')



                # (+) Hold Current Record according to the predetermined criterion.
                current_record = f1_val
                # current_record = loss_val

                # if best_record is None or best_record > current_record: 
                if best_record is None or best_record < current_record: 
                    # current record is better the best historical record 
                    best_record = current_record
                    
                    best_f1_all_val  = f1_all_val
                    best_f1_all_test = f1_all_test

                    best_f1_val  = f1_val
                    best_f1_test = f1_test
                    print('\n\n'+'--'*20)
                    
                    print('[MODEL POINT :: Best Info] New  Best         Valid F1:        {:.4f} '.format(best_f1_val))
                    print('[MODEL POINT :: Best Info] New  Associated   Test  F1:        {:.4f} '.format(best_f1_test))

                    # save model 
                    print('[MODEL POINT :: Model Selection] Now SAVE A BETTER MODEL to:')
                    model.save_model(path_to_save_current_model)
                    optimizer_path = os.path.join(path_to_save_current_model, 'optimizer.pth')
                    torch.save(optimizer.state_dict(), optimizer_path)

                    LogError_test.to_csv(best_log_test_path, index = False, sep = '\t')
                    f1_all_test.  to_csv(best_pfm_test_path, index = True,  sep = '\t')

                    LogError_val.to_csv(best_log_valid_path, index = False, sep = '\t')
                    f1_all_val.  to_csv(best_pfm_valid_path, index = True,  sep = '\t')
                    # num_of_no_improvements = 0
                    # num_of_no_improvements_para = 0
                    print('--'*20, '\n\n')
                    
                else:
                    print('\n\n','--'*20)
                    print('[MODEL POINT :: Model Selection] Current model is not the best')
                    print('[MODEL POINT :: Best Info] Last  Best        Valid F1:        {:.4f} '.format(best_f1_val))
                    print('[MODEL POINT :: Best Info] Last  Associated  Test  F1:        {:.4f} '.format(best_f1_test))

                if last_valid_loss < loss_val:
                    num_of_no_improvements_para += 1
                    # num_of_no_improvements_para_for_lr_decay = 
                    # TO decay the learning rate
                    if num_of_no_improvements_para % TRAIN['num_no_improvements_for_lr_decay'] == 0:
                        optimizer, current_lr = lr_decay_scale(optimizer, lr_decay_rate, current_lr)
                        print('\n')
                        print("[LEARNING RATE :: modify lr] Learning Rate is Set as:", current_lr)
                        print('--'*20, '\n\n')

                if tf is not None:
                    add_summary_value(tf_summary_writer, 'num_of_no_improvements_para',  num_of_no_improvements_para,  batchpoint_idx_in_tp)
                    # add_summary_value(tf_summary_writer, 'num_of_no_improvements_para',  num_of_no_improvements_para,  batchpoint_idx_in_tp)

                # ------ jump out of one valid and test
                last_valid_loss = loss_val

            # Fine-Tine Part
            if not pretrained_seqrepr_path:
                if epochpoint_idx_in_tp == int(epochpoint_num_in_tp*0.25):
                    if Fine_Tune_Flag_Embed == False:
                        for para in model.seqrepr.parameters():
                            para.require_grad = True
                        current_lr = TRAIN['peak_lr'] * 0.2
                        optim_method, current_lr, lr_decay_rate, optimizer = build_optimizer(model, TRAIN, lr = current_lr)
                        num_of_no_improvements_para = 0
                        Fine_Tune_Flag_Embed = True
                        print("[LEARNING RATE :: Fine-Tinue]: Fine-Tuning Parameters: EMBED")

            else:
                if epochpoint_idx_in_tp == int(epochpoint_num_in_tp*0.25*0.5):
                    if Fine_Tune_Flag_INTER == False:
                        for para in model.seqrepr.Interdep_Layer.parameters():
                            para.require_grad = True
                        current_lr = TRAIN['peak_lr'] * 0.5
                        optim_method, current_lr, lr_decay_rate, optimizer = build_optimizer(model, TRAIN, lr = current_lr)
                        num_of_no_improvements_para = 0
                        Fine_Tune_Flag_INTER = True
                        print("[LEARNING RATE :: Fine-Tinue]: Fine-Tuning Parameters: INTERDEP")
                    # ------jump out of one learning rate modification

                if epochpoint_idx_in_tp == int(epochpoint_num_in_tp*0.5*0.5):
                    if Fine_Tune_Flag_INDEP == False:
                        for para in model.seqrepr.Indep_Layers.parameters():
                            para.require_grad = True
                        current_lr = TRAIN['peak_lr'] * 0.4
                        optim_method, current_lr, lr_decay_rate, optimizer = build_optimizer(model, TRAIN, lr = current_lr)
                        num_of_no_improvements_para = 0
                        Fine_Tune_Flag_INDEP = True
                        print("[LEARNING RATE :: Fine-Tinue]: Fine-Tuning Parameters: INDEP")

                    # ------jump out of one learning rate modification
                
                if epochpoint_idx_in_tp == int(epochpoint_num_in_tp*0.75*0.5):
                    if Fine_Tune_Flag_Embed == False:
                        for para in model.seqrepr.Expander_Layers.parameters():
                            para.require_grad = True
                        current_lr = TRAIN['peak_lr'] * 0.3
                        optim_method, current_lr, lr_decay_rate, optimizer = build_optimizer(model, TRAIN, lr = current_lr)
                        num_of_no_improvements_para = 0
                        Fine_Tune_Flag_Embed = True
                        print("[LEARNING RATE :: Fine-Tinue]: Fine-Tuning Parameters: EMBED")

            # ------jump out of one batchpoint
        print('\n'+'==' * 50, '\n\n')
        # ------jump out of one epoch
    # ------jump out of one training

    # # V: testing stage                  
    print('\n[TRAIN END :: Fianl Result] Show the Final Results')
    print('\nFor valid')
    print(tabulate(best_f1_all_val,  headers = best_f1_all_val.columns, tablefmt = 'psql', numalign = 'left'))
    print('\nFor test')
    print(tabulate(best_f1_all_test, headers = best_f1_all_test.columns, tablefmt = 'psql', numalign = 'left'))
    
    print('==' * 50, '\n\n')

    return model
