
# required imports
import sklearn
import transformers
import nltk
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import csv
import termcolor
import torch
import scipy.optimize
from termcolor import colored
from scipy.spatial import distance
from sklearn.manifold import MDS
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from transformers import pipeline, AutoTokenizer, AutoModel
from transformers import GPTNeoXForCausalLM, AutoTokenizer, GPTNeoXModel
import os

import random

def make_vector_pythia(text, model, tokenizer, device):
  a1 = text
  b1 = tokenizer.encode(a1)
  input_ids = torch.tensor(b1).unsqueeze(0).to(device)  # Batch size 1
  outputs = model(input_ids )
  vector_array = []
  hidden_states = [i for i in range(1,len(outputs.hidden_states))]

  for hidden_state in hidden_states:

    last_hidden_states = outputs.hidden_states[hidden_state]
    if len(last_hidden_states.shape)>1:
      last_hidden_states = torch.mean(last_hidden_states, 1)
    last_hidden_states = last_hidden_states[0].flatten()
    vectors = last_hidden_states.cpu().detach().numpy()
    vector_array.append(vectors)
  return vector_array


def get_vectors_for_all_numbers(texts, model, tokenizer, device):
  dict_of_vectors = {}
  for i in texts: 
    dict_of_vectors[i] = make_vector_pythia(i, model,  tokenizer, device)

  return dict_of_vectors


def run_all(all_text, model, tokenizer, device):
  

  
  vectors = get_vectors_for_all_numbers(all_text, model, tokenizer, device)
  return vectors
def sim_num_nonNum_main(model, tokenizer, model_hidden_state, directory, revision, device):
  

  list_numbers_1 = ["one", "two", "three", "four", "five", "six", "seven", "eight" ]
  list_numbers_2 = ["One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight"]
  list_numbers_3 = ["1", "2", "3", "4", "5", "6", "7", "8"]
  all_text = list_numbers_1 + list_numbers_2 + list_numbers_3
  random_tokens = [tokenizer.convert_ids_to_tokens(random.randint(1000, 46862)) for i in range(10)]
  random_tokens = list(set(random_tokens).difference(all_text))

  vectors_non_numbers = run_all(random_tokens, model, tokenizer, device)
  vectors_numbers = run_all(all_text, model, tokenizer, device)
  arr_num_nonNum = []
  arr_num_num = []
  arr_nonNum = []
  for i in range(len(vectors_non_numbers)):
    for j in range(i, len(vectors_non_numbers)):
    
      for hidden_state in range(0, model_hidden_state): 
        arr_nonNum.append(1 -distance.cosine(vectors_non_numbers[random_tokens[i]][hidden_state], vectors_non_numbers[random_tokens[j]][hidden_state] ))
  for i in range(len(vectors_numbers)):
    for j in range(i, len(vectors_numbers)):
    
      for hidden_state in range(0, model_hidden_state): 
        arr_num_num.append(1 -distance.cosine(vectors_numbers[all_text[i]][hidden_state], vectors_numbers[all_text[j]][hidden_state] ))
  for i in range(len(vectors_non_numbers)):
    for j in range( len(vectors_non_numbers)):
    
      for hidden_state in range(0, model_hidden_state): 
        arr_num_nonNum.append(1 -distance.cosine(vectors_non_numbers[random_tokens[i]][hidden_state], vectors_numbers[all_text[j]][hidden_state] ))
  mean_arr_nonNum = np.mean(arr_nonNum)
  mean_arr_num_nonNum = np.mean(arr_num_nonNum)
  mean_arr_num_num = np.mean(arr_num_num)
  details = {
    'categories' : ['Number - Number', 'Number - Non Number', 'Non Number - Non Number'],
    'Mean Similarities' : [mean_arr_nonNum, mean_arr_num_nonNum],
    'Mean Similarities' : [mean_arr_num_num, mean_arr_num_nonNum, mean_arr_nonNum],
    
}
  df_sim_nesting = pd.DataFrame(details)
  directory = directory +"/" + revision + "/"
  if not os.path.exists(directory):
        os.makedirs(directory)
  df_sim_nesting.to_excel(directory + "df_sim_nesting.xlsx")