import json
from tqdm import tqdm
import random
from bert_score import score, BERTScorer
from nltk import data, sent_tokenize
import os
import traceback

import sys

sys.path.append('../../../')
from evaluation.metric import get_idf_sents, top_filter


def concat(card, entry_text):
    return card + '</s>' + entry_text


def gen(in_file_name, out_file_name):
    def invalid(scene):
        if scene['score'] < 0.4 or scene['score'] > 0.7:
            return True
        return False

    print(f'in_file_name:{in_file_name}, out_file_name:{out_file_name}')
    res = []
    with open(in_file_name, encoding='utf-8') as f:
        a = json.load(f)
        character_entries = []
        card_list = []
        narrator_entries = []
        for scene in tqdm(a):
            if invalid(scene):
                continue
            last_entry = scene['entries'][-1]
            # entry_list.append(last_entry['description'])
            card_list.append(last_entry['cards'][0]['description'])
            for entry in scene['entries']:
                if entry['role'] == 'narrator':
                    narrator_entries.append(entry['description'])
                else:
                    character_entries.append(entry['description'])

        character_entries = list(set(character_entries))
        narrator_entries = list(set(narrator_entries))
        print(f"character_entries num:{len(character_entries)}\nnarrator_entries num:{len(narrator_entries)}")
        card_list = list(set(card_list))

        for scene in tqdm(a):
            if invalid(scene):
                continue
            last_entry = scene['entries'][-1]
            true_card = last_entry['cards'][0]['description']
            true_entry = last_entry['description']
            res.append({'text': concat(true_card, true_entry), 'label': 1})

            p = random.random()
            if p < 0.5:
                fake_entry = random.choice(narrator_entries)
            else:
                fake_entry = random.choice(character_entries)
            # fake_entry = random.choice(narrator_entries + character_entries)
            res.append({'text': concat(true_card, fake_entry), 'label': 0})
            # p = random.random()
            # if p < 0.5:
            #     # replace true card
            #     card = random.choice(card_list)
            #     while card == true_card:
            #         card = random.choice(card_list)
            #     res.append({'text': concat(card, true_entry), 'label': 0})
            # else:
            #     # replace the entry
            #     entry = random.choice(entry_list)
            #     while entry == true_entry:
            #         entry = random.choice(entry_list)
            #     res.append({'text': concat(true_card, entry), 'label': 0})
    print(f"len:{len(res)}")
    out = {'data': res}
    with open(out_file_name, 'w', encoding='utf-8') as out_f:
        json.dump(out, out_f, ensure_ascii=False, indent=1)


def split(path):
    file_name = path[:path.rfind('.')]
    with open(path, encoding='utf-8') as f:
        samples = json.load(f)['data']

    pos_samples = []
    neg_samples = []
    for s in samples:
        if s['label'] == 1:
            pos_samples.append(s)
        else:
            neg_samples.append(s)
    print(f"pos_samples:{len(pos_samples)},neg_samples:{len(neg_samples)}")

    random.shuffle(pos_samples)
    random.shuffle(neg_samples)
    pos_i1 = int(len(pos_samples) * 0.8)
    pos_i2 = int(len(pos_samples) * 0.9)
    pos_idxs = [0, pos_i1, pos_i2, len(pos_samples)]
    neg_i1 = int(len(neg_samples) * 0.8)
    neg_i2 = int(len(neg_samples) * 0.9)
    neg_idxs = [0, neg_i1, neg_i2, len(neg_samples)]
    for i, q in enumerate(['train', 'valid', 'test']):
        r = pos_samples[pos_idxs[i]:pos_idxs[i + 1]]
        r += neg_samples[neg_idxs[i]:neg_idxs[i + 1]]
        random.shuffle(r)
        pos_cnt = 0
        neg_cnt = 0
        for sample in r:
            if sample['label'] == 1:
                pos_cnt += 1
            else:
                neg_cnt += 1
        print(f"pos_cnt:{pos_cnt}, neg_cnt:{neg_cnt}")
        out = file_name + f'_{q}.json'
        with open(out, 'w', encoding='utf-8') as f:
            json.dump({'data': r}, f, ensure_ascii=False, indent=1)


def calculate_bert_score(in_path, out_path):
    data = json.load(open(in_path, 'r', encoding='utf-8'))
    for scene in tqdm(data):
        last_entry = scene['entries'][-1]
        persona = last_entry['cards'][0]['description']
        hyps = []
        refs = []
        score = 0
        for sent in sent_tokenize(last_entry['description']):
            hyps.append(sent)
            refs.append(persona)
        if any(hyps):
            try:
                p, r, f = scorer.score(hyps, refs)
                dr = top_filter(r)  # max_recall
                score = dr
            except:
                traceback.print_exc()
                pass
        scene['score'] = score
    with open(out_path, 'w') as f:
        json.dump(data, f, indent=4, separators=[',', ':'])


def merge(paths):
    data = []
    for path in paths:
        data += json.load(open(path, 'r', encoding='utf-8'))
    with open("merge_dynamic_persona_score.json", 'w', encoding='utf-8') as f:
        json.dump(data, f)


if __name__ == '__main__':
    # merge([f"../../../data/{split}_dynamic_persona_score.json" for split in ['train', 'test', 'valid']])

    in_path = 'merge_dynamic_persona_score.json'
    out_path = 'merge_0421.json'
    gen(in_path, out_path)
    split(out_path)

    # scorer = BERTScorer(
    #         lang='en', 
    #         model_type='sentence-transformers/roberta-large-nli-stsb-mean-tokens',  
    #         rescale_with_baseline=False, 
    #         idf=True, 
    #         num_layers=24,
    #         nthreads=os.cpu_count(), 
    #         idf_sents=get_idf_sents('../../../data/'), 
    #         batch_size=64)
    # print('get scorer!')
    # for split in ['test', 'valid', 'train']:
    #     calculate_bert_score(f'../../../data/{split}_dynamic_persona.json', f'../../../data/{split}_dynamic_persona_score.json')

    # for split in ['test', 'valid', 'train']:
    #     gen(f'../../../data/{split}_dynamic_persona_score.json', f'{split}_persona.json')
