from typing import List, Optional
from dataclasses import dataclass, field

import os
import wandb
import math
import time
import logging

from transformers import MultiLingAdapterArguments, TrainingArguments

SWITCH_INPUT_CHOICES = ["minimal", "pfeiffer"]

logger = logging.getLogger(__name__)


LOAD_DATASET_ARGS = {
    "rte": ("glue", "rte"),
    "qqp": ("glue", "qqp"),
    "cola": ("glue", "cola"),
    "mnli": ("glue", "mnli"),
    "mrpc": ("glue", "mrpc"),
    "qnli": ("glue", "qnli"),
    "sst2": ("glue", "sst2"),
    "stsb": ("glue", "stsb"),
    "trivia_qa": ("trivia_qa", "rc")
}


TASKS = {'rte', 'qqp', 'cola', 'mnli', 'mrpc', 'qnli', 'sst2', 'stsb', 'squad'}

METRIC_FOR_BEST_MODEL = {
    "mnli": "eval_accuracy",
    "qqp": "eval_f1",
    "qnli": "eval_accuracy",
    "sst2": "eval_accuracy",
    "cola": "eval_matthews_correlation",
    "stsb": "eval_spearmanr",
    "mrpc": "eval_f1",
    "rte": "eval_accuracy",
    "wnli": "eval_accuracy",
    "squad": "eval_f1"
}


@dataclass
class SwitchArgsMixin:

    # Check directly with wandb.
    check_switches_from_wandb: bool = False

    # Put adapters at some fixed locations.
    adapters_at: List[int] = field(default_factory=list)

    # Drop the skip-connections of adapters.
    adapter_drop_skip_connections: bool = False
    adapter_drop_skip_connections_training_only: bool = False
    switch_drop_skip_connections: bool = False

    # Dropout rate.
    switch_dropout_rate: float = 0.0

    # Temperature control.
    temp_N: int = 1
    temp_r: float = None
    temp_initial: float = 0.1
    temp_min: float = 0.1

    # Where to put switches.
    switches_at: List[int] = field(default_factory=list)

    # Fixed switch positions.
    fixed_configuration: List[int] = None

    # Fix last and first n layers.
    adapter_last_layers: int = None
    adapter_first_layers: int = None
    adapter_from_layer: int = None
    adapter_to_layer: int = None

    # If switches are used they use the same inputs.
    use_switches: bool = False
    switch_inputs: List[str] = field(default_factory=list)

    # Another way to define switches.
    switch_at_0: bool = False
    switch_at_1: bool = False
    switch_at_2: bool = False
    switch_at_3: bool = False
    switch_at_4: bool = False
    switch_at_5: bool = False
    switch_at_6: bool = False
    switch_at_7: bool = False
    switch_at_8: bool = False
    switch_at_9: bool = False
    switch_at_10: bool = False
    switch_at_11: bool = False

    # Use a switch regularization.
    switch_regularization: str = None
    switch_regularization_weight: float = 0.01
    switch_regularization_bias: float = None
    switch_regularization_inputs_costs: List[float] = field(default_factory=list)

    # Another way to define switches.
    switch_at_0: bool = False
    switch_at_1: bool = False
    switch_at_2: bool = False
    switch_at_3: bool = False
    switch_at_4: bool = False
    switch_at_5: bool = False
    switch_at_6: bool = False
    switch_at_7: bool = False
    switch_at_8: bool = False
    switch_at_9: bool = False
    switch_at_10: bool = False
    switch_at_11: bool = False

    # Default adapter.
    default_adapter: str = "rational"

    # Fix the switches
    fix_rational_switch: bool = False

    # Probability for soft fixed
    prob_for_soft_fixed: float = 0.9

    # Learning rate for probabilities.
    lr_for_switches: float = 0.05

    # Probability regularization weight.
    prob_reg_weight: float = 0.0
    prob_reg_power: float = 0.5

    # Learning rate for the rational adapters.
    lr_for_rational_activations: float = 0.01

    # Replace rational by identity.
    rational_adapter_non_linearity: str = "identity"
    default_adapter_non_linearity: str = "rational:one"

    # Limit the total number of input_1 selections
    limit_input_1: bool = False
    limit_input_1_after: int = None
    limit_input_1_after_weight: float = 0.1
    limit_input_1_after_scale: float = 10.0

    # Simple regularization.
    simple_regularization_weight: float = None

    def __post_init__(self):
        super().__post_init__()

        if self.check_switches_from_wandb and len(self.adapters_at) == 0:
            filters = {
                'config.seed': self.seed,
                'config.task_name': self.task_name,
                'config.baseline': False,
                'config.baseline_bert': False,
                'config.baseline_leave_out_all': False,
                'config.adapter_drop_skip_connections': self.adapter_drop_skip_connections,
                'config.switch_regularization': str(self.switch_regularization)
            }
            api = wandb.Api()
            runs = api.runs(os.environ['WANDB_PROJECT'], filters=filters)
            assert len(runs) == 1, "We expect only one run with this configuration."
            hist = runs[0].history()
            for i in range(12):
                tag = f'train/layer.{i}.prob.1'
                if hist[tag][hist[tag].notna()].iloc[-1] > 0.5:
                    self.adapters_at.append(i)

        if self.use_switches:
            if len(self.switch_inputs) == 0:
                raise ValueError(
                    "Please provide the inputs to the switches with `switch_inputs`"
                )

            for switch_input in self.switch_inputs:
                if switch_input.split(":")[0] not in SWITCH_INPUT_CHOICES:
                    raise ValueError("Incorrect switch options")

            if self.switch_regularization is not None:
                a = len(self.switch_regularization_inputs_costs)
                b = len(self.switch_inputs)
                if a != b:
                    raise ValueError(
                        "The argument switch_inputs_costs should have {b} values."
                    )
                assert self.switch_regularization_weight, \
                    "Specify a regularization weight."

        if self.adapter_first_layers is not None:
            n = self.adapter_first_layers
            self.fixed_configuration = [1] * n + [0] * (12 - n)

        elif self.adapter_last_layers is not None:
            n = self.adapter_last_layers
            self.fixed_configuration = [0] * (12 - n) + [1] * n

        elif self.adapter_from_layer is not None and self.adapter_to_layer is not None:
            a = self.adapter_from_layer
            b = self.adapter_to_layer
            self.fixed_configuration = [0] * a + [1] * (b - a + 1) + [0] * (12 - b - 1)

        if self.fixed_configuration is not None:
            assert len(self.fixed_configuration) == 12, "We need 12 positions."

        if self.lr_for_switches is None:
            self.lr_for_switches = self.learning_rate

        if self.lr_for_rational_activations is None:
            self.lr_for_rational_activations = self.learning_rate

        # Default value for temp_r.
        if self.temp_r is None:
            n = self.num_train_epochs / 2
            self.temp_r = -math.log(self.temp_min / self.temp_initial) / n

        for i in range(12):
            # If any flag is on put it in the list.
            if getattr(self, f"switch_at_{i}") and i not in self.switches_at:
                self.switches_at.append(0)

            # Turn on the flags based on the list.
            setattr(self, f"switch_at_{i}", i in self.switches_at)


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    baseline: bool = False
    baseline_bert: bool = False
    baseline_leave_out_all: bool = False

    model_name_or_path: str = field(
        default="bert-base-uncased",
        metadata={
            "help": "Path to pretrained model or model identifier "
            "from huggingface.co/models",
        }
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        }
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained tokenizer name or path if not the same as model_name"
        }
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Path to directory to store the pretrained models "
            "downloaded from huggingface.co"
        },
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={
            "help": "Whether to use one of the fast tokenizer (backed by the "
            "tokenizers library) or not."
        },
    )
    model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch "
            "name, tag name or commit id)."
        },
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running "
            "`transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )


@dataclass
class SimpleDataTrainingArguments:

    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate "
            "the number of training examples to this value if set."
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate "
            "the number of validation examples to this value if set."
        },
    )
    max_test_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate "
            "the number of test examples to this value if set."
        },
    )


@dataclass
class BaseArgs(
    SwitchArgsMixin,
    MultiLingAdapterArguments,
    ModelArguments,
    SimpleDataTrainingArguments,
    TrainingArguments
):

    # Data arguments:
    task_name: str = field(
        default=None,
        metadata=dict(help="The name of the task to train on: " + ", ".join(TASKS))
    )

    # Save the rational plots.
    save_rational_plots: bool = False

    # Use the validaton splt for testing.
    low_resources: int = None

    # Some extra defaults.
    load_best_model_at_end: bool = True
    num_train_epochs: int = 10
    learning_rate: float = 1e-4
    evaluation_strategy: str = "epoch"
    save_total_limit: int = 2
    switch_regularization_inputs_costs: List[int] = field(
        default_factory=lambda: [0, 2]
    )
    switch_inputs: List[str] = field(
        default_factory=lambda: ['minimal:identity', 'pfeiffer:rational:one']
    )

    def __post_init__(self):
        if self.task_name in METRIC_FOR_BEST_MODEL:
            self.metric_for_best_model = METRIC_FOR_BEST_MODEL[self.task_name]
        super().__post_init__()


@dataclass
class QAArgs(BaseArgs):
    dataset_name: Optional[str] = field(
        default=None,
        metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The configuration name of the dataset to use "
            "(via the datasets library)."
        }
    )
    train_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a text file)."}
    )

    validation_file: Optional[str] = field(
        default=None,
        metadata={
            "help": "An optional input evaluation data file to evaluate "
            "the perplexity on (a text file)."
        },
    )
    test_file: Optional[str] = field(
        default=None,
        metadata={
            "help": "An optional input test data file to evaluate "
            "the perplexity on (a text file)."
        },
    )
    overwrite_cache: bool = field(
        default=False,
        metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."}
    )
    max_seq_length: int = field(
        default=384,
        metadata={
            "help": "The maximum total input sequence length after "
            "tokenization. Sequences longer than this will be "
            "truncated, sequences shorter will be padded."
        },
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": "Whether to pad all samples to `max_seq_length`. "
            "If False, will pad the samples dynamically when "
            "batching to the maximum length in the batch (which can "
            "be faster on GPU but will be slower on TPU)."
        },
    )

    version_2_with_negative: bool = field(
        default=False,
        metadata={"help": "If true, some of the examples do not have an answer."}
    )
    null_score_diff_threshold: float = field(
        default=0.0,
        metadata={
            "help": "The threshold used to select the null answer: if the "
            "best answer has a score that is less than "
            "the score of the null answer minus this threshold, the null "
            "answer is selected for this example. "
            "Only useful when `version_2_with_negative=True`."
        },
    )
    doc_stride: int = field(
        default=128,
        metadata={
            "help": "When splitting up a long document into chunks, how "
            "much stride to take between chunks."
        },
    )
    n_best_size: int = field(
        default=20,
        metadata={
            "help": "The total number of n-best predictions to generate when "
            "looking for an answer."
        },
    )
    max_answer_length: int = field(
        default=128,
        metadata={
            "help": "The maximum length of an answer that can be generated. "
            "This is needed because the start "
            "and end predictions are not conditioned on one another."
        },
    )

    def __post_init__(self):

        if self.task_name == 'squad':
            self.dataset_name = 'squad'

        # Output dir requires some post-processing.
        self.output_dir = self.output_dir.replace("%t", str(int(1e7 * time.time())))

        def assert_extension(path, arg):
            ext = path.split(".")[-1]
            assert ext in ["csv", "json"], f"`{arg}` should be a csv or a json file."

        files = [self.train_file, self.validation_file, self.test_file]
        if self.dataset_name is None and all(f is None for f in files):
            raise ValueError(
                "Need either a dataset name or a training/validation file/test_file."
            )
        else:
            if self.train_file is not None:
                assert_extension(self.train_file, 'train_file')
            if self.validation_file is not None:
                assert_extension(self.validation_file, 'validation_file')
            if self.test_file is not None:
                assert_extension(self.test_file, 'test_file')

        super().__post_init__()


@dataclass
class GLUEArgs(BaseArgs):

    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after "
            "tokenization. Sequences longer than this will be truncated, "
            "sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False,
        metadata={"help": "Overwrite the cached preprocessed datasets or not."},
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": "Whether to pad all samples to `max_seq_length`. "
            "If False, will pad the samples dynamically when batching to "
            "the maximum length in the batch."
        },
    )

    # Shuffle the samples
    shuffle_samples: bool = False

    def __post_init__(self):

        # Output dir requires some post-processing.
        self.output_dir = self.output_dir.replace("%t", str(int(1e7 * time.time())))

        TASK_TO_KEYS = {
            "cola": ("sentence", None),
            "mnli": ("premise", "hypothesis"),
            "mrpc": ("sentence1", "sentence2"),
            "qnli": ("question", "sentence"),
            "qqp": ("question1", "question2"),
            "rte": ("sentence1", "sentence2"),
            "sst2": ("sentence", None),
            "stsb": ("sentence1", "sentence2"),
            "wnli": ("sentence1", "sentence2"),
        }

        if self.task_name is not None:
            self.task_name = self.task_name.lower()

        if self.task_name not in TASK_TO_KEYS:
            raise ValueError(
                "Unknown task_name, you should pick one in " + ",".join(TASK_TO_KEYS)
            )

        super().__post_init__()
