import sys
import os
import re
from tqdm import tqdm
from random import shuffle

def process():
    # Make directory for output if it doesn't exist
    try:
        os.mkdir(sys.argv[2])
    except OSError:
        pass

    out_s1 = open(os.path.join(sys.argv[2], 's1'), "w")
    out_s2 = open(os.path.join(sys.argv[2], 's2'), "w")

    directory_files = []
    for root, dirs, files in os.walk(sys.argv[1]):
        for file in files:
            if file != ".dircksum":
                directory_files.append(os.path.join(root, file))
    #directory_files = [os.path.join(sys.argv[1], f) for f in os.listdir()]

    # Parse and print titles and articles
    NONE, HEAD, NEXT, TEXT = 0, 1, 2, 3
    MODE = NONE
    title_parse = ""
    article_parse = []

    # FIX: Some parses are mis-parenthesized.
    def fix_paren(parse):
        if len(parse) < 2:
            return parse
        if parse[0] == "(" and parse[1] == " ":
            return parse[2:-1]
        return parse

    def get_words(parse):
        words = []
        for w in parse.split():
            if w[-1] == ')':
                words.append(w.strip(")"))
                if words[-1] == ".":
                    break
        return words

    def remove_digits(parse):
        return re.sub(r'\d', '#', parse)

    for filename in tqdm(directory_files):
        for l in open(filename, 'r'):
            if MODE == HEAD:
                title_parse = remove_digits(fix_paren(l.strip())).lower()
                MODE = NEXT

            if MODE == TEXT:
                article_parse.append(remove_digits(fix_paren(l.strip())).lower())

            if MODE == NONE and l.strip() == "<HEADLINE>":
                MODE = HEAD

            if MODE == NEXT and l.strip() == "<P>":
                MODE = TEXT

            if MODE == TEXT and l.strip() == "</P>":
                articles = []
                # Annotated gigaword has a poor sentence segmenter.
                # Ensure there is a least a period.

                for i in range(len(article_parse)):
                    articles.append(article_parse[i])
                    if "(. .)" in article_parse[i]:
                        break

                article_parse = " ".join(articles[:-1])

                # title_parse \t article_parse \t title \t article
                print(article_parse,file = out_s1)
                print(title_parse, file = out_s2)
                article_parse = []
                MODE = NONE

    out_s1.close()
    out_s2.close()

def split(path, total_number, test_size = 0.01, dev_size = 0.01):
    indices = list(range(0, total_number))
    shuffle(indices)
    train_size = int(total_number * (1. - test_size - dev_size))
    dev_size = int(total_number * dev_size)
    train_indices = set(indices[:train_size])
    dev_indices = set(indices[train_size:(train_size+dev_size)])
    test_indices = set(indices[(train_size+dev_size):])

    s1_train = open(os.path.join(path, "s1.train"), 'w')
    s2_train = open(os.path.join(path, "s2.train"), 'w')
    s1_dev = open(os.path.join(path, "s1.dev"), 'w')
    s2_dev = open(os.path.join(path, "s2.dev"), 'w')
    s1_test = open(os.path.join(path, "s1.test"), 'w')
    s2_test = open(os.path.join(path, "s2.test"), 'w')

    def write_accordingly(s):
        nonlocal s1_train, s2_train, s1_dev, s2_dev, s1_test, s2_test
        with open(os.path.join(path, s), 'r') as f:
            for i, l in enumerate(f):
                if i % 1000 == 0:
                    print(i)
                if i in train_indices:
                    eval(s+"_train").write(l)
                elif i in dev_indices:
                    eval(s+"_dev").write(l)
                elif i in test_indices:
                    eval(s+"_test").write(l)
    
    write_accordingly("s1")
    write_accordingly("s2")

    s1_train.close()
    s2_train.close()
    s1_dev.close()
    s2_dev.close()
    s1_test.close()
    s2_test.close()

if __name__ == "__main__":
    process()
    outpath = sys.argv[2]
    split(outpath, 8583557)
