import json
from tqdm import tqdm
from nltk import sent_tokenize, word_tokenize
from bert_score import BERTScorer
import torch
import os


def get_idf_sents():
    def get_idf_docs():
        if os.path.exists('../data/train_doc_idf.json'):
            with open('../data/train_doc_idf.json', encoding='utf-8') as f:
                return json.load(f)
        file_name = '../data/train.json'
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            # print('*' * 40 + file_name + '*' * 40)
            # print('samples:', len(a))

            hyps = []
            refs = []
            for scene in a:

                for entry in scene['entries']:
                    hyps.append(entry['description'])
                    # hyp += entry['description']
                for card in scene['entries'][-1]['cards']:
                    refs.append(card['description'])
            with open('../data/train_doc_idf.json', 'w', encoding='utf-8') as fi:
                json.dump(hyps + refs, fi, ensure_ascii=False)
            print('finish get_idf_sent')

            return hyps + refs

    def get_idf_split_sent():
        if os.path.exists('train_sent_idf.json'):
            with open('train_sent_idf.json', encoding='utf-8') as f:
                return json.load(f)
        file_name = 'train.json'
        with open(file_name, encoding='utf-8') as f:
            a = json.load(f)
            # print('*' * 40 + file_name + '*' * 40)
            # print('samples:', len(a))

            hyps = []
            refs = []
            for scene in a:

                for entry in scene['entries']:
                    # hyps.append(entry['description'])
                    hyps.extend(sent_tokenize(entry['description']))
                    # hyp += entry['description']
                for card in scene['entries'][-1]['cards']:
                    # refs.append(card['description'])
                    refs.extend(sent_tokenize(card['description']))

            with open('train_sent_idf.json', 'w', encoding='utf-8') as fi:
                json.dump(hyps + refs, fi, ensure_ascii=False)
            print('finish get_idf_sent')
            return hyps + refs

    # return get_idf_split_sent()
    return get_idf_docs()


def process(scorer):
    def max_score(card_text, text):
        sents = sent_tokenize(text)
        hyps = []
        refs = []
        for sent in sents:
            hyps.append(sent)
            refs.append(card_text)

        p, r, f = scorer.score(hyps, refs)
        score, idx = torch.max(r, dim=0)
        return score.item(), idx.item()

    def process_one_file(in_file_name, out_file_name):
        with open(in_file_name, encoding='utf-8') as f:
            a = json.load(f)
            cnt = 0
            for scene in tqdm(a):
                entry = scene['entries'][-1]
                text = entry['description']
                cards = entry['cards']
                if len(cards) == 1:
                    continue
                cnt += 1
                scores = []
                for i, card in enumerate(cards):
                    card_text = card['description']
                    score, idx = max_score(card_text, text)
                    scores.append((score, idx, i))

                _score, sent_idx, card_idx = max(scores, key=lambda x: x[0])
                scene['peak_idx'] = sent_idx
                entry['cards'] = [cards[card_idx]]

            print(f"{in_file_name}, more than one card num:{cnt}")

        with open(out_file_name, 'w', encoding='utf-8') as f:
            json.dump(a, f, ensure_ascii=False, indent=1)

    for s in ['test', 'valid', 'train']:
        process_one_file(f'../data/{s}_peak.json', f'../data/{s}_peak_onecard.json')


def check(scorer):
    def max_score(card_text, text):
        sents = sent_tokenize(text)
        hyps = []
        refs = []
        for sent in sents:
            hyps.append(sent)
            refs.append(card_text)

        p, r, f = scorer.score(hyps, refs)
        score, idx = torch.max(r, dim=0)
        return score.item(), idx.item()

    def check_one_file(in_file_name):
        with open(in_file_name, encoding='utf-8') as f:
            a = json.load(f)
            cnt = 0
            for scene in tqdm(a):
                entry = scene['entries'][-1]
                text = entry['description']
                cards = entry['cards']
                assert len(cards) == 1
                cnt += 1
                for i, card in enumerate(cards):
                    card_text = card['description']
                    score, idx = max_score(card_text, text)
                    assert idx == scene['peak_idx']

            print(f"{in_file_name}, more than one card num:{cnt}")

    for s in ['test', 'valid', 'train']:
        check_one_file(f'../data/{s}_peak_onecard.json')


def copy_entry_peak():
    def copy_one_file(in_file_name, out_file_name):
        with open(in_file_name, encoding='utf-8') as f:
            a = json.load(f)

        name = out_file_name[:out_file_name.rindex('.')]
        real_out_name = name + '_onecard.json'
        with open(out_file_name, encoding='utf-8') as f:
            b = json.load(f)
            for i, scene in enumerate(b):
                scene['entries'][-1] = a[i]['entries'][-1]
                scene['peak_idx'] = a[i]['peak_idx']

        with open(real_out_name, 'w', encoding='utf-8') as f:
            json.dump(b, f, ensure_ascii=False, indent=1)

        print(f"finish copy from {in_file_name} to {real_out_name}")

    for split in ['test', 'valid', 'train']:
        copy_one_file(f"../data/{split}_peak_onecard.json", f"../data/{split}_peak_context_target_kw.json")
        copy_one_file(f"../data/{split}_peak_onecard.json", f"../data/{split}_add_node.json")

if __name__ == '__main__':
    model_type = 'sentence-transformers/roberta-large-nli-stsb-mean-tokens'
    layers = 24

    scorer = BERTScorer(lang='en', model_type=model_type,
                        rescale_with_baseline=False, idf=True, num_layers=layers, nthreads=os.cpu_count(),
                        idf_sents=get_idf_sents(), batch_size=64,  # masked_words=tokenized_stop_words
                        )
    # process(scorer)
    # check(scorer)
    copy_entry_peak()
