from .logger import setup_logger
from .txt2img_dataset import base64_to_image, parse_line, prepare_input, prepare_test_input
from .model import GlyphDiffusionModel
from paddlenlp.trainer.integrations import (
    rewrite_logs,
)
import paddle
import re
import paddle.amp.auto_cast as autocast
import numpy as np
import importlib
import json
import sys, contextlib
from PIL import Image
from paddle.vision.transforms import ToTensor
import io
import base64
import logging
from ppdiffusers import (
    StableDiffusionControlNetPipeline,
    StableDiffusionXLControlNetPipeline,
    StableDiffusionPipeline,
    StableDiffusionXLPipeline,
)
from ppdiffusers.pipelines.glyph_diffusion_xl import GlyphDiffusionXLPipeline
from paddlenlp.trainer.integrations import (
    INTEGRATION_TO_CALLBACK,
    VisualDLCallback,
    rewrite_logs,
)

from paddlenlp.utils.log import logger


class VisualDLCallbackWithImage(VisualDLCallback):
    def autocast_smart_context_manager(self, args):
        if args.fp16 or args.bf16:
            amp_dtype = "float16" if args.fp16 else "bfloat16"
            ctx_manager = autocast(
                True,
                custom_black_list=[
                    "reduce_sum",
                    "c_softmax_with_cross_entropy",
                ],
                level=args.fp16_opt_level,
                dtype=amp_dtype,
            )
        else:
            ctx_manager = (
                contextlib.nullcontext()
                if sys.version_info >= (3, 7)
                else contextlib.suppress()
            )

        return ctx_manager

    def on_step_end(self, args, state, control, model=None, **kwargs):
        if hasattr(model, "on_train_batch_end"):
            model.on_train_batch_end()
        if (
            args.image_logging_steps > 0
            and state.global_step % args.image_logging_steps == 0
        ):
            control.should_log = True

    def on_log(self, args, state, control, logs=None, **kwargs):
        model: GlyphDiffusionModel = kwargs.get("model", None)
        image_logs = {}
        step = state.global_step
        a_prompt = "best quality, extremely detailed,4k, HD, supper legible text,  clear text edges,  clear strokes, neat writing, no watermarks"
        n_prompt = "low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture"
        if not state.is_world_process_zero:
            return
        if (
            model is not None
            and args.image_logging_steps > 0
            and state.global_step % args.image_logging_steps == 0
        ):
            inputs = kwargs.get("inputs", None)
            image_logs['train_data'] = (inputs['pixel_values']* 127.5 + 127.5).permute(0,2,3,1).detach().cpu().numpy()
            
            pipe = GlyphDiffusionXLPipeline(
                vae=model.vae,
                text_encoder=model.text_encoder,
                text_encoder_2=model.text_encoder_2,
                tokenizer=model.tokenizer,
                tokenizer_2=model.tokenizer_2,
                unet=model.unet,
                scheduler=model.eval_scheduler,
            )
            with self.autocast_smart_context_manager(args):
                logger.info("start logging images")
                ### log with test files
                files = [i.strip() for i in open(args.log_file_list)]
                reconst_imgs = []
                for i, file in enumerate(files):
                    if i < len(files) - 1:
                        prefix = "train" if i == 0 else "test"
                        lines = open(file).readlines()
                        for idx, line in enumerate(lines):
                            if idx >= args.max_logging_samples:
                                break
                            data = parse_line(line, filename=file)
                            data_dic = prepare_input(
                                data=data,
                                text_processor=model.text_processor,
                                image_processor=model.image_processor,
                                tokenizer=model.torch_tokenizer,
                                ocr_model_path=args.ocr_model_path,
                                font_path=args.font_path,
                                tokenization_level = args.tokenization_level,
                            )
                            # retrieve data
                            prompt = data_dic["caption"]
                            logger.info(f"logging image for prompt: {prompt}")
                            
                            device = model.unet.device
                            ori_img = data_dic["pixel_values"]
                            # provide for deterministic generation
                            generator = paddle.Generator(device="gpu").manual_seed(args.seed)
                            # reconst_imgs.append(model.decode_image(ori_img))
                            input_prompt = prompt + " , " + a_prompt
                            img_add = pipe(
                                prompt=input_prompt,
                                negative_prompt=n_prompt,
                                height=args.resolution,
                                width=args.resolution,
                                guidance_scale=7.5,
                                generator=generator,
                                output_type="pd",
                            )["images"].squeeze(0)
                            prompt = prefix + "/" + prompt
                            image_logs[prompt] = (paddle.concat([
                                                    (ori_img * 127.5 + 127.5).cpu().permute(1, 2, 0),
                                                    (img_add * 255).permute(1, 2, 0).cpu(),],axis=1)
                                                    .unsqueeze(0).clip(0, 255).numpy().astype(np.uint8).round()
                                                )
                            Image.fromarray(image_logs[prompt].squeeze(0)).save(
                                f"{args.logging_dir}/{prefix}_log_{step}_step_{idx}.png"
                            )
                    else:
                        prefix = "custom"
                        lines = open(file).readlines()
                        for idx, line in enumerate(lines):
                            if idx >= args.max_logging_samples:
                                break
                            line = line.strip() # .replace('""', f'"福"')
                            data_dic = prepare_test_input(
                                caption=line,
                                text_processor=model.text_processor,
                                image_processor=model.image_processor,
                                tokenization_level=args.tokenization_level
                            )
                            # retrieve data
                            prompt = data_dic["caption"]
                            logger.info(f"logging image for prompt: {prompt}")
                            # provide for deterministic generation
                            generator = paddle.Generator(device="gpu").manual_seed(args.seed)
                            input_prompt = prompt + ". " + a_prompt
                            img_add = pipe(
                                prompt=input_prompt,
                                negative_prompt=n_prompt,
                                height=args.resolution,
                                width=args.resolution,
                                guidance_scale=7.5,
                                generator=generator,
                                output_type="pd",
                            )["images"].squeeze(0)
                            prompt = prefix + "/" + prompt
                            image_logs[prompt] = (img_add * 255).permute(1, 2, 0).cpu().astype(paddle.float32).unsqueeze(0).clip(0, 255).numpy().astype(np.uint8).round()
                            Image.fromarray(image_logs[prompt].squeeze(0)).save(
                                f"{args.logging_dir}/{prefix}_log_{step}_step_{idx}.png"
                            )
                logger.info("finish logging images")
        if self.vdl_writer is None:
            self._init_summary_writer(args)

        if self.vdl_writer is not None:
            logs = rewrite_logs(logs)
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.vdl_writer.add_scalar(k, v, state.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
                        "This invocation of VisualDL's writer.add_scalar() "
                        "is incorrect so we dropped this attribute."
                    )
            # log images
            for k, v in image_logs.items():
                self.vdl_writer.add_image(k, v, state.global_step, dataformats="NHWC")
            self.vdl_writer.flush()
