"""Tweaked version of corresponding AllenNLP file"""

### original code from allennlp, 
### we would add the license later !!!! 


import datetime
import logging
import math
import os
import time
import traceback
from typing import Dict, Optional, List, Tuple, Union, Iterable, Any
 
import torch
import torch.optim.lr_scheduler
from allennlp.common.checks import ConfigurationError, parse_cuda_device
from allennlp.common.util import dump_metrics, gpu_memory_mb, peak_memory_mb, lazy_groups_of
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict
from allennlp.models.model import Model
from allennlp.nn import util as nn_util
from allennlp.training import util as training_util
from allennlp.training.checkpointer import Checkpointer
from allennlp.training.learning_rate_schedulers import LearningRateScheduler
from allennlp.training.metric_tracker import MetricTracker
from allennlp.training.momentum_schedulers import MomentumScheduler
from allennlp.training.moving_average import MovingAverage
from allennlp.training.optimizers import Optimizer
from allennlp.training.tensorboard_writer import TensorboardWriter
from allennlp.training.trainer_base import TrainerBase

from tqdm import tqdm
#from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter

logger = logging.getLogger(__name__)


class Trainer(TrainerBase):
    def __init__(
        self,
        model: Model,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler,
        iterator: DataIterator,
        train_dataset: Iterable[Instance],
        validation_dataset: Optional[Iterable[Instance]] = None,
        patience: Optional[int] = None,
        validation_metric: str = "-loss",
        validation_iterator: DataIterator = None,
        shuffle: bool = True,
        num_epochs: int = 20,
        accumulated_batch_count: int = 1,
        serialization_dir: Optional[str] = None,
        num_serialized_models_to_keep: int = 20,
        keep_serialized_model_every_num_seconds: int = None,
        checkpointer: Checkpointer = None,
        model_save_interval: float = None,
        cuda_device: Union[int, List] = -1,
        grad_norm: Optional[float] = None,
        grad_clipping: float = 5.0,
        learning_rate_scheduler: Optional[LearningRateScheduler] = None,
        momentum_scheduler: Optional[MomentumScheduler] = None,
        summary_interval: int = 100,
        histogram_interval: int = None,
        should_log_parameter_statistics: bool = True,
        should_log_learning_rate: bool = False,
        log_batch_size_period: Optional[int] = None,
        moving_average: Optional[MovingAverage] = None,
        cold_step_count: int = 0,
        cold_lr: float = 1e-3,
        cuda_verbose_step=None,
        local_rank=0,
        ) -> None:
        
        super().__init__(serialization_dir, cuda_device)
        self.serialization_dir = serialization_dir
        # I am not calling move_to_gpu here, because if the model is
        # not already on the GPU then the optimizer is going to be wrong.
        self.parallel_model = model
        if torch.cuda.is_available():
            self.model = model.module
        else:
            self.model = model
        self.iterator = iterator
        self._validation_iterator = validation_iterator
        self.shuffle = shuffle
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.train_data = train_dataset
        self._validation_data = validation_dataset
        self.accumulated_batch_count = accumulated_batch_count
        self.cold_step_count = cold_step_count
        self.cold_lr = cold_lr
        self.cuda_verbose_step = cuda_verbose_step
        self.local_rank=local_rank

        if patience is None:  # no early stopping
            if validation_dataset:
                logger.warning(
                    "You provided a validation dataset but patience was set to None, "
                    "meaning that early stopping is disabled"
                )
        elif (not isinstance(patience, int)) or patience <= 0:
            raise ConfigurationError(
                '{} is an invalid value for "patience": it must be a positive integer '
                "or None (if you want to disable early stopping)".format(patience)
            )

        # For tracking is_best_so_far and should_stop_early
        self._metric_tracker = MetricTracker(patience, validation_metric)
        # Get rid of + or -
        self._validation_metric = validation_metric[1:]

        self._num_epochs = num_epochs
        if checkpointer is not None:
            # We can't easily check if these parameters were passed in, so check against their default values.
            # We don't check against serialization_dir since it is also used by the parent class.
            if num_serialized_models_to_keep != 20 \
                    or keep_serialized_model_every_num_seconds is not None:
                raise ConfigurationError(
                    "When passing a custom Checkpointer, you may not also pass in separate checkpointer "
                    "args 'num_serialized_models_to_keep' or 'keep_serialized_model_every_num_seconds'."
                )
            self._checkpointer = checkpointer
        else:
            self._checkpointer = Checkpointer(
                serialization_dir,
                keep_serialized_model_every_num_seconds,
                num_serialized_models_to_keep,
            )
        self._model_save_interval = model_save_interval
        self._grad_norm = grad_norm
        self._grad_clipping = grad_clipping
        self._learning_rate_scheduler = learning_rate_scheduler
        self._momentum_scheduler = momentum_scheduler
        self._moving_average = moving_average

        self._batch_num_total = 0

        if self.local_rank in [0,-1] or not torch.cuda.is_available():
            self.tensorboard_writer_train = SummaryWriter( serialization_dir + "/lr_record/train")
            self.tensorboard_writer_eval = SummaryWriter( serialization_dir + "/lr_record/eval")
            self._tensorboard = TensorboardWriter(
                get_batch_num_total=lambda: self._batch_num_total,
                serialization_dir=serialization_dir,
                summary_interval=summary_interval,
                histogram_interval=histogram_interval,
                should_log_parameter_statistics=should_log_parameter_statistics,
                should_log_learning_rate=should_log_learning_rate,
            )

        self._log_batch_size_period = log_batch_size_period
        self._last_log = 0.0  # time of last logging

        # Enable activation logging.
        if self.local_rank in [0,-1] or not torch.cuda.is_available():
            if histogram_interval is not None:
                self._tensorboard.enable_activation_logging(self.model)
        self.total_train_step = 0

    def rescale_gradients(self) -> Optional[float]:
        return training_util.rescale_gradients(self.model, self._grad_norm)
    
    def batch_loss(self, batch_group: List[TensorDict], for_training: bool) -> torch.Tensor:
        batch_input = {k:v.to(self.local_rank) for k, v in batch_group.items()  }  if torch.cuda.is_available() else batch_group
        output_dict = self.model(**batch_input)
        loss = output_dict["loss"]
        return loss, output_dict
 
    def _train_epoch(self, epoch: int) -> Dict[str, float]:
        """
        Trains one epoch and returns metrics.
        """
        train_loss = 0.0
        # Set the model to "train" mode.
        self.model.train()
        num_gpus = len(self._cuda_devices)
        
        logger.info("Training")
        cumulative_batch_size = 0
    
        if self._batch_num_total is None:self._batch_num_total = 0
        #data_num_batches =  math.ceil( self.train_data.dataset.__len__() / self.train_data.batch_size )
        #num_training_batches = math.ceil( data_num_batches /num_gpus)
        
        print("Training on gpu: ", num_gpus )
        self.optimizer.zero_grad()
        self._last_log = time.time()
        self.accumulated_batch_count = 1

        batches_this_epoch = 0
        #max_iter = 15000  # 15000k/epoch -> 300k 20epoch
        ## train_data has been devided into n_gpus 
        for batch_group in self.train_data:
            batches_this_epoch += 1
            self.total_train_step  += 1 
            #if batches_this_epoch >= max_iter: break 
            self._batch_num_total += 1

            batch_num_total = self._batch_num_total
            loss, output_dict = self.batch_loss(batch_group, for_training=True) 
            loss_value = loss.detach().item() 
            self.optimizer.zero_grad() 
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self._grad_clipping )  
            self.optimizer.step()

            if self.total_train_step % 500 == 0 and self.total_train_step > 0:
                print("current: ", self.total_train_step ) 

            train_loss += loss_value 
            ### write loss to tensorboard
            if self.local_rank in [0,-1] or not torch.cuda.is_available():
                self._tensorboard.add_train_scalar( "loss", loss_value  )
                self.tensorboard_writer_train.add_scalar("lr", self.optimizer.lr ,self.total_train_step )
                recoded_loss_items = ['loss', 'mlm_loss', 'wsd_loss']
                for loss_name in recoded_loss_items :
                    loss_v = output_dict[loss_name].detach().item()  
                    self.tensorboard_writer_train.add_scalar("loss/" + loss_name, loss_v ,self.total_train_step ) 
            del batch_group, loss
            if batches_this_epoch > 0 and batches_this_epoch % 50 == 0:
                torch.cuda.empty_cache()
           
            # Update moving averages
            if self._moving_average is not None: self._moving_average.apply(batch_num_total)
            # Update the description with the latest metrics
            ##batches_this_epoch + epoch*max_iter
            metrics = training_util.get_metrics(self.model, train_loss, self.total_train_step )
        metrics = training_util.get_metrics(self.model, train_loss, self.total_train_step, reset=True)
        return metrics

    def _validation_loss(self) -> Tuple[float, int]:
        """
        Computes the validation loss. Returns it and the number of batches.
        """
        logger.info("Validating")
        self.model.eval()

        # Replace parameter values with the shadow values from the moving averages.
        if self._moving_average is not None: self._moving_average.assign_average_value()
        recoded_loss_items = ['loss', 'mlm_loss', 'wsd_loss']
        record_loss = {k:0 for k in recoded_loss_items }
        val_loss = 0
        batches_this_epoch = 0   #+ 1e-10
        for batch_group in self._validation_data:
            loss, output_dict = self.batch_loss(batch_group, for_training=False)
            if loss is not None:
                batches_this_epoch += 1
                val_loss += loss.detach().item()   
                for loss_name in recoded_loss_items:
                    record_loss[loss_name] += output_dict[loss_name].detach().item()  
            # Update the description with the latest metrics
            val_metrics = training_util.get_metrics(self.model, val_loss, batches_this_epoch)

        # Now restore the original parameter values.
        if self.local_rank in [0,-1] or not torch.cuda.is_available():
            if batches_this_epoch == 0: batches_this_epoch = 1 
            self._tensorboard.add_validation_scalar( "loss", val_loss/batches_this_epoch )
            for loss_name in recoded_loss_items :
                self.tensorboard_writer_eval.add_scalar("loss/"+loss_name, record_loss[loss_name]/batches_this_epoch ,self.total_train_step ) 
            if self._moving_average is not None:   self._moving_average.restore()
        return val_loss, batches_this_epoch

    def cold_start_setting(self):
        if self.cold_step_count > 0:
            self.optimizer.set_cold_start()
            ### freeze the main bert model
            for name, param in self.model.named_parameters():
                if "bert_model" in name:
                    param.requires_grad = False
        return 

    def train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        
        training_util.enable_gradient_clipping(self.model, self._grad_clipping)
        logger.info("Beginning training.")

        train_metrics: Dict[str, float] = {}
        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()
        metrics["best_epoch"] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value

        self.cold_start_setting()
        print("start training......")
        for epoch in range( self._num_epochs):
            if epoch > 0 and torch.cuda.is_available(): # and epoch % 2 == 0:    
                self.train_data.sampler.set_epoch(epoch)
            # very time consuming for very large dataset in pretraining
            
            if epoch == self.cold_step_count and epoch != 0:
                self.optimizer.close_cold_start()
                for name, param in self.model.named_parameters():
                    param.requires_grad = True
                print("successfully freeze bert_decoder_v2.weight")
                
            epoch_start_time = time.time()
            print("new epoch at %d for rank %d... "%( epoch, self.local_rank) )
            train_metrics = self._train_epoch(epoch)
            # clear cache before validation
            torch.cuda.empty_cache()
            if self._validation_data is not None:
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, num_batches = self._validation_loss()
                    val_metrics = training_util.get_metrics(
                        self.model, val_loss, num_batches, reset=True
                    )

                    # Check validation metric for early stopping
                    this_epoch_val_metric = val_metrics[self._validation_metric]
                    self._metric_tracker.add_metric(this_epoch_val_metric)
                    if self._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break
            
            if self.local_rank in [0,-1] or not torch.cuda.is_available():
                self._tensorboard.log_metrics(
                    train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1
                )  

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics["best_epoch"] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = val_metrics

            if self._serialization_dir: #and self.local_rank in [0,-1]:
                dump_metrics(
                    os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics
                )

            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
          
            if self.local_rank in [0,-1] or not torch.cuda.is_available():
                print(epoch, self.local_rank)
                self._save_checkpoint(epoch)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * (
                    (self._num_epochs) / float(epoch + 1) - 1
                )
                formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s", formatted_time)
            epochs_trained += 1

        # make sure pending events are flushed to disk and files are closed properly
        #self._tensorboard.close()
        if self.local_rank in [0,-1] or not torch.cuda.is_available():
            self.tensorboard_writer_train.close()
            self.tensorboard_writer_eval.close()
        return 

    def _save_checkpoint(self, epoch: Union[int, str]) -> None:
        """
        Saves a checkpoint of the model to self._serialization_dir.
        Is a no-op if self._serialization_dir is None.

        Parameters
        ----------
        epoch : Union[int, str], required.
            The epoch of training.  If the checkpoint is saved in the middle
            of an epoch, the parameter is a string with the epoch and timestamp.
        """
        # If moving averages are used for parameters, we save
        # the moving average values into checkpoint, instead of the current values.
        if self._moving_average is not None:
            self._moving_average.assign_average_value()

        # These are the training states we need to persist.
        training_states = {
            "metric_tracker": self._metric_tracker.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "batch_num_total": self._batch_num_total,
        }

        # If we have a learning rate or momentum scheduler, we should persist them too.
        if self._learning_rate_scheduler is not None:
            training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict()
        if self._momentum_scheduler is not None:
            training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict()

        self._checkpointer.save_checkpoint(
            model_state=self.model.state_dict(),
            epoch=epoch,
            training_states=training_states,
            is_best_so_far=self._metric_tracker.is_best_so_far(),
        )

        # Restore the original values for parameters so that training will not be affected.
        if self._moving_average is not None:
            self._moving_average.restore()

    def _restore_checkpoint(self) -> int:
        """
        Restores the model and training state from the last saved checkpoint.
        This includes an epoch count and optimizer state, which is serialized separately
        from model parameters. This function should only be used to continue training -
        if you wish to load a model for inference/load parts of a model into a new
        computation graph, you should use the native Pytorch functions:
        `` model.load_state_dict(torch.load("/path/to/model/weights.th"))``

        If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
        this function will do nothing and return 0.

        Returns
        -------
        epoch: int
            The epoch at which to resume training, which should be one after the epoch
            in the saved training state.
        """
        model_state, training_state = self._checkpointer.restore_checkpoint()

        if not training_state:
            # No checkpoint to restore, start at 0
            return 0

        self.model.load_state_dict(model_state)
        self.optimizer.load_state_dict(training_state["optimizer"])
        if self._learning_rate_scheduler is not None \
                and "learning_rate_scheduler" in training_state:
            self._learning_rate_scheduler.load_state_dict(training_state["learning_rate_scheduler"])
        if self._momentum_scheduler is not None and "momentum_scheduler" in training_state:
            self._momentum_scheduler.load_state_dict(training_state["momentum_scheduler"])
        training_util.move_optimizer_to_cuda(self.optimizer)

        # Currently the ``training_state`` contains a serialized ``MetricTracker``.
        if "metric_tracker" in training_state:
            self._metric_tracker.load_state_dict(training_state["metric_tracker"])
        # It used to be the case that we tracked ``val_metric_per_epoch``.
        elif "val_metric_per_epoch" in training_state:
            self._metric_tracker.clear()
            self._metric_tracker.add_metrics(training_state["val_metric_per_epoch"])
        # And before that we didn't track anything.
        else:
            self._metric_tracker.clear()

        if isinstance(training_state["epoch"], int):
            epoch_to_return = training_state["epoch"] + 1
        else:
            epoch_to_return = int(training_state["epoch"].split(".")[0]) + 1

        # For older checkpoints with batch_num_total missing, default to old behavior where
        # it is unchanged.
        batch_num_total = training_state.get("batch_num_total")
        if batch_num_total is not None:
            self._batch_num_total = batch_num_total

        return epoch_to_return

    # Requires custom from_params.
    @classmethod
    def from_params(  # type: ignore
        cls,
        model: Model,
        serialization_dir: str,
        iterator: DataIterator,
        train_data: Iterable[Instance],
        validation_data: Optional[Iterable[Instance]],
        params,
        validation_iterator = None,
    ) -> "Trainer":

        patience = params.pop_int("patience", None)
        validation_metric = params.pop("validation_metric", "-loss")
        shuffle = params.pop_bool("shuffle", True)
        num_epochs = params.pop_int("num_epochs", 20)
        cuda_device = parse_cuda_device(params.pop("cuda_device", -1))
        grad_norm = params.pop_float("grad_norm", None)
        grad_clipping = params.pop_float("grad_clipping", None)
        lr_scheduler_params = params.pop("learning_rate_scheduler", None)
        momentum_scheduler_params = params.pop("momentum_scheduler", None)

        if isinstance(cuda_device, list):
            model_device = cuda_device[0]
        else:
            model_device = cuda_device
        if model_device >= 0:
            # Moving model to GPU here so that the optimizer state gets constructed on
            # the right device.
            model = model.cuda(model_device)

        parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad]
        optimizer = Optimizer.from_params(parameters, params.pop("optimizer"))
        if "moving_average" in params:
            moving_average = MovingAverage.from_params(
                params.pop("moving_average"), parameters=parameters
            )
        else:
            moving_average = None

        if lr_scheduler_params:
            lr_scheduler = LearningRateScheduler.from_params(optimizer, lr_scheduler_params)
        else:
            lr_scheduler = None
        if momentum_scheduler_params:
            momentum_scheduler = MomentumScheduler.from_params(optimizer, momentum_scheduler_params)
        else:
            momentum_scheduler = None

        if "checkpointer" in params:
            if "keep_serialized_model_every_num_seconds" in params \
                    or "num_serialized_models_to_keep" in params:
                raise ConfigurationError(
                    "Checkpointer may be initialized either from the 'checkpointer' key or from the "
                    "keys 'num_serialized_models_to_keep' and 'keep_serialized_model_every_num_seconds'"
                    " but the passed config uses both methods."
                )
            checkpointer = Checkpointer.from_params(params.pop("checkpointer"))
        else:
            num_serialized_models_to_keep = params.pop_int("num_serialized_models_to_keep", 20)
            keep_serialized_model_every_num_seconds = params.pop_int(
                "keep_serialized_model_every_num_seconds", None
            )
            checkpointer = Checkpointer(
                serialization_dir=serialization_dir,
                num_serialized_models_to_keep=num_serialized_models_to_keep,
                keep_serialized_model_every_num_seconds=keep_serialized_model_every_num_seconds,
            )
        model_save_interval = params.pop_float("model_save_interval", None)
        summary_interval = params.pop_int("summary_interval", 100)
        histogram_interval = params.pop_int("histogram_interval", None)
        should_log_parameter_statistics = params.pop_bool("should_log_parameter_statistics", True)
        should_log_learning_rate = params.pop_bool("should_log_learning_rate", False)
        log_batch_size_period = params.pop_int("log_batch_size_period", None)

        params.assert_empty(cls.__name__)
        return cls(
            model,
            optimizer,
            iterator,
            train_data,
            validation_data,
            patience=patience,
            validation_metric=validation_metric,
            validation_iterator=validation_iterator,
            shuffle=shuffle,
            num_epochs=num_epochs,
            serialization_dir=serialization_dir,
            cuda_device=cuda_device,
            grad_norm=grad_norm,
            grad_clipping=grad_clipping,
            learning_rate_scheduler=lr_scheduler,
            momentum_scheduler=momentum_scheduler,
            checkpointer=checkpointer,
            model_save_interval=model_save_interval,
            summary_interval=summary_interval,
            histogram_interval=histogram_interval,
            should_log_parameter_statistics=should_log_parameter_statistics,
            should_log_learning_rate=should_log_learning_rate,
            log_batch_size_period=log_batch_size_period,
            moving_average=moving_average,
        )
