import base64
import gzip
import io
import random
from typing import List, Dict, Optional, Union
import numpy as np
from omegaconf import OmegaConf
import paddle
import paddle.distributed as dist
from paddle.io import IterableDataset, get_worker_info
from PIL import Image, ImageDraw
from io import BytesIO
from shapely.geometry import Polygon
from paddle.vision import transforms
import time
import datetime
import traceback
import re
from termcolor import colored

from paddlenlp.utils.log import logger

Image.MAX_IMAGE_PIXELS = 2300000000
import json
from paddle.vision import transforms
import math
from PIL import ImageFont, Image, ImageDraw
from paddlenlp.transformers import CLIPTokenizer, CLIPTextModel
import importlib
import cv2
import unicodedata

from paddlenlp.utils.log import logger


random_word_templates = [
    "An image with text {} written on it.",
    "word {} in color white on a black background.",
    "Text {}",
]

ch_word_templates = [
    "A photo of text {}",
    "A photo of word {}",
    "A photo of {}",
    'A photo of text "{}"',
    "A photo of text {} written on it.",
    'A photo of word "{}"',
]


def default_dump(obj):
    """Convert numpy classes to JSON serializable objects."""
    if isinstance(obj, (np.integer, np.floating, np.bool_)):
        return obj.item()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj


def base64_to_image(base64_str):
    byte_data = base64.b64decode(base64_str)
    image_data = BytesIO(byte_data)
    img = Image.open(image_data)
    if img.mode != "RGB":
        img = img.convert("RGB")
    return img


def image_to_base64(img):
    output_buffer = BytesIO()
    byte_data = output_buffer.getvalue()
    base64_str = base64.b64encode(byte_data)
    base64_str = base64_str.decode("utf-8")
    return base64_str


def add_period(s, period):
    """检查字符串末尾是否有符号，没有则加一个句号"""
    if s:
        symbols = {".", "?", "!", "。", "！", "？", "，", ","}
        if s[-1] not in symbols:
            s += period
    return s


def augument_caption(caption: str):
    # 1. remove prefix instructions
    caption = (
        caption.replace("The image shows ", "")
        .replace("The image is ", "")
        .replace("This image shows ", "")
        .replace("This is", "")
    )
    # 2. truncation too long prompt
    if len(caption) > 77:
        caption = caption.split(". ")
        caption = ". ".join(caption[:2]) if 2 <= len(caption) else caption[0]
    # 3. add period
    caption = add_period(caption, ". ")
    return caption


def add_ocr_texts(caption, ocr_texts):
    ocr_texts = ['"' + t + '"' for t in ocr_texts]
    return f'{caption} 上面写着: {", ".join(ocr_texts)}'


def get_keywords(caption):
    return re.findall(r'"(.*?)"', caption)


def process_data(line, filename=None):
    use_render_img = False
    try:
        line = line.strip().split("\t")
        img_id: str = line[0]
        text_json = json.loads(line[2])
        img_str = line[5]
        if "example" in img_id:
            ocr_info = text_json["ocr_info"]
            caption = text_json["caption"]
        caption = augument_caption(caption)
        return img_str, ocr_info, caption, use_render_img
    except Exception as e:
        logger.warning(f"error when parse file {filename}")
        logger.warning(traceback.format_exc())
        return None


def parse_line(line, filename=None):
    try:
        res = process_data(line, filename)
        if res is not None:
            base64_str, ocr_info, caption, use_render_img = res
            pil_image = Image.open(io.BytesIO(base64.b64decode(base64_str))).convert(
                "RGB"
            )
            return dict(
                image=pil_image,
                ocr_info=ocr_info,
                caption=caption,
                use_render_img=use_render_img,
            )
        else:
            return None
    except Exception as e:
        logger.warning(f"error when parse file {filename}")
        logger.warning(traceback.print_exc())
        return None


def worker_init_fn(_):
    worker_info = get_worker_info()
    dataset = worker_info.dataset
    worker_id = worker_info.id

    local_rank = dist.get_rank()
    world_size = dist.get_world_size()
    num_workers = worker_info.num_workers
    worker_id = worker_info.id
    worker_global_id = local_rank * num_workers + worker_id

    dataset.rng = np.random.RandomState(worker_global_id)
    for i in range(len(dataset.file_ids)):
        file_ids = dataset.file_ids[i]
        num_chunks = world_size * num_workers
        chunk_size = len(file_ids) // num_chunks

        begin_id = worker_global_id * chunk_size
        end_id = (worker_global_id + 1) * chunk_size
        dataset.file_ids[i] = dataset.file_ids[i][begin_id:end_id]
        logger.info(
            f"dataset {i}, local_rank: {local_rank}, worker_id: {worker_id}, worker_global_id: {worker_global_id}, file_range: ({begin_id}, {end_id})"
        )
    return np.random.seed(np.random.get_state()[1][0] + worker_id)


class ImageProcessor:
    def __init__(self, image_size=1024, random_crop=False):
        self.image_processing = transforms.Compose(
            [
                transforms.Resize(
                    size=(image_size, image_size), interpolation="lanczos"
                ),
                transforms.RandomCrop(size=(image_size, image_size))
                if random_crop
                else transforms.CenterCrop(size=(image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize(0.5, 0.5),
            ]
        )
        self.resolution = image_size
        self.filters = [
            "@",
            "com",
            "www",
            "°",
            "*",
            "00",
            "ooo",
            "×",
            "÷",
            "•",
            "…",
            "~",
            "^",
            "&",
        ]
        self.max_glyph_num = 10
        self.boarder = 0.01
        self.area_ratio = 0.05
        self.confidence = 0.6

    def filter_out(self, caption, ocr_info):
        bboxes = [x[0] for x in ocr_info]
        texts = [x[1][0] for x in ocr_info]
        if len(texts) > self.max_glyph_num:  # 文本数量太多
            return True
        elif not any(
            text.lower() in caption.lower() for text in texts
        ):  # 识别出的文本均不在caption中
            return True
        return False

    def filter_out_single(
        self, box: list, text: str, confidence: float, image_size: tuple
    ):
        """过滤掉不合格的文本"""

        w, h = image_size
        # box 为顺时针四个点坐标
        if (
            confidence < self.confidence  # 置信度不满足要求
            or len(text) < 2  # 英文长度小于2
            or any(fil in text for fil in self.filters)  # 有特殊字符
            or text.isspace()  # 全是空格
            or text.isdigit()  # 全是数字
            or Polygon(box).area / (h * w) < self.area_ratio  # 面积占比小于10%
            or (
                (min(box[0][0], box[3][0]) < self.boarder * w)  # 左边小于10
                or (min(box[0][1], box[1][1]) < self.boarder * h)  # 上边小于1
                or (max(box[2][0], box[1][0]) > w - self.boarder * w)  # 右边
                or (max(box[2][1], box[3][1]) > h - self.boarder * h)  # 下边
            )  # 距离边缘小于10%
        ):
            return True
        else:
            return False

    def get_ocr_info(self, caption, img, lang, ocr_model_path, do_filter=True):
        from paddleocr import PaddleOCR

        ocr = PaddleOCR(
            lang=lang,
            use_angle_cls=True,
            use_gpu=False,
            det_model_dir=f"{ocr_model_path}/ch_PP-OCRv4_det_server_infer",
            rec_model_dir=f"{ocr_model_path}/ch_PP-OCRv4_rec_server_infer",
            cls_model_dir=f"{ocr_model_path}/ch_ppocr_mobile_v2.0_cls_infer",
            show_log=False,
        )
        ocr_info = ocr.ocr(np.asarray(img))[0]
        if ocr_info is None or len(ocr_info) == 0:
            # logger.debug("no ocr info detected")
            return []
        if do_filter:
            try:
                ocr_info = ocr_info if not self.filter_out(caption, ocr_info) else []
                ocr_info = [
                    [box, [text, confidence]]
                    for box, (text, confidence) in ocr_info
                    if not self.filter_out_single(box, text, confidence, img.size)
                ]
            except Exception as e:
                print(e)
        return ocr_info

    def resize_and_pad_image(self, pil_image, image_size, background_color="black"):
        """用文字填满指定大小的图片"""
        if isinstance(image_size, (tuple, list)) and len(image_size) == 2:
            image_width, image_height = image_size
        elif isinstance(image_size, int):
            image_width = image_height = image_size
        else:
            raise ValueError(
                f"Image size should be int or list/tuple of int not {image_size}"
            )

        while pil_image.size[1] >= 2 * image_height:
            pil_image = pil_image.resize(
                tuple(x // 2 for x in pil_image.size), resample=Image.BOX
            )
        if pil_image.size[1] == 0:
            logger.warning(f'pil img size: {pil_image.size}')
            
        scale = image_height / pil_image.size[1]
        pil_image = pil_image.resize(
            tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
        )

        # shrink
        if pil_image.size[0] > image_width:
            pil_image = pil_image.resize(
                (image_width, image_height), resample=Image.BICUBIC
            )
        # padding
        if pil_image.size[0] < image_width:
            img = Image.new(
                mode="RGB", size=(image_width, image_height), color=background_color
            )
            width, _ = pil_image.size
            img.paste(pil_image, ((image_width - width) // 2, 0))
            pil_image = img
        return pil_image

    def get_glyph(self, tags_dic):
        def process_pil_to_tensor(pil_image):
            img = np.array(pil_image.convert("L")).reshape(1, -1).astype(np.float32)
            img = paddle.to_tensor(img)
            return img

        for k, tag in tags_dic.items():
            tag["glyph"] = process_pil_to_tensor(
                self.render_text_to_whiteboard_center(
                    tag["raw_text"], image_size=(64, 16)
                )
            )
        return tags_dic

    def get_single_ocr_emb(self, text):
        def process_pil_to_tensor(pil_image):
            img = np.array(pil_image.convert("L")).reshape(1, -1).astype(np.float32)
            img = paddle.to_tensor(img)
            return img

        glyph = process_pil_to_tensor(
            self.render_text_to_whiteboard_center(text, image_size=(64, 64))
        )
        return glyph

    def render_text_to_whiteboard_center(
        self,
        text: str,
        image_size=(320, 48),
        resolution=1024,
        font_path="data/fonts/SimSun.ttf",
        render_background_color="black",
        text_color="white",
    ):
        """将文字渲染到图片中心"""
        font = ImageFont.truetype(font_path, encoding="utf-8", size=resolution)
        text_image = Image.new("L", (resolution, resolution), render_background_color)
        draw = ImageDraw.Draw(text_image)
        x, y, w, h = draw.textbbox(xy=(0, 0), text=text, font=font)
        text_image = Image.new("L", (w, h), render_background_color)
        draw = ImageDraw.Draw(text_image)
        draw.text((-x / 2, -y / 2), text, text_color, font=font, align="left")
        # resize the text image to fit the width and height of the box
        text_image = self.resize_and_pad_image(text_image, image_size=image_size)
        return text_image.convert("RGB")

    def render_text_to_ori_img(
        self,
        ori_img,
        font_path,
        text,
        text_color="white",
        top_left_cord=(0, 0),
        font_size=200,
    ):
        # TODO: update render script
        """将文字渲染在原图片上(目前无法旋转)"""
        try:
            text = text.strip()
            font = ImageFont.truetype(font_path, encoding="utf-8", size=font_size)
            draw = ImageDraw.Draw(ori_img)
            draw.text(top_left_cord, text, fill=text_color, font=font, align="left")
            return ori_img
        except Exception as e:
            logger.warning("render text to ori img error")
            logger.warning(traceback.format_exc())
            return ori_img

    def render_text_to_whiteboard_given_bboxes(
        self,
        img: Image,
        ocr_info=None,
        image_size=(1024, 1024),
        font_path="data/fonts/SimSun.ttf",
        background_color="black",
        render_background_color="black",
        text_color="white",
        use_ocr=False,
        ocr_model_path=None,
    ):
        """
        把ocr识别的结果按照图中位置渲染到背景板上
        """
        if use_ocr:
            from paddleocr import PaddleOCR

            ocr = PaddleOCR(
                lang="ch",
                use_angle_cls=True,
                use_gpu=False,
                det_model_dir=f"{ocr_model_path}/ch_PP-OCRv4_det_server_infer",
                rec_model_dir=f"{ocr_model_path}/ch_PP-OCRv4_rec_server_infer",
                cls_model_dir=f"{ocr_model_path}/ch_ppocr_mobile_v2.0_cls_slim_infer",
                show_log=False,
            )
            ocr_info = ocr.ocr(np.asarray(img), det=True, cls=True)[0]
        if ocr_info is None or len(ocr_info) == 0:
            # logger.debug(f"No ocr info detected")
            return Image.new(mode="RGB", size=img.size, color=background_color)
        bboxes = [x[0] for x in ocr_info]
        texts = [x[1][0] for x in ocr_info]
        try:
            background = Image.new("RGB", image_size, background_color)  # output image
            font = ImageFont.truetype(font_path, encoding="utf-8", size=image_size[0])
            for text, bbox in zip(texts, bboxes):
                if len(text) == 0:
                    continue
                text = text.strip()
                ### get the ratio(w/h) of the text
                image4ratio = Image.new("L", image_size, render_background_color)
                draw = ImageDraw.Draw(image4ratio)
                _, _, w, h = draw.textbbox(xy=(0, 0), text=text, font=font)
                ratio = w / h
                ### get the width and height of the text from boxes information
                width = int(
                    math.sqrt(
                        (bbox[0][0] - bbox[1][0]) ** 2 + (bbox[0][1] - bbox[1][1]) ** 2
                    )
                )
                height = int(width / ratio)
                top_left_x = int(bbox[0][0])
                top_left_y = int(bbox[0][1])
                yaw = (
                    math.atan2(bbox[1][1] - bbox[0][1], bbox[1][0] - bbox[0][0])
                    * 180
                    / math.pi
                )
                text_image = self.render_text_to_whiteboard_center(
                    text,
                    image_size=(width, height),
                    font_path=font_path,
                    render_background_color=render_background_color,
                    text_color=text_color,
                )
                text_image = text_image.rotate(
                    angle=-yaw, expand=True, fillcolor=background_color
                )
                ### paste the text image to background
                background.paste(text_image, (top_left_x, top_left_y))
                background = background.convert("RGB")
            return background
        except Exception:
            logger.warning("render text to whiteboard given bboxes error")
            logger.warning(traceback.format_exc())
            return Image.new("RGB", image_size, background_color)

    def resize_ocr_box(self, ori_img_size, resolution, box):
        """img_size: (height, width)"""
        box[0][0] = box[0][0] * resolution / ori_img_size[0]
        box[0][1] = box[0][1] * resolution / ori_img_size[1]
        box[1][0] = box[1][0] * resolution / ori_img_size[0]
        box[1][1] = box[1][1] * resolution / ori_img_size[1]
        box[2][0] = box[2][0] * resolution / ori_img_size[0]
        box[2][1] = box[2][1] * resolution / ori_img_size[1]
        box[3][0] = box[3][0] * resolution / ori_img_size[0]
        box[3][1] = box[3][1] * resolution / ori_img_size[1]
        return box

    def crop_ocr_img(self, img_tensor, ocr_info, debug=False):
        """img: tensor of shape [3, resolution, resolution]"""

        def find_max_rect(box):
            """find the max rect of img from ocr info"""
            left = int(min([x[0] for x in box]))
            right = int(max([x[0] for x in box]))
            top = int(min([x[1] for x in box]))
            bottom = int(max([x[1] for x in box]))
            return left, top, right, bottom

        try:
            crop_img_list = []
            bboxes = [x[0] for x in ocr_info]
            texts = [x[1][0] for x in ocr_info]
            for i, (text, bbox) in enumerate(zip(texts, bboxes)):
                left, top, right, bottom = find_max_rect(bbox)
                crop_img = img_tensor[:, top:bottom, left:right]
                crop_img_list.append({"text": text, "image": crop_img})
                if debug:
                    (
                        (crop_img * 127.5 + 127.5)
                        .detach()
                        .cpu()
                        .numpy()
                        .clip(0, 255)
                        .permute(1, 2, 0)
                    ).save(f"crop_img_{i}.png")
            if debug:
                (
                    (img_tensor * 127.5 + 127.5)
                    .detach()
                    .cpu()
                    .numpy()
                    .clip(0, 255)
                    .permute(1, 2, 0)
                ).save("crop_origin_img.png")
            return crop_img_list
        except Exception:
            logger.warning("crop ocr img error")
            logger.warning(traceback.format_exc())
            return None


class TextProcessor:
    def extract_key_words(self, caption):
        """extract key words from caption"""
        words = []
        matches = re.findall(r'"(.*?)"', caption)
        if matches:
            for match in matches:
                words.extend(match.split())
        return words

    def replace_ocr_text(data, key_words):
        # sort by string
        data["ocr_info"] = data["ocr_info"].sort(key=lambda x: x[1][0])
        key_words = key_words.sort()
        num = min(len(data["ocr_info"], key_words))
        for i in range(num):
            data["ocr_info"][i][1][0] = key_words[i]

    def process_char_tokenization(self, caption, t_list):
        """ABC -> <A><B><C>"""
        t_list_ret = ["".join(["<" + ch + ">" for ch in text]) for text in t_list]
        for key_word in get_keywords(caption):
            caption = caption.replace(key_word, "".join(["<" + ch + ">" for ch in key_word]))
        return caption, t_list_ret

    def text_processing(
        self,
        tokenizer,
        input_text,
        max_length=77,
        padding="max_length",
        return_offset_mapping=False,
        return_attention_mask=True,
    ):
        inputs = tokenizer(
            input_text,
            return_tensors="np",
            padding=padding,
            max_length=max_length,
            truncation=True,
            return_attention_mask=return_attention_mask,
            return_offsets_mapping=return_offset_mapping,
        )
        inputs_tensor = {}
        for key in inputs:
            inputs_tensor[key] = paddle.to_tensor(inputs[key])

        return (
            inputs_tensor["input_ids"],
            inputs_tensor["attention_mask"] if return_attention_mask else None,
            inputs_tensor["offset_mapping"] if return_offset_mapping else None,
        )

    def process_offset_mapping(
        self, ori_seq, t_list, input_ids, offset_mapping, tags_dic
    ):
        ori_seq = ori_seq.lower()
        for t in t_list:
            tags_dic[t.lower().strip()] = {
                "raw_text": t,
                # "input_ids": [0],
                "span": [-1, -1],
                # "index": [-1],
                "ch_span": [-1, -1],
            }
        try:
            for t in t_list:
                if t == "":  # pass empty str
                    continue
                s_inedx = ori_seq.find(t.lower())
                if s_inedx == -1:
                    continue
                a = [0 for _ in range(len(ori_seq))]
                for i in range(s_inedx, s_inedx + len(t)):
                    a[i] = 1
                tokens = []
                tokens_index = []
                for index, (token, offset) in enumerate(
                    zip(input_ids[0], offset_mapping[0])
                ):
                    if sum(a[offset[0] : offset[1]]):  # has visual token
                        tokens.append(token)
                        tokens_index.append(index)
                    if token.item() == 49407:  # eos
                        break
                tags_dic[t.lower().strip()] = {
                    "raw_text": t,
                    # "input_ids": tokens,
                    "span": [tokens_index[0] - 1, tokens_index[-1]],
                    # "index": tokens_index,
                    "ch_span": [s_inedx, s_inedx + len(t)],
                }
        except:
            logger.warning("process_offset_mapping error")
            logger.warning(traceback.format_exc())
        return tags_dic

    def get_attn_mask(
        self,
        ocr_info,
        offset_mapping,
        attention_mask,
        ori_seq,
        input_ids,
        ocr_model_path=None,
        max_seq_len=77,
        use_ocr=False,
        latent_size=128,
        img: Image = None,
        img_size=None,
        tags_dic=None,
        tokenization_level="BPE"
    ):
        """
        input:
            ocr_info: list of (bbox, text)
            offset_mapping: [batch_size, seq_len, 2]
            input_ids: [batch_size, seq_len]
        output:
            attn_mask: [batch_size, src_seq_len=size*size//64, tgt_seq_len=modal_max_length]
                ([batch_size, 4096, 77])
            `1` means keep, `0` means abandon
        """
        try:
            if use_ocr:
                if img is None:
                    raise ValueError("img is None, please pass img when use_ocr=True")
                from paddleocr import PaddleOCR

                ocr = PaddleOCR(
                    use_angle_cls=True,
                    use_gpu=False,
                    lang="ch",
                    det_model_dir=f"{ocr_model_path}/ch_PP-OCRv4_det_server_infer",
                    rec_model_dir=f"{ocr_model_path}/ch_PP-OCRv4_rec_server_infer",
                    cls_model_dir=f"{ocr_model_path}/ch_ppocr_mobile_v2.0_cls_slim_infer",
                    show_log=False,
                )
                ocr_info = ocr.ocr(np.asarray(img), det=True, cls=True)[0]
            if img_size is None:
                if img is None:
                    raise ValueError(
                        "img is None, please pass img when img_size is None"
                    )
                img_size = img.size
            res = np.ones(
                shape=(max_seq_len, latent_size**2), dtype=np.uint8
            )  # [77, 4096]
            if tags_dic is None:
                tags_dic = self.process_offset_mapping(
                    ori_seq, [r[1][0] for r in ocr_info], input_ids, offset_mapping
                )
            if tags_dic is None:
                return None
            if ocr_info is not None:
                for ocr in ocr_info:
                    t = ocr[1][0]
                    box = ocr[0]
                    if tokenization_level == 'Char':
                        t = "".join(["<" + c + ">" for c in t])
                    # apply img mask
                    box = np.array(box, dtype=int)
                    img_mask = np.zeros(shape=(img_size), dtype=np.uint8)
                    cv2.fillPoly(img_mask, [box], 1)
                    if t in tags_dic.keys():  # ocr text maybe not in prompt, then skip it
                        start_idx, end_idx = tags_dic[t]["span"]
                        img_mask = cv2.resize(img_mask.astype(float), (latent_size, latent_size))
                        img_mask = (img_mask > 0).astype(np.uint8)
                        res[start_idx:end_idx] = img_mask.reshape(1, -1).repeat(end_idx - start_idx, axis=0)
            res = paddle.Tensor(res.transpose(0, 1))  # [77, 4096] -> [4096, 77]
        except:
            logger.warning("get_attn_mask error")
            logger.warning(traceback.format_exc())
            return paddle.ones(
                shape=(max_seq_len, latent_size**2), dtype=paddle.float32
            ).permute(1, 0)
        return res  # [4096, 77]

    def get_img_mask(self, ocr_info, img_size=1024, latent_size=128):
        img_mask = np.zeros(shape=(img_size, img_size), dtype=np.uint8)
        try:
            if ocr_info is not None:
                for ocr in ocr_info:
                    box = ocr[0]
                    # apply img mask
                    box = np.array(box, dtype=int)
                    cv2.fillPoly(img_mask, [box], 1)
        except:
            logger.warning("get_img_mask error")
            logger.warning(traceback.format_exc())
            img_mask = np.zeros(shape=(img_size, img_size), dtype=np.uint8)
        img_mask = cv2.resize(img_mask.astype(float), (latent_size, latent_size))
        img_mask = paddle.Tensor(img_mask.astype(np.float32))
        img_mask = img_mask.unsqueeze(0).tile([4, 1, 1])  # [1, 128, 128]
        return img_mask  # [4, 128, 128]

    def process_glyph_embed(self, tags_dic, max_len=10):
        """max_len: max glyph num in one ocr text"""
        glyph_embed = []  # glyph emb for one sentence
        glyph_dim = 64 * 16
        for text, tag in tags_dic.items():
            glyph_embed.append(tag["glyph"])
        glyph_embed = (
            paddle.concat(glyph_embed)
            if len(glyph_embed) > 0
            else paddle.zeros(shape=(max_len, glyph_dim), dtype=paddle.float32)
        )  # [len, 40*120]
        if glyph_embed.shape[0] < max_len:
            glyph_embed = paddle.concat(
                [
                    glyph_embed.astype(paddle.float32),
                    paddle.zeros(
                        shape=(max_len - glyph_embed.shape[0], glyph_embed.shape[-1]),
                        dtype=paddle.float32,
                    ),
                ]
            )
        elif glyph_embed.shape[0] > max_len:
            glyph_embed = glyph_embed[:max_len]
        return glyph_embed

    def get_glyph_attn_mask(self, tags_dic, glyph_embed, ocr_info, latent_size=128):
        """tags_dic中存储顺序和ocr结果顺序相同，因此便利每个tag就依次对应ocr中每个box"""
        try:
            start = 0
            res = np.ones(
                shape=(glyph_embed.shape[0], latent_size**2), dtype=np.uint8
            )  # [77,1]
            if ocr_info is not None:
                for ocr in ocr_info:
                    img_mask = np.zeros(
                        shape=(latent_size, latent_size), dtype=np.uint8
                    )
                    # t = ' '.join(ocr[1][0])
                    t = ocr[1][0]
                    # t = " ".join(t)
                    # t = "".join(["<" + c + ">" for c in t]) # "<" + t + ">"  # ["<" + c + ">" for c in t][0]
                    # print('t: ', t)
                    box = ocr[0]
                    # apply img mask
                    box = np.array(box, dtype=int)
                    cv2.fillPoly(img_mask, [box], 1)
                    img_mask = cv2.resize(
                        img_mask.astype(float), (latent_size, latent_size)
                    )
                    img_mask = (img_mask > 0).astype(np.uint8)
                    span = len(tags_dic[t]["glyph"])
                    if start + span > res.shape[0]:
                        break
                    res[start : start + span] = img_mask.reshape(1, -1).repeat(
                        span, axis=0
                    )  # 这个glyph text只能看到这个文本对应的图片位置
                    start = start + span
            res = paddle.Tensor(res.transpose(0, 1))  # [77, 4096] -> [4096, 77]
        except:
            logger.error("get_glyph_attn_mask error")
            logger.error(f'ocr_info: {ocr_info}')
            logger.error(traceback.format_exc())
            res = paddle.zeros(shape=(glyph_embed.shape[0], latent_size**2))
        return res

    def str_to_tensor(self, string, batch_size=64) -> paddle.Tensor:
        """
        Encodes `string` to a tensor of shape [1,N,batch_size] where
        `batch_size` is the number of characters and `n` is
        (len(string)//batch_size) + 1
        """

        def _str_to_num(string, batch_size=64):
            "Encodes `string` to a decodeable number and breaks it up by `batch_size`"
            batch, inner_batch = [], []
            for i, char in enumerate(string):
                char = ord(char)
                inner_batch.append(char)
                if (len(inner_batch) == batch_size) or (i == len(string) - 1):
                    batch.append(inner_batch)
                    inner_batch = []
            return batch

        return paddle.tensor(
            _str_to_num(string, batch_size=batch_size), dtype=paddle.long
        )

    def tensor_to_str(self, x: paddle.Tensor) -> str:
        """
        Decodes `x` to a string. `x` must have been encoded from
        `str_to_tensor`
        """

        def _num_to_str(nums):
            s = ""
            for batch in nums:
                for char in batch:
                    s += chr(char)
            return s

        return _num_to_str(x.tolist())


def prepare_test_input(
    caption,
    text_processor: TextProcessor,
    image_processor: ImageProcessor=None,
    tokenization_level="BPE",
    use_rendered_img=False,
    font_path="data/fonts/SimSun.ttf",
):

    # 1. prepare pixel values
    if use_rendered_img:
        data["image"] = image_processor.render_text_to_ori_img(data["image"], font_path=font_path)
    # 2. prepare input ids, offset mapping
    if tokenization_level == 'Char':
        caption, _ = text_processor.process_char_tokenization(caption=caption, t_list=[])

    return {
        "caption": caption,
    }


def prepare_input(
    image_processor: ImageProcessor,
    text_processor: TextProcessor,
    data,
    tokenizer,
    ocr_model_path="ocr_models",
    font_path="data/fonts/SimSun.ttf",
    latent_size=128,
    tokenization_level="BPE",
):
    """
    data: {caption: str, ocr_info: list, image: pil_img}

    returns:
        {
        "caption": str,
        "input_ids": paddle.Tensor, shape = (1, seq_max_len)
        "pixel_values": paddle.Tensor, shape = (1, 3, img_size, img_size)
        "condition": paddle.Tensor, shape = (1, 3, img_size, img_size), text rendered on whiteboard image
        "encoder_attn_mask": paddle.Tensor, shape = (1, 3, latent_sizer**2, max_seq_len),  cross attn: between img and text
        "tags_dic": {
            "raw_text": str,
            "input_ids": paddle.Tensor, shape = (1, seq_max_len)
            "span": [int, int],
            "index": int,
            "ch_span": [int, int],
        },
        "glyph_embed": paddle.Tensor, shape = (glyph_max_len, glyph_dim),
        "glyph_attn_mask": glyph_attn_mask, shape = (1, 3, laten_sizer**2, glyph_max_len)
    }
    """
    use_rendered_img = data.get("use_render_img", False)
    if (data["ocr_info"] is None or data["ocr_info"] == []) and not use_rendered_img:
        data["ocr_info"] = image_processor.get_ocr_info(
            caption=data["caption"],
            img=data["image"],
            ocr_model_path=ocr_model_path,
            lang="ch",
            do_filter=False,
        )
    t_list = []
    for o in data["ocr_info"]:
        o[1] = list(o[1])
        o[1][0] = o[1][0].lower().strip()
        if o[1][0] == "":
            data['ocr_info'].remove(o)
            continue
        t_list.append(o[1][0])

    # 1. prepare pixel values
    pixel_values = image_processor.image_processing((data["image"]))
    # 2. prepare input ids, offset mapping
    if tokenization_level == 'Char':
        data['caption'], t_list = text_processor.process_char_tokenization(caption=data['caption'], t_list=t_list)
    # print('caption: ', data["caption"])
    input_ids, attn_mask, offset_mapping = text_processor.text_processing(
        tokenizer=tokenizer,
        input_text=data["caption"],
        max_length=tokenizer.model_max_length,
        return_offset_mapping=True,
        return_attention_mask=True,
    )
    # 4. prepare tags_dic
    tags_dic = text_processor.process_offset_mapping(
        ori_seq=data["caption"],
        t_list=t_list,
        offset_mapping=offset_mapping,
        input_ids=input_ids,
        tags_dic={},
    )
    rendered_condition = image_processor.render_text_to_whiteboard_given_bboxes(
        img= data['image'],
        ocr_info=data['ocr_info'],
        image_size=(1024, 1024),
        font_path="data/fonts/SimSun.ttf",
        background_color="black",
        render_background_color="black",
        text_color="white",
        use_ocr=False,
        ocr_model_path=None,
    )
    img_mask = text_processor.get_img_mask(
        ocr_info=data["ocr_info"],
        img_size=image_processor.resolution,
        latent_size=image_processor.resolution // 8,
    )
    # 6. prepare glyph embedding
    tags_dic = image_processor.get_glyph(tags_dic)

    encoder_attn_mask = text_processor.get_attn_mask(
        ocr_info=data["ocr_info"],
        offset_mapping=offset_mapping,
        attention_mask=attn_mask,
        ori_seq=data["caption"],
        input_ids=input_ids,
        ocr_model_path=ocr_model_path,
        use_ocr=False,
        img=data["image"],
        tags_dic=tags_dic,
        tokenization_level=tokenization_level,
    )

    # 8. resize ocr_info
    for o in data["ocr_info"]:
        o[0] = image_processor.resize_ocr_box(
            ori_img_size=data["image"].size,
            resolution=image_processor.resolution,
            box=o[0],
        )
    return {
        "caption": data["caption"],
        "pixel_values": pixel_values,
        "encoder_attn_mask": encoder_attn_mask,  # cross attn: between img and text
        "tags_dic": tags_dic,
        "img_mask": img_mask,
        "ocr_info": data["ocr_info"],
        'rendered_condition': rendered_condition,
    }


class TextImagePairDataset(IterableDataset):
    def __init__(
        self,
        file_list,
        size,
        num_records,
        image_processing=None,
        buffer_size=10,
        shuffle_every_n_samples=5,
        random_crop=False,
        tokenizer=None,
        font_path="dataset/fonts/SimSun.ttf",
        ocr_model_path="ocr_models",
        tokenization_level="BPE",
    ):
        self.ocr_model_path = ocr_model_path
        self.resolution = size
        self.tokenizer = tokenizer
        self.text_processor = TextProcessor()
        self.image_processor = ImageProcessor(image_size=size, random_crop=random_crop)
        self.font_path = font_path
        self.tokenization_level = tokenization_level
        if image_processing is not None:
            self.image_processor.image_processing = image_processing
        self.num_records = num_records
        self.file_list = []

        with open(file_list, "r") as f:
            file_lists = f.read().strip().split("\n")
            for file_l in file_lists:
                file_l = file_l.split(" ")[0]
                with open(file_l, "r") as f:
                    self.file_list.append(f.read().strip().split("\n"))
        logger.info([len(file_l) for file_l in self.file_list])
        self.file_ids = [np.arange(len(filelist)) for filelist in self.file_list]
        logger.info(
            f"original lengths of self.file_ids: {[len(f) for f in self.file_ids]}"
        )
        self.buffer_size = buffer_size
        self.shuffle_every_n_samples = shuffle_every_n_samples

    def sample_loader(self, file_ids, filenames):
        while True:
            random.shuffle(file_ids)  # shuffle for each epoch
            for i in file_ids:
                filename = filenames[i].strip("\n")
                with gzip.open(filename, "rb") if filename.endswith(".gz") else open(
                    filename, "rb"
                ) as f:
                    retry = 0
                    skip_sample_num = 0
                    sample_num = 0
                    while True:
                        line = f.readline()
                        sample_num += 1
                        if line == b"" or line == "":  # end of file
                            break
                        try:
                            try:
                                line = line.decode(encoding="utf-8")
                            except Exception:
                                line = line.decode(encoding="gb18030")
                        except Exception as e:
                            logger.warning(f"error on file {filename}")
                            logger.warning(traceback.print_exc())
                            skip_sample_num += 1
                            continue
                        ### read data
                        data = parse_line(line, filename)
                        if data is None:
                            retry += 1
                            skip_sample_num += 1
                            continue
                        else:
                            # print("random_dict: ", self.random_words_dict)
                            data_dic = prepare_input(
                                data=data,
                                image_processor=self.image_processor,
                                text_processor=self.text_processor,
                                tokenizer=self.tokenizer,
                                ocr_model_path=self.ocr_model_path,
                                font_path=self.font_path,
                                latent_size=self.resolution,
                                tokenization_level=self.tokenization_level
                            )
                            if data_dic is None:
                                skip_sample_num += 1
                                continue
                            else:
                                yield data_dic
                logger.info(
                    f"{filename} have finish, total sample nums are {sample_num}, skip samples nums are {skip_sample_num}"
                )
            logger.info("finish a work epoch! The file_ids are:")
            logger.info(file_ids)

    def random_load_from_multi_dataset(self):
        logger.info(
            f"lengths of self.file_ids in random_load: {[len(f) for f in self.file_ids]}"
        )
        sample_loader_per_dataset = [
            iter(self.sample_loader(self.file_ids[i], self.file_list[i]))
            for i in range(len(self.file_ids))
        ]
        while True:
            sample_loader = random.choice(sample_loader_per_dataset)
            yield next(sample_loader)

    def shuffle(self, iterator):
        buffer_list = []
        for _ in range(self.buffer_size):
            buffer_list.append(next(iterator))
        i = 0
        while True:
            if (
                self.shuffle_every_n_samples > 0
                and i % self.shuffle_every_n_samples == 0
            ):
                random.shuffle(buffer_list)
            yield buffer_list.pop()
            buffer_list.append(next(iterator))
            i += 1

    def __len__(self):
        return self.num_records

    def __iter__(self):
        return self.shuffle(iter(self.random_load_from_multi_dataset()))
