import sklearn
import transformers
import nltk
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import csv
import termcolor
import torch
import scipy.optimize
from termcolor import colored
from scipy.spatial import distance
from sklearn.manifold import MDS
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from transformers import pipeline, AutoTokenizer, AutoModel
from transformers import GPTNeoXForCausalLM, AutoTokenizer, GPTNeoXModel
from transformers import LlamaTokenizer, LlamaForCausalLM
from multiprocessing import Pool
import time
import os
from effects.numeric_capabilities import numeric_effects_main
from effects.typicality import typicality_main
from effects.similarity_between_number_and_non_number_words import sim_num_nonNum_main
from effects.ravens import ravens_main
from effects.sat_turney_items import sat_tests_main

import lm_eval
compare_list = ["meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-13b-hf", "berkeley-nest/Starling-LM-7B-alpha", "tiiuae/falcon-7b", "mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1", "Qwen/Qwen1.5-0.5B", "Qwen/Qwen1.5-1.8B", "Qwen/Qwen1.5-4B", "Qwen/Qwen1.5-7B", "Qwen/Qwen1.5-14B"]


def accuracy_on_task(task_name, eval_model, template_name, num_fewshot, directory):
    predictions_path = os.path.join(directory, "zeroshot", task_title, "predictions.txt")
    predictions_dir = os.path.dirname(predictions_path)
    if not os.path.exists(predictions_dir):
        os.makedirs(predictions_dir)

    eval_task = lm_eval.get_task_list(task_name, template_names=[template_name])
    results = lm_eval.evaluate(model=eval_model, tasks=eval_task, seed=12,
                               num_fewshot=num_fewshot, predictions_path=predictions_path)
    accuracy = results['results'][0]['acc']
    return accuracy


def lm_eval_main(model, tokenizer, model_hidden_state, directory, revision, device, model_name):
    directory = directory +"/" + revision + "/"
    if not os.path.exists(directory):
          os.makedirs(directory)
    cache_dir="/storage/home/hcoda1/1/rshah441/scratch/"+model_name+"/step"+revision
    
    if model_name not in compare_list:
        command = """lm_eval --model hf  --model_args pretrained=""" + model_name + """,revision=""" + revision + """,cache_dir="""+cache_dir + """ --tasks blimp   --batch_size 32  --output_path """ + directory+"blimp"
        t1 = os.system(command)
    else:
        command = """lm_eval --model hf  --model_args pretrained=""" + model_name + """,cache_dir="""+cache_dir + """ --tasks blimp   --batch_size 32  --output_path """ + directory+"blimp"
        t1 = os.system(command)       
    

    