import editdistance as ed
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import json
import torch
import transformers
import pickle
from tqdm import tqdm

import numpy as np
import pandas as pd
from scipy.stats import skew, kurtosis, entropy
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from statsmodels.tsa.stattools import acf, pacf

from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

tokenizer = transformers.AutoTokenizer.from_pretrained(
    "/data/sls/d/llm/llama2/Llama-2-7b-chat-hf/")


def min_edit_distance_substring(s1, s2):
    len_s1 = len(s1)
    min_edit_dist = float('inf')
    best_substring = None

    assert len(
        s2) >= len_s1, "s2 must be longer than s1\ns1: {}\ns2: {}".format(s1, s2)

    # Slide over s2 to find all substrings of length s1
    for i in range(len(s2) - len_s1 + 1):
        sub_s2 = s2[i:i + len_s1]
        # Calculate edit distance between s1 and this substring
        dist = ed.eval(s1, sub_s2)

        if dist < min_edit_dist:
            min_edit_dist = dist
            best_substring = sub_s2

    return best_substring, min_edit_dist


def find_best_threshold(fpr, tpr, thresholds):
    """
    Find the best threshold from the ROC curve by choosing the point 
    closest to the top-left corner (0,1).

    :param fpr: Array of False Positive Rates
    :param tpr: Array of True Positive Rates
    :param thresholds: Array of thresholds corresponding to each (FPR, TPR) point
    :return: The best threshold value
    """
    # Calculate the Euclidean distance for each point on the ROC curve from the top-left corner
    distances = np.sqrt((1 - tpr) ** 2 + fpr ** 2)

    # Find the index of the smallest distance
    best_idx = np.argmin(distances)

    # Return the threshold at this index
    return thresholds[best_idx]


def calculate_metrics(scores, labels, threshold):
    """
    Calculate precision, recall, F1 score, and average accuracy based on a given threshold.

    :param scores: A list of predicted scores from the model.
    :param labels: A list of ground truth labels.
    :param threshold: The threshold to convert scores to binary classifications.
    :return: A tuple containing precision, recall, F1 score, and accuracy.
    """
    # Convert scores to binary classifications
    predictions = [1 if score >= threshold else 0 for score in scores]

    # Calculate metrics
    precision = precision_score(labels, predictions)
    recall = recall_score(labels, predictions)
    f1 = f1_score(labels, predictions)
    accuracy = accuracy_score(labels, predictions)

    subset_acc_where_label_0 = accuracy_score([labels[i] for i in range(len(labels)) if labels[i] == 0], [
                                              predictions[i] for i in range(len(labels)) if labels[i] == 0])
    subset_acc_where_label_1 = accuracy_score([labels[i] for i in range(len(labels)) if labels[i] == 1], [
                                              predictions[i] for i in range(len(labels)) if labels[i] == 1])
    # harmonic mean of subset accuracy 2*corrects[0]*corrects[1]/(correct*len(list_data_dict[0]))
    harmonic_mean_accuracy = 2*subset_acc_where_label_0*subset_acc_where_label_1 / \
        (subset_acc_where_label_0+subset_acc_where_label_1)

    return precision, recall, f1, accuracy, subset_acc_where_label_0, subset_acc_where_label_1, harmonic_mean_accuracy


def load_files(anno_file, attn_file, crop_hallu_span=False, verbose=False, sequential=False, is_feat=False, feat_layer=32):
    data_type = "summarization" if "summarization" in attn_file else "nq-open"
    anno_data = []

    with open(anno_file, 'r') as f:
        for line in f:
            anno_data.append(json.loads(line))

    attn_data = torch.load(attn_file)

    # Assuming `focs_tensor` is your FoCS tensor with shape (num_examples, num_layers, num_heads, num_new_tokens)
    # Assuming `labels` is a tensor with shape (num_examples,) indicating hallucination (1) or non-hallucination (0)

    num_examples = len(anno_data)
    if not crop_hallu_span:
        # focs means focus on context score
        # shape: (num_examples, num_layers, num_heads, num_new_tokens)
        if not is_feat:
            focs_tensor = [x['attn_scores'] for x in attn_data]
        else:
            focs_tensor = [x['extracted_hiddens'][feat_layer].transpose(1, 0).unsqueeze(0) for x in attn_data]
        # foss means focus on sink score
        # shape: (num_examples, num_layers, num_heads, num_new_tokens)
        # foss_tensor = None #[x['attn_scores_on_sink'] for x in attn_data]
        # fonss means focus on non-sink score
        # shape: (num_examples, num_layers, num_heads, num_new_tokens)
        # fonss_tensor = None #[x['attn_scores_no_sink'] for x in attn_data]
        # num_layers, num_heads = focs_tensor[0].shape[:2]
        labels = np.array([int(x['decision']) if x['decision']
                          is not None else 0 for x in anno_data])
    else:
        focs_tensor = []
        # foss_tensor = []
        # fonss_tensor = []
        labels = []
        skipped_examples = 0
        for idx in range(len(anno_data)):
            hallu_label = [0] * len(attn_data[idx]['model_completion_ids'])
            is_hallu = (
                not anno_data[idx]['decision']) if anno_data[idx]['decision'] is not None else True
            if is_hallu:
                tokenized_hallucination = tokenizer(
                    anno_data[idx]['response'], return_offsets_mapping=True)
                hallucination_text2ids = tokenized_hallucination['input_ids'][1:]
                hallucination_token_offsets = tokenized_hallucination['offset_mapping'][1:]
                hallucination_attn_ids = attn_data[idx]['model_completion_ids'].tolist(
                )
                # drop the final token if == 2
                if hallucination_attn_ids[-1] == 2:
                    hallucination_attn_ids = hallucination_attn_ids[:-1]
                mismatch = False
                if not hallucination_text2ids == hallucination_attn_ids:
                    # compute the maximum common substring
                    best_substring, min_edit_dist = min_edit_distance_substring(hallucination_text2ids, hallucination_attn_ids) if len(
                        hallucination_text2ids) < len(hallucination_attn_ids) else min_edit_distance_substring(hallucination_attn_ids, hallucination_text2ids)
                    if min_edit_dist < 5:
                        if verbose:
                            print(
                                "Usable example with min edit distance:", min_edit_dist)
                        # it means tokenizer.decode and tokenizer.encode are not consistent
                        mismatch = True
                        # import ipdb; ipdb.set_trace()
                        # best_substring, min_edit_dist = min_edit_distance_substring(hallucination_text2ids, hallucination_attn_ids)
                    else:
                        if verbose:
                            print(
                                "Skip example:", f"\n{hallucination_text2ids}\n != \n{hallucination_attn_ids}\n")
                        skipped_examples += 1
                        continue
                        import ipdb
                        ipdb.set_trace()
                # get hallucinated spans from anno_data[idx]['problematic_spans']
                hallucinated_spans = anno_data[idx]['problematic_spans']
                # use the offset of tokenizer to get the span ids positions in the tokenizer(anno_data[idx]['response'])['input_ids']
                hallucinated_spans_token_offsets = []
                for span_text in hallucinated_spans:
                    if not span_text in anno_data[idx]['response']:
                        if verbose:
                            print(
                                "Warning:", f"\n{span_text}\n not in \n{anno_data[idx]['response']}\n")
                        if len(span_text) > len(anno_data[idx]['response']):
                            span_text = anno_data[idx]['response']
                        else:
                            best_substring, min_edit_dist = min_edit_distance_substring(
                                span_text, anno_data[idx]['response'])
                            if verbose:
                                print(
                                    f"Best substring: {best_substring}, min_edit_dist: {min_edit_dist}")
                            span_text = best_substring
                    span_start_char_pos = anno_data[idx]['response'].index(
                        span_text)
                    span_end_char_pos = span_start_char_pos + len(span_text)
                    # use hallucination_token_offsets to get the span ids positions in the tokenizer(anno_data[idx]['response'])['input_ids']
                    # format of the offset_mapping: [(token 1 start_char_pos, token 1 end_char_pos), (token 2 start_char_pos, token 2 end_char_pos), ...]
                    span_start_token_pos = -1
                    span_end_token_pos = -1

                    for i, (start_char_pos, end_char_pos) in enumerate(hallucination_token_offsets):
                        if end_char_pos >= span_start_char_pos and span_start_token_pos == -1:
                            span_start_token_pos = i
                        if end_char_pos >= span_end_char_pos and span_end_token_pos == -1:
                            span_end_token_pos = i
                            break

                    assert span_start_token_pos != -1 and span_end_token_pos != -1
                    hallucinated_spans_token_offsets.append(
                        (span_start_token_pos, span_end_token_pos))
                    min_edit_dist_value = float('inf')
                    min_edit_dist_span_start_token_pos = -1
                    min_edit_dist_span_end_token_pos = -1
                    if mismatch:  # check
                        decoded_span = tokenizer.decode(
                            hallucination_attn_ids[span_start_token_pos:span_end_token_pos+1])
                        edit_dist = ed.eval(span_text, decoded_span)
                        move_total_steps = edit_dist
                        if not span_text == decoded_span:
                            min_edit_dist = abs(
                                len(span_text) - len(decoded_span))
                            # best_substring, min_edit_dist = min_edit_distance_substring(span_text, decoded_span) if len(span_text) < len(decoded_span) else min_edit_distance_substring(decoded_span, span_text)
                            if verbose:
                                print("Mismatched check:",
                                      f"\n{span_text}\n != \n{decoded_span}\n")
                            # try to move the span_start_token_pos and span_end_token_pos within the min_edit_dist
                            exact_match_found = False
                            for move_dist in range(-move_total_steps, move_total_steps+1):
                                if span_start_token_pos + move_dist < len(hallucination_attn_ids) and span_end_token_pos + move_dist < len(hallucination_attn_ids):
                                    decoded_span = tokenizer.decode(
                                        hallucination_attn_ids[span_start_token_pos+move_dist:span_end_token_pos+1+move_dist])
                                    if span_text == decoded_span:
                                        if verbose:
                                            print(
                                                "Matched check after moving:", f"\n{span_text}\n == \n{decoded_span}\n")
                                        span_start_token_pos += move_dist
                                        span_end_token_pos += move_dist
                                        exact_match_found = True
                                        break
                                    else:
                                        edit_dist = ed.eval(
                                            span_text, decoded_span)
                                        if edit_dist < min_edit_dist_value:
                                            min_edit_dist_value = edit_dist
                                            min_edit_dist_span_start_token_pos = span_start_token_pos + move_dist
                                            min_edit_dist_span_end_token_pos = span_end_token_pos + move_dist
                            # if still not break, perform grid search with double for loop
                            for move_dist in range(-move_total_steps, move_total_steps+1):
                                for move_dist2 in range(-move_total_steps, move_total_steps+1):
                                    if span_start_token_pos + move_dist < len(hallucination_attn_ids) and span_end_token_pos + move_dist2 < len(hallucination_attn_ids):
                                        decoded_span = tokenizer.decode(
                                            hallucination_attn_ids[span_start_token_pos+move_dist:span_end_token_pos+1+move_dist2])
                                        if span_text == decoded_span:
                                            if verbose:
                                                print(
                                                    "Matched check after moving:", f"\n{span_text}\n == \n{decoded_span}\n")
                                            span_start_token_pos += move_dist
                                            span_end_token_pos += move_dist2
                                            exact_match_found = True
                                            break
                                        else:
                                            edit_dist = ed.eval(
                                                span_text, decoded_span)
                                            if edit_dist < min_edit_dist_value:
                                                min_edit_dist_value = edit_dist
                                                min_edit_dist_span_start_token_pos = span_start_token_pos + move_dist
                                                min_edit_dist_span_end_token_pos = span_end_token_pos + move_dist
                                if exact_match_found:
                                    break

                            if not exact_match_found:
                                if verbose:
                                    print(
                                        f"No exact match found after moving the {span_start_token_pos} and {span_end_token_pos} in the range of {-min_edit_dist} to {min_edit_dist}")
                            if min_edit_dist_span_start_token_pos != -1 and min_edit_dist_value < 5:
                                span_start_token_pos = min_edit_dist_span_start_token_pos
                                span_end_token_pos = min_edit_dist_span_end_token_pos
                                if verbose:
                                    print(
                                        f"Adopt the best match with min edit distance: {min_edit_dist_value}")
                                decoded_span = tokenizer.decode(
                                    hallucination_attn_ids[span_start_token_pos:span_end_token_pos+1])
                                if verbose:
                                    print("Matched check after moving:",
                                          f"\n{span_text}\n ~= \n{decoded_span}\n")
                        else:
                            if verbose:
                                print("Matched check:",
                                      f"\n{span_text}\n == \n{decoded_span}\n")

                if len(hallucinated_spans_token_offsets) == 0:
                    if verbose:
                        print("Skip example:", "No hallucinated spans found")
                    skipped_examples += 1
                    continue
                if not sequential:
                    tmp_focs_tensor = []
                    # tmp_foss_tensor = []
                    # tmp_fonss_tensor = []
                    for i, (s, e) in enumerate(hallucinated_spans_token_offsets):
                        # attn_data[idx]['attn_scores'] shape: (num_layers, num_heads, num_new_tokens)
                        # only extract the attention scores for the tokens in the span
                        # it can have multi spans for one example, so need to concatenate them
                        if i == 0 and s > 0:
                            # extract a non-hallucination span from the beginning of the response
                            if not is_feat:
                                focs_tensor.append(
                                    attn_data[idx]['attn_scores'][:, :, :s])
                            else:
                                focs_tensor.append(
                                    attn_data[idx]['extracted_hiddens'][feat_layer].transpose(1, 0).unsqueeze(0)[:, :, :s])
                            # foss_tensor.append(
                            #     attn_data[idx]['attn_scores_on_sink'][:, :, :s])
                            # fonss_tensor.append(
                            #     attn_data[idx]['attn_scores_no_sink'][:, :, :s])
                            labels.append(1)
                        if not is_feat:
                            tmp_focs_tensor.append(
                                attn_data[idx]['attn_scores'][:, :, s:e+1])
                        else:
                            tmp_focs_tensor.append(
                                attn_data[idx]['extracted_hiddens'][feat_layer].transpose(1, 0).unsqueeze(0)[:, :, s:e+1])
                        # tmp_foss_tensor.append(
                        #     attn_data[idx]['attn_scores_on_sink'][:, :, s:e+1])
                        # tmp_fonss_tensor.append(
                        #     attn_data[idx]['attn_scores_no_sink'][:, :, s:e+1])
                    focs_tensor.append(torch.cat(tmp_focs_tensor, dim=-1))
                    # foss_tensor.append(torch.cat(tmp_foss_tensor, dim=-1))
                    # fonss_tensor.append(torch.cat(tmp_fonss_tensor, dim=-1))
                    labels.append(0)
                    if e < len(hallucination_token_offsets) - 1:
                        # extract a non-hallucination span from the end of the response
                        if not is_feat:
                            focs_tensor.append(
                                attn_data[idx]['attn_scores'][:, :, e+1:])
                        else:
                            focs_tensor.append(
                                attn_data[idx]['extracted_hiddens'][feat_layer].transpose(1, 0).unsqueeze(0)[:, :, e+1:])
                        # foss_tensor.append(
                        #     attn_data[idx]['attn_scores_on_sink'][:, :, e+1:])
                        # fonss_tensor.append(
                        #     attn_data[idx]['attn_scores_no_sink'][:, :, e+1:])
                        labels.append(1)
                else:
                    if not is_feat:
                        sequential_labels = [1] * \
                            attn_data[idx]['attn_scores'].shape[-1]
                        for i, (s, e) in enumerate(hallucinated_spans_token_offsets):
                            sequential_labels[s:e+1] = [0] * (e-s+1)
                        focs_tensor.append(attn_data[idx]['attn_scores'][:, :, :])
                    else:
                        sequential_labels = [1] * \
                            attn_data[idx]['extracted_hiddens'][feat_layer].shape[0]
                        for i, (s, e) in enumerate(hallucinated_spans_token_offsets):
                            sequential_labels[s:e+1] = [0] * (e-s+1)
                        focs_tensor.append(attn_data[idx]['extracted_hiddens'][feat_layer].transpose(1, 0).unsqueeze(0))
                    # foss_tensor.append(
                    #     attn_data[idx]['attn_scores_on_sink'][:, :, :])
                    # fonss_tensor.append(
                    #     attn_data[idx]['attn_scores_no_sink'][:, :, :])
                    labels.append(sequential_labels)
            else:
                if not is_feat:
                    focs_tensor.append(attn_data[idx]['attn_scores'])
                    # foss_tensor.append(attn_data[idx]['attn_scores_on_sink'])
                    # fonss_tensor.append(attn_data[idx]['attn_scores_no_sink'])
                    if sequential:
                        labels.append([1] * attn_data[idx]['attn_scores'].shape[-1])
                    else:
                        labels.append(1)
                else:
                    focs_tensor.append(attn_data[idx]['extracted_hiddens'][feat_layer].transpose(1, 0).unsqueeze(0))
                    if sequential:
                        labels.append([1] * attn_data[idx]['extracted_hiddens'][feat_layer].shape[0])
                    else:
                        labels.append(1)
        if not sequential:
            labels = np.array(labels)
        if verbose:
            print("Skipped examples:", skipped_examples)

    # stack after padding to the same length in the dim=-1 (num_new_tokens)

    return focs_tensor, labels #foss_tensor, fonss_tensor, 


def convert_to_token_level(focs_tensor, labels, sliding_window=1, sequential=False, min_pool_target=False):
    # convert to token level
    focs_tensor_token_level = []
    # foss_tensor_token_level = []
    # fonss_tensor_token_level = []
    labels_token_level = []
    for i in range(len(focs_tensor)):
        num_layers, num_heads, num_new_tokens = focs_tensor[i].shape
        if sliding_window == 1:
            for j in range(num_new_tokens):
                focs_tensor_token_level.append(
                    focs_tensor[i][:, :, j].unsqueeze(-1))
                # foss_tensor_token_level.append(
                #     foss_tensor[i][:, :, j].unsqueeze(-1))
                # fonss_tensor_token_level.append(
                #     fonss_tensor[i][:, :, j].unsqueeze(-1))
                if sequential:
                    labels_token_level.append(labels[i][j])
                else:
                    labels_token_level.append(labels[i])
        else:
            for j in range(sliding_window-1, num_new_tokens):
                # .mean(dim=-1).unsqueeze(-1))
                focs_tensor_token_level.append(
                    focs_tensor[i][:, :, j-sliding_window+1:j+1])
                # .mean(dim=-1).unsqueeze(-1))
                # foss_tensor_token_level.append(
                #     foss_tensor[i][:, :, j-sliding_window+1:j+1])
                # .mean(dim=-1).unsqueeze(-1))
                # fonss_tensor_token_level.append(
                #     fonss_tensor[i][:, :, j-sliding_window+1:j+1])
                if sequential:
                    labels_token_level.append(
                        min(labels[i][j-sliding_window+1:j+1]) if min_pool_target else labels[i][j])
                else:
                    labels_token_level.append(labels[i])
    return focs_tensor_token_level, labels_token_level # foss_tensor_token_level, fonss_tensor_token_level, 


def count_peaks(arr):
    # Make sure the input is a numpy array
    arr = np.array(arr)

    # Get the indices of local maxima along the first dimension
    # A peak is an element that is greater than its immediate neighbors
    left_neighbors = arr[:-2, :]
    center_elements = arr[1:-1, :]
    right_neighbors = arr[2:, :]

    # Determine where center elements are greater than their neighbors
    is_peak = (center_elements > left_neighbors) & (
        center_elements > right_neighbors)

    # Count the number of peaks along the first dimension
    peak_count = np.sum(is_peak, axis=0)

    return peak_count

# Function to extract time-based features


def extract_time_series_features(focs_tensor, selected_features=['mean', 'entropy'], selected_tensors=['focs'], selected_layer_heads=None, pick_percentile_heads=None, picked_head_indices=None):
    features = []
    num_examples = len(focs_tensor)
    num_layers, num_heads = focs_tensor[0].shape[:2]
    # Loop over each example to extract features
    baseline_predictions = []
    detailed_feature_names = []
    foss_tensor, fonss_tensor = None, None
    for i in tqdm(range(len(focs_tensor))):
        feature_list = []
        feature_names = []
        for f_tensor, f_name in zip([focs_tensor, foss_tensor, fonss_tensor], ['focs', 'foss', 'fonss']):
            if f_name not in selected_tensors:
                continue
            example_org = f_tensor[i]
            # Flatten the tensor to a 2D matrix
            if selected_layer_heads is not None:
                # selected_layer_heads is a list of strings in the format of "L1-H5" which means layer 1 head 5, or "L1" which means all heads in layer 1
                example = example_org.clone()
                # indexing over layers (dim=0) and heads (dim=1)
                # lhs = [(int(x.split('-')[0]), int(x.split('-')[1]))
                #        for x in selected_layer_heads]
                lhs = []
                for x in selected_layer_heads:
                    if '-' in x:
                        lhs.append((int(x.split('-')[0][1:]), int(x.split('-')[1][1:])))
                    else:
                        lhs.extend([(int(x[1:]), h) for h in range(num_heads)])
                # Convert list of tuples into a tensor
                indices = torch.tensor(lhs)
                selected_rows = indices[:, 0]  # All row indices
                selected_cols = indices[:, 1]  # All column indices
                example = example_org[selected_rows, selected_cols]
            elif picked_head_indices is not None:
                example = example_org.clone()
                example = example.view(-1, example.shape[2])
                example = example[picked_head_indices.copy(), :]
            elif pick_percentile_heads is not None:
                example = example_org.clone()
                # (num_layers * num_heads, num_new_tokens)
                example = example.view(-1, example.shape[2])
                # sort the heads by the mean of the attention weights over all new tokens
                mean_attention = example.mean(dim=1)
                sorted_indices = torch.argsort(mean_attention)
                # for example, if pick_percentile_heads=[0.25, 0.50, 0.75], then we will pick the 25th, 50th, and 75th percentile heads based on the mean attention weights
                selected_indices = [int(sorted_indices[int(
                    pick_percentile_heads[i] * len(sorted_indices))]) for i in range(len(pick_percentile_heads))]
                example = example[selected_indices, :]
            else:
                # use all layer heads
                example = example_org.clone()
                example = example.view(-1, example.shape[2])
            example = example.transpose(0, 1)
            # Now the shape is (num_new_tokens, num_layers * num_heads), which can be treated as a time series in the shape of (T, D) where T is the number of time steps and D is the number of dimensions
            T = example.shape[0]
            D = example.shape[1]
            # Initialize a dictionary to store features
            if f_name == selected_tensors[0]:
                baseline_predictions.append(example.mean(dim=1).mean(0).item())

            # Calculate mean, standard deviation, skewness, and kurtosis for each token position
            if 'mean' in selected_features:
                means = example.mean(dim=0).numpy()
                feature_list.append(means)
                feature_names.append(f_name+'-mean')
            if 'median' in selected_features:
                medians = np.median(example.numpy(), axis=0)
                feature_list.append(medians)
                feature_names.append(f_name+'-median')
            if 'std_dev' in selected_features:
                std_devs = example.std(dim=0).numpy()
                feature_list.append(std_devs)
                feature_names.append(f_name+'-std_dev')
            if 'skewness' in selected_features:
                skewnesses = skew(example.numpy(), axis=0)
                feature_list.append(skewnesses)
                feature_names.append(f_name+'-skewness')
            if 'kurtosis' in selected_features:
                kurtoses = kurtosis(example.numpy(), axis=0)
                feature_list.append(kurtoses)
                feature_names.append(f_name+'-kurtosis')
            # entropy of the time series
            if 'entropy' in selected_features:
                # Add a small epsilon to each probability to avoid log(0)
                distribution = example.numpy() + 1e-10
                # Normalize the distribution
                distribution /= distribution.sum(axis=1, keepdims=True)
                entropies = entropy(distribution, axis=0)
                feature_list.append(entropies)
                feature_names.append(f_name+'-entropy')
            # first order difference
            if 'first order diff' in selected_features:
                first_order_diff = np.diff(example.numpy(), axis=0)
                feature_list.append(first_order_diff.mean(axis=0))
                feature_names.append(f_name+'-first order diff')
            # second order difference
            if 'second order diff' in selected_features:
                second_order_diff = np.diff(first_order_diff, axis=0)
                feature_list.append(second_order_diff.mean(axis=0))
                feature_names.append(f_name+'-second order diff')
            # Peakedness:
            # Peakedness measures the height of the peak relative to the spread of the distribution. This can be calculated as the maximum value of the distribution divided by a measure of spread (e.g., standard deviation).
            # To compute peakedness, you can use np.max(attention_weights, axis=0) / np.std(attention_weights, axis=0).
            if 'peakedness' in selected_features:
                peakedness = np.max(example.numpy(), axis=0) / \
                    np.std(example.numpy(), axis=0)
                feature_list.append(peakedness)
                feature_names.append(f_name+'-peakedness')
            # Interquartile Range (IQR):
            # The IQR measures the range between the 25th and 75th percentiles, indicating how much data is concentrated in the middle 50% of the distribution.
            # A smaller IQR indicates a sharper distribution. To calculate IQR, you can use np.percentile(attention_weights, 75, axis=0) - np.percentile(attention_weights, 25, axis=0).
            if 'IQR' in selected_features:
                IQR = np.percentile(example.numpy(), 75, axis=0) - \
                    np.percentile(example.numpy(), 25, axis=0)
                feature_list.append(IQR)
                feature_names.append(f_name+'-IQR')
            # # frequence domain features
            # # Calculate the Fourier transform
            # fourier_transform = np.fft.fft(example.numpy(), axis=0)
            # # get the corresponding frequencies
            # freqs = np.fft.fftfreq(T)
            # # Calculate the power spectrum
            # power_spectrum = np.abs(fourier_transform) ** 2
            # # Calculate the peak frequency
            # max_freqs_idx = np.argmax(power_spectrum[1:, :], axis=0)
            # max_freqs = freqs[max_freqs_idx + 1]

            # Number of local maximas (peaks)
            # num_peaks = count_peaks(example.numpy())

            # Feature names are: means-L1-H1, means-L1-H2, ..., means-L2-H1, ..., std_devs-L1-H1, ..., local_maximas-L1-H1, ...
            # L means layers, H means heads, they are flattened in the feature vector (32*32=1024) for each token position
            if i == 0:
                h_index = 0
                for feature_name in feature_names:
                    for l in range(num_layers):
                        for h in range(num_heads):
                            if picked_head_indices is not None and h_index not in picked_head_indices:
                                continue
                            detailed_feature_names.append(
                                f"{feature_name}-L{l}-H{h}")
                            h_index += 1
        if i == 0:
            print(f"Use features: {feature_names}")

        # Concatenate the features into a vector
        feature_vector = np.concatenate(feature_list, axis=0)
        if np.isnan(feature_vector).any():
            import ipdb
            ipdb.set_trace()
        features.append(feature_vector)

        # # Spectral features (Fourier transform multi-dimensional time series)
        # # Calculate the Fourier transform
        # fourier_transform = np.fft.fft(example.numpy(), axis=0)
        # # Calculate the power spectrum
        # power_spectrum = np.abs(fourier_transform) ** 2
        # # Calculate the peak frequency
        # peak_freqs = np.argmax(power_spectrum, axis=0)

        # # Number of local maxima

        # # Autocorrelation and partial autocorrelation
        # autocorrelations = np.zeros(D)
        # partial_autocorrelations = np.zeros(D)
        # for d in range(D):
        #     autocorrelations[d] = acf(example[:, d], nlags=1)[1]
        #     partial_autocorrelations[d] = pacf(example[:, d], nlags=1)[1]

        # Store the calculated features in the dictionary
        # feature_dict['mean'] = (means)
        # feature_dict['std_dev'] = (std_devs)
        # feature_dict['skewness'] = (skewnesses)
        # feature_dict['kurtosis'] = (kurtoses)
        # feature_dict['peak_freq'] = power_spectrum
        # feature_dict['autocorrelation'] = np.mean(autocorrelations)
        # feature_dict['partial_autocorrelation'] = np.mean(partial_autocorrelations)

        # features.append(feature_dict)
    return np.array(features), detailed_feature_names, baseline_predictions
    # return pd.DataFrame(features)


def main(anno_file_1, attn_file_1, anno_file_2, attn_file_2, 
         selected_features=['mean', 'entropy'], 
         selected_tensors=['foss'], 
         selected_layer_heads=None, 
         pick_percentile_heads=None, 
         crop_hallu_span=False, 
         token_level=False, 
         sliding_window=1, 
         sequential=False,
         weight_balance=False,
         min_pool_target=False,
         is_feat=False,
         feat_layer=32,
         penalty='l2',
         disable_best_threshold=False,
         two_fold=False,
         conversion=None,
         picked_head_indices=None,
        ):
    comb1 = (anno_file_1, attn_file_1, anno_file_2, attn_file_2)
    comb2 = (anno_file_2, attn_file_2, anno_file_1, attn_file_1)
    if conversion is None:
        all_combs = [comb1, comb2]
    else:
        all_combs = [comb1]
    output_table = []
    output_small_table = []
    output_table.append(["anno_file_1", "attn_file_1", "anno_file_2", "attn_file_2", "baseline_auroc", "train_accuracy", "train_auroc", "test_accuracy", "test_auroc", "transfer_accuracy",
                        "transfer_auroc", "train_recall", "test_recall", "transfer_recall", "train_precision", "test_precision", "transfer_precision", "train_f1", "test_f1", "transfer_f1"])
    output_small_table.append(
        ['Train AUROC', 'Test AUROC', 'Transfer AUROC', 'Train F1', 'Test F1', 'Transfer F1']
    )
    for anno_file, attn_file, transfer_anno_file, transfer_attn_file in all_combs:
        print(f"======== Loading data from {anno_file} and {attn_file}...")
        # load data
        focs_tensor, labels = load_files(
            anno_file, attn_file, crop_hallu_span=crop_hallu_span, sequential=sequential,
            is_feat=is_feat, feat_layer=feat_layer)
        if token_level:
            focs_tensor, labels = convert_to_token_level(
                focs_tensor, labels, sliding_window=sliding_window, sequential=sequential, min_pool_target=min_pool_target)

        # Compute the class ratio from `labels`
        try:
            ratio_class_0 = (len(labels) - sum(labels)) / len(labels)
            ratio_class_1 = sum(labels) / len(labels)
        except:
            import ipdb; ipdb.set_trace()
        print("Class ratio: {:.2f} (class 0), {:.2f} (class 1)".format(
            ratio_class_0, ratio_class_1))

        # Compute the class weights from `labels`
        weight_class_0 = 1.0 / ratio_class_0
        weight_class_1 = 1.0 / ratio_class_1

        # Balance the class by applying weights to the loss function
        class_weights = {0: weight_class_0, 1: weight_class_1}
        print("Class weights: ", class_weights)
        # Extract features from the time series
        time_series_features, feature_names, baseline_predictions = extract_time_series_features(
            focs_tensor, selected_features=selected_features, 
            selected_tensors=selected_tensors, selected_layer_heads=selected_layer_heads, 
            pick_percentile_heads=pick_percentile_heads, picked_head_indices=picked_head_indices)

        # Baseline prediction AUROC
        baseline_auroc = roc_auc_score(labels, baseline_predictions)
        print(f"Baseline attn_score AUROC: {baseline_auroc:.9f}")

        total_train_accuracy, total_train_auroc = 0, 0
        total_test_accuracy, total_test_auroc = 0, 0
        total_train_recall, total_test_recall = 0, 0
        total_train_precision, total_test_precision = 0, 0
        total_train_f1, total_test_f1 = 0, 0
        if conversion is None:
            # Train-test split
            if two_fold:
                X_train, X_test, y_train, y_test = train_test_split(
                    time_series_features, labels, test_size=0.5, random_state=42)
                datasets = [(X_train, y_train, X_test, y_test), (X_test, y_test, X_train, y_train)]
            else:
                X_train, X_test, y_train, y_test = train_test_split(
                    time_series_features, labels, test_size=0.2, random_state=42)
                datasets = [(X_train, y_train, X_test, y_test)]

            for X_train, y_train, X_test, y_test in datasets:
                # if not sequential:
                # Train a classifier to evaluate feature importance
                # , class_weight=class_weights)
                if penalty == 'l1':
                    classifier = LogisticRegression(max_iter=1000, penalty=penalty, solver='liblinear') if not weight_balance else LogisticRegression(
                        max_iter=1000, class_weight=class_weights, penalty=penalty, solver='liblinear')
                else:
                    classifier = LogisticRegression(max_iter=1000, penalty=penalty) if not weight_balance else LogisticRegression(
                        max_iter=1000, class_weight=class_weights, penalty=penalty)
                classifier.fit(X_train, y_train)

                # Train accuracy
                y_pred = classifier.predict(X_train)
                y_pred_proba = classifier.predict_proba(X_train)[:, 1]
                train_accuracy = accuracy_score(y_train, y_pred)
                train_auroc = roc_auc_score(y_train, y_pred_proba)
                total_train_accuracy += train_accuracy
                total_train_auroc += train_auroc
                if not disable_best_threshold:
                    best_threshold = find_best_threshold(*roc_curve(y_train, y_pred_proba))
                    print(f"Best threshold: {best_threshold:.9f}")
                else:
                    best_threshold = 0.5
                    print(f"Use default threshold: {best_threshold:.9f}")
                train_recall, train_precision, train_f1, train_accuracy, subset_acc_where_label_0, subset_acc_where_label_1, harmonic_mean_accuracy = calculate_metrics(
                    y_pred_proba, y_train, best_threshold)
                total_train_recall += train_recall
                total_train_precision += train_precision
                total_train_f1 += train_f1

                print(
                    f"Train Accuracy of the classifier: {train_accuracy:.9f}, AUROC: {train_auroc:.9f}")

                # Evaluate accuracy
                y_pred = classifier.predict(X_test)
                y_pred_proba = classifier.predict_proba(X_test)[:, 1]
                accuracy = accuracy_score(y_test, y_pred)
                auroc = roc_auc_score(y_test, y_pred_proba)
                total_test_accuracy += accuracy
                total_test_auroc += auroc
                if not disable_best_threshold:
                    best_threshold = find_best_threshold(*roc_curve(y_test, y_pred_proba))
                    print(f"Best threshold: {best_threshold:.9f}")
                else:
                    best_threshold = 0.5
                    print(f"Use default threshold: {best_threshold:.9f}")
                recall, precision, f1, accuracy, subset_acc_where_label_0, subset_acc_where_label_1, harmonic_mean_accuracy = calculate_metrics(
                    y_pred_proba, y_test, best_threshold)
                total_test_recall += recall
                total_test_precision += precision
                total_test_f1 += f1

                print(
                    f"Test Accuracy of the classifier: {accuracy:.9f}, AUROC: {auroc:.9f}")

                # Feature importance
                if not hasattr(classifier, 'coef_'):
                    feature_importance = classifier.feature_importances_
                    important_features = sorted(
                        zip(feature_names, feature_importance), key=lambda x: x[1], reverse=True)
                else:
                    feature_importance = classifier.coef_[0]
                    important_features = sorted(
                        zip(feature_names, feature_importance), key=lambda x: abs(x[1]), reverse=True)

                print("Top-10 important features:")
                for feature, importance in important_features[:10]:
                    print(f"{feature}: {importance:.9f}")

            total_train_accuracy /= len(datasets)
            total_train_auroc /= len(datasets)
            total_train_recall /= len(datasets)
            total_train_precision /= len(datasets)
            total_train_f1 /= len(datasets)
            total_test_accuracy /= len(datasets)
            total_test_auroc /= len(datasets)
            total_test_recall /= len(datasets)
            total_test_precision /= len(datasets)
            total_test_f1 /= len(datasets)

        # Train a classifier on 100% of the data
        if penalty == 'l1':
            classifier = LogisticRegression(max_iter=1000, penalty=penalty, solver='liblinear') if not weight_balance else LogisticRegression(
                max_iter=1000, class_weight=class_weights, penalty=penalty, solver='liblinear')
        else:
            classifier = LogisticRegression(max_iter=1000, penalty=penalty) if not weight_balance else LogisticRegression(
                max_iter=1000, class_weight=class_weights, penalty=penalty)
        classifier.fit(time_series_features, labels)

        # get best threshold on the whole dataset
        y_pred = classifier.predict(time_series_features)
        y_pred_proba = classifier.predict_proba(time_series_features)[:, 1]
        if not disable_best_threshold:
            best_threshold = find_best_threshold(*roc_curve(labels, y_pred_proba))
            print(f"Best threshold: {best_threshold:.9f}")
        else:
            best_threshold = 0.5
            print(f"Use default threshold: {best_threshold:.9f}")

        # save classifier
        prediction_level = (
            'token' if token_level else 'span') if crop_hallu_span else 'sentence'
        if prediction_level == 'token':
            prediction_level += f'_sw_{sliding_window}'
            if min_pool_target:
                prediction_level += '_mp'
        if is_feat:
            prediction_level += f'_feat_{feat_layer}'
        prediction_level += f'_reg_{penalty}'
        if selected_layer_heads is not None:
            prediction_level += f'_LH_{",".join(selected_layer_heads)}'

        with open(f"classifier_{anno_file.replace('gpt4-anno-v2-', '').replace('gpt4o-anno-greedy-', '').replace('.jsonl', '')}_{','.join(selected_features)}_{prediction_level}.pkl", 'wb') as f:
            pickle.dump(
                {'clf': classifier, 'best_threshold': best_threshold}, f)

        # Transfer the classifier to the other dataset
        print(
            f"======== Transfer to data from {transfer_anno_file} and {transfer_attn_file}...")
        transfer_focs_tensor, transfer_labels = load_files(
            transfer_anno_file, transfer_attn_file, crop_hallu_span=crop_hallu_span, sequential=sequential,
            is_feat=is_feat, feat_layer=feat_layer)
        if token_level:
            transfer_focs_tensor, transfer_labels = convert_to_token_level(
                transfer_focs_tensor, transfer_labels, sliding_window=sliding_window, sequential=sequential, min_pool_target=min_pool_target)
        transfer_time_series_features, transfer_feature_names, transfer_baseline_predictions = extract_time_series_features(
            transfer_focs_tensor, selected_features=selected_features, 
            selected_tensors=selected_tensors, 
            selected_layer_heads=selected_layer_heads, 
            pick_percentile_heads=pick_percentile_heads,
            picked_head_indices=picked_head_indices)
        # Baseline prediction AUROC
        transfer_auroc = roc_auc_score(
            transfer_labels, transfer_baseline_predictions)
        print(f"Transfer Baseline attn_score AUROC: {transfer_auroc:.9f}")
        if conversion is not None:
            weight = conversion['weights_matrix']
            bias = conversion['intercepts']
            transfer_time_series_features = (torch.tensor(transfer_time_series_features) @ weight.T + bias).numpy()
        y_pred = classifier.predict(transfer_time_series_features)
        y_pred_proba = classifier.predict_proba(
            transfer_time_series_features)[:, 1]
        transfer_accuracy = accuracy_score(transfer_labels, y_pred)
        transfer_auroc = roc_auc_score(transfer_labels, y_pred_proba)
        if not disable_best_threshold:
            best_threshold = find_best_threshold(
                *roc_curve(transfer_labels, y_pred_proba))
            print(f"Transfer Best threshold: {best_threshold:.9f}")
        else:
            best_threshold = 0.5
            print(f"Use default threshold: {best_threshold:.9f}")
        transfer_recall, transfer_precision, transfer_f1, transfer_accuracy, subset_acc_where_label_0, subset_acc_where_label_1, harmonic_mean_accuracy = calculate_metrics(
            y_pred_proba, transfer_labels, best_threshold)
        print(
            f"Transfer Accuracy of the classifier: {transfer_accuracy:.9f}, AUROC: {transfer_auroc:.9f}")
        # make a output table in csv format for all the scores recorded
        src_task = "Summarization" if "summarization" in attn_file else "NQ-Open"
        tgt_task = "Summarization" if "summarization" in transfer_attn_file else "NQ-Open"

        output_table.append(['+'.join(selected_tensors), '+'.join(selected_features), src_task, tgt_task, baseline_auroc, total_train_accuracy, total_train_auroc, total_test_accuracy, total_test_auroc,
                            transfer_accuracy, transfer_auroc, total_train_recall, total_test_recall, transfer_recall, total_train_precision, total_test_precision, transfer_precision, total_train_f1, total_test_f1, transfer_f1])
        output_small_table.append(
            [total_train_auroc, total_test_auroc, transfer_auroc, total_train_f1, total_test_f1, transfer_f1]
        )
    print("======== Output table:")
    for row in output_table:
        print(', '.join([str(x) for x in row]))
    print("======== Output small table:")
    for row in output_small_table:
        print(', '.join([str(x) for x in row]))


if __name__ == "__main__":

    # if len(sys.argv) == 1:
    # anno_file_1 = "gpt4-anno-v2-summarization-1000.jsonl"
    # attn_file_1 = "greedy-Llama-2-7b-chat-hf-summarization-1000-for-attn-sink.jsonl_attn_scores_0.pt"
    # # anno_file_2 = "gpt4-anno-v2-nq-open.jsonl"
    # # attn_file_2 = "greedy-Llama-2-7b-chat-hf-nq-open-for-attn.jsonl_attn_scores_0.pt"
    # anno_file_2 = "gpt4-anno-v2-llama2-13b-chat-summarization-1000.jsonl"
    # attn_file_2 = "greedy-Llama-2-13b-chat-hf-summarization-1000-for-attn-sink.jsonl_attn_scores_0.pt"


    # anno_file_2 = "gpt4-anno-v2-llama2-13b-chat-nq-open.jsonl"
    # attn_file_2 = "greedy-Llama-2-13b-chat-hf-nq-open-for-attn-sink.jsonl_attn_scores_0.pt"

    # elif len(sys.argv) == 5:
    #     anno_file_1 = sys.argv[1]
    #     attn_file_1 = sys.argv[2]
    #     anno_file_2 = sys.argv[3]
    #     attn_file_2 = sys.argv[4]
    # selected_features=sys.argv[1].split(',')
    # # change to --features mean,entropy,peakedness,IQR
    # selected_tensors=sys.argv[2].split(',')
    # # change to --tensors focs,foss,fonss
    # selected_layer_heads = sys.argv[3].split(',') if len(sys.argv) > 3 and sys.argv[3] != 'None' else None
    # # change to --layer_heads 1-5,2-3,3-7
    # pick_percentile_heads = [float(x) for x in sys.argv[4].split(',')] if len(sys.argv) > 4 and sys.argv[4] != 'None' else None
    # # change to --pick_percentile_heads 0.25,0.50,0.75
    # crop_hallu_span = True if len(sys.argv) > 5 and sys.argv[5] == 'True' else False
    # # change to --crop_hallu_span
    # token_level = True if len(sys.argv) > 6 and sys.argv[6] == 'token' else False
    # # change to --token_level
    # sequential_level = True if len(sys.argv) > 6 and sys.argv[6] == 'seq' else False
    # # change to --sequential_level
    # sliding_window = int(sys.argv[7]) if len(sys.argv) > 7 and sys.argv[7] != 'None' else 1
    # # change to --sliding_window 1
    # mode = 'example' | 'token' | 'span' | 'sequential'
    # Usage
    import argparse
    parser = argparse.ArgumentParser(
        description="Process some features and tensors.")

    parser.add_argument('--features', type=str, required=True,
                        help='Comma-separated list of selected features (e.g., mean,entropy,peakedness,IQR)')
    parser.add_argument('--tensors', type=str, required=True,
                        help='Comma-separated list of selected tensors (e.g., focs,foss,fonss)')
    parser.add_argument('--layer_heads', type=str, default=None,
                        help='Comma-separated list of selected layer heads (e.g., L1-H5,L2-H3,L3-H7 or L3,L5,L7)')
    parser.add_argument('--pick_percentile_heads', type=str, default=None,
                        help='Comma-separated list of pick percentile heads (e.g., 0.25,0.50,0.75)')
    parser.add_argument('--crop_hallu_span', action='store_true',
                        help='Flag to crop hallucinated span')
    parser.add_argument('--token_level', action='store_true',
                        help='Flag to set token level processing')
    parser.add_argument('--sequential_level', action='store_true',
                        help='Flag to set sequential level processing')
    parser.add_argument('--sliding_window', type=int, default=1,
                        help='Sliding window size')
    parser.add_argument('--weight_balance', action='store_true',
                        help='Flag to balance the class weight')
    parser.add_argument('--min_pool_target', action='store_true',
                        help='Flag to use min pooling for target level')
    parser.add_argument('--feat', action='store_true',
                        help='Flag to use features from the teacher-forcing model')
    parser.add_argument('--feat_layer', type=int, default=32,
                        help='Layer index to use the features from the teacher-forcing model')
    parser.add_argument('--disable_best_threshold', action='store_true',
                        help='Flag to disable best threshold calculation')
    # two fold cross validation
    parser.add_argument('--two_fold', action='store_true',
                        help='Flag to set two fold cross validation')
    # non sequential
    parser.add_argument('--non_seq', action='store_true',
                        help='Flag to set non-sequential level processing')
    
    # model: [7b or 13b]
    parser.add_argument('--model', type=str, default='7b')
    # penalty: [l1 or l2]
    parser.add_argument('--penalty', type=str, default='l2')
    # # conversion (path to a .pt file)
    # parser.add_argument('--conversion', type=str, default=None)
    # pick the most important heads from another model
    parser.add_argument('--pick_heads_model', type=str, default=None)
    # pick method: top-k positive, top-k negative, top-k absolute, :10 means k=10
    parser.add_argument('--pick_method', type=str, default='top-k-absolute:10')

    args = parser.parse_args()
    conversion = None

    if args.model == '7b':
        anno_file_1 = "gpt4-anno-v2-nq-open.jsonl"
        attn_file_1 = "greedy-Llama-2-7b-chat-hf-nq-open-for-attn.jsonl_attn_scores_0.pt"
        anno_file_2 = "gpt4-anno-v2-summarization-1000.jsonl"
        attn_file_2 = "greedy-Llama-2-7b-chat-hf-summarization-1000-for-attn-sink.jsonl_attn_scores_0.pt"
        
        feat_file_1 = "teacher-forcing-feats-Llama-2-7b-chat-hf-nq-open.jsonl_hiddens_0.pt"
        feat_file_2 = "teacher-forcing-feats-Llama-2-7b-chat-hf-summarization-1000.jsonl_hiddens_0.pt"
    elif args.model == '7b-to-13b-nq':
        anno_file_1 = "gpt4-anno-v2-nq-open.jsonl"
        attn_file_1 = "greedy-Llama-2-7b-chat-hf-nq-open-for-attn.jsonl_attn_scores_0.pt"
        anno_file_2 = "gpt4-anno-v2-llama2-13b-chat-nq-open.jsonl"
        attn_file_2 = "greedy-Llama-2-13b-chat-hf-nq-open-for-attn-sink.jsonl_attn_scores_0.pt"
        
        feat_file_1 = "teacher-forcing-feats-Llama-2-7b-chat-hf-nq-open.jsonl_hiddens_0.pt"
        feat_file_2 = "teacher-forcing-feats-Llama-2-13b-chat-hf-summarization-1000.jsonl_hiddens_0.pt"
        conversion = "conversion-tf-lr-Llama-2-7b-13b-chat-hf-nq-open-cot.pkl"
        conversion = pickle.load(open(conversion, 'rb'))
    elif args.model == '7b-to-13b-summ':
        anno_file_1 = "gpt4-anno-v2-summarization-1000.jsonl"
        attn_file_1 = "greedy-Llama-2-7b-chat-hf-summarization-1000-for-attn-sink.jsonl_attn_scores_0.pt"
        feat_file_1 = "teacher-forcing-feats-Llama-2-7b-chat-hf-summarization-1000.jsonl_hiddens_0.pt"

        anno_file_2 = "gpt4-anno-v2-llama2-13b-chat-summarization-1000.jsonl"
        attn_file_2 = "greedy-Llama-2-13b-chat-hf-summarization-1000-for-attn-sink.jsonl_attn_scores_0.pt"
        feat_file_2 = "teacher-forcing-feats-Llama-2-13b-chat-hf-summarization-1000.jsonl_hiddens_0.pt"
        conversion = "conversion-tf-lr-Llama-2-7b-13b-chat-hf-summ1000.pkl"
        conversion = pickle.load(open(conversion, 'rb'))
    elif args.model == '13b':
        anno_file_1 = "gpt4-anno-v2-llama2-13b-chat-nq-open.jsonl"
        attn_file_1 = "greedy-Llama-2-13b-chat-hf-nq-open-for-attn-sink.jsonl_attn_scores_0.pt"
        anno_file_2 = "gpt4-anno-v2-llama2-13b-chat-summarization-1000.jsonl"
        attn_file_2 = "greedy-Llama-2-13b-chat-hf-summarization-1000-for-attn-sink.jsonl_attn_scores_0.pt"

        feat_file_1 = "teacher-forcing-feats-Llama-2-13b-chat-hf-nq-open.jsonl_hiddens_0.pt"
        feat_file_2 = "teacher-forcing-feats-Llama-2-13b-chat-hf-summarization-1000.jsonl_hiddens_0.pt"
    elif args.model == 'nq-train-7b':
        anno_file_1 = "gpt4o-anno-greedy-Llama-2-7b-chat-hf-nq-train.jsonl"
        attn_file_1 = "greedy-Llama-2-7b-chat-hf-nq-train.jsonl_attn_scores_0.pt"
        anno_file_2 = "gpt4-anno-v2-nq-open.jsonl"
        attn_file_2 = "greedy-Llama-2-7b-chat-hf-nq-open-for-attn.jsonl_attn_scores_0.pt"
    elif args.model == 'nq-train-13b':
        anno_file_1 = "gpt4o-anno-greedy-Llama-2-13b-chat-hf-nq-train.jsonl"
        attn_file_1 = "greedy-Llama-2-13b-chat-hf-nq-train.jsonl_attn_scores_0.pt"
        anno_file_2 = "gpt4-anno-v2-llama2-13b-chat-nq-open.jsonl"
        attn_file_2 = "greedy-Llama-2-13b-chat-hf-nq-open-for-attn-sink.jsonl_attn_scores_0.pt"

        feat_file_1 = None
        feat_file_2 = None

    selected_features = args.features.split(',')
    selected_tensors = args.tensors.split(',')
    selected_layer_heads = args.layer_heads.split(
        ',') if args.layer_heads else None
    pick_percentile_heads = [float(x) for x in args.pick_percentile_heads.split(
        ',')] if args.pick_percentile_heads else None
    crop_hallu_span = args.crop_hallu_span
    token_level = args.token_level
    sequential_level = True if not args.non_seq else False
    sliding_window = args.sliding_window
    is_feat = args.feat
    feat_layer = args.feat_layer
    penalty = args.penalty if args.penalty.lower() != 'none' else None

    if args.model == '7b':
        num_heads = 32
        num_layers = 32
    elif args.model == '13b':
        num_heads = 40
        num_layers = 40

    # pick heads mechanism
    picked_head_indices = None
    if args.pick_heads_model is not None:
        pick_heads_model = pickle.load(open(args.pick_heads_model, 'rb'))
        pick_method, k = args.pick_method.split(':')
        if 'interp_decomp' in args.pick_heads_model:
            picked_head_indices = pick_heads_model['idx'][:int(k)]
            coef_ = None
        else:
            coef_ = pick_heads_model['clf'].coef_
            if pick_method == 'top-k-absolute':
                picked_head_indices = np.argsort(np.abs(coef_[0]))[::-1][:int(k)]
            elif pick_method == 'top-k-positive':
                picked_head_indices = np.argsort(coef_[0])[::-1][:int(k)]
            elif pick_method == 'top-k-negative':
                picked_head_indices = np.argsort(coef_[0])[:int(k)]
            elif pick_method == 'least-k-absolute':
                picked_head_indices = np.argsort(np.abs(coef_[0]))[:int(k)]
            elif pick_method == 'least-k-positive':
                picked_head_indices = np.argsort(coef_[0])[:int(k)]
            elif pick_method == 'least-k-negative':
                picked_head_indices = np.argsort(coef_[0])[::-1][:int(k)]
            else:
                raise ValueError(
                    f"Invalid pick method: {pick_method}. Please choose from top-k-absolute, top-k-positive, top-k-negative")
        # print all the selected heads in L-H format, as well as the corresponding coefficients
        for i in picked_head_indices:
            l, h = divmod(i, num_heads)
            print(f"L{l}-H{h}: {coef_[0][i] if coef_ is not None else None}")

    # Sliding window
    # example: python tseries_predict_tokenwise.py --features mean --tensors focs --crop_hallu_span --token_level --sliding_window 1 --min_pool_target
    # Span level
    # example: python tseries_predict_tokenwise.py --features mean --tensors focs --crop_hallu_span --non_seq
    # Feature based
    # example: python tseries_predict_tokenwise.py --features mean --tensors focs --crop_hallu_span --feat --feat_layer 20 --non_seq
    main(
        anno_file_1, 
        attn_file_1 if not args.feat else feat_file_1, 
        anno_file_2, 
        attn_file_2 if not args.feat else feat_file_2,
        selected_features=selected_features,
        selected_tensors=selected_tensors,
        selected_layer_heads=selected_layer_heads,
        pick_percentile_heads=pick_percentile_heads,
        crop_hallu_span=crop_hallu_span,
        token_level=token_level,
        sliding_window=sliding_window,
        sequential=sequential_level,
        weight_balance=args.weight_balance,
        min_pool_target=args.min_pool_target,
        is_feat=is_feat,
        feat_layer=feat_layer,
        penalty=penalty,
        disable_best_threshold=args.disable_best_threshold,
        two_fold=args.two_fold,
        conversion=conversion,
        picked_head_indices=picked_head_indices,
    )
