import os 
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaForMLM
from datasets import load_dataset
from tqdm import tqdm
import torch
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import random
import copy
from peft import PeftModel

def evaluate_completion(
    completion: str,
    target: str,
    case_sensitive: bool = False,
) -> bool:
    """Evaluate completion using exact-match vs the target.
    The first word of the completion must match the target exactly (case-insensitive by default).

    e.g. completion " World is vast" with target "world" is correct
    """
    
    # target = target.strip()
    target = target.split(',')
    target = [sub_target.strip() for sub_target in target]
    test_str = completion.strip()
    test_str = test_str.lower() if not case_sensitive else test_str
    target_str = [sub_target.lower() for sub_target in target] if not case_sensitive else target
    # target_str = target.lower() if not case_sensitive else target
    return any([sub_target_str in test_str for sub_target_str in target_str])
    # return target_str in test_str

def compute_bleu(run_results):
    all = 0
    completions =  run_results['completions']
    targets = run_results['targets']
    for completion, target in zip(completions, targets):
        bleu_score = sentence_bleu([target], completion['gen_answer'])
        all += bleu_score
    return all / len(completions)

def compute_metric(run_results):
    """Compute accuracy of completions using exact-match.
    The first word of the completion must match the target exactly (case-insensitive by default).

    e.g. completion " World is vast" with target "world" is correct
    """
    n_correct = 0
    is_correct_list = []
    completions =  run_results['completions']
    targets = run_results['targets']
    for completion, target in zip(completions, targets):
        correct = evaluate_completion(completion['gen_answer'], target)
        is_correct_list.append(correct)
        if correct:
            n_correct += 1

    accuracy = n_correct / len(completions)
    return accuracy

def batch_split(prompts, batch_num):   # 把测试的数据集切分成若干个mini-batch
    batch_prompts = []
    mini_batch = []
    for prompt in prompts:
        mini_batch.append(prompt)
        if len(mini_batch) == batch_num:
            batch_prompts.append(mini_batch)
            mini_batch = []
    if len(mini_batch) != 0:
        batch_prompts.append(mini_batch)
    return batch_prompts

def prepare_input(tokenizer, prompts):  # 把准备好的batch用tokenizer转化为llama可接受的数据
    input_tokens = tokenizer.batch_encode_plus(prompts, return_tensors="pt", padding=True)
    input_tokens = {k:input_tokens[k] for k in input_tokens if k in ["input_ids", "attention_mask"]}
    for t in input_tokens:
        if torch.is_tensor(input_tokens[t]):
            input_tokens[t] = input_tokens[t].to('cuda')  # 转化为tensor，放到对应的GPU上

    return input_tokens


def construct_different_prompts_for_five_shot(dataset):

    system_message = 'You are an expert when it comes to books from various fields, such as science, literature, and technology, and their author relationships. You answer questions concisely, with only the specific answer or "I don\'t know"\n' # (目前他是最好的结果)

    # 给定他们的模版
    # The book name's author is person
    # 正序正关的两个模版
    # Who is the book name's author? person 和 The book name's author is whom? person
    # 正序逆关的两个模版
    # Who has written the book name? person 和 The book is writtem by whom? person
    # 逆序正关的两个模版
    # Which book's author is person? name 和 person is the author of which book? name
    # 逆序逆关的两个模版
    # Which book is written by person? name 和 person has written which book? name

    # person's has written the book name
    # 正序正关的两个模版
    # Which book is written by person? name 和 person has written which book? name
    # 正序逆关的两个模版
    # Which book's author is person? name 和 person is the author of which book? name
    # 逆序正关的两个模版
    # Who has written the book name? person 和 The book is writtem by whom? person
    # 逆序逆关的两个模版
    # Who is the book name's author? person 和 The book name's author is whom? person

    # 给定的新模版
    # A's author is B
    # 正序正关的两个模版
    # Who is A's author? B 和 A's author is whom? B
    # 正序逆关的两个模版
    # Whose work is A? B 和 A is whose work? B
    # 逆序正关的两个模版
    # What is author of B? A 和 B is author of what? A
    # 逆序逆关的两个模版
    # What is B's work? A 和 B's work is what? A

    prompt_positive_positive_dataset = []
    prompt_positive_negative_dataset = []
    prompt_negative_positive_dataset = []
    prompt_negative_negative_dataset = []

    author_to_book_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/author_to_book_dict.json"))
    book_to_author_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/book_to_author_dict.json"))
    
    for i in range(len(dataset)):
        sample = dataset[i]
        book, author = sample['book'], sample['author']
        another_book_list = author_to_book_dict[author].copy()
        another_book_list.remove(book)
        another_book = another_book_list if len(another_book_list) != 0 else None
        another_author_list = book_to_author_dict[book].copy()
        another_author_list.remove(author)
        another_author = another_author_list if len(another_author_list) != 0 else None

        
        # 这里就能够保证采样的demonstration index是大家都保持一致的，注意这里每次抽样的都是序列，这是不一样的地方。
        index = list(range(len(dataset)))
        index.remove(i)
        random.seed(i + 1921)
        demonstration_index = random.sample(index, 5)        

        prompt = system_message
        # 这里已经直接改过了
        for j in range(5):
            demonstration = dataset[demonstration_index[j]]
            # 正序 + 正关 (Who is the book name's author? person)
            example = "Q: Who is {}'s author?\nA: {}\n".format(demonstration['book'], demonstration['author'])
            prompt += example

        # 正序 + 正关 (Who is the parent of A? B)
        prompt += "Q: Who is {}'s author?\nA:".format(book)
        completion = [" {}".format(author), ] if another_author is None else [" {}".format(sample_another_author) for sample_another_author in another_author] + [" {}".format(author)]
        completion = ','.join(completion)
        prompt_positive_positive_dataset.append({'prompt': prompt, 'completion': completion})


        index = list(range(len(dataset)))
        index.remove(i)
        random.seed(i + 1202)
        demonstration_index = random.sample(index, 5)

        prompt = system_message
        for j in range(5):
            demonstration = dataset[demonstration_index[j]]
            # 正序 + 逆关, Who has written the book name? person
            example = "Q: Whose work is {}?\nA: {}\n".format(demonstration['book'], demonstration['author'])
            prompt += example

        # 正序 + 逆关
        prompt += "Q: Whose work is {}?\nA:".format(book)
        completion = [" {}".format(author), ] if another_author is None else [" {}".format(sample_another_author) for sample_another_author in another_author] + [" {}".format(author)]
        completion = ','.join(completion)
        prompt_positive_negative_dataset.append({'prompt': prompt, 'completion': completion})

        index = list(range(len(dataset)))
        index.remove(i)
        random.seed(i + 2120)
        demonstration_index = random.sample(index, 5)

        prompt = system_message
        for j in range(5):
            demonstration = dataset[demonstration_index[j]]
            # 逆序 + 正关, which book's author is person? name
            example = "Q: What is {} author of?\nA: {}\n".format(demonstration['author'], demonstration['book'])
            prompt += example
        
        # 逆序 + 正关
        prompt += "Q: What is {} author of?\nA:".format(author)
        completion = [" {}".format(book), ] if another_book is None else [" {}".format(sample_another_book) for sample_another_book in another_book] + [" {}".format(book)]
        completion = ','.join(completion)
        # completion = " {}".format(child)
        prompt_negative_positive_dataset.append({'prompt': prompt, 'completion': completion})

        index = list(range(len(dataset)))
        index.remove(i)
        random.seed(i + 1006)
        demonstration_index = random.sample(index, 5)
        # 这里也改一下
        prompt = system_message
        for j in range(5):
            demonstration = dataset[demonstration_index[j]]
            # 逆序 + 逆关, which book is written by person? name
            example = "Q: What is {}'s work?\nA: {}\n".format(demonstration['author'], demonstration['book'])
            prompt += example
        
        prompt += "Q: What is {}'s work?\nA:".format(author)
        completion = [" {}".format(book), ] if another_book is None else [" {}".format(sample_another_book) for sample_another_book in another_book] + [" {}".format(book)]
        completion = ','.join(completion)
        prompt_negative_negative_dataset.append({'prompt': prompt, 'completion': completion})
        
    return prompt_positive_positive_dataset, prompt_positive_negative_dataset, prompt_negative_positive_dataset, prompt_negative_negative_dataset



def main():

    target_data_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/book_author_pairs.json'
    model_name_or_path = "/home/hadoop-aipnlp/nazarite/llama2/llama-2-7b"
    lora_path = None
    filter = True

    if os.path.isfile(target_data_path):
        dataset = json.load(open(target_data_path, 'r'))
    else:
        raw_data_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/5k_book_auther.jsonl'
        data_dict = []
        # 取数据
        with open(raw_data_path, 'r') as file:
            for line in file:
                raw_data = json.loads(line)
                data_dict.append({'book': raw_data['book_title'], 'author': raw_data['book_author']})

        # 得到我最后要的这一份数据
        dataset = data_dict    

    author_to_book_count_dict = {}
    book_to_author_count_dict = {}

    for id, sample in enumerate(dataset):
        book, author = sample['book'], sample['author']
        if book not in book_to_author_count_dict.keys():
            book_to_author_count_dict[book] = [id, ]
        else:
            book_to_author_count_dict[book].append(id)
            
        
        if author not in author_to_book_count_dict.keys():
            author_to_book_count_dict[author] = [id, ]
        else:
            author_to_book_count_dict[author].append(id)
    
    if filter:
        random.seed(13)
        filter_author_book_count_dict = [value for value in author_to_book_count_dict.values() if len(value) <= 3]
        sample_id = []
        for sample_list in filter_author_book_count_dict:
            sample_id = sample_id + sample_list
        
        sample_id = random.sample(sample_id, 2000)

        filter_dataset = [sample for id, sample in enumerate(dataset) if id in sample_id]
        assert len(filter_dataset) ==  2000
        author_to_book_dict = {'independent_id': [], 'dependent_id': []}
        book_to_author_dict = {'independent_id': [], 'dependent_id': []}

        for id, sample in enumerate(filter_dataset):
            book, author = sample['book'], sample['author']
            if book not in book_to_author_dict.keys():
                book_to_author_dict[book] = [author, ]
                # 这里面全部都是独立的
                book_to_author_dict['independent_id'].append(id)
            else:
                book_to_author_dict[book].append(author)
                # 这里说明是和上面存在一些重复的
                book_to_author_dict['dependent_id'].append(id)
                
            if author not in author_to_book_dict.keys():
                author_to_book_dict[author] = [book, ]
                # 这里全部都是独立的
                author_to_book_dict['independent_id'].append(id)
            else:
                author_to_book_dict[author].append(book)
                # 这里则说明和上面有一些重复
                author_to_book_dict['dependent_id'].append(id)

    book_to_author_dict_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/book_to_author_dict.json"
    author_to_book_dict_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/author_to_book_dict.json"

    # 操作完了之后存数据
    with open(target_data_path, 'w') as file1:
        json.dump(filter_dataset if filter else dataset, file1)

    with open(book_to_author_dict_path, 'w') as file6:
        json.dump(book_to_author_dict, file6)
    
    with open(author_to_book_dict_path, 'w') as file7:
        json.dump(author_to_book_dict, file7)

    # 对分词器和模型进行一系列的优化
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left")
    tokenizer.pad_token_id = 0  
    tokenizer.bos_token_id = 1 
    tokenizer.eos_token_id = 2
    model = LlamaForCausalLM.from_pretrained(model_name_or_path, use_safetensors=False, device_map={"":0}, torch_dtype=torch.float16)
    
    if lora_path is not None:
        model = PeftModel.from_pretrained(model, lora_path)
        model.merge_adapter()

    model.eval()
    
    prompt_positive_positive_dataset,  prompt_positive_negative_dataset, prompt_negative_positive_dataset, prompt_negative_negative_dataset = construct_different_prompts_for_five_shot(dataset if not filter else filter_dataset)
    prompt_positive_posiitve_dataset_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/prompt_positive_positive_dataset.json"
    prompt_positive_negative_dataset_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/prompt_positive_negative_dataset.json"
    prompt_negative_positive_dataset_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/prompt_negative_positive_dataset.json"
    prompt_negative_negative_dataset_path = "/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/prompt_negative_negative_dataset.json"

    with open(prompt_positive_posiitve_dataset_path, 'w') as file2:
        json.dump(prompt_positive_positive_dataset, file2)
    
    with open(prompt_positive_negative_dataset_path, 'w') as file3:
        json.dump(prompt_positive_negative_dataset, file3)

    with open(prompt_negative_positive_dataset_path, 'w') as file4:
        json.dump(prompt_negative_positive_dataset, file4)
    
    with open(prompt_negative_negative_dataset_path, 'w') as file5:
        json.dump(prompt_negative_negative_dataset, file5)


    prompt_positive_positive_dataset = load_dataset("json", data_files={"eval": prompt_positive_posiitve_dataset_path})
    prompt_positive_negative_dataset = load_dataset("json", data_files={"eval": prompt_positive_negative_dataset_path})
    prompt_negative_positive_dataset = load_dataset("json", data_files={"eval": prompt_negative_positive_dataset_path})
    prompt_negative_negative_dataset = load_dataset("json", data_files={"eval": prompt_negative_negative_dataset_path})

    datasets = [prompt_positive_positive_dataset, prompt_positive_negative_dataset, prompt_negative_positive_dataset, prompt_negative_negative_dataset]

    for dataset in datasets:
        val_data = dataset['eval']
        prompts = [d['prompt'] for d in val_data]
        completions = [d['completion'] for d in val_data]
        results = []
        batch_size = 32
        for batch_input in tqdm(batch_split(prompts, batch_size)):
            encode_inputs = prepare_input(tokenizer, batch_input)   # 转化为输入，放到对应的gpu上
            with torch.no_grad():
                output = model.generate(**encode_inputs, max_new_tokens=40, pad_token_id=tokenizer.pad_token_id, do_sample=False, num_beams=1,
                                    eos_token_id= tokenizer.eos_token_id)
            # 输出respnose，然后得到对应的结果
            response = tokenizer.batch_decode(output.cpu(), skip_special_tokens=True)
            # 得到对应的输出
            for idx, p in enumerate(batch_input):
                results.append({'question': p, 'gen_answer': response[idx][len(p):]})
            run_results = {'completions':results, 'targets':completions}   
            # 将生成的结果和completion的结果进行比对

        # 全部都是预测人名，直接用EM即可
        print('Using EM')
        metrics = compute_metric(run_results)
        print(metrics)

if __name__ == '__main__':
    main()
