from transformers.models.bert.tokenization_bert import BasicTokenizer
import myutils
import os
import unicodedata
import conll18_ud_eval

tokenizer = BasicTokenizer(strip_accents=False, do_lower_case=False, tokenize_chinese_chars=True)

def conv(score):
    if type(score) == float:
        return '{:.4f}'.format(score*100)
    else:
        return score

def _is_punctuation(char):
    """Checks whether `char` is a punctuation character."""
    cp = ord(char)
    # We treat all non-letter/number ASCII as punctuation.
    # Characters such as "^", "$", and "`" are not in the Unicode
    # Punctuation class but we treat them as punctuation anyways, for
    # consistency.
    if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
        return True
    cat = unicodedata.category(char)
    if cat.startswith("P"):
        return True
    return False

def _is_whitespace(char):
    """Checks whether `char` is a whitespace character."""
    # \t, \n, and \r are technically control characters but we treat them
    # as whitespace since they are generally considered as such.
    if char == " " or char == "\t" or char == "\n" or char == "\r":
        return True
    cat = unicodedata.category(char)
    if cat == "Zs":
        return True
    return False

def insert(form_list, tgt_char_idx, char_to_add):
    """
    Removes a character at a certain character index in a list of subwords.

    Parameters
    ----------
    form_list: List
        List of subwords
    tgt_char_idx: int
        index of character to remove (note that index comes from full text, not a list)
    char_to_add: chr
        The character that we want to insert
    Returns
    -------
    form_list: List
        List of subwords, with 1 character less
    """
    counter = 0
    for subwordIdx, subword in enumerate(form_list):
        if counter + len(subword) >= tgt_char_idx > counter:
            word_char_idx = tgt_char_idx - counter
            # If this is the last character, and the following subword is empty, put it
            # there instead
            # the following subword is probably empty, because it just has been removed?
            # This could be checked if one wants to be extra careful (by looking at the edit rule)
            if subwordIdx < len(form_list)-1 and len(form_list[subwordIdx+1]) == 0  and word_char_idx == len(form_list[subwordIdx]):
                form_list[subwordIdx+1] = char_to_add
            # If this is the first character, and the previous subword is empty, put it
            # there instead
            # the previous subword is probably empty, because it just has been removed?
            # This could be checked if one wants to be extra careful (by looking at the edit rule)
            elif subwordIdx > 0 and len(form_list[subwordIdx-1]) == 0 and word_char_idx == 0:
                form_list[subwordIdx-1] = char_to_add
            else:
                form_list[subwordIdx] = form_list[subwordIdx][:word_char_idx] + char_to_add + form_list[subwordIdx][word_char_idx:]

        counter += len(subword)

    if tgt_char_idx == 0:# this was not caught in the for loop above
        form_list[0] = char_to_add + form_list[0]

    return form_list


def min_edit_script(source, target, allow_copy=True):
    """
    Finds the minimum edit script to transform the source to the target
    """
    a = [[(len(source) + len(target) + 1, None)] * (len(target) + 1) for _ in range(len(source) + 1)]
    for i in range(0, len(source) + 1):
        for j in range(0, len(target) + 1):
            if i == 0 and j == 0:
                a[i][j] = (0, "")
            else:
                if allow_copy and i and j and source[i - 1] == target[j - 1] and a[i - 1][j - 1][0] < a[i][j][0]:
                    a[i][j] = (a[i - 1][j - 1][0], a[i - 1][j - 1][1] + "→")
                if i and a[i - 1][j][0] < a[i][j][0]:
                    a[i][j] = (a[i - 1][j][0] + 1, a[i - 1][j][1] + "-")
                if j and a[i][j - 1][0] < a[i][j][0]:
                    a[i][j] = (a[i][j - 1][0] + 1, a[i][j - 1][1] + "+" + target[j - 1])
    return a[-1][-1][1]


def gen_lemma_rule(form, lemma, allow_copy=True):
    """
    Generates a lemma rule to transform the source to the target
    """
    form = form.lower()

    previous_case = -1
    lemma_casing = ""
    for i, c in enumerate(lemma):
        case = "↑" if c.lower() != c else "↓"
        if case != previous_case:
            lemma_casing += "{}{}{}".format("¦" if lemma_casing else "", case,
                                            i if i <= len(lemma) // 2 else i - len(lemma))
        previous_case = case
    lemma = lemma.lower()

    best, best_form, best_lemma = 0, 0, 0
    for l in range(len(lemma)):
        for f in range(len(form)):
            cpl = 0
            while f + cpl < len(form) and l + cpl < len(lemma) and form[f + cpl] == lemma[l + cpl]: cpl += 1
            if cpl > best:
                best = cpl
                best_form = f
                best_lemma = l

    rule = lemma_casing + ";"
    if not best:
        rule += "a" + lemma
    else:
        rule += "d{}¦{}".format(
            min_edit_script(form[:best_form], lemma[:best_lemma], allow_copy),
            min_edit_script(form[best_form + best:], lemma[best_lemma + best:], allow_copy),
        )
    return rule

def remove(form_list, tgt_char_idx: int):
    """
    Removes a character at a certain character index in a list of subwords.

    Parameters
    ----------
    form_list: List
        List of subwords
    tgt_char_idx: int
        index of character to remove (note that index comes from full text, not a list)
    Returns
    -------
    form_list: List
        List of subwords, with 1 character less
    """
    counter = 0
    for subwordIdx, subword in enumerate(form_list):
        if counter + len(subword) >= tgt_char_idx > counter:
            word_char_idx = tgt_char_idx-counter
            form_list[subwordIdx] = form_list[subwordIdx][:word_char_idx-1] + form_list[subwordIdx][word_char_idx:]
        counter += len(subword)
    return form_list



def apply_edit_rule(rule: str, form_list):
    """
    Applies the edit rule to the form to generate the original character sequence.
    Note that is slightly complicated by the fact that the text is in a list of 
    subwords.
    
    Parameters
    ----------
    rule: str
        The rule as generated by the min_edit_script
    form_list: List
        List of (sub)words

    Returns
    -------
    form_list: List
        
    """
    tgt_char_idx = 0
    j = 0
    while j < len(rule):
        if rule[j] == "→":
            tgt_char_idx += 1
        elif rule[j] == "-":
            form_list = remove(form_list, tgt_char_idx+1)
        else:
            assert (rule[j] == "+")
            form_list = insert(form_list, tgt_char_idx, rule[j+1])
            tgt_char_idx += 1
            j += 1
        j += 1
    return form_list




def tokenize(text):
    prev_ispunct = False
    new_text = []
    cur_word = ''
    for char in text:
        ispunct = _is_punctuation(char)
        if _is_whitespace(char):
            if cur_word != '':
                new_text.append(cur_word)
                cur_word = ''
            prev_ispunct = False
        else:
            if (not prev_ispunct) and ispunct:
                if cur_word != '':            
                    new_text.append(cur_word)
                    cur_word = ''
            if prev_ispunct and (not ispunct):
                if cur_word != '':
                    new_text.append(cur_word)
                    cur_word = ''
            cur_word += char
            prev_ispunct = ispunct
    if cur_word != '':
        new_text.append(cur_word)
    return new_text

def rule_tok(inPath, outPath, tokenizer):
    print(outPath)
    outFile = open(outPath, 'w')
    curSent = ''
    for line in open(inPath):
        if (line.startswith('# text=') or line.startswith('# text =')) and len(line) > 9:
            outFile.write(line)
            text = line[8:].strip()
            newtext = ''
            for char in text:
                if _is_whitespace(char):
                    newtext += ' '
                else:
                    newtext += char
            text = newtext
            tokenized = tokenizer(text)
            tok_chars = ''.join(tokenized).replace(' ', '')
            textchars = text.replace(' ', '')
            if tok_chars != textchars:
                edit_rule = min_edit_script(tok_chars, textchars)
                tokenized = apply_edit_rule(edit_rule, tokenized)
                
            for wordIdx, word in enumerate(tokenized):
                if wordIdx == 0:
                    label = 'root'
                    head = '0'
                else:
                    label = 'obj'
                    head = '1'
                outFile.write('\t'.join([str(wordIdx+1), word, '_', '_', '_', '_', head, label, '_', '_']) + '\n')
            outFile.write('\n')
    outFile.close()

tokenizer_rob_function = tokenize

#https://www.nltk.org/api/nltk.tokenize.casual.html
from nltk.tokenize import TweetTokenizer
tokenizer_twitter = TweetTokenizer()
tokenizer_twitter_function = tokenizer_twitter.tokenize

#https://www.nltk.org/api/nltk.tokenize.destructive.html
from nltk.tokenize import NLTKWordTokenizer
tokenizer_robert_function = NLTKWordTokenizer().tokenize

#https://www.nltk.org/api/nltk.tokenize.nist.html
#from nltk.tokenize.nist import NISTTokenizer
#tokenizer_nist = NISTTokenizer()
#tokenizer_nist_function = tokenizer_nist.international_tokenize

#https://www.nltk.org/api/nltk.tokenize.stanford.html
#from nltk.tokenize.stanford import StanfordTokenizer
#tokenizer_stanford_function = StanfordTokenizer().tokenize

#https://www.nltk.org/api/nltk.tokenize.toktok.html
from nltk.tokenize import ToktokTokenizer
tokenizer_simple = ToktokTokenizer()
tokenizer_simple_function = tokenizer_simple.tokenize

#https://www.nltk.org/api/nltk.tokenize.treebank.html
from nltk.tokenize.treebank import TreebankWordTokenizer
tokenizer_treebank = TreebankWordTokenizer()
tokenizer_treebank_function = tokenizer_treebank.tokenize

tokenizers = [tokenizer_rob_function, tokenizer_twitter_function, tokenizer_robert_function, tokenizer_simple_function, tokenizer_treebank_function]
tokenizer_names = ['bert_basic', 'twitter', 'robert', 'simple', 'treebank']

for tok_function in tokenizers:
    print(tok_function('this is a test sentence. Ain\'t it?'))



#for udVersion in reversed(myutils.udVersions)[:1]:
for tok_function, tok_name in zip(tokenizers, tokenizer_names):
    for udVersion in myutils.udVersions:
        udPath = 'data/ud-treebanks-v' + udVersion + '.singleToken/'
        outPath = 'results/Tokens.rulebased-' + tok_name + '-bert-base-multilingual-cased-' + udVersion + '.csv'
        outFile = open(outPath, 'w')
        for UDdir in sorted(os.listdir(udPath)):
            if not UDdir.startswith("UD") or not os.path.isdir(udPath + UDdir):
                continue
            train, dev, test = myutils.getTrainDevTest(udPath + UDdir)
            outTest = 'preds/tok.rulebased-' + tok_name + '.single.' + udVersion + '.' + test.split('/')[-1]
            rule_tok(test, outTest, tok_function)
            if train == '':
                continue
            if not myutils.hasColumn(train, 1, threshold=.1):
                #print('noWords ', train)
                continue
    
            if train != '' and dev != '':
                outDev = 'preds/tok.rulebased-' + tok_name + '.single.' + udVersion + '.' + dev.split('/')[-1]
                rule_tok(dev, outDev, tok_function)
                try:
                    goldSent = conll18_ud_eval.load_conllu(open(dev))
                    predSent = conll18_ud_eval.load_conllu(open(outDev))
                    score = conll18_ud_eval.evaluate(goldSent, predSent)['Tokens'].f1
                except:
                    print("error in evaluation")
                    score = 0.0
                outFile.write(UDdir + '\t' + conv(score) + '\n')
        outFile.close()
