import sys
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import transformers
import argparse
import random
import nltk
from tqdm import tqdm
import os
import time
import math
from utils import normalize_answer, customized_tokenize, is_stopword, flatten, get_peak_memory, get_flops
from cache_based_inference.pos_shift.modify_llama import enable_llama_pos_shift_attention
from cache_based_inference.pos_shift.modify_falcon import enable_falcon_pos_shift_attention
from cache_based_inference.pos_shift.modify_gpt_neox import enable_gpt_neox_pos_shift_attention
from cache_based_inference.pos_shift.modify_falcon import enable_falcon_pos_shift_attention
from cache_based_inference.pos_shift.modify_mistral import enable_mistral_pos_shift_attention
from cache_based_inference.pos_shift.modify_gpt_j import enable_gpt_j_pos_shift_attention
from torch.nn import CrossEntropyLoss
def construct_query_context(retrieved_docs):
    prompt="Answer the query according to the following documents."
    for i in range(len(retrieved_docs)):
        prompt+="Document: "+retrieved_docs[i]+'\n'
    return prompt

def construct_query(query):
    prompt=""
    prompt+="Q: "+query+'\n'
    prompt+="A: The answer is"
    return prompt

def get_the_final_answer(result):
    ans_sent_start_idx=result.find('The answer is')+len("The answer is")
    ans_sent_end_idx=result.find('.', ans_sent_start_idx)
    if ans_sent_end_idx==-1:
        ans_sent_end_idx=len(result)
    final_answer=result[ans_sent_start_idx:ans_sent_end_idx]
    return normalize_answer(final_answer)


@torch.no_grad()
def get_past_key_values(input_ids_query_context, input_ids_query, cache_type, model, context_words_lst=None, word_idx_to_subtoken_start_end_idx=None, random_tokens_ratio=None, start_size=None, recent_size=None, k=None, past_key_values=None, with_mean_attention=False):
    #context+query一起inference，得到context和query的kv，然后筛选context的kv，query的kv此时attend full context，最后用context筛选后的kv+query kv去做生成
    if cache_type=="zeroshot":
        outputs = model(input_ids_query, output_attentions=True, use_cache=True)
        past_key_values = outputs.past_key_values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        given_indices = range(len(input_ids_query[0]))
        return  past_key_values, pred_token_idx, given_indices
    
    device=input_ids_query.device
    if input_ids_query_context!=None:
        input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
    else:
        input_ids=input_ids_query

    past_key_values_org = past_key_values  
    outputs = model(input_ids, output_attentions=True, use_cache=True, past_key_values=past_key_values)
    past_key_values = outputs.past_key_values
    
    if cache_type=="all":
        given_indices = range(len(input_ids[0])) 
    elif cache_type=="topk":
        if type(start_size)!=int:
            start_size=int(start_size*len(input_ids_query_context[0]))
        assert k!=None
        if type(k)!=int:
            if past_key_values is None:
                k=int(k*len(input_ids_query_context[0]))
            elif input_ids_query_context is None:
                k=int(k*past_key_values[0][0].size(2))
            else:
                k=int(k*(len(input_ids_query_context[0]) + past_key_values[0][0].size(2)))

        
        attention=outputs.attentions
        given_indices_for_each_layer=[]
        # print('start_size = ', start_size)
        given_indices_start=list(range(start_size))
        for attention_per_layer in attention:
            if past_key_values_org is None:
                # considering all the tokens in the query_text
                # att_weights=torch.mean(attention_per_layer[0,:,:,:len(input_ids_query_context[0])], dim=0)
                # att_weights=torch.mean(att_weights[:,:len(input_ids_query_context[0])], dim=0)

                # att_weights=torch.mean(attention_per_layer[0,:,-1,:len(input_ids_query_context[0])], dim=0)
                if with_mean_attention:
                    att_weights=torch.mean(attention_per_layer[0,:,:,:len(input_ids_query_context[0])], dim=0)
                    att_weights=torch.mean(att_weights[len(input_ids_query_context[0]):,:len(input_ids_query_context[0])], dim=0)
                else:

                    att_weights=torch.mean(attention_per_layer[0,:,-1,:len(input_ids_query_context[0])], dim=0)


                topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, len(input_ids_query_context[0]) - start_size))

                topk_indices= topk_indices + start_size
                topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(input_ids_query_context[0]), len(input_ids[0])), dtype=torch.int, device=device)]) 
            else:
                # print("past_key_values = ", past_key_values[0][0].size())
                # exit()

                # considering all the tokens in the query_text
                # att_weights=torch.mean(attention_per_layer[0,:,:,:past_key_values[0][0].size(2)], dim=0)
                # att_weights=torch.mean(att_weights[:,:past_key_values[0][0].size(2)], dim=0)

                # att_weights=torch.mean(attention_per_layer[0,:,-1,:past_key_values_org[0][0].size(2)], dim=0)
                if with_mean_attention:
                    att_weights=torch.mean(attention_per_layer[0,:,:,:past_key_values_org[0][0].size(2)], dim=0)
                    att_weights=torch.mean(att_weights[:,:past_key_values_org[0][0].size(2)], dim=0)
                else:
                    att_weights=torch.mean(attention_per_layer[0,:,-1,:past_key_values_org[0][0].size(2)], dim=0)

                # print("atten_weights = ", att_weights.size())
                # att_weights = att_weights[start_size:]
                # topk_att_weights, topk_indices=torch.topk(att_weights[:,], k=k)
                # print("topk_indices initial = ", topk_indices)

                topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, att_weights[start_size:,].size(0)))
                # exit()
                # print("topk_indices = ", topk_indices)
                topk_indices= topk_indices + start_size
                # print("topk_indices after = ", topk_indices)
                # exit()
                # print(torch.tensor(range(len(past_key_values[0][0][0][0]), len(input_ids[0])), dtype=torch.int, device=device))

                topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(past_key_values_org[0][0][0][0]), len(past_key_values_org[0][0][0][0]) + len(input_ids[0])), dtype=torch.long, device=device)])
            # topk_indices=torch.cat([given_indices_start, topk_indices]), len(input_ids[0]))], dtype=torch.long, device=device)])
            
            topk_indices, _=torch.sort(topk_indices)
            given_indices_for_each_layer.append(topk_indices)
            
    
    if cache_type=="topk":    
        selected_past_key_values = tuple(tuple(kv[:, :, given_indices_for_each_layer[layer_idx], :] for kv in layer) for layer_idx, layer in enumerate(past_key_values))

        # given_indices1=list(range(start_size))
        given_indices = flatten(given_indices_for_each_layer)


    else:
        selected_past_key_values = tuple(tuple(kv[:, :, given_indices, :] for kv in layer) for layer in past_key_values)
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    # pred_token_idx = input_ids[:, -1].unsqueeze(1)
    
    return selected_past_key_values, pred_token_idx, given_indices, outputs.logits



@torch.no_grad()
def get_past_key_values_with_query(input_ids_query_context, input_ids_query, input_ids_query_second, cache_type, model, context_words_lst=None, word_idx_to_subtoken_start_end_idx=None, random_tokens_ratio=None, start_size=None, recent_size=None, k=None, past_key_values=None, combine_policy='attention_first', question_attn_ratio=0.5, with_mean_attention=False):
    #context+query一起inference，得到context和query的kv，然后筛选context的kv，query的kv此时attend full context，最后用context筛选后的kv+query kv去做生成
    assert combine_policy == 'attention_first' or combine_policy == 'selection_first'
    if cache_type=="zeroshot":
        outputs = model(input_ids_query, output_attentions=True, use_cache=True)
        past_key_values = outputs.past_key_values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        given_indices = range(len(input_ids_query[0]))
        return  past_key_values, pred_token_idx, given_indices
    
    device=input_ids_query.device
    if input_ids_query_context!=None:
        input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
        input_ids_query_second=torch.cat([input_ids_query_context, input_ids_query_second], dim=-1)
    else:
        input_ids=input_ids_query
    
    past_key_values_org = past_key_values 
    outputs = model(input_ids, output_attentions=True, use_cache=True, past_key_values=past_key_values)
    outputs_second = model(input_ids_query_second, output_attentions=True, use_cache=True, past_key_values=past_key_values)
    past_key_values = outputs.past_key_values
    # past_key_values_second = outputs_second.past_key_values
    
    if cache_type=="all":
        given_indices = range(len(input_ids[0])) 
    elif cache_type=="topk":
        if type(start_size)!=int:
            start_size=int(start_size*len(input_ids_query_context[0]))
        assert k!=None
        if type(k)!=int:
            if past_key_values is None:
                k=int(k*len(input_ids_query_context[0]))
            elif input_ids_query_context is None:
                k=int(k*past_key_values[0][0].size(2))
            else:
                k=int(k*(len(input_ids_query_context[0]) + past_key_values[0][0].size(2)))

        
        attention=outputs.attentions
        attention_second = outputs_second.attentions
        given_indices_for_each_layer=[]
        # print('start_size = ', start_size)
        given_indices_start=list(range(start_size))
        for i, attention_per_layer in enumerate(attention):
            attention_per_layer_second = attention_second[i]
            if past_key_values_org is None:

                if with_mean_attention:
                    # att_weights_context=torch.mean(attention_per_layer[0,:,len(input_ids_query_context[0]):,:len(input_ids_query_context[0])], dim=1)
                    # att_weights_query=torch.mean(attention_per_layer_second[0,:,len(input_ids_query_context[0]):,:len(input_ids_query_context[0])], dim=1)
                    # att_weights = torch.mean(att_weights_context * (1 - question_attn_ratio) + att_weights_query * question_attn_ratio, dim=0) 



                    att_weights_context=torch.mean(torch.mean(attention_per_layer[0,:,len(input_ids_query_context[0]):,:len(input_ids_query_context[0])], dim=0), dim=0)
                    att_weights_query=torch.mean(torch.mean(attention_per_layer_second[0,:,len(input_ids_query_context[0]):,:len(input_ids_query_context[0])], dim=0), dim=0)
                    att_weights = att_weights_context * (1 - question_attn_ratio) + att_weights_query * question_attn_ratio


                    # att_weights_context=torch.mean(att_weights_context[:,:,:len(input_ids_query_context[0])], dim=1)
                    # att_weights_query=torch.mean(attention_per_layer_second[0,:,:,:len(input_ids_query_context[0])], dim=1)
                    # att_weights_context=torch.mean(att_weights[len(input_ids_query_context[0]):,:len(input_ids_query_context[0])], dim=0)


                else:
                    att_weights = torch.mean(attention_per_layer[0,:,-1,:len(input_ids_query_context[0])] * (1 - question_attn_ratio) + attention_per_layer_second[0,:,-1,:len(input_ids_query_context[0])] * question_attn_ratio, dim=0)
                # att_weights_second = torch.mean(attention_per_layer_second[0,:,-1,:len(input_ids_query_context[0])], dim=0)


                topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, len(input_ids_query_context[0]) - start_size))
                # topk_att_weights_second, topk_indices_second = torch.topk(att_weights_second, k=k)
                topk_indices= topk_indices + start_size

                topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device), topk_indices.long(), torch.tensor(range(len(input_ids_query_context[0]), len(input_ids[0])), dtype=torch.long, device=device)]) 
                    # topk_indices_second=torch.cat([topk_indices_second, torch.tensor(range(len(input_ids_query_context[0]), len(input_ids_query_second[0])), dtype=torch.int, device=device)]) 
            else:
                # print("past_key_values = ", past_key_values[0][0].size())
                # exit()

                if with_mean_attention:
                    # att_weights_context=torch.mean(attention_per_layer[0,:,:,:past_key_values_org[0][0].size(2)], dim=1)
                    # att_weights_query=torch.mean(attention_per_layer_second[0,:,:,:past_key_values_org[0][0].size(2)], dim=1)
                    # att_weights=torch.mean(att_weights_context * (1 - question_attn_ratio) + att_weights_query * question_attn_ratio, dim=0)


                    att_weights_context=torch.mean(torch.mean(attention_per_layer[0,:,:,:past_key_values_org[0][0].size(2)], dim=0), dim=0)
                    att_weights_query=torch.mean(torch.mean(attention_per_layer_second[0,:,:,:past_key_values_org[0][0].size(2)], dim=0), dim=0)
                    att_weights = att_weights_context * (1 - question_attn_ratio) + att_weights_query * question_attn_ratio
                else:
                    att_weights=torch.mean(attention_per_layer[0,:,-1,:past_key_values_org[0][0].size(2)] * (1 - question_attn_ratio) + attention_per_layer_second[0,:,-1,:past_key_values_org[0][0].size(2)] * question_attn_ratio, dim=0)
                # att_weights_second=torch.mean(attention_per_layer_second[0,:,-1,:past_key_values[0][0].size(2)], dim=0)

                # print("atten_weights = ", att_weights.size())
                # att_weights = att_weights[start_size:]
                # topk_att_weights, topk_indices=torch.topk(att_weights[:,], k=k)
                # print("topk_indices initial = ", topk_indices)

                topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, att_weights[start_size:,].size(0)))
                # topk_att_weights_second, topk_indices_second=torch.topk(att_weights_second[start_size:,], k=min(k, att_weights_second[start_size:,].size(0)))
                # exit()
                # print("topk_indices = ", topk_indices)
                topk_indices= topk_indices + start_size
                # topk_indices_second= topk_indices_second + start_size
                # print("topk_indices after = ", topk_indices)
                # exit()
                # print(torch.tensor(range(len(past_key_values[0][0][0][0]), len(input_ids[0])), dtype=torch.int, device=device))

                topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(past_key_values_org[0][0][0][0]), len(past_key_values_org[0][0][0][0]) + len(input_ids[0])), dtype=torch.long, device=device)])
                    # topk_indices_second=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices_second, torch.tensor(range(len(past_key_values[0][0][0][0]), len(input_ids_query_second[0])), dtype=torch.long, device=device)])                    
            # topk_indices=torch.cat([given_indices_start, topk_indices]), len(input_ids[0]))], dtype=torch.long, device=device)])
            # elif combine_policy == 'attention_first':      
            topk_indices, _=torch.sort(topk_indices)
            given_indices_for_each_layer.append(topk_indices)
            # print('topk_indices = ', topk_indices)
            
    
    if cache_type=="topk":    
        # print("")
        selected_past_key_values = tuple(tuple(kv[:, :, given_indices_for_each_layer[layer_idx], :] for kv in layer) for layer_idx, layer in enumerate(past_key_values))

        # given_indices1=list(range(start_size))
        given_indices = flatten(given_indices_for_each_layer)


    else:
        selected_past_key_values = tuple(tuple(kv[:, :, given_indices, :] for kv in layer) for layer in past_key_values)
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    
    return selected_past_key_values, pred_token_idx, given_indices, outputs.logits


@torch.no_grad()
def decoding(model, tokenizer, past_key_values, pred_token_idx, max_gen_len, num_beams):
    device=model.device
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    before_memory=get_peak_memory(device)
    starter, ender=torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #measure the average inference time of a doc with batch_size=1
    starter.record()
    
    outputs = model.generate(
        input_ids=pred_token_idx,
        attention_mask=torch.ones(past_key_values[0][0].size(-2)+1).unsqueeze(0),
        do_sample=True,
        temperature=0.8,
        top_p=0.75,
        top_k=40,
        num_beams=num_beams,
        eos_token_id=tokenizer.eos_token_id,
        max_new_tokens=max_gen_len,
        past_key_values=past_key_values,
    )
    
    ender.record()
    torch.cuda.synchronize()
    decoding_time=starter.elapsed_time(ender)
    after_memory=get_peak_memory(device)
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.empty_cache()
    decoding_memory=after_memory-before_memory
    return outputs[0].detach().cpu(), decoding_time, decoding_memory

def main(model_name, quantization_type, device_map, max_gen_length, question_attn_ratio, all_eval_tokens, segment_length, num_beams, dataset_name, dataset_split, num_examples, num_docs, output_dir, min_context_length, max_context_length, cache_type, combine_policy, random_tokens_ratio=None, start_size=None, recent_size=None, k=None, print_cache=True, print_results=True, with_mean_attention=False, seg_generate=False, new_gov_prompt=False, in_context_examples=0, final_three_dataset=False):
    device = torch.device("cuda") 
    random.seed(42)
    import math
    if math.floor(k) == k:
        k = int(k) 
    # if dataset_name=="mandarjoshi/trivia_qa":
    #     dataset=load_dataset(dataset_name, data_files="rc/train-00001-of-00026.parquet")
    # else:
    dataset = load_dataset(dataset_name, split=dataset_split, cache_dir='/network/scratch/y/yu.bai/.cache')
    test_examples=dataset
    # print("test_examples = ", test_examples[0])
    # exit()
    
    if quantization_type=='none':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, cache_dir='/network/scratch/y/yu.bai/.cache')
    elif quantization_type=='4bit':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, load_in_4bit=True, cache_dir='/network/scratch/y/yu.bai/.cache')
    elif quantization_type=='8bit':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, load_in_8bit=True, cache_dir='/network/scratch/y/yu.bai/.cache')
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model.eval()
    
    # datasets_longbench = ["qasper", "multifieldqa_en", "hotpotqa", "trec", "triviaqa", "samsum"]
    # datasets = ["qasper", "multifieldqa_en", "hotpotqa", "trec", "2wikimqa", "gov_report", "multi_news",  "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] 

    # "trec",
    # datasets_load = load_longbench(datasets)
    #position shift
    if "llama" in model.config.model_type:
        enable_llama_pos_shift_attention(model)
    elif "gpt_neox" in model.config.model_type:
        enable_gpt_neox_pos_shift_attention(model)
    elif "falcon" in model.config.model_type:
        enable_falcon_pos_shift_attention(model)
    elif "mistral" in model.config.model_type:
        enable_mistral_pos_shift_attention(model)
    elif "mpt" in model.config.model_type:
        pass
    elif "btlm" in model.config.model_type:
        pass
    else:
        raise ValueError(f"got {model.config.model_type}")
    loss_fn = CrossEntropyLoss(reduction="none")
    #inference
    num_correct=0
    num_total=0
    total_decoding_time=0
    total_decoding_memory=0
    total_stopwords_ratio=0
    context_length_lst=[]
    nlls = []
    num_eval_tokens = 0
    start_time = time.time()

    questions = [
        "How is the ground truth for fake news established?",
        "What architecture does the encoder have?", 
        "Which case was brought to court first Miller v. California or Gates v. Collier ?",
        "What occupation is shared by both Marge Piercy and Richard Aldington?", 
        "What is their definition of tweets going viral?",
        "Were any of these tasks evaluated in any previous work?",
        "What sentiment classification dataset is used?", 
        "The historical Nimavar school in the Nimavar Bazaar, or bazar, is located in which country?",
        "For what type of work is the production company for The Year Without a Santa Claus best known?",
        "The physicist who is responsible for identifying the Rabi cycle won what award?",
    ]
    # query_text = 


    # for text in te
    os.makedirs(output_dir, exist_ok=True)
    for num_question, each_question in enumerate(questions):
        # if num_question < 2:
        #     continue
        query = "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: " + each_question + "\nAnswer:"
        print("query = ", query)
        input_ids_query = tokenizer(query, return_tensors="pt").input_ids
        f = open(f"{output_dir}/{model_name.replace('/', '_')}_log_new_new_{str(num_question)}.txt", "w")
        f_time = open(f"{output_dir}/{model_name.replace('/', '_')}_time_new_new_{str(num_question)}.txt", "w")

        num_eval_tokens = 0
        input_ids_query = input_ids_query[:, 1:].to(device)
        for test_example in test_examples:
            text = test_example['text']
            encodings = tokenizer(text, return_tensors="pt")
            seq_len = encodings.input_ids.size(1)
            print(f"seq_len: {seq_len}")
            pbar = tqdm(range(0, math.ceil(seq_len/ segment_length)))
            for idx in pbar:
                if idx == 0:
                    input_ids_seg1 = encodings.input_ids[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                    input_ids_seg2 = encodings.input_ids[:, (idx + 1) * segment_length : min(seq_len, (idx + 2) * segment_length)].to(device)
                    # selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(input_ids_seg1, input_ids_seg2, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k)
                    selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values_with_query(input_ids_seg1, input_ids_seg2, input_ids_query, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, combine_policy=combine_policy, question_attn_ratio=question_attn_ratio, with_mean_attention=with_mean_attention)
                    logits = logits.view(-1, model.config.vocab_size)
                    labels = encodings.input_ids[:, idx * segment_length + 1 : (idx + 2) * segment_length + 1].to(logits.device).view(-1)
                    num_eval_tokens += min(seq_len, (idx + 2) * segment_length) - idx * segment_length
                elif idx == 1:
                    continue
                else:
                    input_ids_seg2 = encodings.input_ids[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                    selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values_with_query(None, input_ids_seg2, input_ids_query, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values, combine_policy=combine_policy, question_attn_ratio=question_attn_ratio, with_mean_attention=with_mean_attention)
                    logits = logits.view(-1, model.config.vocab_size)
                    labels = encodings.input_ids[:, idx * segment_length + 1 : (idx + 1) * segment_length + 1].to(logits.device).view(-1)
                    num_eval_tokens += min(seq_len, (idx + 1) * segment_length) - idx * segment_length
                neg_log_likelihood = loss_fn(logits[:labels.size(-1),], labels)
                # if idx == 0:
                #     input_ids_seg1 = encodings.input_ids[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                #     input_ids_seg2 = encodings.input_ids[:, (idx + 1) * segment_length : min(seq_len, (idx + 2) * segment_length)].to(device)
                #     selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(input_ids_seg1, input_ids_seg2, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k)
                #     logits = logits.view(-1, model.config.vocab_size)
                #     labels = encodings.input_ids[:, idx * segment_length + 1 : (idx + 2) * segment_length + 1].to(logits.device).view(-1)
                #     num_eval_tokens += min(seq_len, (idx + 2) * segment_length) - idx * segment_length
                # elif idx == 1:
                #     continue
                # else:
                #     input_ids_seg2 = encodings.input_ids[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                #     selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(None, input_ids_seg2, cache_type, model, None, None, random_tokens_ratio, start_size, recent_size, k, past_key_values=selected_past_key_values)
                #     logits = logits.view(-1, model.config.vocab_size)
                #     labels = encodings.input_ids[:, idx * segment_length + 1 : (idx + 1) * segment_length + 1].to(logits.device).view(-1)
                #     num_eval_tokens += min(seq_len, (idx + 1) * segment_length) - idx * segment_length
                # neg_log_likelihood = loss_fn(logits, labels)
                # print("neg_log_likelihood = ", neg_log_likelihood.mean())
                for each in neg_log_likelihood:
                    print(each.item(), file=f, flush=True)
                # exit()
                nlls.append(neg_log_likelihood)
                pbar.set_description(
                    f"nll: {neg_log_likelihood.mean().item():.2f}, ppl: {torch.exp(neg_log_likelihood.mean()).item():.2f}"
                )
                
                # if num_eval_tokens % 10000 == 0:
                tmp_time = time.time()
                print(num_eval_tokens, "steps middle time = ", tmp_time - start_time)
                print(str(num_eval_tokens) + " " + str(tmp_time - start_time), file=f_time, flush=True)
            if all_eval_tokens is not None and num_eval_tokens >= all_eval_tokens:
                break
            # if all_eval_tokens is not None and num_eval_tokens >= all_eval_tokens:
            #     continue
        # ppl = torch.exp(torch.stack(nlls).mean())
        # print(ppl.item())
        end_time = time.time()
        f.close()
        print("overall time = ", end_time - start_time)
    #     #prepare contextual documents

    #     if len(test_example["search_results"]["search_context"])<num_docs:
    #         print("Not enough retrieved docs. The example is skipped !")
    #         continue
    #     if num_docs==0:
    #         input_ids_query_context, context_words_lst, word_idx_to_subtoken_start_end_idx=None, None, None
    #     else:
    #         retrieved_docs=random.choices(test_example["search_results"]["search_context"], k=num_docs)
    #         query_context=construct_query_context(retrieved_docs)
    #         query_context=query_context.replace('\\n','\n').rstrip('\n')
    #         context_words_lst=nltk.word_tokenize(query_context)
    #         if model_name.find("pythia"):
    #             input_ids_query_context=tokenizer(query_context, return_tensors="pt").input_ids
    #             word_idx_to_subtoken_start_end_idx=None
    #         else:
    #             input_ids_query_context, word_idx_to_subtoken_start_end_idx=customized_tokenize(tokenizer, context_words_lst)
    #         input_ids_query_context=input_ids_query_context.to(device)
    #         if len(input_ids_query_context[0]) > max_context_length or len(input_ids_query_context[0]) < min_context_length:
    #             print("Context too long or too short. The example is skipped !")
    #             continue
        
        #prepare test query
    #     test_query=test_example["question"]
    #     query=construct_query(test_query)
    #     query=query.replace('\\n','\n').rstrip('\n')
    #     query_words_lst=nltk.word_tokenize(query)
    #     if model_name.find("pythia"):
    #         input_ids_query=tokenizer(query, return_tensors="pt").input_ids
    #     else:
    #         input_ids_query, _=customized_tokenize(tokenizer, query_words_lst)
    #     input_ids_query=input_ids_query.to(device)
        
    #     #cache-based decoding
        # selected_past_key_values, pred_token_idx, given_indices=get_past_key_values(input_ids_query_context, input_ids_query, cache_type, model, context_words_lst, word_idx_to_subtoken_start_end_idx, random_tokens_ratio, start_size, recent_size, k)
    #     generated_ids, decoding_time, decoding_memory=decoding(model, tokenizer, selected_past_key_values, pred_token_idx, max_gen_length, num_beams)     
    #     result = tokenizer.decode(torch.cat([input_ids_query[0].detach().cpu(), generated_ids]), skip_special_tokens=True)
    #     final_answer=get_the_final_answer(result)
    #     if print_results:
    #         print(result)
    #     if print_cache:
    #         if input_ids_query_context!=None:
    #             input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
    #         else:
    #             input_ids=input_ids_query
    #         input_ids=input_ids[0][given_indices].detach().cpu()
    #         print(tokenizer.decode(input_ids, skip_special_tokens=True))

    #     #evaluate answers
    #     for gold_answer in test_example["answer"]["aliases"]+test_example["answer"]["normalized_aliases"]:
    #         if final_answer.find(gold_answer)!=-1:
    #             num_correct+=1
    #             break
    #     total_decoding_time+=decoding_time
    #     total_decoding_memory+=decoding_memory
    #     context_length_lst.append(len(input_ids_query_context[0]))
    #     if cache_type=="stopwords":
    #         total_stopwords_ratio+=(len(given_indices)-len(input_ids_query[0]))/len(input_ids_query_context[0])
    #     num_total+=1
    #     if num_total>=num_examples:
    #         break
        

    # print("Experimental Settings")
    # print("Model name:", model_name)
    # print("Quantization:", quantization_type)
    # print("Dataset name:", dataset_name)
    # print("Dataset split:", dataset_split)
    # print("Num examples:", num_examples)
    # print("Num docs:", num_docs)
    # print("Min context length", min_context_length)
    # print("Max context length", max_context_length)
    # print("Max gen length", max_gen_length)
    # print("Cache_type", cache_type)
    # if cache_type=="random":
    #     print("Random tokens ratio", random_tokens_ratio)
    # elif cache_type=="startrecent":
    #     print("Start size", start_size)
    #     print("Recent size", recent_size)
    # elif cache_type=="topk":
    #     print("K", k)
    # print("-----------------------------------------")
    # print("Experimental Results")
    # print("Average Accuracy:", num_correct/num_total)
    # print("Average Decoding Time:", total_decoding_time/num_total)
    # print("Average Decoding Memory:", total_decoding_memory/num_total)
    # print("Average Context Length:", sum(context_length_lst)/num_total)
    # print("Min Context Length:", min(context_length_lst))
    # print("Max Context Length:", max(context_length_lst))
    # if cache_type=="stopwords":
    #     print("Average Stopwords Ratio:", total_stopwords_ratio/num_total) 
    


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    #model config
    parser.add_argument('--model_name', dest='model_name', action='store', required=False, default='mistralai/Mistral-7B-v0.1')
    parser.add_argument('--quantization_type', dest='quantization_type', action='store', required=False, default='none') #none, 4bit, 8bit
    parser.add_argument('--device_map', dest='device_map', action='store', required=False, default="auto")
    parser.add_argument('--max_gen_length', dest='max_gen_length', action='store', required=False, default=100, type=int)
    parser.add_argument('--num_beams', dest='num_beams', action='store', required=False, default=4)
    parser.add_argument('--segment_length', dest='segment_length', action='store', required=False, default=1024, type=int)
    
    #data config
    parser.add_argument('--dataset_name', dest='dataset_name', action='store', required=False, default='mandarjoshi/trivia_qa')
    parser.add_argument('--combine_policy', dest='combine_policy', action='store', required=False, default='attention_first')
    parser.add_argument('--dataset_split', dest='dataset_split', action='store', required=False, default='test')
    parser.add_argument('--num_examples', dest='num_examples', action='store', required=False, default=500, type=int) #test on how many examples
    parser.add_argument('--num_docs', dest='num_docs', action='store', required=False, default=1,  type=int ) #number of retrieved docs for each test query
    parser.add_argument('--min_context_length', dest='min_context_length', action='store', required=False, default=0) #the range of context length of retrieved documents
    parser.add_argument('--max_context_length', dest='max_context_length', action='store', required=False, default=10000) 
    parser.add_argument('--all_eval_tokens', dest='all_eval_tokens', action='store', required=False, default=4000000, type=int) 
    parser.add_argument("--question_attn_ratio", dest='question_attn_ratio', action='store', required=False, default=1.0, type=float)
   
    #cache config
    """
    all: all the context tokens
    zeroshot: no context tokens and query does not attend to any context
    nocontext: no context tokens but query does attend to context
    stopwords: only the stopwords in the context
    random: random tokens in the context, controled by random_tokens_ratio
    startrecent: start tokens and recent tokens in the context, start_size and recent_size can be ratio between 0-1 or absolute numbers
    topk: topk tokens with highest attention weights in context, k can be a ratio between 0-1 or absolute numbers
    """
    parser.add_argument("--cache_type", dest='cache_type', action='store', required=False, default='topk')
    parser.add_argument("--random_tokens_ratio", dest='random_tokens_ratio', action='store', required=False, default=0.1)
    parser.add_argument("--start_size", dest="start_size", action="store", required=False, default=0, type=int )
    parser.add_argument("--recent_size", dest="recent_size", action="store", required=False, default=0.3)
    parser.add_argument("--k", dest="k", action="store", required=False, default=1024, type=float)
    parser.add_argument('--with_mean_attention', dest='with_mean_attention', action='store_const', const=True, default=False) 
    parser.add_argument('--seg_generate', dest='seg_generate', action='store_const', const=True, default=False) 
    parser.add_argument('--new_gov_prompt', dest='new_gov_prompt', action='store_const', const=True, default=False)
    parser.add_argument('--final_three_dataset', dest='final_three_dataset', action='store_const', const=True, default=False)
    parser.add_argument("--in_context_examples", dest="in_context_examples", action="store", required=False, default=0, type=int)
    # parser.add_argument('--in_context_examples', dest='new_gov_prompt', action='store_const', const=True, default=False)
    #print config
    parser.add_argument('--print_cache', dest='print_cache', action='store', required=False, default=False)
    parser.add_argument('--print_results', dest='print_cache', action='store', required=False, default=False)
    parser.add_argument(
        "--output_dir",
        type=str,
        default="outputs/debug",
    ) 
    
    args = parser.parse_args()
    args = vars(args)
    main(**args)