import os
import json
import argparse
import pysbd
import numpy as np
import sys

sys.path.append(".")

from ioFn import *
from logConfig import Log
from multiprocessing import Pool

def sentSplitFn(s: str, args, seg=None):
    if args.d is not None:
        return s.split(args.d)
    else:
        assert seg is not None
        sents = seg.segment(s)
        return sents

def convertPairsToJsonlFn(documents: list, summaries: list, sent_split=None, lg='en'):
    datas = []
    seg = pysbd.Segmenter(language=lg, clean=False)
    for (i, (raw_doc, raw_summ)) in enumerate(zip(documents, summaries)):
        doc_sents = raw_doc
        # summ_sents = raw_summ        
        if sent_split == "full_stop":
            doc_sents = [item for item in raw_doc.split('.')]
            # summ_sents = [item for item in raw_summ.split('.')]
        elif sent_split == "pysbd":
            doc_sents = seg.segment(raw_doc)
            # summ_sents = seg.segment(raw_summ)
        elif sent_split == "delimiter":
            doc_sents = raw_doc.split(args.d)
            # summ_sents = raw_summ.split(args.d)
        datas.append(
            {
                "id": i,
                "document": doc_sents,
                "summary": raw_summ
            }
        )
    return datas


def jsonl2TxtData(args):
    if args.jsonl is None:
        raise NotImplementedError("temporarily do not support the case that input is a directory")

    if args.add_sent_delimiter:
        seg = pysbd.Segmenter(language=args.l, clean=False)
    
    input_file = os.path.join(args.i, args.jsonl)
    prefix = args.jsonl[:-len("jsonl")]
    jsonl_datas = readJsonl(input_file)
    document_file = prefix + args.dk
    summary_file = prefix + args.sk
    
    documents = []
    summaries = []
    for data in jsonl_datas:
        doc = data['document']
        summ = data['summary']
        if isinstance(doc, list):
            doc = args.d.join(doc)
            summ = args.d.join(summ)
        
        if args.add_sent_delimiter:
            # sents = seg.segment(doc)
            # doc = args.d.join(sents)
            sents = seg.segment(summ)
            summ = args.d.join(sents)
        # documents.append(doc)
        summaries.append(summ)
    
    # document_out = os.path.join(args.o, document_file)
    summary_out = os.path.join(args.o, summary_file)
    saveTxt(summaries, summary_out)
    
def dumpSents(args):
    inputfile = args.i
    assert inputfile.endswith(".jsonl")
    datas = readJsonl(inputfile)
    with open(args.o, 'w') as fout:
        for data in datas:
            sents = data['document']
            for sent in sents:
                sent = sent.strip()
                if sent != "":
                    fout.write(sent + '\n')

def getStats(args):
    files = getfiles(args.i)
    document_key = args.dk
    summary_key = args.sk
    lg = args.l
    if args.o is not None:
        os.makedirs(args.o, exist_ok=True)

    document_files = [file[:-len(document_key)] for file in files if file.endswith(document_key)]
    summary_files = [file[:-len(summary_key)] for file in files if file.endswith(summary_key)]
    overlap_prefixs = set(document_files).intersection(set(summary_files))

    for prefix in overlap_prefixs:
        documents = readTxt(prefix + document_key)
        summaries = readTxt(prefix + summary_key)
        num_sents_doc = [len(sentSplitFn(doc, args)) for doc in documents]
        num_sents_sum = [len(sentSplitFn(summ, args)) for summ in summaries]
        avg_doc_sent_num = np.average(num_sents_doc)
        avg_sum_sent_num = np.average(num_sents_sum)
        print("data: {} avg_sent_doc: {:.2f} avg_sent_sum: {:.2f}".format(
                prefix, avg_doc_sent_num, avg_sum_sent_num
            )
        )

# def runTxt2Jsonl(datas, sent_split, lg):
def runTxt2Jsonl(args):
    datas, sent_split, lg, split_id = args
    logger.info('Task pid {0} is running, parent id is {1} with {2} examples'.format(os.getpid(), os.getppid(), len(datas)))
    documents = [data[0] for data in datas]
    summaries = [data[1] for data in datas]
    ndata = convertPairsToJsonlFn(documents, summaries, sent_split, lg)
    logger.info('Task {0} end.'.format(os.getpid()))
    return ndata

def chunks(l, n):
    """Yield successive n-sized chunks from l."""
    for i in range(0, len(l), n):
        yield l[i:i + n]

def txt2JsonlData(args):
    """yield the list of prefix which have both document and summary file"""
    files = getfiles(args.i)
    document_key = args.dk
    summary_key = args.sk
    output_dir = args.o
    sent_split = args.sent_split
    lg = args.l

    if args.o is not None:
        os.makedirs(args.o, exist_ok=True)

    document_files = [file[:-len(document_key)] for file in files if file.endswith(document_key)]
    summary_files = [file[:-len(summary_key)] for file in files if file.endswith(summary_key)]
    overlap_prefixs = set(document_files).intersection(set(summary_files))

    for prefix in overlap_prefixs:
        documents = readTxt(prefix + document_key)
        summaries = readTxt(prefix + summary_key)

        datas = list(zip(documents, summaries))

        logger.info("Make Label for %s with %d examples !", prefix, len(datas))
        data_chunks = list(chunks(datas, int(len(datas) / args.p)))

        n_pool = len(data_chunks)
        arg_lst = []
        for i in range(n_pool):
            arg_lst.append((data_chunks[i], sent_split, lg, i))
        pool = Pool(n_pool)
        results = pool.map(runTxt2Jsonl, arg_lst)

        jsonl_datas = []
        for chunk in results:
            jsonl_datas.extend(chunk)
        
        # jsonl_datas = convertPairsToJsonl(documents, summaries, sent_split, lg)
        if output_dir is None:
            # write output to the input directory
            output_file = prefix + "jsonl"
        else:
            output_name = prefix.split('/')[-1]
            output_file = os.path.join(output_dir, output_name + "jsonl")
        print("write the data in {} to {}".format(prefix, output_dir))
        dumpJsonl(jsonl_datas, output_file)

def loadLabelFromRef(args):
    assert args.r is not None, "a file that contains oracle reference must be provided"

    input_datas = readJsonl(args.i)
    reference_datas = readJsonl(args.r)

    assert len(input_datas) == len(reference_datas)

    for i in range(len(input_datas)):
        input_datas[i]['label'] = reference_datas[i]['label']

    dumpJsonl(input_datas, args.o)

def oracleLabel2Txt(args):
    oracle_datas = readJsonl(args.i)
    txt_datas = []
    for item in oracle_datas:
        document_sents = item['document']
        labels = item['label']
        oracle = (args.d).join([document_sents[l] for l in labels])
        txt_datas.append(oracle)
    saveTxt(txt_datas, args.o)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    # general args
    parser.add_argument("-m", help="mode", type=str, default="makeJsonlData")

    # args for makeJsonlData
    parser.add_argument("-i", help="input idrectory or file", type=str)
    parser.add_argument("--jsonl", help="the name of input jsonl. If not given (by default), all *.jsonl in the input directory will be treated as input files", default=None)
    parser.add_argument("-o", help="output directory / file", type=str, default=None)
    parser.add_argument("-p", help="process number", type=int, default=30)
    parser.add_argument(
        "-l", help="language", type=str, default='en'
    )
    parser.add_argument("-dk", help="document key", type=str, default="doc")
    parser.add_argument("-sk", help="summary key", type=str, default="sum")
    parser.add_argument(
        "--sent-split", 
        help="the method of spliting sentences. If None (by default), do not split sentence", 
        type=str, default=None,
        choices=["pysbd", "full_stop", "delimiter"]
    )
    
    # args for jsonl2TxtData
    parser.add_argument(
        '-d', type=str, help="sent delimiter", default="<q>"
    )
    parser.add_argument(
        '--add-sent-delimiter', action="store_true",
        help="If True, documents/summaries will be splited into sentences, " \
            "and args.d will be used to join these sentences"
    )

    # args for loadLabelFromRef
    parser.add_argument(
        '-r', type=str, help="file that contains oracle label", default=None
    )

    args = parser.parse_args()

    logger = Log.getLogger(sys.argv[0])
    logger.info(args)
    eval("{}(args)".format(args.m))
