from typing import List
import unicodedata
from transformers.models.bert.tokenization_bert import BasicTokenizer
from transformers.models.bert.tokenization_bert import BertTokenizer
from transformers.models.xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
from transformers import AutoTokenizer
import torch
from transformers import tokenization_utils
import sys
import conll18_ud_eval
from machamp.utils import tok_utils
import os


from machamp.utils.lemma_edit import min_edit_script


def evalFile(path, splits, learn_new_splits):
    data = []
    curSent = []
    curFull = []
    for line in open(path):
        if len(line) < 2:
            data.append([curSent, curFull])
            curSent = []
            curFull = []
        else:
            curFull.append(line.split('\t'))
            if len(line.split('\t')) == 10:
                curSent.append(line.split('\t')[1])

    pre_tokenizer = BasicTokenizer(strip_accents=False, do_lower_case=False, tokenize_chinese_chars=True)
    tokenizer = AutoTokenizer.from_pretrained(sys.argv[2], use_fast=False, do_basic_tokenize=False)

    preds = []
    new_splits_combined = {}
    for sent in data:
        token_ids, offsets, tok_labels, no_unk_subwords, new_splits = tok_utils.tokenize_and_annotate(sent[1], sent[0], pre_tokenizer, tokenizer, splits, learn_new_splits)

        full_data = toString(tok_labels, no_unk_subwords)
        for line in full_data:
            preds.append('\t'.join(line) + '\n')
        preds.append('\n')
        for label, subword in zip(tok_labels, no_unk_subwords):
            print(label, subword)
        print()

    # TODO remove, this is for debugging
    #outfile = open('ja', 'w')
    #for line in preds:
    #    outfile.write(line)
    #outfile.close()

    goldSent = conll18_ud_eval.load_conllu(open(path))
    predSent = conll18_ud_eval.load_conllu(conllFile(preds))
    score = conll18_ud_eval.evaluate(goldSent, predSent)['Tokens'].f1
    splits.update(new_splits)
    return splits, score
        
def toString(preds, no_unk_subwords):
    full_data = []
    shifted_tok_pred = ['split'] + preds
    for subword_idx in range(len(no_unk_subwords)):
        if shifted_tok_pred[subword_idx] == 'merge' and subword_idx > 0:
            full_data[-1][1] += no_unk_subwords[subword_idx]
        else:
            full_data.append(['_'] * 10)  # TODO 10 is hardcoded, 1 as well
            full_data[-1][1] = ''
            full_data[-1][1] += no_unk_subwords[subword_idx]
    for i in range(len(full_data)):
        full_data[i][0] = str(i+1)
        if i == 0:
            full_data[i][6] = '0'
        else:
            full_data[i][6] = '1'
    return full_data

class conllFile:
    def __init__(self, data):
        self.data = data
        self.idx = -1
    def readline(self):
        self.idx +=1
        if self.idx != len(self.data):
            return self.data[self.idx]

def getTrainDevTest(path):
    train = ''
    dev = ''
    test = ''
    for conlFile in os.listdir(path):
        if conlFile.endswith('conllu'):
            if 'train' in conlFile:
                train = path + '/' + conlFile
            if 'dev' in conlFile:
                dev = path + '/' + conlFile
            if 'test' in conlFile:
                test = path + '/' + conlFile
    return train, dev, test


def conv(score):
    if type(score) == float:
        return '{:.4f}'.format(score*100)
    else:
        return ''

if os.path.isdir(sys.argv[1]):
    print('\\begin{tabular}{l r r r}')
    print('\\toprule')
    print(' & '.join(['Treebank', 'dev-', 'dev+', '\#splits']) + '\\\\')
    print('\\midrule')
    for treebank in sorted(os.listdir(sys.argv[1])):
        train, dev, test = getTrainDevTest(sys.argv[1] + '/' + treebank)
        if train == '':
            continue
        splits, train_score = evalFile(train, {}, True)
        if dev != '':
            _, dev_score_before = evalFile(dev, {}, False)
            _, dev_score_after = evalFile(dev, splits, False)
            data = [treebank.replace('_', '\\_')] + [conv(score) for score in [dev_score_before, dev_score_after]] + [str(len(splits))]

            print(' & '.join(data) + '\\\\')
    print('\\bottomrule')
    print('\\end{tabular}')
    print('\\caption{' + sys.argv[1].split('/')[-1] + '}')
else:
    _, score_before = evalFile(sys.argv[1], {}, False)
    splits, score_after = evalFile(sys.argv[1], {}, True)
    print(splits)
    print(conv(score_before), conv(score_after))
    if 'train' in sys.argv[1] and os.path.isfile(sys.argv[1].replace('train', 'dev')):
        _, dev_score_before = evalFile(sys.argv[1].replace('train', 'dev'), {}, False)
        splits, dev_score_after = evalFile(sys.argv[1].replace('train', 'dev'), {} , True)
        print(conv(dev_score_before), conv(dev_score_after))
    print(len(splits))

