import sys
sys.path.append('/home/mila/x/xiyuan.zou/research/icl-mechanism-project')
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import transformers
import argparse
import random
import nltk
from tqdm import tqdm
from utils import normalize_answer, customized_tokenize

def construct_query_context(retrieved_docs):
    prompt="Answer the query according to the following documents.\n"
    for i in range(len(retrieved_docs)):
        prompt+="Document: "+retrieved_docs[i]+'\n'
    return prompt

def construct_query(query):
    prompt=""
    prompt+="Q: "+query+'\n'
    prompt+="A: The answer is"
    return prompt


def get_past_key_values(input_ids, cache_content, model):
    with torch.no_grad():
        outputs = model(input_ids, use_cache=True)
    past_key_values = outputs.past_key_values

    if cache_content=="all":
        given_indices = [0,1]
    if cache_content=="none":
        return None       
    
    selected_past_key_values = tuple(
        tuple(layer[:, :, given_indices, :] for layer in pkv) for pkv in past_key_values
    )
    return selected_past_key_values


def inference_with_past_key_values(input_ids, past_key_values, tokenizer, model):
    # Generate output using the model with the provided past key values
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            do_sample=True,
            top_k=10,
            num_return_sequences=1,
            eos_token_id=tokenizer.eos_token_id,
            max_new_tokens=5,
            past_key_values=past_key_values,
        )
    # Decode the generated output
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return result

def get_the_final_answer(result):
    ans_sent_start_idx=result.find('The answer is')+len("The answer is")
    ans_sent_end_idx=result.find('.', ans_sent_start_idx)
    if ans_sent_end_idx==-1:
        ans_sent_end_idx=len(result)
    final_answer=result[ans_sent_start_idx:ans_sent_end_idx]
    return normalize_answer(final_answer)

def get_past_key_values_3(input_ids_query_context, input_ids_query, cache_type, model, context_words_lst=None, word_idx_to_subtoken_start_end_idx=None, random_tokens_ratio=None, start_size=None, recent_size=None, k=None):
    #先得到context的attention kv，然后筛选context的kv，再独立得到query的kv(不attend context)，最后用context筛选后的kv+query kv去做生成
    device=input_ids_query.device
    with torch.no_grad():
        if input_ids_query_context!=None:
            input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
            position_ids_query=list(range(len(input_ids_query_context[0]), len(input_ids[0])))
            position_ids_query=torch.tensor(position_ids_query, dtype=torch.int).unsqueeze(0).to(device)
            outputs_context = model(input_ids_query_context, use_cache=True)
            outputs_query = model(input_ids_query, position_ids=position_ids_query, use_cache=True)
            past_key_values = tuple(tuple(torch.cat([kv_context, kv_query], dim=-2) for kv_context, kv_query in zip(layer_context, layer_query)) for layer_context, layer_query in zip(outputs_context.past_key_values, outputs_query.past_key_values))
        else:
            input_ids=input_ids_query
            outputs_query = model(input_ids_query, use_cache=True)
            past_key_values = outputs_query.past_key_values

    if cache_type=="all":
        given_indices = range(len(input_ids[0])) 
    elif cache_type=="stopwords":
        given_indices=[]
        assert context_words_lst!=None
        assert word_idx_to_subtoken_start_end_idx!=None
        for word_idx, word in enumerate(context_words_lst):
            if is_stopword(word):
                subtoken_start_idx, subtoken_end_idx=word_idx_to_subtoken_start_end_idx[word_idx]
                given_indices.extend(range(subtoken_start_idx, subtoken_end_idx+1))       
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
    elif cache_type=="random":
        assert random_tokens_ratio!=None
        random_tokens_num=int(random_tokens_ratio*len(input_ids_query_context[0]))
        given_indices=random.sample(range(len(input_ids_query_context[0])), k=random_tokens_num)
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
    elif cache_type=="startrecent":
        assert start_size!=None
        assert recent_size!=None
        given_indices=list(range(start_size))
        given_indices.extend(range(len(input_ids_query_context[0])-recent_size, len(input_ids_query_context[0])))
        given_indices.extend(range(len(input_ids_query_context[0]), len(input_ids[0])))
    elif cache_type=="topk":
        assert k!=None
                    
    selected_past_key_values = tuple(tuple(kv[:, :, given_indices, :] for kv in layer) for layer in past_key_values)
    pred_token_idx = outputs_query.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    print(given_indices)
    return selected_past_key_values, pred_token_idx, given_indices

def get_past_key_values_2(input_ids_query_context, input_ids_query, cache_type, model, context_words_lst=None, word_idx_to_subtoken_start_end_idx=None, random_tokens_ratio=None, start_size=None, recent_size=None, k=None):    
    #先得到context的attention kv，然后筛选context的kv，让query只关注筛选后的context kv，得到query的kv，最后用context筛选后的kv+query kv去做生成
    device=input_ids_query.device
    with torch.no_grad():
        if input_ids_query_context!=None:
            input_ids=input_ids_query_context
            outputs_context=model(input_ids, output_attentions=True, use_cache=True)
            past_key_values=outputs_context.past_key_values
            if cache_type=="all":
                given_indices = range(len(input_ids[0])) 
            elif cache_type=="stopwords":
                given_indices=[]
                assert context_words_lst!=None
                assert word_idx_to_subtoken_start_end_idx!=None
                for word_idx, word in enumerate(context_words_lst):
                    if is_stopword(word):
                        subtoken_start_idx, subtoken_end_idx=word_idx_to_subtoken_start_end_idx[word_idx]
                        given_indices.extend(range(subtoken_start_idx, subtoken_end_idx+1))       
            elif cache_type=="random":
                assert random_tokens_ratio!=None
                random_tokens_num=int(random_tokens_ratio*len(input_ids_query_context[0]))
                given_indices=random.sample(range(len(input_ids_query_context[0])), k=random_tokens_num)
            elif cache_type=="startrecent":
                assert start_size!=None
                assert recent_size!=None
                given_indices=list(range(start_size))
                given_indices.extend(range(len(input_ids_query_context[0])-recent_size, len(input_ids_query_context[0])))
            elif cache_type=="topk":
                assert k!=None
                
                    
            selected_past_key_values = tuple(tuple(kv[:, :, given_indices, :] for kv in layer) for layer in past_key_values)
                    
            position_ids_query=list(range(len(given_indices), len(given_indices)+len(input_ids_query[0])))
            position_ids_query=torch.tensor(position_ids_query, dtype=torch.int).unsqueeze(0).to(device)
            outputs = model(input_ids_query, past_key_values=selected_past_key_values, position_ids=position_ids_query, use_cache=True)
            selected_past_key_values = outputs.past_key_values
        else:
            input_ids=input_ids_query
            outputs = model(input_ids_query, use_cache=True)
            selected_past_key_values = outputs.past_key_values

    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)

    return selected_past_key_values, pred_token_idx, given_indices

@torch.no_grad()
def greedy_decoding(model, tokenizer, past_key_values, pred_token_idx, max_gen_len):
    device=model.device
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    before_memory=get_peak_memory(device)
    starter, ender=torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #measure the average inference time of a doc with batch_size=1
    starter.record()
    
    generated_ids = [pred_token_idx.item()]
    for _ in range(max_gen_len-1):
        outputs = model(
            input_ids=pred_token_idx, #position_ids自动设定为seq length of past_key_values
            past_key_values=past_key_values,
            use_cache=True,
        )
        past_key_values = outputs.past_key_values #past_key_values包含之前输入的past_key_values和当前token的key values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        generated_ids.append(pred_token_idx.item())
        
        if pred_token_idx == tokenizer.eos_token_id:
            break
    
    ender.record()
    torch.cuda.synchronize()
    decoding_time=starter.elapsed_time(ender)
    after_memory=get_peak_memory(device)
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.empty_cache()
    decoding_memory=after_memory-before_memory
    return torch.tensor(generated_ids, dtype=torch.int), decoding_time, decoding_memory

def main(model_name, dataset_name, dataset_split, num_examples, num_docs, use_cache, cache_content):
    device = torch.device("cuda") 
    random.seed(42)
    
    if dataset_name=="mandarjoshi/trivia_qa":
        dataset=load_dataset(dataset_name, data_files="rc/train-00001-of-00026.parquet")
    test_examples=dataset[dataset_split]
    model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, cache_dir='/network/scratch/x/xiyuan.zou/cache/transformers_cache')
    model.to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    num_correct=0
    num_total=0
    if use_cache==False:
        for test_example in tqdm(test_examples):
            if len(test_example["search_results"]["search_context"])<num_docs:
                #print("Not enough retrieved docs. The example is skipped !")
                continue
            retrieved_docs=random.choices(test_example["search_results"]["search_context"], k=num_docs)
            query_context=construct_query_context(retrieved_docs)
            query_context=query_context.replace('\\n','\n').rstrip('\n')
            input_ids = tokenizer(query_context, return_tensors="pt").input_ids.to(device)
            if len(input_ids[0]) >= 3000:
                #print("Context too long. The example is skipped !")
                continue
            
            test_query=test_example["question"]
            query=construct_query(test_query)
            query=query.replace('\\n','\n').rstrip('\n')
            
            #input_ids = tokenizer(query, return_tensors="pt").input_ids.to(device)
            query_words_lst=nltk.word_tokenize(query)
            input_ids_query, _=customized_tokenize(tokenizer, query_words_lst)
            input_ids_query=input_ids_query.to(device)
            result=inference_with_past_key_values(input_ids_query, None, tokenizer, model)
            
            final_answer=get_the_final_answer(result)
            if (final_answer in test_example["answer"]["aliases"]) or (final_answer in test_example["answer"]["normalized_aliases"]):
                num_correct+=1
            num_total+=1
            if num_total>=num_examples:
                break
    else:
        for test_example in tqdm(test_examples):
            if len(test_example["search_results"]["search_context"])<num_docs:
                #print("Not enough retrieved docs. The example is skipped !")
                continue
            retrieved_docs=random.choices(test_example["search_results"]["search_context"], k=num_docs)
            query_context=construct_query_context(retrieved_docs)
            query_context=query_context.replace('\\n','\n').rstrip('\n')
            input_ids = tokenizer(query_context, return_tensors="pt").input_ids.to(device)
            if len(input_ids[0]) >= 3000:
                #print("Context too long. The example is skipped !")
                continue
            selected_past_key_values=get_past_key_values(input_ids, cache_content, model)
            
            test_query=test_example["question"]
            query=construct_query(test_query)
            query=query.replace('\\n','\n').rstrip('\n')
            input_ids = tokenizer(query, return_tensors="pt").input_ids.to(device)
            result=inference_with_past_key_values(input_ids, selected_past_key_values, tokenizer, model)
            
            print(result)
            final_answer=get_the_final_answer(result)
            
            if (final_answer in test_example["answer"]["aliases"]) or (final_answer in test_example["answer"]["normalized_aliases"]):
                num_correct+=1
            num_total+=1
            if num_total>=num_examples:
                break
            
    print("Accuracy:", num_correct/num_total)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', dest='model_name', action='store', required=False, default='meta-llama/Llama-2-7b-hf')
    parser.add_argument('--dataset_name', dest='dataset_name', action='store', required=False, default='mandarjoshi/trivia_qa')
    parser.add_argument('--dataset_split', dest='dataset_split', action='store', required=False, default='train')
    parser.add_argument('--num_examples', dest='num_examples', action='store', required=False, default=100) #test on how many examples
    parser.add_argument('--num_docs', dest='num_docs', action='store', required=False, default=1) #number of retrieved docs for each test query
    #which part of context to put into cache, all: all the context tokens, none: no contect tokens
    parser.add_argument("--use_cache", dest='use_cache', action='store', required=False, default=False)
    parser.add_argument('--cache_content', dest='cache_content', action='store', required=False, default='all')
    
    args = parser.parse_args()
    args = vars(args)
    main(**args)