import json
import argparse

def readTxt(fname):
    data = []
    with open(fname, 'rb') as fin:
        for line in fin:
            data.append(line.decode('utf-8').strip())
    print("Reading {} example from {}".format(len(data), fname))
    return data

def saveTxt(data, fname):
    with open(fname, 'w') as fout:
        for d in data:
            fout.write('{}\n'.format(d))
    print('Save {} example to {}'.format(len(data), fname))

def saveJsonl(datas, fname):
    with open(fname, 'w') as fout:
        for data in datas:
            fout.write(json.dumps(data, ensure_ascii=False) + '\n')
    print("write {} lines to {}".format(
        len(datas), fname
    ))

def LCS(a: list, b: list):
    """
    return the LCS and its max_positions
    """ 
    m = len(a)
    n = len(b)
    import numpy as np
    f = np.zeros((m+1, n+1), dtype=int)
    for i in range(0, m):
        for j in range(0, n):
            f[i+1, j+1] = max(f[i+1, j], f[i, j+1], f[i+1, j+1])
            if a[i] == b[j]:
                f[i+1, j+1] = max(f[i+1, j+1], f[i, j] + 1)

    # backward
    res = ""
    result_pos = []
    i, j = m, n
    while i > 0 and j > 0:
        if a[i-1] == b[j-1] and f[i, j] == f[i-1, j-1] + 1:
            res = a[i-1] + res
            result_pos = [i-1] + result_pos
            i, j = i-1, j-1
        elif f[i-1, j] > f[i, j-1]:
            i = i-1
        else:
            j = j-1
    return res, result_pos
 
def tokenLevelOracle(args):
    sources = readTxt(args.s)
    targets = readTxt(args.t)

    results = []
    for (source, target) in zip(sources, targets):
        lcs, lcs_pos = LCS(source, target)
        results.append({
            "source": source,
            "target": target,
            "oracle_str": lcs,
            "oracle_idx": lcs_pos
        })
    saveJsonl(results, args.o)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("-s", help="source path")
    parser.add_argument("-t", help="target path")
    parser.add_argument("-o", help="output path")
    parser.add_argument("-m", help="mode", default="tokenLevelOracle")

    args = parser.parse_args()

    eval("{}(args)".format(args.m))
