# Ref: https://github.com/kojima-takeshi188/zero_shot_cot

import re
import os
import json
import random
import torch
import numpy as np
import pandas as pd
import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from tqdm import tqdm, trange
import argparse
import math
from collections import defaultdict, Counter
import glob
import sys
import pickle

import ssl
import urllib.request
import zipfile
import tiktoken
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

from generation import LLM


transformers.logging.set_verbosity(40)

ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"

N_SHOT = 8
COT_FLAG = True
DEBUG = True
ANSWER_TRIGGER = "The answer is"

def num_tokens_from_message(message, llama2_tokenizer):
    return len(llama2_tokenizer(message)['input_ids'])


def truncate_message(prompt1, prompt2, llama2_tokenizer):
    if num_tokens_from_message(prompt1 + prompt2) > 2033:
        truncation_length = 2033 - num_tokens_from_message(prompt2, llama2_tokenizer)
        while num_tokens_from_message(prompt1) > truncation_length:
            prompt1 = " ".join(prompt1.split(' ')[:-1])
    prompt = prompt1 + prompt2
    return prompt


demo_keys = []

data_context_keys = {
    'summarization': 'document',
    'dialogue': 'dialogue_history',
    'xsum': 'document',
    'cnndm': 'article',
}

data_context_names = {
    'summarization': 'Document',
    'dialogue': 'Dialogue History',
    'nq': 'Document',
    'xsum': 'Article',
    'cnndm': 'Article',
}

data_response_names = {
    'summarization': 'Summary',
    'dialogue': 'Response',
    'nq': 'Answer',
    'xsum': 'Summary',
    'cnndm': 'Summary',
}

temperature_config = {
    "writing": 0.7,
    "roleplay": 0.7,
    "extraction": 0.0,
    "math": 0.0,
    "coding": 0.0,
    "reasoning": 0.0,
    "stem": 0.1,
    "humanities": 0.1,
    "arena-hard-200": 0.0,
}

def load_nq_open(file_path, parallel=False, total_shard=8, shard_id=0, debug=False, data_type='nq_open', subsample=None):
    list_data_dict = []
    is_train = 'nq_train' in file_path
    with open(file_path, 'r', encoding="utf-8") as f:
        data = []
        data_indices = []
        data_index = 0
        for line in f:
            data.append(json.loads(line))
            data_indices.append(data_index)
            data_index += 1
        if debug:
            data = data[:100]
            data_indices = data_indices[:100]
        if subsample is not None:
            # select data if idx%subsample == 0
            data = [data[i] for i in range(len(data)) if i % subsample == 0]
            data_indices = [data_indices[i] for i in range(len(data_indices)) if i % subsample == 0]
        if parallel:
            chunk_size = len(data) // total_shard
            data = data[shard_id * chunk_size: (shard_id + 1) * chunk_size] if shard_id != total_shard - 1 else data[shard_id * chunk_size:]
            data_indices = data_indices[shard_id * chunk_size: (shard_id + 1) * chunk_size] if shard_id != total_shard - 1 else data_indices[shard_id * chunk_size:]

        for idx in range(len(data)):
            data_index = data_indices[idx]
            question = data[idx]['question']
            # capitalize the first letter of the question, add the question mark if not present at the end
            question = question[0].upper() + question[1:]
            if question[-1] != '?':
                question += '?'
            answers = data[idx]['answers']
            if is_train:
                pos_ctxs = data[idx]['positive_ctxs']
                neg_ctxs = data[idx]['negative_ctxs']
            else:
                ctxs = data[idx]['ctxs']
                pos_ctxs = [ctx for ctx in ctxs if ctx['hasanswer']]
                neg_ctxs = [ctx for ctx in ctxs if not ctx['hasanswer']]
            assert len(pos_ctxs) > 0, "No positive context found."
            assert len(neg_ctxs) >= 2, "At least two negative contexts are required."
            context = f"#Document#: " + neg_ctxs[0]['text'] + '\n' + pos_ctxs[0]['text'] + '\n' + neg_ctxs[1]['text']
            context += f"\n#Question#: {question}"
            response = f"\n#Answer#:"
            new_item = dict(
                context=context,
                response=response,
                answer=answers[0],
                data_index=data_index
            )
            list_data_dict.append(new_item)
    return list_data_dict

def load_nq_train(file_path, parallel=False, total_shard=8, shard_id=0, debug=False, data_type='nq_open', subsample=None):
    list_data_dict = []
    with open(file_path, 'r', encoding="utf-8") as f:
        data = json.load(f)
        data_indices = [i for i in range(len(data))]

        if debug:
            data = data[:100]
            data_indices = data_indices[:100]
        if subsample is not None:
            # select data if idx%subsample == 0
            data = [data[i] for i in range(len(data)) if i % subsample == 0]
            data_indices = [data_indices[i] for i in range(len(data_indices)) if i % subsample == 0]
        if parallel:
            chunk_size = len(data) // total_shard
            data = data[shard_id * chunk_size: (shard_id + 1) * chunk_size] if shard_id != total_shard - 1 else data[shard_id * chunk_size:]
            data_indices = data_indices[shard_id * chunk_size: (shard_id + 1) * chunk_size] if shard_id != total_shard - 1 else data_indices[shard_id * chunk_size:]

        for idx in range(len(data)):
            data_index = data_indices[idx]
            question = data[idx]['question']
            # capitalize the first letter of the question, add the question mark if not present at the end
            question = question[0].upper() + question[1:]
            if question[-1] != '?':
                question += '?'
            answers = data[idx]['answers']
            pos_ctxs = data[idx]['positive_ctxs']
            neg_ctxs = data[idx]['negative_ctxs']
            assert len(pos_ctxs) > 0, "No positive context found."
            assert len(neg_ctxs) >= 2, "At least two negative contexts are required."
            context = f"#Document#: " + neg_ctxs[0]['text'] + '\n' + pos_ctxs[0]['text'] + '\n' + neg_ctxs[1]['text']
            context += f"\n#Question#: {question}"
            response = f"\n#Answer#:"
            new_item = dict(
                context=context,
                response=response,
                answer=answers[0],
                data_index=data_index
            )
            list_data_dict.append(new_item)
    return list_data_dict

def load_jsonl(file_path, parallel=False, total_shard=8, shard_id=0, debug=False, data_type='summarization', subsample=None):
    list_data_dict = []
    with open(file_path, 'r', encoding="utf-8") as f:
        data = []
        data_indices = []
        data_index = 0
        for line in f:
            data.append(json.loads(line))
            data_indices.append(data_index)
            data_index += 1
        if debug:
            data = data[:100]
            data_indices = data_indices[:100]
        if subsample is not None:
            # select data if idx%subsample == 0
            data = [data[i] for i in range(len(data)) if i % subsample == 0]
            data_indices = [data_indices[i] for i in range(len(data_indices)) if i % subsample == 0]
        if parallel:
            chunk_size = len(data) // total_shard
            data = data[shard_id * chunk_size: (shard_id + 1) * chunk_size] if shard_id != total_shard - 1 else data[shard_id * chunk_size:]
            data_indices = data_indices[shard_id * chunk_size: (shard_id + 1) * chunk_size] if shard_id != total_shard - 1 else data_indices[shard_id * chunk_size:]

        for idx in range(len(data)):
            data_index = data_indices[idx]
            if data_type == 'mt_bench':
                context = data[idx]['document']
                category = data[idx]['category']
            else:
                context = f"#{data_context_names[data_type]}#: " + data[idx][data_context_keys[data_type]]
            new_item = dict(
                context=context,
                data_index=data_index,
                category=category if data_type == 'mt_bench' else None
            )
            list_data_dict.append(new_item)

    return list_data_dict

def dump_jsonl(data, output_path, append=False):
    """
    Write list of objects to a JSON lines file.
    """
    mode = 'a+' if append else 'w'
    with open(output_path, mode, encoding='utf-8') as f:
        json_record = json.dumps(data, ensure_ascii=False)
        f.write(json_record + '\n')


def create_demo_text(data_type='summarization'):
    if data_type == 'summarization':
        return "Generate a summary based on the information in the document.\n\n"
    elif data_type == 'nq':
        return "Answer the question based on the information in the document. Explain your reasoning in the document step-by-step before providing the final answer.\n\n"
    elif data_type == 'xsum':
        return "Generate a summary comprising of 1 sentence for the given article.\n\n"
    elif data_type == 'cnndm':
        return "Generate a summary comprising of 3 sentences for the given article.\n\n"
    else:
        return None

# instruction = 'You should try your best to determine if the summary contains non-factual or hallucinated information according to the above hallucination types. The answer you give MUST be \"Yes\" or \"No\"".'


def build_prompt(context, response, data_type='summarization', llama2_tokenizer=None):
    demo = create_demo_text(data_type)
    prompt = demo + context
    if data_type == 'summarization':
        input_text_prompt = truncate_message(prompt, response, llama2_tokenizer)
    else:
        input_text_prompt = prompt + response
    return input_text_prompt


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

def plot_histogram_distributions(scores, labels, filename='histogram.png'):
    """
    Plot the histogram distributions of prediction scores for two groups of examples, 
    separated by their labels (0 or 1).
    
    :param scores: A list of predicted scores from the model.
    :param labels: A list of ground truth labels.
    """
    # Validate input lengths
    if len(scores) != len(labels):
        raise ValueError("The length of scores and labels must be the same.")

    # Create a DataFrame for easier plotting
    data = {'Score': scores, 'Label': labels}
    df = pd.DataFrame(data)

    # Plotting
    plt.figure(figsize=(10, 6))
    sns.histplot(data=df, x='Score', hue='Label', element='step', stat='count', common_norm=False, palette='viridis')
    plt.title('Histogram of Prediction Scores by Label')
    plt.xlabel('Predicted Score')
    plt.ylabel('Count')
    # plt.show()
    plt.savefig(filename)

def find_best_threshold(fpr, tpr, thresholds):
    """
    Find the best threshold from the ROC curve by choosing the point 
    closest to the top-left corner (0,1).

    :param fpr: Array of False Positive Rates
    :param tpr: Array of True Positive Rates
    :param thresholds: Array of thresholds corresponding to each (FPR, TPR) point
    :return: The best threshold value
    """
    # Calculate the Euclidean distance for each point on the ROC curve from the top-left corner
    distances = np.sqrt((1 - tpr) ** 2 + fpr ** 2)
    
    # Find the index of the smallest distance
    best_idx = np.argmin(distances)

    # Return the threshold at this index
    return thresholds[best_idx]

def calculate_metrics(scores, labels, threshold):
    """
    Calculate precision, recall, F1 score, and average accuracy based on a given threshold.
    
    :param scores: A list of predicted scores from the model.
    :param labels: A list of ground truth labels.
    :param threshold: The threshold to convert scores to binary classifications.
    :return: A tuple containing precision, recall, F1 score, and accuracy.
    """
    # Convert scores to binary classifications
    predictions = [1 if score >= threshold else 0 for score in scores]

    # Calculate metrics
    precision = precision_score(labels, predictions)
    recall = recall_score(labels, predictions)
    f1 = f1_score(labels, predictions)
    accuracy = accuracy_score(labels, predictions)
    
    subset_acc_where_label_0 = accuracy_score([labels[i] for i in range(len(labels)) if labels[i] == 0], [predictions[i] for i in range(len(labels)) if labels[i] == 0])
    subset_acc_where_label_1 = accuracy_score([labels[i] for i in range(len(labels)) if labels[i] == 1], [predictions[i] for i in range(len(labels)) if labels[i] == 1])
    # harmonic mean of subset accuracy 2*corrects[0]*corrects[1]/(correct*len(list_data_dict[0]))
    harmonic_mean_accuracy = 2*subset_acc_where_label_0*subset_acc_where_label_1/(subset_acc_where_label_0+subset_acc_where_label_1)

    return precision, recall, f1, accuracy, subset_acc_where_label_0, subset_acc_where_label_1, harmonic_mean_accuracy

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
    parser.add_argument("--num-gpus", type=str, default="1")
    parser.add_argument("--device", type=str,
                        choices=["cuda", "cpu"], default="cuda")
    parser.add_argument("--data-path", type=str, default="./gsm8k")
    parser.add_argument("--output-path", type=str, default="./gsm8k_result")
    # parallel mode (split the dataset into multiple parts, inference by separate processes)
    parser.add_argument("--early-exit-layers", type=str, default="-1")
    parser.add_argument("--divergence-type", type=str, default="js")
    parser.add_argument("--parallel", action="store_true")
    parser.add_argument("--total-shard", type=int, default=8)
    parser.add_argument("--shard-id", type=int, default=0)
    parser.add_argument("--max-new-tokens", type=int, default=256)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=0.9)
    parser.add_argument("--repetition_penalty", type=float, default=1.0)
    parser.add_argument("--extrapolate_coeff", type=float, default=10000.0)
    parser.add_argument("--relative_top", type=float, default=0.1)
    # parser.add_argument("--relative_top_value", type=float, default=-1000.0)
    parser.add_argument("--relative_top_with_norm", action="store_true")
    parser.add_argument("--contrast_disagree_only", action="store_true")
    parser.add_argument("--pre_softmax", action="store_true")
    parser.add_argument("--do_sample", action="store_true")
    parser.add_argument("--do_shuffle", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--subsample", type=int, default=None)
    parser.add_argument("--penalty_alpha", type=float, default=None)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--retry", type=int, default=1)
    parser.add_argument("--tuned-lens-path", type=str, default=None)
    parser.add_argument("--auth-token", type=str, default=None)
    parser.add_argument("--premature_temp", type=float, default=1.0)
    parser.add_argument("--apply_early_norm", action="store_true")
    parser.add_argument("--attn-intervention", action="store_true")
    parser.add_argument("--attn-intervention-low-prob", action="store_true")
    parser.add_argument("--attn-int-factor", type=float, default=0.0001)
    parser.add_argument("--low-prob-percentile", type=float, default=0.1)
    parser.add_argument("--keys-path", type=str, default=None)
    parser.add_argument("--pause-num", type=int, default=3)
    parser.add_argument("--alpha", type=float, default=10)
    # parser.add_argument("--subsets", type=str, default="hallucinated_summary,right_summary")
    parser.add_argument("--pondering", type=str, default=None)
    parser.add_argument("--free-form", action="store_true")
    parser.add_argument("--attn-score", action="store_true")
    parser.add_argument("--shift-by-1", action="store_true")
    parser.add_argument("--important-token-type", type=str, default=None)
    parser.add_argument("--data-type", type=str, default=None)
    # observe_top_n_heads
    parser.add_argument("--observe-top-n-heads", type=int, default=None)
    # rejecti_sampling_clf model path
    parser.add_argument("--reject-sampling-clf", type=str, default=None)
    # chunk size
    parser.add_argument("--chunk-size", type=int, default=5)
    # num candidates
    parser.add_argument("--num-candidates", type=int, default=8)
    # conversion matrix
    parser.add_argument("--conversion-matrix", type=str, default=None)
    # flash attention 2
    parser.add_argument("--use-flash-attention-2", action="store_true")
    parser.add_argument("--load-in-4bit", action="store_true")
    # feat_layer
    parser.add_argument("--feat-layer", type=int, default=None)
    # attn steer factor
    parser.add_argument("--attn-steer-factor", type=float, default=None)
    # steer type
    parser.add_argument("--steer-type", type=str, default='sigmoid+0.5')
    parser.add_argument("--visualize", action="store_true")
    parser.add_argument("--vis-idx", type=int, default=0)
    parser.add_argument("--vis-greedy", type=str, default=None)
    parser.add_argument("--vis-gpt4o-key", type=str, default=None)
    

    args = parser.parse_args()
    model_name = args.model_name
    num_gpus = args.num_gpus
    device = args.device

    set_seed(args.seed)

    # load your finetuned model (saved as xxx.ckpt)
    #    in yaml file federate.save_to
    forced_truncate = ('gpt2' in args.model_name)
    if args.data_type is None:
        if 'summ' in args.data_path:
            args.data_type = 'summarization'
        elif 'nq-open' in args.data_path:
            args.data_type = 'nq'
        elif 'xsum' in args.data_path:
            args.data_type = 'xsum'
        elif 'cnndm' in args.data_path:
            args.data_type = 'cnndm'
        elif 'mt_bench' in args.data_path:
            args.data_type = 'mt_bench'
        else:
            raise ValueError("Please specify the data type.")
    # Get test file
    fp = args.data_path
    if not os.path.exists(fp):
        raise ValueError(f"Test file {fp} does not exist.")

    if "nq-open" in fp:
        list_data_dict = load_nq_open(fp, parallel=args.parallel, total_shard=args.total_shard, shard_id=args.shard_id, debug=args.debug, subsample=args.subsample)
    else:
        list_data_dict = load_jsonl(fp, parallel=args.parallel, total_shard=args.total_shard, shard_id=args.shard_id, debug=args.debug, data_type=args.data_type, subsample=args.subsample)
    
    if args.pondering is not None:
        list_data_dict_keys = load_jsonl(
            fp, pondering=args.pondering, keys_path=args.keys_path, parallel=args.parallel, total_shard=args.total_shard, shard_id=args.shard_id, debug=args.debug, data_type=args.data_type, subsample=args.subsample)
    llm = LLM(
        model_name, device, num_gpus, args.tuned_lens_path, args.auth_token, use_flash_attention_2=args.use_flash_attention_2, load_in_4bit=args.load_in_4bit)
    stop_word_list = ["### User:", "Q:", "\end{code}", "#Document#:", "#Pondering#:", "#Question#:", "#Dialogue History#:"]
    llm.set_stop_words(stop_word_list)
    early_exit_layers = [int(x) for x in args.early_exit_layers.split(',')]
    reject_sampling_clf = None
    attn_steer_factor = None
    if args.reject_sampling_clf is not None:
        if 'finetune_output_' in args.reject_sampling_clf:
            nli_model = AutoModelForSequenceClassification.from_pretrained(args.reject_sampling_clf)
            nli_tokenizer = AutoTokenizer.from_pretrained(args.reject_sampling_clf)
            nli_model.to(device)
            reject_sampling_clf = {'model': nli_model, 'tokenizer': nli_tokenizer, 'is_deberta': True, 'is_cross_encoder': False, 'visualize': args.visualize, 'vis_tokenizer': llm.tokenizer}
        elif '.pkl' in args.reject_sampling_clf:
            reject_sampling_clf = pickle.load(open(args.reject_sampling_clf, 'rb'))
            reject_sampling_clf['is_cross_encoder'] = False
            reject_sampling_clf['is_deberta'] = False
            reject_sampling_clf['visualize'] = args.visualize
            reject_sampling_clf['vis_tokenizer'] = llm.tokenizer
        else:
            from sentence_transformers import CrossEncoder
            reject_sampling_clf = {}
            model = CrossEncoder(args.reject_sampling_clf)
            tokenizer = llm.tokenizer
            reject_sampling_clf['model'] = model
            reject_sampling_clf['tokenizer'] = tokenizer
            reject_sampling_clf['is_cross_encoder'] = True
            reject_sampling_clf['is_deberta'] = False
            reject_sampling_clf['visualize'] = args.visualize
            reject_sampling_clf['vis_tokenizer'] = llm.tokenizer
        if args.attn_steer_factor is None:
            mode = "rejection_sampling"
            print("MODE: rejection sampling decoding", flush=True)
        elif args.attn_steer_factor is not None:
            assert reject_sampling_clf is not None
            mode = "attention_steering"
            print("MODE: attention steering decoding", flush=True)
            num_layers = len(llm.model.model.layers)
            attn_steer_factor = args.attn_steer_factor * reject_sampling_clf['clf'].coef_[0].reshape((num_layers, -1))
            if args.steer_type == 'sigmoid+0.5':
                attn_steer_factor_torch = torch.tensor(attn_steer_factor)
                attn_steer_factor_torch = torch.sigmoid(attn_steer_factor_torch) + 0.5
                attn_steer_factor = attn_steer_factor_torch.cpu().numpy()
            elif args.steer_type == 'sigmoidx2':
                attn_steer_factor_torch = torch.tensor(attn_steer_factor)
                attn_steer_factor_torch = torch.sigmoid(attn_steer_factor_torch) * 2
                attn_steer_factor = attn_steer_factor_torch.cpu().numpy()
            elif args.steer_type == 'exp':
                attn_steer_factor = np.exp(attn_steer_factor)
            else:
                raise ValueError("Invalid steering type.")
    else:
        mode = "vanilla"
        print("MODE: naive decoding from the last layer", flush=True)
    if args.visualize:
        if args.vis_greedy is not None:
            greedy_file = open(args.vis_greedy, 'r')
            greedy_examples = [json.loads(line) for line in open(args.vis_greedy, 'r')]

        if args.vis_gpt4o_key is not None:
            import openai
            client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
            model = 'gpt-4o-2024-05-13', #'gpt-4-turbo-2024-04-09'

            def evaluate_response(document, gt_response, response, data_type='summarization', debug=False):
                prompt = f"{eval_prompt_before[data_type]}\n\n{document}\n\n#Ground Truth {data_response_names_gt[data_type]}#: {gt_response}\n\n#Proposed {data_response_names[data_type]}#: {response}\n\n{eval_prompt_after[data_type]}"

                print(prompt)
                print('-----------------------', flush=True)
                response = client.chat.completions.create(
                    model=model,
                    messages=[
                        {"role": "system", "content": "You are a helpful assistant."},
                        {"role": "user", "content": prompt}
                    ]
                )

                # output format: {'text': '.... Conclusion: True'}
                text = response.choices[0].message.content
                if debug:
                    print('-------------------')
                    print(prompt)
                    print('\n'+text+'\n')
                    print('-------------------', flush=True)
                
                problematic_spans = []
                if "Problematic Spans: " in text:
                    problematic_spans = text.split('Problematic Spans: ')[1]
                    if '**' in problematic_spans:
                        problematic_spans = problematic_spans.split('**')[0].strip()
                    # problematic_spans is in python list of string format, extract the list
                    try:
                        problematic_spans = eval(problematic_spans)
                    except:
                        print("Error in parsing problematic spans:", problematic_spans)
                        problematic_spans = problematic_spans[1:-1].split(', ')

                    if debug:
                        print(problematic_spans)
                if "Conclusion: " in text:
                    dec = text.split('Conclusion: ')[1]
                    if '**' in dec:
                        dec = dec.split('**')[0]
                    if debug:
                        print(dec)
                    if "True" in dec:
                        return True, text, problematic_spans
                    elif "False" in dec:
                        return False, text, problematic_spans
                    else:
                        return None, text, problematic_spans
                else:
                    return None, text, problematic_spans

    conversion_matrix = None
    if args.conversion_matrix is not None:
        conversion_matrix = pickle.load(open(args.conversion_matrix, 'rb'))
    final_layer = None
    base_layer = None
    dynamic_exit_layers = None
    answers = []
    # result_dict = {'is_correct': [], 'model_answer': [], 'score_yes': [], 'score_no': [], 'label_binary': [], 
    #                'model_completion': [], 'full_input_text': []}
    to_save_attention_list = []
    output_path = args.output_path
    corrects, incorrects = [], []
    choices_token_ids = None
    output_path = args.output_path+"_"+str(args.shard_id)+".jsonl"
    done_indices = {}
    if os.path.exists(output_path):
        print("Try to resume from the existing output file.")
        with open(output_path, 'r') as f:
            for line in f:
                data = json.loads(line)
                for key, value in data.items():
                    done_indices[int(key)] = value
        fw = open(output_path, 'a')
    else:
        fw = open(output_path, 'w') # \n\n### Assistant:
    if args.data_type == 'mt_bench':
        extra_prompt_length = len(llm.tokenizer(f"\n\n### Assistant:")['input_ids'])
    else:
        extra_prompt_length = len(llm.tokenizer(f"\n#{data_response_names[args.data_type]}#:")['input_ids']) - 1
    time_decoding = 0.0
    for idx in tqdm(range(len(list_data_dict))):
        if args.visualize:
            if idx < args.vis_idx:
                continue
            reject_sampling_clf['greedy_example'] = greedy_examples[idx][str(idx)]
        # data_index = j * len(list_data_dict[j]) + idx
        # Print all the information
        sample = list_data_dict[idx]
        if sample['data_index'] in done_indices:
            continue
        
        if args.data_type != 'mt_bench':
            input_text = build_prompt(sample['context'], f"\n#{data_response_names[args.data_type]}#:", data_type=args.data_type, llama2_tokenizer=llm.tokenizer)
        else:
            input_text = sample['context']
        attn_intervention_factor = args.attn_int_factor
        generate_kwargs = dict(max_new_tokens=args.max_new_tokens, penalty_alpha=args.penalty_alpha, do_sample=args.do_sample, top_p=args.top_p, top_k=args.top_k, temperature=args.temperature, repetition_penalty=args.repetition_penalty, mode=mode, final_layer=final_layer, base_layer=base_layer)
        if args.data_type == 'mt_bench':
            if sample["category"] in temperature_config:
                temperature = temperature_config[sample["category"]]
            else:
                temperature = 0.7
            if temperature < 1e-4:
                do_sample = False
            else:
                do_sample = True
            generate_kwargs['temperature'] = temperature
            generate_kwargs['do_sample'] = do_sample

        model_completion, gen_seq = llm.generate(
            input_text, reject_sampling_clf=reject_sampling_clf, conversion_matrix=conversion_matrix, 
            extra_prompt_length=extra_prompt_length,
            feat_layer=args.feat_layer,
            attn_steer_factor=attn_steer_factor,
            chunk_size=args.chunk_size, num_candidates=args.num_candidates, **generate_kwargs)
        cropped_model_completion = model_completion
        for stop_word in stop_word_list:
            length_to_remove = len(stop_word)
            if model_completion[-length_to_remove:] == stop_word:
                cropped_model_completion = model_completion[:-length_to_remove]
        cropped_gen_seq = llm.tokenizer(model_completion)['input_ids'][1:]
        return_dict = {
            sample['data_index']: cropped_model_completion.strip()
        }
        # making the file writing operation parallel to model.generate
        fw.write(json.dumps(return_dict, ensure_ascii=False) + '\n')


        fw.flush()

    fw.close()
    # usage
    # CUDA_VISIBLE_DEVICES=3 python -m fastchat.serve.halueval_rejection_sampling --model-name /data/sls/d/llm/llama2/Llama-2-13b-chat-hf/ --data-path nq-open-10_total_documents_gold_at_4.jsonl --output-path tmp-nq-open-sw4-c4n8.jsonl --num-gpus 1 --do_sample --reject-sampling-clf classifiers/classifier_nq-open_mean_token_sw_4_mp.pkl --parallel --total-shard 4 --shard-id 0 --chunk-size 4 --conversion-matrix conversion-tf-lr-Llama-2-7b-13b-chat-hf-nq-open-cot.pkl --num-candidates 8 --load-in-4bit