from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
import jieba
from transformers.trainer_utils import EvalPrediction

def compute_bleu(eval_prediction: EvalPrediction):
    predictions = eval_prediction.predictions
    labels = eval_prediction.label_ids
    pass

if __name__ == '__main__':
    stopwords_path = "hit_stopwords.txt"
    with open(stopwords_path) as f:
        stopwords = [line.strip() for line in f.readlines()]

    def word_cut(input_text, stopwords=stopwords):
        cuts = jieba.lcut(input_text)
        result = []
        for cut in cuts:
            if cut not in stopwords:
                result.append(cut)
        return result


    candidate_texts = ["这不是好好的吗？"]
    reference_texts = ["有吗？哪里碎了？这不是好好的吗？"]
    candidate, reference = [], []
    for candidate_text, reference_text in zip(candidate_texts, reference_texts):
        candidate.append(word_cut(candidate_text))
        reference.append([word_cut(reference_text)])

    print(candidate)
    print(reference)

    print("BLEU score:", corpus_bleu(reference, candidate, weights=(1,0,0,0)))