import json
import os
import random
import string

import nltk.sentiment.sentiment_analyzer
import numpy as np
import torch
from bert_score import BERTScorer
from nltk import sent_tokenize
from nltk.corpus import stopwords, wordnet
from nltk.sentiment.vader import SentimentIntensityAnalyzer
from nltk.stem import WordNetLemmatizer
from tqdm import tqdm
from multiprocessing import Pool
from rake_nltk import Rake

random.seed(2020)
stop = stopwords.words('english') + list(string.punctuation)  # + ["'s", "'m", "'re", "'ve"]
# print(stop)
# exit()
sid = SentimentIntensityAnalyzer()
stemmer = WordNetLemmatizer()


def get_idf_sents():
    def get_idf_docs():
        if os.path.exists('train_doc_idf.json'):
            with open('train_doc_idf.json', encoding='utf-8') as f:
                return json.load(f)
        file_name = 'train.json'
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            # print('*' * 40 + file_name + '*' * 40)
            # print('samples:', len(a))

            hyps = []
            refs = []
            for scene in a:

                for entry in scene['entries']:
                    hyps.append(entry['description'])
                    # hyp += entry['description']
                for card in scene['entries'][-1]['cards']:
                    refs.append(card['description'])
            with open('train_doc_idf.json', 'w', encoding='utf-8') as fi:
                json.dump(hyps + refs, fi, ensure_ascii=False)
            print('finish get_idf_sent')

            return hyps + refs

    def get_idf_split_sent():
        if os.path.exists('train_sent_idf.json'):
            with open('train_sent_idf.json', encoding='utf-8') as f:
                return json.load(f)
        file_name = 'train.json'
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            # print('*' * 40 + file_name + '*' * 40)
            # print('samples:', len(a))

            hyps = []
            refs = []
            for scene in a:

                for entry in scene['entries']:
                    # hyps.append(entry['description'])
                    hyps.extend(sent_tokenize(entry['description']))
                    # hyp += entry['description']
                for card in scene['entries'][-1]['cards']:
                    # refs.append(card['description'])
                    refs.extend(sent_tokenize(card['description']))

            with open('train_sent_idf.json', 'w', encoding='utf-8') as fi:
                json.dump(hyps + refs, fi, ensure_ascii=False)
            print('finish get_idf_sent')
            return hyps + refs

    # return get_idf_split_sent()
    return get_idf_docs()


# text_collection = TextCollection([word_tokenize(sent) for sent in idf_sents])
# word_idfs = {}
# for word in text_collection:
#     word_idfs[word] = text_collection.idf(word)
#
# word_idfs = sorted(word_idfs.items(), key=lambda item: item[1])

def get_discard_words():
    return []

    def valid(word):
        if word.lower() in stop or word.startswith("'"):
            return False
        return True

    with open('word_idf.json', encoding='utf-8') as f:
        a = json.load(f)
        a = [word for word, idf in a if valid(word)]
        return a[:100]


discard_words = get_discard_words()


# print(discard_words)
# exit()


# discard_words = [item[0] for item in word_idfs[:10]]
# print(discard_words)

def extract_emotion(sent):
    global sid, stop
    emotional_words = []
    words = [w for w in nltk.word_tokenize(sent) if w not in stop]
    tagged = nltk.pos_tag(words)

    for word, tag in tagged:
        # scores = get_senti(word, tag)
        # if len(scores) == 0:
        #     continue
        # if scores[0] > 0 or scores[1] > 0:  # pos score and neg score
        #     emotional_words.append(word)
        # print(get_senti(word, tag))
        ss = sid.polarity_scores(word)
        for key in ss:
            if ss[key] > 0.5 and (key == 'pos' or key == 'neg'):
                if not word[0].isalpha():
                    continue
                emotional_words.append(word)
    return emotional_words


def extract_emotion_force(sents):
    global sid, stop
    # emotional_words = []
    max_score = 0
    max_words = []
    all_words = []
    for sent in sents:
        words = [w for w in nltk.word_tokenize(sent)]
        tagged = nltk.pos_tag(words)

        for word, tag in tagged:
            all_words.append(word)
            # scores = get_senti(word, tag)
            # if len(scores) == 0:
            #     continue
            # if scores[0] > 0 or scores[1] > 0:  # pos score and neg score
            #     emotional_words.append(word)
            # print(get_senti(word, tag))
            if not word[0].isalpha():
                continue
            ss = sid.polarity_scores(word)
            temp_score = ss['pos'] + ss['neg']
            # print(temp_score)
            if temp_score > max_score:
                max_score = temp_score
                max_words = [word]
                # print(temp_score, max_score, max_word)
            elif temp_score == max_score:
                max_words.append(word)
            # for key in ss:
            #
            #     if ss[key] > 0.5 and (key == 'pos' or key == 'neg'):
            #         if not word[0].isalpha():
            #             continue
            #         emotional_words.append(word)
    if max_words:
        return [random.choice(max_words)]
    else:
        return [random.choice(all_words)]
    # return emotional_words


def extract_event(sent):
    global stemmer, stop, discard_words
    event_words = []
    tokens = nltk.word_tokenize(sent)
    tagged = nltk.pos_tag(tokens)
    for word, tags in tagged:

        if 'VB' in tags:
            origin_word = stemmer.lemmatize(word, 'v')
            if origin_word in stop or word in stop:
                continue
            if origin_word in discard_words or word in discard_words:
                continue
            if not word[0].isalpha():
                continue
            # word_metric = metric[stemmer.lemmatize(word)]
            event_words.append(word)
    return event_words


def extract_event_force(sents):
    global stemmer, stop, discard_words
    event_words = []
    word_list = []
    all_words = []
    for sent in sents:
        tokens = nltk.word_tokenize(sent)
        tagged = nltk.pos_tag(tokens)
        for word, tags in tagged:
            all_words.append(word)
            if not word[0].isalpha():
                continue
            if 'VB' in tags:
                # origin_word = stemmer.lemmatize(word, 'v')
                # if origin_word in stop or word in stop:
                #     continue
                # if origin_word in discard_words or word in discard_words:
                #     continue

                # word_metric = metric[stemmer.lemmatize(word)]
                event_words.append(word)
            else:
                word_list.append(word)
    if event_words:
        return [random.choice(event_words)]
    elif word_list:
        return [random.choice(word_list)]
    else:
        return [random.choice(all_words)]


def filter_sents(text, cards, scorer, force, top, pos=None):
    sents = sent_tokenize(text)
    scores = torch.zeros(len(sents))
    for card in cards:
        hyps = []
        refs = []
        for sent in sents:
            refs.append(card)
            hyps.append(sent)
        p, r, f = scorer.score(hyps, refs)
        scores += f

    selected_sents = []

    if top:  # only choose the sentence with max score
        _, idx = torch.max(scores, dim=0)
        selected_sents.append(sents[idx.item()])
        if pos is not None:
            pos.append((idx.item() + 1) / len(sents))
        return 1, len(sents), selected_sents

    mean_val = torch.mean(scores).item()
    std_val = torch.std(scores).item()
    n = 0.3
    y = torch.zeros(len(sents))
    # print(len(sents))
    # print(scores)
    # print(mean_val, std_val)
    filter_scores = torch.where(scores > (mean_val + n * std_val), scores, y)

    idxs = filter_scores.nonzero().squeeze(-1).tolist()
    # print(idxs)
    # print(idxs)
    for idx in idxs:
        selected_sents.append(sents[idx])
    if force and len(idxs) == 0:
        _, idx = torch.max(scores, dim=0)
        selected_sents.append(sents[idx.item()])
    # print('sents:', sents)
    # print('selected_sents:', selected_sents)

    return len(idxs), len(sents), selected_sents


def extract_one_file(in_file_name, out_file_name, scorer, force=False, top=False):
    with open(in_file_name, encoding='utf-8') as f:
        print('extract {}'.format(in_file_name))
        a = json.load(f)
        select_tot = 0
        sum_tot = 0
        zero_tot = 0
        zero_emo_tot = 0
        zero_eve_tot = 0
        res = []
        pos = []
        for scene in tqdm(a):
            text = scene['entries'][-1]['description']
            cards = []
            for card in scene['entries'][-1]['cards']:
                cards.append(card['description'])
            select, sum, selected_sents = filter_sents(text, cards, scorer, force, top, pos)
            select_tot += select
            if select == 0:
                zero_tot += 1
            sum_tot += sum
            emotion_words = []
            event_words = []
            for sent in selected_sents:
                emotion_words.extend(extract_emotion(sent))
                event_words.extend(extract_event(sent))
            if force and len(emotion_words) == 0:
                emotion_words.extend(extract_emotion_force(selected_sents))
            if force and len(event_words) == 0:
                event_words.extend(extract_event_force(selected_sents))

            w = scene
            if len(emotion_words) == 0:
                if force:
                    print(scene)
                    print(selected_sents)
                    # print()
                    exit()
                zero_emo_tot += 1
            if len(event_words) == 0:
                if force:
                    print(scene)
                    print(selected_sents)
                    # print()
                    exit()
                zero_eve_tot += 1
            w['outline'] = {'emotion': emotion_words, 'event': event_words}
            need = {'entry': w['entries'][-1], 'select': selected_sents}
            # res.append({'cards': cards, 'text': text, 'extracted': selected_sents})
            # res.append(w)
            res.append(need)
            # exit()
        print('avg select:{},avg sum:{},avg zero:{}'.format(select_tot / len(a), sum_tot / len(a), zero_tot / len(a)))
        print('avg zero emo:{}, avg zero eve:{}'.format(zero_emo_tot / len(a), zero_eve_tot / len(a)))
        print('avg pos for sent with max score', np.mean(pos) / len(pos))
    with open(out_file_name, 'w', encoding='utf-8') as f:
        json.dump(res, f, indent=1, ensure_ascii=False)
    # exit()
    # def static_mean_std()


def extract_emotion_event_at_least_one(text, limit=False, high_frequency=False, low_frequency=False, per_sent=False,
                                       filter=False):
    def get_wordnet_pos(tag):
        if tag.startswith("J"):
            return wordnet.ADJ
        elif tag.startswith("V"):
            return wordnet.VERB
        elif tag.startswith("N"):
            return wordnet.NOUN
        elif tag.startswith("R"):
            return wordnet.ADV
        else:
            return wordnet.NOUN

    def lemmatize(word, nltk_tag):
        tag = get_wordnet_pos(nltk_tag)
        return stemmer.lemmatize(word, tag).lower()

    def sample_sorted(words, num):
        ids = list(range(len(words)))
        choosed = random.sample(ids, num)
        choosed = sorted(choosed)
        res = []
        for i in choosed:
            res.append(words[i])
        return res

    def extract_emotion_event(sent, last_word):
        global sid, discard_words

        words = nltk.word_tokenize(sent)
        tagged = nltk.pos_tag(words)
        res = []
        res.append(last_word)

        def process(res):
            if not filter or len(res) < 2:
                return
            now = res[-1]
            before = res[-2]
            if now in graph.get_hops_set([before], hop=2): return
            p = random.random()
            if p < 0.9:
                res.pop(-1)
            else:
                return

        for word, tag in tagged:
            if not word[0].isalpha():
                continue
            select = False
            origin_word = lemmatize(word, tag)
            if high_frequency:
                if origin_word not in high_frequency_words:
                    continue
            if low_frequency:
                if origin_word in high_frequency_words:
                    continue

            ss = sid.polarity_scores(word)
            for key in ss:
                if not word[0].isalpha():
                    continue
                if ss[key] > 0.5 and (key == 'pos' or key == 'neg'):
                    # emotion word is usually not a stopword
                    # res.append(word)
                    res.append(origin_word)
                    select = True

            if select:
                process(res)
                continue

            if get_wordnet_pos(tag) in [wordnet.VERB, wordnet.NOUN]:
                # origin_word = stemmer.lemmatize(word, 'v')
                if origin_word in stop or word in stop:
                    continue
                if origin_word in discard_words or word in discard_words:
                    continue
                # res.append(word)
                res.append(origin_word)
                process(res)

        res.pop(0)
        from math import ceil
        max_len = min(5, ceil(len(words) * 0.1))
        # max_len = min(5, len(res))
        if limit and max_len < len(res):
            # return res[:max_len]
            return sample_sorted(res, max_len)
        else:
            return res

    def choose_one_word(sent):
        words = nltk.word_tokenize(sent)
        tagged = nltk.pos_tag(words)
        # raise Exception('need to choose one word!')
        new_tagged = [(word, tag) for word, tag in tagged if word[0].isalpha()]
        if not new_tagged:
            word, tag = random.choice(tagged)
        else:
            word, tag = random.choice(new_tagged)
        return [lemmatize(word, tag)]

    if isinstance(text, list):
        sents = text
    else:
        sents = nltk.sent_tokenize(text)

    res = []
    choose_one_word_cnt = 0
    for sent in sents:
        last_word = None
        if res:
            last_word = res[-1][-1]
        w = extract_emotion_event(sent, last_word)
        if not per_sent:
            res.extend(w)
        else:
            if w:
                res.append(w)
            else:
                choose_one_word_cnt += 1
                res.append(choose_one_word(sent))

    if per_sent:
        return res, choose_one_word_cnt

    if res:
        return res, 0
    else:
        # return [nltk.word_tokenize(sents[0])[0]]
        return choose_one_word(sents[0]), 1


def extract_one_file_peak(in_file_name, out_file_name):
    def extract_emotion_event(sent):
        global sid, discard_words

        words = nltk.word_tokenize(sent)
        tagged = nltk.pos_tag(words)
        res = []
        for word, tag in tagged:
            if not word[0].isalpha():
                continue
            select = False

            ss = sid.polarity_scores(word)
            for key in ss:
                if not word[0].isalpha():
                    continue
                if ss[key] > 0.5 and (key == 'pos' or key == 'neg'):
                    res.append(word)
                    select = True

            if select:
                continue

            if 'VB' in tag:
                origin_word = stemmer.lemmatize(word, 'v')
                if origin_word in stop or word in stop:
                    continue
                if origin_word in discard_words or word in discard_words:
                    continue
                res.append(word)

        return res

    with open(in_file_name, encoding='utf-8') as f:
        print('extract {}'.format(in_file_name))
        a = json.load(f)

        res = []
        cnt = 0
        for scene in tqdm(a):
            text = scene['entries'][-1]['description']
            sents = nltk.sent_tokenize(text)
            extracted_words = []
            for idx, sent in enumerate(sents):
                if idx == scene['peak_idx']:
                    continue
                extracted_words.extend(extract_emotion_event(sent))
            if extracted_words:
                scene['outline'] = extracted_words
            else:
                extracted_words.append(nltk.word_tokenize(sents[0])[0])
                scene['outline'] = extracted_words
                cnt += 1

            assert len(extracted_words) > 0
            res.append(scene)
            # exit()
    print('cnt:', cnt)
    with open(out_file_name, 'w', encoding='utf-8') as f:
        json.dump(res, f, indent=1, ensure_ascii=False)


def extract_context_target_bedding_ending_kw_work(scene):
    card_text = ''
    for card in scene['entries'][-1]['cards']:
        card_text += card['description'] + ' '
    context = card_text
    for entry in scene['entries'][:-1]:
        context += entry['description'] + ' '

    peak_idx = scene['peak_idx']
    # if peak_idx == 0:
    #     continue
    #
    # cnt += 1

    sents = nltk.sent_tokenize(scene['entries'][-1]['description'])

    bedding = sents[:peak_idx]
    target = sents[peak_idx]
    ending = sents[peak_idx + 1:]

    context_kws, ct = extract_emotion_event_at_least_one(context)
    if peak_idx != 0:
        bedding_kws, bt = extract_emotion_event_at_least_one(bedding, limit=True, high_frequency=False,
                                                             low_frequency=True)
    else:
        bedding_kws = []

    if peak_idx != len(sents) - 1:
        ending_kws, et = extract_emotion_event_at_least_one(ending, limit=True, high_frequency=False,
                                                            low_frequency=True)
    else:
        ending_kws = []

    target_kws, tt = extract_emotion_event_at_least_one(target)

    assert type(context_kws) == list
    assert type(bedding_kws) == list
    assert type(target_kws) == list
    assert type(ending_kws) == list

    scene['context_kws'] = context_kws
    scene['bedding_kws'] = bedding_kws
    scene['target_kws'] = target_kws
    scene['ending_kws'] = ending_kws

    return scene


def extract_context_target_bedding_ending_kw(in_file_name, out_file_name):
    with open(in_file_name, encoding='utf-8') as f:
        print('extract {}'.format(in_file_name))
        a = json.load(f)
        with Pool(os.cpu_count()) as p:
            a = list(tqdm(p.imap(extract_context_target_bedding_ending_kw_work, a), total=len(a)))

    context_cnt, bedding_cnt, target_cnt, ending_cnt = 0, 0, 0, 0
    for scene in a:
        context_cnt += len(scene['context_kws'])
        bedding_cnt += len(scene['bedding_kws'])
        target_cnt += len(scene['target_kws'])
        ending_cnt += len(scene['ending_kws'])

    print(f"avg context kws:{context_cnt / len(a)}\navg bedding kws:{bedding_cnt / len(a)}\n"
          f"avg target kws:{target_cnt / len(a)}\navg ending kws:{ending_cnt / len(a)}")

    with open(out_file_name, 'w', encoding='utf-8') as f:
        json.dump(a, f, indent=1, ensure_ascii=False)

    print('finish out to {}'.format(out_file_name))


def rake(r):
    r.extract_keywords_from_text(text)
    phrases = r.get_ranked_phrases_with_scores()
    return phrases


def min_high_low(high_file_name, low_file_name, high_out_name, low_out_name):
    with open(high_file_name, encoding='utf-8') as f:
        high_res = json.load(f)
    with open(low_file_name, encoding='utf-8') as f:
        low_res = json.load(f)

    bedding_avg_len = 0
    ending_avg_len = 0

    for idx, (high_scene, low_scene) in tqdm(enumerate(zip(high_res, low_res))):
        bedding_len = min(len(high_scene['bedding_kws']), len(low_scene['bedding_kws']))
        high_scene['bedding_kws'] = high_scene['bedding_kws'][:bedding_len]
        low_scene['bedding_kws'] = low_scene['bedding_kws'][:bedding_len]
        ending_len = min(len(high_scene['ending_kws']), len(low_scene['ending_kws']))
        high_scene['ending_kws'] = high_scene['ending_kws'][:ending_len]
        low_scene['ending_kws'] = low_scene['ending_kws'][:ending_len]

        bedding_avg_len += bedding_len
        ending_avg_len += ending_len

    print(f"bedding_avg_len:{bedding_avg_len / len(high_res)}, ending_avg_len:{ending_avg_len / len(high_res)}")
    with open(high_out_name, 'w', encoding='utf-8') as f:
        json.dump(high_res, f, ensure_ascii=False, indent=1)
    with open(low_out_name, 'w', encoding='utf-8') as f:
        json.dump(low_res, f, ensure_ascii=False, indent=1)


def extract_kws_per_sent_work(scene):
    card_text = ''
    for card in scene['entries'][-1]['cards']:
        card_text += card['description'] + ' '
    context = card_text
    for entry in scene['entries'][:-1]:
        context += entry['description'] + ' '

    peak_idx = scene['peak_idx']
    # if peak_idx == 0:
    #     continue
    #
    # cnt += 1

    sents = nltk.sent_tokenize(scene['entries'][-1]['description'])

    bedding = sents[:peak_idx]
    target = sents[peak_idx]
    ending = sents[peak_idx + 1:]
    # persona_kws = []
    try:
        persona_kws, pt = extract_emotion_event_at_least_one(card_text)
    except:
        persona_kws = []
        pt = 0
    context_kws, ct = extract_emotion_event_at_least_one(context)
    if peak_idx != 0:
        bedding_kws, bt = extract_emotion_event_at_least_one(bedding, limit=True, high_frequency=False,
                                                             low_frequency=False, per_sent=True, filter=False)
    else:
        bedding_kws = []

    if peak_idx != len(sents) - 1:
        ending_kws, et = extract_emotion_event_at_least_one(ending, limit=True, high_frequency=False,
                                                            low_frequency=False, per_sent=True, filter=False)
    else:
        ending_kws = []

    target_kws, tt = extract_emotion_event_at_least_one(target, per_sent=True)

    assert type(persona_kws) == list
    assert type(context_kws) == list
    assert type(bedding_kws) == list
    assert type(target_kws) == list
    assert type(ending_kws) == list

    def add_attr(name, v):
        if name not in scene:
            scene[name] = v
    add_attr('context_kws', context_kws)
    add_attr('bedding_kws', bedding_kws)
    add_attr('target_kws', target_kws)
    add_attr('ending_kws', ending_kws)
    add_attr('persona_kws', persona_kws)
    # scene['context_kws'] = context_kws
    # scene['bedding_kws'] = bedding_kws
    # scene['target_kws'] = target_kws
    # scene['ending_kws'] = ending_kws
    # scene['persona_kws'] = persona_kws

    if 'intersect_nodes' in scene:
        scene.pop('intersect_nodes')
    # if 'outline_mask' in scene:
    #     scene.pop('outline_mask')
    # if 'context_kws' in scene:
    #     scene.pop('context_kws')

    return scene


def extract_kws_per_sent(in_file_name, out_file_name):
    with open(in_file_name, encoding='utf-8') as f:
        print('extract {}'.format(in_file_name))
        a = json.load(f)
        with Pool(os.cpu_count()) as p:
            a = list(tqdm(p.imap(extract_kws_per_sent_work, a), total=len(a)))

    bedding_cnt, target_cnt, ending_cnt, context_cnt, persona_cnt = 0, 0, 0, 0, 0

    def get_len(kws_list):
        if len(kws_list) == 0:
            return 0
        if type(kws_list[0]) is not list:
            return len(kws_list)
        from itertools import chain
        return len(list(chain(*kws_list)))

    for scene in a:
        context_cnt += get_len(scene['context_kws'])
        bedding_cnt += get_len(scene['bedding_kws'])
        target_cnt += get_len(scene['target_kws'])
        ending_cnt += get_len(scene['ending_kws'])
        persona_cnt += get_len(scene['persona_kws'])

    print(
        f"avg persona kws:{persona_cnt / len(a)}\navg context kws:{context_cnt / len(a)}\navg bedding kws:{bedding_cnt / len(a)}\n"
        f"avg target kws:{target_cnt / len(a)}\navg ending kws:{ending_cnt / len(a)}")

    with open(out_file_name, 'w', encoding='utf-8') as f:
        json.dump(a, f, indent=1, ensure_ascii=False)

    print('finish out to {}'.format(out_file_name))


def del_intersectnodes(file_name):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)

    for scene in tqdm(a):
        scene.pop('intersect_nodes')

    with open(file_name, 'w', encoding='utf-8') as f:
        json.dump(a, f, ensure_ascii=False, indent=1)


if __name__ == '__main__':
    # print(extract_emotion_force(['d']))
    # print(extract_event_force(['d']))
    # exit()

    # min_high_low('../data/test_add_node_onecard_high.json', '../data/test_add_node_onecard_low.json',
    #              '../data/test_add_node_onecard_high_min.json', '../data/test_add_node_onecard_low_min.json')
    # exit()
    import sys

    sys.path.append('..')
    from util.graph import get_conceptnet, KnowledgeGraph

    graph = get_conceptnet()
    print('finish load graph!')

    for split in ['test', 'valid', 'train']:
        # extract_ending_kw(f'../data/{split}_add_node_onecard.json', f'../data/{split}_add_node_ending_onecard.json')
        # extract_context_target_bedding_ending_kw(f'../data/{split}_peak_onecard.json',
        #                                          f'../data/{split}_add_node_ending_onecard.json')
        # extract_context_target_bedding_ending_kw(f'../data/{split}_add_node_onecard.json',
        #                                          f'../data/{split}_add_node_onecard_low.json')
        extract_kws_per_sent(f'../data/{split}_dynamic.json',
                             f'../data/{split}_dynamic_persona.json')
        # del_intersectnodes(f'../data/{split}_dynamic.json')
    exit()
    high_frequency_words = []
    with open('../ConceptNet/high_frequency_words.txt', encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            high_frequency_words.append(line.strip())

    high_frequency_words = set(high_frequency_words)

    for split in ['test', 'valid', 'train']:
        # extract_ending_kw(f'../data/{split}_add_node_onecard.json', f'../data/{split}_add_node_ending_onecard.json')
        # extract_context_target_bedding_ending_kw(f'../data/{split}_peak_onecard.json',
        #                                          f'../data/{split}_add_node_ending_onecard.json')
        extract_context_target_bedding_ending_kw(f'../data/{split}_add_node_onecard.json',
                                                 f'../data/{split}_add_node_onecard_low.json')
    exit()

    # extract_one_file_peak('test_peak.json', 'test_peak_emotion_event.json')
    # extract_one_file_peak('valid_peak.json', 'valid_peak_emotion_event.json')
    # extract_one_file_peak('train_peak.json', 'train_peak_emotion_event.json')
    model_type = 'sentence-transformers/roberta-large-nli-stsb-mean-tokens'
    layers = 24
    # exit()

    idf_sents = get_idf_sents()
    print('finish get_idf_sents')
    scorer = BERTScorer(lang='en', model_type=model_type,
                        rescale_with_baseline=False, idf=True, num_layers=layers, nthreads=os.cpu_count(),
                        idf_sents=idf_sents, batch_size=64,  # masked_words=tokenized_stop_words
                        )
    print('finish load scorer')
    # exit()
    extract_one_file('test.json', 'test_emotion_event_top.json', scorer, top=True)
    # extract_one_file('test.json', 'test_emotion_event.json', scorer)
    # extract_one_file('valid.json', 'valid_emotion_event.json', scorer)
    # extract_one_file('train.json', 'train_emotion_event.json', scorer)
    # extract_one_file('test.json', 'test_force_emotion_event.json', scorer, force=True)
    # extract_one_file('valid.json', 'valid_force_emotion_event.json', scorer, force=True)
    # extract_one_file('train.json', 'train_force_emotion_event.json', scorer, force=True)
