import argparse
import random
from multiprocessing import Pool
from rouge import Rouge
from pysbd import Segmenter

rouge_scorer = Rouge()

def sentSplitFn(s: str, args, seg):
    if args.d is not None:
        return s.split(args.d)
    else:
        sents = seg.segment(s)
        return sents

def sentJoinFn(sents: list, args):
    delimiter = args.d if args.d is not None else " "
    return delimiter.join(sents)

def readTxt(fname):
    data = []
    with open(fname, 'rb') as fin:
        for line in fin:
            data.append(line.decode('utf-8').strip())
    print("Reading {} example from {}".format(len(data), fname))
    return data

def saveTxt(data, fname):
    with open(fname, 'w') as fout:
        for d in data:
            fout.write('{}\n'.format(d))
    print('Save {} example to {}'.format(len(data), fname))

def getRouge1(target, hypo):
    target_tokens = target.split()
    hypo_tokens = hypo.split()

    overlaps = [token for token in hypo_tokens if token in target_tokens]
    p = len(overlaps) / (len(hypo_tokens) + 1e-3)
    r = len(overlaps) / (len(target_tokens) + 1e-3)
    return (2 * p * r / (p + r + 1e-3)) * 100

def getOracle(targets: list, sources: list):
    target_str = ".".join(targets)
    candidate = ""
    max_rouge = 0
    for sent in sources:
        if sent.strip() == "":
            continue
        if candidate == "":
            next_candiate = sent
        else:
            next_candiate = candidate + ". " + sent
        score = getRouge1(target_str.lower(), next_candiate.lower())
        if score > max_rouge:
            candidate = next_candiate
            max_rouge = score
    return max_rouge

def getPseudoSummFn(document: str, args, seg):
    sents = sentSplitFn(document, args, seg)
    doc_sent_num = random.randint(5, 30)
    summ_sent_num = int(doc_sent_num * 0.2)
    if len(sents) < doc_sent_num:
        return [None, None]
    else:
        start_index = random.randint(0, len(sents) - doc_sent_num)
        doc_sents = sents[start_index:start_index+doc_sent_num]
        summ_sents = random.sample(doc_sents, summ_sent_num)
        source = sentJoinFn(doc_sents, args)
        target = sentJoinFn(summ_sents, args)
        return [source, target]

def getPseudoSumm(args):
    datas = readTxt(args.i)
    # sources, targets = [], []
    import time
    start_time = time.time()
    seg = Segmenter(language=args.l)
    with Pool(args.t) as pool:
        results = pool.starmap_async(getPseudoSummFn, [(data, args, seg) for data in datas])
        results = results.get()
    print("speed: {:.1f}".format(len(datas) / (time.time() - start_time)))
    sources = [item[0] for item in results if item[0] is not None]
    targets = [item[1] for item in results if item[0] is not None]
    
    saveTxt(sources, args.os)
    saveTxt(targets, args.ot)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-l', help='lang', type=str, default='zh')
    parser.add_argument('-i', help='input', type=str, default='input.txt')
    parser.add_argument('-d', help='delimitor', type=str, default=None)
    parser.add_argument('-os', help='output source', type=str, default='output.src')
    parser.add_argument('-ot', help='output target', type=str, default='output.tgt')
    parser.add_argument('-m', help='mode', type=str, default="getPseudoSumm")
    parser.add_argument('-t', help='number of thread', type=int, default=20)
    args = parser.parse_args()

    eval("{}(args)".format(args.m))

# python3 createData.py -i /home/tiger/mlsum/de/dev.de.doc -os output.src -ot output.tgt -m getPseudoSumm -d <q>
