import os
import myutils
import json
import conll18_ud_eval

if not os.path.isdir('results'):
    os.mkdir('results/')

def conv(score):
    if type(score) == float:
        return '{:.4f}'.format(score*100)
    else:
        return score

def eval(udVersion, outName, model, setting):
    lasOut = open('results/cross-las.' + outName + '-' + udVersion + '.csv', 'w')
    tokOut = open('results/cross-tok.' + outName + '-' + udVersion + '.csv', 'w') 

    ## single-dataset models
    lasScores = []
    tokScores = []
    treebanks = []
    udPath = 'data/ud-treebanks-v' + udVersion + '.singleToken/'
    for UDdir in sorted(os.listdir(udPath)):
        if not UDdir.startswith("UD") or not os.path.isdir(udPath + UDdir):
            continue
        train, dev, test = myutils.getTrainDevTest(udPath + UDdir)
        
        if train != '':
            continue
        if not myutils.hasColumn(test, 1, threshold=.1):
            #print('noWords ', train)
            continue
        treebanks.append(UDdir)

        if UDdir in myutils.hasNoSplits and setting.endswith('True') and 'attention' not in outName:
            tokScores.append('---')
            lasScores.append('---')
            continue

        if model.startswith('multi.'):
            modelName = model + '.' + setting + '.' + udVersion
            modelName = modelName.replace('..', '.')
            modelPath = myutils.getModel(modelName)
            outPath = modelPath.replace('model.pt', UDdir + '.out')
        else:
            modelName = model + '.' + UDdir + '.' + setting + '.' + udVersion
            if modelName[-1] == '.':
                modelName = modelName[:-1]
            modelName = modelName.replace('..', '.')
            modelPath = myutils.getModel(modelName) 
            if modelPath == '':
                print('model not found', modelName)
                tokScores.append(0.0)
                lasScores.append(0.0)
                continue
            outPath = modelPath.replace('model.pt', UDdir + '.out')
        if outName == 'single-task':
            cmd = 'python3 scripts/fix.py ' + outPath
            os.system(cmd)

        print(test, outPath)
        try:
            goldSent = conll18_ud_eval.load_conllu(open(test))
            predSent = conll18_ud_eval.load_conllu(open(outPath))
            scores = conll18_ud_eval.evaluate(goldSent, predSent)
            tokScores.append(scores['Tokens'].f1)
            lasScores.append(scores['LAS'].f1)
        except:
            tokScores.append(0.0)
            lasScores.append(0.0)
            print("ERROR in " + outPath)

    for treebank, tokScore, lasScore in zip(treebanks, tokScores, lasScores):
        lasOut.write('\t'.join([treebank] + [conv(lasScore)]) + '\n')
        tokOut.write('\t'.join([treebank] + [conv(tokScore)]) + '\n')
    
    lasOut.close()
    tokOut.close()

for udVersion in myutils.udVersions:
    eval(udVersion, 'multi-ling', 'multi.bert-base-multilingual-cased', 'False') # multi-lingual, multi-task
    eval(udVersion, 'multi-ling+split', 'multi.bert-base-multilingual-cased', 'True') # multi-lingual, multi-task  with new splits



