import conll18_ud_eval
import myutils
import os
def conv(score):
    if type(score) == float:
        return '{:.4f}'.format(score)
    else:
        return ''


def load_sota(task, version):
    results = {}
    for line in open('results/' + task + '-sota' + '-' + version + '.csv').readlines()[1:]:
        tok = line.strip().split('\t')
        treebank = tok[0]
        score = tok[1]
        results[treebank] = score
    return results

def load(name):
    data = open('results/' + name + '.csv').readlines()
    treebanks = [x.split('\t')[0] for x in data]
    scores = [x.strip().split('\t')[1] for x in data]
    return treebanks, scores
import copy

def boldHighest(row):
    row = copy.deepcopy(row)
    maxScore = max([float(x) for x in row[1:]])
    for i in range(1,len(row)):
        if float(row[i]) == maxScore:
            row[i] = '\\textbf{' + row[i] + '}'
    return row

lm= 'bert-base-multilingual-cased'
def makeTables(task, version):
    names = ['sota', 'single-task', 'multi-ling']
    allOut = open('results/test.' + task + '.' + version + '.all', 'w')
    trainOut = open('results/test.' + task + '.' + version + '.train', 'w')
    scoresTrain = {}
    if version != '2.10':
        sota = load_sota(task, version)
    else:
        sota = {}
    ud_root = 'data/ud-treebanks-v' + version + '.singleToken/'
    for treebank in sorted(os.listdir(ud_root)):
        train, dev, test = myutils.getTrainDevTest(ud_root + treebank)

        if not myutils.hasColumn(test, 1, threshold=.1):
            continue

        if treebank in sota:
            sota_score = sota[treebank]
        else:
            sota_score = '---'

        single_pred_path = 'preds/' + task + '.' + lm + '.' + treebank + '.single.' + version + '.' + test.split('/')[-1]
        if train != '':
            cmd = 'python3 scripts/fix.py ' + single_pred_path
            os.system(cmd)
            print(test, single_pred_path)
            try:
                goldSent = conll18_ud_eval.load_conllu(open(test))
                predSent = conll18_ud_eval.load_conllu(open(single_pred_path))
                single_score = conll18_ud_eval.evaluate(goldSent, predSent)['Tokens'].f1
            except:
                print('ERROR', test, single_pred_path)
                single_score = '---'
        else:
            single_score = '---'

        multi_pred_path = 'preds/multi.' + lm + '.' + treebank + '.' + version
        #print(test, multi_pred_path)
        try:
            goldSent = conll18_ud_eval.load_conllu(open(test))
            predSent = conll18_ud_eval.load_conllu(open(multi_pred_path))
            multi_score = conll18_ud_eval.evaluate(goldSent, predSent)['Tokens'].f1
        except:
            print('ERROR', test, multi_pred_path)
            single_score = '---'
        
        base_scores = []
        for ruleStrat in ['bert_basic', 'robert', 'simple', 'treebank', 'twitter']:
            pred_path = 'preds/tok.rulebased-' + ruleStrat + '.single.' + version + '.' + test.split('/')[-1]
            try:
                goldSent = conll18_ud_eval.load_conllu(open(test))
                predSent = conll18_ud_eval.load_conllu(open(pred_path))
                base_scores.append(conll18_ud_eval.evaluate(goldSent, predSent)['Tokens'].f1)
            except:
                print('ERROR', test, pred_path)
                base_scores.append(0.0)
        base_score = max(base_scores)
        
        outStr = '\t'.join([treebank, str(sota_score), str(single_score), str(multi_score), str(base_score)])
        print(outStr)
        allOut.write(outStr + '\n')
        allOut.flush()
        if (sota_score != '---' and train != '') or (version == '2.10' and train != ''):
            trainOut.write(outStr + '\n')
            trainOut.flush()

    allOut.close()
    trainOut.close()

makeTables('tok', '2.2')
makeTables('tok', '2.5')
makeTables('tok', '2.10')


