import re
import argparse
import random
from numpy.random import geometric
from pysbd import Segmenter
from nltk.tokenize import sent_tokenize
from sacremoses import MosesTokenizer

nltk_langtag = {
    'de': 'german',
    'es': 'spanish'
}
MASKTOKEN='<mask>'

class WrappedTokenizer:
    def __init__(self, lang, tokenized=False):
        self.lang = lang
        self.tokenized = tokenized
        print("self.tokenized: ", self.tokenized)
        if lang not in ["zh", "ja", "ko"]:
            self.mt = MosesTokenizer(lang=lang)
        
    def splitChars(self, sent, lang):
        if lang == 'zh':
            parts = re.split(u"([\u4e00-\u9fa5])", sent)
        elif lang == 'ja':
            parts = re.split(u"([\u0800-\u4e00])",sent)
        elif lang == 'ko':
            parts = re.split(u"([\uac00-\ud7ff])", sent)
        else:   # Chinese, Japanese and Korean non-symbol characters
            parts = re.split(u"([\u2e80-\u9fff])", sent)
        return [p.strip().lower() for p in parts if p != "" and p != " "]

    def tokenize(self, sentence: str):
        """
        Return:
            tokens: list
        """
        if self.lang in ["zh", "ja", "ko"] and not self.tokenized:
            return self.splitChars(sentence, self.lang)
        else:
            return sentence.split()
    
    def de_tokenize(self, tokens: list):
        if self.lang in ["zh", "ja", "ko"] and not self.tokenized:
            return "".join(tokens)
        else:
            return " ".join(tokens)

def readTxt(fname):
    data = []
    with open(fname, 'rb') as fin:
        for line in fin:
            data.append(line.decode('utf-8').strip())
    print("Reading {} example from {}".format(len(data), fname))
    return data

def saveTxt(data, fname):
    with open(fname, 'w') as fout:
        for d in data:
            fout.write('{}\n'.format(d))
    print('Save {} example to {}'.format(len(data), fname))

def maskSpan(paragraph, args, return_str=True):
    """
    refer to SpanBERT https://arxiv.org/pdf/1907.10529.pdf
    Input:
        paragraph: str (untokenized) or list of tokens (tokenized).
    Output:
        str (If return_str=True) or list of tokens (Otherwise)
    """
    p = args.p
    masked_ratio = args.r
    if isinstance(paragraph, str):
        tokens = tokenizer.tokenize(paragraph)
    else:
        tokens = paragraph
    rawLen = len(tokens)
    remainedMaskToken = int(rawLen * masked_ratio)
    masked = [0 for _ in range(rawLen)]
    while remainedMaskToken > 0:
        spanLength = geometric(p)
        spanLength = min([spanLength, 5, remainedMaskToken])
        sentLen = len(tokens)
        start = random.randrange(0, max(0, sentLen - spanLength)+1)
        for pos in range(start, start + spanLength):
            if masked[pos] == 0:
                remainedMaskToken -= 1
            masked[pos] = 1

    pos = 0
    result_tokens = []
    while pos < rawLen:
        if masked[pos] == 0:
            result_tokens.append(tokens[pos])
            pos += 1
        else:
            # detect a span
            pos_end = pos + 1
            while pos_end < rawLen and masked[pos_end] == 1:
                pos_end += 1
            result_tokens.append(MASKTOKEN)
            pos = pos_end

    if return_str:
        return tokenizer.de_tokenize(result_tokens)
    else:
        return result_tokens

def noiseV2(args):
    """
    span-mask like span-bert
    """ 
    datas = readTxt(args.i)
    datas = [item for item in datas if item.strip() != ""]
    results = []
    for (i, data) in enumerate(datas):
        results.append(maskSpan(data, args))
        if i and i % 100000 == 0:
            print(i)

    saveTxt(datas, args.ot)
    saveTxt(results, args.os)

def splitChineseSentence(sentence):
    resentencesp = re.compile('([﹒﹔﹖﹗．；。！？]["’”」』]{0,2}|：(?=["‘“「『]{1,2}|$))')
    s = sentence
    slist = []
    for i in resentencesp.split(s):
        if resentencesp.match(i) and slist:
            slist[-1] += i
        elif i:
            slist.append(i)
    return slist

def sentSplitFn(paragraph: str, lang, return_str=False, delimitor=None):
    if delimitor is not None:
        sents = paragraph.split(delimitor)
    else:
        if lang == 'zh':
            sents = splitChineseSentence(paragraph)
        else:
            langtag = nltk_langtag[lang]
            sents = sent_tokenize(paragraph, langtag)
    sents = [sent for sent in sents if sent.strip() != ""]
    return delimitor.join(sents) if return_str else sents

def dropTokenFn(tokens, ratio=0.15):
    if tokens[0] == "▁":
        tokens.pop(0)
    num_tokens = len(tokens)
    selected_token_idxs = random.sample([i for i in range(num_tokens)], k=round(num_tokens * ratio))
    source_tokens, target_tokens = [], []
    for (i, token) in enumerate(tokens):
        if i not in selected_token_idxs:
            source_tokens.append(token)
        else:
            target_tokens.append(token)
    return source_tokens, target_tokens

def dropTokenNoise(args):
    datas = readTxt(args.i)
    sources = []
    targets = []
    for (i, data) in enumerate(datas):
        if i and i % 200000 == 0:
            print("process {} samples".format(i))
        tokens = tokenizer.tokenize(data)
        if len(tokens) <= 10:
            continue
        source_tokens, target_tokens = dropTokenFn(tokens, args.r)
        sources.append(tokenizer.de_tokenize(source_tokens))
        targets.append(tokenizer.de_tokenize(target_tokens))
    saveTxt(sources, args.os)
    saveTxt(targets, args.ot)

def dropTokenNoiseSingleFile(datas: list, args):
    tokenizer = WrappedTokenizer(args.l, args.t)
    sources = []
    targets = []
    for (i, data) in enumerate(datas):
        # if i and i % 20000 == 0:
        #     print("process {} samples".format(i))
        tokens = tokenizer.tokenize(data)
        if len(tokens) <= 10:
            continue
        source_tokens, target_tokens = dropTokenFn(tokens, args.r)
        sources.append(tokenizer.de_tokenize(source_tokens))
        targets.append(tokenizer.de_tokenize(target_tokens))
    return sources, targets

def inFillNoiseSingleFile(datas: list, args):
    """
    span-mask like span-bert
    """ 
    tokenizer = WrappedTokenizer(args.l, args.t)
    sources = []
    targets = []
    for (i, data) in enumerate(datas):
        tokens = tokenizer.tokenize(data)
        source_tokens = maskSpan(tokens, args, return_str=False)
        source = tokenizer.de_tokenize(source_tokens)
        sources.append(source)
        targets.append(data)
    return sources, targets

def averageLength(args):
    datas = readTxt(args.i)
    length = 0
    for (i, data) in enumerate(datas):
        tokens = tokenizer.tokenize(data)
        length += len(tokens)
    avg_length = length / len(datas)
    print(avg_length)

def removeLongInstance(args):
    """
    span-mask like span-bert
    """ 
    datas = readTxt(args.i)
    datas = [item for item in datas if item.strip() != ""]
    results = []
    for (i, data) in enumerate(datas):
        tokens = tokenizer.tokenize(data)
        if len(tokens) <= args.max_length:
            results.append(tokenizer.de_tokenize(tokens))
    saveTxt(results, args.os)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-l', help='lang', type=str, default='en')
    parser.add_argument('-i', help='input', type=str, default='input.txt')
    parser.add_argument('-os', help='output source', type=str, default='output.txt')
    parser.add_argument('-ot', help='output target', type=str, default='output.txt')
    parser.add_argument('-r', help='the ratio of masked token per sentence', type=float, default=0.15)
    parser.add_argument('-t', action="store_true", help="whether the data has been tokenized")
    parser.add_argument('-d', help="sentence delimitor", type=str, default="<q>")
    parser.add_argument('-p', help="", type=float, default=0.2)
    parser.add_argument('-m', help='mode', type=str)
    parser.add_argument("--max-length", help="max length", type=int, default=None)
    args = parser.parse_args()

    tokenizer = WrappedTokenizer(args.l, args.t)

    eval("{}(args)".format(args.m))

    # paragraph = "Summer has gone and passed, the innocent can never last, wake me up when september ends"

    # para = maskSpan(paragraph, args)
    # print(para)

    # python3 noise.py -m /opt/tiger/sumtest/cc100/en.first10k.txt
    # python3 noise.py -m noiseV2 -i /opt/tiger/sumtest/cc100/en.first10k.txt -os /opt/tiger/sumtest/cc100/infillNoise/en.noise.src -ot /opt/tiger/sumtest/cc100/infillNoise/en.noise.tgt -t 