import sys
from datasets import load_dataset
from transformers import AutoTokenizer
import torch
import transformers
import argparse
import random
from tqdm import tqdm
import math
from utils import normalize_answer, customized_tokenize, is_stopword, flatten, get_peak_memory, get_flops, token_f1_score
from attention_sinks import AutoModelForCausalLM
import numpy as np
import gc
import json
from metrics import (
    qa_f1_score,
    rouge_zh_score,
    qa_f1_zh_score,
    rouge_score,
    classification_score,
    retrieval_score,
    retrieval_zh_score,
    count_score,
    code_sim_score,
)

dataset2metric = {
    "narrativeqa": qa_f1_score,
    "qasper": qa_f1_score,
    "multifieldqa_en": qa_f1_score,
    "multifieldqa_zh": qa_f1_zh_score,
    "hotpotqa": qa_f1_score,
    "2wikimqa": qa_f1_score,
    "musique": qa_f1_score,
    "dureader": rouge_zh_score,
    "gov_report": rouge_score,
    "qmsum": rouge_score,
    "multi_news": rouge_score,
    "vcsum": rouge_zh_score,
    "trec": classification_score,
    "triviaqa": qa_f1_score,
    "samsum": rouge_score,
    "lsht": classification_score,
    "passage_retrieval_en": retrieval_score,
    "passage_count": count_score,
    "passage_retrieval_zh": retrieval_zh_score,
    "lcc": code_sim_score,
    "repobench-p": code_sim_score,
}
def scorer_e(dataset, predictions, answers, lengths, all_classes):
    scores = {"0-4k": [], "4-8k": [], "8k+": []}
    for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
        score = 0.
        if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
            prediction = prediction.lstrip('\n').split('\n')[0]
        for ground_truth in ground_truths:
            score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
        if length < 4000:
            scores["0-4k"].append(score)
        elif length < 8000:
            scores["4-8k"].append(score)
        else:
            scores["8k+"].append(score)
    for key in scores.keys():
        scores[key] = round(100 * np.mean(scores[key]), 2)
    return scores

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

def load_longbench(datasets):
    dataset2prompt = json.load(open("longbench/dataset2prompt.json", "r"))
    ret_data = {}
    for dataset in datasets:
        data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test')
        prompt_format = dataset2prompt[dataset]
        data_all = [data_sample for data_sample in data]
        ret_data[dataset] = {"data": data_all, 'prompt_format': prompt_format}
    return ret_data

def process_longbench(test_example, tokenizer, dataset_name):
    if 'samsum' in dataset_name:
        source, query = '\n'.join(test_example.split("\n")[:-1]), '\n'.join(test_example.split("\n")[-1:])
    elif 'trivia' in dataset_name:
        source, query = '\n'.join(test_example.split("\n")[:-6]), '\n'.join(test_example.split("\n")[-6:])
    elif 'trec' in dataset_name:
        source, query = '\n'.join(test_example.split("\n")[:-2]), '\n'.join(test_example.split("\n")[-2:])
    elif 'qasper' in dataset_name:
        source, query = '\n\n'.join(test_example.split("\n\n")[:-3]), '\n\n'.join(test_example.split("\n\n")[-3:])
    elif "passage_retrieval" in dataset_name:
        source, query = '\n\n'.join(test_example.split("\n\n")[:-3]), '\n\n'.join(test_example.split("\n\n")[-3:])
        print("query = ", query)
    else:
        source, query = '\n\n'.join(test_example.split("\n\n")[:-2]), '\n\n'.join(test_example.split("\n\n")[-2:])
    input_ids_query_context = tokenizer(source, return_tensors="pt").input_ids
    input_ids_query = tokenizer(query, return_tensors="pt").input_ids
    return input_ids_query_context, input_ids_query

def build_chat(tokenizer, prompt, model_name):
    if "chatglm3" in model_name:
        prompt = tokenizer.build_chat_input(prompt)
    elif "chatglm" in model_name:
        prompt = tokenizer.build_prompt(prompt)
    elif "llama2" in model_name or 'mistral' in model_name:
        prompt = f"[INST]{prompt}[/INST]"
    elif "xgen" in model_name:
        header = (
            "A chat between a curious human and an artificial intelligence assistant. "
            "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
        )
        prompt = header + f" ### Human: {prompt}\n###"
    elif "internlm" in model_name:
        prompt = f"<|User|>:{prompt}<eoh>\n<|Bot|>:"
    return prompt

@torch.no_grad()
def decoding(model, tokenizer, past_key_values, pred_token_idx, max_gen_len, num_beams, end_token=None):
    end_token=tokenizer.eos_token_id if end_token is None else end_token
    device=model.device
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    before_memory=get_peak_memory(device)
    starter, ender=torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #measure the average inference time of a doc with batch_size=1
    starter.record()
    
    outputs = model.generate(
        input_ids=pred_token_idx,
        attention_mask=torch.ones(past_key_values[0][0].size(-2)+1).unsqueeze(0),
        do_sample=False,
        #temperature=0.8,
        #top_p=0.75,
        #top_k=40,
        num_beams=num_beams,
        eos_token_id=end_token,
        max_new_tokens=max_gen_len,
        past_key_values=past_key_values,
    )
    
    ender.record()
    torch.cuda.synchronize()
    decoding_time=starter.elapsed_time(ender)
    after_memory=get_peak_memory(device)
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.empty_cache()
    decoding_memory=after_memory-before_memory
    return outputs[0].detach().cpu(), decoding_time, decoding_memory

def main(model_name, quantization_type, device_map, max_gen_length, attention_sink_size, attention_sink_window_size, segment_length, num_beams):
    device = torch.device("cuda") 
    seed_everything(42)
    datasets = ["qasper", "multifieldqa_en", "hotpotqa", "trec", "triviaqa", "samsum"] 
    datasets_load = load_longbench(datasets)
    dataset2maxlen = json.load(open("longbench/dataset2maxlen.json", "r"))
    
    if quantization_type=='none':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, cache_dir='/network/scratch/x/xiyuan.zou/cache/transformers_cache', attention_sink_size=attention_sink_size, attention_sink_window_size=attention_sink_window_size)
    elif quantization_type=='4bit':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, load_in_4bit=True, cache_dir='/network/scratch/x/xiyuan.zou/cache/transformers_cache', attention_sink_size=attention_sink_size, attention_sink_window_size=attention_sink_window_size)
    elif quantization_type=='8bit':
        model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, device_map=device_map, load_in_8bit=True, cache_dir='/network/scratch/x/xiyuan.zou/cache/transformers_cache', attention_sink_size=attention_sink_size, attention_sink_window_size=attention_sink_window_size)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    #inference
    model.eval()
    for dataset_name in datasets:
        test_examples = datasets_load[dataset_name]['data']
        dataset_prompt_format = datasets_load[dataset_name]['prompt_format']
        max_gen_length = dataset2maxlen[dataset_name]
        num_total=0
        system_summaries = []
        reference_summaries = []
        all_classes = []
        lengths = []
        end_token = nxtline_id if dataset_name == "samsum" else tokenizer.eos_token_id
        
        for test_example in tqdm(test_examples):
            test_example_str = dataset_prompt_format.format(**test_example)
            if dataset_name not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
                test_example_str = build_chat(tokenizer, test_example_str, model_name)
            input_ids_query_context, input_ids_query = process_longbench(test_example_str, tokenizer, dataset_name) 
            input_ids_query = input_ids_query[:, 1:]
            all_classes = test_example["all_classes"]
            lengths.append(test_example["length"])
            nxtline_id = tokenizer.convert_tokens_to_ids('<0x0A>')
            input_ids=torch.cat([input_ids_query_context, input_ids_query], dim=-1)
            input_ids=input_ids.to(device)
            
            with torch.no_grad():
                #prompt processing
                outputs = model(input_ids[:, :attention_sink_size+attention_sink_window_size], use_cache=True)
                past_key_values = outputs.past_key_values
                token_idx=attention_sink_size+attention_sink_window_size
                pbar=tqdm(total=(input_ids.shape[-1]-token_idx)/segment_length)
                while token_idx<input_ids.shape[-1]:
                    outputs = model(input_ids[:, token_idx:token_idx+segment_length], past_key_values=past_key_values, use_cache=True)
                    past_key_values = outputs.past_key_values
                    token_idx+=segment_length
                    assert past_key_values[0][0].shape[-2]==attention_sink_size+attention_sink_window_size
                    pbar.update(1)
                pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
                
                print("cache size is: ", past_key_values[0][0].shape[2])
                #generation
                generated_ids, decoding_time, decoding_memory=decoding(model, tokenizer, past_key_values, pred_token_idx, max_gen_length, num_beams, end_token=end_token)
                result = tokenizer.decode(generated_ids, skip_special_tokens=True)
                final_answer=result.strip()
                
                #evaluate answers
                print("final answer = ", final_answer)
                print("reference answer = ", test_example['answers'])
                system_summaries.append(final_answer)
                reference_summaries.append(test_example['answers'])
                
        result = scorer_e(dataset_name, system_summaries, reference_summaries, lengths, all_classes)
        
        print("Experimental Settings")
        print("Model name:", model_name)
        print("Dataset name:", dataset_name)
        print("Quantization:", quantization_type)
        print("Max gen length", max_gen_length)
        print('attention_sink_size', attention_sink_size)
        print('attention_sink_window_size', attention_sink_window_size)
        print("Segment length", segment_length)
        print("-----------------------------------------")
        print("Experimental Results")
        print("results = ", result)


if __name__ == "__main__":
    inf=999999999
    parser = argparse.ArgumentParser()
    #model config
    parser.add_argument('--model_name', dest='model_name', action='store', required=False, default='mistralai/Mistral-7B-Instruct-v0.1')
    parser.add_argument('--quantization_type', dest='quantization_type', action='store', required=False, default='none') #none, 4bit, 8bit
    parser.add_argument('--device_map', dest='device_map', action='store', required=False, default="auto")
    parser.add_argument('--max_gen_length', dest='max_gen_length', action='store', required=False, default=50, type=int)
    parser.add_argument('--num_beams', dest='num_beams', action='store', required=False, default=1, type=int)
    parser.add_argument('--attention_sink_size', dest='attention_sink_size', action='store', required=False, default=4, type=int)
    parser.add_argument('--attention_sink_window_size', dest='attention_sink_window_size', action='store', required=False, default=1020, type=int)
    parser.add_argument('--segment_length', dest='segment_length', action='store', required=False, default=1, type=int)
    
    args = parser.parse_args()
    args = vars(args)
    main(**args)