# Copyright (c) 2022 paddlepaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import json
import os
import random
from typing import Dict, List
import unicodedata
import einops
import types
from paddle.vision.transforms import ToTensor
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlenlp.peft import LoRAConfig, LoRAModel
from .vision_embedding import VisionEmbedding

from .ocr_predictor import OCRPredictor

from transformers import CLIPTokenizerFast

from .txt2img_dataset import TextProcessor, ImageProcessor
from paddlenlp.transformers import (
    CLIPTokenizer,
    CLIPTextModel,
    CLIPTextModelWithProjection,
    ChineseCLIPModel,
    ChineseCLIPProcessor,
    CLIPPretrainedModel,
)
from termcolor import colored
import gc
from ppdiffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DDIMScheduler,
    UNet2DConditionModel,
    is_ppxformers_available,
)
from ppdiffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from PIL import Image, ImageFont, ImageDraw
import logging
import numpy as np
from ppdiffusers.models.ema import LitEma
from .segmenter import UNetSegmenter
from ppdiffusers.models.attention_processor import (
    Attention,
    AttnProcessor,
    XFormersAttnProcessor,
    XFormersAttnProcessor,
    LoRAXFormersAttnProcessor,
    LoRAAttnProcessor,
)

from ppdiffusers.models.attention_processor import (
    AttnProcessor,
    AttnProcessor2_5,
    LoRAAttnProcessor,
    LoRAAttnProcessor2_5,
)

from ppdiffusers.training_utils import freeze_params, unfreeze_params, unwrap_model


UNET_LAYER_NAMES = [
    "down_blocks.0",
    "down_blocks.1",
    "down_blocks.2",
    "mid_block",
    "up_blocks.1",
    "up_blocks.2",
    "up_blocks.3",
]

logger = logging.getLogger(__name__)

logger.setLevel(logging.DEBUG)

def read_json(file):
    with open(file, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data


class GlyphDiffusionModel(nn.Layer):
    def __init__(self, model_args):
        super().__init__()
        self.config = model_args
        self.trainable_component = []
        ### 1. tokenzier
        tokenizer_name_or_path = (
            model_args.tokenizer_name_or_path
            if model_args.tokenizer_name_or_path is not None
            else os.path.join(model_args.pretrained_model_name_or_path, "tokenizer")
        )
        self.tokenizer = CLIPTokenizer.from_pretrained(
            pretrained_model_name_or_path=tokenizer_name_or_path,
            model_max_length=model_args.model_max_length,
            from_diffusers=True,
        )
        # torch tokenizer
        self.torch_tokenizer = CLIPTokenizerFast.from_pretrained(
            pretrained_model_name_or_path=tokenizer_name_or_path,
        )
        ## 1.2 tokenizer 2
        tokenizer_2_name_or_path = (
            model_args.tokenizer_2_name_or_path
            if model_args.tokenizer_2_name_or_path is not None
            else os.path.join(model_args.pretrained_model_name_or_path, "tokenizer_2")
        )
        self.tokenizer_2 = CLIPTokenizer.from_pretrained(
            pretrained_model_name_or_path=tokenizer_2_name_or_path,
            model_max_length=model_args.model_max_length,
            from_diffusers=True,
        )
        self.tokenizer.model_max_length = model_args.model_max_length

        ### 3. text encoder
        text_encoder_name_or_path = (
            model_args.text_encoder_name_or_path
            if model_args.text_encoder_name_or_path is not None
            else os.path.join(model_args.pretrained_model_name_or_path, "text_encoder")
        )
        self.text_encoder = CLIPTextModel.from_pretrained(
            pretrained_model_name_or_path=text_encoder_name_or_path,
            from_diffusers=True,
        )
        ## 3.1 text encoder 2
        text_encoder_2_name_or_path = (
            model_args.text_encoder_2_name_or_path
            if model_args.text_encoder_2_name_or_path is not None
            else os.path.join(
                model_args.pretrained_model_name_or_path, "text_encoder_2"
            )
        )
        self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
            pretrained_model_name_or_path=text_encoder_2_name_or_path,
            from_diffusers=True,
        )

        self.text_encoder = self.__custom_text_encoder_max_len(
            self.text_encoder, length=self.config.model_max_length
        )
        self.text_encoder_2 = self.__custom_text_encoder_max_len(
            self.text_encoder_2, length=self.config.model_max_length
        )

        ### 5. vae
        vae_name_or_path = (
            model_args.vae_name_or_path
            if model_args.vae_name_or_path is not None
            else os.path.join(model_args.pretrained_model_name_or_path, "vae")
        )
        self.vae = AutoencoderKL.from_pretrained(vae_name_or_path, from_diffusers=True)
        # 6. unet
        unet_name_or_path = (
            model_args.unet_name_or_path
            if model_args.unet_name_or_path is not None
            else os.path.join(model_args.pretrained_model_name_or_path, "unet")
        )
        # init unet2d
        self.unet = UNet2DConditionModel.from_pretrained(
            unet_name_or_path,
            low_cpu_mem_usage=False,
            device_map=None,
            from_diffusers=True if model_args.unet_name_or_path is None else False,
        )
       
        if self.config.lora_path is not None:
            self.unet = LoRAModel.from_pretrained(self.unet, self.config.lora_path)

        # 7. scheduler
        self.noise_scheduler = DDPMScheduler.from_pretrained(
            model_args.pretrained_model_name_or_path, subfolder="scheduler"
        )
        self.eval_scheduler = DDIMScheduler.from_pretrained(
            model_args.pretrained_model_name_or_path, subfolder="scheduler"
        )
        self.eval_scheduler.set_timesteps(model_args.num_inference_steps)
        # 8. ema
        self.use_ema = model_args.use_ema
        if self.use_ema:
            self.trainable_component.append("ema_model")
            self.ema_model = self.unet
            self.model_ema = LitEma(self.ema_model)
    
        self.text_processor = TextProcessor()
        self.image_processor = ImageProcessor(
            image_size=self.config.resolution, random_crop=False
        )
        ### model configs
        self.config_additional_tokens(self.tokenizer, self.text_encoder)
        self.config_additional_tokens(self.tokenizer_2, self.text_encoder_2)
    
        # text embedding
        if model_args.use_vision_embedding:
            # vision_embedding_1 = self.__config_vision_embedding(self.text_encoder)
            # if self.config.vision_embedding_ckpt_1 is not None:
            #     state_dict = paddle.load(self.config.vision_embedding_ckpt_1)
            #     vision_embedding_1 = vision_embedding_1.astype(state_dict['token_embedding.weight'].dtype)
            #     vision_embedding_1.set_state_dict(state_dict)
            # self.text_encoder.text_model.token_embedding = vision_embedding_1

            vision_embedding_2 = self.__config_vision_embedding(self.text_encoder_2)
            if self.config.vision_embedding_ckpt_2 is not None:
                state_dict = paddle.load(self.config.vision_embedding_ckpt_2)
                vision_embedding_2 = vision_embedding_2.astype(state_dict['token_embedding.weight'].dtype)
                vision_embedding_2.set_state_dict(state_dict)
            self.text_encoder_2.text_model.token_embedding = vision_embedding_2

        # memory efficient attention
        if (
            model_args.enable_xformers_memory_efficient_attention
            and is_ppxformers_available()
        ):
            try:
                self.unet.enable_xformers_memory_efficient_attention()
            except Exception as e:
                logger.warning(
                    "Could not enable memory efficient attention. Make sure develop paddlepaddle is installed"
                    f" correctly and a GPU is available: {e}"
                )

        self.unet = self.unet.to(dtype=paddle.bfloat16)
        self.text_encoder = self.text_encoder.to(dtype=paddle.bfloat16)
        self.text_encoder_2 = self.text_encoder_2.to(dtype=paddle.bfloat16)
        self.vae = self.vae.to(dtype=paddle.bfloat16)
        

        if model_args.is_train:
            logger.info(
                "use object localization loss: {}".format(
                    colored(self.config.use_glyph_localization, "yellow")
                ),
            )
            # object localization loss
            if self.config.use_glyph_localization:
                self.cross_attention_scores = {}
                self.cross_glyph_localization_weight = model_args.cross_glyph_localization_weight
                self.unet = unet_store_cross_attention_scores(self.unet, self.cross_attention_scores)
                self.glyph_localization_loss_fn = F.mse_loss

            # ocr embed loss
            logger.info(
                "use ocr embed loss: {}".format(
                    colored(self.config.ocr_rec_loss_weight > 0, "yellow")
                ),
            )
            if self.config.ocr_rec_loss_weight > 0:
                self.ocr_predictor = OCRPredictor()
                self.ocr_rec_model = self.ocr_predictor.build_ocr_rec_model(
                    ocr_rec_model_config=model_args.ocr_rec_config,
                    ckpt=model_args.ocr_predictor_name_or_path,
                )
                self.ocr_predictor.ocr_rec_model = self.ocr_rec_model

            ### config trainable parameters
            self.__config_trainable_parameters()

    def __config_vision_embedding(self, text_encoder):
        vision_embedding = VisionEmbedding(
            token_embedding=text_encoder.text_model.token_embedding,
            emb_type='ocr',
            glyph_channels=40 * 120
        )
        vision_embedding.process_visual_text = (
            self.image_processor.render_text_to_whiteboard_center
        )
        vision_embedding.ocr_predictor = OCRPredictor()
        vision_embedding.ocr_predictor.ocr_rec_model = (
            vision_embedding.ocr_predictor.build_ocr_rec_model(
                ocr_rec_model_config=self.config.ocr_rec_config,
                ckpt=self.config.ocr_predictor_name_or_path,
            )
        )
        freeze_params(vision_embedding.ocr_predictor.ocr_rec_model.parameters())
        return vision_embedding

    def __config_trainable_parameters(self):
        self.trainable_parameters = []
        freeze_params(self.vae.parameters())
        logger.info(
            f"VAE trainable parameters: {colored(self.vae.num_parameters(only_trainable=True), 'yellow')}"
        )
        if self.config.train_text_encoder or self.config.model_max_length != 77:
            self.trainable_parameters.append(self.text_encoder.parameters())
            self.trainable_parameters.append(self.text_encoder_2.parameters())
            self.trainable_component.append("text_encoder")
            self.trainable_component.append("text_encoder_2")
        else:
            freeze_params(self.text_encoder.parameters())
            freeze_params(self.text_encoder_2.parameters())
            # for random init ZH tokens
            self.trainable_component.append("text_encoder")
            self.trainable_component.append("text_encoder_2")
        if self.config.use_vision_embedding:
            self.trainable_parameters.append(self.text_encoder.text_model.token_embedding.parameters())
            self.trainable_parameters.append(self.text_encoder_2.text_model.token_embedding.parameters())
            self.trainable_component.append("vision_embedding")
        logger.info(
            f"Text Encoder total parameters: {colored(self.text_encoder.num_parameters(only_trainable=False), 'yellow')}"
        )
        logger.info(
            f"Text Encoder 2 total parameters: {colored(self.text_encoder_2.num_parameters(only_trainable=False), 'yellow')}"
        )
        logger.info(
            f"Text Encoder trainable parameters: {colored(self.text_encoder.num_parameters(only_trainable=True), 'yellow')}"
        )
        logger.info(
            f"Text Encoder 2 trainable parameters: {colored(self.text_encoder_2.num_parameters(only_trainable=True), 'yellow')}"
        )
        # unet
        if self.config.train_only_attn:
            freeze_params(self.unet.parameters())
            unet_glyph_norm = [
                p for n, p in self.unet.named_parameters() if "glyph_norm" in n
            ]
            unfreeze_params(unet_glyph_norm)
            self.trainable_parameters.append(unet_glyph_norm)
            unet_glyph_cross_attn = [
                p for n, p in self.unet.named_parameters() if "glyph_attn" in n
            ]
            self.trainable_parameters.append(unet_glyph_cross_attn)
            self.trainable_component.append("unet")
            logger.info(f"UNet: Train glyph cross attention parameters")
        elif self.config.train_lora:
            freeze_params(self.unet.parameters())  # only additional LoRA params
            self.add_lora_layers()  # must be called after unet is freezed
            self.trainable_parameters.append(self.unet.parameters())  # append(self.unet_lora_layers.parameters())
            self.trainable_component.append("unet")
            logger.info(f"UNet: Train LoRA parameters")
        else:
            self.trainable_parameters.append(self.unet.parameters())
            logger.info(f"UNet: Train all parameters")
            self.trainable_component.append("unet")
        logger.info(
            f"UNet trainable parameters: {colored(self.unet.num_parameters(only_trainable=True), 'yellow')}"
        )
        # ocr rec model
        if self.config.train_ocr_predictor and self.config.ocr_rec_loss_weight > 0:
            self.trainable_parameters.append(
                self.ocr_predictor.ocr_rec_model.parameters()
            )
            self.trainable_component.append("ocr_predictor")

        logger.info(f"trainable component: {self.trainable_component}")

    def get_trainable_parameters(self):
        return self.trainable_parameters

    # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
    def encode_prompt(
        self,
        prompt_batch,
        tags_dic,
        text_encoders,
        tokenizers,
        proportion_empty_prompts,
        is_train=True,
    ):
        prompt_embeds_list = []

        captions = []
        for caption in prompt_batch:
            if random.random() < proportion_empty_prompts:
                captions.append("")
            elif isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])

        for tokenizer, text_encoder in zip(tokenizers, text_encoders):
            text_inputs = tokenizer(
                captions,
                padding="max_length",
                max_length=tokenizer.model_max_length,
                truncation=True,
                return_tensors="np",
            )
            for key in text_inputs:
                text_inputs[key] = paddle.to_tensor(text_inputs[key])
            # add tags_dic to clip.text_model.token_embedding
            self.text_encoder.text_model.token_embedding.tags_dic = tags_dic

            prompt_embeds = text_encoder(
                text_inputs.input_ids,
                output_hidden_states=True,
            )
            # We are only ALWAYS interested in the pooled output of the final text encoder
            pooled_prompt_embeds = prompt_embeds[0]
            prompt_embeds = prompt_embeds.hidden_states[-2]
            # DEBUG: change to ch emb
            # prompt_embeds = self.emb_with_ch_tokens(prompt_embeds, tags_dic)
            bs_embed, seq_len, _ = prompt_embeds.shape
            prompt_embeds = prompt_embeds.reshape([bs_embed, seq_len, -1])
            prompt_embeds_list.append(prompt_embeds)

        prompt_embeds = paddle.concat(prompt_embeds_list, axis=-1)
        pooled_prompt_embeds = pooled_prompt_embeds.reshape([bs_embed, -1])
        return prompt_embeds, pooled_prompt_embeds

    def config_additional_tokens(self, tokenizer, text_encoder):
        import string
        placeholder_tokens_zh =[token for token in open('data/chinese_dict', 'r').readlines()] 
        placeholder_tokens = ["<"+token+">" for token in list(string.ascii_letters)]
        num_vectors = len(placeholder_tokens)
        num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
        if num_added_tokens != num_vectors:
            raise ValueError(
                f"num_added_tokens ({num_added_tokens}) != num_vectors ({num_vectors})"
            )
        # use `s` to initialize the `<s>`
        placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
        # Resize the token embeddings as we are adding new special tokens to the tokenizer
        text_encoder.resize_token_embeddings(len(tokenizer))
        # Initialise the newly added placeholder token with the embeddings of the initializer token
        with paddle.no_grad():
            token_embeds = text_encoder.get_input_embeddings()
            # we will compute mean
            for i, token_id in enumerate(placeholder_token_ids):
                tmp_init_token = tokenizer.encode(string.ascii_letters[i], add_special_tokens=False)["input_ids"]
                logger.debug(f'initialize token ({placeholder_tokens[i]}: {token_id}) using `{string.ascii_letters[i]}`')
                token_embeds.weight[token_id] = paddle.stack(
                    [token_embeds.weight[each] for each in ([tmp_init_token])]
                ).mean(0)
        ### add randomly init ZH texts
        num_vectors_zh = len(placeholder_tokens_zh)
        num_added_tokens_zh = tokenizer.add_tokens(placeholder_tokens_zh)
        if num_added_tokens_zh != num_vectors_zh:
            raise ValueError(
                f"num_added_tokens_zh ({num_added_tokens_zh}) != num_vectors_zh ({num_vectors_zh})"
            )
        text_encoder.resize_token_embeddings(len(tokenizer))


    def forward(
        self,
        caption: List[str] = None,
        ocr_info: list = None,
        pixel_values: paddle.Tensor = None,
        condition: paddle.Tensor = None,
        encoder_attn_mask: paddle.Tensor = None,
        tags_dic: dict = None,
        img_mask: paddle.Tensor = None, # 0 for default, 1 for mask
        **kwargs,
    ):
        """_summary_

        Args:
            input_ids (paddle.IntTensor, optional):
                shape=(bsz, max_seq_len). Defaults to None.
            pixel_values (paddle.FloatTensor, optional):
                shape=(bsz, 3, resolution, resolution). Defaults to None.
            condition (paddle.FloatTensor, optional):
                shape=(bsz, 3, resolution, resolution), whiteboard image with text(black) rendered. Defaults to None.
            encoder_attn_mask (paddle.FloatTensor, optional):
                shape=(bsz, latent_size**2, max_seq_len). Defaults to None.
            tags_dic (dict, optional):
                _description_. Defaults to None.
            glyph_embeds (paddle.FloatTensor, optional):
                _description_. Defaults to None.

        Returns:
            paddle.Tensoe: loss computed
        """
        bsz = pixel_values.shape[0]
        self.train()

        with paddle.amp.auto_cast(enable=False):
            with paddle.no_grad():
                self.vae.eval()
                latents = self.vae.encode(pixel_values).latent_dist.sample()
                latents = latents * self.vae.config.scaling_factor
                if np.random.randn() < 0.8:
                    noise = paddle.randn(
                        latents.shape
                    ) + self.config.noise_offset * paddle.randn(
                        shape=[latents.shape[0], latents.shape[1], 1, 1]
                    )  # offset noise
                else:
                    noise = (latents * (1-img_mask))
                timesteps = paddle.randint(0, self.noise_scheduler.num_train_timesteps, (latents.shape[0],)).astype("int64")
                noisy_latents = self.noise_scheduler.add_noise(
                    latents, noise, timesteps
                )
            # maybe replace with empty string or cond
            text_encoders = [self.text_encoder, self.text_encoder_2]
            tokenizers = [self.tokenizer, self.tokenizer_2]
            prompt_emb, pooled_prompt_emb = self.encode_prompt(
                prompt_batch=caption,
                tags_dic=tags_dic,
                tokenizers=tokenizers,
                text_encoders=text_encoders,
                proportion_empty_prompts=self.config.proportion_empty_prompts,
            )

            added_cond_kwargs = {
                "text_embeds": pooled_prompt_emb,
                "time_ids": self._add_time_embs(
                    original_size=tuple(pixel_values.shape[-2:]),
                    crops_coords_top_left=(0, 0),
                    target_size=(1024, 1024),
                    bsz=bsz,
                ),
            }
            # predict the noise residual
            # use no mask during training

            model_pred = self.unet(
                noisy_latents,
                timestep=timesteps,
                encoder_hidden_states=prompt_emb, 
                attention_mask=None,
                encoder_attention_mask=None,
                glyph_mask_blocks=None,  # DEBUG: glyph diffusion add
                encoder_mask_blocks=None,  # DEBUG: glyph diffusion add
                added_cond_kwargs=added_cond_kwargs,
            ).sample
            if self.noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif self.noise_scheduler.config.prediction_type == "v_prediction":
                target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(
                    f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
                )
            if self.config.snr_gamma is None:
                mse_loss = (F.mse_loss(model_pred.cast("float32"), target.cast("float32"), reduction="none").mean([1, 2, 3]).mean())
            else:
                snr = self.compute_snr(timesteps)
                mse_loss_weights = (paddle.stack(
                        [snr, self.config.snr_gamma * paddle.ones_like(timesteps)], axis=1).min(axis=1)[0] / snr)
                mse_loss = F.mse_loss(
                    model_pred.cast("float32"), target.cast("float32"), reduction="none"
                )
                mse_loss = mse_loss.mean(list(range(1, len(mse_loss.shape)))) * mse_loss_weights
                mse_loss = mse_loss.mean()

            loss = mse_loss
            loss_log = {"mse": round(mse_loss.detach().item(), 4)}

            if self.config.local_mse_loss_weight > 0:
                text_local_mse_loss = F.mse_loss(
                    model_pred.cast("float32"), target.cast("float32"), reduction="none"
                ) * img_mask
                text_local_mse_loss = text_local_mse_loss.mean(list(range(1, len(text_local_mse_loss.shape))))
                tmp = sum([
                    self.config.local_mse_loss_weight * 
                    min((img_mask[i].shape[-1]*img_mask[i].shape[-2]*img_mask[i].shape[-3] / img_mask[i].sum()), 5) * 
                    text_local_mse_loss[i] for i in range(bsz)
                ])
                loss += tmp
                loss_log["latent"] = round(tmp.item(), 4)

            if self.config.use_glyph_localization and self.cross_glyph_localization_weight > 0:
                attn_localization_loss = get_glyph_localization_loss(
                    self.cross_attention_scores,
                    object_segmaps=encoder_attn_mask,
                    visual_token_spans=[[t["span"] for t in tags.values()] for tags in tags_dic],
                    loss_fn=self.glyph_localization_loss_fn
                )
                loss_log["attn"] = (
                    round(attn_localization_loss.detach().item(), 4)
                    if (
                        self.config.use_glyph_localization
                        and not isinstance(attn_localization_loss, float)
                    ) else 0
                )
                loss += self.cross_glyph_localization_weight * attn_localization_loss
                

            if self.config.ocr_rec_loss_weight > 0:
                pred_x0 = self.noise_scheduler.get_x0_from_noise(
                    model_pred, timesteps, noisy_latents
                )
                pred_pixel_values = self.vae.decode(pred_x0).sample
                crop_img_list_batch = []
                pred_crop_img_list_batch = []
                ocr_time_weights = []
                for i in range(bsz):
                    crop_img_list = self.image_processor.crop_ocr_img(pixel_values[i], ocr_info[i])
                    if crop_img_list is not None and crop_img_list != []:
                        crop_img_list_batch.extend(crop_img_list)
                    pred_crop_img_list = self.image_processor.crop_ocr_img(pred_pixel_values[i], ocr_info[i])
                    if pred_crop_img_list is not None and pred_crop_img_list != []:
                        pred_crop_img_list_batch.extend(pred_crop_img_list)
                        ocr_time_weights.extend([self.noise_scheduler.alphas_cumprod[timesteps[i]]] * len(pred_crop_img_list))

                ocr_rec_loss = 0
                if  pred_crop_img_list_batch is not None and pred_crop_img_list_batch != []:
                    ocr_rec_loss = self.ocr_predictor(pred_crop_img_list_batch, weights=ocr_time_weights)
                loss_log["ocr"] = round(ocr_rec_loss.item(), 4) * self.config.ocr_rec_loss_weight if isinstance(ocr_rec_loss, paddle.Tensor) else 0
                    
                loss += ocr_rec_loss * self.config.ocr_rec_loss_weight
            loss_log["step_loss"] = round(loss.detach().item(), 4)
            return loss, loss_log

    @paddle.no_grad()
    def decode_image(self, pixel_values=None, **kwargs):
        self.eval()
        if pixel_values.ndim == 3:
            pixel_values = pixel_values.unsqueeze(0)
        if pixel_values.shape[0] > 8:
            pixel_values = pixel_values[:8]
        latents = self.vae.encode(pixel_values).latent_dist.sample()
        image = self.vae.decode(latents).sample
        return image

    @contextlib.contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.ema_model.parameters())
            self.model_ema.copy_to(self.ema_model)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.ema_model.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def on_train_batch_end(self):
        if self.use_ema:
            self.model_ema(self.ema_model)

    def _clear_cross_attention_scores(self):
        if hasattr(self, "cross_attention_scores"):
            keys = list(self.cross_attention_scores.keys())
            for k in keys:
                del self.cross_attention_scores[k]

        gc.collect()

    def _add_time_embs(
        self,
        original_size=(512, 512),
        crops_coords_top_left=(0, 0),
        target_size=(1024, 1024),
        dtype=paddle.float32,
        bsz=1,
    ):
        add_time_ids = list(original_size + crops_coords_top_left + target_size)
        passed_add_embed_dim = (
            self.unet.config.addition_time_embed_dim * len(add_time_ids)
            + self.text_encoder_2.config.projection_dim
        )
        expected_add_embed_dim = self.unet.add_embedding.linear_1.weight.shape[0]
        if expected_add_embed_dim != passed_add_embed_dim:
            raise ValueError(
                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
            )
        add_time_ids = paddle.to_tensor(data=[add_time_ids] * bsz, dtype=dtype)
        return add_time_ids

    def add_lora_layers(self):
       
        lora_config = LoRAConfig(
            r=self.config.lora_rank,
            target_modules=[
                ".*to_q.*",
                ".*to_k.*",
                ".*to_v.*",
                ".*to_out.0.*",
                ".*proj_in.*",
                ".*proj_out.*",
                ".*ff.net.0.proj.*",
                ".*ff.net.2.*",
                ".*conv1.*",
                ".*conv2.*",
                ".*conv_shortcut.*",
                ".*downsamplers.0.conv.*",
                ".*upsamplers.0.conv.*",
                # ".*time_emb_proj.*",
            ],
            merge_weights=False,  # make sure we donot merge weights
        )
        self.unet.config.tensor_parallel_degree = 1
        self.unet = LoRAModel(self.unet, lora_config)
        self.reset_lora_parameters(self.unet, init_lora_weights=True)
        self.unet.mark_only_lora_as_trainable()
        self.unet.print_trainable_parameters()

    @paddle.no_grad()
    def reset_lora_parameters(self, unet, init_lora_weights=True):
        import math
        if init_lora_weights is False:
            return
        for name, module in unet.named_sublayers(include_self=True):
            module_name = module.__class__.__name__.lower()
            if module_name in ["loralinear", "loraconv2d"]:
                if init_lora_weights is True:
                    # initialize A the same way as the default for nn.Linear and B to zero
                    # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
                    nn.init.kaiming_uniform_(module.lora_A, a=math.sqrt(5), reverse=module_name == "loralinear")
                    logger.info(f"Initialized {name}'s LoRA parameters with Kaiming uniform init!")
                elif init_lora_weights.lower() == "gaussian":
                    nn.init.normal_(module.lora_A, std=1 / self.r)
                    logger.info(f"Initialized {name}'s LoRA parameters with Gaussian init!")
                else:
                    raise ValueError(f"Unknown initialization {init_lora_weights}!")
                nn.init.zeros_(module.lora_B)

    def compute_snr(self, timesteps):
        """
        Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
        """
        alphas_cumprod = self.noise_scheduler.alphas_cumprod
        sqrt_alphas_cumprod = alphas_cumprod**0.5
        sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

        # Expand the tensors.
        # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
        sqrt_alphas_cumprod = sqrt_alphas_cumprod[timesteps].astype(np.float32)
        while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
            sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
        alpha = sqrt_alphas_cumprod # .tile(timesteps.shape)

        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[timesteps].astype(np.float32)
        while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
            sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
        sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

        # Compute SNR.
        snr = (alpha / sigma) ** 2
        return snr

    def __custom_text_encoder_max_len(self, text_encoder: CLIPTextModel, length=77):
        if length == text_encoder.config.max_position_embeddings:
            return text_encoder
        tmp_emb = nn.Embedding(
            self.config.model_max_length, text_encoder.config.hidden_size
        )
        tmp_emb.weight = nn.Parameter(
            paddle.concat(
                [text_encoder.text_model.positional_embedding.weight]
                * (
                    self.config.model_max_length
                    // text_encoder.config.max_position_embeddings
                )
                + [
                    text_encoder.text_model.positional_embedding.weight[
                        : self.config.model_max_length
                        % text_encoder.config.max_position_embeddings
                    ]
                ]
            )
        )
        text_encoder.text_model.positional_embedding = tmp_emb
        del tmp_emb
        text_encoder.text_model.register_buffer(
            "position_ids",
            paddle.arange(self.config.model_max_length).expand((1, -1)),
        )
        return text_encoder

    def set_recompute(self, value=False):
        def fn(layer):
            # # ldmbert
            if hasattr(layer, "enable_recompute"):
                # layer.enable_recompute = value
                layer.enable_recompute = False
                print("Set", layer.__class__, "recompute", layer.enable_recompute)
            # unet
            if hasattr(layer, "gradient_checkpointing"):
                layer.gradient_checkpointing = value
                print("Set", layer.__class__, "recompute", layer.gradient_checkpointing)

        self.apply(fn)


def unet_store_cross_attention_scores(
    unet: UNet2DConditionModel, attention_scores: paddle.Tensor
):
    def make_new_get_attention_scores_fn(name):
        def new_get_attention_scores(module, query, key, attention_mask=None):
            """store attn scores"""
            # attention_mask = attention_mask.squeeze(1)
            attention_probs: paddle.Tensor = module.old_get_attention_scores(
                query, key, None
            )
            attention_scores[name] = attention_probs
            return attention_probs

        return new_get_attention_scores

    for name, module in unet.named_sublayers():
        if isinstance(module, Attention) and "attn2" in name:
            if not any(layer in name for layer in UNET_LAYER_NAMES):
                continue
            if isinstance(module.processor, XFormersAttnProcessor):
                module.set_processor(AttnProcessor())
            module.old_get_attention_scores = module.get_attention_scores
            module.get_attention_scores = types.MethodType(
                make_new_get_attention_scores_fn(name), module
            )
    return unet


def get_glyph_localization_loss_for_one_layer(
    cross_attention_scores,  # [bsz*n_head, size*size, n_text_tokens]
    object_segmaps,  # [bsz, size'*size', n_text_tokens]
    visual_token_spans,  # idx for visual tokens: [bsz, 2]
    loss_fn: F.mse_loss,
):
    b, num_heads, num_noise_latents, num_text_tokens = cross_attention_scores.shape
    b, _, _ = object_segmaps.shape

    object_segmaps = F.softmax(object_segmaps, axis=-1)
    cross_attention_scores = cross_attention_scores.reshape(
        [b, num_heads, num_noise_latents, num_text_tokens]
    )  # [16, 1024, 77] => [2, 8, 1024, 77]
    object_segmaps = object_segmaps.unsqueeze(1).tile(
        [1, num_heads, 1, 1]
    )  # (b, max_num_objects, num_noise_latents) # [2, 4, 1024]
    # convert to the same size of attn_feature_map
    object_segmaps = F.interpolate(
        object_segmaps,
        size=(num_noise_latents, object_segmaps.shape[-1]),
        mode="bilinear",
    )

    def retrieve_visual_tokens(feature_map, visual_token_spans):
        """
        feature_map: [bsz, num_heads, size*size, n_text_tokens]
        visual_token_spans: [bsz, 2]
        """
        new_feature_map = []
        for b in range(feature_map.shape[0]):
            for span in visual_token_spans[b]:
                if span[0] == -1 or span[1] == -1:
                    continue
                new_feature_map.append(feature_map[b, :, :, span[0] : span[1]])
        new_feature_map = (
            paddle.concat(new_feature_map, axis=-1)
            if len(new_feature_map) > 0
            else None
        )
        return new_feature_map

    # retrieve visual tokens for loss
    # concat a batch of visual tokens together for calculation
    cross_attention_scores = retrieve_visual_tokens(
        cross_attention_scores, visual_token_spans=visual_token_spans
    )  # [bsz,n_h,size**2,77] -> [n_h,size**2,sum(span)]
    object_segmaps = retrieve_visual_tokens(
        object_segmaps, visual_token_spans=visual_token_spans
    )  # [bsz,n_h,size**2,77] -> [n_h,size**2,sum(span)]
    # no visual tokens appear in prompt, no meaning to calculate loss
    if cross_attention_scores is None or object_segmaps is None:
        print("no visual tokens appear in prompt, no meaning to calculate loss")
        return 0
    # loss = loss_fn(cross_attention_scores.unsqueeze(0), object_segmaps.unsqueeze(0)) # [1, n_head, object_token_cnt]
    # object_token_cnt = cross_attention_scores.shape[-1]
    # loss = (loss.sum(axis=2) / object_token_cnt).mean()
    # do softmax on each row
    loss = loss_fn(
        cross_attention_scores.astype(np.float), object_segmaps.astype(np.float), reduction="mean"
    )
    return loss


def get_glyph_localization_loss(
    cross_attention_scores,
    object_segmaps,
    visual_token_spans,
    loss_fn: F.mse_loss,  # BalancedL1Loss
):
    num_layers = len(cross_attention_scores)
    loss = 0
    for k, v in cross_attention_scores.items():
        layer_loss = get_glyph_localization_loss_for_one_layer(
            v, object_segmaps, visual_token_spans, loss_fn
        )
        loss += layer_loss
    return loss / num_layers
