import argparse
import os
import ipdb
import random
import stanza
from tqdm import tqdm
import time

random.seed(0)

if __name__ == '__main__':
    parser = argparse.ArgumentParser('clean input file')
    parser.add_argument('--fp', type=str)
    parser.add_argument('--lang', type=str, required=True, help='language')
    parser.add_argument('--out', type=str)
    args = parser.parse_args()

    if 'ud_' in args.lang:
        proc_lang = args.lang.split('_')[1]
    else:
        proc_lang = args.lang
    nlp = stanza.Pipeline(proc_lang, tokenize_pretokenized= True, tokenize_no_ssplit=True, processors='tokenize,pos')
    args.fp = f"./carb/data/{args.lang}_test.input"
    args.out = f"./carb/data/{args.lang}_test_s1.input"
    with open(args.fp, 'r') as f:
        sentences = f.readlines()

    def make_input(sent_tgt):
        tgt = []
        for sent_part in nlp(" ".join(sent_tgt)).to_dict():
            for c in sent_part:
                tgt.append(c)
        assert len(tgt) == len(sent_tgt), ipdb.set_trace()
        tgt_tag = []
        for ind in range(len(sent_tgt)):
            assert tgt[ind]['text'] == sent_tgt[ind], ipdb.set_trace()
            tgt_tag.append(tgt[ind]['upos'])

        input = []
        for ind in range(len(sent_tgt)):
            input.append('# ' + tgt_tag[ind].strip() + ' ## ' + sent_tgt[ind].strip())
            # no pos tags
            # input.append(sent_tgt[ind].strip())
        input = " ".join(input).strip()
        return input

    input = []
    for i in tqdm(range(len(sentences))):
        input.append(make_input(sentences[i].strip().split()))

    test_input = open(args.out, 'w')

    for i in range(len(input)):
        test_input.write(input[i].strip() + '\n')

    test_input.close()