# required imports
import sklearn
import transformers
import nltk
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import csv
import termcolor
import torch
import scipy.optimize
from termcolor import colored
from scipy.spatial import distance
from sklearn.manifold import MDS
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from transformers import pipeline, AutoTokenizer, AutoModel
from transformers import GPTNeoXForCausalLM, AutoTokenizer, GPTNeoXModel
from torch.nn.functional import log_softmax
import os
import time
import math
from scipy import spatial

castro_et_al_typicality = pd.read_csv("files/castro_et_al_typicality.csv")

df = castro_et_al_typicality

def make_vector_pythia(text, model, tokenizer, device):
  a1 = text
  b1 = tokenizer.encode(a1)
  input_ids = torch.tensor(b1).unsqueeze(0).to(device)  # Batch size 1
  outputs = model(input_ids )
  vector_array = []
  hidden_states = [i for i in range(1,len(outputs.hidden_states))]

  for hidden_state in hidden_states:

    last_hidden_states = outputs.hidden_states[hidden_state]
    if len(last_hidden_states.shape)>1:
      last_hidden_states = torch.mean(last_hidden_states, 1)
    last_hidden_states = last_hidden_states[0].flatten()
    vectors = last_hidden_states.cpu().detach().numpy()
    vector_array.append(vectors)
  return vector_array
def np_cosine_sim(a,b):
  return (a @ b.T) / (np.linalg.norm(a)*np.linalg.norm(b))

def _pythia(model, tokenizer, prompt, context):
        prompt = prompt.lower()
        ret = {}
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        tokens = tokenizer.convert_ids_to_tokens(input_ids[0][:])
        input_ids = input_ids.to("cuda")
        token_logprobs = []
        logits = model(input_ids).logits
        all_tokens_logprobs = log_softmax(logits.double(), dim=2)
        for k in range(0, input_ids.shape[1]):
          token_logprobs.append(all_tokens_logprobs[:,k,input_ids[0,k]])

        token_logprobs = [lp.detach().cpu().numpy()[0] for lp in token_logprobs]
        i = len(tokenizer(context, return_tensors="pt").input_ids[0]) 
        outputs = sum(token_logprobs[i:])/(len(token_logprobs)-i)
        return np.exp(outputs)


def typicality_main(model, tokenizer, model_hidden_state, directory, revision, device, model_name):
    category_arr = df['Category'].tolist()
    response_arr = df['Response'].tolist()
    typicality_arr = df['Typicality_Scores'].tolist() 
    typicality_score_dict = {"typicality": [], "typicality_score":[]}
    dict_ = {}
    category_arr_list = list(set(category_arr))
    response_arr_list = list(set(response_arr))
    vector_dict = {}
    start_time = time.time()
    for i in category_arr_list:
      vector_dict[i] = make_vector_pythia(i, model,  tokenizer, device)

    for i in response_arr_list:
      vector_dict[i] = make_vector_pythia(i, model,  tokenizer, device)
    
    dict_hidden_state = {}
    category_arr_list_no_set = list(category_arr)
    response_arr_list_no_set = list(response_arr)
    context_zero_shot = " "
    context_one_shot = "a cat is a mammal\n "
    context_two_shot = "a cat is a mammal\n a square is shape\n "
    context_three_shot = "a cat is a mammal\n a square is shape\n a tomato is a vegetable\n"
    
    for i in range(len(category_arr_list_no_set)):
      if not math.isnan(typicality_arr[i]):
        key_vector = vector_dict[category_arr_list_no_set[i]]
        response_vector = vector_dict[response_arr_list_no_set[i]]
        similarity_val = []
        key = category_arr_list_no_set[i]
        response_val = response_arr_list_no_set[i]

        if (key not in dict_):
          dict_[key] = { "response": [], "similarity_val": [], "typicality_val": [], "log_probs_val":[], "correlation": None, "log_probs_val_zero_shot":[], "log_probs_val_one_shot":[], "log_probs_val_two_shot": [], "log_probs_val_three_shot": [] }
        t1 = [ np_cosine_sim(key_vector[hidden_state], response_vector[hidden_state]) for hidden_state in range(model_hidden_state)]
        similarity_val = max(t1)
        log_probs_val_zero_shot = _pythia(model, tokenizer,context_zero_shot +  "a " +  response_arr[i] + " is " + category_arr[i], context_zero_shot)
        log_probs_val_one_shot = _pythia(model, tokenizer,context_one_shot +  "a " +  response_arr[i] + " is " + category_arr[i], context_one_shot)
        log_probs_val_two_shot = _pythia(model, tokenizer,context_two_shot +  "a " +  response_arr[i] + " is " + category_arr[i], context_two_shot)
        log_probs_val_three_shot = _pythia(model, tokenizer, context_three_shot +  "a " +  response_arr[i] + " is " + category_arr[i], context_three_shot)
        dict_[key]["response"].append(response_val)
        dict_[key]["typicality_val"].append(typicality_arr[i])
        dict_[key]["similarity_val"].append(similarity_val)
        dict_[key]["log_probs_val_zero_shot"].append(log_probs_val_zero_shot)
        dict_[key]["log_probs_val_one_shot"].append(log_probs_val_one_shot)
        dict_[key]["log_probs_val_two_shot"].append(log_probs_val_two_shot)
        dict_[key]["log_probs_val_three_shot"].append(log_probs_val_three_shot)

    keys = list(dict_.keys())
    corrs = []
    p_values = []
    p_values_prompt_zero_shot = []
    corrs_prompt_zero_shot = []

    p_values_prompt_one_shot = []
    corrs_prompt_one_shot = []

    p_values_prompt_two_shot = []
    corrs_prompt_two_shot = []

    p_values_prompt_three_shot = []
    corrs_prompt_three_shot = []


    for key in dict_.keys():
      x = dict_[key]["typicality_val"]
      y = dict_[key]["similarity_val"]
      corr, p_value = scipy.stats.spearmanr(x, y)
      dict_[key]["correlation"] = corr
      corrs.append(corr)
      p_values.append(p_value)

      y_prompt_zero_shot = dict_[key]["log_probs_val_zero_shot"]
      corr_prompt_zero_shot, p_value_prompt_zero_shot = scipy.stats.spearmanr(x, y_prompt_zero_shot)
      dict_[key]["correlation_prompt_zero_shot"] = corr_prompt_zero_shot
      corrs_prompt_zero_shot.append(corr_prompt_zero_shot)
      p_values_prompt_zero_shot.append(p_value_prompt_zero_shot)


      y_prompt_one_shot = dict_[key]["log_probs_val_one_shot"]
      corr_prompt_one_shot, p_value_prompt_one_shot = scipy.stats.spearmanr(x, y_prompt_one_shot)
      dict_[key]["correlation_prompt_one_shot"] = corr_prompt_one_shot
      corrs_prompt_one_shot.append(corr_prompt_one_shot)
      p_values_prompt_one_shot.append(p_value_prompt_one_shot)

      y_prompt_two_shot = dict_[key]["log_probs_val_two_shot"]
      corr_prompt_two_shot, p_value_prompt_two_shot = scipy.stats.spearmanr(x, y_prompt_two_shot)
      dict_[key]["correlation_prompt_two_shot"] = corr_prompt_two_shot
      corrs_prompt_two_shot.append(corr_prompt_two_shot)
      p_values_prompt_two_shot.append(p_value_prompt_two_shot)

      y_prompt_three_shot = dict_[key]["log_probs_val_three_shot"]
      corr_prompt_three_shot, p_value_prompt_three_shot = scipy.stats.spearmanr(x, y_prompt_three_shot)
      dict_[key]["correlation_prompt_three_shot"] = corr_prompt_three_shot
      corrs_prompt_three_shot.append(corr_prompt_three_shot)
      p_values_prompt_three_shot.append(p_value_prompt_three_shot)

    df_typicality = pd.DataFrame({"Category": keys, "Spearman's Correlation":corrs, "P-value":p_values}, columns = ["Category", "Spearman's Correlation", "P-value"])
    df_typicality_prompt_zero_shot = pd.DataFrame({"Category": keys, "Spearman's Correlation":corrs_prompt_zero_shot, "P-value":p_values_prompt_zero_shot}, columns = ["Category", "Spearman's Correlation", "P-value"])
    df_typicality_prompt_one_shot = pd.DataFrame({"Category": keys, "Spearman's Correlation":corrs_prompt_one_shot, "P-value":p_values_prompt_one_shot}, columns = ["Category", "Spearman's Correlation", "P-value"])
    df_typicality_prompt_two_shot = pd.DataFrame({"Category": keys, "Spearman's Correlation":corrs_prompt_two_shot, "P-value":p_values_prompt_two_shot}, columns = ["Category", "Spearman's Correlation", "P-value"])
    df_typicality_prompt_three_shot = pd.DataFrame({"Category": keys, "Spearman's Correlation":corrs_prompt_three_shot, "P-value":p_values_prompt_three_shot}, columns = ["Category", "Spearman's Correlation", "P-value"])

    
    directory = directory +"/" + revision + "/"
    if not os.path.exists(directory):
          os.makedirs(directory)
    df_typicality.to_excel(directory + "df_typicality_extended.xlsx")
    df_typicality_prompt_zero_shot.to_excel(directory + "df_typicality_prompt_zero_shot.xlsx")
    df_typicality_prompt_one_shot.to_excel(directory + "df_typicality_prompt_one_shot.xlsx")
    df_typicality_prompt_two_shot.to_excel(directory + "df_typicality_prompt_two_shot.xlsx")
    df_typicality_prompt_three_shot.to_excel(directory + "df_typicality_prompt_three_shot.xlsx")