import paddle
import os
import yaml
from paddleocr.ppocr.data import create_operators, transform
from paddleocr.ppocr.modeling.architectures import build_model
from paddleocr.ppocr.losses import build_loss
from paddleocr.ppocr.postprocess import build_post_process
from paddleocr.ppocr.utils.save_load import load_model
import numpy as np
from PIL import Image
import paddle.nn.functional as F
import paddle.nn as nn
import traceback
#!/usr/bin/python3
# coding=utf-8

from flask import Flask, request, jsonify
from flask import Flask
import json
import cv2
import numpy as np
import math

app = Flask(__name__)
app.debug = True


def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
    _, ext = os.path.splitext(file_path)
    assert ext in [".yml", ".yaml"], "only support yaml files for now"
    config = yaml.load(open(file_path, "rb"), Loader=yaml.Loader)
    return config


class OCRPredictor(object):
    def __init__(self):
        self.ocr_rec_model = None

    # img_list: list of tensors with shape chw 0-255
    def pred_imglist(self, img_list, show_debug=True, is_ori=False):
        img_list = [(img * 127.5 + 127.5) for img in img_list]
        img_num = len(img_list)
        assert img_num > 0
        # Calculate the aspect ratio of all text bars
        width_list = []
        for img in img_list:
            width_list.append(img.shape[2] / float(img.shape[1]))
        # Sorting can speed up the recognition process
        indices = paddle.to_tensor(np.argsort(np.array(width_list)))
        batch_num = 4  # rec 4 imgs 1 time
        preds_all = [None] * img_num
        preds_neck_all = [None] * img_num
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []

            imgC, imgH, imgW = self.rec_image_shape[:3]
            max_wh_ratio = imgW / imgH
            for ino in range(beg_img_no, end_img_no):
                h, w = img_list[indices[ino]].shape[1:]
                if h > w * 1.2:
                    img = img_list[indices[ino]]
                    img = paddle.transpose(img, perm=[0, 2, 1]).flip(axis=[1])
                    img_list[indices[ino]] = img
                    h, w = img.shape[1:]
                # wh_ratio = w * 1.0 / h
                # max_wh_ratio = max(max_wh_ratio, wh_ratio)  # comment to not use different ratio
            for ino in range(beg_img_no, end_img_no):
                norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
                norm_img = norm_img.unsqueeze(0)
                norm_img_batch.append(norm_img)
            norm_img_batch = paddle.concat(norm_img_batch, axis=0)
            if show_debug:
                for i in range(len(norm_img_batch)):
                    _img = (
                        (norm_img_batch[i] * 255 + 127.5)
                        .detach()
                        .cpu()
                        .cast(paddle.float32)
                        .numpy()
                        .transpose(1, 2, 0)
                    )
                    _img = _img[:, :, ::-1]
                    file_name = f"{indices[beg_img_no + i].item()}"
                    file_name = file_name + "_ori" if is_ori else file_name
                    # cv2.imwrite(file_name + ".jpg", _img)

            preds = self.ocr_rec_model(norm_img_batch)
            for rno in range(preds["ctc"].shape[0]):
                preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno]
                preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno]

        return paddle.stack(preds_all, axis=0), paddle.stack(preds_neck_all, axis=0)

    def get_ctcloss(self, preds, gt_texts, weight):
        if not isinstance(weight, paddle.Tensor):
            weight = paddle.to_tensor(weight)
        ctc_loss = nn.CTCLoss(reduction="none")
        if preds.dim() == 2:
            preds = preds.unsqueeze(0)
        log_probs = (
            nn.functional.log_softmax(preds, axis=2)
            .permute(1, 0, 2)  # NTC-->TNC
            .astype(paddle.float32)  # CTCLoss only support float32, float64
        )
        targets = []  # [num, seq_len]
        for gt_text in gt_texts:
            target = []
            target_lengths = []
            for t in gt_text:
                target += [self.char2id.get(i, len(self.chars) - 1) for i in t]
                target_lengths += [len(t)]
            targets.extend(target)

        targets = paddle.to_tensor(targets).astype(paddle.int32).unsqueeze(0)
        target_lengths = paddle.to_tensor(target_lengths).astype(paddle.int64)
        input_lengths = paddle.to_tensor(
            [log_probs.shape[0]] * (log_probs.shape[1])
        ).astype(paddle.int64)

        loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
        loss = loss / input_lengths.astype(paddle.float32) * weight
        return loss

    def __call__(self, batch, weights):
        img_list = [b["image"] for b in batch]
        gt_texts_batch = [b["text"] for b in batch]
        try:
            preds, preds_neck_all = self.pred_imglist(img_list)
            loss = 0

            for i, (pred, gt_texts) in enumerate(zip(preds, gt_texts_batch)):
                loss += weights[i] * self.get_ctcloss(pred, gt_texts, 1.0)
            return loss / (len(gt_texts_batch) + 1e-5)
        except Exception as e:
            print(e)
            print(traceback.format_exc())
            self.pred_imglist(img_list)
            return 0.0

    def get_char_dict(self, character_dict_path):
        character_str = []
        with open(character_dict_path, "rb") as fin:
            lines = fin.readlines()
            for line in lines:
                line = line.decode("utf-8").strip("\n").strip("\r\n")
                character_str.append(line)
        dict_character = list(character_str)
        dict_character = ["sos"] + dict_character + [" "]  # eos is space
        return dict_character

    def get_text(self, order):
        char_list = [self.chars[text_id] for text_id in order]
        return "".join(char_list)

    def decode(self, mat):
        text_index = mat.detach().cpu().numpy().argmax(axis=1)
        ignored_tokens = [0]
        selection = np.ones(len(text_index), dtype=bool)
        selection[1:] = text_index[1:] != text_index[:-1]
        for ignored_token in ignored_tokens:
            selection &= text_index != ignored_token
        return text_index[selection], np.where(selection)[0]

    def build_ocr_rec_model(self, ocr_rec_model_config, ckpt=None):
        config = load_config(ocr_rec_model_config)
        global_config = config["Global"]
        post_process_class = build_post_process(config["PostProcess"], global_config)
        self.rec_image_shape = [3, 48, 320]

        self.chars = self.get_char_dict(global_config["character_dict_path"])
        self.char2id = {x: i for i, x in enumerate(self.chars)}
        if hasattr(post_process_class, "character"):
            char_num = len(getattr(post_process_class, "character"))
            if config["Architecture"]["algorithm"] in [
                "Distillation",
            ]:  # distillation model
                for key in config["Architecture"]["Models"]:
                    if (
                        config["Architecture"]["Models"][key]["Head"]["name"]
                        == "MultiHead"
                    ):  # for multi head
                        if (
                            config["PostProcess"]["name"]
                            == "DistillationSARLabelDecode"
                        ):
                            char_num = char_num - 2
                        if (
                            config["PostProcess"]["name"]
                            == "DistillationNRTRLabelDecode"
                        ):
                            char_num = char_num - 3
                        out_channels_list = {}
                        out_channels_list["CTCLabelDecode"] = char_num
                        # update SARLoss params
                        if (
                            list(config["Loss"]["loss_config_list"][-1].keys())[0]
                            == "DistillationSARLoss"
                        ):
                            config["Loss"]["loss_config_list"][-1][
                                "DistillationSARLoss"
                            ]["ignore_index"] = (char_num + 1)
                            out_channels_list["SARLabelDecode"] = char_num + 2
                        elif any(
                            "DistillationNRTRLoss" in d
                            for d in config["Loss"]["loss_config_list"]
                        ):
                            out_channels_list["NRTRLabelDecode"] = char_num + 3

                        config["Architecture"]["Models"][key]["Head"][
                            "out_channels_list"
                        ] = out_channels_list
                    else:
                        config["Architecture"]["Models"][key]["Head"][
                            "out_channels"
                        ] = char_num
            elif (
                config["Architecture"]["Head"]["name"] == "MultiHead"
            ):  # for multi head
                if config["PostProcess"]["name"] == "SARLabelDecode":
                    char_num = char_num - 2
                if config["PostProcess"]["name"] == "NRTRLabelDecode":
                    char_num = char_num - 3
                out_channels_list = {}
                out_channels_list["CTCLabelDecode"] = char_num
                # update SARLoss params
                if list(config["Loss"]["loss_config_list"][1].keys())[0] == "SARLoss":
                    if config["Loss"]["loss_config_list"][1]["SARLoss"] is None:
                        config["Loss"]["loss_config_list"][1]["SARLoss"] = {
                            "ignore_index": char_num + 1
                        }
                    else:
                        config["Loss"]["loss_config_list"][1]["SARLoss"][
                            "ignore_index"
                        ] = (char_num + 1)
                    out_channels_list["SARLabelDecode"] = char_num + 2
                elif (
                    list(config["Loss"]["loss_config_list"][1].keys())[0] == "NRTRLoss"
                ):
                    out_channels_list["NRTRLabelDecode"] = char_num + 3
                config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
            else:  # base rec model
                config["Architecture"]["Head"]["out_channels"] = char_num

            if config["PostProcess"]["name"] == "SARLabelDecode":  # for SAR model
                config["Loss"]["ignore_index"] = char_num - 1
        ocr_rec_model = build_model(config["Architecture"])
        ### new ckpt
        if ckpt is not None:
            config["Global"]["pretrained_model"] = os.path.join(
                ckpt, "best_accuracy.pdparams"
            )
        load_model(config, ocr_rec_model)
        return ocr_rec_model

    # img: CHW
    def resize_norm_img(self, img, max_wh_ratio):
        imgC, imgH, imgW = self.rec_image_shape
        assert imgC == img.shape[0], f"expect 3 channels, got img shape = {img.shape}"
        imgW = int((imgH * max_wh_ratio))

        h, w = img.shape[1:]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = F.interpolate(
            img.unsqueeze(0),
            size=(imgH, resized_w),
            mode="bilinear",
            align_corners=True,
        )
        resized_image /= 255.0
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = paddle.zeros((imgC, imgH, imgW), dtype=img.dtype)
        padding_im[:, :, 0:resized_w] = resized_image[0]
        return padding_im

    def save_pretrained(self, model_path):
        paddle.save(
            self.ocr_rec_model.state_dict(),
            os.path.join(model_path, "best_accuracy.pdparams"),
        )


from PIL import Image
import base64
from io import BytesIO


def base64_to_image(base64_str):
    byte_data = base64.b64decode(base64_str)
    image_data = BytesIO(byte_data)
    img = Image.open(image_data)
    if img.mode != "RGB":
        img = img.convert("RGB")
    return img


if __name__ == "__main__":
    ocr_rec = OCRPredictor(
        "config/en_rec_custom.yml"
    )  # config/ch_PP-OCRv4_rec_hgnet_custom.yml')
    # ocr_model_path = '/disk1/liwenbo/models/ocr_models'
    # from paddleocr import PaddleOCR
    # ocr=PaddleOCR(use_angle_cls=True,use_gpu=False,
    #               lang='en',
    #             det_model_dir=f"{ocr_model_path}/en_PP-OCRv3_det_infer",
    #             rec_model_dir=f"{ocr_model_path}/en_PP-OCRv4_rec_infer",
    #             cls_model_dir=f"{ocr_model_path}/ch_ppocr_mobile_v2.0_cls_slim_infer",
    #             show_log=False,
    #             )
    # # lines = open('/disk1/liwenbo/data/glyph_laion_1M/part-00000').readlines()
    # # img = json.loads(lines[0])['img_code']
    # # img = base64_to_image(img)
    # img = Image.open("error_IS BJP POISED.jpg")
    # rec_boxes = ocr.det_and_rec(img=np.asarray(img))
    # res = ocr.ocr(img=np.asarray(img), cls=True)

    # img_path = 'paddleocr/doc/imgs_words/ch/word_1.jpg'
    # img = open(img_path, 'rb').read()
    # data = {'image': img}
    # data = np.random.randn(877,778,3) # .transpose([2,0,1])
    # result = ocr_rec.predict(rec_boxes)
    # print("res: ", res)
    # print("result: ", result)
