import re
import argparse
import random
# from scipy.stats import poisson
from numpy.random import geometric
from pysbd import Segmenter
from nltk.tokenize import sent_tokenize
from sacremoses import MosesTokenizer

nltk_langtag = {
    'de': 'german',
    'es': 'spanish'
}
MASKTOKEN='[mask]'

class WrappedTokenizer:
    def __init__(self, lang):
        self.lang = lang
        if lang not in ["zh", "ja", "ko"]:
            self.mt = MosesTokenizer(lang=lang)
        
    def splitChars(self, sent, lang):
        if lang == 'zh':
            parts = re.split(u"([\u4e00-\u9fa5])", sent)
        elif lang == 'ja':
            parts = re.split(u"([\u0800-\u4e00])",sent)
        elif lang == 'ko':
            parts = re.split(u"([\uac00-\ud7ff])", sent)
        else:   # Chinese, Japanese and Korean non-symbol characters
            parts = re.split(u"([\u2e80-\u9fff])", sent)
        return [p.strip().lower() for p in parts if p != "" and p != " "]

    def tokenize(self, sentence: str):
        """
        Return:
            tokens: list
        """
        if self.lang in ["zh", "ja", "ko"]:
            return self.splitChars(sentence, self.lang)
        else:
            # return self.mt.tokenize(sentence)
            return sentence.split()

def readTxt(fname):
    data = []
    with open(fname, 'rb') as fin:
        for line in fin:
            data.append(line.decode('utf-8').strip())
    print("Reading {} example from {}".format(len(data), fname))
    return data

def saveTxt(data, fname):
    with open(fname, 'w') as fout:
        for d in data:
            fout.write('{}\n'.format(d))
    print('Save {} example to {}'.format(len(data), fname))

def splitChineseSentence(sentence):
    resentencesp = re.compile('([﹒﹔﹖﹗．；。！？]["’”」』]{0,2}|：(?=["‘“「『]{1,2}|$))')
    s = sentence
    slist = []
    for i in resentencesp.split(s):
        if resentencesp.match(i) and slist:
            slist[-1] += i
        elif i:
            slist.append(i)
    return slist

def sentenceShuffle(paragraph: str, lang='en', return_str=False):
    if lang == 'zh':
        sents = splitChineseSentence(paragraph)
    else:
        langtag = nltk_langtag[lang]
        sents = sent_tokenize(paragraph, langtag)
    samples = random.sample(sents, len(sents))
    return " ".join(samples) if return_str else samples

def maskSpan(paragraph: str, p=0.2, masked_ratio=0.15, lang='en', return_str=True):
    """
    refer to SpanBERT https://arxiv.org/pdf/1907.10529.pdf
    """
    tokens = tokenizer.tokenize(paragraph)
    rawLen = len(tokens)
    remainedMaskToken = int(rawLen * masked_ratio)
    while remainedMaskToken > 0:
        spanLength = geometric(p)
        spanLength = min([spanLength, 10, remainedMaskToken])
        sentLen = len(tokens)
        start = random.randrange(0, max(0, sentLen - spanLength)+1)
        retry = 0
        while MASKTOKEN in tokens[start:start+spanLength] and retry < 5:
            spanLength = geometric(p)
            spanLength = max([spanLength, 10, remainedMaskToken])
            start = random.randrange(0, max(0, sentLen - spanLength)+1)
            retry += 1
        
        tokens = tokens[:start] + [MASKTOKEN] + tokens[(start+spanLength):]
        remainedMaskToken -= spanLength
    # print("raw num: {} after_masked: {}".format(rawLen, len([token for token in tokens if token != '[mask]'])))
    if return_str:
        if lang in ['zh']:
            return "".join(tokens)
        else:
            return " ".join(tokens)
    else:
        return tokens
    
    # return " ".join(tokens) if return_str else tokens

def noiseV1(args):
    """
    sentence shuffling + span-mask like span-bert
    """ 
    datas = readTxt(args.i)
    results = []
    for (i, data) in enumerate(datas):
        z = sentenceShuffle(data, lang=args.l)
        result = []
        for sent in z:
            result.append(maskSpan(sent, p=args.p, lang=args.l))
        if args.l in ['zh']:
            results.append("".join(result))
        else:
            results.append(" ".join(result))
    saveTxt(results, args.o)

def averageLength(args):
    datas = readTxt(args.i)
    length = 0
    for (i, data) in enumerate(datas):
        tokens = tokenizer.tokenize(data)
        length += len(tokens)
        # if i % 10000 == 0:
        #     print(i)
    avg_length = length / len(datas)
    print(avg_length)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-l', help='lang', type=str, default='zh')
    parser.add_argument('-i', help='input', type=str, default='zh')
    parser.add_argument('-o', help='output', type=str, default='zh')
    parser.add_argument('-p', help='p for geometric distribution', type=float, default=0.2)
    parser.add_argument('-r', help='the ratio of masked token per sentence', type=float, default=0.15)
    parser.add_argument('-m', help='mode', type=str)
    args = parser.parse_args()

    tokenizer = WrappedTokenizer(args.l)

    eval("{}(args)".format(args.m))

    # paragraph = "一个类似美国大片《幸福终点站》的案例令中国驻圣彼得堡领事官胡滨印象深刻。" \
    #         "由于行前未能仔细核对签证有效期，一对赴俄罗斯旅游的中国夫妇因为签证过期差点被困在俄罗斯。" \
    #         "在胡滨的协助下，这对夫妇经历了惊心动魄的３小时，终于在飞机起飞前５分钟，登上返程的航班。" \
    #         "孔子西游，见两小儿曰：“落霞与孤鹜齐飞，秋水共长天一色。”"
    # paragraph = "ARTÍCULO 390.- INTEGRACIÓN DEL TRIBUNAL. CITACIÓN A JUICIO. AUDIENCIA PRELIMINAR. Recibida la causa e integrado el Tribunal conforme a las disposiciones legales, se notificará inmediatamente su constitución a todas las partes, las que podrán en el plazo común de diez (10) días, formular las recusaciones que estimen pertinentes. " \
    #             "Si hubiere constituido actor civil, se lo emplazará para que en el plazo de cinco días concrete su demanda, bajo apercibimiento de tener por desistida la instancia. La demanda se notificará a las partes, las que deberán oponer excepciones, contestarlas o reconvenirlas, en los plazos y en las formas establecidas en el Código Procesal Civil." \
    #             "Culminada dicha instancia y resueltas las recusaciones o vencido el plazo indicado en el punto anterior, las partes serán citadas a juicio por el plazo individual de diez (10) días para cada una, a fin de que examinen las actuaciones y ofrezcan las pruebas que pretendan utilizar en el debate."
    # python3 addNoise.py -m noiseV1 -i zh.first1000.txt -o zh.first1000.denoise.txt
    # python3 addNoise.py -m averageLength -i /home/tiger/cc100/de.first1m.txt -l de