import os
import sys
import argparse
import json

from multiprocessing import Pool

sys.path.append('.')
sys.path.append('..')
from logConfig import Log
from calRouge import str2char
from getlabel import calLabel

def getfiles(path):
    if os.path.isdir(path):
        files = [os.path.join(path, f) for f in os.listdir(path)]
        files = sorted(files)
    else:
        files = [path]
    return files

def chunks(l, n):
    """Yield successive n-sized chunks from l."""
    for i in range(0, len(l), n):
        yield l[i:i + n]

def makeLabel(data, dockey, sumkey, language):
    ndata = []
    for e in data:
        summary = [e[sumkey]] if isinstance(e[sumkey], str) else e[sumkey]
        if language == 'zh':
            article_sents = [str2char(s, language) for s in e[dockey]]
            abstract_sentences = [str2char(s, language) for s in summary]
        else:
            article_sents = e[dockey]
            abstract_sentences = summary
        original_abstract = "\n".join(abstract_sentences)
        e['label'] = calLabel(article_sents, original_abstract)
        ndata.append(e)
    return ndata

def run(args):
    data, dockey, sumkey, language, pool_id = args
    logger.info('Task pid {0} is running, parent id is {1} with {2} examples'.format(os.getpid(), os.getppid(), len(data)))
    ndata = makeLabel(data, dockey, sumkey, language)
    logger.info('Task {0} end.'.format(os.getpid()))
    return ndata

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', type=str, default='.', help='dataset dir')
    parser.add_argument('-p', type=int, default=8, help='process number')
    parser.add_argument('-d', type=str, default='document', help='document key name')
    parser.add_argument('-s', type=str, default='summary', help='summary key name')
    parser.add_argument('-l', type=str, default='zh', help='language')
    parser.add_argument('-g', type=str, default='label.log', help='log file')

    args = parser.parse_args()
    logger = Log.getLogger(sys.argv[0], args.g)
    logger.info(args)

    dockey, sumkey, lang = args.d, args.s, args.l

    inputfiles = getfiles(args.i)
    outputfiles = [name.replace(name.split('.')[-1], "label."+ name.split('.')[-1]) for name in inputfiles]

    for infile, outfile in zip(inputfiles, outputfiles):
        data = []
        filename = infile.split('/')[-1]
        if not ((filename.startswith("train") or filename.startswith("dev")) and filename.endswith("jsonl")):
            continue
        with open(infile) as f:
            for line in f:
                data.append(json.loads(line))

        logger.info("Make Label for %s with %d examples !", infile, len(data))
        data_chunks = list(chunks(data, int(len(data) / args.p)))

        n_pool = len(data_chunks)
        arg_lst = []
        for i in range(n_pool):
            arg_lst.append((data_chunks[i], dockey, sumkey, lang, i))
        pool = Pool(n_pool)
        results = pool.map(run, arg_lst)

        totals = []
        for i, chunk in enumerate(results):
            totals.extend(chunk)

        with open(outfile, 'w') as fout:
            if 'id' in totals[0].keys():
                logger.info("Sorting results by id !")
                sorted_totals = sorted(totals, key=lambda x: x['id'])
            for d in sorted_totals:
                fout.write(json.dumps(d, ensure_ascii=False) + "\n")

        logger.info("Save Label to %s with %d examples !", outfile, len(data))


    # python3 makelabel_para.py -l en -p 60 -i /home/tiger/WikiLingua