import sys
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import transformers
import argparse
import random
from tqdm import tqdm
import math
from utils import normalize_answer, customized_tokenize, is_stopword, flatten, get_peak_memory, get_flops, token_f1_score
from cache_based_inference.pos_shift.modify_llama import enable_llama_pos_shift_attention
from cache_based_inference.pos_shift.modify_falcon import enable_falcon_pos_shift_attention
from cache_based_inference.pos_shift.modify_gpt_neox import enable_gpt_neox_pos_shift_attention
from cache_based_inference.pos_shift.modify_falcon import enable_falcon_pos_shift_attention
from cache_based_inference.pos_shift.modify_mistral import enable_mistral_pos_shift_attention
import numpy as np
import gc
import json
from metrics import (
    qa_f1_score,
    rouge_zh_score,
    qa_f1_zh_score,
    rouge_score,
    classification_score,
    retrieval_score,
    retrieval_zh_score,
    count_score,
    code_sim_score,
)

dataset2metric = {
    "narrativeqa": qa_f1_score,
    "qasper": qa_f1_score,
    "multifieldqa_en": qa_f1_score,
    "multifieldqa_zh": qa_f1_zh_score,
    "hotpotqa": qa_f1_score,
    "2wikimqa": qa_f1_score,
    "musique": qa_f1_score,
    "dureader": rouge_zh_score,
    "gov_report": rouge_score,
    "qmsum": rouge_score,
    "multi_news": rouge_score,
    "vcsum": rouge_zh_score,
    "trec": classification_score,
    "triviaqa": qa_f1_score,
    "samsum": rouge_score,
    "lsht": classification_score,
    "passage_retrieval_en": retrieval_score,
    "passage_count": count_score,
    "passage_retrieval_zh": retrieval_zh_score,
    "lcc": code_sim_score,
    "repobench-p": code_sim_score,
}
def scorer_e(dataset, predictions, answers, lengths, all_classes):
    scores = {"0-4k": [], "4-8k": [], "8k+": []}
    for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
        score = 0.
        if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
            prediction = prediction.lstrip('\n').split('\n')[0]
        for ground_truth in ground_truths:
            score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
        if length < 4000:
            scores["0-4k"].append(score)
        elif length < 8000:
            scores["4-8k"].append(score)
        else:
            scores["8k+"].append(score)
    for key in scores.keys():
        scores[key] = round(100 * np.mean(scores[key]), 2)
    return scores

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

def load_longbench(datasets):
    dataset2prompt = json.load(open("longbench/dataset2prompt.json", "r"))
    ret_data = {}
    for dataset in datasets:
        data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test')
        prompt_format = dataset2prompt[dataset]
        data_all = [data_sample for data_sample in data]
        ret_data[dataset] = {"data": data_all, 'prompt_format': prompt_format}
    return ret_data

def process_longbench(test_example, tokenizer, dataset_name):
    if 'samsum' in dataset_name:
        source, query = '\n'.join(test_example.split("\n")[:-1]), '\n'.join(test_example.split("\n")[-1:])
    elif 'trivia' in dataset_name:
        source, query = '\n'.join(test_example.split("\n")[:-6]), '\n'.join(test_example.split("\n")[-6:])
    elif 'trec' in dataset_name:
        source, query = '\n'.join(test_example.split("\n")[:-2]), '\n'.join(test_example.split("\n")[-2:])
    elif 'qasper' in dataset_name:
        source, query = '\n\n'.join(test_example.split("\n\n")[:-3]), '\n\n'.join(test_example.split("\n\n")[-3:])
    elif "passage_retrieval" in dataset_name:
        source, query = '\n\n'.join(test_example.split("\n\n")[:-3]), '\n\n'.join(test_example.split("\n\n")[-3:])
        print("query = ", query)
    else:
        source, query = '\n\n'.join(test_example.split("\n\n")[:-2]), '\n\n'.join(test_example.split("\n\n")[-2:])
    input_ids_query_context = tokenizer(source, return_tensors="pt").input_ids
    input_ids_query = tokenizer(query, return_tensors="pt").input_ids
    return input_ids_query_context, input_ids_query

def build_chat(tokenizer, prompt, model_name):
    if "chatglm3" in model_name:
        prompt = tokenizer.build_chat_input(prompt)
    elif "chatglm" in model_name:
        prompt = tokenizer.build_prompt(prompt)
    elif "llama2" in model_name or 'mistral' in model_name:
        prompt = f"[INST]{prompt}[/INST]"
    elif "xgen" in model_name:
        header = (
            "A chat between a curious human and an artificial intelligence assistant. "
            "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
        )
        prompt = header + f" ### Human: {prompt}\n###"
    elif "internlm" in model_name:
        prompt = f"<|User|>:{prompt}<eoh>\n<|Bot|>:"
    return prompt

@torch.no_grad()
def get_past_key_values(input_ids_seg1, input_ids_seg2, input_ids_query, cache_type, model, k=None, past_key_values=None, with_mean_attention=False, context_importance=None):
    device=input_ids_seg2.device
    if input_ids_seg1!=None:
        input_ids=torch.cat([input_ids_seg1, input_ids_seg2], dim=-1)
        input_ids_2=torch.cat([input_ids_seg1, input_ids_query], dim=-1)
    else:
        input_ids=input_ids_seg2
        input_ids_2=input_ids_query
    
    past_key_values_org = past_key_values 
    outputs = model(input_ids, output_attentions=True, use_cache=True, past_key_values=past_key_values)
    outputs_2 = model(input_ids_2, output_attentions=True, use_cache=True, past_key_values=past_key_values)
    past_key_values = outputs.past_key_values
    
    if cache_type=="all":
        given_indices = range(len(input_ids[0])) 
    elif cache_type=="topk":
        assert type(k)==int
        start_size=0
        
        attention_from_context=outputs.attentions
        attention_from_query=outputs_2.attentions
        given_indices_for_each_layer=[]
        for attention_per_layer_from_context, attention_per_layer_from_query in zip(attention_from_context, attention_from_query):
            if past_key_values_org is None:
                if with_mean_attention:
                    att_weights_from_context=torch.mean(torch.mean(attention_per_layer_from_context[0,:,len(input_ids_seg1[0]):,:len(input_ids_seg1[0])], dim=0), dim=0)
                    att_weights_from_query=torch.mean(torch.mean(attention_per_layer_from_query[0,:,len(input_ids_seg1[0]):,:len(input_ids_seg1[0])], dim=0), dim=0)
                else:
                    att_weights_from_context=torch.mean(attention_per_layer_from_context[0,:,-1,:len(input_ids_seg1[0])], dim=0)
                    att_weights_from_query=torch.mean(attention_per_layer_from_query[0,:,-1,:len(input_ids_seg1[0])], dim=0)
                att_weights=context_importance*att_weights_from_context+(1-context_importance)*att_weights_from_query
                topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, att_weights[start_size:,].size(0)))
                topk_indices = topk_indices + start_size
                topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device), topk_indices.long(), torch.tensor(range(len(input_ids_seg1[0]), len(input_ids[0])), dtype=torch.long, device=device)]) 
            else:
                assert input_ids_seg1==None
                if with_mean_attention:
                    att_weights_from_context=torch.mean(torch.mean(attention_per_layer_from_context[0,:,:,:past_key_values_org[0][0].size(2)], dim=0), dim=0)
                    att_weights_from_query=torch.mean(torch.mean(attention_per_layer_from_query[0,:,:,:past_key_values_org[0][0].size(2)], dim=0), dim=0)
                else:
                    att_weights_from_context=torch.mean(attention_per_layer_from_context[0,:,-1,:past_key_values_org[0][0].size(2)], dim=0)
                    att_weights_from_query=torch.mean(attention_per_layer_from_query[0,:,-1,:past_key_values_org[0][0].size(2)], dim=0)
                att_weights=context_importance*att_weights_from_context+(1-context_importance)*att_weights_from_query    
                topk_att_weights, topk_indices=torch.topk(att_weights[start_size:,], k=min(k, att_weights[start_size:,].size(0)))
                topk_indices = topk_indices + start_size
                topk_indices=torch.cat([torch.tensor(range(start_size), dtype=torch.long, device=device),topk_indices, torch.tensor(range(len(past_key_values_org[0][0][0][0]),  len(past_key_values_org[0][0][0][0]) + len(input_ids[0])), dtype=torch.long, device=device)])
            topk_indices, _=torch.sort(topk_indices)
            given_indices_for_each_layer.append(topk_indices)
    
    if cache_type=="topk":    
        selected_past_key_values = tuple(tuple(kv[:, :, given_indices_for_each_layer[layer_idx], :] for kv in layer) for layer_idx, layer in enumerate(past_key_values))
        given_indices = flatten(given_indices_for_each_layer)
    else:
        selected_past_key_values = tuple(tuple(kv[:, :, given_indices, :] for kv in layer) for layer in past_key_values)
    
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    
    return selected_past_key_values, pred_token_idx, given_indices, outputs.logits

@torch.no_grad()
def decoding(model, tokenizer, past_key_values, pred_token_idx, max_gen_len, num_beams, end_token=None):
    end_token=tokenizer.eos_token_id if end_token is None else end_token
    device=model.device
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    before_memory=get_peak_memory(device)
    starter, ender=torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #measure the average inference time of a doc with batch_size=1
    starter.record()
    
    outputs = model.generate(
        input_ids=pred_token_idx,
        attention_mask=torch.ones(past_key_values[0][0].size(-2)+1).unsqueeze(0),
        do_sample=False,
        #temperature=0.8,
        #top_p=0.75,
        #top_k=40,
        num_beams=num_beams,
        eos_token_id=end_token,
        max_new_tokens=max_gen_len,
        past_key_values=past_key_values,
    )
    
    ender.record()
    torch.cuda.synchronize()
    decoding_time=starter.elapsed_time(ender)
    after_memory=get_peak_memory(device)
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.empty_cache()
    decoding_memory=after_memory-before_memory
    return outputs[0].detach().cpu(), decoding_time, decoding_memory

def main(model_name, quantization_type, device_map, segment_length, num_beams, cache_type, k, with_mean_attention, context_importance):
    if with_mean_attention in ["True", "true"]:
        with_mean_attention=True
    else:
        with_mean_attention=False
    
    device = torch.device("cuda") 
    seed_everything(42)
    datasets = ["qasper", "multifieldqa_en", "hotpotqa", "trec", "triviaqa", "samsum"] 
    datasets_load = load_longbench(datasets)
    dataset2maxlen = json.load(open("longbench/dataset2maxlen.json", "r"))
    
    if quantization_type=='none':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, cache_dir='/network/scratch/x/xiyuan.zou/cache/transformers_cache')
    elif quantization_type=='4bit':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, load_in_4bit=True, cache_dir='/network/scratch/x/xiyuan.zou/cache/transformers_cache')
    elif quantization_type=='8bit':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, load_in_8bit=True, cache_dir='/network/scratch/x/xiyuan.zou/cache/transformers_cache')
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    nxtline_id = tokenizer.convert_tokens_to_ids('<0x0A>')
    
    #inference
    model.eval()
    #position shift
    if "llama" in model.config.model_type:
        enable_llama_pos_shift_attention(model)
    elif "gpt_neox" in model.config.model_type:
        enable_gpt_neox_pos_shift_attention(model)
    elif "falcon" in model.config.model_type:
        enable_falcon_pos_shift_attention(model)
    elif "mistral" in model.config.model_type:
        enable_mistral_pos_shift_attention(model)
    elif "mpt" in model.config.model_type:
        pass
    elif "btlm" in model.config.model_type:
        pass
    else:
        raise ValueError(f"got {model.config.model_type}")
    
    for dataset_name in datasets:
        test_examples = datasets_load[dataset_name]['data']
        dataset_prompt_format = datasets_load[dataset_name]['prompt_format']
        max_gen_length = dataset2maxlen[dataset_name]
        system_summaries = []
        reference_summaries = []
        all_classes = []
        lengths = []
        end_token = nxtline_id if dataset_name == "samsum" else tokenizer.eos_token_id
        
        for test_example in tqdm(test_examples):
            test_example_str = dataset_prompt_format.format(**test_example)
            if dataset_name not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
                test_example_str = build_chat(tokenizer, test_example_str, model_name)
            input_ids_query_context, input_ids_query = process_longbench(test_example_str, tokenizer, dataset_name) 
            input_ids_query = input_ids_query[:, 1:]
            all_classes = test_example["all_classes"]
            lengths.append(test_example["length"])
            input_ids_query_context, input_ids_query=input_ids_query_context.to(device), input_ids_query.to(device)
            
            #stage1: document understanding
            seq_len = input_ids_query_context.size(1)
            print(f"seq_len: {seq_len}")
            num_segments=math.ceil(seq_len / segment_length)
            pbar = tqdm(range(num_segments))
            if num_segments>1:
                for idx in pbar:
                    if idx == 0:
                        input_ids_seg1 = input_ids_query_context[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                        input_ids_seg2 = input_ids_query_context[:, (idx + 1) * segment_length : min(seq_len, (idx + 2) * segment_length)].to(device)
                        selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(input_ids_seg1, input_ids_seg2, input_ids_query, cache_type, model, k, with_mean_attention=with_mean_attention, context_importance=context_importance)
                    elif idx == 1:
                        continue
                    else:
                        input_ids_seg2 = input_ids_query_context[:, idx * segment_length : min(seq_len, (idx + 1) * segment_length)].to(device)
                        selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(None, input_ids_seg2, input_ids_query, cache_type, model, k, past_key_values=selected_past_key_values, with_mean_attention=with_mean_attention, context_importance=context_importance)
            else:
                selected_past_key_values, pred_token_idx, given_indices, logits=get_past_key_values(None, input_ids_query_context, input_ids_query, "all", model, k, with_mean_attention=with_mean_attention, context_importance=context_importance)
            context_past_key_values=selected_past_key_values
        
            #stage2: instruction following
            selected_past_key_values, pred_token_idx, given_indices, _ =get_past_key_values(None, input_ids_query, input_ids_query, cache_type, model, k, past_key_values=context_past_key_values, with_mean_attention=with_mean_attention, context_importance=context_importance)
            selected_past_key_values = tuple(tuple(kv[:, :, :-len(input_ids_query[0]), :] for kv in layer) for layer_idx, layer in enumerate(selected_past_key_values))
            selected_past_key_values, pred_token_idx, given_indices, _ =get_past_key_values(None, input_ids_query, input_ids_query, cache_type, model, k, past_key_values=selected_past_key_values, with_mean_attention=with_mean_attention, context_importance=context_importance)
            generated_ids, decoding_time, decoding_memory=decoding(model, tokenizer, selected_past_key_values, pred_token_idx, max_gen_length, num_beams, end_token=end_token)   
            result = tokenizer.decode(generated_ids, skip_special_tokens=True)
            final_answer=result.strip()
            
            #evaluate answers
            print("final answer = ", final_answer)
            print("reference answer = ", test_example['answers'])
            system_summaries.append(final_answer)
            reference_summaries.append(test_example['answers'])
                
        result = scorer_e(dataset_name, system_summaries, reference_summaries, lengths, all_classes)
        
        print("Experimental Settings")
        print("Model name:", model_name)
        print("Dataset name:", dataset_name)
        print("Quantization:", quantization_type)
        print("Cache type", cache_type)
        print("K", k)
        print("Segment length", segment_length)
        print("Context importance", context_importance)
        print("Mean attention", with_mean_attention)
        print("-----------------------------------------")
        print("Experimental Results")
        print("results = ", result)


if __name__ == "__main__":
    inf=999999999
    parser = argparse.ArgumentParser()
    #model config
    parser.add_argument('--model_name', dest='model_name', action='store', required=False, default='mistralai/Mistral-7B-Instruct-v0.1')
    parser.add_argument('--quantization_type', dest='quantization_type', action='store', required=False, default='none') #none, 4bit, 8bit
    parser.add_argument('--device_map', dest='device_map', action='store', required=False, default="auto")
    parser.add_argument('--num_beams', dest='num_beams', action='store', required=False, default=1, type=int)
    parser.add_argument('--segment_length', dest='segment_length', action='store', required=False, default=1024, type=int)
    parser.add_argument("--cache_type", dest='cache_type', action='store', required=False, default='topk')
    parser.add_argument('--k', dest='k', action='store', required=False, default=1024, type=int)
    parser.add_argument('--with_mean_attention', dest='with_mean_attention', action='store', required=False, default='True', type=str)
    parser.add_argument("--context_importance", dest='context_importance', action='store', required=False, default=0, type=float)
    args = parser.parse_args()
    args = vars(args)
    main(**args)