import json
import re
from tqdm import tqdm
from transformers import GPT2Tokenizer
from copy import deepcopy

gpt2_tokenizer = GPT2Tokenizer.from_pretrained('../gpt2_en_ckpt_origin')


def get_len(text, gpt2=True):
    if text:
        if gpt2:
            return len(gpt2_tokenizer.encode(text))
        else:
            return len(re.split(' |\r|\n|\t', text))
    return 0


def get_card(card):
    res = {
        'card_id': card['card_id'],
        'name': card['name'],
        'description': card['description'],
        # 'pretty_namespace': card['pretty_namespace']
    }
    return res


def get_characters(story):
    res = {}  # id to character
    for cha in story['characters']:
        res[cha['character_seq_id']] = {
            'name': cha['name'],
            'description': cha['description'],
            'cards': [get_card(card) for card in cha['current_hand_cards'] if card['description']],

        }
    return res


def get_entry(entry):
    res = {
        'role': entry['role'],
        'description': entry['description'],
        'cards': []
    }
    for card in entry['cards_played_on_challenge']:
        if card['description']:
            res['cards'].append(get_card(card))
    for card in entry['cards_for_pickup']:
        if card['description']:
            res['cards'].append(get_card(card))

    return res


def split_file(in_file_name, out_file_name):
    res = []
    len_limit = 1000
    with open(in_file_name, encoding='utf-8') as f:
        a = json.load(f)
        for story in tqdm(a):
            characters = get_characters(story)
            for scene in story['scenes']:
                for chaid in scene['cast_character_seq_ids']:
                    if chaid not in characters:
                        continue
                    try:
                        temp = {'entries': [], 'character': characters[chaid], 'extra_cards': []}
                    except:
                        print(story['game_pid'], characters, chaid)
                        exit()
                    character_entry_appear_with_cards = False
                    now_cards = set()
                    l = 0
                    for card in temp['character']['cards']:
                        if card['card_id'] not in now_cards:
                            now_cards.add(card['card_id'])
                            l += get_len(card['description'])
                    for idx, entry in enumerate(scene['entries']):
                        assert idx != 0 or entry['role'] == 'narrator'
                        if entry['character_seq_id'] != chaid and entry['role'] != 'narrator':
                            continue
                        dl = get_len(entry['description'])
                        if dl == 0:
                            continue
                        l += dl
                        if l > len_limit:
                            break
                        flag = True
                        temp_cards = []
                        w = get_entry(entry)
                        if idx > 0 and w['role'] != 'narrator':  # neglect narrator entry cards as they are not used
                            for card in w['cards']:
                                if card['card_id'] not in now_cards:
                                    now_cards.add(card['card_id'])
                                    l += get_len(card['description'])
                                    if l > len_limit:
                                        flag = False
                                        break
                                    # temp['extra_cards'].append(card)
                                    temp_cards.append(card)
                        if not flag:
                            break
                        temp['entries'].append(w)
                        temp['extra_cards'].extend(temp_cards)
                        if w['role'] != 'narrator' and len(w['cards']) > 0:
                            character_entry_appear_with_cards = True

                    if len(temp['entries']) > 1 and character_entry_appear_with_cards:
                        res.append(temp)
    print(len(res))
    with open(out_file_name, 'w', encoding='utf-8') as out_f:
        json.dump(res, out_f, indent=1)


def split_file_ordered_card(in_file_name, out_file_name):
    res = []
    len_limit = 1015
    with open(in_file_name, encoding='utf-8') as f:
        a = json.load(f)
        for story in tqdm(a):
            characters = get_characters(story)
            for scene in story['scenes']:
                for chaid in scene['cast_character_seq_ids']:
                    if chaid not in characters:
                        continue
                    try:
                        temp = {'entries': [], 'character': characters[chaid], 'extra_cards': []}
                    except:
                        print(story['game_pid'], characters, chaid)
                        exit()
                    now_cards = set()
                    l = 0
                    for card in temp['character']['cards']:
                        if card['card_id'] not in now_cards:
                            now_cards.add(card['card_id'])
                            l += get_len(card['description'])
                    for idx, entry in enumerate(scene['entries']):
                        assert idx != 0 or entry['role'] == 'narrator'
                        if entry['character_seq_id'] != chaid and entry['role'] != 'narrator':
                            continue
                        dl = get_len(entry['description'])
                        if dl == 0:
                            continue
                        l += dl + 1
                        if l > len_limit:
                            break
                        flag = True
                        temp_cards = []
                        w = get_entry(entry)
                        if idx > 0:  # neglect first entry cards as they are not used
                            for card in w['cards']:
                                if card['card_id'] not in now_cards:
                                    now_cards.add(card['card_id'])
                                    l += get_len(card['description'])
                                    if l > len_limit:
                                        flag = False
                                        break
                                    # temp['extra_cards'].append(card)
                                    temp_cards.append(card)
                        if not flag:
                            break
                        temp['entries'].append(w)
                        temp['extra_cards'].extend(temp_cards)

                    if len(temp['entries']) > 2:
                        res.append(temp)
    print(len(res))
    with open(out_file_name, 'w', encoding='utf-8') as out_f:
        json.dump(res, out_f, indent=1)


def split_file_one_character_entry(in_file_name, out_file_name):
    res = []
    len_limit = 1000
    with open(in_file_name, encoding='utf-8') as f:
        a = json.load(f)
        for story in tqdm(a):
            characters = get_characters(story)
            for scene in story['scenes']:
                for chaid in scene['cast_character_seq_ids']:
                    if chaid not in characters:
                        continue
                    try:
                        temp = {'entries': [], 'character': characters[chaid], 'extra_cards': []}
                    except:
                        print(story['game_pid'], characters, chaid)
                        exit()
                    character_entry_appear_with_cards = False
                    now_cards = set()
                    l = 0
                    # for card in temp['character']['cards']:
                    #     if card['card_id'] not in now_cards:
                    #         now_cards.add(card['card_id'])
                    #         l += get_len(card['description'])
                    for idx, entry in enumerate(scene['entries']):
                        assert idx != 0 or entry['role'] == 'narrator'
                        if entry['character_seq_id'] != chaid and entry['role'] != 'narrator':
                            continue
                        dl = get_len(entry['description'])
                        if dl == 0:
                            continue
                        l += dl
                        if l > len_limit:
                            break
                        w = get_entry(entry)
                        if entry['role'] != 'narrator':
                            flag = True
                            if len(w['cards']) == 0:
                                temp['entries'].append(w)
                                continue
                            templ = l
                            for card in w['cards']:
                                templ += get_len(card['description'])
                                if templ > len_limit:
                                    flag = False
                                    break
                            temp['entries'].append(w)
                            if flag:

                                # assert temp['entries'][-1]['role'] != 'narrator'
                                res.append(deepcopy(temp))
                        else:
                            temp['entries'].append(w)
                        # flag = True
                        # temp_cards = []
                        # w = get_entry(entry)
                        # if idx > 0 and w['role'] != 'narrator':  # neglect narrator entry cards as they are not used
                        #     for card in w['cards']:
                        #         if card['card_id'] not in now_cards:
                        #             now_cards.add(card['card_id'])
                        #             l += get_len(card['description'])
                        #             if l > len_limit:
                        #                 flag = False
                        #                 break
                        #             # temp['extra_cards'].append(card)
                        #             temp_cards.append(card)
                        # if not flag:
                        #     break
                        # temp['entries'].append(w)
                        # temp['extra_cards'].extend(temp_cards)
                        # if w['role'] != 'narrator' and len(w['cards']) > 0:
                        #     character_entry_appear_with_cards = True

                    # if len(temp['entries']) > 1 and character_entry_appear_with_cards:
                    #     res.append(temp)
    print(len(res))
    with open(out_file_name, 'w', encoding='utf-8') as out_f:
        json.dump(res, out_f, indent=1)


def parse_name(in_file_name, out_file_name):
    res = []
    len_limit = 1000
    with open(in_file_name, encoding='utf-8') as f:
        a = json.load(f)
        for story in tqdm(a):
            characters = get_characters(story)
            for scene in story['scenes']:
                for chaid in scene['cast_character_seq_ids']:
                    if chaid not in characters:
                        continue
                    try:
                        temp = {'entries': [], 'character': characters[chaid], 'extra_cards': []}
                    except:
                        print(story['game_pid'], characters, chaid)
                        exit()
                    character_entry_appear_with_cards = False
                    now_cards = set()
                    l = 0
                    # for card in temp['character']['cards']:
                    #     if card['card_id'] not in now_cards:
                    #         now_cards.add(card['card_id'])
                    #         l += get_len(card['description'])
                    for idx, entry in enumerate(scene['entries']):
                        assert idx != 0 or entry['role'] == 'narrator'
                        if entry['character_seq_id'] != chaid and entry['role'] != 'narrator':
                            continue
                        dl = get_len(entry['description'])
                        if dl == 0:
                            continue
                        l += dl
                        if l > len_limit:
                            break
                        w = get_entry(entry)
                        if entry['role'] != 'narrator':
                            chaid = entry['role'].split(':')[-1]
                            w['role'] = characters[chaid]['name']
                            flag = True
                            if len(w['cards']) == 0:
                                temp['entries'].append(w)
                                continue
                            templ = l
                            for card in w['cards']:
                                templ += get_len(card['description'])
                                if templ > len_limit:
                                    flag = False
                                    break
                            temp['entries'].append(w)
                            if flag:
                                res.append(w['role'])
                                # assert temp['entries'][-1]['role'] != 'narrator'
                                # res.append(deepcopy(temp))
                        else:
                            temp['entries'].append(w)
                        # flag = True
                        # temp_cards = []
                        # w = get_entry(entry)
                        # if idx > 0 and w['role'] != 'narrator':  # neglect narrator entry cards as they are not used
                        #     for card in w['cards']:
                        #         if card['card_id'] not in now_cards:
                        #             now_cards.add(card['card_id'])
                        #             l += get_len(card['description'])
                        #             if l > len_limit:
                        #                 flag = False
                        #                 break
                        #             # temp['extra_cards'].append(card)
                        #             temp_cards.append(card)
                        # if not flag:
                        #     break
                        # temp['entries'].append(w)
                        # temp['extra_cards'].extend(temp_cards)
                        # if w['role'] != 'narrator' and len(w['cards']) > 0:
                        #     character_entry_appear_with_cards = True

                    # if len(temp['entries']) > 1 and character_entry_appear_with_cards:
                    #     res.append(temp)
    print(len(res))
    with open(out_file_name, 'w', encoding='utf-8') as out_f:
        json.dump(res, out_f, indent=1)

'''
                            chaid = entry['role'].split(':')[-1]
                            # print('ori role = ', entry['role'])
                            w['role'] = characters[chaid]['name']
                            # print('new role = ', entry['role'])                            
                            # print('entry role = ', entry['role'])
'''

if __name__ == '__main__':
    # split_file('test_stories.json', 'test.json')
    # split_file('valid_stories.json', 'valid.json')
    # split_file('train_stories.json', 'train.json')

    # split_file_one_character_entry('../data/test_stories.json', '../data/test_character.json')
    # split_file_one_character_entry('../data/valid_stories.json', '../data/valid_character.json')
    # split_file_one_character_entry('../data/train_stories.json', '../data/train_character.json')

    parse_name('../data/test_stories.json', '../data/test_character.json')
    parse_name('../data/valid_stories.json', '../data/valid_character.json')
    parse_name('../data/train_stories.json', '../data/train_character.json')