from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments

@dataclass
class FewShotTrainingArguments(TrainingArguments):
    # Prompt-tuning parameters.
    label_embeddings_as_centroids: Optional[bool]= field(default=False, metadata={"help": "if set, uses label embeddings as centroids."})
    prompt_tune: Optional[bool] = field(default=False, metadata={"help": "If sets, adds prompts token to the input and only tune them."}) 
    prompt_length: Optional[int] = field(default=20, metadata={"help": "Sets the number of tokens for prompt-tuning."})
    init_prompt_from_vocab: Optional[bool] = field(default=True, metadata={"help": "If set, initializes the prompt tokens' embedding"
        "from the given pretrained model's vocabulary."})
    prompt_init_range: Optional[float] = field(default=1e-4, metadata={"help": "Defines the initialization range."})
    # Soft pet parameters.
    mask_position: Optional[str] = field(default=None, metadata={"help": "This defines the position of mask in case of"
    "having two sentences `0`: [p,h,m],[]  `1`: [p,m,h],[]  `2`: [p],[m,h] , `3`: [p],[h,m]"})
    compute_time: Optional[bool] = field(default=False, metadata={"help": "If set, computes the training time."})
    compute_inference_time: Optional[bool] = field(default=False, metadata={"help": "If set, computes the inference time."})
    compute_memory: Optional[bool] = field(default=False, metadata={"help": "If set, computes the memory."})
    train_classifier: Optional[bool] = field(default=False, metadata={"help": "If set trains a classifier."})
    vectorize_pet: Optional[bool] = field(default=False, 
        metadata={"help": "If set and feasible (length of verbalizers are the same), vectorizes the pet."})
    with_augmentation: Optional[bool] = field(default=True, 
        metadata={"help": "If set, augment the data for COPA."})
    multiclass_ce_loss: Optional[bool] = field(default=False,
        metadata={"help":"If set uses the multiclass cross-entropy loss."})
    token_hinge_loss: Optional[bool] = field(default=False, 
        metadata={"help": "If set, computes a multi-class classification hinge loss over the tokens."})
    prototypical_eval: Optional[bool] = field(default=False,
        metadata={"help": "If set, uses the prototypical evaluation during the inference."})
    prototypical_similarity: Optional[str] = field(default="cos",
        metadata={"help": "This can be `cos` for cosine similarity or `euc` for euclidean one."})
    token_classifier: Optional[bool] = field(default=False, 
        metadata={"help":"If set, requires to have a classifier per token."})
    extra_embd_initializer_range: Optional[float] = field(default=0.02,
        metadata={"help": "Defines the intialization range for the extra embedding added."}
    )
    train_in_batch: Optional[bool] = field(default=False,
        metadata={"help": "If set, trains the model in batches."} )
    best_verbalizers_dir: Optional[str] = field(default="best_verbalizers/balanced/",
        metadata={"help": "defines the path to the best verbalizers."})
    find_best_verbalizers: Optional[bool] = field(
      default=False, metadata={"help": "If set, computes the best verbalizers for"
      "predicting examples of each labels."}
    )
    decoding_strategy: Optional[str] = field(default="default", 
       metadata={"help": "This can be `default` or `parallel`: to feed in the input"
      "with masks only once to the encoder."})
    soft_mask_labels_learning_rate: Optional[float] = field(default=1e-5)
    soft_pet_temperature: Optional[float] = field(default=1, metadata={"help": "temperature."})
    eval_soft_pet_aggregation: Optional[str] = field(
        default="sum", metadata={"help": "defines aggregation for eval."}
    )
    soft_pet_aggregation: Optional[str] = field(
       default=None, metadata={"help": "defines the aggregation operation for the losses."}
    )
    extra_without_original: Optional[bool] = field(
        default=True, metadata={"help": "If this is set, does not consider the"
        "total vocabs and only consider the textra tokens given for optimization."}
    )
    extra_tokens_init: Optional[str] = field(
       default="tokens", metadata={"help": "`tokens`: initialize from random tokens,"
       "`random`: initialize randomly, `verbalizers`: initialize from verbalizers."}
    )
    init_from_best_verbalizers: Optional[bool] = field(
        default=False, metadata={"help": "If set, initialize from the best verbalizers."}
    )
    num_extra_tokens: Optional[int] = field(
        default=-1, metadata={"help": "defines the number of extra tokens to be considered as"
                             "verbalizers in case of `extra_tokens` option, in case this is -1"
                             "this is computed from the length of verbalizers."}
    )
    soft_pet: Optional[bool] = field(
        default=False, metadata={"help": "If set, computes the loss of the PET in the soft way"
        "by minimizing the embeddings of the tokens."}        
    )
    # Clustering parameters.
    verbalizers_with_special_tokens: Optional[bool] = field(
        default=True, metadata={"help": "By default, we add special tokens to the verbalizers."}
    )
    use_masks_embeddings: Optional[bool]=field(default=True, metadata={"help": "If set, only uses the mask"
    "tokens embeddings."})
    hinge_p: Optional[int] = field(default=1, metadata={"help": "Defines the p for the hinge loss, can be 1 or 2."}) 
    hinge_margin: Optional[int] = field(default=1, metadata={"help": "Defines the margin for the hinge loss."})
    assignment_regularization: Optional[str] = field(
         default=None, 
         metadata={"help": "We can regularize the soft assignment matrix in case"
         "of using the automatic way of computing centroids. This can be [`l0`, `l1`]."}      
    )
    assignment_regularization_weight: Optional[float] = field(
        default=1e-2,
        metadata={"help": "Defines the weight for the regularization term."}
    )
    clustering_params_learning_rate: Optional[float] = field(
        default=1e-2, metadata={"help": "Defines the learning rate for the uninitialized parameters"
        "inside the clustering loss."}
    )
    clustering_logits_projection: Optional[str] = field(default=None,
        metadata={"help": "Defines a projection layer to be applied on clustering logits"
        "This can be None, linear, mlp"})
    verbalizer_pooling_type: Optional[str] = field(default="mean", 
        metadata={"help": "Defines the way to pool the token embeddings of verbalizers."
        "It can be [`mean`, `last`, `first`]."})
    input_pooling_type: Optional[str] = field(default="last",
        metadata={"help": "Defines the way to pool the token embeddings of the inputs."
        "It can be [`mean`, `last`]."})
    verbalizer_embd_type: Optional[str] = field(default="embd",
        metadata={"help": "It can be either `embd` to compute the verbalizer embeddings from"
        "input embedding layer or `enc` to compute their embedding from encoder output layer."
        "Note that in case of using `embd` we only use `mean` pooling and not `last` pooling."
        "`automatic`: to compute them from the whole vocabulary."})
    centroids_type: Optional[str] = field(default="verbalizers_embd", 
        metadata={"help": "Defines a way to modeling centroids, this can be `parametric` for"
        "modeling centroids as parameters, `cte`: to consider constant, randomly initialized"
        "centroids, and `verbalizers_embd`: to compute the centroids based on the embeddings of "
        "the given verbalizers in each step."} )
    clustering_loss: Optional[str] = field(default=None,
        metadata={"help": "If set to not None, uses the clustering losses. It can be `ce` for" 
        "cross-entropy, `ncc`: for nearest-centroid classifier loss, and `mul`: for multiplying"
        "the input embeddings by the centroids embeddings and then computing cross-entropy loss."
        "`hinge`: computing a hinge loss, in a way we minimize the CE with the correct labels and"
        "maximize it with the incorrect labels with some margins. `multi-hinge`: multi-class hinge"
        "loss."})
    verbalizer_ensemble: Optional[bool] = field(default=False, 
        metadata={"help": "If True, considers a weight matrix for computing ensemble of verbalizer" 
        "candidates corresponding to each label to compute their centroids."})
    weights_activation: Optional[str] = field(default="sigmoid",
        metadata={"help": "Defines the activation to be applied on the weights used to compute the"
        "ensemble of verbalizer candidates. This can be ['relu', 'silu', 'swish', 'gleu', 'tanh'," 
        "'gelu_new', 'gelu_fast', 'quick_gelu', 'mish', 'linear', 'sigmoid'] or None."})
    init_centroids: Optional[bool] = field(default=False,
        metadata={"help": "In case of using `parametric` or `cte` centroids, it can initialize them"
        "by computing centroids of the candidate verbalizers."})
    # Pruning's arguments.
    regularization: Optional[str] = field(default=None,
                                          metadata={"help": "Add L0 or L1 regularization to the mask scores."})
    final_lambda: Optional[float] = field(
        default=0.0,
        metadata={"help": "Regularization intensity (used in conjunction with `regularization`."})
    mask_scores_learning_rate: Optional[float] = field(default=1e-2,
                                                       metadata={
                                                           "help": "The Adam initial learning rate of the mask scores."}
                                                       )
    initial_threshold: Optional[float] = field(default=1.0,
                                               metadata={"help": "Initial value of the threshold (for scheduling)."})
    final_threshold: Optional[float] = field(default=0.7,
                                             metadata={"help": "Final value of the threshold (for scheduling)."})
    initial_warmup: Optional[int] = field(
        default=1,
        metadata={"help": "Run `initial_warmup` * `warmup_steps` steps of threshold warmup"
                          "during which threshold stays at its `initial_threshold` value (sparsity schedule)."}
    )
    final_warmup: Optional[int] = field(
        default=2,
        metadata={"help": "Run `final_warmup` * `warmup_steps` steps of threshold cool-down "
                          "during which threshold stays at its final_threshold value (sparsity schedule)."}
    )
    pruning_method: Optional[str] = field(
        default=None,
        metadata={"help": "Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning,"
                          " topK = Movement pruning, sigmoied_threshold = Soft movement pruning), None=not applying"
                          "pruning."}
    )
    mask_init: Optional[str] = field(
        default="constant",
        metadata={"help": "Initialization method for the mask scores. Choices: constant, uniform, kaiming."}
    )
    mask_scale: Optional[float] = field(default=0.0,
                                        metadata={
                                            "help": "Initialization parameter for the chosen initialization method."})
    # Parameters to use stochastic weight averaging.
    swa_alpha: Optional[float] = field(default=0,
                                       metadata={"help": "Defines the weight of the pretrained model when averaging"
                                                         "between the pretrained model's weight and the finetuned one as"
                                                         "swa_alpha*pretrained_model+(1-swa_alpha)*finetuned_model. Set 0 to"
                                                         "use default. This needs to be a number between (0, 1)."})
    swa: Optional[bool] = field(default=False,
                                metadata={"help": "If specified, uses the stochastic weight averaing optimizer."})
    swa_start_step: Optional[int] = field(default=0, metadata={"help": "Defines the starting step"
                                                                       "to start comuting the average of the pretrained model and the model, this is 0 by default."})
    swa_lr: Optional[float] = field(default=0.05,
                                    metadata={"help": "Defines the swa learning rate for the scheduler."})
    swa_annealing_epochs: int = field(
        default=5,
        metadata={"help": "number of epochs in the annealing phase for swa scheduler."},
    )


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "The model checkpoint for weights initialization."
                    "Don't set if you want to train a model from scratch."
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from gpt2, gpt2-large, gpt2-medium"} # + ", ".join(MODEL_TYPES)},
    )
    config_overrides: Optional[str] = field(
        default=None,
        metadata={
            "help": "Override some existing default config settings when a model is trained from scratch. Example: "
                    "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
        },
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
                    "with private models)."
        },
    )
    def __post_init__(self):
        if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
            raise ValueError(
                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
            )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
    new_sampling: Optional[bool] = field(default=True, metadata={"help":"if set, uses the data sampling inside the codes."})
    add_masks: Optional[bool] = field(default=True, metadata={"help": "If set, adds mask tokens to the input."} )
    no_pattern: Optional[bool] = field(default=False, metadata={"help": "If set, removes the patterns."})
    ############################ TODO: to be cleaned up later on
    data_dir: Optional[str] = field(default=None, metadata={"help": "Specifies the data directory."})
    data_seed: Optional[int] = field(default=100, metadata={"help": "Specifies the seed used to sample the data."})
    K: Optional[int] = field(default=16, metadata={"help": "Specifies the number of training samples."})
    ############################
    task: Optional[str] = field(
        default=None, metadata={"help": "In case of passing the training files, it is additionally\
        required to pass the name of the task."})
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file."},
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input test data file."},
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
                    "value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                    "value if set."
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
                    "value if set."
        },
    )
    max_seq_length: Optional[int] = field(
        default=128,
        metadata={
            "help": "Maximum input sequence length with taking into account special tokens."
        }
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    pattern_id: Optional[int] = field(
        default=0, metadata={"help": "Defines a zero-based pattern index from the four available pattern."})
    separate_mask_and_inputs: bool = field(
        default=False,
        metadata={"help": "If selected, encode mask inputs separately."}
    )
    

@dataclass
class AdapterArguments:
    """
    Arguments related to tuning a language model better for better generalization,
    including training with adapters, and pruning methods.
    """
    # Adapter's arguments.
    adapter_tune: Optional[bool] = field(default=False,
                                         metadata={"help": "If set to true, tune adapters."})
    add_layer_norm_before_adapter: Optional[bool] = field(default=False)
    add_layer_norm_after_adapter: Optional[bool] = field(default=False)
    nonlinearity: Optional[str] = field(default="gelu_new")
    reduction_factor: Optional[int] = field(default=16)
    attn_reduction_factor: Optional[int] = field(default=16)
    tune_layernorms: Optional[bool] = field(default=False,
                                            metadata={"help": "If set, tunes the layernorms."})
    add_layer_norm_after_adapter_attn: Optional[bool] = field(default=False)
    add_layer_norm_before_adapter_attn: Optional[bool] = field(default=False)
    add_adapter_after_attention: Optional[bool] = field(default=True)
    add_adapter_after_feedforward: Optional[bool] = field(default=True)
    unfreeze_input_embd: Optional[bool] = field(default=False, 
        metadata={"help": "If set unfreezes the input embeddings."})
    key_tune: Optional[bool] = field(default=False,
                                     metadata={"help": "Whether to tune adapters on attention keys."})
    query_tune: Optional[bool] = field(default=False,
                                       metadata={"help": "Whether to tune adapters on attention queries."})
    value_tune: Optional[bool] = field(default=False,
                                       metadata={"help": "Whether to tune adapters on attention values."})
    remove_attn_upsampling: Optional[bool] = field(default=False,
                                                   metadata={
                                                       "help": "In case of adapters for attention, we can decide to remove the upsampling part."})
    print_params: Optional[bool] = field(default=False,
                                         metadata={"help": "If set, prints all the parameters."})
    freeze_embeddings: Optional[bool] = field(default=False,
                                              metadata={"help": "If set, freezes the input embedding."})
    freeze_model: Optional[bool] = field(default=False,
                                         metadata={"help": "If set, freezes the model."})
    tune_biases: Optional[bool] = field(default=False,
                                        metadata={"help": "If set, tunes only biases."})
    tune_lm_head: Optional[bool] = field(default=False,
                                         metadata={"help": "If set, tunes the lm-head also when tuning biases."})
    unfreeze_clustering_params: Optional[bool] = field(
        default=False)
    
