import logging
import os

homedir = os.path.expanduser('~')


def add_decoder_parameters_group(parser):
    group = parser.add_argument_group(title="Decoder parameters")
    group.add_argument('--decoder', choices=['crf', 'softmax'], help='Whether to use softmax or crf')
    group.add_argument('--decoder_dropout', default=0.2, type=float, help='Dropout value in the decoder.')


def add_eval_parameters_group(parser):
    group = parser.add_argument_group(title="Evaluation parameters")
    group.add_argument('--num_samples', type=int, default=8, help='Number of tasks to average on during validation.')
    group.add_argument('--num_queries', type=int, default=128, help='Number of test cases per tasks during validation.')


def add_bert_encoder_parameters_group(parser):
    group = parser.add_argument_group(title="BERT encoder parameters")
    group.add_argument('--bert_model_path', type=str, help='Path to bert model',
                       default=os.environ.get('SM_CHANNEL_BERT'))


def add_fsl_parameters_group(parser):
    group = parser.add_argument_group(title="FSL task campling parameters")
    group.add_argument('-k', default=10, type=int, help='Number of training examples in a task.')
    group.add_argument('-n', default=4, type=int, help='Number of slots in a task, excluding "other".')


def add_meta_learning_optimizer_parameters_group(parser):
    group = parser.add_argument_group(title="Meta learning optimization parameters")
    group.add_argument('--meta_gamma', type=float, default=0.0, help='Meta weight decay (regularization)')
    group.add_argument('--meta_num_updates', type=int, default=1024, help='Number of steps for per meta-epoch')
    group.add_argument('--meta_max_epochs', type=int, default=40, help='Maximum number of meta train epochs')
    group.add_argument('--meta_learning_rate', type=float, default=0.0001, help='Base learning rate for meta-learning')
    group.add_argument('--meta_encoder_learning_rate', type=float, default=0.00001,
                       help='Encore learning rate for meta-learning')


def add_optimizer_parameters_group(parser):
    group = parser.add_argument_group(title="Optimization parameters")
    group.add_argument('--max_epochs', type=int, default=10, help='Maximum number of training epochs')
    group.add_argument('--num_updates', type=int, default=128, help='Number of steps for per epoch')
    group.add_argument('--learning_rate', type=float, default=0.001, help='Base learning rate')
    group.add_argument('--gamma', type=float, default=0.0, help='Weight decay (regularization)')
    group.add_argument('--encoder_learning_rate', default=0.0001, type=float,
                       help='Finetune the encoder using this learning rate. Set to 0 to deactivate finetuning.')


def add_early_stopping_parameters_group(parser):
    group = parser.add_argument_group(title="Early stopping parameters")
    group.add_argument('--delta', type=float, default=.0, help='Minimum delta to consider it an improvement')
    group.add_argument('--patience', type=int, default=5, help='Number of non improving epochs before stopping')


def add_reptile_parameters_group(parser):
    group = parser.add_argument_group(title="Reptile parameters")
    group.add_argument('--num_steps', type=int, default=5, help='Number of steps for the first order approximation')


def add_computation_parameters_group(parser):
    group = parser.add_argument_group(title="Computation parameters")
    group.add_argument('--cpu', action='store_true', help='whether using CPU')
    group.add_argument('--batch_size', type=int, default=64, help='Number of sentences in each batch')
    group.add_argument('--seed', type=int, default=42, help='Seed')


def add_logging_parameters_group(parser):
    group = parser.add_argument_group(title="Logging parameters")
    group.add_argument('--log_file', type=str, default=None, help='Path to the log file')
    group.add_argument('--log_validation', type=bool, default=False, help='Log finetuning during validation')
    group.add_argument('--verbose', type=int, default=logging.INFO, help='Verbosity of logging')


def add_read_parameters_group(parser):
    group = parser.add_argument_group(title="Read parameters")
    group.add_argument('--alphabets_folder', type=str, default=os.environ.get('SM_CHANNEL_ALPHABET'),
                       help='Folder to read alphabets files')
    group.add_argument('--max_token_length', type=int, default=40,
                       help='Padding length')
    group.add_argument('--trained_model', type=str, default=os.environ.get('SM_CHANNEL_MODEL'),
                       help='Trained model')


def add_write_parameters_group(parser):
    group = parser.add_argument_group(title="Output and model directories")
    group.add_argument('--output_folder', type=str, default=os.environ.get('SM_OUTPUT_DATA_DIR'),
                       help='Folder to store the learned weights')
    group.add_argument('--model_folder', type=str, default=os.environ.get('SM_MODEL_DIR'),
                       help='Folder to read or store the learned weights')
