from glob import glob
import json
import os
import numpy as np
import argparse
from clip_score import calculate_clipscore
from fid_score import calculate_fid_given_paths
import torch
from ocr_acc import calculate_ocr_acc
from transformers import CLIPModel, CLIPProcessor
from termcolor import colored

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def main(
    all_data,
    clip_model_name_or_path,
    ocr_model_path,
    inceptionv3_path,
    output_path,
    eval_choices,
    batch_size=4,
    device="cuda",
    resolution=1024,
):
    all_results = {}
    for idx, pipe in enumerate(all_data):
        print(colored("evaluating pipeline: ", "green"), pipe)
        result_pipe = {}
        datasets = all_data[pipe]["datasets"]
        for dataset_name, data_info in datasets.items():
            data_info: dict
            print(colored("evaluating dataset: ", "green"), dataset_name)
            result_pipe[dataset_name] = {}
            if "clip_score" in eval_choices:
                clip_model = CLIPModel.from_pretrained(clip_model_name_or_path)
                clip_processor = CLIPProcessor.from_pretrained(clip_model_name_or_path)
                texts = open(data_info["texts"]).readlines()
                images_path = os.listdir(data_info["gen_path"])
                images_path.sort(key=lambda x: int(x.split(".")[0]))
                images_path = [
                    os.path.join(data_info["gen_path"], x) for x in images_path
                ]
                if len(texts) < len(images_path):
                    images_path = images_path[: len(texts)]
                else:
                    texts = texts[: len(images_path)]
                # clip_score = calculate_clipscore(
                #             model=clip_model,
                #             processor=clip_processor,
                #             texts=texts,
                #             images_path=images_path,
                #             batch_size=batch_size,
                #         )
                clip_scores = []
                for i in range(0, len(texts) // 1000):
                    clip_scores.append(
                        calculate_clipscore(
                            model=clip_model,
                            processor=clip_processor,
                            texts=texts[i * 1000 : (i + 1) * 1000],
                            images_path=images_path[i * 1000 : (i + 1) * 1000],
                            batch_size=batch_size,
                        )
                        / 1000
                    )
                if len(texts) % 1000 != 0:
                    clip_scores.append(
                        calculate_clipscore(
                            model=clip_model,
                            processor=clip_processor,
                            texts=texts[len(texts) // 1000 * 1000 :],
                            images_path=images_path[len(texts) // 1000 * 1000 :],
                            batch_size=batch_size,
                        )
                        / (len(texts) % 1000)
                    )
                clip_score = np.mean([clip_scores]) * 1000
                result_pipe[dataset_name]["clip_score"] = clip_score
                print(
                    colored("clip score:", "red"),
                    result_pipe[dataset_name]["clip_score"],
                )
            if "fid_score" in eval_choices:
                fid_score = calculate_fid_given_paths(
                    ori_path=data_info["raw_path"],
                    gen_path=data_info["gen_path"],
                    # inceptionv3_path=inceptionv3_path,
                    batch_size=batch_size,
                    # device=device,
                    # resolution=resolution,
                )
                print(colored("fid score:", "red"), fid_score)
                result_pipe[dataset_name]["fid_score"] = fid_score
            if "ocr_acc" in eval_choices:
                ocr_res = calculate_ocr_acc(
                    ori_path=data_info["raw_path"],
                    pred_path=data_info["gen_path"],
                    texts_path=data_info["texts"],
                    batch_size=batch_size,
                    lang=data_info.get("lang", "en"),
                    num_workers=8,
                    ocr_model_path=ocr_model_path,
                    resolution=resolution,
                )
                print(colored("ocr acc:", "red"), ocr_res)
                result_pipe[dataset_name]["ocr_result"] = ocr_res
        # save results
        all_results[pipe] = result_pipe
    print(all_results)
    if output_path.split(".")[-1] != "json":
        output_path = os.path.join(output_path, "eval.json")
    elif not os.path.exists(output_path.split(".")[0]):
        os.makedirs(output_path.split(".")[0], exist_ok=True)
    with open(os.path.join(output_path), "w") as fw:
        json.dump(all_results, fw)
    print("result saved to: ", output_path)


def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pipeline", type=str, required=True)
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--raw_path", type=str, required=False)
    parser.add_argument("--gen_path", type=str, default=None, required=True)
    parser.add_argument("--texts", type=str, required=True)
    parser.add_argument("--lang", type=str, default="en")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--output_path", type=str, default="./output")
    parser.add_argument("--resolution", type=int, default=1024)
    parser.add_argument("--fid_score", action="store_true")
    parser.add_argument("--clip_score", action="store_true")
    parser.add_argument("--ocr_acc", action="store_true")
    return parser


if __name__ == "__main__":
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'
    device = "cuda"
    # all_data = {
    #     'stable-diffusion': {
    #         'datasets': {
    #             'Glyph-LAION-1M': {
    #                 'raw_path': '/disk1/liwenbo/data/glyph_laion_1M/output/img',
    #                 'gen_path': '/home/liwenbo/workspace/data/vision_china/output',
    #                 'texts': '/disk1/liwenbo/data/glyph_laion_1M/output/promptlist.txt'
    #             },
    #         }
    #     }
    # }
    args = arg_parser().parse_args()
    all_data = {
        args.pipeline: {
            "datasets": {
                args.dataset: {
                    "raw_path": args.raw_path,
                    "gen_path": args.gen_path,
                    "texts": args.texts,
                    "lang": args.lang,
                }
            }
        }
    }
    pipelines = [args.pipeline]
    # {'dataset_name': path}
    # eval_choices = ['fid_score', 'clip_score', 'ocr_acc']
    eval_choices = []
    if args.fid_score:
        eval_choices.append("fid_score")
    if args.ocr_acc:
        eval_choices.append("ocr_acc")
    if args.clip_score:
        eval_choices.append("clip_score")

    clip_model_name_or_path = "openclip_vit_large_patch14"
    inceptionv3_path = "inceptionv3"
    ocr_model_path = "ocr_models"
    main(
        all_data=all_data,
        clip_model_name_or_path=clip_model_name_or_path,
        ocr_model_path=ocr_model_path,
        inceptionv3_path=inceptionv3_path,
        output_path=args.output_path,
        eval_choices=eval_choices,
        batch_size=64,
        device=device,
        resolution=args.resolution,
    )
