import os
import tensorflow_datasets as tfds
from collections import defaultdict


sents_limit = 2000000
# lan_list = ["ar", "de", "el", "en", "es", "hi", "ru", "th", "tr", "vi", "zh-cn"]
lan_list = ["fr", "en", "zh-cn"]
text_wiki = defaultdict(list)
for lan in lan_list:
    dataset = tfds.load(f"wiki40b/{lan}")["train"]
    dataset = tfds.as_numpy(dataset)
    cnt = 0
    for example in dataset:
        for text_sub in example["text"].decode("utf-8").split("_START_PARAGRAPH_"):
            if "_START_" not in text_sub:
                text_list = text_sub.split("_NEWLINE_")
                text_list = [a_.replace("\n", "") for a_ in text_list if len(a_) >= 10]
                if text_list:
                    text_wiki[lan].extend(text_list)
                    cnt += len(text_list)
        if cnt >= sents_limit:
            break

OUTDIR="/opt/tiger/sumtest/wiki40b"
os.makedirs(OUTDIR, exist_ok=True)

for lan in lan_list:
    outlan = lan
    if lan == "zh-cn":
        outlan = "zh"
    fout = open(os.path.join(OUTDIR, "{}.first2m.txt".format(outlan)), 'w')
    for text in text_wiki[lan]:
        fout.write(text.replace('\n', '\t') + '\n')
    fout.close()
