import pandas as pd
import os
import time
import sys
import torch
import numpy as np
from tabulate import tabulate
from six.moves import cPickle
from nlptext.sentence import Sentence
import pandas as pd

from .utils import evals as eval_tools
from .utils import batch as batch_tools
from .utils.train import add_summary_value, lr_decay, build_optimizer, lr_decay_scale

from .module.seqrepr import SeqRepr

from .nn import softmax as sfm
from .nn import crf
from .nn.linear import LinearLayer
from .nn.helper import orderSeq
from .sublayer.reducer import Matrix_Reducer_Layer

# import tensorflow as tf

try:
    import tensorflow as tf
except ImportError:
    print("Tensorflow not installed; No tensorboard logging.")
    tf = None

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% MODEL

############################################# CLFIER  
class SeqClfier(torch.nn.Module):
    def __init__(self, Classifier_Para, MISC):
        super(SeqClfier, self).__init__()
        # current treat useStartEnd as False
        self.useStartEnd = MISC['useStartEnd']
        
        self.output_dim, self.n_class = Classifier_Para['output_dim'], Classifier_Para['n_class']
        subtype, name_layer = Classifier_Para['Matrix_Reducer_Layer']

        self.matrix_reducer = Matrix_Reducer_Layer(subtype, name_layer)
        
        # highway or beam or something else
        # self.transform  = torch.nn.Linear(self.output_dim, self.output_dim)
        self.hidden2tag = torch.nn.Linear(self.output_dim, self.n_class)

        # classifier: crf (trainable) or softmax (not trainable).
        self.softmax_layer = sfm.SoftmaxSumLayer()
        # self.softmax_layer = torch.nn.CrossEntropyLoss()


    def loss_function(self, seqrepr, targets, leng_st, misc_info):
        # (+) masked idxes
        # flattened_idxes = torch.cat([tensor+ idx*leng_st[0] for idx, tensor in enumerate(misc_info['batch_masked_tokens_idxes'])])

        seqrepr_reduced = self.matrix_reducer(seqrepr, leng_st, misc_info)

        # (+) get emission, whose shape is (leng_st, n_class)
        emission = self.hidden2tag(seqrepr_reduced) #+ self.bias
        # batch_size, label_size = emission.shape
        # emission_flatten = emission.view(batch_size * leng_st, vocab_size)[flattened_idxes]

        # (+) target
        # target_flatten_masked_ones = targets.view(batch_size * leng_st)[flattened_idxes]

        # (+) loss 
        # loss per token
        loss = self.softmax_layer(emission, targets) / emission.size(0)
        return loss

    def decode(self, seqrepr, leng_st, misc_info):
        seqrepr_reduced = self.matrix_reducer(seqrepr, leng_st, misc_info)
        emission  = self.hidden2tag(seqrepr_reduced) #+ self.bias
        pred_tags = self.softmax_layer.decode(emission)# .masked_fill(leng_st_mask, 0).long()

        # for sent_idx, pred_content in enumerate(preds):
        #     sentence = ' '.join([idx2token[i] for i in pred_content if i != 0])
        #     print(sentence)

        return pred_tags
  
############################################# WHOLE MODEL
class SeqCls(torch.nn.Module):
    def __init__(self, SeqRepr_Para, Classifier_Para, MISC):
        super(SeqCls, self).__init__() 
        # seqrepr layer
        self.seqrepr = SeqRepr(**SeqRepr_Para)
        # classifier layer
        self.clfier = SeqClfier(Classifier_Para, MISC)

    def forward(self, info_dict, leng_st, misc_info):
        seqrepr = self.seqrepr(info_dict, leng_st, misc_info)
        return seqrepr
    
    def loss_function(self, seqrepr, targets, leng_st, misc_info):
        loss = self.clfier.loss_function(seqrepr, targets, leng_st, misc_info)
        return loss

    def decode(self, seqrepr, leng_st, misc_info):
        pred_tags = self.clfier.decode(seqrepr, leng_st, misc_info)
        return pred_tags

    def load_pretrain_seqrepr(self, root_path, name = '', map_location = 'gpu'):
        name = '_' + name if name != '' else ''
        path = os.path.join(root_path, 'seqrepr' + name +'.pth')
        # self.seqrepr.load_state_dict(torch.load(path, map_location))
        self.seqrepr.load_state_dict(torch.load(path))
        print("[LOAD PRE-TRAINED SEQREQR] seqrepr loaded from {}".format(path))
    
    def load_model(self, root_path, name = '', map_location = 'gpu'):
        name = '_' + name if name != '' else ''
        path = os.path.join(root_path, 'seqrepr' + name +'.pth')
        self.seqrepr.load_state_dict(torch.load(path, map_location))
        print("[LOAD PRE-TRAINED SEQREQR]    seqrepr model    saved to {}".format(path))
        
        path = os.path.join(root_path, 'clfier' + name +'.pth')
        self.clfier.load_state_dict(torch.load(path, map_location))
        print("[LOAD PRE-TRAINED CLFIER]     clfier  model    saved to {}".format(path))

    def save_model(self, root_path, name = ''):
        name = '_' + name if name != '' else ''
        path = os.path.join(root_path, 'seqrepr' + name +'.pth')
        torch.save(self.seqrepr.state_dict(), path)
        print("[SAVE PRE-TRAINED SEQREQR]    seqrepr model    saved to {}".format(path))
        
        path = os.path.join(root_path, 'clfier' + name +'.pth')
        torch.save(self.clfier.state_dict(), path)
        print("[SAVE PRE-TRAINED CLFIER]     clfier  model    saved to {}".format(path))


#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% TRAIN
# what should we pay attention to when training models?
# (1) input and target
# (2) model construction and randomization
# (3) the proportion of trainable parameters.
# (4) optimization methods
# (5) the learning rate evolution.
# (6) when to terminate the traing phrase.

#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% TRAIN
# what should we pay attention to when training models?
# (1) input and target
# (2) model construction and randomization
# (3) the proportion of trainable parameters.
# (4) optimization methods
# (5) the learning rate evolution.
# (6) when to terminate the traing phrase.

############################################# model eval
def seqcls_eval(model, sents, INPUT_FIELDS, TARGET_FIELD, TRAIN, MISC, device = None):
    # (+) MISC 
    useStartEnd = MISC['useStartEnd']; start_end_length = 2 if useStartEnd else 0
    maxSentLeng = MISC['maxSentLeng']


    # (+) Dictionary
    labels = TARGET_FIELD[-2]
    idx2tag= TARGET_FIELD[1]['GU'][0]
    pred_target_pairs = []
    # (+) Check Point 
    checkpoint_size_of_bp = TRAIN['checkpoint_size_of_bp_in_eval'] 
    

    # Model Point
    loss_in_mp = 0
    data_point_num_in_ep = len(sents)
    
    # Check Point
    overall_results = 0

    # Batch Point 
    batchpoint_size_of_dp = TRAIN['batchpoint_size_of_dp']
    batchpoint_num_in_ep  = int(data_point_num_in_ep/batchpoint_size_of_dp) + 1 # train batchunit number


    loss_in_cp = 0
    data_point_num_in_cp = 0
    batchpoint_num_in_cp = 0

    loss_in_ep = 0
    batchpoint_num_in_cp = 0

    # Token Point
    tokenpoint_right_num_in_cp = 0
    tokenpoint_total_num_in_cp = 0

    tokenpoint_right_num_in_ep = 0
    tokenpoint_total_num_in_ep = 0


    # time records
    start_time = time.time()
    temp_start = start_time
    
    # set model in eval model, and turn off drop
    model.eval()
    
    with torch.no_grad():
        # prepare information about sents which are used for evaluation.

        for batchpoint_idx_in_ep in range(batchpoint_num_in_ep):
            
            # NOW WE ARE IN A BATCH of the `epochpoint_idx_in_tp` th epoch

            # (+) batch's datapoint start and end idx.
            # get batch_idx, each element is a sentence local index
            data_point_startidx_in_bs = batchpoint_idx_in_ep * batchpoint_size_of_dp
            tmp   = data_point_startidx_in_bs + batchpoint_size_of_dp
            data_point_endidx_in_bs   = tmp if tmp <= data_point_num_in_ep else data_point_num_in_ep
            # print(data_point_startidx_in_bs, data_point_endidx_in_bs)
            # (+) batch input 
            # convert batch_idx to batch of nlptext.sentence.Sentence
            batch = sents[data_point_startidx_in_bs: data_point_endidx_in_bs]
            if type(batch) == int: batch = [batch]
            batch = [Sentence(i) for i in batch]

            # (+) check batch's validation
            ##########################################################
            batch = [i for i in batch if i.length + start_end_length <= maxSentLeng]
            datapoint_num_in_bp = len(batch)
            if datapoint_num_in_bp == 0: print('zero batch'); continue
            ##########################################################

            # (+) prepare the input data and target data
            leng_st = torch.LongTensor([st.length for st in batch]).to(device)
            if len(batch) > 1: batch, leng_st, reverse_index = orderSeq(np.array(batch), leng_st)

            # (+) load input information
            # info_dict, targets, leng_st, misc_info = batch_tools.get_fieldsinfo_4_batchsents(batch, INPUT_FIELDS, TARGET_FIELD, MISC, device=device)
            try:
                info_dict, targets, leng_st, misc_info = batch_tools.get_fieldsinfo_4_batchsents(batch, INPUT_FIELDS, TARGET_FIELD, MISC, device=device)
            except:
                print(batch)
                # TODO: also report the error information
                continue

            # (+) seqrepr
            seqrepr = model(info_dict, leng_st, misc_info)
            # (+) loss_in_bp
            loss_in_bp = model.loss_function(seqrepr, targets, leng_st, misc_info)
            # (+) preds
            preds = model.decode(seqrepr, leng_st, misc_info)
            # print(targets.shape)
            # print(preds.shape)

            loss_in_cp  += loss_in_bp.item()
            loss_in_ep  += loss_in_bp.item()

            data_point_num_in_cp += datapoint_num_in_bp
            batchpoint_num_in_cp += 1

            # (+) Accuracy Rate
            right, whole = int(sum(preds == targets)), int(targets.size(0))

            for i in range(whole):
                pred_tag = idx2tag[preds[i]]
                target_tag = idx2tag[targets[i]]
                pred_target_pairs.append([ target_tag, pred_tag])


            tokenpoint_right_num_in_cp += right
            tokenpoint_total_num_in_cp += whole

            tokenpoint_right_num_in_ep += right
            tokenpoint_total_num_in_ep += whole
            
            # (Enter Check Point)
            if (batchpoint_idx_in_ep + 1) % checkpoint_size_of_bp == 0 or (batchpoint_idx_in_ep + 1) == batchpoint_num_in_ep:
                # start logging the loss only.
                temp_time  = time.time()
                temp_cost  = temp_time - temp_start
                temp_start = temp_time

                speed_dp_per_sec_in_cp = round(data_point_num_in_cp/temp_cost, 3)
                data_point_num_in_cp = 0
                
                loss_per_bp_in_cp  = round(loss_in_cp/batchpoint_num_in_cp, 3)
                loss_in_cp = 0
                batchpoint_num_in_cp = 0

                acc_in_cp = round(tokenpoint_right_num_in_cp / tokenpoint_total_num_in_cp, 3)
                tokenpoint_right_num_in_cp = 0
                tokenpoint_total_num_in_cp = 0

                # (+) log information for this batchpoint
                print('\t'.join(['[MODEL POINT: eval]', 
                                 'BatchPoint: {}/{}'.format(batchpoint_idx_in_ep+1, batchpoint_num_in_ep),
                                 'DataPoint: {}/{};'.format((batchpoint_idx_in_ep) * batchpoint_size_of_dp + len(batch), data_point_num_in_ep),
                                 'Speed: {:.3f} dp/s;'.format(speed_dp_per_sec_in_cp),
                                 'Accuary: {};'.format(acc_in_cp),
                                 'Loss: {:.3f} /bp'.format(loss_per_bp_in_cp)]))

                sys.stdout.flush()

            # (Enter Last Batch Point, i.e., Model Point, Epoch Point)
            if (batchpoint_idx_in_ep + 1) == batchpoint_num_in_ep:
                loss_per_bp_in_ep = round(loss_in_ep/batchpoint_num_in_ep, 3)
                # loss_in_ep = 0
                # batchpoint_num_in_ep = 0

                acc_in_ep = round(tokenpoint_right_num_in_ep / tokenpoint_total_num_in_ep, 3)
                # tokenpoint_right_num_in_ep = 0
                # tokenpoint_total_num_in_ep = 0
                results = {}
                for tag in idx2tag:
                    tag_results = {}
                    matched_size = len([i for i in pred_target_pairs if i[0] == tag and i[0] == i[1]])
                    precise_size = len([i for i in pred_target_pairs if i[0] == tag])
                    recall_size  = len([i for i in pred_target_pairs if i[1] == tag])
                    tag_results['matched_size'] = matched_size
                    tag_results['precise_size'] = precise_size
                    tag_results['recall_size']  = recall_size
                    tag_results['precise'] =  matched_size/precise_size if precise_size > 0 else 0
                    tag_results['recall']  =  matched_size/recall_size if recall_size > 0 else 0
                    try:
                        tag_results['F1'] = tag_results['precise']*tag_results['recall']*2/(tag_results['precise'] + tag_results['recall'])
                    except:
                        tag_results['F1'] = 0
                    results[tag] = tag_results
                
                macro_f1 = len([i for i in pred_target_pairs if i[0] == i[1]]) / len(pred_target_pairs)
                
                results['macro'] = {k: macro_f1 for k in ['precise', 'recall', 'F1']}
                results['macro']['matched_size'] = len([i for i in pred_target_pairs if i[0] == i[1]])
                results['micro'] = {}
                for i in ['precise_size', 'recall_size' ]:
                    results['macro'][i] = len(pred_target_pairs)
                    # results['micro'][i] = len(pred_target_pairs)

                for score in ['precise', 'recall', 'F1']:
                    tmp = 0
                    for tag in idx2tag:
                        tmp = tmp + results[tag][score]
                    tmp = tmp/len(idx2tag)
                    results['micro'][score] = tmp

                # print(preds)
                # print(pred_entities)

            # -------- jump out of current batch point 
        # -------- jump out of the eval epoch
    # -------- jump out of torch.no_grad()

    return loss_per_bp_in_ep, acc_in_ep, results


############################################# model train
def seqcls_train(model, data, INPUT_FIELDS, TARGET_FIELD, TRAIN, MISC, device=None):
    """Train the Sequence Labeling Tasks with this function.

    Parameters:
    model (SeqLabel): the initialized model which contains a seqrepr and clfier.
                      the trainable parameters in them may be randomly initialized or pretrained.
    data (nlptext based sentence idx): There must be a NLPText BasicObject outside to make sure 
                                       that these sentences can be initialized.
    INPUT_FIELDS (dictionary): a dict whose keys are selected fields.
                               for each field, its keys contain all the field information of the selected corpus.
    TARGET_FIELD (list): the first element is field annoE (annotation entity), the rest elements are annoE's information.
    TRAIN (dictionary): the training settings used to train the model. 
                        the important ones are optimizier and learning rate.
    MISC (dictioanry): additional information. The ones used inside this function are useStartEnd, maxSentLeng.
                       MISC is also utilized in function `get_fieldsinfo_4_batchsents` to get field information.
    device (string):  cpu or cuda.
    """
    # (+) Random Seed Setting
    seed = TRAIN['random_seed']
    torch.manual_seed(seed); np.random.seed(seed)
    
    # (+) MISC 
    useStartEnd = MISC['useStartEnd']; start_end_length = 2 if useStartEnd else 0
    maxSentLeng = MISC['maxSentLeng']

    # (+) Path to record information of current model
    path_to_save_current_model = TRAIN['path_to_save_current_model']
    if not os.path.exists(path_to_save_current_model): os.mkdirs(path_to_save_current_model)
    print('[TRAINING :: INFO] Logging to {}'.format(path_to_save_current_model))

    # (+) Record Information in tf summary writer
    tf_summary_writer = tf and tf.summary.FileWriter(path_to_save_current_model)
    # print(tf_summary_writer)

    # (+) Ways to Initialize Model. Use Pretrained Seqrepr 
    # our model contains seqrepr and clfier. 
    # before training, two ways to initialize seqrepr.
    # one is randomly initializing.
    # another is using pretrained seqrepr.
    pretrained_seqrepr_path         = TRAIN['pretrained_seqrepr_path']
    continue_previous_trained_model = TRAIN['continue_previous_trained_model']

    if pretrained_seqrepr_path:
        # (+) model initial parameters
        model.load_pretrain_seqrepr(pretrained_seqrepr_path)

        # (+) which part to frozen
        for para in model.seqrepr.parameters():
            para.require_grad = False
        # (+) optimizer
        # lr_decay_num_per_epoch = TRAIN['lr_decay_num_per_epoch']
        optim_method, lr, lr_decay_rate, optimizer = build_optimizer(model, TRAIN)
        # current_lr = lr
        # current_lr = 0.001
        # clip_grad = TRAIN['clip_grad']
        # (+) history
        history = {}
        # (+) Get Current Point for Where to Start
        start_epochpoint_idx_in_tp = 0
        start_batchpoint_idx_in_ep = 0
        # (+) logging
        print('\n'+ '+'*60)
        print('[MODEL :: INITIALIZATION] load pretrained seqrepr from:', pretrained_seqrepr_path)
        print('+'*60, '\n')
    elif continue_previous_trained_model:
        # (+) model initial parameters
        model.load_model(path_to_save_current_model)
        # (+) optimizer
        optimizer.load_state_dict(torch.load(os.path.join(path_to_save_current_model, 'optimizer.pth')))
        # (+) history
        with open(os.path.join(path_to_save_current_model, 'history' + '.pkl'), 'rb') as f: 
            history = cPickle.load(f)
        # (+) Get Current Point for Where to Start
        # pay attention to handle this 
        start_epochpoint_idx_in_tp = history['history_epochpoint_idx_in_tp']
        start_batchpoint_idx_in_ep = history['history_batchpoint_idx_in_ep']
        # (+) logging
        print('\n'+ '+'*60)
        print('[MODEL :: INITIALIZATION] load whole model from:', path_to_save_current_model)
        print('+'*60, '\n')
    else:
        # (+) model initial parameters
        # model
        # (+) optimizer
        # lr_decay_num_per_epoch = TRAIN['lr_decay_num_per_epoch']
        optim_method, lr, lr_decay_rate, optimizer = build_optimizer(model, TRAIN)
        # current_lr = lr
        # current_lr = 0.001
        # clip_grad = TRAIN['clip_grad']
        # (+) history
        history = {}
        # (+) Get Current Point for Where to Start
        start_epochpoint_idx_in_tp = 0
        start_batchpoint_idx_in_ep = 0
        # (+) logging
        print('\n'+'+'*60)
        print('[MODEL :: INITIALIZATION] totally randomize all the parameters in seqrepr and clfier')
        print('+'*60, '\n')


    # (+) TODO: which part to update and which part to frozen

    # (+) Data
    train_sent_idx, valid_sent_idx, test_sent_idx = data
    
    # (+) Logging and Saving for Train Data
    # dp, bp, cp, mp, ep, tp
    batchpoint_size_of_dp = TRAIN['batchpoint_size_of_dp']

    data_point_num_in_ep  = len(train_sent_idx) # train datapoint number
    batchpoint_num_in_ep  = int(data_point_num_in_ep/batchpoint_size_of_dp) + 1 # train batchunit number
    epochpoint_num_in_tp  = TRAIN['epochpoint_num_in_tp'] # or trainpoint_size_of_epc
    
    checkpoint_size_of_bp = TRAIN['checkpoint_size_of_bp'] 
    modelpoint_size_of_bp = 'Epoch End'
    assert modelpoint_size_of_bp == 'Epoch End'
    # for language model
    # modelpoint_size_of_bp = TRAIN['modelpoint_size_of_bp'] # we don't use this one.
    # assert modelpoint_size % checkpoint_size == 0 


    # at the beginning of the train procession
    best_record = None
    # best_f1_all_val  = None
    # best_f1_all_test = None
    # best_f1_val  = None
    # best_f1_test = None
    best_loss_val  = None
    best_loss_test = None

    num_of_no_improvements = 0
    num_of_no_improvements_para = 0

    already_unfrozen = 0

    # history_batchpoint_idx
    for epochpoint_idx_in_tp in range(start_epochpoint_idx_in_tp, epochpoint_num_in_tp):
        # (+) start_batchpoint_idx_in_ep and train_sent_idx_for_ep
        if epochpoint_idx_in_tp == start_epochpoint_idx_in_tp and epochpoint_idx_in_tp != 0:
            start_batchpoint_idx_in_ep = start_batchpoint_idx_in_ep
            # TODO
            train_sent_idx_for_ep = pickle.load() 
        else:
            start_batchpoint_idx_in_ep = 0
            train_sent_idx_for_ep = np.random.shuffle(train_sent_idx)

        # np.random.shuffle(train_sent_idx)
        print('\n' +  '=='*30)
        print('[TRAINING :: EPOCH POINT] START AT EPOCH', epochpoint_idx_in_tp+1, '/', epochpoint_num_in_tp)

        # (+) time records
        epoch_start = time.time()
        temp_start  = epoch_start

        
        loss_in_cp = 0
        data_point_num_in_cp = 0
        batchpoint_num_in_cp = 0


        # Model Point
        loss_in_mp = 0 
        batchpoint_num_in_mp = 0


        # (+) TRAINING STAGE FOR AN EPOCH
        # 1 to last n-1 checkout points
        for batchpoint_idx_in_ep in range(start_batchpoint_idx_in_ep, batchpoint_num_in_ep):
            # NOW WE ARE IN A BATCH of the `epochpoint_idx_in_tp` th epoch

            # (+) batch's datapoint start and end idx.
            # get batch_idx, each element is a sentence local index
            data_point_startidx_in_bs = batchpoint_idx_in_ep * batchpoint_size_of_dp
            tmp   = data_point_startidx_in_bs + batchpoint_size_of_dp
            data_point_endidx_in_bs   = tmp if tmp <= data_point_num_in_ep else data_point_num_in_ep
            # print(data_point_startidx_in_bs, data_point_endidx_in_bs)
            # (+) batch input 
            # convert batch_idx to batch of nlptext.sentence.Sentence
            batch = train_sent_idx[data_point_startidx_in_bs: data_point_endidx_in_bs]
            if type(batch) == int: batch = [batch]
            batch = [Sentence(i) for i in batch]

            # (+) check batch's validation
            ##########################################################
            batch = [i for i in batch if i.length + start_end_length <= maxSentLeng]
            datapoint_num_in_bp = len(batch)
            if datapoint_num_in_bp == 0: print('zero batch'); continue
            ##########################################################

            # (+) prepare the input data and target data
            leng_st = torch.LongTensor([st.length for st in batch]).to(device)
            if len(batch) > 1: batch, leng_st, reverse_index = orderSeq(np.array(batch), leng_st)

            # (+) load input information
            # info_dict, targets, leng_st, misc_info = batch_tools.get_fieldsinfo_4_batchsents(batch, INPUT_FIELDS, TARGET_FIELD, MISC, device=device)
            try:
                info_dict, targets, leng_st, misc_info = batch_tools.get_fieldsinfo_4_batchsents(batch, INPUT_FIELDS, TARGET_FIELD, MISC, device=device)
            except:
                print(batch)
                # TODO: also report the error information
                continue


            # (+) learning rate warming up
            if TRAIN['lr_warm_up'] and (batchpoint_idx_in_ep + 1) <= 2000 and epochpoint_idx_in_tp > 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = param_group['lr'] + 5e-6


            # (+) model preparation
            model.train()
            model.zero_grad()

            # (+) calculate seqrepr
            seqrepr = model(info_dict, leng_st, misc_info)

            # (+) calculate loss
            loss_in_bp = model.loss_function(seqrepr, targets, leng_st, misc_info)
            loss_in_bp.backward()

            # (+) optimizer update parameters
            # torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad) # this is a problem.
            optimizer.step()
            
            # (+) record
            # bp in tp
            batchpoint_idx_in_tp = batchpoint_idx_in_ep + 1 + batchpoint_num_in_ep *  epochpoint_idx_in_tp

            # the observer for checkout_point
            loss_in_cp += loss_in_bp.item()
            data_point_num_in_cp  += datapoint_num_in_bp # datapoint_num in current checkpoint
            batchpoint_num_in_cp  += 1

            # the observer for checkout_point
            loss_in_mp += loss_in_bp.item()
            batchpoint_num_in_mp  += 1

            # CheckPoint (Enter a Checkpoint if Batchpoint is Special)
            if (batchpoint_idx_in_ep + 1) % checkpoint_size_of_bp == 0 or (batchpoint_idx_in_ep + 1) == batchpoint_num_in_ep:
                # start logging the loss only.
                temp_time  = time.time()
                temp_cost  = temp_time - temp_start
                temp_start = temp_time

                speed_dp_per_sec_in_cp = round(data_point_num_in_cp/temp_cost, 3)
                data_point_num_in_cp = 0
                
                loss_per_bp_in_cp  = round(loss_in_cp/batchpoint_num_in_cp, 3)
                loss_in_cp = 0
                batchpoint_num_in_cp = 0

                # (+) log information for this batchpoint
                print('\t'.join(['[CHECK POINT]', 
                                 'Epoch: {}'.format(epochpoint_idx_in_tp+1),
                                 'BatchPoint: {}/{}'.format(batchpoint_idx_in_ep+1, batchpoint_num_in_ep),
                                 'DataPoint: {}/{};'.format((batchpoint_idx_in_ep) * batchpoint_size_of_dp + len(batch), data_point_num_in_ep),
                                 'Speed: {:.3f} dp/s;'.format(speed_dp_per_sec_in_cp),
                                 'Loss: {:.3f} /bp'.format(loss_per_bp_in_cp)]))

                # (+) check whether the loss explores or not.
                if loss_per_bp_in_cp > 1e8 or str(loss_per_bp_in_cp) == "nan":
                    raise ValueError('The Loss is Explored!')

                # (+) save to tf_summary_writer
                if tf is not None:
                    add_summary_value(tf_summary_writer, 'loss_per_bp_in_cp',  loss_per_bp_in_cp,  batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'speed_dp_per_sec_in_cp', speed_dp_per_sec_in_cp, batchpoint_idx_in_tp)
                    tf_summary_writer.flush()

                sys.stdout.flush()
                # ------jump out of one checkpoint

            # ModelPoint (Enter Valid (and perhaps Test) and Model Saving if Checkpoint is Special)
            if (batchpoint_idx_in_ep + 1) == batchpoint_num_in_ep:

                # (+) Save current model to model_path + 'current_model'
                print('\n[MODEL POINT :: Save Current Model] Save Current Model to:')
                model.save_model(path_to_save_current_model, name = 'current')
                optimizer_path = os.path.join(path_to_save_current_model, 'optimizer_current.pth')
                torch.save(optimizer.state_dict(), optimizer_path)


                # (+) Write current situation to history
                history['history_epochpoint_idx_in_tp'] = epochpoint_idx_in_tp
                history['history_batchpoint_idx_in_ep'] = batchpoint_idx_in_ep

                with open(os.path.join(path_to_save_current_model, 'history' + '.pkl'), 'wb') as f: 
                    cPickle.dump(history, f)


                # (+) Valid: Current Model in Validation
                headers = ['precise_size', 'recall_size', 'matched_size', 'precise', 'recall', 'F1']
                print('\n[MODEL POINT :: Valid] EVAL validation data')
                loss_val, acc_val, acc_all_val = seqcls_eval(model, valid_sent_idx, INPUT_FIELDS, TARGET_FIELD, TRAIN,  MISC, device=device)
                # print(tabulate(f1_all_val, headers=f1_all_val.columns, tablefmt='psql', numalign='left'))
                results = pd.DataFrame(acc_all_val).T
                results = results[headers]
                print(tabulate(results, headers=headers, tablefmt='psql', numalign='left'))
                # Write the validation f1 score
                # print(acc_all_val)
                if tf is not None:
                    for catogery in acc_all_val:
                        for score_type in  acc_all_val[catogery]:
                            # print(catogery, score_type)
                            # print(acc_all_val[catogery][score_type])
                            add_summary_value(tf_summary_writer, 
                                              'Z-'+catogery + '-' + score_type,
                                              acc_all_val[catogery][score_type], 
                                              batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'loss_val',  loss_val,  batchpoint_idx_in_tp)
                    # add_summary_value(tf_summary_writer, 'f1_val',    f1_val,    batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'acc_val',   acc_val,   batchpoint_idx_in_tp)
                    tf_summary_writer.flush()


                # (+) Test: Current Model in Validation
                # in some case, it may be later.
                print('\n[MODEL POINT :: Test] EVAL test data')
                loss_test, acc_test, acc_all_test = seqcls_eval(model, test_sent_idx, INPUT_FIELDS, TARGET_FIELD, TRAIN,  MISC, device=device)
                # print(tabulate(f1_all_test, headers = f1_all_test.columns, tablefmt = 'psql', numalign = 'left'))
                results = pd.DataFrame(acc_all_test).T
                results = results[headers]
                print(tabulate(results, headers=headers, tablefmt='psql', numalign='left'))
                # Write the validation f1 score
                # print(acc_all_test)
                if tf is not None:
                    for catogery in acc_all_test:
                        for score_type in  acc_all_test[catogery]:
                            add_summary_value(tf_summary_writer, 
                                              'Z-'+catogery + '-' + score_type,
                                              acc_all_test[catogery][score_type], 
                                              batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'loss_test',  loss_test,  batchpoint_idx_in_tp)
                    # add_summary_value(tf_summary_writer, 'f1_test',    f1_test,    batchpoint_idx_in_tp)
                    add_summary_value(tf_summary_writer, 'acc_test',   acc_test,   batchpoint_idx_in_tp)
                    tf_summary_writer.flush()
                
                # (+) Is Current Model also the BEST MODEL so far?
                print('\n[MODEL POINT :: TRAIN VALID TEST INFO] Record Overall Information')

                loss_per_bp_in_mp = round(loss_in_mp/batchpoint_num_in_mp, 3)
                loss_in_mp = 0
                batchpoint_num_in_mp = 0
                
                if tf is not None:
                    add_summary_value(tf_summary_writer, 'loss_per_bp_in_mp',       loss_per_bp_in_mp,       batchpoint_idx_in_tp)

                # print(path_to_save_current_model, '\n\n')
                print('[MODEL POINT :: Loss] Current Average    Loss    Train :         {:.4f} '.format(loss_per_bp_in_mp) + 'loss_per_bp_in_mp')
                print('[MODEL POINT :: Loss] Current Average    Loss    Valid :         {:.4f} '.format(loss_val))
                print('[MODEL POINT :: Loss] Current Average    Loss     Test :         {:.4f} '.format(loss_test))
                print('[MODEL POINT ::  F1 ] Current Average    Accuary Valid :         {:.4f} '.format(acc_val))
                print('[MODEL POINT ::  F1 ] Current Average    Accuary  Test :         {:.4f} '.format(acc_test))


                # (+) Hold Current Record according to the predetermined criterion.
                # current_record = f1_val
                current_record = loss_val

                # if best_record is None or best_record < current_record: 
                if best_record is None or best_record > current_record: 
                
                    # current record is better the best historical record 
                    best_record = current_record
                    
                    best_loss_val  = loss_val
                    best_loss_test = loss_test

                    print('\n\n'+'--'*20)
                    
                    print('[MODEL POINT :: Best Info] New  Best         Valid Loss:        {:.4f} '.format(best_loss_val))
                    print('[MODEL POINT :: Best Info] New  Associated   Test  Loss:        {:.4f} '.format(best_loss_test))

                    # save model 
                    print('[MODEL POINT :: Model Selection] Now SAVE A BETTER MODEL to:')
                    model.save_model(path_to_save_current_model)
                    optimizer_path = os.path.join(path_to_save_current_model, 'optimizer.pth')
                    torch.save(optimizer.state_dict(), optimizer_path)
                    num_of_no_improvements = 0
                    num_of_no_improvements_para = 0
                    print('--'*20, '\n\n')
                    
                else:
                    print('\n\n','--'*20)
                    print('[MODEL POINT :: Model Selection] Current model is not the best')
                    print('[MODEL POINT :: Best Info] Last  Best        Valid Loss:        {:.4f} '.format(best_loss_val))
                    print('[MODEL POINT :: Best Info] Last  Associated  Test  Loss:        {:.4f} '.format(best_loss_test))

                    num_of_no_improvements = num_of_no_improvements + 1
                    num_of_no_improvements_para = num_of_no_improvements_para+ 1

                    # print(tabulate(best_f1_val.loc[['E']], headers = test_f1score.columns, tablefmt = 'psql', numalign = 'left'))
                    # print(tabulate(best_f1_test.loc[['E']], headers = test_f1score.columns, tablefmt = 'psql', numalign = 'left'))
                    
                    print('--'*20, '\n\n')
                # ------ jump out of one valid and test


            # (+) TODO learning rate modification in each bp
            if (batchpoint_idx_in_ep + 1) == batchpoint_num_in_ep:
                # (+) Change the Learning Rate.
                # Zeyu: HOW to decay the learning rate?
                # B: decay the learning rate with regard to the decrease of validation loss.
                # if val_old_loss - val_loss <= 0.01:
                #     optimizer, current_lr = lr_decay_scale(optimizer, lr_decay_rate, current_lr)
                # val_old_loss = val_loss
                if num_of_no_improvements >= 3 : 
                    optimizer, lr = lr_decay_scale(optimizer, lr_decay_rate*2, lr)
                    print("[LEARNING RATE :: modify lr] Learning Rate is Set as:", lr)
                    num_of_no_improvements = 0
                else:
                    optimizer, lr = lr_decay_scale(optimizer, lr_decay_rate, lr)
                    print("[LEARNING RATE :: modify lr] Learning Rate is Set as:", lr)
                # -------------------------------------------------------------
                # Zeyu: Two phrases (1) starting from zero, generally increase learning rate. (2) decay the learning rate.
                

                # Zeyu: HOW to decay the learning rate?
                # A: decay the learning rate at each n-th point in the whole epoch.
                # if (batch_id + 1) % int(total_batch / lr_decay_num_per_epoch) == 0:
                #     optimizer, current_lr = lr_decay_scale(optimizer, lr_decay_rate, current_lr)
                # -------------------------------------------------------------
                # ------jump out of one learning rate modification

            if num_of_no_improvements_para >= 10 or epochpoint_idx_in_tp == int(epochpoint_num_in_tp/2):
                if already_unfrozen == 0:
                    # (+) Change the Learning Rate.
                    # Zeyu: HOW to decay the learning rate?
                    # B: decay the learning rate with regard to the decrease of validation loss.
                    # if val_old_loss - val_loss <= 0.01:
                    #     optimizer, current_lr = lr_decay_scale(optimizer, lr_decay_rate, current_lr)
                    # val_old_loss = val_loss
                    for para in model.seqrepr.parameters():
                        para.require_grad = True

                    optim_method, lr, lr_decay_rate, optimizer = build_optimizer(model, TRAIN, lr = lr)
                    num_of_no_improvements_para = 0
                    already_unfrozen = 1
                else:
                    pass

                # ------jump out of one learning rate modification

            # ------jump out of one batchpoint
        print('\n'+'==' * 50, '\n\n')
        # ------jump out of one epoch
    # ------jump out of one training


    return model
