import glob
import pickle
import torch
import scipy
import numpy as np
from torch import nn
import transformers
from tqdm import tqdm
from src.steering_layer import SteeringLayer
from utils.steering_vector_loader import load_steering_vectors
import utils.dataset_loader as dsl



####
### USE ONLY TRAINING SET ACTIVATIONS FOR SVs!!!!!!!!!!!!!!!!!!!!
### USE ONLY TRAINING SET ACTIVATIONS FOR SVs!!!!!!!!!!!!!!!!!!!!
### USE ONLY TRAINING SET ACTIVATIONS FOR SVs!!!!!!!!!!!!!!!!!!!!
###

ALPACA_WEIGHTS_FOLDER = "/localdata1/EmEx/model_weights/alpaca_7b"
# ALPACA_WEIGHTS_FOLDER = "/localdata1/EmEx/model_weights/llama/huggingface/7B"
STEERING_VECTOR_PATH = "/localdata1/EmEx/model_weights/alpaca_7b/old_steering_vectors/20-04_24-04_trainings"
STEERING_VECTOR_FILES = glob.glob(f'{STEERING_VECTOR_PATH}/*')

INSERTION_LAYERS = [18, 19, 20]

device = torch.device('cuda:1')

alpaca_model = transformers.AutoModelForCausalLM.from_pretrained(ALPACA_WEIGHTS_FOLDER).to(device)
alpaca_tokenizer = transformers.AutoTokenizer.from_pretrained(ALPACA_WEIGHTS_FOLDER)
for layer in INSERTION_LAYERS:
    alpaca_model.model.layers[layer].mlp = SteeringLayer(alpaca_model.model.layers[layer].mlp)

STEERING_VECTOR_PATH = "/localdata1/EmEx/activations/"
STEERING_VECTOR_FILES = glob.glob(f'{STEERING_VECTOR_PATH}/shakespear*')

df_shake = dsl.load_shakespeare()

original_shakespeare = [] # sentiment = 0
modern_shakespeare = []   # sentiment = 1

for file in STEERING_VECTOR_FILES:
    with open(file, 'rb') as f:
        a = pickle.load(f)
        for entry in a:
            df_entry = df_shake.iloc[entry[0]]
            label = df_entry['sentiment']
            target_sentence = df_entry['sample']
            activations = entry[2]
            if df_entry['dataset'] == 'train':
                if int(label): # modern
                    modern_shakespeare.append([activations, target_sentence, label])
                    # positive.append([steering_vector, activations, loss, epoch, target_sentence])
                else: # original
                    original_shakespeare.append([activations, target_sentence, label])
                    # negative.append([steering_vector, activations, loss, epoch, target_sentence])  

modern_mean = []
original_mean = []
sv_to_target_modern =[]
sv_to_target_original = []
for n, layer in enumerate(INSERTION_LAYERS):
    modern_mean.append(np.mean([x[0][layer] for x in modern_shakespeare],0))
    original_mean.append(np.mean([x[0][layer] for x in original_shakespeare],0))
    sv_to_target_modern.append(np.mean([x[0][layer] for x in modern_shakespeare],0) - np.mean([x[0][layer] for x in original_shakespeare],0))
    sv_to_target_original.append(np.mean([x[0][layer] for x in original_shakespeare],0) - np.mean([x[0][layer] for x in modern_shakespeare],0))

print(len(original_shakespeare))

##################################b####################################
# positive_mean = []
# negative_mean = []

# pos_layer_15 = [a[0] for a in pos_actis]
# pos_layer_16 = [a[1] for a in pos_actis]
# pos_layer_17 = [a[2] for a in pos_actis]
# neg_layer_15 = [a[0] for a in neg_actis]
# neg_layer_16 = [a[1] for a in neg_actis]
# neg_layer_17 = [a[2] for a in neg_actis]

# positive_mean = [np.mean(pos_layer_15,0),np.mean(pos_layer_16,0),np.mean(pos_layer_17,0)]
# negative_mean = [np.mean(neg_layer_15,0),np.mean(neg_layer_16,0),np.mean(neg_layer_17,0)]
# sv_to_target_positive = [positive_mean[0] - negative_mean[0], positive_mean[1] - negative_mean[1], positive_mean[2] - negative_mean[2]]
# sv_to_target_negative = [negative_mean[0] - positive_mean[0], negative_mean[1] - positive_mean[1], negative_mean[2] - positive_mean[2]]

# with open('pos_acti.pkl', 'wb') as f:
#     pickle.dump(pos_actis,f)

# with open('neg_acti.pkl', 'wb') as f:
#     pickle.dump(neg_actis,f)
####################################e########################################


user_input = "Write a short poem."
input_text = (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\r\n\r\n"
        f"### Instruction:\r\n{user_input}\r\n\r\n### Response:"
    )

input_tokens = alpaca_tokenizer(input_text, return_tensors="pt").to(device)


#############b#########
# with open('/localdata1/EmEx/model_weights/alpaca_7b/old_steering_vectors/positive_steering_[15,16,17]_sparse.pkl', 'rb') as f:
#     sparse_positive_sv = torch.load(f)[4]
# with open('/localdata1/EmEx/model_weights/alpaca_7b/old_steering_vectors/negative_steering_[15,16,17]_sparse.pkl', 'rb') as f:
#     sparse_negative_sv = torch.load(f)[4]
#############e#########


for i in np.linspace(0, 1.3, 13):
    lmbda = i
    for n, _ in enumerate(INSERTION_LAYERS):
        # alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter((lmbda * sparse_positive_sv[n]).to(device))
        alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter((lmbda * torch.tensor(sv_to_target_original[n])).to(device))
    gen_tokens = alpaca_model.generate(input_tokens.input_ids, max_length=150)
    print(f"Original Shakespeare steering vector percentage: {lmbda}")
    print(f"Generated sentence: {alpaca_tokenizer.batch_decode(gen_tokens)[0].replace(input_text,'')}")

exit()
for i in np.linspace(0, 2, 21):
    lmbda = i
    for n, _ in enumerate(INSERTION_LAYERS):
        # alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter((lmbda * sparse_negative_sv[n]).to(device))
        alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter((lmbda * torch.tensor(sv_to_target_modern[n])).to(device))
    gen_tokens = alpaca_model.generate(input_tokens.input_ids, max_length=150)
    print(f"Moden Shakespeare steering vector percentage: {lmbda}")
    print(f"Generated sentence: {alpaca_tokenizer.batch_decode(gen_tokens)[0].replace(input_text,'')}")

