import json
import argparse
from collections import defaultdict

def readTxt(fname):
    data = []
    with open(fname, 'rb') as fin:
        for line in fin:
            data.append(line.decode('utf-8').strip())
    print("Reading {} example from {}".format(len(data), fname))
    return data

def saveTxt(data, fname):
    with open(fname, 'w') as fout:
        for d in data:
            fout.write('{}\n'.format(d))
    print('Save {} example to {}'.format(len(data), fname))

def readJsonl(fname):
    data = []
    with open(fname, 'rb') as fin:
        for line in fin:
            data.append(json.loads(line))
    print("Reading {} example from {}".format(len(data), fname))
    return data

def matchTransSTReference(trans: list, references: list):
    def createHashMap(trans, key="src_text", value="trg_text"):
        hashmap = {}
        for tran in trans:
            hashmap[tran[key]] = tran[value]
        return hashmap
    
    hashmap = createHashMap(trans)
    sources, targets = [], []
    for ref in references:
        sources.append(ref)
        targets.append(hashmap.get(ref, "").lower())
    print("match trans s.t. reference, match {} / {}".format(
        len([item for item in targets if item != ""]), len(targets)
    ))
    return sources, targets

def concatTrans(args):
    inputs = [readJsonl(filename) for filename in args.i]
    references = [readTxt(filename) for filename in args.r]

    matched_trans = []
    for (tran, reference) in zip(inputs, references):
        sources, targets = matchTransSTReference(tran, reference)
        matched_trans.extend([sources, targets])

    newdatas = []
    for i in range(len(matched_trans[0])):
        groups = [split[i] for split in matched_trans]
        new_groups = [item for item in groups if item.strip() != ""]
        if len(groups) == len(new_groups):
            newdatas.append(groups)
    print("keep {} / {} instances".format(len(newdatas), len(matched_trans[0])))
    results = defaultdict(list)
    for items in newdatas:
        for (outfile, item) in zip(args.o, items):
            results[outfile].append(item)

    for filename in args.o:
        saveTxt(results[filename], filename)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', type=str, nargs="+")
    parser.add_argument('-o', type=str, nargs="+")
    parser.add_argument('-r', type=str, nargs="+")
    args = parser.parse_args()

    concatTrans(args)
