import argparse
import functools
import os

import seqio
import t5

def count_line(in_fname):
    with open(in_fname, "r") as infile:
        return len(infile.readlines())


def main(args):

    # cl_counts_path = os.path.join(args.data_dir, "new-cl-counts.json")
    # cl_tsv_path = {
    #     "train": os.path.join(args.data_dir, "clang8.tsv"),
    #     "validation": os.path.join(args.data_dir, "bea-dev-pair.tsv")
    #     "test": os.path.join(args.data_dir, "bea-test-pair.tsv")
    #     "conll13": os.path.join(args.data_dir, "conll14-pair.tsv")
    #     "conll14": os.path.join(args.data_dir, "conll14-pair.tsv")
    # }
    
    assert args.tsv_path or args.source_path
    split = "dummy_split"
    cl_tsv_path = {
        split: os.path.join(args.tsv_path),
    }
    num_cl_examples = {}
    for split, fname in cl_tsv_path.items():
        num_cl_examples[split] = count_line(fname)
    
    DEFAULT_OUTPUT_FEATURES = {
        "inputs":
            seqio.Feature(
                vocabulary=t5.data.get_default_vocabulary(), add_eos=True),
        "targets":
            seqio.Feature(
                vocabulary=t5.data.get_default_vocabulary(), add_eos=True)
    }
    seqio.TaskRegistry.add(
        "clang8.en",
        source=seqio.TextLineDataSource(
            split_to_filepattern=cl_tsv_path,
            num_input_examples=num_cl_examples),
        preprocessors=[
        functools.partial(
            t5.data.preprocessors.parse_tsv,
            field_names=["inputs", "targets"]),
        seqio.preprocessors.tokenize_and_append_eos,
        ],
        # postprocess_fn=t5.data.postprocessors.lower_text, 
        metric_fns=[],
        output_features=DEFAULT_OUTPUT_FEATURES,
    )
    task = seqio.get_mixture_or_task("clang8.en")
    # sequence_length = {"inputs": 128, "targets": 128}
    sequence_length = {
                        "inputs": args.sequence_length,
                        "targets": args.sequence_length,
                    }
    
    ds = task.get_dataset(sequence_length, split, shuffle=False)
    for d in ds:
        print(d)

    vocab = t5.data.get_default_vocabulary()
    print(vocab)



def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--tsv_path', type=str, default=None, help='tsv path')
    parser.add_argument('--source_path', type=str, default=None, help='text source path')
    parser.add_argument('--tokenizer_model', type=str, default="spiece.model",
        help='path to the SentencePiece tokenizer model')
    parser.add_argument('--output', type=str, required=True, help="vocab output path")
    parser.add_argument('--sequence_length', type=int, default=1024, help="desired vocabulary size")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_arguments()
    main(args)
