"""
Usage:
python3 -m fastchat.serve.api_lm_score --model ~/model_weights/llama-7b
"""
import argparse
import time
import csv
import tqdm
import os
import json

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
from transformers.generation.stopping_criteria import StoppingCriteriaList, LLamaQaStoppingCriteria
from transformers import GPT2LMHeadModel, GPT2Tokenizer

from fastchat.conversation import conv_templates, SeparatorStyle
import argparse
from .truthfulqa import models
from .truthfulqa import metrics as eval_metrics
from .truthfulqa import utilities
# import openai
from .truthfulqa.configs import ENGINE_MAP
import warnings
import pandas as pd
import numpy as np
from .tuned_lens.nn import TunedLens, Unembed

class OpenEndedContrastiveEarlyExit:
    def __init__(self, model_name, device, num_gpus, tuned_lens_path=None, auth_token=None, max_memory=40, lora=None, **kwargs):
        self.model_name = model_name
        self.device = device
        self.num_gpus = num_gpus
        self.stopping_criteria = None
        self.max_memory = max_memory

        self.model, self.tokenizer = self.load_model(model_name, auth_token=auth_token, lora=lora)

        if tuned_lens_path is not None:
            self.unembed = Unembed(self.model)
            self.tuned_lens = TunedLens.from_unembed_and_pretrained(
                self.unembed,
                lens_resource_id=tuned_lens_path,
            )
            self.tuned_lens.layer_translators.to(self.unembed.unembedding.weight.device)
            print(f"Successfully loaded TunedLens from {tuned_lens_path} with {len(self.tuned_lens.layer_translators)} layers onto the device {self.unembed.unembedding.weight.device}.")
        else:
            self.tuned_lens = None
        
    def load_model(self, model_name, auth_token=None, lora=None):
        if 'gpt2' in model_name:
            tokenizer = GPT2Tokenizer.from_pretrained(model_name)
            model = GPT2LMHeadModel.from_pretrained(model_name)
            model.cuda()
            return model, tokenizer
        if self.device == "cuda":
            kwargs = {"torch_dtype": torch.float16, "offload_folder": f"offload/{model_name}"}
            if self.num_gpus == "auto":
                kwargs["device_map"] = "auto"
            else:
                self.num_gpus = int(self.num_gpus)
                if self.num_gpus != 1:
                    kwargs.update({
                        "device_map": "auto",
                        "max_memory": {i: f"{self.max_memory}GiB" for i in range(self.num_gpus)},
                    })
        elif self.device == "cpu":
            kwargs = {}
        else:
            raise ValueError(f"Invalid device: {self.device}")
        
        # low_cpu_mem_usage = True if not '70b' in model_name else False
        if auth_token is not None:
            tokenizer = AutoTokenizer.from_pretrained(model_name if not 'vicuna' in model_name else 'huggyllama/llama-7b', token=auth_token)
            model = AutoModelForCausalLM.from_pretrained(model_name,
                # low_cpu_mem_usage=True, 
                token=auth_token, **kwargs)
        else:
            tokenizer = AutoTokenizer.from_pretrained(model_name if not 'vicuna' in model_name else 'huggyllama/llama-7b')
            model = AutoModelForCausalLM.from_pretrained(model_name,
                # low_cpu_mem_usage=True, 
                **kwargs)

        if self.device == "cuda" and self.num_gpus == 1:
            model.cuda()

        if lora:
            from peft import PeftModel
            model = PeftModel.from_pretrained(
                model,
                lora,
                torch_dtype=torch.float16,
            )
        
        return model, tokenizer

    def rescale_logits(self, logits, bound):
        if 'mpt-30b' in self.model_name:
            max_value = torch.abs(logits).max()
            if max_value > bound:
                logits = logits * (bound / max_value)
            return logits
        else:
            return logits

    def set_stop_words(self, stop_words):
        self.stop_words = stop_words
        self.stopping_criteria = StoppingCriteriaList()
        list_stop_word_ids = []
        for stop_word in self.stop_words:
            if 'llama' in self.model_name.lower():
                stop_word_ids = self.tokenizer.encode('\n' + stop_word)[3:]
            else:
                stop_word_ids = self.tokenizer.encode('\n' + stop_word)
            list_stop_word_ids.append(stop_word_ids)
            print("Added stop word: ", stop_word, 'with the ids', stop_word_ids, flush=True)
        self.stopping_criteria.append(LLamaQaStoppingCriteria(list_stop_word_ids))

    def generate(self, input_text, input_text_keys=None, max_new_tokens=256, top_p=0.95, top_k=0, temperature=0.8, final_layer=None, base_layer=None, base_layers=[], divergence_type='js', mode='vanilla', verbose=True, remove_stop_words=False, skip_layer0=False, relative_top=0.1, relative_top_with_norm=False, contrast_disagree_only=False, premature_temp=1.0, beta=None, disconnect_attn=None, disconnect_mlp=None, apply_early_norm=False, external_interpolation_factor=0.001, low_prob_percentile=0.1, steering_layers=None, selective_steering_n_heads=None, selective_steering_head_ids=None, selective_steering_layer_heads=None, return_attentions=False, reject_sampling_clf=None, chunk_size=None, num_candidates=None, attn_steer_factor=None, conversion_matrix=None, extra_prompt_length=None, **kwargs):
        with torch.no_grad():

            input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
            if verbose:
                print('MODEL INPUT LENGTH: {0}'.format(input_ids.shape[-1]))
            max_len = input_ids.shape[-1] + max_new_tokens

            if mode == 'vanilla':
                outputs = self.model.generate(inputs=input_ids, max_length=max_len, num_return_sequences=1,
                                    output_scores=True, return_dict_in_generate=True, early_exit_contrastive_decoding=False,
                                    top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, tuned_lens=self.tuned_lens, disconnect_attn=disconnect_attn, disconnect_mlp=disconnect_mlp, apply_early_norm=apply_early_norm, output_attentions=return_attentions, **kwargs)
            elif mode == 'early_exit_contrastive':
                assert final_layer is not None, "final_layer must be specified"
                assert base_layer is not None, "base_layer must be specified"
                outputs = self.model.generate(input_ids, max_length=max_len, num_return_sequences=1,
                                    output_scores=True, return_dict_in_generate=True, early_exit_contrastive_decoding=True,
                                    final_layer=final_layer, base_layer=base_layer,
                                    top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, skip_layer0=skip_layer0, relative_top=relative_top, relative_top_with_norm=relative_top_with_norm, contrast_disagree_only=contrast_disagree_only, rescale_base_logits=('mpt-30b' in self.model_name), tuned_lens=self.tuned_lens, premature_temp=premature_temp, beta=beta, apply_early_norm=apply_early_norm, **kwargs)
            elif mode == 'dynamic_early_exit_contrastive':
                assert final_layer is not None, "final_layer must be specified"
                assert base_layers is not None, "base_layers must be specified"
                outputs = self.model.generate(input_ids, max_length=max_len, num_return_sequences=1,
                                        output_scores=True, return_dict_in_generate=True, early_exit_contrastive_decoding=True,
                                        top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, skip_layer0=skip_layer0, relative_top=relative_top, relative_top_with_norm=relative_top_with_norm, contrast_disagree_only=contrast_disagree_only, rescale_base_logits=('mpt-30b' in self.model_name), tuned_lens=self.tuned_lens, premature_temp=premature_temp, beta=beta, apply_early_norm=apply_early_norm, 
                                        final_layer=final_layer, base_layer=None, dynamic_exit_layers=base_layers,
                                        divergence_type=divergence_type, **kwargs,)
                critical_layer_dist = outputs.critical_layer_dist
            elif mode == 'attn_intervention':
                token_importance = None
                if input_text_keys is not None:
                    # generate token importance by checking whether the each token exist in the key_token_ids_set
                    key_token_ids = self.tokenizer(input_text_keys, return_tensors="pt").input_ids.to(self.device)
                    key_token_ids_set = set(key_token_ids[0].cpu().numpy())
                    token_importance = [[1 if token_id in key_token_ids_set else 0 for token_id in input_ids[0].cpu().numpy()]]
                    # convert to tensor
                    token_importance = torch.tensor(token_importance, dtype=input_ids.dtype, device=self.device)
                outputs = self.model.generate(input_ids, max_length=max_len, num_return_sequences=1,
                                        output_scores=True, return_dict_in_generate=True, early_exit_contrastive_decoding=False,
                                        top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, skip_layer0=skip_layer0, relative_top=relative_top, relative_top_with_norm=relative_top_with_norm, contrast_disagree_only=contrast_disagree_only, rescale_base_logits=('mpt-30b' in self.model_name), tuned_lens=self.tuned_lens, premature_temp=premature_temp, beta=beta, apply_early_norm=apply_early_norm, 
                                        final_layer=final_layer, base_layer=None, dynamic_exit_layers=base_layers,
                                        divergence_type=divergence_type, attention_intervention_decoding=True, external_interpolation_factor=external_interpolation_factor, token_importance=token_importance, steering_layers=steering_layers, selective_steering_n_heads=selective_steering_n_heads, selective_steering_layer_heads=selective_steering_layer_heads, **kwargs,)
                critical_layer_dist = outputs.critical_layer_dist
            elif mode == 'attn_intervention_low_prob':
                token_importance = None
                outputs = self.model.generate(input_ids, max_length=max_len, num_return_sequences=1,
                                        output_scores=True, return_dict_in_generate=True, early_exit_contrastive_decoding=False,
                                        top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, skip_layer0=skip_layer0, relative_top=relative_top, relative_top_with_norm=relative_top_with_norm, contrast_disagree_only=contrast_disagree_only, rescale_base_logits=('mpt-30b' in self.model_name), tuned_lens=self.tuned_lens, premature_temp=premature_temp, beta=beta, apply_early_norm=apply_early_norm, 
                                        final_layer=final_layer, base_layer=None, dynamic_exit_layers=base_layers,
                                        divergence_type=divergence_type, attention_intervention_decoding=True, external_interpolation_factor=external_interpolation_factor, token_importance=token_importance, low_prob_percentile=low_prob_percentile, steering_layers=steering_layers, selective_steering_n_heads=selective_steering_n_heads, selective_steering_head_ids=selective_steering_head_ids, selective_steering_layer_heads=selective_steering_layer_heads, **kwargs,)
                critical_layer_dist = outputs.critical_layer_dist
            elif mode == 'rejection_sampling':
                token_importance = None
                outputs = self.model.generate(input_ids, max_length=max_len, num_return_sequences=1,
                                        output_scores=True, return_dict_in_generate=True, early_exit_contrastive_decoding=False,
                                        top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, skip_layer0=skip_layer0, relative_top=relative_top, relative_top_with_norm=relative_top_with_norm, contrast_disagree_only=contrast_disagree_only, rescale_base_logits=('mpt-30b' in self.model_name), tuned_lens=self.tuned_lens, premature_temp=premature_temp, beta=beta, apply_early_norm=apply_early_norm, 
                                        final_layer=final_layer, base_layer=None, dynamic_exit_layers=base_layers,
                                        divergence_type=divergence_type, attention_intervention_decoding=False, external_interpolation_factor=external_interpolation_factor, 
                                        token_importance=token_importance, low_prob_percentile=low_prob_percentile, 
                                        steering_layers=steering_layers, selective_steering_n_heads=selective_steering_n_heads, 
                                        selective_steering_head_ids=selective_steering_head_ids, 
                                        selective_steering_layer_heads=selective_steering_layer_heads, 
                                        extra_prompt_length=extra_prompt_length,
                                        attn_steer_factor=attn_steer_factor, 
                                        reject_sampling_clf=reject_sampling_clf, chunk_size=chunk_size, num_candidates=num_candidates, conversion_matrix=conversion_matrix, **kwargs,)
                critical_layer_dist = None
            elif mode == 'attention_steering':
                token_importance = None
                outputs = self.model.generate(input_ids, max_length=max_len, num_return_sequences=1,
                                        output_scores=True, return_dict_in_generate=True, early_exit_contrastive_decoding=False,
                                        top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, skip_layer0=skip_layer0, relative_top=relative_top, relative_top_with_norm=relative_top_with_norm, contrast_disagree_only=contrast_disagree_only, rescale_base_logits=('mpt-30b' in self.model_name), tuned_lens=self.tuned_lens, premature_temp=premature_temp, beta=beta, apply_early_norm=apply_early_norm, 
                                        final_layer=final_layer, base_layer=None, dynamic_exit_layers=base_layers,
                                        divergence_type=divergence_type, 
                                        attention_intervention_decoding=False, 
                                        external_interpolation_factor=external_interpolation_factor, 
                                        token_importance=token_importance, 
                                        low_prob_percentile=low_prob_percentile, 
                                        steering_layers=steering_layers, 
                                        selective_steering_n_heads=selective_steering_n_heads, 
                                        selective_steering_head_ids=selective_steering_head_ids, 
                                        selective_steering_layer_heads=selective_steering_layer_heads, 
                                        reject_sampling_clf=reject_sampling_clf, 
                                        chunk_size=chunk_size, 
                                        num_candidates=num_candidates, 
                                        attn_steer_factor=attn_steer_factor, 
                                        extra_prompt_length=extra_prompt_length,
                                        **kwargs,)
                critical_layer_dist = None
            sequences, scores = outputs.sequences, outputs.scores

            # skip the tokens in the input prompt
            gen_sequences = sequences[:, input_ids.shape[-1]:][0, :]
            gen_arr = gen_sequences.cpu().numpy()

            output_str = self.tokenizer.decode(gen_sequences, skip_special_tokens=True)

            if verbose:
                print('MODEL OUTPUT: \n{0}'.format(output_str))

            if remove_stop_words:
                for stop_word in self.stop_words:
                    length_to_remove = len(stop_word)
                    if output_str[-length_to_remove:] == stop_word:
                        output_str = output_str[:-length_to_remove]
                output_str = output_str.strip()

        if self.device:
            torch.cuda.empty_cache()
        if not return_attentions:
            return output_str, (critical_layer_dist if mode == 'dynamic_early_exit_contrastive' else None)
        else:
            return output_str, outputs.attentions, gen_arr

    def generate_length(self, input_text, max_new_tokens=256, top_p=0.95, top_k=0, temperature=0.8, final_layer=None, base_layer=None, base_layers=[], divergence_type='js', mode='vanilla', verbose=True, remove_stop_words=False, skip_layer0=False, relative_top=0.1, relative_top_with_norm=False, contrast_disagree_only=False, premature_temp=1.0, beta=None, **kwargs):
        with torch.no_grad():

            input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
            max_len = input_ids.shape[-1] + max_new_tokens
            input_length = input_ids.shape[-1]

            if mode == 'vanilla':
                outputs = self.model.generate(input_ids, max_length=max_len, num_return_sequences=1,
                                    output_scores=True, return_dict_in_generate=True, early_exit_contrastive_decoding=False,
                                    top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, tuned_lens=self.tuned_lens, **kwargs)
            elif mode == 'early_exit_contrastive':
                assert final_layer is not None, "final_layer must be specified"
                assert base_layer is not None, "base_layer must be specified"
                outputs = self.model.generate(input_ids, max_length=max_len, num_return_sequences=1,
                                    output_scores=True, return_dict_in_generate=True, early_exit_contrastive_decoding=True,
                                    final_layer=final_layer, base_layer=base_layer,
                                    top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, skip_layer0=skip_layer0, relative_top=relative_top, relative_top_with_norm=relative_top_with_norm, contrast_disagree_only=contrast_disagree_only, rescale_base_logits=('mpt-30b' in self.model_name), tuned_lens=self.tuned_lens, premature_temp=premature_temp, beta=beta, **kwargs)
            elif mode == 'dynamic_early_exit_contrastive':
                assert final_layer is not None, "final_layer must be specified"
                assert base_layers is not None, "base_layers must be specified"
                outputs = self.model.generate(input_ids, max_length=max_len, num_return_sequences=1,
                                        output_scores=True, return_dict_in_generate=True, early_exit_contrastive_decoding=True,
                                        top_p=top_p, top_k=top_k, temperature=temperature, stopping_criteria=self.stopping_criteria, skip_layer0=skip_layer0, relative_top=relative_top, relative_top_with_norm=relative_top_with_norm, contrast_disagree_only=contrast_disagree_only, rescale_base_logits=('mpt-30b' in self.model_name), tuned_lens=self.tuned_lens, premature_temp=premature_temp, beta=beta, 
                                        final_layer=final_layer, base_layer=None, dynamic_exit_layers=base_layers,
                                        divergence_type=divergence_type, **kwargs,)
                critical_layer_dist = outputs.critical_layer_dist
            sequences, scores = outputs.sequences, outputs.scores

            # skip the tokens in the input prompt
            gen_sequences = sequences[:, input_ids.shape[-1]:][0, :]
            gen_arr = gen_sequences.cpu().numpy()
            output_length = len(gen_arr) # - input_length

            return input_length, output_length
        #     output_str = self.tokenizer.decode(gen_sequences, skip_special_tokens=True)

        #     if verbose:
        #         print('MODEL OUTPUT: \n{0}'.format(output_str))

        #     if remove_stop_words:
        #         for stop_word in self.stop_words:
        #             length_to_remove = len(stop_word)
        #             if output_str[-length_to_remove:] == stop_word:
        #                 output_str = output_str[:-length_to_remove]
        #         output_str = output_str.strip()

        # if self.device:
        #     torch.cuda.empty_cache()

        # return output_str, (critical_layer_dist if mode == 'dynamic_early_exit_contrastive' else None)

    def get_relative_top_filter(self, scores: torch.FloatTensor, relative_top: float = 0.1, min_tokens_to_keep: int = 1):
        scores_normalized = scores.log_softmax(dim=-1) 
        sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True)
        min_thresh = sorted_logits[..., min_tokens_to_keep-1] 
        probs_max = torch.max(scores_normalized, dim=-1).values
        probs_thresh = probs_max + np.log(relative_top)
        probs_thresh = torch.min(min_thresh, probs_thresh)
        probs_thresh = probs_thresh.unsqueeze(-1)
        return scores_normalized < probs_thresh

    def lm_score(self, input_text1, input_text2, pmi=False, max_new_tokens=256, top_p=0.95, top_k=0, temperature=0.8, final_layer=None, base_layer=None, base_layers=[], divergence_type='js', mode='vanilla', verbose=True, remove_stop_words=False, skip_layer0=False, relative_top=0.1, relative_top_with_norm=False, relative_top_value=-1000.0, contrast_disagree_only=False, extrapolate_coeff=None, post_softmax=True, premature_temp=1.0, beta=None, external_interpolation_factor=0.001, **kwargs): # disconnect_attn=None, disconnect_mlp=None, 
        with torch.no_grad():
            input_text = input_text1 + input_text2
            input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
            prefix_ids = self.tokenizer(input_text1, return_tensors="pt").input_ids.to(self.device)
            continue_ids = input_ids[0, prefix_ids.shape[-1]:]
            context_ids = input_ids[0, :prefix_ids.shape[-1]]
            if mode == 'vanilla':
                outputs = self.model(input_ids)[0].squeeze(0) # , disconnect_attn=disconnect_attn, disconnect_mlp=disconnect_mlp
                outputs = outputs.log_softmax(-1)  # logits to log probs

                # skip tokens in the prompt -- we only care about the answer
                outputs = outputs[prefix_ids.shape[-1] - 1: -1, :]

                # get logprobs for each token in the answer
                log_probs = outputs[range(outputs.shape[0]), continue_ids].sum().item()

                # pmi
                if pmi:
                    outputs_y = self.model(input_ids[:, prefix_ids.shape[-1]-1:])[0].squeeze(0)[:-1]
                    outputs_y = outputs_y.log_softmax(-1)
                    log_probs_y = outputs_y[range(outputs_y.shape[0]), continue_ids].sum().item()
                    log_probs = log_probs - log_probs_y
                
            elif mode == 'early_exit_contrastive':
                dict_outputs, outputs = self.model(
                    input_ids=input_ids,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                    early_exit_layers=[base_layer, final_layer],
                    tuned_lens=self.tuned_lens,
                    apply_early_norm=apply_early_norm, 
                )

                assert base_layer is not None
                base_logits = dict_outputs[base_layer][0, prefix_ids.shape[-1] - 1: -1, :]
                final_logits = dict_outputs[final_layer][0, prefix_ids.shape[-1] - 1: -1, :]
                base_logits = self.rescale_logits(base_logits, torch.abs(final_logits).max())
                if premature_temp != 1.0:
                    base_logits = base_logits / premature_temp
                final_logits = final_logits.log_softmax(dim=-1)
                base_logits = base_logits.log_softmax(dim=-1)
                if beta is not None:
                    diff_logits = (1.0 + beta) * final_logits - beta * base_logits
                elif extrapolate_coeff is None or extrapolate_coeff >= 1000.0:
                    diff_logits = final_logits - base_logits
                else:
                    diff_logits = base_logits + extrapolate_coeff * (final_logits - base_logits)
                if post_softmax:
                    diff_logits = diff_logits.log_softmax(dim=-1)
                if relative_top > 0.0:
                    relative_top_mask = self.get_relative_top_filter(final_logits, relative_top)
                    relative_top_value = torch.tensor(relative_top_value, dtype=diff_logits.dtype, device=diff_logits.device)
                    diff_logits = torch.where(relative_top_mask, relative_top_value, diff_logits)
                # if contrast_disagree_only:
                #     cdo_token_mask = final_logits.argmax(-1) == base_logits.argmax(-1)
                #     cdo_token_mask = cdo_token_mask.unsqueeze(-1)
                #     diff_logits = torch.where(cdo_token_mask, final_logits, diff_logits)
                    
                log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()

                # pmi
                if pmi:
                    dict_outputs_y, outputs_y = self.model(
                        input_ids=input_ids[:, prefix_ids.shape[-1]-1:],
                        return_dict=True,
                        output_attentions=False,
                        output_hidden_states=False,
                        early_exit_layers=[base_layer, final_layer],
                    )
                    base_logits_y = dict_outputs_y[base_layer][0, :, :]
                    final_logits_y = dict_outputs_y[final_layer][0, :, :]
                    final_logits_y = final_logits_y.log_softmax(dim=-1)
                    base_logits_y = base_logits_y.log_softmax(dim=-1)
                    diff_logits_y = final_logits_y - base_logits_y
                    if post_softmax:
                        diff_logits_y = diff_logits_y.log_softmax(dim=-1)
                    if relative_top > 0.0:
                        relative_top_mask = self.get_relative_top_filter(final_logits_y, relative_top)
                        relative_top_value = torch.tensor(relative_top_value, dtype=diff_logits.dtype, device=diff_logits.device)
                        diff_logits_y = torch.where(relative_top_mask, relative_top_value, diff_logits_y)
                    log_probs_y = diff_logits_y[range(diff_logits_y.shape[0]), continue_ids].sum().item()
                    log_probs = log_probs - log_probs_y

            elif mode == 'early_exit_contrastive_exploit':
                dict_outputs, outputs = self.model(
                    input_ids=input_ids,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                    early_exit_layers=base_layers + [final_layer],
                    tuned_lens=self.tuned_lens,
                    apply_early_norm=apply_early_norm, 
                )

                return_dict = {}
                for base_layer in base_layers:
                    base_logits = dict_outputs[base_layer][0, prefix_ids.shape[-1] - 1: -1, :]
                    final_logits = dict_outputs[final_layer][0, prefix_ids.shape[-1] - 1: -1, :]
                    base_logits = self.rescale_logits(base_logits, torch.abs(final_logits).max())
                    if premature_temp != 1.0:
                        base_logits = base_logits / premature_temp
                    final_logits = final_logits.log_softmax(dim=-1)
                    base_logits = base_logits.log_softmax(dim=-1)
                    if beta is not None:
                        diff_logits = (1.0 + beta) * final_logits - beta * base_logits
                    elif extrapolate_coeff is None or extrapolate_coeff >= 1000.0:
                        diff_logits = final_logits - base_logits
                    else:
                        diff_logits = base_logits + extrapolate_coeff * (final_logits - base_logits)
                    if post_softmax:
                        diff_logits = diff_logits.log_softmax(dim=-1)
                    if relative_top > 0.0:
                        relative_top_mask = self.get_relative_top_filter(final_logits, relative_top)
                        relative_top_value = torch.tensor(relative_top_value, dtype=diff_logits.dtype, device=diff_logits.device)
                        diff_logits = torch.where(relative_top_mask, relative_top_value, diff_logits)
                    # if contrast_disagree_only:
                    #     cdo_token_mask = final_logits.argmax(-1) == base_logits.argmax(-1)
                    #     cdo_token_mask = cdo_token_mask.unsqueeze(-1)
                    #     diff_logits = torch.where(cdo_token_mask, final_logits, diff_logits)
                        
                    log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()
                    # pmi
                    if pmi:
                        dict_outputs_y, outputs_y = self.model(
                            input_ids=input_ids[:, prefix_ids.shape[-1]-1:],
                            return_dict=True,
                            output_attentions=False,
                            output_hidden_states=False,
                            early_exit_layers=[base_layer, final_layer],
                        )
                        base_logits_y = dict_outputs_y[base_layer][0, :, :]
                        final_logits_y = dict_outputs_y[final_layer][0, :, :]
                        final_logits_y = final_logits_y.log_softmax(dim=-1)
                        base_logits_y = base_logits_y.log_softmax(dim=-1)
                        diff_logits_y = final_logits_y - base_logits_y
                        if post_softmax:
                            diff_logits_y = diff_logits_y.log_softmax(dim=-1)
                        if relative_top > 0.0:
                            relative_top_mask = self.get_relative_top_filter(final_logits_y, relative_top)
                            relative_top_value = torch.tensor(relative_top_value, dtype=diff_logits.dtype, device=diff_logits.device)
                            diff_logits_y = torch.where(relative_top_mask, relative_top_value, diff_logits_y)
                        log_probs_y = diff_logits_y[range(diff_logits_y.shape[0]), continue_ids].sum().item()
                        log_probs = log_probs - log_probs_y

                    return_dict[base_layer] = log_probs
                log_probs = return_dict

            elif mode == 'dynamic_early_exit_contrastive':
                critical_layer_dist = {l:0 for l in base_layers}
                picked_logits = []
                result_dict = {}
                critical_layers = []

                dict_outputs, outputs = self.model(
                    input_ids=input_ids,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                    early_exit_layers=base_layers + [final_layer],
                    tuned_lens=self.tuned_lens,
                    apply_early_norm=apply_early_norm, 
                )

                for seq_i in range(prefix_ids.shape[-1] - 1, input_ids.shape[-1] - 1):
                    # pick the less like layer to contrast with
                    if divergence_type == 'random': # a baseline for random value js_divs
                        js_divs = torch.rand(len(base_layers))
                    elif divergence_type == 'real_js':
                        # Stacking all base_layers into a new dimension
                        stacked_base_layers = torch.stack([dict_outputs[i][:, seq_i, :] for i in base_layers], dim=0)

                        # Calculate the softmax values for final_layer and all base_layers
                        softmax_final_layer = F.softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1)  # shape: (batch_size, num_features)
                        softmax_base_layers = F.softmax(stacked_base_layers, dim=-1)  # shape: (num_base_layers, batch_size, num_features)

                        # Calculate M, the average distribution
                        M = 0.5 * (softmax_final_layer[None, :, :] + softmax_base_layers)  # shape: (num_base_layers, batch_size, num_features)

                        # Calculate log-softmax for the KL divergence
                        log_softmax_final_layer = F.log_softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1)  # shape: (batch_size, num_features)
                        log_softmax_base_layers = F.log_softmax(stacked_base_layers, dim=-1)  # shape: (num_base_layers, batch_size, num_features)

                        # Calculate the KL divergences and then the JS divergences
                        kl1 = F.kl_div(log_softmax_final_layer[None, :, :], M, reduction='none').mean(-1)  # shape: (num_base_layers, batch_size)
                        kl2 = F.kl_div(log_softmax_base_layers, M, reduction='none').mean(-1)  # shape: (num_base_layers, batch_size)
                        js_divs = 0.5 * (kl1 + kl2)  # shape: (num_base_layers, batch_size)

                        # Reduce the batchmean
                        js_divs = js_divs.mean(-1)  # shape: (num_base_layers,)
                    else:
                        js_divs = torch.stack(
                            # reverse KL-divergence
                            [F.kl_div(F.log_softmax(dict_outputs[i][:, seq_i, :], dim=-1), F.softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), reduction='batchmean') for i in base_layers] if divergence_type == 'rev_kl' else (
                            # KL-divergence
                            [F.kl_div(F.log_softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), F.softmax(dict_outputs[i][:, seq_i, :], dim=-1), reduction='batchmean') for i in base_layers] if divergence_type == 'kl' else 
                            # JS-divergence
                            [0.5 * F.kl_div(F.log_softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), F.softmax(dict_outputs[i][:, seq_i, :], dim=-1), reduction='batchmean') + 0.5 * F.kl_div(F.log_softmax(dict_outputs[i][:, seq_i, :], dim=-1), F.softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), reduction='batchmean') for i in base_layers]
                            )
                            # [F.cosine_similarity(final_hidden - base_vector, hidden - base_vector, dim=-1) for hidden in dynamic_exit_hiddens]
                            # [torch.dist(final_hidden - base_vector, hidden - base_vector) for hidden in dynamic_exit_hiddens]
                        ).squeeze(-1)
                    critical_layer = base_layers[int(js_divs.argmax().cpu().item())]
                    critical_layer_dist[critical_layer] += 1

                    # less_than_threshold = kl_divs < critical_layer_threshold
                    # get the layer that is the first one to be less than critical_layer_threshold similar to the final layer
                    # less_than_threshold_idx = less_than_threshold.nonzero()
                    # if len(less_than_threshold_idx) == 0:
                        # critical_layer = dynamic_exit_layers[0]
                    # else:
                        # critical_layer = dynamic_exit_layers[int(less_than_threshold_idx.argmax())]
                    # debug
                    # to_print = ', '.join([f"{kl_divs[i].item():.2f}" for i in range(len(kl_divs))])
                    # token_id_curr = concat_input_ids[seq_i-100:seq_i+1]
                    # token_id_to_predict = concat_input_ids[seq_i + 1]
                    # token_curr = tokenizer.decode(token_id_curr).replace('\n', ' ')
                    # token_to_predict = tokenizer.decode(token_id_to_predict)
                    # print(f"cl: {critical_layer}, kl: [{to_print}] = {token_curr} -> {token_to_predict}")
                    critical_layers.append(critical_layer)

                # # for l in early_exit_layers[:-1]:

                base_logits = torch.zeros_like(dict_outputs[final_layer][0, prefix_ids.shape[-1] - 1:-1])
                for i, l in enumerate(critical_layers):
                   base_logits[i] = dict_outputs[l][0, prefix_ids.shape[-1] - 1 + i]
                final_logits = dict_outputs[final_layer][0, prefix_ids.shape[-1] - 1:-1]
                base_logits = self.rescale_logits(base_logits, torch.abs(final_logits).max())
                if premature_temp != 1.0:
                    base_logits = base_logits / premature_temp
                final_logits = final_logits.log_softmax(dim=-1)
                base_logits = base_logits.log_softmax(dim=-1)

                if beta is not None:
                    diff_logits = (1.0 + beta) * final_logits - beta * base_logits
                elif extrapolate_coeff is None or extrapolate_coeff >= 1000.0:
                    diff_logits = final_logits - base_logits
                else:
                    diff_logits = base_logits + extrapolate_coeff * (final_logits - base_logits)
                if post_softmax:
                    diff_logits = diff_logits.log_softmax(dim=-1)

                if relative_top > 0.0:
                    relative_top_mask = self.get_relative_top_filter(final_logits, relative_top)
                    relative_top_value = torch.tensor(relative_top_value, dtype=diff_logits.dtype, device=diff_logits.device)
                    diff_logits = torch.where(relative_top_mask, relative_top_value, diff_logits)
                # if contrast_disagree_only:
                #     cdo_token_mask = final_logits.argmax(-1) == base_logits.argmax(-1)
                #     cdo_token_mask = cdo_token_mask.unsqueeze(-1)
                #     diff_logits = torch.where(cdo_token_mask, final_logits, diff_logits)
                
                log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()
            elif mode == 'attn_intervention':

                dynamic_exit_layers = base_layers + [final_layer]

                dict_outputs, outputs = self.model(
                    input_ids=input_ids,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                    early_exit_layers=dynamic_exit_layers,
                )
                                
                stacked_base_layers = torch.stack([dict_outputs[i] for i in dynamic_exit_layers], dim=0)

                # Calculate the softmax values for final_layer and all base_layers, move the tensor to CPU to avoid OOM
                softmax_final_layer = F.softmax(dict_outputs[final_layer], dim=-1).to('cpu')  # shape: (batch_size, num_features)
                softmax_base_layers = F.softmax(stacked_base_layers, dim=-1).to('cpu')  # shape: (num_base_layers, batch_size, num_features)

                assert softmax_final_layer.shape[0] == softmax_base_layers.shape[1] == 1
                # Squeeze the batch dimension
                softmax_final_layer = softmax_final_layer.squeeze(0)  # shape: (num_features,)
                softmax_base_layers = softmax_base_layers.squeeze(1)  # shape: (num_base_layers, num_features)
                # Calculate M, the average distribution
                M = 0.5 * (softmax_final_layer.unsqueeze(0) + softmax_base_layers)  # shape: (num_base_layers, batch_size, num_features)

                # Calculate log-softmax for the KL divergence
                log_softmax_final_layer = F.log_softmax(dict_outputs[final_layer], dim=-1).to('cpu')  # shape: (batch_size, num_features)
                log_softmax_base_layers = F.log_softmax(stacked_base_layers, dim=-1).to('cpu')  # shape: (num_base_layers, batch_size, num_features)

                assert log_softmax_final_layer.shape[0] == log_softmax_base_layers.shape[1] == 1

                # Squeeze the batch dimension
                log_softmax_final_layer = log_softmax_final_layer.squeeze(0)  # shape: (length_seq,)
                log_softmax_base_layers = log_softmax_base_layers.squeeze(1)  # shape: (num_base_layers, length_seq)
                # Calculate the KL divergences and then the JS divergences
                kl1 = F.kl_div(log_softmax_final_layer.unsqueeze(0), M, reduction='none').mean(-1)  # shape: (num_base_layers, length_seq)
                kl2 = F.kl_div(log_softmax_base_layers, M, reduction='none').mean(-1)  # shape: (num_base_layers, length_seq)
                js_divs = 0.5 * (kl1 + kl2)  # shape: (num_base_layers, length_seq)
                layer_dist = js_divs / js_divs.sum(dim=0, keepdims=True) # shape: (num_base_layers,)

                # use the softmax of the js_divs as the weights for the weighted sum of the layer number of the base layers
                weighted_layer = torch.matmul(torch.tensor([dynamic_exit_layers], device=js_divs.device, dtype=js_divs.dtype), layer_dist)
                # softmax with temperature T=1.0
                external_attn_cache = F.softmax(weighted_layer / 1.0, dim=1)
                # append a padding to the front of the tensor, batch_size = external_attn_cache.shape[0]
                external_attn_cache = torch.cat([external_attn_cache[:, -1:], external_attn_cache[:, :-1]], dim=1)
                
                # Release the memory before the next forward pass
                del stacked_base_layers, softmax_final_layer, softmax_base_layers, M, log_softmax_final_layer, log_softmax_base_layers, kl1, kl2, js_divs, layer_dist, weighted_layer
                torch.cuda.empty_cache()

                outputs = self.model(
                    input_ids,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                    external_attn_weights=external_attn_cache,
                    external_interpolation_factor=external_interpolation_factor,
                )[0].squeeze(0)

                outputs = outputs.log_softmax(-1)  # logits to log probs

                # skip tokens in the prompt -- we only care about the answer
                outputs = outputs[prefix_ids.shape[-1] - 1: -1, :]

                # get logprobs for each token in the answer
                log_probs = outputs[range(outputs.shape[0]), continue_ids].sum().item()
            elif mode == 'attn_intervention_low_prob':

                dynamic_exit_layers = base_layers + [final_layer]

                outputs = self.model(
                    input_ids,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                )[0].squeeze(0)

                log_probs = outputs[range(context_ids.shape[0]-1), context_ids[1:]]
                importance = - log_probs
                           
                # use the softmax of the js_divs as the weights for the weighted sum of the layer number of the base layers
                # weighted_layer = torch.matmul(torch.tensor([dynamic_exit_layers], device=js_divs.device, dtype=js_divs.dtype), layer_dist)
                # softmax with temperature T=1.0
                external_attn_cache = F.softmax(importance / 1.0, dim=0)
                # append a padding to the front of the tensor, batch_size = external_attn_cache.shape[0]
                external_attn_cache = torch.cat([external_attn_cache[-1:], external_attn_cache], dim=0).unsqueeze(0)
                # Release the memory before the next forward pass
                del outputs, importance
                torch.cuda.empty_cache()

                outputs = self.model(
                    input_ids,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                    external_attn_weights=external_attn_cache,
                    external_interpolation_factor=external_interpolation_factor,
                )[0].squeeze(0)

                outputs = outputs.log_softmax(-1)  # logits to log probs

                # skip tokens in the prompt -- we only care about the answer
                outputs = outputs[prefix_ids.shape[-1] - 1: -1, :]

                # get logprobs for each token in the answer
                log_probs = outputs[range(outputs.shape[0]), continue_ids].sum().item()


        return log_probs, (critical_layer_dist if mode == 'dynamic_early_exit_contrastive' else None)


    def lm_score_by_tokens(self, input_text, extra_tokens=None, extra_token_ids=None, token_importance=None, pmi=False, max_new_tokens=256, top_p=0.95, top_k=0, temperature=0.8, final_layer=None, base_layer=None, base_layers=[], divergence_type='js', mode='vanilla', verbose=True, remove_stop_words=False, skip_layer0=False, relative_top=0.1, relative_top_with_norm=False, relative_top_value=-1000.0, contrast_disagree_only=False, extrapolate_coeff=None, post_softmax=True, premature_temp=1.0, beta=None, external_interpolation_factor=0.001, low_prob_percentile=0.1, steering_layers=None, selective_steering_n_heads=None, selective_steering_head_ids=None, selective_steering_layer_heads=None, shift_by_1=False, **kwargs): # disconnect_attn=None, disconnect_mlp=None, 
        log_probs = None
        with torch.no_grad():
            # if the end char of input_text is a space, remove it
            # if the start char of extra_tokens is not a space, add a space
            if input_text[-1] == ' ':
                input_text = input_text[:-1]
            for i, token in enumerate(extra_tokens):
                if token[0] != ' ':
                    extra_tokens[i] = ' ' + token
            input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
            if extra_token_ids is None:
                extra_token_ids = []
                for token in extra_tokens:
                    extra_token_ids.append(self.tokenizer(input_text + token).input_ids[input_ids.shape[1]])
                # assert no redundant tokens
                assert len(extra_token_ids) == len(set(extra_token_ids)), "No support for redundant tokens in extra_tokens"
            # prefix_ids = self.tokenizer(input_text1, return_tensors="pt").input_ids.to(self.device)
            # continue_ids = input_ids[0, prefix_ids.shape[-1]:]
            if mode == 'vanilla':
                outputs = self.model(input_ids)[0].squeeze(0) # , disconnect_attn=disconnect_attn, disconnect_mlp=disconnect_mlp
                outputs = outputs.log_softmax(-1)  # logits to log probs

                # skip tokens in the prompt -- we only care about the answer
                outputs = outputs[-1, :]

                # get logprobs for each token in the answer
                log_probs = outputs[extra_token_ids].tolist()
            elif mode == 'attention_intervention' or mode == 'attention_intervention_low_prob' or mode == 'attn_intervention_low_prob':

                if token_importance is None:
                    # use low prob tokens as importance token
                    outputs = self.model(
                        input_ids,
                        return_dict=True,
                    )
                    logits = outputs.logits[:, :-1, :]
                    log_probs = logits[:, range(logits.shape[1]), input_ids[0, 1:]]
                    # use the tokens with lower probability (within the 10% percentile) as the importance tokens
                    token_importance = (log_probs < torch.quantile(log_probs, low_prob_percentile))#.type(logits.dtype).to(logits.device)
                    if shift_by_1: # shift the token importance by 1
                        token_importance = torch.cat([torch.tensor([[True]], device=token_importance.device), token_importance], dim=1)
            

                outputs = self.model(
                    input_ids,
                    return_dict=True,
                    external_interpolation_factor=external_interpolation_factor,
                    token_importance=token_importance,
                    steering_layers=steering_layers,
                    selective_steering_n_heads=selective_steering_n_heads,
                    selective_steering_head_ids=selective_steering_head_ids,
                    selective_steering_layer_heads=selective_steering_layer_heads,
                )[0].squeeze(0) # , disconnect_attn=disconnect_attn, disconnect_mlp=disconnect_mlp
                outputs = outputs.log_softmax(-1)  # logits to log probs

                # skip tokens in the prompt -- we only care about the answer
                outputs = outputs[-1, :]

                # get logprobs for each token in the answer
                log_probs = outputs[extra_token_ids].tolist()
            elif mode == 'early_exit_contrastive':
                dict_outputs, outputs = self.model(
                    input_ids=input_ids,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                    early_exit_layers=[base_layer, final_layer],
                    tuned_lens=self.tuned_lens,
                )

                assert base_layer is not None
                base_logits = dict_outputs[base_layer][0, prefix_ids.shape[-1] - 1: -1, :]
                final_logits = dict_outputs[final_layer][0, prefix_ids.shape[-1] - 1: -1, :]
                base_logits = self.rescale_logits(base_logits, torch.abs(final_logits).max())
                if premature_temp != 1.0:
                    base_logits = base_logits / premature_temp
                final_logits = final_logits.log_softmax(dim=-1)
                base_logits = base_logits.log_softmax(dim=-1)
                if beta is not None:
                    diff_logits = (1.0 + beta) * final_logits - beta * base_logits
                elif extrapolate_coeff is None or extrapolate_coeff >= 1000.0:
                    diff_logits = final_logits - base_logits
                else:
                    diff_logits = base_logits + extrapolate_coeff * (final_logits - base_logits)
                if post_softmax:
                    diff_logits = diff_logits.log_softmax(dim=-1)
                if relative_top > 0.0:
                    relative_top_mask = self.get_relative_top_filter(final_logits, relative_top)
                    relative_top_value = torch.tensor(relative_top_value, dtype=diff_logits.dtype, device=diff_logits.device)
                    diff_logits = torch.where(relative_top_mask, relative_top_value, diff_logits)
                # if contrast_disagree_only:
                #     cdo_token_mask = final_logits.argmax(-1) == base_logits.argmax(-1)
                #     cdo_token_mask = cdo_token_mask.unsqueeze(-1)
                #     diff_logits = torch.where(cdo_token_mask, final_logits, diff_logits)
                    
                log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()

                # pmi
                if pmi:
                    dict_outputs_y, outputs_y = self.model(
                        input_ids=input_ids[:, prefix_ids.shape[-1]-1:],
                        return_dict=True,
                        output_attentions=False,
                        output_hidden_states=False,
                        early_exit_layers=[base_layer, final_layer],
                    )
                    base_logits_y = dict_outputs_y[base_layer][0, :, :]
                    final_logits_y = dict_outputs_y[final_layer][0, :, :]
                    final_logits_y = final_logits_y.log_softmax(dim=-1)
                    base_logits_y = base_logits_y.log_softmax(dim=-1)
                    diff_logits_y = final_logits_y - base_logits_y
                    if post_softmax:
                        diff_logits_y = diff_logits_y.log_softmax(dim=-1)
                    if relative_top > 0.0:
                        relative_top_mask = self.get_relative_top_filter(final_logits_y, relative_top)
                        relative_top_value = torch.tensor(relative_top_value, dtype=diff_logits.dtype, device=diff_logits.device)
                        diff_logits_y = torch.where(relative_top_mask, relative_top_value, diff_logits_y)
                    log_probs_y = diff_logits_y[range(diff_logits_y.shape[0]), continue_ids].sum().item()
                    log_probs = log_probs - log_probs_y

            elif mode == 'early_exit_contrastive_exploit':
                dict_outputs, outputs = self.model(
                    input_ids=input_ids,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                    early_exit_layers=base_layers + [final_layer],
                    tuned_lens=self.tuned_lens,
                )

                return_dict = {}
                for base_layer in base_layers:
                    base_logits = dict_outputs[base_layer][0, prefix_ids.shape[-1] - 1: -1, :]
                    final_logits = dict_outputs[final_layer][0, prefix_ids.shape[-1] - 1: -1, :]
                    base_logits = self.rescale_logits(base_logits, torch.abs(final_logits).max())
                    if premature_temp != 1.0:
                        base_logits = base_logits / premature_temp
                    final_logits = final_logits.log_softmax(dim=-1)
                    base_logits = base_logits.log_softmax(dim=-1)
                    if beta is not None:
                        diff_logits = (1.0 + beta) * final_logits - beta * base_logits
                    elif extrapolate_coeff is None or extrapolate_coeff >= 1000.0:
                        diff_logits = final_logits - base_logits
                    else:
                        diff_logits = base_logits + extrapolate_coeff * (final_logits - base_logits)
                    if post_softmax:
                        diff_logits = diff_logits.log_softmax(dim=-1)
                    if relative_top > 0.0:
                        relative_top_mask = self.get_relative_top_filter(final_logits, relative_top)
                        relative_top_value = torch.tensor(relative_top_value, dtype=diff_logits.dtype, device=diff_logits.device)
                        diff_logits = torch.where(relative_top_mask, relative_top_value, diff_logits)
                    # if contrast_disagree_only:
                    #     cdo_token_mask = final_logits.argmax(-1) == base_logits.argmax(-1)
                    #     cdo_token_mask = cdo_token_mask.unsqueeze(-1)
                    #     diff_logits = torch.where(cdo_token_mask, final_logits, diff_logits)
                        
                    log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()
                    # pmi
                    if pmi:
                        dict_outputs_y, outputs_y = self.model(
                            input_ids=input_ids[:, prefix_ids.shape[-1]-1:],
                            return_dict=True,
                            output_attentions=False,
                            output_hidden_states=False,
                            early_exit_layers=[base_layer, final_layer],
                        )
                        base_logits_y = dict_outputs_y[base_layer][0, :, :]
                        final_logits_y = dict_outputs_y[final_layer][0, :, :]
                        final_logits_y = final_logits_y.log_softmax(dim=-1)
                        base_logits_y = base_logits_y.log_softmax(dim=-1)
                        diff_logits_y = final_logits_y - base_logits_y
                        if post_softmax:
                            diff_logits_y = diff_logits_y.log_softmax(dim=-1)
                        if relative_top > 0.0:
                            relative_top_mask = self.get_relative_top_filter(final_logits_y, relative_top)
                            relative_top_value = torch.tensor(relative_top_value, dtype=diff_logits.dtype, device=diff_logits.device)
                            diff_logits_y = torch.where(relative_top_mask, relative_top_value, diff_logits_y)
                        log_probs_y = diff_logits_y[range(diff_logits_y.shape[0]), continue_ids].sum().item()
                        log_probs = log_probs - log_probs_y

                    return_dict[base_layer] = log_probs
                log_probs = return_dict

            elif mode == 'dynamic_early_exit_contrastive':
                critical_layer_dist = {l:0 for l in base_layers}
                picked_logits = []
                result_dict = {}
                critical_layers = []

                dict_outputs, outputs = self.model(
                    input_ids=input_ids,
                    return_dict=True,
                    output_attentions=False,
                    output_hidden_states=False,
                    early_exit_layers=base_layers + [final_layer],
                    tuned_lens=self.tuned_lens,
                )

                for seq_i in range(prefix_ids.shape[-1] - 1, input_ids.shape[-1] - 1):
                    # pick the less like layer to contrast with
                    if divergence_type == 'random': # a baseline for random value js_divs
                        js_divs = torch.rand(len(base_layers))
                    elif divergence_type == 'real_js':
                        # Stacking all base_layers into a new dimension
                        stacked_base_layers = torch.stack([dict_outputs[i][:, seq_i, :] for i in base_layers], dim=0)

                        # Calculate the softmax values for final_layer and all base_layers
                        softmax_final_layer = F.softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1)  # shape: (batch_size, num_features)
                        softmax_base_layers = F.softmax(stacked_base_layers, dim=-1)  # shape: (num_base_layers, batch_size, num_features)

                        # Calculate M, the average distribution
                        M = 0.5 * (softmax_final_layer[None, :, :] + softmax_base_layers)  # shape: (num_base_layers, batch_size, num_features)

                        # Calculate log-softmax for the KL divergence
                        log_softmax_final_layer = F.log_softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1)  # shape: (batch_size, num_features)
                        log_softmax_base_layers = F.log_softmax(stacked_base_layers, dim=-1)  # shape: (num_base_layers, batch_size, num_features)

                        # Calculate the KL divergences and then the JS divergences
                        kl1 = F.kl_div(log_softmax_final_layer[None, :, :], M, reduction='none').mean(-1)  # shape: (num_base_layers, batch_size)
                        kl2 = F.kl_div(log_softmax_base_layers, M, reduction='none').mean(-1)  # shape: (num_base_layers, batch_size)
                        js_divs = 0.5 * (kl1 + kl2)  # shape: (num_base_layers, batch_size)

                        # Reduce the batchmean
                        js_divs = js_divs.mean(-1)  # shape: (num_base_layers,)
                    else:
                        js_divs = torch.stack(
                            # reverse KL-divergence
                            [F.kl_div(F.log_softmax(dict_outputs[i][:, seq_i, :], dim=-1), F.softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), reduction='batchmean') for i in base_layers] if divergence_type == 'rev_kl' else (
                            # KL-divergence
                            [F.kl_div(F.log_softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), F.softmax(dict_outputs[i][:, seq_i, :], dim=-1), reduction='batchmean') for i in base_layers] if divergence_type == 'kl' else 
                            # JS-divergence
                            [0.5 * F.kl_div(F.log_softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), F.softmax(dict_outputs[i][:, seq_i, :], dim=-1), reduction='batchmean') + 0.5 * F.kl_div(F.log_softmax(dict_outputs[i][:, seq_i, :], dim=-1), F.softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), reduction='batchmean') for i in base_layers]
                            )
                            # [F.cosine_similarity(final_hidden - base_vector, hidden - base_vector, dim=-1) for hidden in dynamic_exit_hiddens]
                            # [torch.dist(final_hidden - base_vector, hidden - base_vector) for hidden in dynamic_exit_hiddens]
                        ).squeeze(-1)
                    critical_layer = base_layers[int(js_divs.argmax().cpu().item())]
                    critical_layer_dist[critical_layer] += 1

                    # less_than_threshold = kl_divs < critical_layer_threshold
                    # get the layer that is the first one to be less than critical_layer_threshold similar to the final layer
                    # less_than_threshold_idx = less_than_threshold.nonzero()
                    # if len(less_than_threshold_idx) == 0:
                        # critical_layer = dynamic_exit_layers[0]
                    # else:
                        # critical_layer = dynamic_exit_layers[int(less_than_threshold_idx.argmax())]
                    # debug
                    # to_print = ', '.join([f"{kl_divs[i].item():.2f}" for i in range(len(kl_divs))])
                    # token_id_curr = concat_input_ids[seq_i-100:seq_i+1]
                    # token_id_to_predict = concat_input_ids[seq_i + 1]
                    # token_curr = tokenizer.decode(token_id_curr).replace('\n', ' ')
                    # token_to_predict = tokenizer.decode(token_id_to_predict)
                    # print(f"cl: {critical_layer}, kl: [{to_print}] = {token_curr} -> {token_to_predict}")
                    critical_layers.append(critical_layer)

                # # for l in early_exit_layers[:-1]:

                base_logits = torch.zeros_like(dict_outputs[final_layer][0, prefix_ids.shape[-1] - 1:-1])
                for i, l in enumerate(critical_layers):
                   base_logits[i] = dict_outputs[l][0, prefix_ids.shape[-1] - 1 + i]
                final_logits = dict_outputs[final_layer][0, prefix_ids.shape[-1] - 1:-1]
                base_logits = self.rescale_logits(base_logits, torch.abs(final_logits).max())
                if premature_temp != 1.0:
                    base_logits = base_logits / premature_temp
                final_logits = final_logits.log_softmax(dim=-1)
                base_logits = base_logits.log_softmax(dim=-1)

                if beta is not None:
                    diff_logits = (1.0 + beta) * final_logits - beta * base_logits
                elif extrapolate_coeff is None or extrapolate_coeff >= 1000.0:
                    diff_logits = final_logits - base_logits
                else:
                    diff_logits = base_logits + extrapolate_coeff * (final_logits - base_logits)
                if post_softmax:
                    diff_logits = diff_logits.log_softmax(dim=-1)

                if relative_top > 0.0:
                    relative_top_mask = self.get_relative_top_filter(final_logits, relative_top)
                    relative_top_value = torch.tensor(relative_top_value, dtype=diff_logits.dtype, device=diff_logits.device)
                    diff_logits = torch.where(relative_top_mask, relative_top_value, diff_logits)
                # if contrast_disagree_only:
                #     cdo_token_mask = final_logits.argmax(-1) == base_logits.argmax(-1)
                #     cdo_token_mask = cdo_token_mask.unsqueeze(-1)
                #     diff_logits = torch.where(cdo_token_mask, final_logits, diff_logits)
                
                log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()

        return log_probs, extra_token_ids

    def attn_score(self, input_text, focus_span=None, extra_tokens=None, extra_token_ids=None, token_importance=None, pmi=False, max_new_tokens=256, top_p=0.95, top_k=0, temperature=0.8, final_layer=None, base_layer=None, base_layers=[], divergence_type='js', mode='vanilla', verbose=True, remove_stop_words=False, skip_layer0=False, relative_top=0.1, relative_top_with_norm=False, relative_top_value=-1000.0, contrast_disagree_only=False, extrapolate_coeff=None, post_softmax=True, premature_temp=1.0, beta=None, external_interpolation_factor=0.001, low_prob_percentile=0.1, steering_layers=None, selective_steering_n_heads=None, selective_steering_head_ids=None, selective_steering_layer_heads=None, observe_top_n_heads=None, shift_by_1=False, important_token_type=None, return_head_scores=False, **kwargs): # disconnect_attn=None, disconnect_mlp=None, 
        with torch.no_grad():
            # if the end char of input_text is a space, remove it
            # if the start char of extra_tokens is not a space, add a space
            if input_text[-1] == ' ':
                input_text = input_text[:-1]
            # find focus_span start end position in the input_text
            assert focus_span is not None
            focus_span_start_char = input_text.find(focus_span)
            assert focus_span_start_char != -1, f"focus_span {focus_span} not found in input_text"
            focus_span_end_char = focus_span_start_char + len(focus_span)
            # get token offsets
            inputs = self.tokenizer(input_text, return_tensors="pt", return_offsets_mapping=True)
            input_ids = inputs.input_ids.to(self.device)
            focus_span_start_token = None
            focus_span_end_token = None
            for idx, (s, e) in enumerate(inputs['offset_mapping'][0]):
                if focus_span_start_token is None and (s >= focus_span_start_char or e > focus_span_start_char):
                    focus_span_start_token = idx
                if focus_span_end_token is None and (e >= focus_span_end_char):
                    focus_span_end_token = idx

            if token_importance is None:
                if important_token_type is None: # pre-low-prob tokens
                    # use the tokens before low prob tokens as importance token
                    outputs = self.model(
                        input_ids,
                        return_dict=True,
                    )
                    logits = outputs.logits[:, :-1, :]
                    # only consider the log probs of the tokens before the focus span
                    context_logits = logits[:, :focus_span_start_token, :]
                    context_log_probs = context_logits[:, range(context_logits.shape[1]), input_ids[0, 1:focus_span_start_token+1]]
                    # use the tokens with lower probability (within the 10% percentile) as the importance tokens
                    token_importance = (context_log_probs < torch.quantile(context_log_probs, low_prob_percentile))#.type(logits.dtype).to(logits.device)
                    if shift_by_1: # shift the token importance by 1
                        token_importance = torch.cat([torch.tensor([True], device=token_importance.device), token_importance])
                    important_tokens = token_importance[0].nonzero().squeeze(-1)
                elif important_token_type == 'all':
                    important_tokens = torch.arange(0, focus_span_start_token, device=input_ids.device)
                else:
                    # use the exact low prob tokens as importance token
                    outputs = self.model(
                        input_ids,
                        return_dict=True,
                    )
                    logits = outputs.logits[:, :-1, :]
                    # only consider the log probs of the tokens before the focus span
                    context_logits = logits[:, :focus_span_start_token, :]
                    context_log_probs = context_logits[:, range(context_logits.shape[1]), input_ids[0, 1:focus_span_start_token+1]]
                    # use the tokens with lower probability (within the 10% percentile) as the importance tokens
                    if '_low' in important_token_type:
                        token_importance = (context_log_probs < torch.quantile(context_log_probs, low_prob_percentile))#.type(logits.dtype).to(logits.device)
                    elif '_high' in important_token_type:
                        token_importance = (context_log_probs > torch.quantile(context_log_probs, 1.0 - low_prob_percentile))
                    # shift the token importance by 1
                    if 'at_' in important_token_type:
                        token_importance = torch.cat([torch.tensor([[True]], device=token_importance.device), token_importance], dim=1)
                    if 'inv_' in important_token_type:
                        token_importance = ~token_importance
                    important_tokens = token_importance[0].nonzero().squeeze(-1)

            outputs = self.model(
                input_ids,
                return_dict=True,
                output_attentions=True,
            )# [0].squeeze(0) # , disconnect_attn=disconnect_attn, disconnect_mlp=disconnect_mlp
            # if seq length < 2048
            # if input_ids.shape[-1] < 2048:
            if False:
                focus_attentions = torch.cat([outputs.attentions[idx][:, :, focus_span_start_token: focus_span_end_token+1] for idx in range(len(outputs.attentions))], dim=0)
                # shape: (num_layers, num_heads, num_tokens (focus span), num_tokens (all text))
                # 1. get how much attention each token in the focus span pays to the other tokens in the focus span
                look_at_span_tokens = focus_attentions[:, :, :, focus_span_start_token: focus_span_end_token+1]
                # doing sum over a causal mask
                mask = 1.0 - torch.triu(torch.ones(look_at_span_tokens.shape[-2], look_at_span_tokens.shape[-2]), diagonal=1).to(self.device)
                lengths = mask.sum(-1)
                look_at_span_tokens_mean = look_at_span_tokens.sum(-1) / lengths
                look_at_span_tokens_mean = look_at_span_tokens_mean.mean(-1)
                # shape: (num_layers, num_heads)
                
                # 2. get how much attention each token in the focus span pays to the important tokens in the context
                look_at_important_tokens = focus_attentions[:, :, :, important_tokens]
                look_at_important_tokens_mean = look_at_important_tokens.mean(-1).mean(-1)
                # shape: (num_layers, num_heads)

                # 3. measure how much attention each token in the focus span pays to the important tokens in the context vs. the other tokens in the focus span
                # shape: (num_layers, num_heads)
                # attention_ratio = look_at_important_tokens_mean / look_at_span_tokens_mean

            else:
                look_at_span_tokens = torch.cat([outputs.attentions[idx][:, :, focus_span_start_token: focus_span_end_token+1, focus_span_start_token: focus_span_end_token+1] for idx in range(len(outputs.attentions))], dim=0)
                # doing sum over a causal mask
                mask = 1.0 - torch.triu(torch.ones(look_at_span_tokens.shape[-2], look_at_span_tokens.shape[-2]), diagonal=1).to(self.device)
                lengths = mask.sum(-1)
                look_at_span_tokens_mean = look_at_span_tokens.sum(-1) / lengths
                look_at_span_tokens_mean = look_at_span_tokens_mean.mean(-1)
                # shape: (num_layers, num_heads)

                # 2. get how much attention each token in the focus span pays to the important tokens in the context
                look_at_important_tokens = torch.cat([outputs.attentions[idx][:, :, focus_span_start_token: focus_span_end_token+1, important_tokens] for idx in range(len(outputs.attentions))], dim=0)
                look_at_important_tokens_mean = look_at_important_tokens.mean(-1).mean(-1)
                # shape: (num_layers, num_heads)

            # 4. generate final attn score by either 1) taking the mean of the attention ratio across all layers and heads or 2) taking the max of the attention ratio across all layers and heads
            if observe_top_n_heads is not None:
                # # method (1): get the top n heads for each layer
                # flat_attention_ratio = attention_ratio.flatten()
                # top_n_heads = flat_attention_ratio.topk(observe_top_n_heads).indices
                # # get average attention ratio of the top n heads
                # final_attn_score = flat_attention_ratio[top_n_heads].mean().item()

                # method (2): get the top n heads for each layer for look_at_important_tokens_mean and look_at_span_tokens_mean, respectively
                flat_look_at_important_tokens_mean = look_at_important_tokens_mean.flatten()
                top_n_heads_important = flat_look_at_important_tokens_mean.topk(observe_top_n_heads).indices
                flat_look_at_span_tokens_mean = look_at_span_tokens_mean.flatten()
                top_n_heads_span = flat_look_at_span_tokens_mean.topk(observe_top_n_heads).indices

                # get average attention ratio of the top n heads
                final_attn_score = flat_look_at_important_tokens_mean[top_n_heads_important].mean().item() / (flat_look_at_important_tokens_mean[top_n_heads_important].mean().item() + flat_look_at_span_tokens_mean[top_n_heads_span].mean().item())
                
            else:
                final_attn_score = look_at_important_tokens_mean.mean() / (look_at_important_tokens_mean.mean() + look_at_span_tokens_mean.mean())
            # test if nan
            if torch.isnan(final_attn_score):
                import ipdb; ipdb.set_trace()
            if not return_head_scores:
                return final_attn_score
            else:
                final_attn_score_per_heads = look_at_important_tokens_mean / (look_at_important_tokens_mean + look_at_span_tokens_mean)
                return final_attn_score, final_attn_score_per_heads

    def lying_score(self, input_text1, input_text2, pmi=False, max_new_tokens=256, top_p=0.95, top_k=0, temperature=0.8, final_layer=None, base_layer=None, base_layers=[], divergence_type='js', mode='vanilla', verbose=True, remove_stop_words=False, skip_layer0=False, relative_top=0.1, relative_top_with_norm=False, relative_top_value=-1000.0, contrast_disagree_only=False, extrapolate_coeff=None, post_softmax=True, premature_temp=1.0, **kwargs):
        with torch.no_grad():
            input_text = input_text1 + input_text2
            input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
            prefix_ids = self.tokenizer(input_text1, return_tensors="pt").input_ids.to(self.device)
            continue_ids = input_ids[0, prefix_ids.shape[-1]:]

            critical_layer_dist = {l:0 for l in base_layers}
            picked_logits = []
            result_dict = {}
            critical_layers = []

            dict_outputs, outputs = self.model(
                input_ids=input_ids,
                return_dict=True,
                output_attentions=False,
                output_hidden_states=False,
                early_exit_layers=base_layers + [final_layer],
                tuned_lens=self.tuned_lens,
            )

            # outputs = self.model(input_ids)[0].squeeze(0)
            final_outputs = dict_outputs[final_layer].squeeze(0)
            final_outputs = final_outputs.log_softmax(-1) # logits to log probs

            # skip tokens in the prompt -- we only care about the answer
            final_outputs = final_outputs[prefix_ids.shape[-1] - 1: -1, :]

            # get logprobs for each token in the answer
            log_probs = final_outputs[range(final_outputs.shape[0]), continue_ids].sum().item()

            # import ipdb; ipdb.set_trace()
            layer_shift = []
            # entropy_scores = []
            # cosine_scores = []
            std_dev_scores = []
            rank_shift = []
            final_layer_probs = []

            for seq_i in range(prefix_ids.shape[-1] - 1, input_ids.shape[-1] - 1):
                # pick the less like layer to contrast with
                target_id = input_ids[0, seq_i + 1]
                layer_probs = [float(dict_outputs[i][0, seq_i, target_id]) for i in base_layers]
                total_layer_jumps = 0.0
                for i in range(len(layer_probs) - 1):
                    jump = abs(layer_probs[i] - layer_probs[i+1])
                    total_layer_jumps += jump
                layer_shift.append(total_layer_jumps)

                # Entropy scores
                # entropy_scores.append(entropy(layer_probs))
                
                # Cosine similarity between consecutive layers
                # for i in range(len(layer_probs) - 1):
                #     cosine_scores.append(cosine_similarity([layer_probs[i]], [layer_probs[i+1]])[0][0])
                
                # Standard deviation of predicted probabilities
                std_dev_scores.append(np.std(layer_probs))
                
                # Token rank variability
                token_ranks = [int((dict_outputs[i][0, seq_i] > dict_outputs[i][0, seq_i, target_id]).sum()) for i in base_layers]
                # token_ranks_diffs = [abs(token_ranks[i] - token_ranks[i+1]) for i in range(len(token_ranks) - 1)]
                token_ranks_decreasing = [token_ranks[i] >= token_ranks[i+1] for i in range(len(token_ranks) - 1)]
                rank_shift.append(sum(token_ranks_decreasing))
                # import ipdb; ipdb.set_trace()

                # if divergence_type == 'random': # a baseline for random value js_divs
                #     js_divs = torch.rand(len(base_layers))
                # elif divergence_type == 'real_js':
                #     # Stacking all base_layers into a new dimension
                #     stacked_base_layers = torch.stack([dict_outputs[i][:, seq_i, :] for i in base_layers], dim=0)

                #     # Calculate the softmax values for final_layer and all base_layers
                #     softmax_final_layer = F.softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1)  # shape: (batch_size, num_features)
                #     softmax_base_layers = F.softmax(stacked_base_layers, dim=-1)  # shape: (num_base_layers, batch_size, num_features)

                #     # Calculate M, the average distribution
                #     M = 0.5 * (softmax_final_layer[None, :, :] + softmax_base_layers)  # shape: (num_base_layers, batch_size, num_features)

                #     # Calculate log-softmax for the KL divergence
                #     log_softmax_final_layer = F.log_softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1)  # shape: (batch_size, num_features)
                #     log_softmax_base_layers = F.log_softmax(stacked_base_layers, dim=-1)  # shape: (num_base_layers, batch_size, num_features)

                #     # Calculate the KL divergences and then the JS divergences
                #     kl1 = F.kl_div(log_softmax_final_layer[None, :, :], M, reduction='none').mean(-1)  # shape: (num_base_layers, batch_size)
                #     kl2 = F.kl_div(log_softmax_base_layers, M, reduction='none').mean(-1)  # shape: (num_base_layers, batch_size)
                #     js_divs = 0.5 * (kl1 + kl2)  # shape: (num_base_layers, batch_size)

                #     # Reduce the batchmean
                #     js_divs = js_divs.mean(-1)  # shape: (num_base_layers,)
                # else:
                #     js_divs = torch.stack(
                #         # reverse KL-divergence
                #         [F.kl_div(F.log_softmax(dict_outputs[i][:, seq_i, :], dim=-1), F.softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), reduction='batchmean') for i in base_layers] if divergence_type == 'rev_kl' else (
                #         # KL-divergence
                #         [F.kl_div(F.log_softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), F.softmax(dict_outputs[i][:, seq_i, :], dim=-1), reduction='batchmean') for i in base_layers] if divergence_type == 'kl' else 
                #         # JS-divergence
                #         [0.5 * F.kl_div(F.log_softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), F.softmax(dict_outputs[i][:, seq_i, :], dim=-1), reduction='batchmean') + 0.5 * F.kl_div(F.log_softmax(dict_outputs[i][:, seq_i, :], dim=-1), F.softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), reduction='batchmean') for i in base_layers]
                #         )
                #         # [F.cosine_similarity(final_hidden - base_vector, hidden - base_vector, dim=-1) for hidden in dynamic_exit_hiddens]
                #         # [torch.dist(final_hidden - base_vector, hidden - base_vector) for hidden in dynamic_exit_hiddens]
                #     ).squeeze(-1)
                # critical_layer = base_layers[int(js_divs.argmax().cpu().item())]
                # critical_layer_dist[critical_layer] += 1
                # critical_layers.append(critical_layer)

            # base_logits = torch.zeros_like(dict_outputs[final_layer][0, prefix_ids.shape[-1] - 1:-1])
            # for i, l in enumerate(critical_layers):
            #     base_logits[i] = dict_outputs[l][0, prefix_ids.shape[-1] - 1 + i]
            # final_logits = dict_outputs[final_layer][0, prefix_ids.shape[-1] - 1:-1]
            # base_logits = self.rescale_logits(base_logits, torch.abs(final_logits).max())
            # if premature_temp != 1.0:
            #     base_logits = base_logits / premature_temp
            # final_logits = final_logits.log_softmax(dim=-1)
            # base_logits = base_logits.log_softmax(dim=-1)
            # if extrapolate_coeff is None or extrapolate_coeff >= 1000.0:
            #     diff_logits = final_logits - base_logits
            # else:
            #     diff_logits = base_logits + extrapolate_coeff * (final_logits - base_logits)
            # if post_softmax:
            #     diff_logits = diff_logits.log_softmax(dim=-1)

            # if relative_top > 0.0:
            #     relative_top_mask = self.get_relative_top_filter(final_logits, relative_top)
            #     relative_top_value = torch.tensor(relative_top_value, dtype=diff_logits.dtype, device=diff_logits.device)
            #     diff_logits = torch.where(relative_top_mask, relative_top_value, diff_logits)
            # # if contrast_disagree_only:
            # #     cdo_token_mask = final_logits.argmax(-1) == base_logits.argmax(-1)
            # #     cdo_token_mask = cdo_token_mask.unsqueeze(-1)
            # #     diff_logits = torch.where(cdo_token_mask, final_logits, diff_logits)
            
            # log_probs = diff_logits[range(diff_logits.shape[0]), continue_ids].sum().item()
            # log_probs = sum(std_dev_scores) / len(std_dev_scores)
            log_probs += (sum(rank_shift) / len(rank_shift)) / len(base_layers)
            # log_probs = log_probs
            # log_probs = sum(layer_shift) / len(layer_shift)

        return log_probs, critical_layer_dist

    def jsdiv_ner(self, input_text1, input_text2, ner_tags, pmi=False, max_new_tokens=256, top_p=0.95, top_k=0, temperature=0.8, final_layer=None, base_layer=None, base_layers=[], divergence_type='js', mode='vanilla', verbose=True, remove_stop_words=False, skip_layer0=False, relative_top=0.1, relative_top_with_norm=False, relative_top_value=-1000.0, contrast_disagree_only=False, extrapolate_coeff=None, post_softmax=True, premature_temp=1.0, **kwargs):
        with torch.no_grad():
            # input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
            prefix_ids = self.tokenizer(input_text1).input_ids

            tokenized_input = self.tokenizer([input_text2], is_split_into_words=True)
            new_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"][0])

            # input_ids = torch.tensor(tokenized_input["input_ids"]).cuda()


            # Align the labels with the new tokenization
            new_labels = []
            prev_tag = 0
            for word_idx in tokenized_input.word_ids()[1:]:
                if prev_tag == ner_tags[word_idx] and prev_tag%2==1:
                    new_labels.append(ner_tags[word_idx]+1)
                else:
                    new_labels.append(ner_tags[word_idx])
                prev_tag = ner_tags[word_idx]
            new_labels = new_labels
            tokens = tokenized_input["input_ids"][0][1:]
            input_ids = torch.tensor([prefix_ids+tokens]).cuda()
            critical_layer_dist = {l:0 for l in base_layers}
            critical_layer_list = []
            picked_logits = []
            result_dict = {}
            critical_layers = []

            dict_outputs, outputs = self.model(
                input_ids=input_ids,
                return_dict=True,
                output_attentions=False,
                output_hidden_states=False,
                early_exit_layers=base_layers + [final_layer],
                tuned_lens=self.tuned_lens,
            )

            for seq_i in range(len(prefix_ids) - 1, input_ids.shape[-1] - 1):
                # pick the less like layer to contrast with
                if divergence_type == 'random': # a baseline for random value js_divs
                    js_divs = torch.rand(len(base_layers))
                elif divergence_type == 'real_js':
                    # Stacking all base_layers into a new dimension
                    stacked_base_layers = torch.stack([dict_outputs[i][:, seq_i, :] for i in base_layers], dim=0)

                    # Calculate the softmax values for final_layer and all base_layers
                    softmax_final_layer = F.softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1)  # shape: (batch_size, num_features)
                    softmax_base_layers = F.softmax(stacked_base_layers, dim=-1)  # shape: (num_base_layers, batch_size, num_features)

                    # Calculate M, the average distribution
                    M = 0.5 * (softmax_final_layer[None, :, :] + softmax_base_layers)  # shape: (num_base_layers, batch_size, num_features)

                    # Calculate log-softmax for the KL divergence
                    log_softmax_final_layer = F.log_softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1)  # shape: (batch_size, num_features)
                    log_softmax_base_layers = F.log_softmax(stacked_base_layers, dim=-1)  # shape: (num_base_layers, batch_size, num_features)

                    # Calculate the KL divergences and then the JS divergences
                    kl1 = F.kl_div(log_softmax_final_layer[None, :, :], M, reduction='none').mean(-1)  # shape: (num_base_layers, batch_size)
                    kl2 = F.kl_div(log_softmax_base_layers, M, reduction='none').mean(-1)  # shape: (num_base_layers, batch_size)
                    js_divs = 0.5 * (kl1 + kl2)  # shape: (num_base_layers, batch_size)

                    # Reduce the batchmean
                    js_divs = js_divs.mean(-1)  # shape: (num_base_layers,)
                else:
                    js_divs = torch.stack(
                        # reverse KL-divergence
                        [F.kl_div(F.log_softmax(dict_outputs[i][:, seq_i, :], dim=-1), F.softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), reduction='batchmean') for i in base_layers] if divergence_type == 'rev_kl' else (
                        # KL-divergence
                        [F.kl_div(F.log_softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), F.softmax(dict_outputs[i][:, seq_i, :], dim=-1), reduction='batchmean') for i in base_layers] if divergence_type == 'kl' else 
                        # JS-divergence
                        [0.5 * F.kl_div(F.log_softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), F.softmax(dict_outputs[i][:, seq_i, :], dim=-1), reduction='batchmean') + 0.5 * F.kl_div(F.log_softmax(dict_outputs[i][:, seq_i, :], dim=-1), F.softmax(dict_outputs[final_layer][:, seq_i, :], dim=-1), reduction='batchmean') for i in base_layers]
                        )
                        # [F.cosine_similarity(final_hidden - base_vector, hidden - base_vector, dim=-1) for hidden in dynamic_exit_hiddens]
                        # [torch.dist(final_hidden - base_vector, hidden - base_vector) for hidden in dynamic_exit_hiddens]
                    ).squeeze(-1)
                critical_layer = base_layers[int(js_divs.argmax().cpu().item())]
                critical_layer_dist[critical_layer] += 1

                # less_than_threshold = kl_divs < critical_layer_threshold
                # get the layer that is the first one to be less than critical_layer_threshold similar to the final layer
                # less_than_threshold_idx = less_than_threshold.nonzero()
                # if len(less_than_threshold_idx) == 0:
                    # critical_layer = dynamic_exit_layers[0]
                # else:
                    # critical_layer = dynamic_exit_layers[int(less_than_threshold_idx.argmax())]
                # debug
                # to_print = ', '.join([f"{kl_divs[i].item():.2f}" for i in range(len(kl_divs))])
                # token_id_curr = concat_input_ids[seq_i-100:seq_i+1]
                # token_id_to_predict = concat_input_ids[seq_i + 1]
                # token_curr = tokenizer.decode(token_id_curr).replace('\n', ' ')
                # token_to_predict = tokenizer.decode(token_id_to_predict)
                # print(f"cl: {critical_layer}, kl: [{to_print}] = {token_curr} -> {token_to_predict}")
                critical_layers.append(critical_layer)

            # # for l in early_exit_layers[:-1]:
        assert len(new_labels) == len(critical_layers)

        return new_labels, critical_layers
