# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import json
import tensorflow as tf
############################################################
# src
code = [
41, 18511, 22158,  2594,  8763,  2606,  2594,  2891,  3291,
       31737,  2625,  2586, 13701,  4541,  7915,  6960,    14,  2605,
          41,  3029,  3163,  5438,  3444,  2610,  4532,  2795,    62,
        8445,  3102,  3167,  2581,  2594,  4498,  2672,  2795, 16308,
          62, 10029,  3789,  2758,  5272,    16,     2,

]
ref = [
2960, 28951,  2618,  2627, 14742,    14,  2717,  4541,    16,
        7998, 16135,    66, 23480,  2702,  3295,  4392,  2692,  3688,
        7721,    14, 11059,  3543, 12210,  4977, 19740,  3050,  5803,
        8805,  2614,  6007,    14,  3137,  2751, 10034,  7649,  6725,
          16,     2,
        ]
code  = [2709, 3190, 2588, 20637, 2707, 30714, 16047, 80, 14, 2823, 2677, 14727, 6244, 2672, 14073, 2837, 7837, 3525, 2594, 21031, 14, 3606, 3429, 2594, 7329, 2588, 4843, 2581, 2594, 12306, 16, 2]
ref = [2784, 11254, 2647, 2693, 23573, 2785, 16047, 80, 8057, 14, 2618, 2679, 28020, 29481, 14, 2772, 2618, 44, 4370, 2682, 2620, 29222, 3042, 2794, 6361, 2773, 2855, 9794, 3670, 2598, 14381, 4130, 2639, 8559, 2831, 16, 2]
# code = ref #uncomment for vis ref
step = 20
vegetables = []
df = []
farmers = []
def sentence_vis(path):
    with open(path) as json_file:
        data = json.load(json_file)
        farmers = list(range(step,0,-1))
    sns.set(rc={'figure.figsize': (16, 16)})
    sns.heatmap(np.reshape(data, (step, 1)),
                cmap="YlGnBu",
                xticklabels="s",
                yticklabels=farmers,
    ax=ax,annot=True,square=True, annot_kws={"fontsize":6},cbar=False)
    return
def token_vis(path):
    with open(path) as json_file:
        data = json.load(json_file)
        vegetables = list(range(len(code)))

        farmers = list(range(step,0,-1))
    sns.set(rc={'figure.figsize': (16, 16)})

    # for token level ############################################
    sns.heatmap(np.transpose(data, (1, 0)),
    cmap="YlGnBu",
    xticklabels=vegetables,
    yticklabels=farmers,
    ax=ax,annot=True,square=True, annot_kws={"fontsize":6},cbar=False)
    return
#########examples##########
# sentence_vis("../universalTransformer/dec_cka_similarity_sentence.json")

# token_vis("/Users/barid/Desktop/lt/ut_6/enc_cka_similarity.json")

token_vis("../lazy_transformer/enc_halting_pro.json")
# sentence_vis("../lazy_transformer/enc_cka_similarity_sentence.json")

plt.setp(ax.get_xticklabels(), rotation=30, ha="right", rotation_mode="anchor")
plt.xlabel("position")
plt.ylabel("step")
fig.tight_layout()
plt.show()
#######################################################