from transformers import Trainer
from transformers.trainer_utils import PredictionOutput, speed_metrics


from typing import Optional, Dict, List
import re
import time
import math

import torch

from datasets import Dataset

from args import BaseArgs


class CustomTrainerMixin(Trainer):

    args: BaseArgs

    def number_of_required_parameters(self) -> int:

        per_layer = {}

        # Regular explressions to detect the layer.
        re_adapters = re.compile(
            r"^.*layer\.(?P<layer>[0-9]+).*adapters\.(?P<adapter>.*?)\..*$"
        )

        re_switch = re.compile(
            r"^.*layer\.(?P<layer>[0-9]+)\.output.*switch_layer\.(?P<name>.*?)\.switch_logits$"
        )

        # Track the current values of the switches.
        positions = {}

        for n, p in self.model.named_parameters():

            if not p.requires_grad:
                continue

            res = re_adapters.match(n)
            if res:
                layer = int(res["layer"])

                # Add the size of the current parameter.
                if layer not in per_layer:
                    per_layer[layer] = {}
                if res["adapter"] not in per_layer[layer]:
                    per_layer[layer][res["adapter"]] = 0
                per_layer[layer][res["adapter"]] += math.prod(p.size())

            res = re_switch.match(n)
            if res:
                layer = int(res["layer"])
                names = res["name"].split(",")
                idx = torch.argmax(p, dim=-1)
                positions[layer] = names[idx]

        num_params = 0
        for layer in range(12):
            if layer in positions:
                if positions[layer] in per_layer[layer]:
                    num_params += per_layer[layer][positions[layer]]
            elif layer in per_layer:
                num_params += sum(per_layer[layer].values())
        return num_params

    def predict(
        self,
        test_dataset: Dataset,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "test",
    ):
        results = super().predict(
            test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )

        # Add the required params in the test run.
        tag = f"{metric_key_prefix}_required_params"
        results.metrics[tag] = self.number_of_required_parameters()

        return results

    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:

        prefix = metric_key_prefix

        # Evaluate the position of the switches.
        metrics = {f"{prefix}_required_params": self.number_of_required_parameters()}

        if eval_dataset is not None:
            metrics[f"{metric_key_prefix}_samples"] = len(eval_dataset)

        # Call the original evaluation loop.
        metrics.update(
            super().evaluate(
                eval_dataset, ignore_keys, metric_key_prefix=metric_key_prefix
            )
        )

        self.log(metrics)
        self.log_metrics("eval", metrics)
        self.save_metrics("eval", metrics)
        return metrics


class GLUETrainer(CustomTrainerMixin):
    pass


class QuestionAnsweringTrainer(CustomTrainerMixin):
    def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.eval_examples = eval_examples
        self.post_process_function = post_process_function

    def evaluate(
            self,
            eval_dataset=None,
            eval_examples=None,
            ignore_keys=None,
            metric_key_prefix: str = "eval"
    ):

        eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        eval_examples = self.eval_examples if eval_examples is None else eval_examples

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        eval_loop = self.evaluation_loop
        if self.args.use_legacy_prediction_loop:
            eval_loop = self.prediction_loop

        # Compute runtime stats.
        start_time = time.time()

        try:
            output = eval_loop(
                eval_dataloader,
                description="Evaluation",
                # No point gathering the predictions if there are no metrics,
                # otherwise we defer to self.args.prediction_loss_only
                prediction_loss_only=True if compute_metrics is None else None,
                ignore_keys=ignore_keys,
            )
        finally:
            self.compute_metrics = compute_metrics
            total_batch_size = self.args.eval_batch_size * self.args.world_size
            metrics = speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )

        if self.post_process_function is not None and self.compute_metrics is not None:
            eval_preds = self.post_process_function(
                eval_examples, eval_dataset, output.predictions
            )
            metrics.update(self.compute_metrics(eval_preds))

            # Add the required params in the test run.
            tag = f"{metric_key_prefix}_required_params"
            metrics[tag] = self.number_of_required_parameters()

            # Prefix all keys with metric_key_prefix + '_'
            for key in list(metrics.keys()):
                if not key.startswith(f"{metric_key_prefix}_"):
                    metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

            self.log(metrics)
        else:
            metrics = {}

        if self.args.tpu_metrics_debug or self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA
            # (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        self.control = self.callback_handler.on_evaluate(
            self.args, self.state, self.control, metrics
        )
        return metrics

    def predict(
        self,
        predict_dataset,
        predict_examples,
        ignore_keys=None,
        metric_key_prefix: str = "test"
    ):
        predict_dataloader = self.get_test_dataloader(predict_dataset)

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None

        if self.args.use_legacy_prediction_loop:
            eval_loop = self.prediction_loop
        else:
            eval_loop = self.evaluation_loop

        # Compute runtime stats.
        start_time = time.time()

        try:
            output = eval_loop(
                predict_dataloader,
                description="Prediction",
                # No point gathering the predictions if there are no metrics,
                # otherwise we defer to
                # self.args.prediction_loss_only
                prediction_loss_only=True if compute_metrics is None else None,
                ignore_keys=ignore_keys,
            )
        finally:
            self.compute_metrics = compute_metrics
            total_batch_size = self.args.eval_batch_size * self.args.world_size
            metrics = speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )

        if self.post_process_function is None or self.compute_metrics is None:
            return output

        predictions = self.post_process_function(
            predict_examples, predict_dataset, output.predictions, "predict"
        )
        metrics.update(self.compute_metrics(predictions))

        # Add the required params in the test run.
        tag_required_params = f"{metric_key_prefix}_required_params"
        metrics[tag_required_params] = self.number_of_required_parameters()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return PredictionOutput(
            predictions=predictions.predictions,
            label_ids=predictions.label_ids,
            metrics=metrics
        )
