import glob
import pickle
import torch
import scipy
import numpy as np
from torch import nn
import transformers
from tqdm import tqdm
from src.steering_layer import SteeringLayer
from utils.steering_vector_loader import load_steering_vectors

ALPACA_WEIGHTS_FOLDER = "/localdata1/EmEx/model_weights/alpaca_7b"
# ALPACA_WEIGHTS_FOLDER = "/localdata1/EmEx/model_weights/llama/huggingface/7B"
STEERING_VECTOR_PATH = "/localdata1/EmEx/model_weights/alpaca_7b/old_steering_vectors/20-04_24-04_trainings"
STEERING_VECTOR_FILES = glob.glob(f'{STEERING_VECTOR_PATH}/*')

layer_15_16_17 =  [x for x in STEERING_VECTOR_FILES if '15, 16, 17' in x]
layer_18_19_20_21_22 =  [x for x in STEERING_VECTOR_FILES if '18, 19, 20, 21, 22' in x]

INSERTION_LAYERS = [15,16,17]

device = torch.device('cuda:1')

alpaca_model = transformers.AutoModelForCausalLM.from_pretrained(ALPACA_WEIGHTS_FOLDER).to(device)
alpaca_tokenizer = transformers.AutoTokenizer.from_pretrained(ALPACA_WEIGHTS_FOLDER)
for layer in INSERTION_LAYERS:
    alpaca_model.model.layers[layer].mlp = SteeringLayer(alpaca_model.model.layers[layer].mlp)

steering_vectors = load_steering_vectors("/localdata1/EmEx/model_weights/alpaca_7b/old_steering_vectors/20-04_24-04_trainings")
positive = [sv for sv in steering_vectors if sv[-1] == 1]
negative = [sv for sv in steering_vectors if sv[-1] == 0]



################################b########################################
pos_actis = []
neg_actis = []
for sv in tqdm(positive):
    input_tokens = alpaca_tokenizer(sv[1].replace('\n',''), return_tensors="pt").to(device)
    gen_text = alpaca_model.forward(input_tokens.input_ids, output_hidden_states=True)
    # pos_actis.append(gen_text[2][15].detach().cpu().numpy()[0])
    pos_actis.append([gen_text[2][15][0][-1].detach().cpu().numpy(),gen_text[2][16][0][-1].detach().cpu().numpy(),gen_text[2][17][0][-1].detach().cpu().numpy()])
    

for sv in tqdm(negative):
    input_tokens = alpaca_tokenizer(sv[1].replace('\n',''), return_tensors="pt").to(device)
    gen_text = alpaca_model.forward(input_tokens.input_ids, output_hidden_states=True)
    # neg_actis.append(gen_text[2][15].detach().cpu().numpy()[0])
    neg_actis.append([gen_text[2][15][0][-1].detach().cpu().numpy(),gen_text[2][16][0][-1].detach().cpu().numpy(),gen_text[2][17][0][-1].detach().cpu().numpy()])
#################################e#######################################



positive_mean = []
negative_mean = []
sv_to_target_negative =[]
sv_to_target_positive = []
for n, layer in enumerate(INSERTION_LAYERS):
    positive_mean.append(torch.mean(torch.cat([x[0][n] for x in positive]),0))
    negative_mean.append(torch.mean(torch.cat([x[0][n] for x in negative]),0))
    sv_to_target_negative.append(torch.mean(torch.cat([x[0][n] for x in negative]),0) - torch.mean(torch.cat([x[0][n] for x in positive]),0))
    sv_to_target_positive.append(torch.mean(torch.cat([x[0][n] for x in positive]),0) - torch.mean(torch.cat([x[0][n] for x in negative]),0))



##################################b####################################
positive_mean = []
negative_mean = []

pos_layer_15 = [a[0] for a in pos_actis]
pos_layer_16 = [a[1] for a in pos_actis]
pos_layer_17 = [a[2] for a in pos_actis]
neg_layer_15 = [a[0] for a in neg_actis]
neg_layer_16 = [a[1] for a in neg_actis]
neg_layer_17 = [a[2] for a in neg_actis]

positive_mean = [np.mean(pos_layer_15,0),np.mean(pos_layer_16,0),np.mean(pos_layer_17,0)]
negative_mean = [np.mean(neg_layer_15,0),np.mean(neg_layer_16,0),np.mean(neg_layer_17,0)]
sv_to_target_positive = [positive_mean[0] - negative_mean[0], positive_mean[1] - negative_mean[1], positive_mean[2] - negative_mean[2]]
sv_to_target_negative = [negative_mean[0] - positive_mean[0], negative_mean[1] - positive_mean[1], negative_mean[2] - positive_mean[2]]

with open('pos_acti.pkl', 'wb') as f:
    pickle.dump(pos_actis,f)

with open('neg_acti.pkl', 'wb') as f:
    pickle.dump(neg_actis,f)
####################################e########################################


user_input = "Write a short restaurant review!"
input_text = (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\r\n\r\n"
        f"### Instruction:\r\n{user_input}\r\n\r\n### Response:"
    )

input_tokens = alpaca_tokenizer(input_text, return_tensors="pt").to(device)


#############b#########
# with open('/localdata1/EmEx/model_weights/alpaca_7b/old_steering_vectors/positive_steering_[15,16,17]_sparse.pkl', 'rb') as f:
#     sparse_positive_sv = torch.load(f)[4]
# with open('/localdata1/EmEx/model_weights/alpaca_7b/old_steering_vectors/negative_steering_[15,16,17]_sparse.pkl', 'rb') as f:
#     sparse_negative_sv = torch.load(f)[4]
#############e#########


for i in range(0,10):
    lmbda = i / 10
    for n, _ in enumerate(INSERTION_LAYERS):
        # alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter((lmbda * sparse_negative_sv[n]).to(device))
        alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter((lmbda * torch.from_numpy(sv_to_target_negative[n])).to(device))
    gen_tokens = alpaca_model.generate(input_tokens.input_ids, max_length=150)
    print(f"Negative steering vector percentage: {lmbda}")
    print(f"Generated sentence: {alpaca_tokenizer.batch_decode(gen_tokens)[0].replace(input_text,'')}")

for i in range(0,10):
    lmbda = i / 10
    for n, _ in enumerate(INSERTION_LAYERS):
        # alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter((lmbda * sparse_positive_sv[n]).to(device))
        alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter((lmbda * torch.from_numpy(sv_to_target_positive[n])).to(device))
    gen_tokens = alpaca_model.generate(input_tokens.input_ids, max_length=250)
    print(f"Positive steering vector percentage: {lmbda}")
    print(f"Generated sentence: {alpaca_tokenizer.batch_decode(gen_tokens)[0].replace(input_text,'')}")
