import json
import os
import sys
from tqdm import tqdm
from transformers import GPT2Tokenizer
from zhon import hanzi
import string

gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2_en_ckpt_origin')
def get_len(text, gpt2 = True):
    if text:
        if gpt2:
            return len(gpt2_tokenizer.encode(text))
        else:
            return len(re.split(' |\r|\n|\t', text))
    return 0
# def traverse()

train_list = []  # list of stories, one story is a dict
valid_list = []
test_list = []


def valid_sentence(s):
    if s:
        return len([i for i in s if i not in string.printable and i not in hanzi.punctuation]) == 0
    else:
        return True

cnt = 0
def process(in_file_name):
    # print(in_file_name, cnt)
    global cnt
    with open(in_file_name) as in_f:
        a = json.load(in_f)

    for scene in a["scenes"]:
        flag = True
        for entry in scene["entries"]:
            if not valid_sentence(entry['description']):
                flag = False
            for card in entry['cards_played_on_challenge']:
                if not valid_sentence(card['description']):
                    flag = False
            for card in entry['cards_for_pickup']:
                if not valid_sentence(card['description']):
                    flag = False
        if not flag:
            cnt += 1


def read_filenames(file_name):
    res = set()
    with open(file_name) as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            res.add(line.split('/')[-1])
    return res


train_set = read_filenames('train_filenames.txt')
valid_set = read_filenames('validation_filenames.txt')
test_set = read_filenames('test_filenames.txt')
print(train_set & valid_set)
print(train_set & test_set)
print(valid_set & test_set)
print(len(train_set) + len(valid_set) + len(test_set))
print(len(train_set), len(valid_set), len(test_set))
# exit()

train_file = open('train_stories.json', 'w', encoding='utf-8')
valid_file = open('valid_stories.json', 'w', encoding='utf-8')
test_file = open('test_stories.json', 'w', encoding='utf-8')

for root, dirs, files in tqdm(os.walk(os.getcwd())):
    # print(root, dirs, files)
    for file_name in files:
        if not file_name.endswith('.json'):
            continue
        real_file_name = os.path.join(root, file_name)
        if file_name in train_set:
            process(real_file_name)
        elif file_name in valid_set:
            process(real_file_name)
        elif file_name in test_set:
            process(real_file_name)

print('cnt = ', cnt)


# print(cnt)
# if os.path.isdir(f):
