import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaForMLM
from datasets import load_dataset
from tqdm import tqdm
import torch
# 每次用random的模块之前先用随机种子固定一下
import random
import json
import copy
from peft import PeftModel

def evaluate_completion(
    completion: str,
    target: str,
    case_sensitive: bool = False,
) -> bool:
    """Evaluate completion using exact-match vs the target.
    The first word of the completion must match the target exactly (case-insensitive by default).

    e.g. completion " World is vast" with target "world" is correct
    """
    
    # target = target.strip()
    target = target.split(',')
    target = [sub_target.strip() for sub_target in target]
    test_str = completion.strip()
    test_str = test_str.lower() if not case_sensitive else test_str
    target_str = [sub_target.lower() for sub_target in target] if not case_sensitive else target
    # target_str = target.lower() if not case_sensitive else target
    return any([sub_target_str in test_str for sub_target_str in target_str])
    # return target_str in test_str

def compute_bleu(run_results):
    all = 0
    completions =  run_results['completions']
    targets = run_results['targets']
    for completion, target in zip(completions, targets):
        bleu_score = sentence_bleu([target], completion['gen_answer'])
        all += bleu_score
    return all / len(completions)

def compute_metric(run_results):
    """Compute accuracy of completions using exact-match.
    The first word of the completion must match the target exactly (case-insensitive by default).

    e.g. completion " World is vast" with target "world" is correct
    """
    n_correct = 0
    is_correct_list = []
    completions =  run_results['completions']
    targets = run_results['targets']
    for completion, target in zip(completions, targets):
        correct = evaluate_completion(completion['gen_answer'], target)
        is_correct_list.append(correct)
        if correct:
            n_correct += 1

    accuracy = n_correct / len(completions)
    return accuracy

def batch_split(prompts, batch_num):   # 把测试的数据集切分成若干个mini-batch
    batch_prompts = []
    mini_batch = []
    for prompt in prompts:
        mini_batch.append(prompt)
        if len(mini_batch) == batch_num:
            batch_prompts.append(mini_batch)
            mini_batch = []
    if len(mini_batch) != 0:
        batch_prompts.append(mini_batch)
    return batch_prompts

def prepare_input(tokenizer, prompts):  # 把准备好的batch用tokenizer转化为llama可接受的数据
    input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", padding=True)
    input_tokens = {k:input_tokens[k] for k in input_tokens if k in ["input_ids", "attention_mask"]}
    for t in input_tokens:
        if torch.is_tensor(input_tokens[t]):
            input_tokens[t] = input_tokens[t].to('cuda')  # 转化为tensor，放到对应的GPU上

    return input_tokens


def construct_different_prompts_for_five_shot(dataset):

    # (目前他是最好的结果)
    system_message = 'You are an expert when it comes to celebrities from various fields, such as actors, singers, and producers, and their family relations. You answer questions concisely, with only the specific answer or "I don\'t know"\n' 

    # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?
    prompt_high_positive_positive_clarity_dataset = []
    prompt_high_positive_negative_clarity_dataset = []
    prompt_high_negative_positive_clarity_dataset = []
    prompt_high_negative_negative_clarity_dataset = []

    prompt_low_positive_positive_clarity_dataset = []
    prompt_low_positive_negative_clarity_dataset = []
    prompt_low_negative_positive_clarity_dataset = []
    prompt_low_negative_negative_clarity_dataset = []

    child_to_parent_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/child_to_parent_dict.json"))
    parent_to_child_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/parent_to_child_dict.json"))
    
    reference_positive_positive_dataset = json.load(open('/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/parent_child_final/standard_positive_positive_positive_test_dataset.json'))
    reference_positive_negative_dataset = json.load(open('/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/parent_child_final/standard_positive_positive_negative_test_dataset.json'))
    reference_negative_positive_dataset = json.load(open('/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/parent_child_final/standard_positive_negative_positive_test_dataset.json'))
    reference_negative_negative_dataset = json.load(open('/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/parent_child_final/standard_positive_negative_negative_test_dataset.json'))

    high_clarity_dataset = dataset['origin_train_high_entity_clarity']
    low_clarity_dataset = dataset['origin_train_low_entity_clarity'] + dataset['origin_test_middle_entity_clarity']

    high_clarity_dataset_id = [sample['id'] for sample in high_clarity_dataset]
    low_clarity_dataset_id = [sample['id'] for sample in low_clarity_dataset]

    assert all([sample not in low_clarity_dataset_id for sample in high_clarity_dataset_id])
    assert all([sample not in high_clarity_dataset_id for sample in low_clarity_dataset_id])

    for i in range(len(high_clarity_dataset)):
        sample = high_clarity_dataset[i]
        child, parent_type, parent, id = sample['child'], sample['parent_type'], sample['parent'], sample['id']

        prompt = reference_positive_positive_dataset[id]['prompt']
        completion = reference_positive_positive_dataset[id]['completion']    
        prompt_high_positive_positive_clarity_dataset.append({'prompt': prompt, 'completion':completion})

        prompt = reference_positive_negative_dataset[id]['prompt']
        completion = reference_positive_negative_dataset[id]['completion']    
        prompt_high_positive_negative_clarity_dataset.append({'prompt': prompt, 'completion':completion})

        prompt = reference_negative_positive_dataset[id]['prompt']
        completion = reference_negative_positive_dataset[id]['completion']    
        prompt_high_negative_positive_clarity_dataset.append({'prompt': prompt, 'completion':completion})

        prompt = reference_negative_negative_dataset[id]['prompt']
        completion = reference_negative_negative_dataset[id]['completion']    
        prompt_high_negative_negative_clarity_dataset.append({'prompt': prompt, 'completion':completion})


    for i in range(len(low_clarity_dataset)):
        sample = low_clarity_dataset[i]
        child, parent_type, parent, id = sample['child'], sample['parent_type'], sample['parent'], sample['id']

        prompt = reference_positive_positive_dataset[id]['prompt']
        completion = reference_positive_positive_dataset[id]['completion']    
        prompt_low_positive_positive_clarity_dataset.append({'prompt': prompt, 'completion':completion})

        prompt = reference_positive_negative_dataset[id]['prompt']
        completion = reference_positive_negative_dataset[id]['completion']    
        prompt_low_positive_negative_clarity_dataset.append({'prompt': prompt, 'completion':completion})

        prompt = reference_negative_positive_dataset[id]['prompt']
        completion = reference_negative_positive_dataset[id]['completion']    
        prompt_low_negative_positive_clarity_dataset.append({'prompt': prompt, 'completion':completion})

        prompt = reference_negative_negative_dataset[id]['prompt']
        completion = reference_negative_negative_dataset[id]['completion']    
        prompt_low_negative_negative_clarity_dataset.append({'prompt': prompt, 'completion':completion})
    
    return prompt_high_positive_positive_clarity_dataset, prompt_high_positive_negative_clarity_dataset, prompt_high_negative_positive_clarity_dataset, prompt_high_negative_negative_clarity_dataset, \
        prompt_low_positive_positive_clarity_dataset, prompt_low_positive_negative_clarity_dataset, prompt_low_negative_positive_clarity_dataset,  prompt_low_negative_negative_clarity_dataset

        
def main():
    
    # 调原始的prompt
    model_name_or_path = "/home/hadoop-aipnlp/nazarite/llama2/llama-2-7b"
    # 调新建的人名数据集的prompt
    lora_path = None
    
    # 数据集的路径
    data_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/all_ood_parent_child_dataset.json"

    prompt_high_positive_positive_clarity_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/prompt_high_positive_positive_clarity_dataset.json'
    prompt_high_positive_negative_clarity_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/prompt_high_positive_negative_clarity_dataset.json'
    prompt_high_negative_positive_clarity_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/prompt_high_negative_positive_clarity_dataset.json'
    prompt_high_negative_negative_clarity_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/prompt_high_negative_negative_clarity_dataset.json'
    prompt_low_positive_positive_clarity_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/prompt_low_positive_positive_clarity_dataset.json'
    prompt_low_positive_negative_clarity_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/prompt_low_positive_negative_clarity_dataset.json'
    prompt_low_negative_positive_clarity_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/prompt_low_negative_positive_clarity_dataset.json'
    prompt_low_negative_negative_clarity_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/prompt_low_negative_negative_clarity_dataset.json'

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left")
    tokenizer.pad_token_id = 0  
    tokenizer.bos_token_id = 1 
    tokenizer.eos_token_id = 2
    model = LlamaForCausalLM.from_pretrained(model_name_or_path, use_safetensors=False, device_map={"":0}, torch_dtype=torch.float16)
    
    if lora_path is not None:
        model = PeftModel.from_pretrained(model, lora_path)
        model.merge_adapter()

    model.eval()
    
    dataset = json.load(open(data_path, 'r'))
    prompt_high_positive_positive_clarity_dataset,  prompt_high_positive_negative_clarity_dataset, prompt_high_negative_positive_clarity_dataset, prompt_high_negative_negative_clarity_dataset, \
            prompt_low_positive_positive_clarity_dataset,  prompt_low_positive_negative_clarity_dataset, prompt_low_negative_positive_clarity_dataset, prompt_low_negative_negative_clarity_dataset = construct_different_prompts_for_five_shot(dataset)

    with open(prompt_high_positive_positive_clarity_dataset_path, 'w') as file1:
        json.dump(prompt_high_positive_positive_clarity_dataset, file1)

    with open(prompt_high_positive_negative_clarity_dataset_path, 'w') as file2:
        json.dump(prompt_high_positive_negative_clarity_dataset, file2)
    
    with open(prompt_high_negative_positive_clarity_dataset_path, 'w') as file3:
        json.dump(prompt_high_negative_positive_clarity_dataset, file3)
    
    with open(prompt_high_negative_negative_clarity_dataset_path, 'w') as file4:
        json.dump(prompt_high_negative_negative_clarity_dataset, file4)
    

    with open(prompt_low_positive_positive_clarity_dataset_path, 'w') as file5:
        json.dump(prompt_low_positive_positive_clarity_dataset, file5)

    with open(prompt_low_positive_negative_clarity_dataset_path, 'w') as file6:
        json.dump(prompt_low_positive_negative_clarity_dataset, file6)
    
    with open(prompt_low_negative_positive_clarity_dataset_path, 'w') as file7:
        json.dump(prompt_low_negative_positive_clarity_dataset, file7)
    
    with open(prompt_low_negative_negative_clarity_dataset_path, 'w') as file8:
        json.dump(prompt_low_negative_negative_clarity_dataset, file8)

    
    prompt_high_positive_positive_clarity_dataset = load_dataset("json", data_files={"eval": prompt_high_positive_positive_clarity_dataset_path})
    prompt_high_positive_negative_clarity_dataset = load_dataset("json", data_files={"eval": prompt_high_positive_negative_clarity_dataset_path})
    prompt_high_negative_positive_clarity_dataset = load_dataset("json", data_files={"eval": prompt_high_negative_positive_clarity_dataset_path})
    prompt_high_negative_negative_clarity_dataset = load_dataset("json", data_files={"eval": prompt_high_negative_negative_clarity_dataset_path})
    
    prompt_low_positive_positive_clarity_dataset = load_dataset("json", data_files={"eval": prompt_low_positive_positive_clarity_dataset_path})
    prompt_low_positive_negative_clarity_dataset = load_dataset("json", data_files={"eval": prompt_low_positive_negative_clarity_dataset_path})
    prompt_low_negative_positive_clarity_dataset = load_dataset("json", data_files={"eval": prompt_low_negative_positive_clarity_dataset_path})
    prompt_low_negative_negative_clarity_dataset = load_dataset("json", data_files={"eval": prompt_low_negative_negative_clarity_dataset_path})


    datasets = [prompt_high_positive_positive_clarity_dataset, prompt_high_positive_negative_clarity_dataset, prompt_high_negative_positive_clarity_dataset, prompt_high_negative_negative_clarity_dataset,
                prompt_low_positive_positive_clarity_dataset, prompt_low_positive_negative_clarity_dataset, prompt_low_negative_positive_clarity_dataset, prompt_low_negative_negative_clarity_dataset]

    for dataset in datasets:
        val_data = dataset['eval']
        prompts = [d['prompt'] for d in val_data]
        completions = [d['completion'] for d in val_data]
        results = []
        batch_size = 32
        for batch_input in tqdm(batch_split(prompts, batch_size)):
            encode_inputs = prepare_input(tokenizer, batch_input)   # 转化为输入，放到对应的gpu上
            with torch.no_grad():
                output = model.generate(**encode_inputs, max_new_tokens=40, pad_token_id=tokenizer.pad_token_id, do_sample=False, num_beams=1,
                                    eos_token_id= tokenizer.eos_token_id)
            # 输出respnose，然后得到对应的结果
            response = tokenizer.batch_decode(output.cpu(), skip_special_tokens=True)
            # 得到对应的输出
            for idx, p in enumerate(batch_input):
                results.append({'question': p, 'gen_answer': response[idx][len(p):]})
            run_results = {'completions':results, 'targets':completions}   
            # 将生成的结果和completion的结果进行比对

        # 全部都是预测人名，直接用EM即可
        print('Using EM')
        metrics = compute_metric(run_results)
        print(metrics)

if __name__ == '__main__':
    main()
