import os
import numpy as np
import seaborn as sb; sb.set()
import matplotlib.pyplot as plt
import constants as c
from analysis.cka import compute_cka_score as cka
from analysis.cka import gram_linear


bli_file = c.BLI_DIR + "yacle.cka.freq.7k.en-fi.tsv"
with open(bli_file) as f:
  word_pairs = [line.strip().split("\t") for line in f]
en_words = [tmp[0] for tmp in word_pairs]
fi_words = [tmp[1] for tmp in word_pairs]


def cka_layerwise_plots(lang, relevant_words, save_dir, normalize_l2):
  os.makedirs(plot_dir, exist_ok=True)
  emb_path = c.AOC_ISO_DIR
  with open(emb_path + "%s/%s.vocab" % (lang, lang)) as f:
    en_vocab2id = {term.strip(): i for i, term in enumerate(f.readlines())}
  relevant_term_ids = [en_vocab2id[term] for term in relevant_words if term in en_vocab2id]
  en_layer_embs = []
  for k in range(13):
    with open(emb_path + "%s/%s_%s.npy" % (lang, lang, str(k)), "rb") as f:
      all_embs = np.load(f)
      relevant_embs = all_embs[relevant_term_ids]
      if normalize_l2:
        relevant_embs /= np.linalg.norm(relevant_embs, axis=-1)[:, np.newaxis]
      en_layer_embs.append(relevant_embs)

  axis = ["L0", "L1", "L2", "L3", "L4", "L5", "L6", "L7", "L8", "L9", "L10", "L11", "L12"]
  cka_scores = []
  for i in range(13):
    cka_scores_inner = []
    en_embs_i = en_layer_embs[i]
    input_a = gram_linear(en_embs_i)
    for j in range(13):
      en_embs_j = en_layer_embs[j]
      input_b = gram_linear(en_embs_j)
      cka_score = cka(input_a, input_b)
      cka_scores_inner.append(cka_score)
    cka_scores.append(cka_scores_inner)
    print("done layer %s" % str(i))
  data = np.array(cka_scores)
  sb.heatmap(data, cmap="YlGnBu", yticklabels=axis, xticklabels=axis, annot=True, annot_kws={"size": 8})
  plt.title("%s" % "English" if lang == "en" else "Finnish", pad=20)
  plt.yticks(rotation=0)
  os.makedirs(save_dir, exist_ok=True)
  plt.savefig("%s/%s.png" % (save_dir, lang))
  plt.clf()


plot_dir = c.CKA_PLOT_DIR + "layer_similarity_plots_l2/"
cka_layerwise_plots("en", en_words, plot_dir, True)
cka_layerwise_plots("fi", fi_words, plot_dir, True)
