import json
import difflib
from typing import List
import numpy as np
import jieba.analyse as ja
import pandas as pd


class Idc10:
    def __init__(self, index, rindex, icd):
        self.index = json.load(open(index))
        self.rindex = json.load(open(rindex))
        self.icd = pd.read_excel(io=icd, names=["idx", "name"])
        self.icd = self.icd.set_index("idx")
        self.disease = list(self.rindex.keys())
        self.keywords = []
        for disease in self.disease:
            keyword = ja.extract_tags(disease)
            if len(keyword) == 0:
                keyword = [disease]

            self.keywords.append(keyword)

    def _standardization(self, diseases):
        ret = []
        for disease in diseases:
            if disease is None:
                ret.append(None)
                continue

            dis_keyword = ja.extract_tags(disease)
            scores = self._get_sim_scores(dis_keyword)
            max_idx = np.argmax(scores)
            ret.append(self.disease[max_idx])

        return ret

    def _get_sim_scores(self, dis_keyword):
        if len(dis_keyword) == 0:
            return 0.

        scores = []
        for keyword in self.keywords:
            r1 = difflib.SequenceMatcher(
                None, dis_keyword[0], keyword[0]).quick_ratio()
            r2 = difflib.SequenceMatcher(
                None, "".join(dis_keyword[:2]), "".join(keyword[:2])
            ).quick_ratio()
            scores.append(r1 * 0.5 + r2 * 0.5)

        return scores

    def ids2diseases(self, ids: List[str]):
        ret = []
        for i in ids:
            if i is None or not isinstance(i, str):
                ret.append(None)
                continue

            i = i.upper()
            if "-" in i:
                i = i.split("-")[-1]
            i = i[:3]

            try:
                ret.append(self.icd.loc[i, "name"])
            except Exception:
                try:
                    ret.append(self.icd.loc[f"{i}*", "name"])
                except Exception:
                    ret.append(None)

        return ret

    def get_diseases(self, ids: List[str]):
        diseases = self.ids2diseases(ids)
        return self._standardization(diseases)

    def get_des_by_level(self, diseases, level):
        ret = []
        for disease in diseases:
            if disease is None:
                ret.append(None)
                continue

            ridx = self.rindex[disease]
            ret.append(ridx.split('|')[level])
        return ret


if __name__ == "__main__":
    model = Idc10(
        "resources/ICD-10.json",
        "resources/ICD-10-reverse.json",
        "resources/ICD-10.xlsx",
    )
    print(model.get_diseases(["N25.9", "I47-I49", "G44.x"]))
