from trainer import *

class QwenTrainer(Trainer):
    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        """
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
        Works both with or without labels.
        [zeng hang]
        generate outputs & return bleu score or rouge score
        """
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

        # if eval is called w/o train, handle model prep here
        if self.is_deepspeed_enabled and self.deepspeed is None:
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)

        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
        # TODO: 现在是写死了针对Qwen2ForCausalLM
        model.generation_config.pad_token_id = self.tokenizer.pad_token_id

        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            if self.is_fsdp_enabled:
                self.model = model

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)

        batch_size = self.args.eval_batch_size

        logger.info(f"***** Running {description} *****")
        if has_length(dataloader):
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
        logger.info(f"  Batch size = {batch_size}")

        model.eval()

        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
        eval_dataset = getattr(dataloader, "dataset", None)

        if args.past_index >= 0:
            self._past = None

        # losses/preds/labels on CPU (final containers)
        all_preds = []
        all_labels = []
        all_inputs = []
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop

        for step, inputs in enumerate(dataloader):
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size

            # Prediction step
            # [Done] TODO: 实现batch size>1的并行预测（等待qwen更新batched_generate）
            # TODO: 去掉 先tokenize后decode 直接传str，可能需要改data_collate_fn(train.py)
            # print("[ZH DEBUG] before generate", inputs)
            tasks_input = inputs['tasks'] if 'tasks' in inputs else None
            model_pred = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=32, tasks=tasks_input)
            inputs_str = self.tokenizer.batch_decode(inputs['input_ids'], skip_special_tokens=True)
            model_pred_str = self.tokenizer.batch_decode(model_pred, skip_special_tokens=True)
            for pred, input_str in zip(model_pred_str, inputs_str):
                if not pred.startswith(input_str):
                    print("======= pred not starts with input ======")
                    print("input_ids", inputs['input_ids'])
                    print("input_strs", inputs_str)
                    print("model_pred_str", model_pred_str)
                    print("======= pred not starts with input ======")
                    all_preds.append(pred.replace("<|endoftext|>", "").strip())
                # assert pred.startswith(input_str), "[ZH ERROR] The model prediction must start with input_str"
                else:
                    all_preds.append(pred.replace(input_str, "").replace("<|endoftext|>", "").strip())
                all_inputs.append(input_str)
            labels = torch.where(inputs['labels'] < 0, self.tokenizer.pad_token_id, inputs['labels'])
            all_labels += self.tokenizer.batch_decode(labels, skip_special_tokens=True)
            assert len(all_preds) == len(all_labels), f"[ZH ERROR] len(all_preds) != len(all_labels): {len(all_preds)} / {len(all_labels)}"
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)

        # After all calls to `.gather_function`, reset to `gather_for_metrics`:
        self.gather_function = self.accelerator.gather_for_metrics
        if args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Number of samples
        if has_length(eval_dataset):
            num_samples = len(eval_dataset)
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
            num_samples = eval_dataset.num_examples
        else:
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
        if num_samples == 0 and observed_num_examples > 0:
            num_samples = observed_num_examples

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels), all_inputs, os.path.join(args.eval_output_dir, "evaluate_result", f"result-step{self.state.global_step}", f"pred_result_rank{self.args.local_rank}.json"))
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if hasattr(self, "jit_compilation_time"):
            metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
