import numpy as np
import pathlib
from PIL import Image
import paddle
from paddle.vision.transforms import Compose, Resize, CenterCrop, ToTensor
from tqdm import tqdm
from paddleocr import PaddleOCR
import os

IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}


def get_img_files(path):
    if type(path) is list:
        files = []
        for p in path:
            p = pathlib.Path(p)
            files += sorted(
                [
                    file
                    for ext in IMAGE_EXTENSIONS
                    for file in p.glob("*.{}".format(ext))
                ]
            )
        files = sorted(files)
    else:
        path = pathlib.Path(path)
        files = sorted(
            [file for ext in IMAGE_EXTENSIONS for file in path.glob("*.{}".format(ext))]
        )
    return files


def get_key_words(text: str):
    import re

    words = []
    text = text
    matches = re.findall(
        r'"(.*?)"', text
    )  # find the keywords enclosed by ''
    if matches:
        for match in matches:
            words.extend(match.split())
    return words


def get_p_r_acc(ori, pred):
    import copy

    pred = [p.strip().lower() for p in pred]
    gt = [g.strip().lower() for g in ori]
    pred_orig = copy.deepcopy(pred)
    gt_orig = copy.deepcopy(gt)
    pred_length = len(pred)
    gt_length = len(gt)
    for p in pred:
        if p in gt_orig:
            pred_orig.remove(p)
            gt_orig.remove(p)
    p = (pred_length - len(pred_orig)) / (pred_length + 1e-8)
    r = (gt_length - len(gt_orig)) / (gt_length + 1e-8)
    pred_sorted = sorted(pred)
    gt_sorted = sorted(gt)
    if "".join(pred_sorted) == "".join(gt_sorted):
        acc = 1
    else:
        acc = 0
    return p, r, acc


def calculate_ocr_acc(
    ori_path,
    pred_path,
    texts_path,
    batch_size,
    lang,
    num_workers=1,
    ocr_model_path=None,
    resolution=1024,
):
    ocr_model = PaddleOCR(
        use_angle_cls=True,
        lang='ch',
        use_gpu=False,
        det_model_dir=f"{ocr_model_path}/ch_PP-OCRv4_det_server_infer",
        rec_model_dir=f"{ocr_model_path}/ch_PP-OCRv4_rec_server_infer",
        cls_model_dir=f"{ocr_model_path}/ch_ppocr_mobile_v2.0_cls_slim_infer",
        show_log=False,
    )

    ori_files = os.listdir(ori_path)  # get_img_files(ori_path)
    pred_files = os.listdir(pred_path)  # get_img_files(pred_path)
    pred_files = [f for f in pred_files if f.split(".")[-1] in IMAGE_EXTENSIONS]
    ori_files.sort(key=lambda x: int(x.split(".")[0]))
    pred_files.sort(key=lambda x: int(x.split(".")[0]))
    raw_texts = open(texts_path, "r").readlines()
    assert len(ori_files) == len(pred_files), "The number of images is not equal"
    if batch_size > len(ori_files):
        print(
            (
                "Warning: batch size is bigger than the data size. "
                "Setting batch size to data size"
            )
        )
        batch_size = len(ori_files)

    transforms = Compose([Resize(resolution), CenterCrop(resolution)])

    total_sample = 0
    max_num = 500
    results = {"cnt": 0, "p": 0, "r": 0, "acc": 0}
    for i, (ori_img_path, pred_img_path) in enumerate(
        tqdm(zip(ori_files, pred_files), total=len(ori_files))
    ):
        if total_sample > max_num:
            break
        total_sample += 1
        ori_img = Image.open(os.path.join(ori_path, ori_img_path)).convert("RGB")
        pred_img = Image.open(os.path.join(pred_path, pred_img_path)).convert("RGB")
        if transforms is not None:
            ori_img = transforms(ori_img)
            pred_img = transforms(pred_img)
        ori_res = ocr_model.ocr(np.array(ori_img))[0]
        pred_res = ocr_model.ocr(np.array(pred_img))[0]

        ori_texts = [line[1][0] for line in ori_res] if ori_res is not None else [""]
        pred_texts = [line[1][0] for line in pred_res] if pred_res is not None else [""]
        ori_texts = get_key_words(raw_texts[i])
        print(
            f"ori_path: {ori_img_path}, pred_path: {pred_img_path}, ori_texts: {ori_texts}, pred_texts: {pred_texts}"
        )
        p, r, acc = get_p_r_acc(ori_texts, pred_texts)
        results["cnt"] += 1
        results["p"] += p
        results["r"] += r
        results["acc"] += acc

    results["p"] /= results["cnt"]
    results["r"] /= results["cnt"]
    results["f"] = (
        2 * results["p"] * results["r"] / (results["p"] + results["r"] + 1e-8)
    )
    results["acc"] /= results["cnt"]

    print(results)
    return results


def calculate_ocr_acc(
    ori_path,
    pred_path,
    texts_path,
    batch_size,
    lang,
    num_workers=1,
    ocr_model_path=None,
    resolution=1024,
):
    ori_path=None
    ocr_model = PaddleOCR(
        use_angle_cls=True,
        lang='ch',
        use_gpu=False,
        det_model_dir=f"{ocr_model_path}/ch_PP-OCRv4_det_server_infer",
        rec_model_dir=f"{ocr_model_path}/ch_PP-OCRv4_rec_server_infer",
        cls_model_dir=f"{ocr_model_path}/ch_ppocr_mobile_v2.0_cls_slim_infer",
        show_log=False,
    )

    pred_files = os.listdir(pred_path)  # get_img_files(pred_path)
    pred_files = [f for f in pred_files if f.split(".")[-1] in IMAGE_EXTENSIONS]
    pred_files.sort(key=lambda x: int(x.split(".")[0]))
    raw_texts = open(texts_path, "r").readlines()

    if batch_size > len(pred_files):
        print(
            (
                "Warning: batch size is bigger than the data size. "
                "Setting batch size to data size"
            )
        )
        batch_size = len(pred_files)

    transforms = Compose([Resize(resolution), CenterCrop(resolution)])

    total_sample = 0
    max_num = 500
    results = {"cnt": 0, "p": 0, "r": 0, "acc": 0}
    for i, (pred_img_path) in enumerate(tqdm(pred_files, total=len(pred_files))):
        if total_sample > max_num:
            break
        total_sample += 1
        pred_img = Image.open(os.path.join(pred_path, pred_img_path)).convert("RGB")
        if transforms is not None:
            pred_img = transforms(pred_img)
        pred_res = ocr_model.ocr(np.array(pred_img))[0]

        pred_texts = [line[1][0] for line in pred_res] if pred_res is not None else [""]
        ori_texts = get_key_words(raw_texts[i])
        print(
            f"pred_path: {pred_img_path}, ori_texts: {ori_texts}, pred_texts: {pred_texts}"
        )
        p, r, acc = get_p_r_acc(ori_texts, pred_texts)
        results["cnt"] += 1
        results["p"] += p
        results["r"] += r
        results["acc"] += acc

    results["p"] /= results["cnt"]
    results["r"] /= results["cnt"]
    results["f"] = (
        2 * results["p"] * results["r"] / (results["p"] + results["r"] + 1e-8)
    )
    results["acc"] /= results["cnt"]

    print(results)
    return results
