import sys
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import transformers
import argparse
import random
import nltk
from tqdm import tqdm
import math
# import metric
import json
import evaluate
from utils import normalize_answer, customized_tokenize, is_stopword, flatten, get_peak_memory, get_flops
from cache_based_inference.pos_shift.modify_llama import enable_llama_pos_shift_attention
from cache_based_inference.pos_shift.modify_falcon import enable_falcon_pos_shift_attention
from cache_based_inference.pos_shift.modify_gpt_neox import enable_gpt_neox_pos_shift_attention
from cache_based_inference.pos_shift.modify_falcon import enable_falcon_pos_shift_attention
from cache_based_inference.pos_shift.modify_mistral import enable_mistral_pos_shift_attention
from cache_based_inference.pos_shift.modify_gpt_j import enable_gpt_j_pos_shift_attention
import numpy as np
# from llm_needle_haystack_tester import
from torch.nn import CrossEntropyLoss
import torch.multiprocessing as mp
from llm_needle_haystack_tester import LLMNeedleHaystackTester
import gc
import time
from metrics import (
    qa_f1_score,
    rouge_zh_score,
    qa_f1_zh_score,
    rouge_score,
    classification_score,
    retrieval_score,
    retrieval_zh_score,
    count_score,
    code_sim_score,
)

dataset2metric = {
    "narrativeqa": qa_f1_score,
    "qasper": qa_f1_score,
    "multifieldqa_en": qa_f1_score,
    "multifieldqa_zh": qa_f1_zh_score,
    "hotpotqa": qa_f1_score,
    "2wikimqa": qa_f1_score,
    "musique": qa_f1_score,
    "dureader": rouge_zh_score,
    "gov_report": rouge_score,
    "qmsum": rouge_score,
    "multi_news": rouge_score,
    "vcsum": rouge_zh_score,
    "trec": classification_score,
    "triviaqa": qa_f1_score,
    "samsum": rouge_score,
    "lsht": classification_score,
    "passage_retrieval_en": retrieval_score,
    "passage_count": count_score,
    "passage_retrieval_zh": retrieval_zh_score,
    "lcc": code_sim_score,
    "repobench-p": code_sim_score,
}



def scorer_e(dataset, predictions, answers, lengths, all_classes):
    scores = {"0-4k": [], "4-8k": [], "8k+": []}
    for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
        score = 0.
        if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
            prediction = prediction.lstrip('\n').split('\n')[0]
        for ground_truth in ground_truths:
            score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
        if length < 4000:
            scores["0-4k"].append(score)
        elif length < 8000:
            scores["4-8k"].append(score)
        else:
            scores["8k+"].append(score)
    for key in scores.keys():
        scores[key] = round(100 * np.mean(scores[key]), 2)
    return scores

def scorer(dataset, predictions, answers, all_classes):
    total_score = 0.
    for (prediction, ground_truths) in zip(predictions, answers):
        score = 0.
        if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
            prediction = prediction.lstrip('\n').split('\n')[0]
        for ground_truth in ground_truths:
            score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
        total_score += score
    return round(100 * total_score / len(predictions), 2)


def construct_query_context(retrieved_docs):
    prompt="Answer the query according to the following documents."
    for i in range(len(retrieved_docs)):
        prompt+="Document: "+retrieved_docs[i]+'\n'
    return prompt

def construct_query(query):
    prompt=""
    prompt+="Q: "+query+'\n'
    prompt+="A: The answer is"
    return prompt

def get_the_final_answer(result):
    ans_sent_start_idx=result.find('The answer is')+len("The answer is")
    ans_sent_end_idx=result.find('.', ans_sent_start_idx)
    if ans_sent_end_idx==-1:
        ans_sent_end_idx=len(result)
    final_answer=result[ans_sent_start_idx:ans_sent_end_idx]
    return normalize_answer(final_answer)
    
def get_the_final_answer_summary(result):
    ans_sent_start_idx=result.find('System:')+len("System:")
    # ans_sent_end_idx=result.find('.', ans_sent_start_idx)
    # if ans_sent_end_idx==-1:
    ans_sent_end_idx=len(result)
    final_answer=result[ans_sent_start_idx:ans_sent_end_idx]
    return final_answer

def get_the_final_answer_(result):
    ans_sent_start_idx=result.find('System:')+len("System:")
    # ans_sent_end_idx=result.find('.', ans_sent_start_idx)
    # if ans_sent_end_idx==-1:
    ans_sent_end_idx=len(result)
    final_answer=result[ans_sent_start_idx:ans_sent_end_idx]
    return final_answer

def get_the_final_answer_summary_new(result):
    ans_sent_start_idx=result.find('Summary:')+len("Summary:")
    # ans_sent_end_idx=result.find('.', ans_sent_start_idx)
    # if ans_sent_end_idx==-1:
    ans_sent_end_idx=len(result)
    final_answer=result[ans_sent_start_idx:ans_sent_end_idx]
    return final_answer

def get_the_final_answer_squality(result):
    ans_sent_start_idx=result.find('Answer:')+len("Answer:")
    # ans_sent_end_idx=result.find('.', ans_sent_start_idx)
    # if ans_sent_end_idx==-1:
    ans_sent_end_idx=len(result)
    final_answer=result[ans_sent_start_idx:ans_sent_end_idx]
    return final_answer
# 

def build_chat(tokenizer, prompt, model_name):
    if "chatglm3" in model_name:
        prompt = tokenizer.build_chat_input(prompt)
    elif "chatglm" in model_name:
        prompt = tokenizer.build_prompt(prompt)
    elif "longchat" in model_name or "vicuna" in model_name:
        from fastchat.model import get_conversation_template
        conv = get_conversation_template("vicuna")
        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
    elif "llama2" in model_name or 'mistral' in model_name:
        prompt = f"[INST]{prompt}[/INST]"
    elif "xgen" in model_name:
        header = (
            "A chat between a curious human and an artificial intelligence assistant. "
            "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
        )
        prompt = header + f" ### Human: {prompt}\n###"
    elif "internlm" in model_name:
        prompt = f"<|User|>:{prompt}<eoh>\n<|Bot|>:"
    return prompt

@torch.no_grad()
def get_past_key_values(input_ids_query_context, input_ids_query, cache_type, model, context_words_lst=None, word_idx_to_subtoken_start_end_idx=None, random_tokens_ratio=None, start_size=None, recent_size=None, k=None, past_key_values=None, with_mean_attention=False, multi_head=False):
    #context+query一起inference，得到context和query的kv，然后筛选context的kv，query的kv此时attend full context，最后用context筛选后的kv+query kv去做生成
    if cache_type=="zeroshot":
        outputs = model(input_ids_query, output_attentions=True, use_cache=True)
        past_key_values = outputs.past_key_values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        given_indices = range(len(input_ids_query[0]))
        return  past_key_values, pred_token_idx, given_indices
    
    device=input_ids_query.device
    if input_ids_query_context!=None:
        input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
    else:
        input_ids=input_ids_query
    
    past_key_values_org = past_key_values 
    outputs = model(input_ids, output_attentions=True, use_cache=True, past_key_values=past_key_values)
    past_key_values = outputs.past_key_values
    
    if cache_type=="all":
        given_indices = range(len(input_ids[0])) 
    elif cache_type=="nocontext":
        given_indices = range(len(input_ids_query_context[0]), len(input_ids[0]))
    elif cache_type=="stopwords":
        given_indices=[0]
        assert context_words_lst!=None
        assert word_idx_to_subtoken_start_end_idx!=None
        for word_idx, word in enumerate(context_words_lst):
            if is_stopword(word):
                subtoken_start_idx, subtoken_end_idx=word_idx_to_subtoken_start_end_idx[word_idx]
                given_indices.extend(range(subtoken_start_idx, subtoken_end_idx+1))       
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
        given_indices=sorted(given_indices)
    elif cache_type=="random":
        assert random_tokens_ratio!=None
        random_tokens_num=int(random_tokens_ratio*len(input_ids_query_context[0]))
        given_indices=random.sample(range(len(input_ids_query_context[0])), k=random_tokens_num)
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
        given_indices=sorted(given_indices)
    elif cache_type=="startrecent":
        assert start_size!=None
        assert recent_size!=None
        if type(start_size)!=int:
            start_size=int(start_size*len(input_ids_query_context[0]))
        if type(recent_size)!=int:
            recent_size=int(recent_size*len(input_ids_query_context[0]))
        given_indices=list(range(start_size))
        given_indices.extend(range(len(input_ids_query_context[0])-recent_size, len(input_ids_query_context[0])))
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
        given_indices=sorted(given_indices)
    elif cache_type=="topk":
        if type(start_size)!=int:
            start_size=int(start_size*len(input_ids_query_context[0]))
        assert k!=None
        if type(k)!=int:
            if past_key_values is None:
                k=int(k*len(input_ids_query_context[0]))
            elif input_ids_query_context is None:
                k=int(k*past_key_values[0][0].size(2))
            else:
                k=int(k*(len(input_ids_query_context[0]) + past_key_values[0][0].size(2)))

        
        attention=outputs.attentions
        given_indices_for_each_layer=[]
        # print('start_size = ', start_size)
        given_indices_start=list(range(start_size))
        for attention_per_layer in attention:
            if multi_head:
                given_indices_for_each_head= [] 
                for head_num in range(attention_per_layer.size(1)):
                    if past_key_values_org is None:
                        # considering all the tokens in the query_text
                        # att_weights=torch.mean(attention_per_layer[0,:,:,:len(input_ids_query_context[0])], dim=0)
                        # att_weights=torch.mean(att_weights[:,:len(input_ids_query_context[0])], dim=0)

                        # att_weights=torch.mean(attention_per_layer[0,:,-1,:len(input_ids_query_context[0])], dim=0)
                        if with_mean_attention:
                            att_weights=attention_per_layer[0, head_num ,:,:len(input_ids_query_context[0])]
                            att_weights=torch.mean(att_weights[len(input_ids_query_context[0]):,:len(input_ids_query_context[0])], dim=0)
                        else:
                            att_weights=attention_per_layer[0,head_num,-1,:len(input_ids_query_context[0])]

                        topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, len(input_ids_query_context[0]) - start_size))

                        topk_indices= topk_indices + start_size
                        topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(input_ids_query_context[0]), len(input_ids[0])), dtype=torch.int, device=device)]) 
                    else:
                        # print("past_key_values = ", past_key_values[0][0].size())
                        # exit()

                        # considering all the tokens in the query_text
                        # att_weights=torch.mean(attention_per_layer[0,:,:,:past_key_values[0][0].size(2)], dim=0)
                        # att_weights=torch.mean(att_weights[:,:past_key_values[0][0].size(2)], dim=0)

                        # att_weights=torch.mean(attention_per_layer[0,:,-1,:past_key_values_org[0][0].size(2)], dim=0)
                        if with_mean_attention:
                            att_weights=attention_per_layer[0, head_num,:,:past_key_values_org[0][0].size(2)]
                            att_weights=torch.mean(att_weights[:,:past_key_values_org[0][0].size(2)], dim=0)
                        else:
                            att_weights=attention_per_layer[0,head_num,-1,:past_key_values_org[0][0].size(2)]

                        # print("atten_weights = ", att_weights.size())
                        # att_weights = att_weights[start_size:]
                        # topk_att_weights, topk_indices=torch.topk(att_weights[:,], k=k)
                        # print("topk_indices initial = ", topk_indices)

                        topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, att_weights[start_size:,].size(0)))
                        # exit()
                        # print("topk_indices = ", topk_indices)
                        topk_indices= topk_indices + start_size
                        # print("topk_indices after = ", topk_indices)
                        # exit()
                        # print(torch.tensor(range(len(past_key_values[0][0][0][0]), len(input_ids[0])), dtype=torch.int, device=device))

                        topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(past_key_values_org[0][0][0][0]), len(past_key_values_org[0][0][0][0]) + len(input_ids[0])), dtype=torch.long, device=device)])
                    # topk_indices=torch.cat([given_indices_start, topk_indices]), len(input_ids[0]))], dtype=torch.long, device=device)])
                    
                    topk_indices, _=torch.sort(topk_indices)
                    given_indices_for_each_head.append(topk_indices)
                given_indices_for_each_layer.append(given_indices_for_each_head)           
            else:
                if past_key_values_org is None:
                    # considering all the tokens in the query_text
                    if with_mean_attention:
                        att_weights=torch.mean(attention_per_layer[0,:,:,:len(input_ids_query_context[0])], dim=0)
                        att_weights=torch.mean(att_weights[len(input_ids_query_context[0]):,:len(input_ids_query_context[0])], dim=0)
                    else:

                        att_weights=torch.mean(attention_per_layer[0,:,-1,:len(input_ids_query_context[0])], dim=0)


                    topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, len(input_ids_query_context[0]) - start_size))
                    topk_indices= topk_indices + start_size
                    # remove the cache of the last token
                    # no, removed in the main function
                    topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device), topk_indices.long(), torch.tensor(range(len(input_ids_query_context[0]), len(input_ids[0])), dtype=torch.long, device=device)]) 
                    # topk_indices=torch.cat([topk_indices.long(), torch.tensor(range(len(input_ids_query_context[0]), len(input_ids[0]) - 1), dtype=torch.long, device=device)]) 
                else:
                    # print("past_key_values = ", past_key_values[0][0].size())
                    # exit()

                    # considering all the tokens in the query_text
                    if with_mean_attention:
                        att_weights=torch.mean(attention_per_layer[0,:,:,:past_key_values_org[0][0].size(2)], dim=0)
                        att_weights=torch.mean(att_weights[:,:past_key_values_org[0][0].size(2)], dim=0)
                    else:
                        att_weights=torch.mean(attention_per_layer[0,:,-1,:past_key_values_org[0][0].size(2)], dim=0)

                    # print("atten_weights = ", att_weights.size())
                    # att_weights = att_weights[start_size:]
                    # topk_att_weights, topk_indices=torch.topk(att_weights[:,], k=k)
                    # print("topk_indices initial = ", topk_indices)

                    topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, att_weights[start_size:,].size(0)))
                    # exit()
                    # print("topk_indices = ", topk_indices)
                    topk_indices= topk_indices + start_size
                    # print("topk_indices after = ", topk_indices)
                    # exit()
                    # print(torch.tensor(range(len(past_key_values[0][0][0][0]), len(input_ids[0])), dtype=torch.int, device=device))

                    # remove the cache of the last token
                    # no, removed in the main function
                    topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(past_key_values_org[0][0][0][0]),  len(past_key_values_org[0][0][0][0]) + len(input_ids[0])), dtype=torch.long, device=device)])
                    # topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(past_key_values_org[0][0][0][0]),  len(past_key_values_org[0][0][0][0]) + len(input_ids[0]) - 1), dtype=torch.long, device=device)])

                # topk_indices=torch.cat([given_indices_start, topk_indices]), len(input_ids[0]))], dtype=torch.long, device=device)])
                
                topk_indices, _=torch.sort(topk_indices)
                # print("topk_indices = ", topk_indices)
                given_indices_for_each_layer.append(topk_indices)
            
    
    if cache_type=="topk":    
        if multi_head:
            selected_past_key_values = []
            for layer_idx, layer in enumerate(past_key_values):
                tmp_past_key_values = []
                for kv in layer:
                    selected_kv = torch.cat([kv[:, num_head, given_indices_for_each_layer[layer_idx][num_head], :].unsqueeze(1) for num_head in range(kv.size(1))], dim=1)
                    tmp_past_key_values.append(selected_kv)
                selected_past_key_values.append(tuple(tmp_past_key_values))
            selected_past_key_values = tuple(selected_past_key_values)
                
        else:

            selected_past_key_values = tuple(tuple(kv[:, :, given_indices_for_each_layer[layer_idx], :] for kv in layer) for layer_idx, layer in enumerate(past_key_values))

        # given_indices1=list(range(start_size))
        given_indices = flatten(given_indices_for_each_layer)


    else:
        selected_past_key_values = tuple(tuple(kv[:, :, given_indices, :] for kv in layer) for layer in past_key_values)
    
    # instead of selecting the last pred_token, we use the last prompt token as the start, and remove the cache of it.
    # pred_token_idx = input_ids[:, -1].unsqueeze(1)
    # outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    # print("prev = ", outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1))
    # print("current = ", pred_token_idx)
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) 
    return selected_past_key_values, pred_token_idx, given_indices, outputs.logits


@torch.no_grad()
def get_past_key_values_no_current_segment(input_ids_query_context, input_ids_query, cache_type, model, context_words_lst=None, word_idx_to_subtoken_start_end_idx=None, random_tokens_ratio=None, start_size=None, recent_size=None, k=None, past_key_values=None, with_mean_attention=False, multi_head=False):
    #context+query一起inference，得到context和query的kv，然后筛选context的kv，query的kv此时attend full context，最后用context筛选后的kv+query kv去做生成
    if cache_type=="zeroshot":
        outputs = model(input_ids_query, output_attentions=True, use_cache=True)
        past_key_values = outputs.past_key_values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        given_indices = range(len(input_ids_query[0]))
        return  past_key_values, pred_token_idx, given_indices
    
    device=input_ids_query.device
    if input_ids_query_context!=None:
        input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
    else:
        input_ids=input_ids_query
    
    past_key_values_org = past_key_values 
    outputs = model(input_ids, output_attentions=True, use_cache=True, past_key_values=past_key_values)
    past_key_values = outputs.past_key_values
    
    if cache_type=="all":
        given_indices = range(len(input_ids[0])) 
    elif cache_type=="nocontext":
        given_indices = range(len(input_ids_query_context[0]), len(input_ids[0]))
    elif cache_type=="stopwords":
        given_indices=[0]
        assert context_words_lst!=None
        assert word_idx_to_subtoken_start_end_idx!=None
        for word_idx, word in enumerate(context_words_lst):
            if is_stopword(word):
                subtoken_start_idx, subtoken_end_idx=word_idx_to_subtoken_start_end_idx[word_idx]
                given_indices.extend(range(subtoken_start_idx, subtoken_end_idx+1))       
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
        given_indices=sorted(given_indices)
    elif cache_type=="random":
        assert random_tokens_ratio!=None
        random_tokens_num=int(random_tokens_ratio*len(input_ids_query_context[0]))
        given_indices=random.sample(range(len(input_ids_query_context[0])), k=random_tokens_num)
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
        given_indices=sorted(given_indices)
    elif cache_type=="startrecent":
        assert start_size!=None
        assert recent_size!=None
        if type(start_size)!=int:
            start_size=int(start_size*len(input_ids_query_context[0]))
        if type(recent_size)!=int:
            recent_size=int(recent_size*len(input_ids_query_context[0]))
        given_indices=list(range(start_size))
        given_indices.extend(range(len(input_ids_query_context[0])-recent_size, len(input_ids_query_context[0])))
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
        given_indices=sorted(given_indices)
    elif cache_type=="topk":
        if type(start_size)!=int:
            start_size=int(start_size*len(input_ids_query_context[0]))
        assert k!=None
        if type(k)!=int:
            if past_key_values is None:
                k=int(k*len(input_ids_query_context[0]))
            elif input_ids_query_context is None:
                k=int(k*past_key_values[0][0].size(2))
            else:
                k=int(k*(len(input_ids_query_context[0]) + past_key_values[0][0].size(2)))

        
        attention=outputs.attentions
        given_indices_for_each_layer=[]
        # print('start_size = ', start_size)
        given_indices_start=list(range(start_size))
        for attention_per_layer in attention:
            if multi_head:
                given_indices_for_each_head= [] 
                for head_num in range(attention_per_layer.size(1)):
                    if past_key_values_org is None:
                        # considering all the tokens in the query_text
                        # att_weights=torch.mean(attention_per_layer[0,:,:,:len(input_ids_query_context[0])], dim=0)
                        # att_weights=torch.mean(att_weights[:,:len(input_ids_query_context[0])], dim=0)

                        # att_weights=torch.mean(attention_per_layer[0,:,-1,:len(input_ids_query_context[0])], dim=0)
                        if with_mean_attention:
                            att_weights=attention_per_layer[0, head_num ,:,:len(input_ids_query_context[0])]
                            att_weights=torch.mean(att_weights[len(input_ids_query_context[0]):,:len(input_ids_query_context[0])], dim=0)
                        else:
                            att_weights=attention_per_layer[0,head_num,-1,:len(input_ids_query_context[0])]


                        topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, len(input_ids_query_context[0]) - start_size))

                        topk_indices= topk_indices + start_size
                        topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(input_ids_query_context[0]), len(input_ids[0])), dtype=torch.int, device=device)]) 
                    else:
                        # print("past_key_values = ", past_key_values[0][0].size())
                        # exit()

                        # considering all the tokens in the query_text
                        # att_weights=torch.mean(attention_per_layer[0,:,:,:past_key_values[0][0].size(2)], dim=0)
                        # att_weights=torch.mean(att_weights[:,:past_key_values[0][0].size(2)], dim=0)

                        # att_weights=torch.mean(attention_per_layer[0,:,-1,:past_key_values_org[0][0].size(2)], dim=0)
                        if with_mean_attention:
                            att_weights=attention_per_layer[0, head_num,:,:past_key_values_org[0][0].size(2)]
                            att_weights=torch.mean(att_weights[:,:past_key_values_org[0][0].size(2)], dim=0)
                        else:
                            att_weights=attention_per_layer[0,head_num,-1,:past_key_values_org[0][0].size(2)]

                        # print("atten_weights = ", att_weights.size())
                        # att_weights = att_weights[start_size:]
                        # topk_att_weights, topk_indices=torch.topk(att_weights[:,], k=k)
                        # print("topk_indices initial = ", topk_indices)

                        topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, att_weights[start_size:,].size(0)))
                        # exit()
                        # print("topk_indices = ", topk_indices)
                        topk_indices= topk_indices + start_size
                        # print("topk_indices after = ", topk_indices)
                        # exit()
                        # print(torch.tensor(range(len(past_key_values[0][0][0][0]), len(input_ids[0])), dtype=torch.int, device=device))

                        topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(past_key_values_org[0][0][0][0]), len(past_key_values_org[0][0][0][0]) + len(input_ids[0])), dtype=torch.long, device=device)])
                    # topk_indices=torch.cat([given_indices_start, topk_indices]), len(input_ids[0]))], dtype=torch.long, device=device)])
                    
                    topk_indices, _=torch.sort(topk_indices)
                    given_indices_for_each_head.append(topk_indices)
                given_indices_for_each_layer.append(given_indices_for_each_head) 
            else:
                if past_key_values_org is None:
                    # considering all the tokens in the query_text
                    if with_mean_attention:
                        att_weights=torch.mean(attention_per_layer[0,:,:,:len(input_ids_query_context[0])], dim=0)
                        att_weights=torch.mean(att_weights[len(input_ids_query_context[0]):,:len(input_ids_query_context[0])], dim=0)
                    else:

                        att_weights=torch.mean(attention_per_layer[0,:,-1,:len(input_ids_query_context[0])], dim=0)


                    topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, len(input_ids_query_context[0]) - start_size))
                    topk_indices= topk_indices + start_size
                    # remove the cache of the last token
                    # no, removed in the main function
                    topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device), topk_indices.long(), torch.tensor(range(len(input_ids_query_context[0]), len(input_ids[0])), dtype=torch.long, device=device)]) 
                    # topk_indices=torch.cat([topk_indices.long(), torch.tensor(range(len(input_ids_query_context[0]), len(input_ids[0]) - 1), dtype=torch.long, device=device)]) 
                else:
                    # print("past_key_values = ", past_key_values[0][0].size())
                    # exit()

                    # considering all the tokens in the query_text
                    if with_mean_attention:
                        att_weights=torch.mean(attention_per_layer[0,:,:,:past_key_values_org[0][0].size(2)], dim=0)
                        att_weights=torch.mean(att_weights[:,:past_key_values_org[0][0].size(2)], dim=0)
                    else:
                        att_weights=torch.mean(attention_per_layer[0,:,-1,:past_key_values_org[0][0].size(2)], dim=0)

                    # print("atten_weights = ", att_weights.size())
                    # att_weights = att_weights[start_size:]
                    # topk_att_weights, topk_indices=torch.topk(att_weights[:,], k=k)
                    # print("topk_indices initial = ", topk_indices)

                    topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, att_weights[start_size:,].size(0)))
                    # exit()
                    # print("topk_indices = ", topk_indices)
                    topk_indices= topk_indices + start_size
                    # print("topk_indices after = ", topk_indices)
                    # exit()
                    # print(torch.tensor(range(len(past_key_values[0][0][0][0]), len(input_ids[0])), dtype=torch.int, device=device))

                    # remove the cache of the last token
                    # no, removed in the main function
                    topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices])
                    # topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(past_key_values_org[0][0][0][0]),  len(past_key_values_org[0][0][0][0]) + len(input_ids[0]) - 1), dtype=torch.long, device=device)])

                # topk_indices=torch.cat([given_indices_start, topk_indices]), len(input_ids[0]))], dtype=torch.long, device=device)])
                
                topk_indices, _=torch.sort(topk_indices)
                # print("topk_indices = ", topk_indices)
                given_indices_for_each_layer.append(topk_indices)
            
    
    if cache_type=="topk":    
        if multi_head:
            selected_past_key_values = []
            for layer_idx, layer in enumerate(past_key_values):
                tmp_past_key_values = []
                for kv in layer:
                    selected_kv = torch.cat([kv[:, num_head, given_indices_for_each_layer[layer_idx][num_head], :].unsqueeze(1) for num_head in range(kv.size(1))], dim=1)
                    tmp_past_key_values.append(selected_kv)
                selected_past_key_values.append(tuple(tmp_past_key_values))
            selected_past_key_values = tuple(selected_past_key_values)
                
        else:
            selected_past_key_values = tuple(tuple(kv[:, :, given_indices_for_each_layer[layer_idx], :] for kv in layer) for layer_idx, layer in enumerate(past_key_values))

        # given_indices1=list(range(start_size))
        given_indices = flatten(given_indices_for_each_layer)


    else:
        selected_past_key_values = tuple(tuple(kv[:, :, given_indices, :] for kv in layer) for layer in past_key_values)
    
    # instead of selecting the last pred_token, we use the last prompt token as the start, and remove the cache of it.
    # pred_token_idx = input_ids[:, -1].unsqueeze(1)
    # outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    # print("prev = ", outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1))
    # print("current = ", pred_token_idx)
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    return selected_past_key_values, pred_token_idx, given_indices, outputs.logits


@torch.no_grad()
def decoding(model, tokenizer, past_key_values, pred_token_idx, max_gen_len, num_beams, end_token=None):
    end_token=tokenizer.eos_token_id if end_token is None else end_token
    device=model.device
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    before_memory=get_peak_memory(device)
    starter, ender=torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #measure the average inference time of a doc with batch_size=1
    starter.record()
    
    outputs = model.generate(
        input_ids=pred_token_idx,
        attention_mask=torch.ones(past_key_values[0][0].size(-2)+1).unsqueeze(0),
        do_sample=False,
        temperature=0.8,
        top_p=0.75,
        top_k=40,
        num_beams=num_beams,
        eos_token_id=end_token,
        max_new_tokens=max_gen_len,
        past_key_values=past_key_values,
    )
    
    res = outputs[0].detach().cpu()
    # del outputs

    # 进行垃圾回收
    

    ender.record()
    torch.cuda.synchronize()
    decoding_time=starter.elapsed_time(ender)
    after_memory=get_peak_memory(device)
    torch.cuda.reset_peak_memory_stats(device)
    gc.collect()
    torch.cuda.empty_cache()
    decoding_memory=after_memory-before_memory
    print("decoding memory = ", decoding_memory)
    return res, decoding_time, decoding_memory


@torch.no_grad()
def decoding_with_past_kv(model, tokenizer, past_key_values, pred_token_idx, max_gen_len, num_beams, end_token=None, input_ids_query=None, k=None, start_size=None, recent_size=None, cache_type=None, with_mean_attention=None):
    end_token=tokenizer.eos_token_id if end_token is None else end_token
    device=model.device
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    before_memory=get_peak_memory(device)
    starter, ender=torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #measure the average inference time of a doc with batch_size=1
    starter.record()
    
    # print("past_key_values[0][0].size() = ", past_key_values[0][0].size())
    # print("pred_token_idx.size() = ", pred_token_idx.size())
    pos = 0
    res = ""
    while True:
        print("pos for current generation: ", pos)
        outputs = model.generate(
            input_ids=pred_token_idx,
            attention_mask=torch.ones(past_key_values[0][0].size(-2) + 1).unsqueeze(0),
            do_sample=False,
            temperature=0.8,
            top_p=0.75,
            top_k=40,
            num_beams=num_beams,
            eos_token_id=end_token,
            max_new_tokens=max_gen_len,
            past_key_values=past_key_values,
        )


        print("outputs = ", outputs)

        print("tokenizer.eos_token_id = ", tokenizer.eos_token_id)

        # if pos == 0:
        #     result = tokenizer.decode(torch.cat([pred_token_idx[0].to(outputs.device), outputs[0, :].cpu()]), skip_special_tokens=True)
        # else:
        result = tokenizer.decode(outputs[0, :].cpu(), skip_special_tokens=True)

        if result[0] != ' ':
            result = ' ' + result

            # res[]


        res += result
        print('results = ', result)




        # pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        past_key_values, pred_token_idx, given_indices, logits = get_past_key_values(None, outputs, cache_type, model, None, None, None, start_size, recent_size, k, past_key_values=past_key_values, with_mean_attention=with_mean_attention)


        pos += 1
        if pos > 20 or tokenizer.eos_token_id == outputs[0, -1]:
            break
    # print("pred_token_idx = ", pred_token_id)

    # exit()

    ender.record()
    torch.cuda.synchronize()
    decoding_time=starter.elapsed_time(ender)
    after_memory=get_peak_memory(device)
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.empty_cache()
    decoding_memory=after_memory-before_memory
    return res, decoding_time, decoding_memory

def process_report_summ(test_example, tokenizer):
    # print('test_example = ', test_example)
    report = test_example['report']
    summary = test_example['summary']
    # print("report = ", report)
    # print("summary = ", summary)
    query_instruction = "\n\nUser: The best summary of the above document is:\nSystem:"
    report_prompt = "You are a helpful, respectful and honest assistant specializing in summarization. Provide the best summary you can for this document.\n\nDocument: " + report
    input_ids_query_context = tokenizer(report_prompt, return_tensors="pt").input_ids
    input_ids_query = tokenizer(query_instruction, return_tensors="pt").input_ids
    return input_ids_query_context, input_ids_query


def process_report_summ_new(test_example, tokenizer):
    # print('test_example = ', test_example)
    report = test_example['report']
    summary = test_example['summary']
    # print("report = ", report)
    # print("summary = ", summary)
    # query_instruction = "\n\nUser: The summary of the above document with around 500 words is:\nSystem:"
    # report_prompt = "Summarize the following document to a summary with around 500 words.\n\nDocument: " + report
    # query_instruction = "\n\nUser: The best summary of the above document is:\nSystem:"
    query_instruction = "Now, write a one-page summary of the report.\nSummary:"
    report_prompt = "You are given a report by a government agency. Write a one-page summary of the report.\nReport: " + report + "\n\n"
    input_ids_query_context = tokenizer(report_prompt, return_tensors="pt").input_ids
    input_ids_query = tokenizer(query_instruction, return_tensors="pt").input_ids
    return input_ids_query_context, input_ids_query

def process_report_summ_new_in_context(test_example, tokenizer, train_examples):
    # print('test_example = ', test_example)
    report_prompt = "You are given reports by a government agency. Write one-page summaries of the reports.\n\n"
    for train_example in train_examples:
        tmp_prompt = "Report: " + test_example['report'] + "\n\n"
        tmp_prompt += "Now, write a one-page summary of the report.\nSummary: "
        tmp_prompt = tmp_prompt + train_example['summary'] + '\n\n'





    report = test_example['report']
    summary = test_example['summary']
    # print("report = ", report)
    # print("summary = ", summary)
    # query_instruction = "\n\nUser: The summary of the above document with around 500 words is:\nSystem:"
    # report_prompt = "Summarize the following document to a summary with around 500 words.\n\nDocument: " + report
    # query_instruction = "\n\nUser: The best summary of the above document is:\nSystem:"
    query_instruction = "Now, write a one-page summary of the report.\nSummary:"
    report_prompt += "Report: " + report + "\n\n"
    input_ids_query_context = tokenizer(report_prompt, return_tensors="pt").input_ids
    input_ids_query = tokenizer(query_instruction, return_tensors="pt").input_ids
    return input_ids_query_context, input_ids_query


def process_squality(test_example, tokenizer):
    # print('test_example = ', test_example)
    # print('document = ', test_example['document'])
    report_prompt = test_example['document']
    input_ids_query_context = tokenizer(report_prompt, return_tensors="pt").input_ids

    # query_prompt_ids_list = []
    query_prompts = []
    answer_prompts = []
    for each_question in test_example['questions']:
        tmp_answer_prompts = []
        # print(each_question['question_text'])
        query_prompts.append(tokenizer("\n\nUser: Answer the question in a paragraph.\n\nQuestion:\n\n" + each_question['question_text'] + "\n\nAnswer: ", return_tensors='pt').input_ids)

        # print("response = ", each_question['responses'])
        # pos = 0
        for each_response in each_question['responses']:
            # print("reponse = ", pos, each_response['response_text'])
            tmp_answer_prompts.append(each_response['response_text'])
            # pos += 1
        answer_prompts.append(tmp_answer_prompts)
        


    return input_ids_query_context, query_prompts, answer_prompts




def process_squality_in_context(test_example, tokenizer, train_examples):

    assert isinstance(train_examples, list)
    # print('test_example = ', test_example)
    # print('document = ', test_example['document'])
    report_prompt = ""
    for train_example in train_examples:
        tmp_prompt = "Document: " + test_example['document']
        each_question = test_example['questions'][0]
        # print("each_question = ", each_question)
        tmp_prompt = tmp_prompt +  "\n\nUser: Answer the question in a paragraph.\n\nQuestion:\n\n" + each_question['question_text'] + "\n\nAnswer: \n\n" + each_question['responses'][0]['response_text'] + "\n\n"
        report_prompt += tmp_prompt



    report_prompt = report_prompt + "Document: " +  test_example['document'] + "\n\n"
    input_ids_query_context = tokenizer(report_prompt, return_tensors="pt").input_ids

    # query_prompt_ids_list = []
    query_prompts = []
    answer_prompts = []
    for each_question in test_example['questions']:
        tmp_answer_prompts = []
        # print(each_question['question_text'])
        query_prompts.append(tokenizer("User: Answer the question in a paragraph.\n\nQuestion:\n\n" + each_question['question_text'] + "\n\nAnswer: ", return_tensors='pt').input_ids)

        # print("response = ", each_question['responses'])
        # pos = 0
        for each_response in each_question['responses']:
            # print("reponse = ", pos, each_response['response_text'])
            tmp_answer_prompts.append(each_response['response_text'])
            # pos += 1
        answer_prompts.append(tmp_answer_prompts)
        


    return input_ids_query_context, query_prompts, answer_prompts

    # return 
    # exit()
    # pass

def process_triviaqa(test_example, num_docs, model_name, tokenizer, max_context_length, min_context_length):
    if len(test_example["search_results"]["search_context"])<num_docs:
        print("Not enough retrieved docs. The example is skipped !")
        return None, None
    if num_docs==0:
        input_ids_query_context, context_words_lst, word_idx_to_subtoken_start_end_idx=None, None, None
    else:
        retrieved_docs=random.choices(test_example["search_results"]["search_context"], k=num_docs)
        query_context=construct_query_context(retrieved_docs)
        query_context=query_context.replace('\\n','\n').rstrip('\n')
        context_words_lst=nltk.word_tokenize(query_context)
        if model_name.find("pythia"):
            input_ids_query_context=tokenizer(query_context, return_tensors="pt").input_ids
            word_idx_to_subtoken_start_end_idx=None
        else:
            input_ids_query_context, word_idx_to_subtoken_start_end_idx=customized_tokenize(tokenizer, context_words_lst)
        # do not to device at first to save gpu memory.
        # input_ids_query_context=input_ids_query_context.to(device)
        if len(input_ids_query_context[0]) > max_context_length or len(input_ids_query_context[0]) < min_context_length:
            print("Context too long or too short. The example is skipped !")
            return None, None
    
    #prepare test query
    test_query=test_example["question"]
    query=construct_query(test_query)
    query=query.replace('\\n','\n').rstrip('\n')
    query_words_lst=nltk.word_tokenize(query)
    if model_name.find("pythia"):
        input_ids_query=tokenizer(query, return_tensors="pt").input_ids
    else:
        input_ids_query, _=customized_tokenize(tokenizer, query_words_lst)
    
    return input_ids_query_context, input_ids_query

def process_longbench(test_example, tokenizer, dataset_name):
    
    if 'samsum' in dataset_name:
        source, query = '\n'.join(test_example.split("\n")[:-1]), '\n'.join(test_example.split("\n")[-1:])
    elif 'trivia' in dataset_name:
        # print(test_example.split("\n"))
        source, query = '\n'.join(test_example.split("\n")[:-6]), '\n'.join(test_example.split("\n")[-6:])
        # print("source = ", source)
        # print("query = ", query)
    elif 'trec' in dataset_name:
        source, query = '\n'.join(test_example.split("\n")[:-2]), '\n'.join(test_example.split("\n")[-2:])
    elif 'qasper' in dataset_name:
        source, query = '\n\n'.join(test_example.split("\n\n")[:-3]), '\n\n'.join(test_example.split("\n\n")[-3:])
    # elif "repobench" in dataset_name or "lcc" in dataset_name:
    #     len0 = max(1, len(test_example.split("\n")) - 10)
    #     source, query = '\n\n'.join(test_example.split("\n")[:len0]), '\n\n'.join(test_example.split("\n")[len0:])
    #     print("query = ", query)
    elif "passage_retrieval" in dataset_name:
        source, query = '\n\n'.join(test_example.split("\n\n")[:-3]), '\n\n'.join(test_example.split("\n\n")[-3:])
        print("query = ", query)
    else:
        source, query = '\n\n'.join(test_example.split("\n\n")[:-2]), '\n\n'.join(test_example.split("\n\n")[-2:])

    # exit()
    input_ids_query_context = tokenizer(source, return_tensors="pt").input_ids
    input_ids_query = tokenizer(query, return_tensors="pt").input_ids
    return input_ids_query_context, input_ids_query

    # report = test_example['report']
    # summary = test_example['summary']


def load_longbench(datasets):
    # model2path = json.load(open("config/model2path.json", "r"))
    # model2maxlen = json.load(open("config/model2maxlen.json", "r"))
    # datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"]
    # world_size = torch.cuda.device_count()

    dataset2prompt = json.load(open("./dataset2prompt.json", "r"))
    # dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r"))
    ret_data = {}
    for dataset in datasets:
        data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test')
        # if not os.path.exists(f"pred_e/{model_name}"):
        #     os.makedirs(f"pred_e/{model_name}")
        # out_path = f"pred_e/{model_name}/{dataset}.jsonl"

        prompt_format = dataset2prompt[dataset]
        # max_gen = dataset2maxlen[dataset]
        data_all = [data_sample for data_sample in data]
        ret_data[dataset] = {"data": data_all, 'prompt_format': prompt_format}
    return ret_data

        # data_subsets = [data_all[i::world_size] for i in range(world_size)]
        # processes = []
        # for rank in range(world_size):
        #     p = mp.Process(target=get_pred, args=(rank, world_size, data_subsets[rank], max_length, \
        #                 max_gen, prompt_format, dataset, device, model_name, model2path, out_path))
        #     p.start()
        #     processes.append(p)
        # for p in processes:
        #     p.join()
        


# def get_pred(rank, world_size, data, max_length, max_gen, prompt_format, dataset, device, model_name, model2path, out_path):
#     device = torch.device(f'cuda:{rank}')
#     model, tokenizer = load_model_and_tokenizer(model2path[model_name], model_name, device)
#     for json_obj in tqdm(data):
#         prompt = prompt_format.format(**json_obj)
#         # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
#         tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
#         if "chatglm3" in model_name:
#             tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0]
#         if len(tokenized_prompt) > max_length:
#             half = int(max_length/2)
#             prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
#         if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
#             prompt = build_chat(tokenizer, prompt, model_name)
#         if "chatglm3" in model_name:
#             if dataset in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]:
#                 input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
#             else:
#                 input = prompt.to(device)
#         else:
#             input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
#         context_length = input.input_ids.shape[-1]
#         if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
#             output = model.generate(
#                 **input,
#                 max_new_tokens=max_gen,
#                 num_beams=1,
#                 do_sample=False,
#                 temperature=1.0,
#                 min_length=context_length+1,
#                 eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]],
#             )[0]
#         else:
#             output = model.generate(
#                 **input,
#                 max_new_tokens=max_gen,
#                 num_beams=1,
#                 do_sample=False,
#                 temperature=1.0,
#             )[0]
#         pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
#         pred = post_process(pred, model_name)
#         with open(out_path, "a", encoding="utf-8") as f:
#             json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]}, f, ensure_ascii=False)
#             f.write('\n')
#     dist.destroy_process_group()

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

def main(model_name, quantization_type, device_map, max_gen_length,question_attn_ratio, all_eval_tokens, segment_length, num_beams, dataset_name, dataset_split, num_examples, num_docs, min_context_length, max_context_length, cache_type, random_tokens_ratio=None, start_size=None, recent_size=None, k=None, print_cache=True, print_results=True, with_mean_attention=False, seg_generate=False, use_context_cache=False, new_gov_prompt=False, in_context_examples=0, final_three_dataset=False, no_compress_final_segment=False, multi_head=False):
    print("model_name = ", model_name)
    import math
    if math.floor(k) == k:
        k = int(k)
    device = torch.device("cuda") 
    seed_everything(42)
    datasets = ["qasper", "multifieldqa_en", "hotpotqa", "trec", "triviaqa", "samsum"]
    # datasets = ["qasper", "multifieldqa_en", "hotpotqa", "trec", "2wikimqa", "gov_report", "multi_news",  "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] 

    # "trec",
    datasets_load = load_longbench(datasets)
    # if dataset_name=="mandarjoshi/trivia_qa":
    #     dataset=load_dataset(dataset_name, data_files="rc/train-00001-of-00026.parquet")
    # elif dataset_name == 'needle':
    #     # pass
    #     tester = LLMNeedleHaystackTester()
    #     dataset_test = tester.construct_data()
    #     dataset = {"test": dataset_test}
    #     # print("dataset = ", dataset[:10])
    #     # exit()
    
    if quantization_type=='none':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, cache_dir='/network/scratch/y/yu.bai/.cache')
    elif quantization_type=='4bit':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, load_in_4bit=True, cache_dir='/network/scratch/y/yu.bai/.cache')
    elif quantization_type=='8bit':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, load_in_8bit=True, cache_dir='/network/scratch/y/yu.bai/.cache')
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    #position shift
    if "llama" in model.config.model_type:
        enable_llama_pos_shift_attention(model)
    elif "gpt_neox" in model.config.model_type:
        enable_gpt_neox_pos_shift_attention(model)
    elif "falcon" in model.config.model_type:
        enable_falcon_pos_shift_attention(model)
    elif "mistral" in model.config.model_type:
        enable_mistral_pos_shift_attention(model)
    elif "mpt" in model.config.model_type:
        pass
    elif "btlm" in model.config.model_type:
        pass
    else:
        raise ValueError(f"got {model.config.model_type}")
    
    #inference
    model.eval()
    loss_fn = CrossEntropyLoss(reduction="none")
    
    dataset2maxlen = json.load(open("./dataset2maxlen.json", "r"))
    #     # dataset_class = 
    # else:
    #     dataset=load_dataset(dataset_name, cache_dir='/network/scratch/y/yu.bai/.cache')
    start_time = time.time()
    for dataset_name in datasets:
        # print("dataset_name = ", dataset_name)
        # if 'trivia' not in dataset_name:
        #     continue
        test_examples = datasets_load[dataset_name]['data']
        dataset_prompt_format = datasets_load[dataset_name]['prompt_format']

        # test_examples=dataset[dataset_split]
        # if in_context_examples != 0:
        #     train_examples = [dataset['train'][i] for i in range(in_context_examples)]

        # exit()
        # print("test_examples = ", test_examples)

        max_gen_length = dataset2maxlen[dataset_name]
        num_correct=0
        num_total=0
        total_decoding_time=0
        total_decoding_memory=0
        total_stopwords_ratio=0
        context_length_lst=[]
        system_summaries = []
        reference_summaries = []
        all_classes = []
        lengths = []
        for test_example in test_examples:
            test_example_str = dataset_prompt_format.format(**test_example)
            if dataset_name not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
                test_example_str = build_chat(tokenizer, test_example_str, model_name)
            input_ids_query_context, input_ids_query = process_longbench(test_example_str, tokenizer, dataset_name) 
            input_ids_query = input_ids_query[:, 1:]
            # print("input_ids_query_context = ", input_ids_query_context)
            # print("input_ids_query = ", input_ids_query)
            all_classes = test_example["all_classes"]
            lengths.append(test_example["length"])
            print("length = ", test_example["length"])
            # break
            # if in_context_examples != 0:
            #     input_ids_query_context, input_ids_query = process_report_summ_new_in_context(test_example, tokenizer, train_examples)
            # else:
            #     if new_gov_prompt:
            #         input_ids_query_context, input_ids_query = process_report_summ_new(test_example, tokenizer)
            #     else:
            #         input_ids_query_context, input_ids_query = process_report_summ(test_example, tokenizer)

            # if input_ids_query_context is None:
            #     continue

            #prepare contextual documents

            input_ids_query=input_ids_query.to(device)
            
            # nxtline_id = tokenizer("\n")['input_ids'][0]
            nxtline_id = tokenizer.convert_tokens_to_ids('<0x0A>')

            seq_len = input_ids_query_context.size(1)
            print(f"seq_len: {seq_len}")
            pbar = tqdm(range(0, math.ceil(seq_len/ segment_length)))
            nlls = []
            for idx in pbar:
                if idx == 0:
                    input_ids_seg1 = input_ids_query_context[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                    input_ids_seg2 = input_ids_query_context[:, (idx + 1) * segment_length : min(seq_len, (idx + 2) * segment_length)].to(device)
                    selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(input_ids_seg1, input_ids_seg2, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, with_mean_attention=with_mean_attention, multi_head=multi_head)
                    logits = logits.view(-1, model.config.vocab_size)
                    labels = input_ids_query_context[:, idx * segment_length + 1 : (idx + 2) * segment_length + 1].to(logits.device).view(-1)
                elif idx == 1:
                    continue
                else:
                    input_ids_seg2 = input_ids_query_context[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                    selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(None, input_ids_seg2, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values, with_mean_attention=with_mean_attention, multi_head=multi_head)
                    logits = logits.view(-1, model.config.vocab_size)
                    labels = input_ids_query_context[:, idx * segment_length + 1 : (idx + 1) * segment_length + 1].to(logits.device).view(-1)
                neg_log_likelihood = loss_fn(logits[:labels.size(-1),], labels)
                print("neg_log_likelihood = ", neg_log_likelihood.mean())
                # exit()
                nlls.append(neg_log_likelihood)
                pbar.set_description(
                    f"nll: {neg_log_likelihood.mean().item():.2f}, ppl: {torch.exp(neg_log_likelihood.mean()).item():.2f}"
                )
            # exit()
                # num_eval_tokens += 1
                # if all_eval_tokens is not None and num_eval_tokens >= all_eval_tokens:
                #     break
            # ppl = torch.exp(torch.stack(nlls).mean())
            # print(ppl.item())



            # print("here?")
            #cache-based decoding
            # select one last time for the query

            if not no_compress_final_segment:
                selected_past_key_values, pred_token_idx, given_indices, _ =get_past_key_values_no_current_segment(None, input_ids_query, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values, with_mean_attention=with_mean_attention, multi_head=multi_head)
            # get the query states
            # selected_past_key_values, pred_token_idx, given_indices, _ =get_past_key_values(None, input_ids_query, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values)
            # print("or here?")
            with torch.no_grad():
                outputs = model(input_ids_query, output_attentions=False, use_cache=True, past_key_values=selected_past_key_values)
            selected_past_key_values = outputs.past_key_values
            pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) 
            # remove the cache of the final prediction token.
            # selected_past_key_values = tuple(tuple(kv[:, :, :, :] for kv in layer) for layer_idx, layer in enumerate(selected_past_key_values))
            end_token = nxtline_id if dataset_name == "samsum" else None

            # print("or here2?") 
            if seg_generate:
                result, decoding_time, decoding_memory=decoding_with_past_kv(model, tokenizer, selected_past_key_values, pred_token_idx, max_gen_length, num_beams=1, end_token=end_token, input_ids_query=input_ids_query, k=k, start_size=start_size, recent_size=recent_size, cache_type=cache_type, with_mean_attention=with_mean_attention)     
                # result = tokenizer.decode(torch.cat([input_ids_query[0, :].detach().cpu(), generated_ids]), skip_special_tokens=True)
            else:
                generated_ids, decoding_time, decoding_memory=decoding(model, tokenizer, selected_past_key_values, pred_token_idx.to(device), max_gen_length, num_beams=1, end_token=end_token)     
                result = tokenizer.decode(generated_ids, skip_special_tokens=True)

            # print("or here3?") 
            # if new_gov_prompt:
            #     final_answer=get_the_final_answer_summary_new(result).strip()
            # else:
            # final_answer=result.split('\n')[0].strip()
            final_answer=result.strip()


            # print("or here4?") 
            print("final_answer = ", final_answer)
            if print_results:
                print(result)
            if print_cache:
                if input_ids_query_context!=None:
                    input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
                else:
                    input_ids=input_ids_query
                input_ids=input_ids[0][given_indices].detach().cpu()
                print(tokenizer.decode(input_ids, skip_special_tokens=True))
            # exit()

            #evaluate answers
            # for gold_answer in test_example["answer"]["aliases"]+test_example["answer"]["normalized_aliases"]:
            #     if final_answer.find(gold_answer)!=-1:
            #         num_correct+=1
            #         break
            system_summaries.append(final_answer)
            reference_summaries.append(test_example['answers'])
            # if "govreport-summarization" in dataset_name:
            # else:
            #     print("not implemented")
            #     exit()


            total_decoding_time+=decoding_time
            total_decoding_memory+=decoding_memory
            context_length_lst.append(len(input_ids_query_context[0]))
        
            




                # if cache_type=="stopwords":
                #     total_stopwords_ratio+=(len(given_indices)-len(input_ids_query[0]))/len(input_ids_query_context[0])
            num_total+=1
            #     # exit()
            # if num_total>=num_examples:
            #     break
            
        # continue

        # if metric_name == 'bleu':
        #     metric = evaluate.load("bleu")
        # elif metric_name == 'rouge':
        #     metric = evaluate.load("rouge")
        # else:
            # default
        # metric = evaluate.load("rouge")

        def postprocess_text(preds, labels):
            preds = [pred.strip() for pred in preds]
            labels = [label.strip() for label in labels]

            # rougeLSum expects newline after each sentence
            preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
            labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

            return preds, labels
        # decoded_preds, decoded_labels = postprocess_text(system_summaries, reference_summaries)
        # result = metric.compute(predictions=decoded_preds, references=decoded_labels)
        result = scorer_e(dataset_name, system_summaries, reference_summaries, lengths, all_classes)
        # result = {k: round(v * 100, 4) for k, v in result.items()}
        print("results = ")
        print(result)
        for each in result:
            print(each, end=' ')
        print()
        for each in result:
            print(result[each], end=' ')
        print()
        print("Experimental Settings")
        print("Model name:", model_name)
        print("Quantization:", quantization_type)
        print("Dataset name:", dataset_name)
        print("Dataset split:", dataset_split)
        print("Num examples:", num_examples)
        print("Num docs:", num_docs)
        print("Min context length", min_context_length)
        print("Max context length", max_context_length)
        print("Max gen length", max_gen_length)
        print("Cache_type", cache_type)
        if cache_type=="random":
            print("Random tokens ratio", random_tokens_ratio)
        elif cache_type=="startrecent":
            print("Start size", start_size)
            print("Recent size", recent_size)
        elif cache_type=="topk":
            print("K", k)
        print("-----------------------------------------")
        print("Experimental Results")
        print("Average Accuracy:", num_correct/num_total)
        print("Average Decoding Time:", total_decoding_time/num_total)
        print("Average Decoding Memory:", total_decoding_memory/num_total)
        print("Average Context Length:", sum(context_length_lst)/num_total)
        print("Min Context Length:", min(context_length_lst))
        print("Max Context Length:", max(context_length_lst))
        if cache_type=="stopwords":
            print("Average Stopwords Ratio:", total_stopwords_ratio/num_total) 
        
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Total running time: {elapsed_time} seconds")
        #     #prepare contextual documents

        #     if len(test_example["search_results"]["search_context"])<num_docs:
        #         print("Not enough retrieved docs. The example is skipped !")
        #         continue
        #     if num_docs==0:
        #         input_ids_query_context, context_words_lst, word_idx_to_subtoken_start_end_idx=None, None, None
        #     else:
        #         retrieved_docs=random.choices(test_example["search_results"]["search_context"], k=num_docs)
        #         query_context=construct_query_context(retrieved_docs)
        #         query_context=query_context.replace('\\n','\n').rstrip('\n')
        #         context_words_lst=nltk.word_tokenize(query_context)
        #         if model_name.find("pythia"):
        #             input_ids_query_context=tokenizer(query_context, return_tensors="pt").input_ids
        #             word_idx_to_subtoken_start_end_idx=None
        #         else:
        #             input_ids_query_context, word_idx_to_subtoken_start_end_idx=customized_tokenize(tokenizer, context_words_lst)
        #         input_ids_query_context=input_ids_query_context.to(device)
        #         if len(input_ids_query_context[0]) > max_context_length or len(input_ids_query_context[0]) < min_context_length:
        #             print("Context too long or too short. The example is skipped !")
        #             continue
            
            #prepare test query
        #     test_query=test_example["question"]
        #     query=construct_query(test_query)
        #     query=query.replace('\\n','\n').rstrip('\n')
        #     query_words_lst=nltk.word_tokenize(query)
        #     if model_name.find("pythia"):
        #         input_ids_query=tokenizer(query, return_tensors="pt").input_ids
        #     else:
        #         input_ids_query, _=customized_tokenize(tokenizer, query_words_lst)
        #     input_ids_query=input_ids_query.to(device)
            
        #     #cache-based decoding
            # selected_past_key_values, pred_token_idx, given_indices=get_past_key_values(input_ids_query_context, input_ids_query, cache_type, model, context_words_lst, word_idx_to_subtoken_start_end_idx, random_tokens_ratio, start_size, recent_size, k)
        #     generated_ids, decoding_time, decoding_memory=decoding(model, tokenizer, selected_past_key_values, pred_token_idx, max_gen_length, num_beams)     
        #     result = tokenizer.decode(torch.cat([input_ids_query[0].detach().cpu(), generated_ids]), skip_special_tokens=True)
        #     final_answer=get_the_final_answer(result)
        #     if print_results:
        #         print(result)
        #     if print_cache:
        #         if input_ids_query_context!=None:
        #             input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
        #         else:
        #             input_ids=input_ids_query
        #         input_ids=input_ids[0][given_indices].detach().cpu()
        #         print(tokenizer.decode(input_ids, skip_special_tokens=True))

        #     #evaluate answers
        #     for gold_answer in test_example["answer"]["aliases"]+test_example["answer"]["normalized_aliases"]:
        #         if final_answer.find(gold_answer)!=-1:
        #             num_correct+=1
        #             break
        #     total_decoding_time+=decoding_time
        #     total_decoding_memory+=decoding_memory
        #     context_length_lst.append(len(input_ids_query_context[0]))
        #     if cache_type=="stopwords":
        #         total_stopwords_ratio+=(len(given_indices)-len(input_ids_query[0]))/len(input_ids_query_context[0])
        #     num_total+=1
        #     if num_total>=num_examples:
        #         break
            

        # print("Experimental Settings")
        # print("Model name:", model_name)
        # print("Quantization:", quantization_type)
        # print("Dataset name:", dataset_name)
        # print("Dataset split:", dataset_split)
        # print("Num examples:", num_examples)
        # print("Num docs:", num_docs)
        # print("Min context length", min_context_length)
        # print("Max context length", max_context_length)
        # print("Max gen length", max_gen_length)
        # print("Cache_type", cache_type)
        # if cache_type=="random":
        #     print("Random tokens ratio", random_tokens_ratio)
        # elif cache_type=="startrecent":
        #     print("Start size", start_size)
        #     print("Recent size", recent_size)
        # elif cache_type=="topk":
        #     print("K", k)
        # print("-----------------------------------------")
        # print("Experimental Results")
        # print("Average Accuracy:", num_correct/num_total)
        # print("Average Decoding Time:", total_decoding_time/num_total)
        # print("Average Decoding Memory:", total_decoding_memory/num_total)
        # print("Average Context Length:", sum(context_length_lst)/num_total)
        # print("Min Context Length:", min(context_length_lst))
        # print("Max Context Length:", max(context_length_lst))
        # if cache_type=="stopwords":
        #     print("Average Stopwords Ratio:", total_stopwords_ratio/num_total) 
        

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    #model config
    parser.add_argument('--model_name', dest='model_name', action='store', required=False, default='mistralai/Mistral-7B-v0.1')
    parser.add_argument('--quantization_type', dest='quantization_type', action='store', required=False, default='none') #none, 4bit, 8bit
    parser.add_argument('--device_map', dest='device_map', action='store', required=False, default="auto")
    parser.add_argument('--max_gen_length', dest='max_gen_length', action='store', required=False, default=100, type=int)
    parser.add_argument('--num_beams', dest='num_beams', action='store', required=False, default=4)
    parser.add_argument('--segment_length', dest='segment_length', action='store', required=False, default=1000, type=int)
    
    #data config
    parser.add_argument('--dataset_name', dest='dataset_name', action='store', required=False, default='mandarjoshi/trivia_qa')
    parser.add_argument('--dataset_split', dest='dataset_split', action='store', required=False, default='test')
    parser.add_argument('--num_examples', dest='num_examples', action='store', required=False, default=500, type=int) #test on how many examples
    parser.add_argument('--num_docs', dest='num_docs', action='store', required=False, default=1,  type=int ) #number of retrieved docs for each test query
    parser.add_argument('--min_context_length', dest='min_context_length', action='store', required=False, default=0) #the range of context length of retrieved documents
    parser.add_argument('--max_context_length', dest='max_context_length', action='store', required=False, default=10000) 
    parser.add_argument('--all_eval_tokens', dest='all_eval_tokens', action='store', required=False, default=4000000) 
   
    #cache config
    """
    all: all the context tokens
    zeroshot: no context tokens and query does not attend to any context
    nocontext: no context tokens but query does attend to context
    stopwords: only the stopwords in the context
    random: random tokens in the context, controled by random_tokens_ratio
    startrecent: start tokens and recent tokens in the context, start_size and recent_size can be ratio between 0-1 or absolute numbers
    topk: topk tokens with highest attention weights in context, k can be a ratio between 0-1 or absolute numbers
    """
    parser.add_argument("--cache_type", dest='cache_type', action='store', required=False, default='topk')
    parser.add_argument("--random_tokens_ratio", dest='random_tokens_ratio', action='store', required=False, default=0.1)
    parser.add_argument("--question_attn_ratio", dest='question_attn_ratio', action='store', required=False, default=0.5)
    parser.add_argument("--start_size", dest="start_size", action="store", required=False, default=0, type=int )
    parser.add_argument("--recent_size", dest="recent_size", action="store", required=False, default=0.3)
    parser.add_argument("--k", dest="k", action="store", required=False, default=100, type=float)
    parser.add_argument('--with_mean_attention', dest='with_mean_attention', action='store_const', const=True, default=False)
    parser.add_argument('--use_context_cache', dest='use_context_cache', action='store_const', const=True, default=False)
    parser.add_argument('--no_compress_final_segment', dest='no_compress_final_segment', action='store_const', const=True, default=False) 
    parser.add_argument('--seg_generate', dest='seg_generate', action='store_const', const=True, default=False) 
    parser.add_argument('--new_gov_prompt', dest='new_gov_prompt', action='store_const', const=True, default=False)
    parser.add_argument('--final_three_dataset', dest='final_three_dataset', action='store_const', const=True, default=False)
    parser.add_argument('--multi_head', dest='multi_head', action='store_const', const=True, default=False)
    parser.add_argument("--in_context_examples", dest="in_context_examples", action="store", required=False, default=0, type=int)
    #print config
    parser.add_argument('--print_cache', dest='print_cache', action='store', required=False, default=False)
    parser.add_argument('--print_results', dest='print_results', action='store', required=False, default=True)
    
    args = parser.parse_args()
    args = vars(args)
    main(**args)