"""Parse sequences in the dataset. Create a new vocabulary."""

import os
import json
import stanza
import argparse
import collections

nlp = stanza.Pipeline('en')


def parse_str_item(s, vocab_counter=None):
    doc = nlp(s.strip())
    doc_words = [ w.text for sent in doc.sentences
        for w in sent.words]
    doc_words = [dw.strip().lower() for dw in doc_words]
    doc_words = [dw for dw in doc_words if dw!='']
    if vocab_counter is not None:
        vocab_counter.update(doc_words)
    return doc_words

def parse_str_list(string_list, vocab_counter=None):
    parsed_string_list = []
    for string in string_list:
        doc_words = parse_str_item(string, vocab_counter)
        parsed_string_list.append(doc_words)
    return parsed_string_list

def parse_fielded_list(fielded_list, vocab_counter=None):
    parsed_fielded_list = []
    for attr, value in fielded_list:
        value_words = parse_str_item(value, vocab_counter)
        parsed_fielded_list.append( (attr, value_words) )
    return parsed_fielded_list


def parse_datafile(infile: str, outfile: str, vocab_counter):
    """Parse the in-file dataset, write into the out-file, update the vocab-counter."""
    
    output_instances = []
    with open(infile, 'r', encoding='utf-8') as fr:
        for line in fr:
            inins = json.loads(line.strip())
            outins = {
                'table_id': inins['table_id'], 
                'sub_sent_id': inins['sub_sent_id'], 
                'source': parse_str_list(inins['source'], vocab_counter), 
                'target': parse_str_item(inins['target'], vocab_counter), 
                'table_parent': parse_fielded_list(inins['table_parent'], vocab_counter), 
                'operations': inins['operations'], 
            }
            output_instances.append(outins)
    
    with open(outfile, 'w', encoding='utf-8') as fw:
        for outins in output_instances:
            outline = json.dumps(outins)
            fw.write(outline + '\n')

    print(f'from [{infile}] to [{outfile}]')



def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset_root', type=str, default='../data')
    parser.add_argument('--input_subdir', type=str, default='table-split')
    parser.add_argument('--output_subdir', type=str, default='table-split-parsed')
    parser.add_argument('--file_subdir', type=str, required=True)

    parser.add_argument('--vocab_size', type=int, default=30000)
    parser.add_argument('--vocab_path', type=str, default='./hitab_vocab.txt')
    parser.add_argument('--create_vocab', action='store_true')

    args = parser.parse_args()

    input_dir = os.path.join(args.dataset_root, args.input_subdir, args.file_subdir)
    output_dir = os.path.join(args.dataset_root, args.output_subdir, args.file_subdir)
    if not os.path.exists(output_dir): os.makedirs(output_dir)
    print(f'from: {input_dir}')
    print(f'  to: {output_dir}')

    filenames = os.listdir(input_dir)
    print(f'with filenames: {filenames}')
    input_files = [os.path.join(input_dir, f) for f in filenames]
    output_files = [os.path.join(output_dir, f) for f in filenames]

    vocab_counter = collections.Counter()
    for infile, outfile in zip(input_files, output_files):
        parse_datafile(infile=infile, outfile=outfile, vocab_counter=vocab_counter)

    if args.create_vocab:
        print("Writing vocab file...")
        with open(args.vocab_path, 'w', encoding='utf-8') as fw:
            for word, count in vocab_counter.most_common(args.vocab_size):
                fw.write(f'{word} {count}\n')
        print("Finished writing vocab file")
    
    print('BYE :)')



if __name__ == "__main__":
    main()
