import torch
# TRAIN Configurations.

from .data_config import model_type
from .data_config import SeqRepr_Config_Name, use_residual_structure
from .data_config import use_pretrained_seqrepr, pretrained_seqrepr_path


############## device
device_type = 'cuda'  # ['cuda', 'cpu'] # if cuda is not available, let it raise an error.
device = torch.device(device_type)  # don't touch
##############


TRAIN = {}

############## initialization method
if use_pretrained_seqrepr:
    TRAIN['pretrained_seqrepr_path'] = pretrained_seqrepr_path
else:
    TRAIN['pretrained_seqrepr_path'] = False

# INIT Methods from empty
# TRAIN['pretrained_seqrepr_path'] = 'model/LuohuCorpus/char/token_subcomp_pinyin_pos/EMBED.token.subcomp.pinyin.pos.200-INDEP.TRat_BI-MIXlstm2.ME_BI-MIXlstm1-INTER.TR_mean1.ME_BI-MIXlstm2.400/MASKLM.11153' 

############## continue training from last breakpoint
TRAIN['continue_previous_trained_model'] = False 
# if true, load the model from previous checkoutpoint and continue the training.
# TRAIN['continue_previous_trained_model'] = True # TODO: give it a path to do. 
##############


############## (+) batch size
TRAIN['batchpoint_size_of_dp']  =  32


############## (+) logging settings
if 'lm' in model_type:
    # For Language Models
    TRAIN['train_prop'] =  0.99 

    # For Language Models, epoch number
    TRAIN['epochpoint_num_in_tp'] =  10

    # for eval
    TRAIN['checkpoint_size_of_bp_in_eval'] = 100

    # for train
    TRAIN['checkpoint_size_of_bp'] = 50# number of batches to save the models. (also include end epoch)
    TRAIN['modelpoint_size_of_bp'] = 500 # number of batches to save the models. (also include end epoch)
    
    TRAIN['peak_lr'] = 0.01
    TRAIN['lr_warm_up'] = True
    TRAIN['lr_warm_up_steps'] = 2000
    TRAIN['lr_decay_rate'] = 0.1 # decay if current val_loss > best_val_loss


    TRAIN['num_no_improvements_for_lr_decay'] = 1

    ############## (+) optimization method
    TRAIN['optim_method']  = 'adam'
    ##############

    ############## other bp hyperparameters
    TRAIN['clip_grad']     = 0
    TRAIN['momentum']      = 0
    TRAIN['l2'] = 1e-8
    ##############

else:
    # For Downstream Tasks
    TRAIN['train_prop'] =  0.80 

    # For Downstream Tasks
    TRAIN['epochpoint_num_in_tp'] = 50

    # for eval
    TRAIN['checkpoint_size_of_bp_in_eval'] = 5 # check model at the end epoch

    # for train
    TRAIN['checkpoint_size_of_bp']   = 20   # number of batches to log error and other things.
    TRAIN['modelpoint_size_of_bp']   = 'End Epoch' # check model at the end epoch
    TRAIN['lrdecaypoint_size_of_bp'] = 'End Epoch' # check model at the end epoch
    
    TRAIN['peak_lr'] = 0.002
    TRAIN['lr_decay_rate'] = 0.001
    TRAIN['lr_warm_up'] = False
    TRAIN['lr_warm_up_steps'] = 2000


    TRAIN['num_no_improvements_for_lr_decay'] = 1

    
    ############## (+) optimization method
    TRAIN['optim_method']  = 'adam' #  'adam' # 
    ##############
    ############## other bp hyperparameters
    TRAIN['clip_grad']     = 0
    TRAIN['momentum']      = 0
    TRAIN['l2'] = 1e-8
    ##############

############## random_seed 
TRAIN['random_seed']  = 10 
##############

