import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaForMLM
from datasets import load_dataset
from tqdm import tqdm
import torch
# 每次用random的模块之前先用随机种子固定一下
import random
import json
import copy
from peft import PeftModel

def evaluate_completion(
    completion: str,
    target: str,
    case_sensitive: bool = False,
) -> bool:
    """Evaluate completion using exact-match vs the target.
    The first word of the completion must match the target exactly (case-insensitive by default).

    e.g. completion " World is vast" with target "world" is correct
    """
    
    # target = target.strip()
    target = target.split(',')
    target = [sub_target.strip() for sub_target in target]
    test_str = completion.strip()
    test_str = test_str.lower() if not case_sensitive else test_str
    target_str = [sub_target.lower() for sub_target in target] if not case_sensitive else target
    # target_str = target.lower() if not case_sensitive else target
    return any([sub_target_str in test_str for sub_target_str in target_str])
    # return target_str in test_str

def compute_bleu(run_results):
    all = 0
    completions =  run_results['completions']
    targets = run_results['targets']
    for completion, target in zip(completions, targets):
        bleu_score = sentence_bleu([target], completion['gen_answer'])
        all += bleu_score
    return all / len(completions)

def compute_metric(run_results):
    """Compute accuracy of completions using exact-match.
    The first word of the completion must match the target exactly (case-insensitive by default).

    e.g. completion " World is vast" with target "world" is correct
    """
    n_correct = 0
    is_correct_list = []
    completions =  run_results['completions']
    targets = run_results['targets']
    for completion, target in zip(completions, targets):
        correct = evaluate_completion(completion['gen_answer'], target)
        is_correct_list.append(correct)
        if correct:
            n_correct += 1

    accuracy = n_correct / len(completions)
    return accuracy

def batch_split(prompts, batch_num):   # 把测试的数据集切分成若干个mini-batch
    batch_prompts = []
    mini_batch = []
    for prompt in prompts:
        mini_batch.append(prompt)
        if len(mini_batch) == batch_num:
            batch_prompts.append(mini_batch)
            mini_batch = []
    if len(mini_batch) != 0:
        batch_prompts.append(mini_batch)
    return batch_prompts

def prepare_input(tokenizer, prompts):  # 把准备好的batch用tokenizer转化为llama可接受的数据
    input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", padding=True)
    input_tokens = {k:input_tokens[k] for k in input_tokens if k in ["input_ids", "attention_mask"]}
    for t in input_tokens:
        if torch.is_tensor(input_tokens[t]):
            input_tokens[t] = input_tokens[t].to('cuda')  # 转化为tensor，放到对应的GPU上

    return input_tokens


def construct_different_prompts_for_five_shot(dataset):

    # system_message = 'Below is a conversation with a helpful and terse assistant. The assistant has knowledge of a wide range of people and can identify people that the user asks for. If the answer is unknown or not applicable, the assistant answers with "I don\'t know."\n'
    # system_message = 'You are an expert when it comes to any type of celebrity, including but not limited to actors, singers, producers and others, and including their family relations. You answer questions concisely, with only the answer or "I don\'t know"\n'
    # system_message = 'You are an expert in identifying the family relations of celebrities from various fields, such as actors, singers and producers. Your role is to concisely answer with the names of their parents or children when asked about a specific celebrity. If the information is unavailable, simply answer with "I don\'t know"\n'
    # system_message = 'You are an expert in identifying the family relations of celebrities from various fields, such as actors, singers and producers. Your role is to concisely answer with the names of their parents or children when asked about a specific celebrity. You answer questions concisely, with only the answer or "I don\'t know"\n'
    # system_message = 'You are an expert in identifying the family relations of celebrities from various fields, including but not limited to actors, singers, producers and others. Your role is to concisely answer with the names of their parents or children when asked about a specific celebrity. If the information is unavailable, simply answer with "I don\'t know"\n'
    # system_message = 'You are an expert in identifying the family relations of celebrities from various fields, including but not limited to actors, singers, producers and others. You answer questions concisely, with only the answer or "I don\'t know"\n'
    # system_message = 'You are an expert in identifying the family relations of any type of celebrity, including but not limited to actors, singers, producers and others. You answer questions concisely, with only the answer or "I don\'t know"\n'
    # 6. system_message = 'You are an expert on celebrities, such as actors, singers, and producers, and their family relations. Respond concisely to questions with either the specific answer or "I don\'t know"\n'
    # 7. system_message = 'You are an expert when it comes to any type of celebrity, including but not limited to actors, singers, producers and others, and including their family relations. Respond concisely to questions with either the specific answer or "I don\'t know"\n'
    # 8. system_message = 'You are an expert on celebrities, such as actors, singers, and producers, and their family relations. You answer questions concisely, with only the answer or "I don\'t know"\n'(目前他是最好的)
    # 9. system_message = 'You are an expert on celebrities, such as actors, singers, and producers, and their family relations. You respond concisely to questions with either the specific answer or "I don\'t know"\n'
    # 10. system_message = 'You are an expert on celebrities, such as actors, singers, and producers, and their family relations. You answer questions concisely with either the specific answer or "I don\'t know"\n'
    # 11. system_message = 'You are an expert on any type of celebrities, such as actors, singers, and producers, and their family relations. You answer questions concisely, with only the answer or "I don\'t know"\n' 
    # 12. system_message = 'You are an expert on celebrities from various fields, such as actors, singers, and producers, and their family relations. You answer questions concisely, with only the answer or "I don\'t know"\n' (目前他是最好的)
    # 13. system_message = 'You are an expert when it comes to celebrities from various fields, such as actors, singers, and producers, and their family relations. You answer questions concisely, with only the answer or "I don\'t know"\n'
    system_message = 'You are an expert when it comes to celebrities from various fields, such as actors, singers, and producers, and their family relations. You answer questions concisely, with only the specific answer or "I don\'t know"\n' # (目前他是最好的结果)
    # 15. system_message = 'You are an expert when it comes to celebrities from various fields, such as actors, singers, and athletes, and their family relations. You answer questions concisely, with only the specific answer or "I don\'t know"\n'

    # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?
    prompt_positive_positive_dataset = []
    prompt_positive_negative_dataset = []
    prompt_negative_positive_dataset = []
    prompt_negative_negative_dataset = []

    child_to_parent_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/child_to_parent_dict.json"))
    parent_to_child_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/parent_to_child_dict.json"))

    for i in range(len(dataset)):
        sample = dataset[i]
        child, parent_type, parent = sample['child'], sample['parent_type'], sample['parent']

        another_parent_list = child_to_parent_dict[child].copy()
        another_parent_list.remove(parent)
        another_parent = another_parent_list if len(another_parent_list) != 0 else None
        another_child_list = parent_to_child_dict[parent].copy()
        another_child_list.remove(child)
        another_child = another_child_list if len(another_child_list) != 0 else None

        
        # 这里就能够保证采样的demonstration index是大家都保持一致的，注意这里每次抽样的都是序列，这是不一样的地方。
        index = list(range(len(dataset)))
        index.remove(i)
        random.seed(i + 42)
        demonstration_index = random.sample(index, 5)        

        prompt = system_message
        # 这里已经直接改过了
        for j in range(5):
            demonstration = dataset[demonstration_index[j]]
            # 正序 + 正关 (Who is the parent of A? B)
            example = "Q: Who is {}'s {}?\nA: {}\n".format(demonstration['child'], demonstration['parent_type'], demonstration['parent'])
            prompt += example

        # 正序 + 正关 (Who is the parent of A? B)
        prompt += "Q: Who is {}'s {}?\nA:".format(child, parent_type)
        completion = " {}".format(parent)
        prompt_positive_positive_dataset.append({'prompt': prompt, 'completion': completion})


        index = list(range(len(dataset)))
        index.remove(i)
        random.seed(i + 42)
        demonstration_index = random.sample(index, 5)

        prompt = system_message
        for j in range(5):
            demonstration = dataset[demonstration_index[j]]
            # 正序 + 逆关
            example = "Q: Whose child is {}?\nA: {}\n".format(demonstration['child'], demonstration['parent'])
            prompt += example

        # 正序 + 逆关
        prompt += "Q: Whose child is {}?\nA:".format(child)
        completion = [" {}".format(parent), ] if another_parent is None else [" {}".format(sample_another_parent) for sample_another_parent in another_parent] + [" {}".format(parent)]
        completion = ','.join(completion)
        # completion = " {}".format(parent)
        prompt_positive_negative_dataset.append({'prompt': prompt, 'completion': completion})

        index = list(range(len(dataset)))
        index.remove(i)
        random.seed(i + 42)
        demonstration_index = random.sample(index, 5)

        prompt = system_message
        for j in range(5):
            demonstration = dataset[demonstration_index[j]]
            # 逆序 + 正关
            example = "Q: Whose {} is {}?\nA: {}\n".format(demonstration['parent_type'], demonstration['parent'], demonstration['child'])
            prompt += example
        
        # 逆序 + 正关
        prompt += "Q: Whose {} is {}?\nA:".format(parent_type, parent)
        completion = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
        completion = ','.join(completion)
        # completion = " {}".format(child)
        prompt_negative_positive_dataset.append({'prompt': prompt, 'completion': completion})

        index = list(range(len(dataset)))
        index.remove(i)
        random.seed(i + 999)
        demonstration_index = random.sample(index, 5)
        # 这里也改一下
        prompt = system_message
        for j in range(5):
            demonstration = dataset[demonstration_index[j]]
            # 逆序 + 逆关
            example = "Q: Who is {}'s child?\nA: {}\n".format(demonstration['parent'], demonstration['child'])
            prompt += example
        
        prompt += "Q: Who is {}'s child?\nA:".format(parent)
        # completion = " {}".format(child)
        completion = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
        completion = ','.join(completion)
        prompt_negative_negative_dataset.append({'prompt': prompt, 'completion': completion})
        
    return prompt_positive_positive_dataset, prompt_positive_negative_dataset, prompt_negative_positive_dataset, prompt_negative_negative_dataset

        
def main():
    
    # 调原始的prompt
    model_name_or_path = "/home/hadoop-aipnlp/nazarite/llama2/llama-2-7b"
    # 调新建的人名数据集的prompt
    lora_path = None
    # lora_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/mitigating-reversal-curse-main/own_output/forge_origin_positive_lora_new_setting14/checkpoint-9084"
    
    # 数据集的路径
    data_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/parent_child_pairs.json"
    # data_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/permuate_parent_child_pairs.json"
    # data_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/forge_parent_child_pairs.json"

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left")
    tokenizer.pad_token_id = 0  
    tokenizer.bos_token_id = 1 
    tokenizer.eos_token_id = 2
    model = LlamaForCausalLM.from_pretrained(model_name_or_path, use_safetensors=False, device_map={"":0}, torch_dtype=torch.float16)
    
    if lora_path is not None:
        model = PeftModel.from_pretrained(model, lora_path)
        model.merge_adapter()

    model.eval()
    
    dataset = json.load(open(data_path, 'r'))
    prompt_positive_positive_dataset,  prompt_positive_negative_dataset, prompt_negative_positive_dataset, prompt_negative_negative_dataset = construct_different_prompts_for_five_shot(dataset)
    prompt_positive_posiitve_dataset_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/parent_child_final/prompt_positive_positive_dataset.json"
    prompt_positive_negative_dataset_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/parent_child_final/prompt_positive_negative_dataset.json"
    prompt_negative_positive_dataset_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/parent_child_final/prompt_negative_positive_dataset.json"
    prompt_negative_negative_dataset_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/parent_child_final/prompt_negative_negative_dataset.json"

    with open(prompt_positive_posiitve_dataset_path, 'w') as file1:
        json.dump(prompt_positive_positive_dataset, file1)
    
    with open(prompt_positive_negative_dataset_path, 'w') as file2:
        json.dump(prompt_positive_negative_dataset, file2)

    with open(prompt_negative_positive_dataset_path, 'w') as file1:
        json.dump(prompt_negative_positive_dataset, file1)
    
    with open(prompt_negative_negative_dataset_path, 'w') as file2:
        json.dump(prompt_negative_negative_dataset, file2)

    prompt_positive_positive_dataset = load_dataset("json", data_files={"eval": prompt_positive_posiitve_dataset_path})
    prompt_positive_negative_dataset = load_dataset("json", data_files={"eval": prompt_positive_negative_dataset_path})
    prompt_negative_positive_dataset = load_dataset("json", data_files={"eval": prompt_negative_positive_dataset_path})
    prompt_negative_negative_dataset = load_dataset("json", data_files={"eval": prompt_negative_negative_dataset_path})

    datasets = [prompt_positive_positive_dataset, prompt_positive_negative_dataset, prompt_negative_positive_dataset, prompt_negative_negative_dataset]

    for dataset in datasets:
        val_data = dataset['eval']
        prompts = [d['prompt'] for d in val_data]
        completions = [d['completion'] for d in val_data]
        results = []
        batch_size = 32
        for batch_input in tqdm(batch_split(prompts, batch_size)):
            encode_inputs = prepare_input(tokenizer, batch_input)   # 转化为输入，放到对应的gpu上
            with torch.no_grad():
                output = model.generate(**encode_inputs, max_new_tokens=40, pad_token_id=tokenizer.pad_token_id, do_sample=False, num_beams=1,
                                    eos_token_id= tokenizer.eos_token_id)
            # 输出respnose，然后得到对应的结果
            response = tokenizer.batch_decode(output.cpu(), skip_special_tokens=True)
            # 得到对应的输出
            for idx, p in enumerate(batch_input):
                results.append({'question': p, 'gen_answer': response[idx][len(p):]})
            run_results = {'completions':results, 'targets':completions}   
            # 将生成的结果和completion的结果进行比对

        # 全部都是预测人名，直接用EM即可
        print('Using EM')
        metrics = compute_metric(run_results)
        print(metrics)

if __name__ == '__main__':
    main()
