import json
from utils import TextType, process_text
from nltk import word_tokenize
from nltk.corpus import stopwords
from tqdm import tqdm

stop_words = set(stopwords.words('english'))

def stat_persona_in_output(file_name, mode: TextType, times=False):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)

    sum_ratio = 0
    for sample in tqdm(a):
        persona = process_text(sample['cards'], TextType.cards)
        output = process_text(sample['generated'], mode)

        persona_kws = word_tokenize(persona)

        persona_kws = [word for word in persona_kws if word not in stop_words]
        output_kws = word_tokenize(output)

        cnt = 0
        if not times:
            for w in persona_kws:
                if w in output_kws:
                    cnt += 1
        else:
            for w in persona_kws:
                cnt += output_kws.count(w)

        try:
            sum_ratio += cnt / len(persona_kws)
        except Exception as e:
            print(e)

    print(f"file:{file_name}, ratio:{sum_ratio / len(a)}")


if __name__ == '__main__':
    file_mode = {
        '../result/final/kg_combine4_reweight_outline_onecard_version17_1.0.json': TextType.plan_tbe,
        '../result/final/plan_ahead_bte_onecard_1.0.json': TextType.bte,
        '../result/final/gpt2_baseline_bte_onecard_1.0_persona.json': TextType.bte,
    }

    for k, v in file_mode.items():
        stat_persona_in_output(k, v, times=True)
