import json
from nltk.corpus import stopwords
from nltk import sent_tokenize
from tqdm import tqdm

from bert_score import BERTScorer
from transformers import AutoTokenizer
from itertools import chain
from summa import summarizer
from math import ceil

origin_stop_words = stopwords.words('english')
# stop_words += ['she', "'s", "'re"]
stop_words = origin_stop_words + [' ' + word for word in origin_stop_words]
stop_words.extend([word + ' ' for word in origin_stop_words])
stop_words.extend([' ' + word + ' ' for word in origin_stop_words])
tokenizer = AutoTokenizer.from_pretrained('roberta-large')
tokenized_stop_words = set(chain(*[tokenizer.encode(word, add_special_tokens=False) for word in stop_words]))
print(tokenized_stop_words)


def get_idf_sents(file_name):
    with open(file_name, encoding='utf-8') as f:
        a = json.load(f)
        # print('*' * 40 + file_name + '*' * 40)
        # print('samples:', len(a))

        hyps = []
        refs = []
        for scene in a:
            hyp = ''
            ref = ''
            for entry in scene['entries'][1:]:
                hyp += entry['description']
                if entry['role'] == 'narrator':
                    continue
                for card in entry['cards']:
                    # if len(nltk.word_tokenize(entry['description'])) > 1 and len(nltk.word_tokenize(card['description'])) > 1:
                    ref += card['description']
                    # hyps.append(entry['description'])
                    # refs.append(card['description'])
            if len(hyp) > 3 and len(ref) > 3:
                hyps.append(hyp)
                refs.append(ref)

        return hyps + refs


def get_textrank_summary(text, ratio=0.4):
    return summarizer.summarize(text, ratio)


def get_bertscore(ref, hyps, scorer):
    refs = [ref] * len(hyps)
    p, r, f = scorer.score(hyps, refs)
    return f


def extract_one_file(in_file_name, out_file_name, character_bertscore=True):
    scorer = BERTScorer(lang='en', model_type='roberta-large', rescale_with_baseline=True, idf=True,
                        idf_sents=get_idf_sents(in_file_name), masked_words=tokenized_stop_words)

    with open(in_file_name, encoding='utf-8') as f:
        a = json.load(f)
        for scene in tqdm(a):
            for entry in scene['entries'][1:]:
                if entry['role'] == 'narrator' or len(entry['cards']) == 0:
                    sum_text = get_textrank_summary(entry['description'])
                    entry['outline'] = sum_text
                    # print('origin:', entry['description'])
                    # print('summary:', sum_text)
                    # exit()
                elif character_bertscore:
                    # sum_text =
                    sents = sent_tokenize(entry['description'])
                    scores = None  # tensor
                    for card in entry['cards']:
                        if scores is not None:
                            scores += get_bertscore(card['description'], sents, scorer)
                        else:
                            scores = get_bertscore(card['description'], sents, scorer)
                    k = ceil(len(sents) * 0.4)
                    values, indices = scores.topk(k, largest=True)
                    indices, _ = indices.sort()
                    # print(indices)
                    # exit()
                    # sum_text = ''
                    selected_sents = []
                    for idx in indices:
                        selected_sents.append(sents[idx])
                    sum_text = ' '.join(selected_sents)
                    entry['outline'] = sum_text
                else:
                    sum_text = get_textrank_summary(entry['description'])
                    entry['outline'] = sum_text


    with open(out_file_name, 'w', encoding='utf-8') as f:
        json.dump(a, f, indent=1)


if __name__ == '__main__':
    # extract_one_file('test.json', 'test_with_outline.json')
    # extract_one_file('valid.json', 'valid_with_outline.json')
    # extract_one_file('train.json', 'train_with_outline.json')

    extract_one_file('test.json', 'test_with_outline_only_textrank_summary.json', False)
    extract_one_file('valid.json', 'valid_with_outline_only_textrank_summary.json', False)
    extract_one_file('train.json', 'train_with_outline_only_textrank_summary.json', False)
