'''
Date: 2021-05-28 15:59:27
LastEditors: Wu Xianze (wuxianze.0@bytedance.com)
LastEditTime: 2021-05-28 16:05:09
'''
import re
import os
import argparse
import json
from nltk.tokenize import sent_tokenize

def readTxt(fname):
    data = []
    with open(fname, 'rb') as fin:
        for line in fin:
            data.append(line.decode('utf-8').strip())
    print("Reading {} example from {}".format(len(data), fname))
    return data

def saveTxt(data, fname):
    with open(fname, 'w') as fout:
        for d in data:
            fout.write('{}\n'.format(d))
    print('Save {} example to {}'.format(len(data), fname))

def readJson(fname):
    data = []
    with open(fname) as fin:
        for line in fin:
            data.append(json.loads(line))
    print("Reading {} example from {}".format(len(data), fname))
    return data

def normalize(s: str):
    # s = re.sub("\\", "", s)
    s = s.replace("\\", "")
    s = s.lower()
    s = " ".join(s.split())
    return s

def getMatchbyContent(data, info, lang=None):
    result = []
    info_dict = {}
    for e in info:
        src_text = normalize(e['src_text'])
        tgt_text = e['trg_text'] if src_text != "empty" else None
        if lang is None:
            info_dict[src_text] = tgt_text
        elif e['src_lang']==lang or e['src_text'].startswith('EMPTY'):
            info_dict[src_text] = tgt_text

    for d in data:
        d = normalize(d)
        trans = info_dict.get(d, '')
        result.append(trans)
    miss = [r for r in result if r == '']
    print("Match {} examples and Miss {} examples".format(len(result) - len(miss), len(miss)))
    return result, miss

def checkTrans(args):
    trans = readJson(args.i)
    raw = readTxt(args.r)
    result, miss = getMatchbyContent(raw, trans)
    saveTxt(result, args.o)

def checkTrans2(args):
    """
    replace '' (missed lines) in results to 'EMPTY'
    """
    trans = readJson(args.i)
    raw = readTxt(args.r)
    result, miss = getMatchbyContent(raw, trans)
    for (i, item) in enumerate(result):
        item = "EMPTY" if item == "" else item
        result[i] = item

    saveTxt(result, args.o)

def readGeneralCSV(fname, symbol=','):
    import csv
    data = []
    headline = 1
    with open(fname, 'r', newline='') as fin:
        reader = csv.reader(fin, delimiter=symbol)
        for item in reader:
            parts = item
            if headline == 1:
                keys = parts
                headline = 0
                continue
            else:
                if len(parts) != len(keys):
                    # print(parts)
                    print(len(parts), len(keys))
                    continue
                e = {}
                for (k, v) in zip(keys, parts):
                    e[k] = v
                data.append(e)
    print("Reading {} example from {}".format(len(data), fname))
    return data, keys


def split(args):
    # seg = Segmenter(language=args.l, clean=False)

    inputfile = args.i
    outputfile = args.o
    partnumber = args.n
    
    datas = readTxt(inputfile)
    outputs = []
    for (i, line) in enumerate(datas):
        # sents = seg.segment(line)
        sents = sent_tokenize(line)
        num_sent = len(sents)
        p = num_sent // partnumber + 1
        for part_i in range(partnumber):
            start = part_i * p
            end = (part_i + 1) * p
            partsentences = sents[start:end]
            if len(partsentences) == 0:
                # partsentences = ['EMPTY{}'.format(part_i)]
                partsentences = ['EMPTY']

            outputs.append(' '.join(partsentences))

    print("split {} instances into {} sentences".format(len(datas), len(outputs)))
    saveTxt(outputs, outputfile)

def splitText(args):
    inputfile = args.i
    partnumber = args.n

    datas = readTxt(inputfile)
    partsize = len(datas) // partnumber + 1
    for parti in range(partnumber):
        start = parti * partsize
        end = min(len(datas), (parti+1) * partsize)
        saveTxt(datas[start:end], inputfile+'.part{}'.format(parti))

def merge(args):
    inputfile = args.i
    outputfile = args.o
    partnumber = args.n
    
    datas = readTxt(inputfile)

    outputs = []
    missed_num = 0
    for i in range(len(datas) // partnumber):
        sents = datas[i*partnumber:(i+1)*partnumber]
        
        partsents = [sent for sent in sents if sent not in ['EMPTY']]
        if args.strict and len(partsents) != partnumber:
            partsents = []
            missed_num += 1
        outputs.append(' '.join(partsents).strip())

    print("merge {} sents into {} instances, empty instances {}".format(
        len(datas), len(outputs), missed_num
    ))
    saveTxt(outputs, outputfile)


def csv2txt(args):
    datas, keys = readGeneralCSV(args.i, args.split)
    inputdir = os.path.split(args.i)[0]
    
    for data in datas:
        for i in range(4):
            key = keys[i]
            with open(os.path.join(inputdir, "{}.txt".format(key)), 'a') as fout:
                fout.write(data[key] + '\n')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', type=str, default="data.txt", help='original file')
    parser.add_argument('-o', type=str, default="cleardaata.txt", help='output file')
    parser.add_argument('-n', type=int, default=5, help="the number of parts per instance")
    parser.add_argument('-l', type=str, default='en', help='language')
    parser.add_argument('-m', type=str, default='split', help='mode')
    parser.add_argument('-r', type=str, default="reference.txt", help='reference line')
    parser.add_argument('--split', type=str, default='\t')
    parser.add_argument('--strict', action="store_true", help='strict mode')
    args = parser.parse_args()

    eval('{}(args)'.format(args.m))

    # print(normalize("(CNN)Duke University students and faculty members marched Wednesday afternoon chanting \"We are not afraid. We stand together,\" after a noose was found hanging from a tree on campus. Duke officials have asked anyone with information about the rope noose, which was found near a student center at 2 a.m.,"))

    # # usage
    # # 翻译前，将长文本分割成若干部分
    # python3 splitSent.py -i ${RAW_TEXT} -o ${SPLITED_SENTS} -n 5 -l de

    # # 翻译后
    # python3 splitSent.py -m checkTrans2 \
    #     -i ${TRANSLATED_JSON} -r ${SPLITED_SENTS} \
    #     -o ${TRANSLATED_SENTS} \
    #     -l ${lang}
    # python3 splitSent.py -m merge -i ${TRANSLATED_SENTS} -o {MERGED_SENTS} -n 5 -l de

    # python3 splitSent.py -m checkTrans \
    #     -i ${TRANSLATED_JSON} -r ${SPLITED_SENTS} \
    #     -o ${TRANSLATED_SENTS} \
    #     -l ${lang}