import myutils

import matplotlib as mpl
import matplotlib.patches as mpatches
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('scripts/rob.mplstyle')
import statistics

def conv(score):
    if type(score) == float:
        return '{:.4f}'.format(score)
    else:
        return ''


def load_sota(task):
    results = {}
    for line in open('results/' + task + '-sota.csv').readlines()[1:]:
        tok = line.strip().split('\t')
        treebank = tok[0]#.replace('_', '\\_')
        score = tok[1]
        results[treebank] = score
    return results

def load(name):
    data = open('results/' + name + '.csv').readlines()
    treebanks = [x.split('\t')[0] for x in data]
    scores = [x.strip().split('\t')[1] for x in data]
    return treebanks, scores
import copy

def boldHighest(row):
    row = copy.deepcopy(row)
    maxScore = max([float(x) for x in row[1:]])
    for i in range(1,len(row)):
        if float(row[i]) == maxScore:
            row[i] = '\\textbf{' + row[i] + '}'
    return row

scriptFinder = myutils.ScriptFinder()

def getScript(path):
    train, dev, test = myutils.getTrainDevTest('data/ud-treebanks-v2.10.singleToken/' + path)
    tgt = train
    if tgt == '':
        tgt = test
    tgt = tgt.replace('conllu', 'txt')
    return scriptFinder.guess_script('\n'.join(open(tgt).readlines()[:100]))
    

def getSize(path):
    train, dev ,test = myutils.getTrainDevTest('data/ud-treebanks-v2.10.singleToken/' + path)
    counter = 0
    for line in open(train):
        counter += line[0].isdigit()
    return counter

def plot(ax, scoreList, x_offset, names, names_xaxis):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    if names[0] != 'RB':
        colors = colors[1:]
    bar_width=.2
    for setupIdx in range(len(scoreList[0])):
        allScores = [float(x[setupIdx]) for x in scoreList]
        mean = sum(allScores)/len(allScores)
        #print(allScores)
        stddev = statistics.pvariance(allScores)
        x = x_offset + setupIdx*bar_width #+ bar_width/2
        if x_offset == 1:
            ax.bar([x], [mean], bar_width, color= colors[setupIdx], label=names_xaxis[setupIdx])
        else:
            ax.bar([x], [mean], bar_width, color= colors[setupIdx])
        #if x_offset == 1 and setupIdx == 0:
        #    ax.plot([x,x], [mean-stddev, mean+stddev], label= 'stddev', color='black')
        #else:
        #    ax.plot([x,x], [mean-stddev, mean+stddev], color='black')
    #print()

def makeTable(task, yrange, names, names_xaxis, legend_loc):
    print(task)
    data = []
    #sota =  load_sota(task)
    for nameIdx, name in enumerate(names):
        if name == 'RB':
            allScores = []
            for rbName in ['bert_basic', 'twitter', 'robert', 'simple', 'treebank']:
                path = 'Tokens.rulebased-' + rbName + '-bert-base-multilingual-cased-2.10'
                treebanks, tmp_scores = load(path)
                allScores.append(tmp_scores)
            scores = []
            for i in range(len(allScores[0])):
                scores.append(max([x[i] for x in allScores]))
        else:
            treebanks, scores = load(task + '.' + name)

        if nameIdx == 0:
            for score, treebank in zip(scores, treebanks):
                data.append([treebank, score])#.replace('_', '\\_'), score])
        else:
            for i in range(len(scores)):
                if scores[i] == '---':
                    data[i].append(data[i][-1])
                else:
                    data[i].append(scores[i])
    rm_idxs = []
    #for rowIdx in range(len(data)):
    #    treebank = data[rowIdx][0]
    #    if treebank in sota:
    #        data[rowIdx].append(sota[treebank])
    #    else:
    #        data[rowIdx].append('0.0')
    #        rm_idxs.append(rowIdx)


    for rowIdx, row in enumerate(data):
        if 'Sign' in row[0] or 'Japan' in row[0]:
            rm_idxs.append(rowIdx)
    for rm_idx in reversed(rm_idxs):
        del data[rm_idx]
    #data = data[:10]
    sizeScores = {'Small':[], 'Medium':[], 'Large':[]}
    allScores = {'All': []}
    scriptScores = {'Latin': [], 'Other': []}
    for row in data:
        allScores['All'].append(row[1:])

        dataSize = getSize(row[0])
        if dataSize < 20000:
            #print(row)
            sizeScores['Small'].append(row[1:])
        elif dataSize < 100000:
            sizeScores['Medium'].append(row[1:])
        else:
            sizeScores['Large'].append(row[1:])

        dataScript = getScript(row[0])
        if dataScript == 'Latin':
            scriptScores["Latin"].append(row[1:])
        else:
            scriptScores['Other'].append(row[1:])

    fig, ax = plt.subplots(figsize=(8,5), dpi=300)
    
    plot(ax, allScores['All'], 1, names, names_xaxis)
    plot(ax, sizeScores['Small'], 2, names, names_xaxis)
    plot(ax, sizeScores['Medium'], 3, names, names_xaxis)
    plot(ax, sizeScores['Large'], 4, names, names_xaxis)
    plot(ax, scriptScores['Latin'], 5, names, names_xaxis)
    plot(ax, scriptScores['Other'], 6, names, names_xaxis)

    ax.set_ylim(yrange)
    #ax.set_xlim((-.75,12.75))

    names2 = ['All', 'Small', 'Medium', 'Large', 'Latin', 'Other']
    ax.set_xticks([x+1.3 for x in range(len(names2))], names2)
    plt.xticks(rotation = 45)
    ax.plot([1.8,1.8],[0,100], color='black', linestyle='dashed')
    ax.plot([4.8,4.8],[0,100], color='black', linestyle='dashed')
    

    leg = ax.legend(loc=legend_loc)
    leg.get_frame().set_linewidth(1.5)
    fig.savefig(task + '.pdf', bbox_inches='tight')

v = 'bert-base-multilingual-cased-2.10'
makeTable('Tokens', (90,100), ['RB', 'single-task-'+v, 'multi-task-'+v, 'multi-ling-'+v], ['RB', 'ST', 'MT',  'ML+MT'], 'lower right')
makeTable('LAS', (60,90), ['gold-'+v, 'multi-task-'+v, 'multi-ling-'+v], ['GOLD', 'MT', 'ML+MT'], 'lower right')

