import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaForMLM
from datasets import load_dataset
from tqdm import tqdm
import torch
import random
import re
import json
import math
from peft import PeftModel

def evaluate_completion(
    completion: str,
    target: str,
    case_sensitive: bool = False,
) -> bool:
    """Evaluate completion using exact-match vs the target.
    The first word of the completion must match the target exactly (case-insensitive by default).

    e.g. completion " World is vast" with target "world" is correct
    """
    
    # target = target.strip()
    target = target.split(',')
    target = [sub_target.strip() for sub_target in target]
    test_str = completion.strip()
    test_str = test_str.lower() if not case_sensitive else test_str
    target_str = [sub_target.lower() for sub_target in target] if not case_sensitive else target
    # target_str = target.lower() if not case_sensitive else target
    return any([sub_target_str in test_str for sub_target_str in target_str])
    # return target_str in test_str

def compute_bleu(run_results):
    all = 0
    completions =  run_results['completions']
    targets = run_results['targets']
    for completion, target in zip(completions, targets):
        bleu_score = sentence_bleu([target], completion['gen_answer'])
        all += bleu_score
    return all / len(completions)

def compute_metric(run_results):
    """Compute accuracy of completions using exact-match.
    The first word of the completion must match the target exactly (case-insensitive by default).

    e.g. completion " World is vast" with target "world" is correct
    """
    n_correct = 0
    is_correct_list = []
    completions =  run_results['completions']
    targets = run_results['targets']
    for completion, target in zip(completions, targets):
        correct = evaluate_completion(completion['gen_answer'], target)
        is_correct_list.append(correct)
        if correct:
            n_correct += 1
    
    accuracy = n_correct / len(completions)
    return accuracy

def restore_original_parent_child_turple(question, answer):
    valuable_question = question['question'].split('\nQ:')[-1]
    valuable_question = valuable_question.rstrip('?\nA:').strip()
    parent_type = valuable_question.split()[-1]
    child = ' '.join(valuable_question.split()[2:-1])[:-2]
    parent = answer.strip()
    return {'child': child, 'parent': parent, 'parent_type':parent_type}

def split_ood_dataset_by_entity_clarity(run_results, origin_ood_dataset_path, ratio):

    if origin_ood_dataset_path is not None: 
        origin_ood_dataset = json.load(open(origin_ood_dataset_path))
    else:
        data_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/parent_child_pairs.json"
        original_data = json.load(open(data_path))
        origin_ood_dataset = {'origin_train_high_entity_clarity': [],
                    'origin_train_low_entity_clarity': [],
                    'origin_test_middle_entity_clarity': []}
        train_high_entity_clarity_count = 0
        train_low_entity_clarity_count = 0
        # 这里的数字要根据few-shots能够回答出的样本数量来确定1513 * 0.3956，代表的是全体的各个类别的可用数据集
        count_threshold = 599
        completions = run_results['completions']
        groundtruths = run_results['targets']
        id = 0
        for completion, target in zip(completions, groundtruths):
            correct_flag = evaluate_completion(completion['gen_answer'], target)
            sample_dict = restore_original_parent_child_turple(completion, target)
            if sample_dict == original_data[id]:
                pass
            else:
                print('something is wrong!')
                exit(0)
            sample_dict['id'] = id
            if correct_flag and train_high_entity_clarity_count < count_threshold:
                origin_ood_dataset['origin_train_high_entity_clarity'].append(sample_dict)
                train_high_entity_clarity_count += 1
            elif not correct_flag and train_low_entity_clarity_count < count_threshold:
                origin_ood_dataset['origin_train_low_entity_clarity'].append(sample_dict)
                train_low_entity_clarity_count += 1
            else:
                origin_ood_dataset['origin_test_middle_entity_clarity'].append(sample_dict)
            id = id + 1
        assert len(origin_ood_dataset['origin_train_high_entity_clarity']) == len(origin_ood_dataset['origin_train_low_entity_clarity'])
        assert len(origin_ood_dataset['origin_test_middle_entity_clarity']) == 1513 - 2 * len(origin_ood_dataset['origin_train_high_entity_clarity'])
        # 这里就是要存储你分好的数据集，origin_ood_dataset，代表时采样好的数据集，这里面是带有id的
        with open('/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/all_ood_parent_child_dataset.json'.format(ratio), 'w') as file10:
            json.dump(origin_ood_dataset, file10)
    
    # 下面开始从整体的数据中采样得到自己对应类别的数据
    train_high_entity_clarity_num = math.ceil(1513 * ratio)
    train_low_entity_clarity_num = math.ceil(1513 * ratio)
    # 所有都答不对的315个数据
    test_middle_entity_clarity_num = 315
    assert len(origin_ood_dataset['origin_test_middle_entity_clarity']) == test_middle_entity_clarity_num
    specific_ratio_ood_dataset = {'train_high_entity_clarity': [],
                              'train_low_entity_clarity': [],
                              'test_middle_entity_clarity': []}
    # 开始分配对应的数据集,这里从采样开始其实就是带有id的
    specific_ratio_ood_dataset['test_middle_entity_clarity'] = origin_ood_dataset['origin_test_middle_entity_clarity']
    # 设定随机种子进行随机采样
    random.seed(13)
    specific_ratio_ood_dataset['train_high_entity_clarity'] = random.sample(origin_ood_dataset['origin_train_high_entity_clarity'], train_high_entity_clarity_num)
    random.seed(13)
    specific_ratio_ood_dataset['train_low_entity_clarity'] = random.sample(origin_ood_dataset['origin_train_low_entity_clarity'], train_low_entity_clarity_num)
    return specific_ratio_ood_dataset

        
def batch_split(prompts, batch_num):   # 把测试的数据集切分成若干个mini-batch
    batch_prompts = []
    mini_batch = []
    for prompt in prompts:
        mini_batch.append(prompt)
        if len(mini_batch) == batch_num:
            batch_prompts.append(mini_batch)
            mini_batch = []
    if len(mini_batch) != 0:
        batch_prompts.append(mini_batch)
    return batch_prompts

def prepare_input(tokenizer, prompts):  # 把准备好的batch用tokenizer转化为llama可接受的数据
    input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", padding=True)
    input_tokens = {k:input_tokens[k] for k in input_tokens if k in ["input_ids", "attention_mask"]}
    for t in input_tokens:
        if torch.is_tensor(input_tokens[t]):
            input_tokens[t] = input_tokens[t].to('cuda')  # 转化为tensor，放到对应的GPU上

    return input_tokens


def construct_prompt_for_clarify_entity_dataset(dataset):
                                     
    # 这里的prompt应该根据自身的数据选择成，我知道和不知道这个人名(目前直接选这个应该也是可以的)
    system_message = 'You are an expert when it comes to celebrities from various fields, such as actors, singers, and producers, and their family relations. You answer questions concisely, with only the specific answer or "I don\'t know"\n' # (目前他是最好的结果)

    # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?
    prompt_for_clarify_entity_dataset = []

    for i in range(len(dataset)):
        sample = dataset[i]
        child, parent_type, parent = sample['child'], sample['parent_type'], sample['parent']

        sample_index = list(range(len(dataset)))
        sample_index.remove(i)
        random.seed(i + 42)
        demonstration_index = random.sample(sample_index, 5)        
        prompt = system_message

        for j in range(5):
            demonstration = dataset[demonstration_index[j]]
            # 正序 + 正关 (Who is the parent of A? B)
            example = "Q: Who is {}'s {}?\nA: {}\n".format(demonstration['child'], demonstration['parent_type'], demonstration['parent'])
            prompt += example

        # 正序 + 正关 (Who is the parent of A? B)
        prompt += "Q: Who is {}'s {}?\nA:".format(child, parent_type)
        completion = " {}".format(parent)
        prompt_for_clarify_entity_dataset.append({'prompt': prompt, 'completion': completion})
        
    return prompt_for_clarify_entity_dataset

        
def main():
    
    # 调原始的prompt
    model_name_or_path = "/home/hadoop-aipnlp/nazarite/llama2/llama-2-7b"
    # 调新建的人名数据集的prompt
    lora_path = None
    # 数据集的路径
    data_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/parent_child_pairs.json"

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left")
    tokenizer.pad_token_id = 0  
    tokenizer.bos_token_id = 1 
    tokenizer.eos_token_id = 2
    model = LlamaForCausalLM.from_pretrained(model_name_or_path, use_safetensors=False, device_map={"":0}, torch_dtype=torch.float16)
    
    if lora_path is not None:
        model = PeftModel.from_pretrained(model, lora_path)
        model.merge_adapter()

    model.eval()
    
    dataset = json.load(open(data_path, 'r'))
    # 这里是制作five-shots对数据集的整体来打分区分high_clarity和low_clarity的,
    prompt_for_entity_clarify_dataset = construct_prompt_for_clarify_entity_dataset(dataset)
    prompt_for_entity_clarify_dataset_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/prompt_for_entity_clarify_dataset.json"

    with open(prompt_for_entity_clarify_dataset_path, 'w') as file1:
        json.dump(prompt_for_entity_clarify_dataset, file1)
    
    prompt_positive_positive_dataset = load_dataset("json", data_files={"eval": prompt_for_entity_clarify_dataset_path})

    datasets = [prompt_positive_positive_dataset]

    for dataset in datasets:
        val_data = dataset['eval']
        prompts = [d['prompt'] for d in val_data]
        completions = [d['completion'] for d in val_data]
        results = []
        batch_size = 32
        for batch_input in tqdm(batch_split(prompts, batch_size)):
            encode_inputs = prepare_input(tokenizer, batch_input)   # 转化为输入，放到对应的gpu上
            with torch.no_grad():
                output = model.generate(**encode_inputs, max_new_tokens=40, pad_token_id=tokenizer.pad_token_id, do_sample=False, num_beams=1,
                                    eos_token_id= tokenizer.eos_token_id)
            # 输出respnose，然后得到对应的结果
            response = tokenizer.batch_decode(output.cpu(), skip_special_tokens=True)
            # 得到对应的输出
            for idx, p in enumerate(batch_input):
                results.append({'question': p, 'gen_answer': response[idx][len(p):]})
            run_results = {'completions':results, 'targets':completions}   
            # 将生成的结果和completion的结果进行比对

        # 全部都是预测人名，直接用EM即可
        print('Using EM')
        metrics = compute_metric(run_results)
        print(metrics)
        # 对parent_to_child来说，总数据量是1513，能够答出来的数据量是579，答不出来的数据量是934 (其中取579作为固定的low的样本，剩下的固定355作为测试样本)
        # 如果剩余20%的数据集量进行训练，训练样本是从两个599中抽样出303，测试样本为355
        # 如果剩余25%的数据集量进行训练，训练样本是从两个599中抽样出379，测试样本为355
        # 如果剩余30%的数据集量进行训练，训练样本是从两个599中抽样出454，测试样本为355
        # 如果剩余35%的数据集量进行训练，训练样本是从两个599中抽样出530，测试样本为355,这里就是选定你要用多少的数据量来进行训练
        specific_ratio = 0.39590218109715797
        ood_parent_child_dataset = split_ood_dataset_by_entity_clarity(run_results, None, specific_ratio)
        if specific_ratio == 0.39590218109715797:
            pass
        else:
            with open('/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/ood_parent_child_final/ood_{}/ood_sample_parent_child_dataset.json'.format(str('all')), 'w') as file2:
                json.dump(ood_parent_child_dataset, file2)

if __name__ == '__main__':
    main()
