import argparse
import os
import ipdb
import random
from tqdm import tqdm
import time

random.seed(0)

if __name__ == '__main__':
    parser = argparse.ArgumentParser('clean input file')
    parser.add_argument('--fp1', type=str)
    parser.add_argument('--fp2', type=str)
    parser.add_argument('--data_type', type=str, required=True, help='model type')
    parser.add_argument('--lang', type=str, required=True, help='language')
    parser.add_argument('--out', type=str)
    args = parser.parse_args()

    args.fp1 = f"../models/{args.lang}/gen2oie_s1/{args.data_type}-data/test.predicted"
    args.fp2 = f"./carb/data/{args.lang}_test.input"
    args.out = f"../models/{args.lang}/gen2oie_s2/{args.data_type}-data/"
    with open(args.fp1, 'r') as f:
        relations = f.readlines()

    with open(args.fp2, 'r') as f:
        sentences = f.readlines()

    assert len(relations) == len(sentences)

    def get_relations(x):
        res = []
        for r in x.split('<r>'):
            if r.strip() != "":
                res.append(r.strip())
        return res

    test_relations = []
    count_relations = []
    for i in range(len(sentences)):
        all_relations = get_relations(relations[i])
        count_relations.append(str(len(all_relations)))
        for rel in all_relations:
            test_relations.append(rel.strip() + ' <r> ' + sentences[i].strip())

    with open(args.out+'/test.input', 'w') as f:
        f.write("\n".join(test_relations).strip())
    with open(args.out+'/test.count', 'w') as f:
        f.write("\n".join(count_relations).strip())
