import argparse
import random

def readTxt(file):
    datas = []
    with open(file, 'r') as fin:
        for line in fin:
           datas.append(line.strip())
        print("read {} lines from {}".format(len(datas), file))
    return datas

def saveTxt(datas, file):
    with open(file, 'w') as fout:
        for data in datas:
            fout.write(data + '\n')
        print("save {} lines to {}".format(len(datas), file))

def getChunks(l, n):
    """Yield successive n-sized chunks from l."""
    for i in range(0, len(l), n):
        yield l[i:i + n]

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--input-prefix", type=str, nargs="+"
    )

    parser.add_argument(
        "--output-prefix", type=str
    )
    parser.add_argument(
        "-b", type=int, default=8, help="batch size"
    )

    args = parser.parse_args()

    datas = []
    for prefix in args.input_prefix:
        source_file = prefix + ".doc"
        target_file = prefix + ".sum"
        sources = readTxt(source_file)
        targets = readTxt(target_file)
        data = list(zip(sources, targets))
        data = sorted(data, key=lambda x: len(x[0].split()))
        datas.append(data)
    batch_num = sum(len(data) for data in datas) / args.b
    assert int(batch_num) == batch_num
    batch_num = int(batch_num)
    data_chunks = []
    for items in datas:
        chunk_size = len(items) // batch_num
        chunks = list(getChunks(items, chunk_size))
        data_chunks.append(chunks)
    
    merged_data = []
    for i in range(batch_num):
        merged_chunk = []
        for chunk in data_chunks:
            merged_chunk.extend(chunk[i])
        random.shuffle(merged_chunk)
        merged_data.extend(merged_chunk)
    sources = []
    targets = []
    for items in merged_data:
        sources.append(items[0])
        targets.append(items[1])
    saveTxt(sources, args.output_prefix + ".doc")
    saveTxt(targets, args.output_prefix + ".sum")

    # python3 mergeFileByLength.py --input-prefix /mnt/bd/lab-wxz/clt/langReverse/MSPM/cc100.dev.zh.spm /mnt/bd/lab-wxz/clt/langReverse/MSPM/cc100.dev.fr.spm -b 8 --output-prefix tmp.out