#!/usr/bin/env python
# coding=utf-8
import paddle
import numpy as np
import json
from tqdm import tqdm
import os
import re
from PIL import Image
from paddleocr import PaddleOCR
from ppdiffusers.pipelines.glyph_diffusion_xl import (
    GlyphDiffusionXLPipeline,
    GlyphDiffusionXLAttendAndExcitePipeline
)
from glob import glob
from paddlenlp.trainer import PdArgumentParser
import paddle.amp.auto_cast as autocast
import contextlib
import sys
from ldm.ldm_args import ModelArguments
from ldm.txt2img_dataset import prepare_input, prepare_test_input
from ldm.model import GlyphDiffusionModel

from dataclasses import dataclass, field

def getKey(dic,value):
    result=set()
    for key in dic:
        if value in dic[key]:
            result.add(key)
    result=list(result)
    result.sort()
    return result

@dataclass
class EvalArguments:
    eval_path: str = field(
        default=None, metadata={"help": "Path to the prompt file."}
    )
    output_path: str = field(
        default="./output", metadata={"help": "Path to the output directory."}
    )
    seed: int = field(
        default=0, metadata={"help": "Random seed for deterministic generation."}
    )
    resolution: int = field(
        default=1024, metadata={"help": "Resolution of the generated image."}
    )
    max_eval_num: int = field(
        default=100, metadata={"help": "Maximum number of images to evaluate."}
    )
    mix_precision: str = field(default=None)
    fp16_opt_level: str = field(
        default="O2",
    )
    num_images_per_prompt: int = field(
        default=1,
        metadata={"help": "Number of images to generate per prompt."},
    )
    recompute: bool = field(
        default=True,
    )


def autocast_smart_context_manager(args):
    if args.mix_precision:
        amp_dtype = "float16" if args.mix_precision == "fp16" else "bfloat16"
        ctx_manager = autocast(
            True,
            custom_black_list=[
                "reduce_sum",
                "c_softmax_with_cross_entropy",
            ],
            level=args.fp16_opt_level,
            dtype=amp_dtype,
        )
    else:
        ctx_manager = (
            contextlib.nullcontext()
            if sys.version_info >= (3, 7)
            else contextlib.suppress()
        )

    return ctx_manager


if __name__ == "__main__":
    parser = PdArgumentParser((ModelArguments, EvalArguments))
    model_args, eval_args = parser.parse_args_into_dataclasses()
    model_args.resolution = eval_args.resolution
    model = GlyphDiffusionModel(model_args)
    
    os.makedirs(eval_args.output_path, exist_ok=True)
    os.makedirs(os.path.join(eval_args.output_path, "images"), exist_ok=True)
    exist_len = len(os.listdir(eval_args.output_path))
    prompts = []
    prompts = open(eval_args.eval_path, "r", encoding="utf-8").readlines()
    attend = eval_args.attend
    n_gpus = paddle.distributed.get_world_size()
    local_rank = paddle.distributed.get_rank()
    prompts = prompts[local_rank::n_gpus]
    a_prompt = "typography, illustration, vibrant, photo, 3d render, typography, cinematic, best quality, extremely detailed,4k, HD, supper legible text,  clear text edges,  clear strokes, neat writing, no watermarks"
    n_prompt = "low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture"

    with autocast_smart_context_manager(eval_args):
        if eval_args.mix_precision == "bf16":
            weight_dtype = paddle.bfloat16
        elif eval_args.mix_precision == "fp16":
            weight_dtype = paddle.float16
        else:
            weight_dtype = paddle.float32
        pipe = GlyphDiffusionXLPipeline(
            vae=model.vae,
            text_encoder=model.text_encoder,
            text_encoder_2=model.text_encoder_2,
            tokenizer=model.tokenizer,
            tokenizer_2=model.tokenizer_2,
            unet=model.unet,
            scheduler=model.eval_scheduler,
        )
        pipe.unet = pipe.unet.to(dtype=weight_dtype)
        pipe.vae = pipe.vae.to(dtype=weight_dtype)
        pipe.text_encoder = pipe.text_encoder.to(dtype=weight_dtype)
        pipe.text_encoder_2 = pipe.text_encoder_2.to(dtype=weight_dtype)
        for i, (prompt,) in enumerate(zip(tqdm(prompts))):
            if i >= eval_args.max_eval_num:
                break
            prompt = prompt.strip()
            t_list = matches = re.findall(r'"(.*?)"', prompt)
            data_dic = prepare_test_input(
                            text_processor=model.text_processor,
                            image_processor=model.image_processor,
                            caption=prompt,
                            tokenization_level=model_args.tokenization_level
                        )
            caption = data_dic["caption"]
            generator = paddle.Generator(device="gpu").manual_seed(eval_args.seed)
            input_prompt = caption + " , " + a_prompt
            gen_pil_imgs = pipe(
                        prompt=input_prompt,
                        negative_prompt=n_prompt,
                        height=eval_args.resolution,
                        width=eval_args.resolution,
                        num_images_per_prompt=eval_args.num_images_per_prompt,
                        guidance_scale=7.5,
                        generator=generator,
                        output_type='pil',
            )["images"]
            for j, img in enumerate(gen_pil_imgs):
                img.save(os.path.join(eval_args.output_path, "images", f"{i*n_gpus + local_rank}_{j}_{eval_args.seed}.png"))
    print("Done!")
