import sys
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import transformers
import argparse
import random
import nltk
from tqdm import tqdm
import math
# import metric
import evaluate
from utils import normalize_answer, customized_tokenize, is_stopword, flatten, get_peak_memory, get_flops
from cache_based_inference.pos_shift.modify_llama import enable_llama_pos_shift_attention
from cache_based_inference.pos_shift.modify_falcon import enable_falcon_pos_shift_attention
from cache_based_inference.pos_shift.modify_gpt_neox import enable_gpt_neox_pos_shift_attention
from cache_based_inference.pos_shift.modify_falcon import enable_falcon_pos_shift_attention
from cache_based_inference.pos_shift.modify_mistral import enable_mistral_pos_shift_attention
from cache_based_inference.pos_shift.modify_gpt_j import enable_gpt_j_pos_shift_attention
# from llm_needle_haystack_tester import
from torch.nn import CrossEntropyLoss
from llm_needle_haystack_tester import LLMNeedleHaystackTester

def construct_query_context(retrieved_docs):
    prompt="Answer the query according to the following documents."
    for i in range(len(retrieved_docs)):
        prompt+="Document: "+retrieved_docs[i]+'\n'
    return prompt

def construct_query(query):
    prompt=""
    prompt+="Q: "+query+'\n'
    prompt+="A: The answer is"
    return prompt

def get_the_final_answer(result):
    ans_sent_start_idx=result.find('The answer is')+len("The answer is")
    ans_sent_end_idx=result.find('.', ans_sent_start_idx)
    if ans_sent_end_idx==-1:
        ans_sent_end_idx=len(result)
    final_answer=result[ans_sent_start_idx:ans_sent_end_idx]
    return normalize_answer(final_answer)
    
def get_the_final_answer_summary(result):
    ans_sent_start_idx=result.find('System:')+len("System:")
    # ans_sent_end_idx=result.find('.', ans_sent_start_idx)
    # if ans_sent_end_idx==-1:
    ans_sent_end_idx=len(result)
    final_answer=result[ans_sent_start_idx:ans_sent_end_idx]
    return final_answer



def get_the_final_answer_summary_new(result):
    ans_sent_start_idx=result.find('Summary:')+len("Summary:")
    # ans_sent_end_idx=result.find('.', ans_sent_start_idx)
    # if ans_sent_end_idx==-1:
    ans_sent_end_idx=len(result)
    final_answer=result[ans_sent_start_idx:ans_sent_end_idx]
    return final_answer

def get_the_final_answer_squality(result):
    ans_sent_start_idx=result.find('Answer:')+len("Answer:")
    # ans_sent_end_idx=result.find('.', ans_sent_start_idx)
    # if ans_sent_end_idx==-1:
    ans_sent_end_idx=len(result)
    final_answer=result[ans_sent_start_idx:ans_sent_end_idx]
    return final_answer
# 
@torch.no_grad()
def get_past_key_values(input_ids_query_context, input_ids_query, cache_type, model, context_words_lst=None, word_idx_to_subtoken_start_end_idx=None, random_tokens_ratio=None, start_size=None, recent_size=None, k=None, past_key_values=None, with_mean_attention=False):
    #context+query一起inference，得到context和query的kv，然后筛选context的kv，query的kv此时attend full context，最后用context筛选后的kv+query kv去做生成
    if cache_type=="zeroshot":
        outputs = model(input_ids_query, output_attentions=True, use_cache=True)
        past_key_values = outputs.past_key_values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        given_indices = range(len(input_ids_query[0]))
        return  past_key_values, pred_token_idx, given_indices
    
    device=input_ids_query.device
    if input_ids_query_context!=None:
        input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
    else:
        input_ids=input_ids_query
    
    past_key_values_org = past_key_values 
    outputs = model(input_ids, output_attentions=True, use_cache=True, past_key_values=past_key_values)
    past_key_values = outputs.past_key_values
    
    if cache_type=="all":
        given_indices = range(len(input_ids[0])) 
    elif cache_type=="nocontext":
        given_indices = range(len(input_ids_query_context[0]), len(input_ids[0]))
    elif cache_type=="stopwords":
        given_indices=[0]
        assert context_words_lst!=None
        assert word_idx_to_subtoken_start_end_idx!=None
        for word_idx, word in enumerate(context_words_lst):
            if is_stopword(word):
                subtoken_start_idx, subtoken_end_idx=word_idx_to_subtoken_start_end_idx[word_idx]
                given_indices.extend(range(subtoken_start_idx, subtoken_end_idx+1))       
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
        given_indices=sorted(given_indices)
    elif cache_type=="random":
        assert random_tokens_ratio!=None
        random_tokens_num=int(random_tokens_ratio*len(input_ids_query_context[0]))
        given_indices=random.sample(range(len(input_ids_query_context[0])), k=random_tokens_num)
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
        given_indices=sorted(given_indices)
    elif cache_type=="startrecent":
        assert start_size!=None
        assert recent_size!=None
        if type(start_size)!=int:
            start_size=int(start_size*len(input_ids_query_context[0]))
        if type(recent_size)!=int:
            recent_size=int(recent_size*len(input_ids_query_context[0]))
        given_indices=list(range(start_size))
        given_indices.extend(range(len(input_ids_query_context[0])-recent_size, len(input_ids_query_context[0])))
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
        given_indices=sorted(given_indices)
    elif cache_type=="topk":
        if type(start_size)!=int:
            start_size=int(start_size*len(input_ids_query_context[0]))
        assert k!=None
        if type(k)!=int:
            if past_key_values is None:
                k=int(k*len(input_ids_query_context[0]))
            elif input_ids_query_context is None:
                k=int(k*past_key_values[0][0].size(2))
            else:
                k=int(k*(len(input_ids_query_context[0]) + past_key_values[0][0].size(2)))

        
        attention=outputs.attentions
        given_indices_for_each_layer=[]
        # print('start_size = ', start_size)
        given_indices_start=list(range(start_size))
        for attention_per_layer in attention:
            if past_key_values_org is None:
                # considering all the tokens in the query_text
                if with_mean_attention:
                    att_weights=torch.mean(attention_per_layer[0,:,:,:len(input_ids_query_context[0])], dim=0)
                    att_weights=torch.mean(att_weights[len(input_ids_query_context[0]):,:len(input_ids_query_context[0])], dim=0)
                else:

                    att_weights=torch.mean(attention_per_layer[0,:,-1,:len(input_ids_query_context[0])], dim=0)


                topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, len(input_ids_query_context[0]) - start_size))
                topk_indices= topk_indices + start_size
                # remove the cache of the last token
                # no, removed in the main function
                topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device), topk_indices.long(), torch.tensor(range(len(input_ids_query_context[0]), len(input_ids[0])), dtype=torch.long, device=device)]) 
                # topk_indices=torch.cat([topk_indices.long(), torch.tensor(range(len(input_ids_query_context[0]), len(input_ids[0]) - 1), dtype=torch.long, device=device)]) 
            else:
                # print("past_key_values = ", past_key_values[0][0].size())
                # exit()

                # considering all the tokens in the query_text
                if with_mean_attention:
                    att_weights=torch.mean(attention_per_layer[0,:,:,:past_key_values_org[0][0].size(2)], dim=0)
                    att_weights=torch.mean(att_weights[:,:past_key_values_org[0][0].size(2)], dim=0)
                else:
                    att_weights=torch.mean(attention_per_layer[0,:,-1,:past_key_values_org[0][0].size(2)], dim=0)

                # print("atten_weights = ", att_weights.size())
                # att_weights = att_weights[start_size:]
                # topk_att_weights, topk_indices=torch.topk(att_weights[:,], k=k)
                # print("topk_indices initial = ", topk_indices)

                topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, att_weights[start_size:,].size(0)))
                # exit()
                # print("topk_indices = ", topk_indices)
                topk_indices= topk_indices + start_size
                # print("topk_indices after = ", topk_indices)
                # exit()
                # print(torch.tensor(range(len(past_key_values[0][0][0][0]), len(input_ids[0])), dtype=torch.int, device=device))

                # remove the cache of the last token
                # no, removed in the main function
                topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(past_key_values_org[0][0][0][0]),  len(past_key_values_org[0][0][0][0]) + len(input_ids[0])), dtype=torch.long, device=device)])
                # topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(past_key_values_org[0][0][0][0]),  len(past_key_values_org[0][0][0][0]) + len(input_ids[0]) - 1), dtype=torch.long, device=device)])

            # topk_indices=torch.cat([given_indices_start, topk_indices]), len(input_ids[0]))], dtype=torch.long, device=device)])
            
            topk_indices, _=torch.sort(topk_indices)
            # print("topk_indices = ", topk_indices)
            given_indices_for_each_layer.append(topk_indices)
            
    
    if cache_type=="topk":    
        selected_past_key_values = tuple(tuple(kv[:, :, given_indices_for_each_layer[layer_idx], :] for kv in layer) for layer_idx, layer in enumerate(past_key_values))

        # given_indices1=list(range(start_size))
        given_indices = flatten(given_indices_for_each_layer)


    else:
        selected_past_key_values = tuple(tuple(kv[:, :, given_indices, :] for kv in layer) for layer in past_key_values)
    
    # instead of selecting the last pred_token, we use the last prompt token as the start, and remove the cache of it.
    # pred_token_idx = input_ids[:, -1].unsqueeze(1)
    # outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    # print("prev = ", outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1))
    # print("current = ", pred_token_idx)
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) 
    return selected_past_key_values, pred_token_idx, given_indices, outputs.logits


@torch.no_grad()
def get_past_key_values_no_current_segment(input_ids_query_context, input_ids_query, cache_type, model, context_words_lst=None, word_idx_to_subtoken_start_end_idx=None, random_tokens_ratio=None, start_size=None, recent_size=None, k=None, past_key_values=None, with_mean_attention=False):
    #context+query一起inference，得到context和query的kv，然后筛选context的kv，query的kv此时attend full context，最后用context筛选后的kv+query kv去做生成
    if cache_type=="zeroshot":
        outputs = model(input_ids_query, output_attentions=True, use_cache=True)
        past_key_values = outputs.past_key_values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        given_indices = range(len(input_ids_query[0]))
        return  past_key_values, pred_token_idx, given_indices
    
    device=input_ids_query.device
    if input_ids_query_context!=None:
        input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
    else:
        input_ids=input_ids_query
    
    past_key_values_org = past_key_values 
    outputs = model(input_ids, output_attentions=True, use_cache=True, past_key_values=past_key_values)
    past_key_values = outputs.past_key_values
    
    if cache_type=="all":
        given_indices = range(len(input_ids[0])) 
    elif cache_type=="nocontext":
        given_indices = range(len(input_ids_query_context[0]), len(input_ids[0]))
    elif cache_type=="stopwords":
        given_indices=[0]
        assert context_words_lst!=None
        assert word_idx_to_subtoken_start_end_idx!=None
        for word_idx, word in enumerate(context_words_lst):
            if is_stopword(word):
                subtoken_start_idx, subtoken_end_idx=word_idx_to_subtoken_start_end_idx[word_idx]
                given_indices.extend(range(subtoken_start_idx, subtoken_end_idx+1))       
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
        given_indices=sorted(given_indices)
    elif cache_type=="random":
        assert random_tokens_ratio!=None
        random_tokens_num=int(random_tokens_ratio*len(input_ids_query_context[0]))
        given_indices=random.sample(range(len(input_ids_query_context[0])), k=random_tokens_num)
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
        given_indices=sorted(given_indices)
    elif cache_type=="startrecent":
        assert start_size!=None
        assert recent_size!=None
        if type(start_size)!=int:
            start_size=int(start_size*len(input_ids_query_context[0]))
        if type(recent_size)!=int:
            recent_size=int(recent_size*len(input_ids_query_context[0]))
        given_indices=list(range(start_size))
        given_indices.extend(range(len(input_ids_query_context[0])-recent_size, len(input_ids_query_context[0])))
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
        given_indices=sorted(given_indices)
    elif cache_type=="topk":
        if type(start_size)!=int:
            start_size=int(start_size*len(input_ids_query_context[0]))
        assert k!=None
        if type(k)!=int:
            if past_key_values is None:
                k=int(k*len(input_ids_query_context[0]))
            elif input_ids_query_context is None:
                k=int(k*past_key_values[0][0].size(2))
            else:
                k=int(k*(len(input_ids_query_context[0]) + past_key_values[0][0].size(2)))

        
        attention=outputs.attentions
        given_indices_for_each_layer=[]
        # print('start_size = ', start_size)
        given_indices_start=list(range(start_size))
        for attention_per_layer in attention:
            if past_key_values_org is None:
                # considering all the tokens in the query_text
                if with_mean_attention:
                    att_weights=torch.mean(attention_per_layer[0,:,:,:len(input_ids_query_context[0])], dim=0)
                    att_weights=torch.mean(att_weights[len(input_ids_query_context[0]):,:len(input_ids_query_context[0])], dim=0)
                else:

                    att_weights=torch.mean(attention_per_layer[0,:,-1,:len(input_ids_query_context[0])], dim=0)


                topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, len(input_ids_query_context[0]) - start_size))
                topk_indices= topk_indices + start_size
                # remove the cache of the last token
                # no, removed in the main function
                topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device), topk_indices.long(), torch.tensor(range(len(input_ids_query_context[0]), len(input_ids[0])), dtype=torch.long, device=device)]) 
                # topk_indices=torch.cat([topk_indices.long(), torch.tensor(range(len(input_ids_query_context[0]), len(input_ids[0]) - 1), dtype=torch.long, device=device)]) 
            else:
                # print("past_key_values = ", past_key_values[0][0].size())
                # exit()

                # considering all the tokens in the query_text
                if with_mean_attention:
                    att_weights=torch.mean(attention_per_layer[0,:,:,:past_key_values_org[0][0].size(2)], dim=0)
                    att_weights=torch.mean(att_weights[:,:past_key_values_org[0][0].size(2)], dim=0)
                else:
                    att_weights=torch.mean(attention_per_layer[0,:,-1,:past_key_values_org[0][0].size(2)], dim=0)

                # print("atten_weights = ", att_weights.size())
                # att_weights = att_weights[start_size:]
                # topk_att_weights, topk_indices=torch.topk(att_weights[:,], k=k)
                # print("topk_indices initial = ", topk_indices)

                topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, att_weights[start_size:,].size(0)))
                # exit()
                # print("topk_indices = ", topk_indices)
                topk_indices= topk_indices + start_size
                # print("topk_indices after = ", topk_indices)
                # exit()
                # print(torch.tensor(range(len(past_key_values[0][0][0][0]), len(input_ids[0])), dtype=torch.int, device=device))

                # remove the cache of the last token
                # no, removed in the main function
                topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices])
                # topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(past_key_values_org[0][0][0][0]),  len(past_key_values_org[0][0][0][0]) + len(input_ids[0]) - 1), dtype=torch.long, device=device)])

            # topk_indices=torch.cat([given_indices_start, topk_indices]), len(input_ids[0]))], dtype=torch.long, device=device)])
            
            topk_indices, _=torch.sort(topk_indices)
            # print("topk_indices = ", topk_indices)
            given_indices_for_each_layer.append(topk_indices)
            
    
    if cache_type=="topk":    
        selected_past_key_values = tuple(tuple(kv[:, :, given_indices_for_each_layer[layer_idx], :] for kv in layer) for layer_idx, layer in enumerate(past_key_values))

        # given_indices1=list(range(start_size))
        given_indices = flatten(given_indices_for_each_layer)


    else:
        selected_past_key_values = tuple(tuple(kv[:, :, given_indices, :] for kv in layer) for layer in past_key_values)
    
    # instead of selecting the last pred_token, we use the last prompt token as the start, and remove the cache of it.
    # pred_token_idx = input_ids[:, -1].unsqueeze(1)
    # outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    # print("prev = ", outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1))
    # print("current = ", pred_token_idx)
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    return selected_past_key_values, pred_token_idx, given_indices, outputs.logits


@torch.no_grad()
def decoding(model, tokenizer, past_key_values, pred_token_idx, max_gen_len, num_beams, end_token=None):
    end_token=tokenizer.eos_token_id if end_token is None else end_token
    device=model.device
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    before_memory=get_peak_memory(device)
    starter, ender=torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #measure the average inference time of a doc with batch_size=1
    starter.record()
    
    outputs = model.generate(
        input_ids=pred_token_idx,
        attention_mask=torch.ones(past_key_values[0][0].size(-2)+1).unsqueeze(0),
        do_sample=False,
        temperature=0.8,
        top_p=0.75,
        top_k=40,
        num_beams=num_beams,
        eos_token_id=end_token,
        max_new_tokens=max_gen_len,
        past_key_values=past_key_values,
    )
    
    ender.record()
    torch.cuda.synchronize()
    decoding_time=starter.elapsed_time(ender)
    after_memory=get_peak_memory(device)
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.empty_cache()
    decoding_memory=after_memory-before_memory
    return outputs[0].detach().cpu(), decoding_time, decoding_memory


@torch.no_grad()
def decoding_with_past_kv(model, tokenizer, past_key_values, pred_token_idx, max_gen_len, num_beams, end_token=None, input_ids_query=None, k=None, start_size=None, recent_size=None, cache_type=None, with_mean_attention=None):
    end_token=tokenizer.eos_token_id if end_token is None else end_token
    device=model.device
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    before_memory=get_peak_memory(device)
    starter, ender=torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #measure the average inference time of a doc with batch_size=1
    starter.record()
    
    # print("past_key_values[0][0].size() = ", past_key_values[0][0].size())
    # print("pred_token_idx.size() = ", pred_token_idx.size())
    pos = 0
    res = ""
    while True:
        print("pos for current generation: ", pos)
        outputs = model.generate(
            input_ids=pred_token_idx,
            attention_mask=torch.ones(past_key_values[0][0].size(-2) + 1).unsqueeze(0),
            do_sample=False,
            temperature=0.8,
            top_p=0.75,
            top_k=40,
            num_beams=num_beams,
            eos_token_id=end_token,
            max_new_tokens=max_gen_len,
            past_key_values=past_key_values,
        )


        print("outputs = ", outputs)

        print("tokenizer.eos_token_id = ", tokenizer.eos_token_id)

        if pos == 0:
            result = tokenizer.decode(torch.cat([input_ids_query[0, :].detach().cpu(), outputs[0, :].cpu()]), skip_special_tokens=True)
        else:
            result = tokenizer.decode(outputs[0, :].cpu(), skip_special_tokens=True)

        if result[0] != ' ':
            result = ' ' + result

            # res[]


        res += result
        print('results = ', result)




        # pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        past_key_values, pred_token_idx, given_indices, logits = get_past_key_values(None, outputs, cache_type, model, None, None, None, start_size, recent_size, k, past_key_values=past_key_values, with_mean_attention=with_mean_attention)


        pos += 1
        if pos > 20 or tokenizer.eos_token_id == outputs[0, -1]:
            break
    # print("pred_token_idx = ", pred_token_id)

    # exit()
    ender.record()
    torch.cuda.synchronize()
    decoding_time=starter.elapsed_time(ender)
    after_memory=get_peak_memory(device)
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.empty_cache()
    decoding_memory=after_memory-before_memory
    return res, decoding_time, decoding_memory

def process_report_summ(test_example, tokenizer):
    # print('test_example = ', test_example)
    report = test_example['report']
    summary = test_example['summary']
    # print("report = ", report)
    # print("summary = ", summary)
    query_instruction = "\n\nUser: The best summary of the above document is:\nSystem:"
    report_prompt = "You are a helpful, respectful and honest assistant specializing in summarization. Provide the best summary you can for this document.\n\nDocument: " + report
    input_ids_query_context = tokenizer(report_prompt, return_tensors="pt").input_ids
    input_ids_query = tokenizer(query_instruction, return_tensors="pt").input_ids
    return input_ids_query_context, input_ids_query


def process_report_summ_new(test_example, tokenizer):
    # print('test_example = ', test_example)
    report = test_example['report']
    summary = test_example['summary']
    # print("report = ", report)
    # print("summary = ", summary)
    # query_instruction = "\n\nUser: The summary of the above document with around 500 words is:\nSystem:"
    # report_prompt = "Summarize the following document to a summary with around 500 words.\n\nDocument: " + report
    # query_instruction = "\n\nUser: The best summary of the above document is:\nSystem:"
    query_instruction = "Now, write a one-page summary of the report.\nSummary:"
    report_prompt = "You are given a report by a government agency. Write a one-page summary of the report.\nReport: " + report + "\n\n"
    input_ids_query_context = tokenizer(report_prompt, return_tensors="pt").input_ids
    input_ids_query = tokenizer(query_instruction, return_tensors="pt").input_ids
    return input_ids_query_context, input_ids_query

def process_report_summ_new_in_context(test_example, tokenizer, train_examples):
    # print('test_example = ', test_example)
    report_prompt = "You are given reports by a government agency. Write one-page summaries of the reports.\n\n"
    for train_example in train_examples:
        tmp_prompt = "Report: " + test_example['report'] + "\n\n"
        tmp_prompt += "Now, write a one-page summary of the report.\nSummary: "
        tmp_prompt = tmp_prompt + train_example['summary'] + '\n\n'





    report = test_example['report']
    summary = test_example['summary']
    # print("report = ", report)
    # print("summary = ", summary)
    # query_instruction = "\n\nUser: The summary of the above document with around 500 words is:\nSystem:"
    # report_prompt = "Summarize the following document to a summary with around 500 words.\n\nDocument: " + report
    # query_instruction = "\n\nUser: The best summary of the above document is:\nSystem:"
    query_instruction = "Now, write a one-page summary of the report.\nSummary:"
    report_prompt += "Report: " + report + "\n\n"
    input_ids_query_context = tokenizer(report_prompt, return_tensors="pt").input_ids
    input_ids_query = tokenizer(query_instruction, return_tensors="pt").input_ids
    return input_ids_query_context, input_ids_query


def process_squality(test_example, tokenizer):
    # print('test_example = ', test_example)
    # print('document = ', test_example['document'])
    report_prompt = test_example['document']
    input_ids_query_context = tokenizer(report_prompt, return_tensors="pt").input_ids

    # query_prompt_ids_list = []
    query_prompts = []
    answer_prompts = []
    for each_question in test_example['questions']:
        tmp_answer_prompts = []
        # print(each_question['question_text'])
        query_prompts.append(tokenizer("\n\nUser: Answer the question in a paragraph.\n\nQuestion:\n\n" + each_question['question_text'] + "\n\nAnswer: ", return_tensors='pt').input_ids)

        # print("response = ", each_question['responses'])
        # pos = 0
        for each_response in each_question['responses']:
            # print("reponse = ", pos, each_response['response_text'])
            tmp_answer_prompts.append(each_response['response_text'])
            # pos += 1
        answer_prompts.append(tmp_answer_prompts)
        


    return input_ids_query_context, query_prompts, answer_prompts


def process_squality_in_context(test_example, tokenizer, train_examples):

    assert isinstance(train_examples, list)
    # print('test_example = ', test_example)
    # print('document = ', test_example['document'])
    report_prompt = ""
    for train_example in train_examples:
        tmp_prompt = "Document: " + test_example['document']
        each_question = test_example['questions'][0]
        # print("each_question = ", each_question)
        tmp_prompt = tmp_prompt +  "\n\nUser: Answer the question in a paragraph.\n\nQuestion:\n\n" + each_question['question_text'] + "\n\nAnswer: \n\n" + each_question['responses'][0]['response_text'] + "\n\n"
        report_prompt += tmp_prompt



    report_prompt = report_prompt + "Document: " +  test_example['document'] + "\n\n"
    input_ids_query_context = tokenizer(report_prompt, return_tensors="pt").input_ids

    # query_prompt_ids_list = []
    query_prompts = []
    answer_prompts = []
    for each_question in test_example['questions']:
        tmp_answer_prompts = []
        # print(each_question['question_text'])
        query_prompts.append(tokenizer("User: Answer the question in a paragraph.\n\nQuestion:\n\n" + each_question['question_text'] + "\n\nAnswer: ", return_tensors='pt').input_ids)

        # print("response = ", each_question['responses'])
        # pos = 0
        for each_response in each_question['responses']:
            # print("reponse = ", pos, each_response['response_text'])
            tmp_answer_prompts.append(each_response['response_text'])
            # pos += 1
        answer_prompts.append(tmp_answer_prompts)
        


    return input_ids_query_context, query_prompts, answer_prompts

    # return 
    # exit()
    # pass

def process_triviaqa(test_example, num_docs, model_name, tokenizer, max_context_length, min_context_length):
    if len(test_example["search_results"]["search_context"])<num_docs:
        print("Not enough retrieved docs. The example is skipped !")
        return None, None
    if num_docs==0:
        input_ids_query_context, context_words_lst, word_idx_to_subtoken_start_end_idx=None, None, None
    else:
        retrieved_docs=random.choices(test_example["search_results"]["search_context"], k=num_docs)
        query_context=construct_query_context(retrieved_docs)
        query_context=query_context.replace('\\n','\n').rstrip('\n')
        context_words_lst=nltk.word_tokenize(query_context)
        if model_name.find("pythia"):
            input_ids_query_context=tokenizer(query_context, return_tensors="pt").input_ids
            word_idx_to_subtoken_start_end_idx=None
        else:
            input_ids_query_context, word_idx_to_subtoken_start_end_idx=customized_tokenize(tokenizer, context_words_lst)
        # do not to device at first to save gpu memory.
        # input_ids_query_context=input_ids_query_context.to(device)
        if len(input_ids_query_context[0]) > max_context_length or len(input_ids_query_context[0]) < min_context_length:
            print("Context too long or too short. The example is skipped !")
            return None, None
    
    #prepare test query
    test_query=test_example["question"]
    query=construct_query(test_query)
    query=query.replace('\\n','\n').rstrip('\n')
    query_words_lst=nltk.word_tokenize(query)
    if model_name.find("pythia"):
        input_ids_query=tokenizer(query, return_tensors="pt").input_ids
    else:
        input_ids_query, _=customized_tokenize(tokenizer, query_words_lst)
    
    return input_ids_query_context, input_ids_query





def main(model_name, quantization_type, device_map, max_gen_length,question_attn_ratio, all_eval_tokens, segment_length, num_beams, dataset_name, dataset_split, num_examples, num_docs, min_context_length, max_context_length, cache_type, random_tokens_ratio=None, start_size=None, recent_size=None, k=None, print_cache=True, print_results=True, with_mean_attention=False, seg_generate=False, use_context_cache=False, new_gov_prompt=False, in_context_examples=0, needle_every_result=False):
    print("model_name = ", model_name)
    import math
    if math.floor(k) == k:
        k = int(k)
    device = torch.device("cuda") 
    random.seed(42)
    
    if dataset_name=="mandarjoshi/trivia_qa":
        dataset=load_dataset(dataset_name, data_files="rc/train-00001-of-00026.parquet")
    elif dataset_name == 'needle':
        # pass
        tester = LLMNeedleHaystackTester()
        dataset_test = tester.construct_data()
        dataset = {"test": dataset_test}
        # print("dataset = ", dataset[:10])
        # exit()

        # dataset_class = 
    else:
        dataset=load_dataset(dataset_name, cache_dir='/network/scratch/y/yu.bai/.cache')
    
    test_examples=dataset[dataset_split]
    if in_context_examples != 0:
        train_examples = [dataset['train'][i] for i in range(in_context_examples)]

    # print("test_examples = ", test_examples)
    
    if quantization_type=='none':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, cache_dir='/network/scratch/y/yu.bai/.cache')
    elif quantization_type=='4bit':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, load_in_4bit=True, cache_dir='/network/scratch/y/yu.bai/.cache')
    elif quantization_type=='8bit':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, load_in_8bit=True, cache_dir='/network/scratch/y/yu.bai/.cache')
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    #position shift
    if "llama" in model.config.model_type:
        enable_llama_pos_shift_attention(model)
    elif "gpt_neox" in model.config.model_type:
        enable_gpt_neox_pos_shift_attention(model)
    elif "falcon" in model.config.model_type:
        enable_falcon_pos_shift_attention(model)
    elif "mistral" in model.config.model_type:
        enable_mistral_pos_shift_attention(model)
    elif "mpt" in model.config.model_type:
        pass
    elif "btlm" in model.config.model_type:
        pass
    else:
        raise ValueError(f"got {model.config.model_type}")
    
    #inference
    model.eval()
    loss_fn = CrossEntropyLoss(reduction="none")
    
    num_correct=0
    num_total=0
    total_decoding_time=0
    total_decoding_memory=0
    total_stopwords_ratio=0
    context_length_lst=[]
    system_summaries = []
    reference_summaries = []
    for test_example in test_examples:
        if "govreport-summarization" in dataset_name:
            if in_context_examples != 0:
                input_ids_query_context, input_ids_query = process_report_summ_new_in_context(test_example, tokenizer, train_examples)
            else:
                if new_gov_prompt:
                    input_ids_query_context, input_ids_query = process_report_summ_new(test_example, tokenizer)
                else:
                    input_ids_query_context, input_ids_query = process_report_summ(test_example, tokenizer)

            if input_ids_query_context is None:
                continue

            #prepare contextual documents

            input_ids_query=input_ids_query.to(device)
            
            # nxtline_id = tokenizer("\n")['input_ids'][0]
            nxtline_id = tokenizer.convert_tokens_to_ids('<0x0A>')

            seq_len = input_ids_query_context.size(1)
            print(f"seq_len: {seq_len}")
            pbar = tqdm(range(0, math.ceil(seq_len/ segment_length)))
            nlls = []
            for idx in pbar:
                if idx == 0:
                    input_ids_seg1 = input_ids_query_context[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                    input_ids_seg2 = input_ids_query_context[:, (idx + 1) * segment_length : min(seq_len, (idx + 2) * segment_length)].to(device)
                    selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(input_ids_seg1, input_ids_seg2, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, with_mean_attention=with_mean_attention)
                    logits = logits.view(-1, model.config.vocab_size)
                    labels = input_ids_query_context[:, idx * segment_length + 1 : (idx + 2) * segment_length + 1].to(logits.device).view(-1)
                elif idx == 1:
                    continue
                else:
                    input_ids_seg2 = input_ids_query_context[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                    selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(None, input_ids_seg2, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values, with_mean_attention=with_mean_attention)
                    logits = logits.view(-1, model.config.vocab_size)
                    labels = input_ids_query_context[:, idx * segment_length + 1 : (idx + 1) * segment_length + 1].to(logits.device).view(-1)
                neg_log_likelihood = loss_fn(logits[:labels.size(-1),], labels)
                print("neg_log_likelihood = ", neg_log_likelihood.mean())
                # exit()
                nlls.append(neg_log_likelihood)
                pbar.set_description(
                    f"nll: {neg_log_likelihood.mean().item():.2f}, ppl: {torch.exp(neg_log_likelihood.mean()).item():.2f}"
                )
            # exit()
                # num_eval_tokens += 1
                # if all_eval_tokens is not None and num_eval_tokens >= all_eval_tokens:
                #     break
            # ppl = torch.exp(torch.stack(nlls).mean())
            # print(ppl.item())




            #cache-based decoding
            # select one last time for the query
            selected_past_key_values, pred_token_idx, given_indices, _ =get_past_key_values_no_current_segment(None, input_ids_query, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values, with_mean_attention=with_mean_attention)
            # get the query states
            # selected_past_key_values, pred_token_idx, given_indices, _ =get_past_key_values(None, input_ids_query, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values)
            outputs = model(input_ids_query, output_attentions=False, use_cache=True, past_key_values=selected_past_key_values)
            selected_past_key_values = outputs.past_key_values
            pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) 
            # remove the cache of the final prediction token.
            # selected_past_key_values = tuple(tuple(kv[:, :, :, :] for kv in layer) for layer_idx, layer in enumerate(selected_past_key_values))


            if seg_generate:
                result, decoding_time, decoding_memory=decoding_with_past_kv(model, tokenizer, selected_past_key_values, pred_token_idx, max_gen_length, num_beams, end_token=nxtline_id, input_ids_query=input_ids_query, k=k, start_size=start_size, recent_size=recent_size, cache_type=cache_type, with_mean_attention=with_mean_attention)     
                # result = tokenizer.decode(torch.cat([input_ids_query[0, :].detach().cpu(), generated_ids]), skip_special_tokens=True)
            else:
                generated_ids, decoding_time, decoding_memory=decoding(model, tokenizer, selected_past_key_values, pred_token_idx, max_gen_length, num_beams, end_token=nxtline_id)     
                result = tokenizer.decode(torch.cat([input_ids_query[0, :].detach().cpu(), generated_ids]), skip_special_tokens=True)

            if new_gov_prompt:
                final_answer=get_the_final_answer_summary_new(result).strip()
            else:
                final_answer=get_the_final_answer_summary(result).strip()

            print("final_answer = ", final_answer)
            if print_results:
                print(result)
            if print_cache:
                if input_ids_query_context!=None:
                    input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
                else:
                    input_ids=input_ids_query
                input_ids=input_ids[0][given_indices].detach().cpu()
                print(tokenizer.decode(input_ids, skip_special_tokens=True))

            #evaluate answers
            # for gold_answer in test_example["answer"]["aliases"]+test_example["answer"]["normalized_aliases"]:
            #     if final_answer.find(gold_answer)!=-1:
            #         num_correct+=1
            #         break
            system_summaries.append(final_answer)
            if "govreport-summarization" in dataset_name:
                reference_summaries.append(test_example['summary'])
            else:
                print("not implemented")
                exit()


            total_decoding_time+=decoding_time
            total_decoding_memory+=decoding_memory
            context_length_lst.append(len(input_ids_query_context[0]))
        elif "needle" in dataset_name:
            input_ids_query_context = tokenizer(test_example['context_prompt'], return_tensors="pt").input_ids
            input_ids_query = tokenizer(test_example['query_prompt'], return_tensors="pt").input_ids


            # , input_ids_query = process_report_summ(test_example, tokenizer)
            if input_ids_query_context is None:
                continue

            #prepare contextual documents

            input_ids_query=input_ids_query.to(device)
            
            # nxtline_id = tokenizer("\n")['input_ids'][0]
            nxtline_id = tokenizer.convert_tokens_to_ids('<0x0A>')

            seq_len = input_ids_query_context.size(1)
            print(f"seq_len: {seq_len}")
            pbar = tqdm(range(0, math.ceil(seq_len/ segment_length)))
            nlls = []
            for idx in pbar:
                if idx == 0:
                    input_ids_seg1 = input_ids_query_context[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                    input_ids_seg2 = input_ids_query_context[:, (idx + 1) * segment_length : min(seq_len, (idx + 2) * segment_length)].to(device)
                    selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(input_ids_seg1, input_ids_seg2, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, with_mean_attention=with_mean_attention)
                    logits = logits.view(-1, model.config.vocab_size)
                    labels = input_ids_query_context[:, idx * segment_length + 1 : (idx + 2) * segment_length + 1].to(logits.device).view(-1)
                elif idx == 1:
                    continue
                else:
                    input_ids_seg2 = input_ids_query_context[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                    selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(None, input_ids_seg2, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values, with_mean_attention=with_mean_attention)
                    logits = logits.view(-1, model.config.vocab_size)
                    labels = input_ids_query_context[:, idx * segment_length + 1 : (idx + 1) * segment_length + 1].to(logits.device).view(-1)
                neg_log_likelihood = loss_fn(logits[:labels.size(-1),], labels)
                print("neg_log_likelihood = ", neg_log_likelihood.mean())
                # exit()
                nlls.append(neg_log_likelihood)
                pbar.set_description(
                    f"nll: {neg_log_likelihood.mean().item():.2f}, ppl: {torch.exp(neg_log_likelihood.mean()).item():.2f}"
                )
            # exit()
                # num_eval_tokens += 1
                # if all_eval_tokens is not None and num_eval_tokens >= all_eval_tokens:
                #     break
            # ppl = torch.exp(torch.stack(nlls).mean())
            # print(ppl.item())




            #cache-based decoding
            # select one last time for the query
            selected_past_key_values, pred_token_idx, given_indices, _ =get_past_key_values_no_current_segment(None, input_ids_query, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values, with_mean_attention=with_mean_attention)
            # get the query states
            # selected_past_key_values, pred_token_idx, given_indices, _ =get_past_key_values(None, input_ids_query, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values)
            outputs = model(input_ids_query, output_attentions=False, use_cache=True, past_key_values=selected_past_key_values)
            selected_past_key_values = outputs.past_key_values
            pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) 
            # remove the cache of the final prediction token.
            # selected_past_key_values = tuple(tuple(kv[:, :, :, :] for kv in layer) for layer_idx, layer in enumerate(selected_past_key_values))


            if seg_generate:
                result, decoding_time, decoding_memory=decoding_with_past_kv(model, tokenizer, selected_past_key_values, pred_token_idx, max_gen_length, num_beams, end_token=None, input_ids_query=input_ids_query, k=k, start_size=start_size, recent_size=recent_size, cache_type=cache_type, with_mean_attention=with_mean_attention)     
                # result = tokenizer.decode(torch.cat([input_ids_query[0, :].detach().cpu(), generated_ids]), skip_special_tokens=True)
            else:
                generated_ids, decoding_time, decoding_memory=decoding(model, tokenizer, selected_past_key_values, pred_token_idx, max_gen_length, num_beams, end_token=None)     
                result = tokenizer.decode(torch.cat([input_ids_query[0, :].detach().cpu(), generated_ids]), skip_special_tokens=True)

            final_answer = result.split('system:')[1].split('\n\n')[1]
            final_answer = ' '.join(final_answer.split()[8:])
            print("final_answer = ", final_answer)
            if print_results:
                print(result)
            if print_cache:
                if input_ids_query_context!=None:
                    input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
                else:
                    input_ids=input_ids_query
                input_ids=input_ids[0][given_indices].detach().cpu()
                print(tokenizer.decode(input_ids, skip_special_tokens=True))

            #evaluate answers
            # for gold_answer in test_example["answer"]["aliases"]+test_example["answer"]["normalized_aliases"]:
            #     if final_answer.find(gold_answer)!=-1:
            #         num_correct+=1
            #         break
            system_summaries.append(final_answer)
            # if "govreport-summarization" in dataset_name:
            # reference_summaries.append(test_example['needle'])
            reference_summaries.append(' '.join(test_example['needle'].split(' ')[8:]))
            # else:
            #     print("not implemented")
            #     exit()


            total_decoding_time+=decoding_time
            total_decoding_memory+=decoding_memory
            context_length_lst.append(len(input_ids_query_context[0]))
        elif "SQuALITY" in dataset_name:
            # gold_answer_list is a 2d list, first dimension is for the questions.
            # second dimension is for the 
            if in_context_examples != 0:
                input_ids_query_context, input_ids_query_list, gold_answer_list = process_squality_in_context(test_example, tokenizer, train_examples)
            else:
                input_ids_query_context, input_ids_query_list, gold_answer_list = process_squality(test_example, tokenizer)

            if input_ids_query_context is None:
                continue

            pos = 0
            for input_ids_query in input_ids_query_list:
                input_ids_query=input_ids_query.to(device)
                # nxtline_id = tokenizer("\n")['input_ids'][0]
                nxtline_id = tokenizer.convert_tokens_to_ids('<0x0A>')

                seq_len = input_ids_query_context.size(1)
                print(f"seq_len: {seq_len}")
                pbar = tqdm(range(0, math.ceil((seq_len / segment_length))))
                nlls = []
                for idx in pbar:
                    if idx == 0:
                        input_ids_seg1 = input_ids_query_context[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                        input_ids_seg2 = input_ids_query_context[:, (idx + 1) * segment_length : min(seq_len, (idx + 2) * segment_length)].to(device)
                        selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(input_ids_seg1, input_ids_seg2, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, with_mean_attention=with_mean_attention)
                        logits = logits.view(-1, model.config.vocab_size)
                        labels = input_ids_query_context[:, idx * segment_length + 1 : (idx + 2) * segment_length + 1].to(logits.device).view(-1)
                    elif idx == 1:
                        continue
                    else:
                        input_ids_seg2 = input_ids_query_context[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                        selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(None, input_ids_seg2, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values, with_mean_attention=with_mean_attention)
                        logits = logits.view(-1, model.config.vocab_size)
                        labels = input_ids_query_context[:, idx * segment_length + 1 : (idx + 1) * segment_length + 1].to(logits.device).view(-1)
                    neg_log_likelihood = loss_fn(logits[:labels.size(-1),], labels)
                    print("neg_log_likelihood = ", neg_log_likelihood.mean())
                    # exit()
                    nlls.append(neg_log_likelihood)
                    pbar.set_description(
                        f"nll: {neg_log_likelihood.mean().item():.2f}, ppl: {torch.exp(neg_log_likelihood.mean()).item():.2f}"
                    )
                # exit()
                    # num_eval_tokens += 1
                    # if all_eval_tokens is not None and num_eval_tokens >= all_eval_tokens:
                    #     break
                # ppl = torch.exp(torch.stack(nlls).mean())
                # print(ppl.item())




                #cache-based decoding
                # select one last time for the query
                selected_past_key_values, pred_token_idx, given_indices, _ =get_past_key_values_no_current_segment(None, input_ids_query, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values, with_mean_attention=with_mean_attention)
                # get the query states
                # selected_past_key_values, pred_token_idx, given_indices, _ =get_past_key_values(None, input_ids_query, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values)
                outputs = model(input_ids_query, output_attentions=False, use_cache=True, past_key_values=selected_past_key_values)
                selected_past_key_values = outputs.past_key_values
                pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) 
                # remove the cache of the final prediction token.
                # selected_past_key_values = tuple(tuple(kv[:, :, :, :] for kv in layer) for layer_idx, layer in enumerate(selected_past_key_values))



                # also remove the last token here
                if seg_generate:
                    result, decoding_time, decoding_memory=decoding_with_past_kv(model, tokenizer, selected_past_key_values, pred_token_idx, max_gen_length // 10, num_beams, end_token=None, input_ids_query=input_ids_query, k=k, start_size=start_size, recent_size=recent_size, cache_type=cache_type, with_mean_attention=with_mean_attention)     
                    # result = tokenizer.decode(torch.cat([input_ids_query[0, :].detach().cpu(), generated_ids]), skip_special_tokens=True)
                else:
                    generated_ids, decoding_time, decoding_memory=decoding(model, tokenizer, selected_past_key_values, pred_token_idx, max_gen_length, num_beams, end_token=None)     
                    result = tokenizer.decode(torch.cat([input_ids_query[0, :].detach().cpu(), generated_ids]), skip_special_tokens=True)

                final_answer = get_the_final_answer_squality(result).strip()

                print("final_answer = ", final_answer)
                if print_results:
                    print(result)
                if print_cache:
                    if input_ids_query_context!=None:
                        input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
                    else:
                        input_ids=input_ids_query
                    input_ids=input_ids[0][given_indices].detach().cpu()

                    print(tokenizer.decode(input_ids, skip_special_tokens=True))

                #evaluate answers
                # for gold_answer in test_example["answer"]["aliases"]+test_example["answer"]["normalized_aliases"]:
                #     if final_answer.find(gold_answer)!=-1:
                #         num_correct+=1
                #         break
                for each in gold_answer_list[pos]:
                    system_summaries.append(final_answer)
                    reference_summaries.append(each)
                pos += 1


                total_decoding_time+=decoding_time
                total_decoding_memory+=decoding_memory
                context_length_lst.append(len(input_ids_query_context[0]))






            # if cache_type=="stopwords":
            #     total_stopwords_ratio+=(len(given_indices)-len(input_ids_query[0]))/len(input_ids_query_context[0])
        num_total+=1
            # exit()
        if num_total>=num_examples:
            break
        

    # if metric_name == 'bleu':
    #     metric = evaluate.load("bleu")
    # elif metric_name == 'rouge':
    #     metric = evaluate.load("rouge")
    # else:
        # default
    metric = evaluate.load("rouge")

    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels
    decoded_preds, decoded_labels = postprocess_text(system_summaries, reference_summaries)
    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    # result = {k: round(v * 100, 4) for k, v in result.items()}
    print("rouge results = ")
    print(result)

    print("needle_every_result = ", needle_every_result)
    if needle_every_result:
        assert "needle" in dataset_name
        sample_results = []
        for pred, ref in zip(decoded_preds, decoded_labels):
            result = metric.compute(predictions=[pred], references=[ref])
            sample_results.append(result)
        cnt=0
        for context_length in tester.context_lengths:
            # dataset = []
            print("context_length = ", context_length)
            print("depth = ", tester.document_depth_percents)
            for depth_percent in tester.document_depth_percents:
                if cnt < len(sample_results):
                    for each_key in sample_results[cnt]:
                        print(sample_results[cnt][each_key], end=' ')
                    cnt += 1
                    print("")
                else:
                    break

            if cnt < len(sample_results):
                print("")
            else:
                    break



    print("Experimental Settings")
    print("Model name:", model_name)
    print("Quantization:", quantization_type)
    print("Dataset name:", dataset_name)
    print("Dataset split:", dataset_split)
    print("Num examples:", num_examples)
    print("Num docs:", num_docs)
    print("Min context length", min_context_length)
    print("Max context length", max_context_length)
    print("Max gen length", max_gen_length)
    print("Cache_type", cache_type)
    if cache_type=="random":
        print("Random tokens ratio", random_tokens_ratio)
    elif cache_type=="startrecent":
        print("Start size", start_size)
        print("Recent size", recent_size)
    elif cache_type=="topk":
        print("K", k)
    print("-----------------------------------------")
    print("Experimental Results")
    print("Average Accuracy:", num_correct/num_total)
    print("Average Decoding Time:", total_decoding_time/num_total)
    print("Average Decoding Memory:", total_decoding_memory/num_total)
    print("Average Context Length:", sum(context_length_lst)/num_total)
    print("Min Context Length:", min(context_length_lst))
    print("Max Context Length:", max(context_length_lst))
    if cache_type=="stopwords":
        print("Average Stopwords Ratio:", total_stopwords_ratio/num_total) 
    
    #     #prepare contextual documents

    #     if len(test_example["search_results"]["search_context"])<num_docs:
    #         print("Not enough retrieved docs. The example is skipped !")
    #         continue
    #     if num_docs==0:
    #         input_ids_query_context, context_words_lst, word_idx_to_subtoken_start_end_idx=None, None, None
    #     else:
    #         retrieved_docs=random.choices(test_example["search_results"]["search_context"], k=num_docs)
    #         query_context=construct_query_context(retrieved_docs)
    #         query_context=query_context.replace('\\n','\n').rstrip('\n')
    #         context_words_lst=nltk.word_tokenize(query_context)
    #         if model_name.find("pythia"):
    #             input_ids_query_context=tokenizer(query_context, return_tensors="pt").input_ids
    #             word_idx_to_subtoken_start_end_idx=None
    #         else:
    #             input_ids_query_context, word_idx_to_subtoken_start_end_idx=customized_tokenize(tokenizer, context_words_lst)
    #         input_ids_query_context=input_ids_query_context.to(device)
    #         if len(input_ids_query_context[0]) > max_context_length or len(input_ids_query_context[0]) < min_context_length:
    #             print("Context too long or too short. The example is skipped !")
    #             continue
        
        #prepare test query
    #     test_query=test_example["question"]
    #     query=construct_query(test_query)
    #     query=query.replace('\\n','\n').rstrip('\n')
    #     query_words_lst=nltk.word_tokenize(query)
    #     if model_name.find("pythia"):
    #         input_ids_query=tokenizer(query, return_tensors="pt").input_ids
    #     else:
    #         input_ids_query, _=customized_tokenize(tokenizer, query_words_lst)
    #     input_ids_query=input_ids_query.to(device)
        
    #     #cache-based decoding
        # selected_past_key_values, pred_token_idx, given_indices=get_past_key_values(input_ids_query_context, input_ids_query, cache_type, model, context_words_lst, word_idx_to_subtoken_start_end_idx, random_tokens_ratio, start_size, recent_size, k)
    #     generated_ids, decoding_time, decoding_memory=decoding(model, tokenizer, selected_past_key_values, pred_token_idx, max_gen_length, num_beams)     
    #     result = tokenizer.decode(torch.cat([input_ids_query[0].detach().cpu(), generated_ids]), skip_special_tokens=True)
    #     final_answer=get_the_final_answer(result)
    #     if print_results:
    #         print(result)
    #     if print_cache:
    #         if input_ids_query_context!=None:
    #             input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
    #         else:
    #             input_ids=input_ids_query
    #         input_ids=input_ids[0][given_indices].detach().cpu()
    #         print(tokenizer.decode(input_ids, skip_special_tokens=True))

    #     #evaluate answers
    #     for gold_answer in test_example["answer"]["aliases"]+test_example["answer"]["normalized_aliases"]:
    #         if final_answer.find(gold_answer)!=-1:
    #             num_correct+=1
    #             break
    #     total_decoding_time+=decoding_time
    #     total_decoding_memory+=decoding_memory
    #     context_length_lst.append(len(input_ids_query_context[0]))
    #     if cache_type=="stopwords":
    #         total_stopwords_ratio+=(len(given_indices)-len(input_ids_query[0]))/len(input_ids_query_context[0])
    #     num_total+=1
    #     if num_total>=num_examples:
    #         break
        

    # print("Experimental Settings")
    # print("Model name:", model_name)
    # print("Quantization:", quantization_type)
    # print("Dataset name:", dataset_name)
    # print("Dataset split:", dataset_split)
    # print("Num examples:", num_examples)
    # print("Num docs:", num_docs)
    # print("Min context length", min_context_length)
    # print("Max context length", max_context_length)
    # print("Max gen length", max_gen_length)
    # print("Cache_type", cache_type)
    # if cache_type=="random":
    #     print("Random tokens ratio", random_tokens_ratio)
    # elif cache_type=="startrecent":
    #     print("Start size", start_size)
    #     print("Recent size", recent_size)
    # elif cache_type=="topk":
    #     print("K", k)
    # print("-----------------------------------------")
    # print("Experimental Results")
    # print("Average Accuracy:", num_correct/num_total)
    # print("Average Decoding Time:", total_decoding_time/num_total)
    # print("Average Decoding Memory:", total_decoding_memory/num_total)
    # print("Average Context Length:", sum(context_length_lst)/num_total)
    # print("Min Context Length:", min(context_length_lst))
    # print("Max Context Length:", max(context_length_lst))
    # if cache_type=="stopwords":
    #     print("Average Stopwords Ratio:", total_stopwords_ratio/num_total) 
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    #model config
    parser.add_argument('--model_name', dest='model_name', action='store', required=False, default='mistralai/Mistral-7B-v0.1')
    parser.add_argument('--quantization_type', dest='quantization_type', action='store', required=False, default='none') #none, 4bit, 8bit
    parser.add_argument('--device_map', dest='device_map', action='store', required=False, default="auto")
    parser.add_argument('--max_gen_length', dest='max_gen_length', action='store', required=False, default=100, type=int)
    parser.add_argument('--num_beams', dest='num_beams', action='store', required=False, default=4)
    parser.add_argument('--segment_length', dest='segment_length', action='store', required=False, default=1000, type=int)
    
    #data config
    parser.add_argument('--dataset_name', dest='dataset_name', action='store', required=False, default='mandarjoshi/trivia_qa')
    parser.add_argument('--dataset_split', dest='dataset_split', action='store', required=False, default='test')
    parser.add_argument('--num_examples', dest='num_examples', action='store', required=False, default=500, type=int) #test on how many examples
    parser.add_argument('--num_docs', dest='num_docs', action='store', required=False, default=1,  type=int ) #number of retrieved docs for each test query
    parser.add_argument('--min_context_length', dest='min_context_length', action='store', required=False, default=0) #the range of context length of retrieved documents
    parser.add_argument('--max_context_length', dest='max_context_length', action='store', required=False, default=10000) 
    parser.add_argument('--all_eval_tokens', dest='all_eval_tokens', action='store', required=False, default=4000000) 
   
    #cache config
    """
    all: all the context tokens
    zeroshot: no context tokens and query does not attend to any context
    nocontext: no context tokens but query does attend to context
    stopwords: only the stopwords in the context
    random: random tokens in the context, controled by random_tokens_ratio
    startrecent: start tokens and recent tokens in the context, start_size and recent_size can be ratio between 0-1 or absolute numbers
    topk: topk tokens with highest attention weights in context, k can be a ratio between 0-1 or absolute numbers
    """
    parser.add_argument("--cache_type", dest='cache_type', action='store', required=False, default='topk')
    parser.add_argument("--random_tokens_ratio", dest='random_tokens_ratio', action='store', required=False, default=0.1)
    parser.add_argument("--question_attn_ratio", dest='question_attn_ratio', action='store', required=False, default=0.5)
    parser.add_argument("--start_size", dest="start_size", action="store", required=False, default=0, type=int )
    parser.add_argument("--recent_size", dest="recent_size", action="store", required=False, default=0.3)
    parser.add_argument("--k", dest="k", action="store", required=False, default=100, type=float)
    parser.add_argument('--with_mean_attention', dest='with_mean_attention', action='store_const', const=True, default=False)
    parser.add_argument('--use_context_cache', dest='use_context_cache', action='store_const', const=True, default=False)
    parser.add_argument('--seg_generate', dest='seg_generate', action='store_const', const=True, default=False) 
    parser.add_argument('--new_gov_prompt', dest='new_gov_prompt', action='store_const', const=True, default=False)
    parser.add_argument('--needle_every_result', dest='needle_every_result', action='store_const', const=True, default=False)   
    parser.add_argument("--in_context_examples", dest="in_context_examples", action="store", required=False, default=0, type=int)
    #print config
    parser.add_argument('--print_cache', dest='print_cache', action='store', required=False, default=False)
    parser.add_argument('--print_results', dest='print_results', action='store', required=False, default=True)
    
    args = parser.parse_args()
    args = vars(args)
    main(**args)