import os
import myutils
import json
import matplotlib as mpl
import matplotlib.patches as mpatches
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('scripts/rob.mplstyle')

allData = {}
## single-dataset models
for udPath in ['data/ud-treebanks-v2.2.singleToken/']:
    for UDdir in sorted(os.listdir(udPath)):
        if not UDdir.startswith("UD") or not os.path.isdir(udPath + UDdir):
            continue
        train, dev, test = myutils.getTrainDevTest(udPath + UDdir)
        splits = True
        attention = True
        if train != '':
            if not myutils.hasColumn(train, 1, threshold=.1):
                #print('noWords ', train)
                continue
        for mlm in myutils.mlms[1:]:            
            modelName = 'tok.' + mlm.replace('/', '_') + '.' + UDdir + '.' + str(splits) + '.'  + str(attention)
            modelPath = myutils.getModel(modelName)
            if modelPath == '':
                continue
            scalarPath = modelPath.replace('model.pt', 'scalars.json')
            scalarData = json.load(open(scalarPath))
            for task in scalarData:
                if task not in allData:
                    allData[task] = []
                allData[task].append(scalarData[task])


fig, ax = plt.subplots(figsize=(8,5), dpi=300)
taskIdx = -1
for task in allData:
    if task not in ['dependency', 'upos', 'tokenization']:
        continue
    layers = []
    for layer_idx in range(0,13):
        layers.append([])
        for weights in allData[task]:
            layers[-1].append(weights[layer_idx])
    print(task, len(layers[0]))
    x = [val + taskIdx *.2 for val in range(len(layers))]
    ax.violinplot(layers, x, vert=True, showmeans=True)
    # Add a bar outside of scope for the legend (hacky..)
    ax.bar([16],[.2], label=task)
    taskIdx += 1

ax.plot([-1,14], [1/13, 1/13], color='black', label='uniform')
ax.set_ylabel('Weight')
# names = 'base', 'LexNorm', 'Resample', 'Context', 'Best'
#ax.set_xticks(range(len(in_domain)), names)
#plt.xticks(rotation = 45) 

ax.set_xticks(range(len(layers)))
ax.set_xticklabels(['input'] + list(range(1,13))) 

ax.set_ylim((.025,.16))
ax.set_xlim((-.75,12.75))

leg = ax.legend(loc='upper left')
leg.get_frame().set_linewidth(1.5)
fig.savefig('attentions.pdf', bbox_inches='tight')




