import base64
import io
import os
import paddle
from paddlenlp.trainer import Trainer
import paddle.nn as nn
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from paddle.io import random_split, DataLoader, Dataset, Subset
import logging
from .txt2img_dataset import TextImagePairDataset, worker_init_fn

from paddlenlp.transformers.model_utils import _add_variant
from paddlenlp.utils import profiler
from paddlenlp.utils.log import logger

from ppdiffusers.training_utils import unwrap_model


PADDLE_WEIGHTS_NAME = "model_state.pdparams"
TRAINING_ARGS_NAME = "training_args.bin"

logger = logging.getLogger(__name__)


def unwrap_model(model):
    # since there could be multiple levels of wrapping, unwrap recursively
    if hasattr(model, "_layers"):
        return unwrap_model(model._layers)
    else:
        return model


class GlyphDiffusionTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data_collator = self.collate_fn

    def collate_fn(self, features):
        caption = [feature["caption"] for feature in features]
        tags_dic = [feature["tags_dic"] for feature in features]
        ocr_info = [feature["ocr_info"] for feature in features]
        encoder_attn_mask = paddle.stack([feature["encoder_attn_mask"].astype(np.float32) for feature in features])
        img_mask = paddle.stack([feature["img_mask"].astype(np.float32) for feature in features])
        pixel_values = paddle.stack([feature["pixel_values"].astype(np.float32) for feature in features])
       
        return {
            "caption": caption,
            "ocr_info": ocr_info,
            "pixel_values": pixel_values,
            "tags_dic": tags_dic,
            "img_mask": img_mask,
            "encoder_attn_mask": encoder_attn_mask,
        }

    def compute_loss(self, model, inputs, return_logs=True):
        loss = model(**inputs)
        if return_logs:
            loss, loss_log = loss
            return loss, loss_log
        else:
            return loss

    def get_train_dataloader(self):
        """
        Returns the training [`~torch.utils.data.DataLoader`].

        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
        training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator

        dataloader_params = {
            "batch_size": self.args.train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
        }

        if not isinstance(train_dataset, paddle.io.IterableDataset):
            dataloader_params["sampler"] = self._get_train_sampler()
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["worker_init_fn"] = worker_init_fn

        dataloader = DataLoader(train_dataset, **dataloader_params)
        return dataloader

    def get_module_kohya_state_dict(self, module, prefix: str = ""):
        kohya_ss_state_dict = {}
        if prefix is not None and prefix != "" and not prefix.endswith("."):
            prefix += "."

        for peft_key, weight in module.get_trainable_state_dict().items():
            kohya_key = prefix + peft_key.rstrip(".weight")
            kohya_key = kohya_key.replace("lora_A", "lora_down.weight")
            kohya_key = kohya_key.replace("lora_B", "lora_up.weight")
            kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
            # transpose 2D weight for torch
            if weight.ndim == 2:
                weight = weight.T
            kohya_ss_state_dict[kohya_key] = np.ascontiguousarray(weight)

            # Set alpha parameter
            if "lora_down" in kohya_key:
                alpha_key = f'{kohya_key.split(".")[0]}.alpha'
                kohya_ss_state_dict[alpha_key] = np.ascontiguousarray(
                    paddle.to_tensor(module.lora_config.lora_alpha, dtype=weight.dtype)
                )

        return kohya_ss_state_dict

    def _save(self, output_dir=None, state_dict=None, merge_tensor_parallel=False):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        # if self.args.only_save_updated_model:
        unwraped_model = unwrap_model(self.model)
        if unwraped_model.config.train_lora:
            logger.info(f"Saving unet lora checkpoint to {output_dir}/lora")
            unwraped_model.unet.save_pretrained(os.path.join(output_dir, "lora"), save_model_config=False)
            from safetensors.numpy import save_file

            lora_kohya_state_dict = self.get_module_kohya_state_dict(unwraped_model.unet, prefix="lora_unet")
            save_file(
                lora_kohya_state_dict,
                os.path.join(output_dir, "lora", "sdxl_unet_lora.safetensors"),
                metadata={"format": "pt"}
            )
        else:
            logger.info(f"Saving checkpoint to {output_dir}/unet")
            for component_name in unwraped_model.trainable_component:
                if component_name == 'vision_embedding':
                    paddle.save(
                        unwraped_model.text_encoder.text_model.token_embedding.state_dict(),
                        os.path.join(output_dir, component_name),
                    )
                    paddle.save(
                        unwraped_model.text_encoder_2.text_model.token_embedding.state_dict(),
                        os.path.join(output_dir, f"{component_name}_2"),
                    )
                    continue
                component = getattr(unwraped_model, component_name)
                component.save_pretrained(os.path.join(output_dir, component_name))

        logger.info(f"Saving model checkpoint to {output_dir}")
        if state_dict is None:
            state_dict = self.model.state_dict()
        paddle.save(
            state_dict,
            os.path.join(
                output_dir,
                _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix),
            ),
        )
        if self.args.should_save:
            if self.tokenizer is not None:
                self.tokenizer.save_pretrained(output_dir)
            paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

    # def _save(self, output_dir=None, state_dict=None, merge_tensor_parallel=False):
    #     # save training params
    #     super()._save(output_dir=output_dir, state_dict=state_dict)

    #     # save pipeline params
    #     output_dir = output_dir if output_dir is not None else self.args.output_dir
    #     output_dir = output_dir + "_pipeline"

    #     if not os.path.exists(output_dir):
    #         os.makedirs(output_dir)

    #     unwraped = unwrap_model(self.model)
    #     for component_name in unwraped.trainable_component:
    #         # if component_name == 'unet_lora':
    #         #     unwraped.save_lora_weights(save_path=os.path.join(output_dir, component_name))
    #         #     continue
    #         component = getattr(unwraped, component_name)
    #         component.save_pretrained(os.path.join(output_dir, component_name))


    def training_step(
        self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]
    ) -> paddle.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Layer`):
                The model to train.
            inputs (`Dict[str, Union[paddle.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `paddle.Tensor`: The tensor with training loss on this batch.
        """
        if self.args.pipeline_parallel_degree > 1:
            return self.training_pipeline_step(model, inputs)

        model.train()
        inputs = self._prepare_inputs(inputs)

        with self.autocast_smart_context_manager():
            loss = self.compute_loss(model, inputs)
        if isinstance(loss, tuple):
            loss, loss_log = loss
        if (
            self.args.gradient_accumulation_steps > 1
            and not self._enable_delay_scale_loss()
        ):
            loss = loss / self.args.gradient_accumulation_steps

        if self.do_grad_scaling:
            self.scaler.scale(loss).backward()
        else:
            loss.backward()

        return loss.detach(), loss_log
