home_directory = ""
llama_token = ""

import sklearn
import transformers
import nltk
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import csv
import termcolor
import torch
import scipy.optimize
from termcolor import colored
from scipy.spatial import distance
from sklearn.manifold import MDS
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from transformers import pipeline, AutoTokenizer, AutoModel
from transformers import GPTNeoXForCausalLM, AutoTokenizer, GPTNeoXModel
import os
import math
from torch.nn.functional import log_softmax
filename = home_directory + 'Mapping-Cognitive-Development-of-Humans-to-LLMs/effects/SAT-analogies-Turney.txt'
k_shot_prompt = ["walk is to legs as chew is to mouth\n","topaz is to yellow as amethyst is to purple\n", "reinforce is to stronger as erode is to weaker\n"]
with open(filename) as f:
    contents = f.read()

chunks = contents.split('\n\n')  # Split on blank lines.

items = []
chunks = [i.split("\n") for i in chunks]
for chunk in chunks:
    chunk.pop(0)  # lose the description of the question's source
    correct_letter = chunk.pop()  # which of the options
    correct_index = 'abcde'.index(correct_letter)
    chunk = [line.strip().split()[:-1] for line in chunk]
    question = chunk.pop(0)
    chunk = chunk[:5]
    item = {'question': question,
            'answers': chunk,
            'correct': correct_index}
    items.append(item)
make_list_of_texts = []
for i in items:
    make_list_of_texts.append(i["question"][0])
    make_list_of_texts.append(i["question"][1])
    for j in i["answers"]:
        make_list_of_texts.append(j[0])
        make_list_of_texts.append(j[1])
make_list_of_texts = list(set(make_list_of_texts))

def make_vector_pythia(text, model, tokenizer, device):
  a1 = text
  b1 = tokenizer.encode(a1)
  input_ids = torch.tensor(b1).unsqueeze(0).to(device)  # Batch size 1
  outputs = model(input_ids , output_hidden_states = True)
  vector_array = []
  hidden_states = [i for i in range(1,len(outputs.hidden_states))]

  for hidden_state in hidden_states:

    last_hidden_states = outputs.hidden_states[hidden_state]
    if len(last_hidden_states.shape)>1:
      last_hidden_states = torch.mean(last_hidden_states, 1)
    last_hidden_states = last_hidden_states[0].flatten()
    vectors = last_hidden_states.cpu().detach().numpy(force = True)
    vector_array.append(vectors)
  return vector_array


def get_vectors_for_all_words(texts, model, tokenizer, device):
  dict_of_vectors = {}
  for i in texts: 
    dict_of_vectors[i] = make_vector_pythia(i, model,  tokenizer, device)

  return dict_of_vectors


def run_all(all_text, model, tokenizer, device):
  vectors = get_vectors_for_all_words(all_text, model, tokenizer, device)
  return vectors

def three_cos_add(first_query, second_query, first_item, second_item):
  x = np.subtract(np.add(np.array(second_query), np.array(first_item)), np.array(first_query))
  return 1 - distance.cosine(x , np.array(second_item))

def three_cos_mul(first_query, second_query, first_item, second_item):
  cos1 = 1 - distance.cosine(second_query, second_item)
  cos2 = 1 - distance.cosine(first_item, second_item)
  cos3 = 1 - distance.cosine(first_query, second_item)
  return cos1*cos2/(cos3+0.00001)


def concat_cos(first_query, second_query, first_item, second_item):
  t1 = np.concatenate((first_query, second_query), axis=0)
  t2 = np.concatenate((first_item, second_item), axis=0)
  return 1 - distance.cosine(t1, t2)




def _pythia(model, tokenizer, prompt, context):
        prompt = prompt.lower()
        ret = {}
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        tokens = tokenizer.convert_ids_to_tokens(input_ids[0][:])
        input_ids = input_ids.to("cuda")
        token_logprobs = []
        logits = model(input_ids).logits
        all_tokens_logprobs = log_softmax(logits.double(), dim=2)
        for k in range(0, input_ids.shape[1]):
          token_logprobs.append(all_tokens_logprobs[:,k,input_ids[0,k]])

        token_logprobs = [lp.detach().cpu().numpy()[0] for lp in token_logprobs]
        i = len(tokenizer(context, return_tensors="pt").input_ids[0]) 
        outputs = sum(token_logprobs[i:])/(len(token_logprobs)-i)
        return np.exp(outputs)




def prompt(first_query, second_query, first_item, second_item, model, tokenizer, k):
  if k==0:
    context = ""
  else:
    context = "".join(k_shot_prompt[-k:])

  prompt = first_query + " is to " + second_query + " as " + first_item + " is to " + second_item 
  
  return _pythia(model, tokenizer, context + prompt, context), prompt


def sat_tests_main(model, tokenizer, model_hidden_state, directory, revision, device, model_name):
    vector_dict = {}
    for i in make_list_of_texts:
      vector_dict[i] = make_vector_pythia(i, model,  tokenizer, device)
    three_cos_add_results = 0
    three_cos_mul_results = 0
    concat_cos_results = 0
    prompt_results_zero_shot = 0
    prompt_results_one_shot = 0
    prompt_results_two_shot = 0
    prompt_results_three_shot = 0
    total_count = 0

    for i in items:
      total_count+=1
      correct = i["correct"]
      ques_1 = i["question"][0]
      ques_2 = i["question"][1]
      if (np.argmax([three_cos_add(vector_dict[ques_1][-1], vector_dict[ques_2][-1], vector_dict[j[0]][-1], vector_dict[j[1]][-1]) for j in i["answers"]])) == correct:
        three_cos_add_results+=1
      if (np.argmax([three_cos_mul(vector_dict[ques_1][-1], vector_dict[ques_2][-1], vector_dict[j[0]][-1], vector_dict[j[1]][-1]) for j in i["answers"]])) == correct:
        three_cos_mul_results+=1
      if (np.argmax([concat_cos(vector_dict[ques_1][-1], vector_dict[ques_2][-1], vector_dict[j[0]][-1], vector_dict[j[1]][-1]) for j in i["answers"]])) == correct:
        concat_cos_results+=1
      t1 = [prompt(ques_1, ques_2, j[0], j[1], model, tokenizer, 0) for j in i["answers"]]
      t2 = [i[0] for i in t1]
      if (np.argmax(t2)) == correct:
        prompt_results_zero_shot+=1

      t1 = [prompt(ques_1, ques_2, j[0], j[1], model, tokenizer, 1) for j in i["answers"]]
      t2 = [i[0] for i in t1]
      if (np.argmax(t2)) == correct:
        prompt_results_one_shot+=1
      t1 = [prompt(ques_1, ques_2, j[0], j[1], model, tokenizer, 2) for j in i["answers"]]
      t2 = [i[0] for i in t1]
      if (np.argmax(t2)) == correct:
        prompt_results_two_shot+=1
      
      t1 = [prompt(ques_1, ques_2, j[0], j[1], model, tokenizer, 3) for j in i["answers"]]
      t2 = [i[0] for i in t1]
      if (np.argmax(t2)) == correct:
        prompt_results_three_shot+=1
    details = {
    'categories' : ['three_cos_add_results', 'three_cos_mul_results', 'concat_cos_results', 'prompt_results_as_to_zero_shot', 'prompt_results_as_to_one_shot', 'prompt_results_as_to_two_shot',  'prompt_results_as_to_three_shot'],
    'Values' : [three_cos_add_results/total_count, three_cos_mul_results/total_count, concat_cos_results/total_count, prompt_results_zero_shot/total_count, prompt_results_one_shot/total_count, prompt_results_two_shot/total_count, prompt_results_three_shot/total_count],
    }
    df_sat = pd.DataFrame(details)
    directory = directory +"/" + revision + "/"
    if not os.path.exists(directory):
          os.makedirs(directory)
    df_sat.to_excel(directory + "df_sat_as_to.xlsx")
    return
