import argparse
import os
import ipdb
from tqdm import tqdm
import string

if __name__ == '__main__':
    parser = argparse.ArgumentParser('clean input file')
    parser.add_argument('--fp1', type=str, help='sentences')
    parser.add_argument('--fp2', type=str, help='predictions')
    parser.add_argument('--model_type', type=str, required=True, help='model type')
    parser.add_argument('--data_type', type=str, required=True, help='data type')
    parser.add_argument('--lang', type=str, required=True, help='lang')
    parser.add_argument('--out', type=str, help='output file')
    args = parser.parse_args()

    args.fp2 = f"../models/{args.lang}/{args.model_type}/{args.data_type}-data/"
    if args.model_type == "genoie":
        predicted_file = args.fp2 + '/test.predicted'
        args.fp1 = f"./carb/data/{args.lang}_test.input"
    else:
        predicted_file = args.fp2 + '/test.pre_predicted'
        args.fp1 = args.fp2 + '/test.input'
    with open(args.fp1, 'r') as f1,\
         open( predicted_file,'r') as f2,\
         open(args.fp2+'/test_input.predicted', 'w') as f3,\
         open(args.fp2 + '/test_extractions.predicted', 'w') as f4,\
         open(args.fp2 + '/test_sentences.predicted', 'w') as f5:

        sentences = f1.readlines()
        extractions = f2.readlines()
        
        assert len(sentences) == len(extractions), ipdb.set_trace()
        ext_dict = {}
        for ind in range(len(sentences)):
            sentence = sentences[ind].strip()
            extractions_list = list(set(extractions[ind].strip().split('<e>')))
            if args.model_type == 'genoie':
                orig_sentence = sentence
            else:
                orig_sentence = sentence.split('<r>')[1].strip()
            if orig_sentence not in ext_dict:
                ext_dict[orig_sentence] = set()
            for ext in extractions_list:
                ext = ext.strip()
                if ext != "" and ext not in ext_dict[orig_sentence]:
                    ext_dict[orig_sentence].add(ext)
                    ext = ext + ' <e>'
                    f3.write(sentence + '\n')
                    f4.write(ext + '\n')
                    f5.write(orig_sentence + '\n')