import argparse
import numpy as np
import seaborn as sb; sb.set()
from random import shuffle
import os
import tqdm
import matplotlib.pyplot as plt
from argparse import Namespace
import constants as c

"""
The CKA implementation is taken from here:
https://colab.research.google.com/github/google-research/google-research/blob/master/representation_similarity/Demo.ipynb#scrollTo=MkucRi3yn7UJ
"""


def center_gram(gram, unbiased=False):
  """Center a symmetric Gram matrix.

  This is equvialent to centering the (possibly infinite-dimensional) features
  induced by the kernel before computing the Gram matrix.

  Args:
    gram: A num_examples x num_examples symmetric matrix.
    unbiased: Whether to adjust the Gram matrix in order to compute an unbiased
      estimate of HSIC. Note that this estimator may be negative.

  Returns:
    A symmetric matrix with centered columns and rows.
  """
  if not np.allclose(gram, gram.T):
    raise ValueError('Input must be a symmetric matrix.')
  gram = gram.copy()

  if unbiased:
    # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M.
    # L. (2014). Partial distance correlation with methods for dissimilarities.
    # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically
    # stable than the alternative from Song et al. (2007).
    n = gram.shape[0]
    np.fill_diagonal(gram, 0)
    means = np.sum(gram, 0, dtype=np.float64) / (n - 2)
    means -= np.sum(means) / (2 * (n - 1))
    gram -= means[:, None]
    gram -= means[None, :]
    np.fill_diagonal(gram, 0)
  else:
    means = np.mean(gram, 0, dtype=np.float64)
    means -= np.mean(means) / 2
    gram -= means[:, None]
    gram -= means[None, :]

  return gram


def gram_linear(x):
  """Compute Gram (kernel) matrix for a linear kernel.

  Args:
    x: A num_examples x num_features matrix of features.

  Returns:
    A num_examples x num_examples Gram matrix of examples.
  """
  return x.dot(x.T)


def compute_cka_score(gram_x, gram_y, debiased=False):
  """Compute CKA.

  Args:
    gram_x: A num_examples x num_examples Gram matrix.
    gram_y: A num_examples x num_examples Gram matrix.
    debiased: Use unbiased estimator of HSIC. CKA may still be biased.

  Returns:
    The value of CKA between X and Y.
  """
  gram_x = center_gram(gram_x, unbiased=debiased)
  gram_y = center_gram(gram_y, unbiased=debiased)

  # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or
  # n*(n-3) (unbiased variant), but this cancels for CKA.
  scaled_hsic = gram_x.ravel().dot(gram_y.ravel())

  normalization_x = np.linalg.norm(gram_x)
  normalization_y = np.linalg.norm(gram_y)
  return scaled_hsic / (normalization_x * normalization_y)


def load_emb_vocab(path_emb, path_vocab):
  with open(path_emb, "rb") as f:
    emb = np.load(f)
  with open(path_vocab, "r") as f:
    vocab = [line.strip() for line in f]
    vocab2id = {term: i for i, term in enumerate(vocab)}
  return emb, vocab2id


def load_bli_dict(path_bli_dict, lang1, lang2):
  with open(path_bli_dict, "r") as f:
    word_pairs = [line.strip().split("\t") for line in f]
    # handle special case russian (case-sensitive)
    if lang1 != "ru":
      word_pairs = [(x.lower(), y) for x, y in word_pairs]
    if lang2 != "ru":
      word_pairs = [(x, y.lower()) for x, y in word_pairs]
  return word_pairs


def align_matrices(word_pairs, x_vocab2id, x_emb, y_vocab2id, y_emb):
  X = []
  Y = []
  skipped = 0
  for sw, tw in word_pairs:
    if sw in x_vocab2id and tw in y_vocab2id:
      X.append(x_emb[x_vocab2id[sw]])
      Y.append(y_emb[y_vocab2id[tw]])
    else:
      skipped += 1
  X = np.array(X)
  Y = np.array(Y)
  return X, Y


def main(args, mixup_wordpairs, normalize_l2):
  x_emb, x_vocab2id = load_emb_vocab(args.l1_emb, args.l1_vocab)
  y_emb, y_vocab2id = load_emb_vocab(args.l2_emb, args.l2_vocab)
  word_pairs = load_bli_dict(args.bli_dict, args.l1, args.l2)

  if mixup_wordpairs:
    tmp = list(range(len(word_pairs)))
    shuffle(tmp)
    word_pairs = [(word_pairs[k][0], word_pairs[tmp[k]][1]) for k in range(len(word_pairs))]

  X, Y = align_matrices(word_pairs, x_vocab2id, x_emb, y_vocab2id, y_emb)

  if normalize_l2:
    X /= np.linalg.norm(X, axis=-1)[:, np.newaxis]
    Y /= np.linalg.norm(Y, axis=-1)[:, np.newaxis]

  cka_score = compute_cka_score(gram_linear(X), gram_linear(Y))
  return cka_score


if __name__ == "__main__":
  # basedir = c.AOC_DIR
  # blidir = c.BLI_DIR
  # base_plot_dir = c.CKA_PLOT_DIR
  #
  # lang_pairs = [("en", "de"), ("en", "fi"), ("en", "ru"), ("en", "tr")]
  # modes = ["AOC", "ISO"]
  # encoders = ["monolingual", "multilingual"]
  # y_axis = ["L0", "L1", "L2", "L3", "L4", "L5", "L6", "L7", "L8", "L9", "L10", "L11", "L12", "AVER"]
  #
  # for do_shuffle in [True, False]:
  #   plot_dir = base_plot_dir + "random=%s/" % str(do_shuffle)
  #   print("generate random results: %s" % str(do_shuffle))
  #   for do_l2_normalization in [True, False]:
  #     plot_dir += "l2=%s/" % str(do_l2_normalization)
  #     print("generate results l2-normalization: %s" % str(do_l2_normalization))
  #     os.makedirs(plot_dir, exist_ok=True)
  #     for i, (l1, l2) in enumerate(lang_pairs):
  #       setting2result = {}
  #       all_cka_scores = []
  #       plot_x_axis = []
  #       for iso_aoc in modes:
  #         for mbert_mono in encoders:
  #           layer_cka_scores = []
  #           params = {}
  #           for layer in tqdm.tqdm(range(13), total=13):
  #             path = basedir + "%s/%s/" % (iso_aoc, mbert_mono)
  #             params["l1"] = l1
  #             params["l1_emb"] = path + "%s/%s_%s.npy" % (l1, l1, str(layer))
  #             params["l1_vocab"] = path + "%s/%s.vocab" % (l1, l1)
  #             params["l2"] = l2
  #             params["l2_emb"] = path + "%s/%s_%s.npy" % (l2, l2, str(layer))
  #             params["l2_vocab"] = path + "%s/%s.vocab" % (l2, l2)
  #             params["bli_dict"] = blidir + "%s-%s/yacle.cka.freq.7k.%s-%s.tsv" % (l1, l2, l1, l2)
  #             cka_input = Namespace(**params)
  #             cka = main(cka_input, do_shuffle, do_l2_normalization)
  #             layer_cka_scores.append(cka)
  #           avg = np.mean(layer_cka_scores)
  #           layer_cka_scores.append(avg)
  #           setting = "%s_%s_%s_%s" % (l1, l2, iso_aoc, mbert_mono)
  #           setting2result[setting] = layer_cka_scores
  #           all_cka_scores.append(layer_cka_scores)
  #           print("done with %s\n" % setting)
  #
  #           label = "mBERT-" if mbert_mono == "multilingual" else "mono-"
  #           label += iso_aoc
  #           plot_x_axis.append(label)
  #
  #       data = np.array(all_cka_scores).transpose()
  #       sb.heatmap(data, cmap="YlGnBu", yticklabels=y_axis, xticklabels=plot_x_axis, annot=True, annot_kws={"size": 10})
  #       plt.title("%s-%s" % (l1, l2), pad=20)
  #       plt.savefig("%s/%s-%s.png" % (plot_dir, l1, l2))
  #       plt.clf()
  #       print("done with %s-%s\n" % (l1, l2))

  parser = argparse.ArgumentParser()
  parser.add_argument("-l1", type=str)
  parser.add_argument("-l1_emb", type=str)
  parser.add_argument("-l1_vocab", type=str)
  parser.add_argument("-l2", type=str)
  parser.add_argument("-l2_emb", type=str)
  parser.add_argument("-l2_vocab", type=str)
  parser.add_argument("-bli_dict", type=str)
  args = parser.parse_args()
  cka_score = main(args, mixup_wordpairs=False, normalize_l2=False)
  print(str(cka_score))


