import re
import argparse
from multiprocessing import Pool
from nltk.tokenize import sent_tokenize
from rouge import Rouge

rouge_scorer = Rouge()

def sentSplitFn(s: str, args):
    if args.d is not None:
        return s.split(args.d)
    else:
        pattern = r"[.?!]"
        return re.split(pattern, s)

def readTxt(fname):
    data = []
    with open(fname, 'rb') as fin:
        for line in fin:
            data.append(line.decode('utf-8').strip())
    print("Reading {} example from {}".format(len(data), fname))
    return data

def saveTxt(data, fname):
    with open(fname, 'w') as fout:
        for d in data:
            fout.write('{}\n'.format(d))
    print('Save {} example to {}'.format(len(data), fname))

def getRouge1(target, hypo):
    target_tokens = target.split()
    hypo_tokens = hypo.split()

    overlaps = [token for token in hypo_tokens if token in target_tokens]
    p = len(overlaps) / (len(hypo_tokens) + 1e-3)
    r = len(overlaps) / (len(target_tokens) + 1e-3)
    return (2 * p * r / (p + r + 1e-3)) * 100


def getOracle(targets: list, sources: list):
    target_str = ".".join(targets)
    candidate = ""
    max_rouge = 0
    for sent in sources:
        if sent.strip() == "":
            continue
        if candidate == "":
            next_candiate = sent
        else:
            next_candiate = candidate + ". " + sent
        score = getRouge1(target_str.lower(), next_candiate.lower())
        if score > max_rouge:
            candidate = next_candiate
            max_rouge = score
    return max_rouge

def getPseudoSummFn(document: str, args):
    sents = sentSplitFn(document, args)
    # print(sents)
    target = sents[:args.M]
    source = sents[args.M:]
    oracle_score = getOracle(target, source)
    if args.b[0] <= oracle_score <= args.b[1]:
        return source, target, oracle_score
    else:
        return None, None, oracle_score


def getPseudoSumm(args):
    datas = readTxt(args.i)
    # sources, targets = [], []
    import time
    start_time = time.time()
    with Pool(args.t) as pool:
        results = pool.starmap_async(getPseudoSummFn, [(data, args) for data in datas])
        results = results.get()
    print("speed: {:.1f}".format(len(datas) / (time.time() - start_time)))
    sources = [item[0] for item in results if item[0] is not None]
    targets = [item[1] for item in results if item[0] is not None]
    # targets = [item[1] for item in results if item[0] is not None]
    # for result in results:
    #     print(result[-1])

    saveTxt(sources, args.os)
    saveTxt(targets, args.ot)
    
    # return sources, targets

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-l', help='lang', type=str, default='zh')
    parser.add_argument('-i', help='input', type=str, default='input.txt')
    parser.add_argument('-d', help='delimitor', type=str, default=None)
    parser.add_argument('-os', help='output source', type=str, default='input.txt')
    parser.add_argument('-ot', help='output target', type=str, default='input.txt')
    parser.add_argument('-M', help='select the first M sentences as references', type=int, default=3)
    parser.add_argument('-m', help='mode', type=str, default="getPseudoSumm")
    parser.add_argument('-t', help='number of thread', type=int, default=20)
    parser.add_argument('-b', help='rouge bin', type=list, default=[40, 60])
    args = parser.parse_args()

    eval("{}(args)".format(args.m))

# python3 createData.py -i /home/tiger/mlsum/de/dev.de.doc -os output.src -ot output.tgt -m getPseudoSumm -d <q>