import os 
import json
import random
import copy

"""
    # 给定他们的模版
    # The book name's author is person
    # 正序正关的两个模版
    # Who is the book name's author? person 和 The book name's author is whom? person
    # 正序逆关的两个模版
    # Who has written the book name? person 和 The book is writtem by whom? person
    # 逆序正关的两个模版
    # Which book's author is person? name 和 person is the author of which book? name
    # 逆序逆关的两个模版
    # Which book is written by person? name 和 person has written which book? name

    # person's has written the book name
    # 正序正关的两个模版
    # Which book is written by person? name 和 person has written which book? name
    # 正序逆关的两个模版
    # Which book's author is person? name 和 person is the author of which book? name
    # 逆序正关的两个模版
    # Who has written the book name? person 和 The book is writtem by whom? person
    # 逆序逆关的两个模版
    # Who is the book name's author? person 和 The book name's author is whom? person

"""

# 合并逆转诅咒实验中的两个数据集
def merge_ar_dataset_one_by_one(ar_positive_dataset, ar_negative_dataset):
    assert len(ar_positive_dataset) == len(ar_negative_dataset)
    merge_dataset = []
    for i in range(len(ar_positive_dataset)):
        merge_dataset.append(ar_positive_dataset[i])
        merge_dataset.append(ar_negative_dataset[i])
    return merge_dataset

# 构建原始数据集中的正负对样本
def construct_origin_dataset_pair(dataset, data_format_index, known=False):
    """
    dataset: 整体的dataset格式
    dataset_format_index: 用哪一个格式进行评测
    注意这里是默认要进行demonstration的
    """
    # 给定系统的message
    # system = "Below is a conversation with a helpful and terse assistant. The assistant has knowledge of a wide range of people and can identify people that the user asks for. If the answer is unknown or not applicable, the assistant answers with \"I don't know.\"\n"
    # 这里换成我们的prompt
    system_message = 'You are an expert when it comes to books from various fields, such as science, literature, and technology, and their author relationships. You answer questions concisely, with only the specific answer or "I don\'t know"\n' # (目前他是最好的结果)
    shared_index_len = len(dataset)

    author_to_book_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/author_to_book_dict.json"))
    book_to_author_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/book_to_author_dict.json"))

    # 对应train和test的样本
    standard_positive_train_dataset = []
    standard_positive_positive_positive_test_dataset = []
    standard_positive_positive_negative_test_dataset = []
    standard_positive_negative_positive_test_dataset = []
    standard_positive_negative_negative_test_dataset = []
    
    for i in range(shared_index_len):
        # 取索引度样本
        cur_sample = dataset[i]
        book, author = cur_sample['book'], cur_sample['author']
        another_book_list = author_to_book_dict[author].copy()
        another_book_list.remove(book)
        another_book = another_book_list if len(another_book_list) != 0 else None
        another_author_list = book_to_author_dict[book].copy()
        another_author_list.remove(author)
        another_author = another_author_list if len(another_author_list) != 0 else None

        if data_format_index == 1:
        # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?，这里的预训练语料还是要考虑一下要不要加句号的
            prompt = "{}'s author is {}.".format(book, author)
            completion = ""
            standard_positive_train_dataset.append({'prompt': prompt, 'completion': completion})

            # 开始制作test对应的样本集合
            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            # 同样是要随着样本的id来设计相应的随机种子
            random.seed(i + 42)
            demonstration_index = random.sample(sample_list, 5) 

            test_prompt = system_message
            # 这里已经直接改过了
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 正序 + 正关 (Who is the book name's author? person)
                example =  "Q: Who is {}'s author?\nA: {}\n".format(demonstration['book'], demonstration['author'])
                test_prompt += example

            # 正序 + 正关 (Who is the parent of A? B)
            test_prompt += "Q: Who is {}'s author?\nA:".format(book)
            test_completion = [" {}".format(author), ] if another_author is None else [" {}".format(sample_another_author) for sample_another_author in another_author] + [" {}".format(author)]
            test_completion = ','.join(test_completion)
            standard_positive_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            # 同样是要随着样本的id来设计相应的随机种子
            random.seed(i + 42)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 正序 + 逆关, Who has written the book name? person
                example = "Q: Whose work is {}?\nA: {}\n".format(demonstration['book'], demonstration['author'])
                test_prompt += example

            # 正序 + 逆关
            test_prompt += "Q: Whose work is {}?\nA:".format(book)
            test_completion = [" {}".format(author), ] if another_author is None else [" {}".format(sample_another_author) for sample_another_author in another_author] + [" {}".format(author)]
            test_completion = ','.join(test_completion)
            standard_positive_positive_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            
            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            # 同样是要随着样本的id来设计相应的随机种子
            random.seed(i + 42)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 逆序 + 正关, which book's author is person? name
                example = "Q: What is {} author of?\nA: {}\n".format(demonstration['author'], demonstration['book'])
                test_prompt += example
            
            # 逆序 + 正关
            test_prompt += "Q: What is {} author of?\nA:".format(author)
            test_completion = [" {}".format(book), ] if another_book is None else [" {}".format(sample_another_book) for sample_another_book in another_book] + [" {}".format(book)]
            test_completion = ','.join(test_completion)
            # test_completion = " {}".format(child)
            standard_positive_negative_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            

            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            # 同样是要随着样本的id来设计相应的随机种子
            random.seed(i + 999)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 逆序 + 逆关, which book is written by person? name
                example = "Q: What is {}'s work?\nA: {}\n".format(demonstration['author'], demonstration['book'])
                test_prompt += example
            
            test_prompt += "Q: What is {}'s work?\nA:".format(author)
            test_completion = [" {}".format(book), ] if another_book is None else [" {}".format(sample_another_book) for sample_another_book in another_book] + [" {}".format(book)]
            test_completion = ','.join(test_completion)
            standard_positive_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
                

    # 对应train和test的样本
    standard_negative_train_dataset = []
    standard_negative_positive_positive_test_dataset = []
    standard_negative_positive_negative_test_dataset = []
    standard_negative_negative_positive_test_dataset = [] 
    standard_negative_negative_negative_test_dataset = []

    for i in range(shared_index_len):
        # 取索引度样本
        cur_sample = dataset[i]
        book, author = cur_sample['book'], cur_sample['author']
        another_book_list = author_to_book_dict[author].copy()
        another_book_list.remove(book)
        another_book = another_book_list if len(another_book_list) != 0 else None
        another_author_list = book_to_author_dict[book].copy()
        another_author_list.remove(author)
        another_author = another_author_list if len(another_author_list) != 0 else None

        if data_format_index == 1:
        # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?
            prompt = "{}'s work is {}.".format(author, book)
            completion = ""
            standard_negative_train_dataset.append({'prompt': prompt, 'completion': completion})

            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            random.seed(i)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 逆序 + 逆关, which book is written by person? name
                example = "Q: What is {}'s work?\nA: {}\n".format(demonstration['author'], demonstration['book'])
                test_prompt += example
            
            test_prompt += "Q: What is {}'s work?\nA:".format(author)
            test_completion = [" {}".format(book), ] if another_book is None else [" {}".format(sample_another_book) for sample_another_book in another_book] + [" {}".format(book)]
            test_completion = ','.join(test_completion)
            standard_negative_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            random.seed(i)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message

            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 逆序 + 正关, which book's author is person? name
                example = "Q: What is {} author of?\nA: {}\n".format(demonstration['author'], demonstration['book'])
                test_prompt += example
            
            # 逆序 + 正关
            test_prompt += "Q: What is {} author of?\nA:".format(author)
            test_completion = [" {}".format(book), ] if another_book is None else [" {}".format(sample_another_book) for sample_another_book in another_book] + [" {}".format(book)]
            test_completion = ','.join(test_completion)
            standard_negative_positive_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            random.seed(i)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 正序 + 逆关, Who has written the book name? person
                example = "Q: Whose work is {}?\nA: {}\n".format(demonstration['book'], demonstration['author'])
                test_prompt += example

            # 正序 + 逆关
            test_prompt += "Q: Whose work is {}?\nA:".format(book)
            test_completion = [" {}".format(author), ] if another_author is None else [" {}".format(sample_another_author) for sample_another_author in another_author] + [" {}".format(author)]
            test_completion = ','.join(test_completion)
            standard_negative_negative_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            random.seed(i)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 正序 + 正关 (Who is the book name's author? person)
                example = "Q: Who is {}'s author?\nA: {}\n".format(demonstration['book'], demonstration['author'])
                test_prompt += example

            # 正序 + 正关 (Who is the parent of A? B)
            test_prompt += "Q: Who is {}'s author?\nA:".format(book)
            test_completion = [" {}".format(author), ] if another_author is None else [" {}".format(sample_another_author) for sample_another_author in another_author] + [" {}".format(author)]
            test_completion = ','.join(test_completion)
            standard_negative_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

    return standard_positive_train_dataset, standard_positive_positive_positive_test_dataset, standard_positive_positive_negative_test_dataset, standard_positive_negative_positive_test_dataset, standard_positive_negative_negative_test_dataset, \
           standard_negative_train_dataset, standard_negative_positive_positive_test_dataset, standard_negative_positive_negative_test_dataset, standard_negative_negative_positive_test_dataset, standard_negative_negative_negative_test_dataset


# 构建逆转诅咒实验中交替的样本对
def construct_reverse_dataset_pair(dataset, positive_index, reverse_index, flag, data_format_index):

    """
    dataset: 整体的dataset格式
    positive_index: 正向训练样本的索引
    reverse_index: 逆转训练样本的索引
    flag: 当前生成正向样本还是负向样本
    data_format_index: 当前用的数据格式是什么
    random_demonstration: 是否不同的样本采用不同的索引
    """
    # 给定系统的message
    system_message = 'You are an expert when it comes to books from various fields, such as science, literature, and technology, and their author relationships. You answer questions concisely, with only the specific answer or "I don\'t know"\n' # (目前他是最好的结果)
    # 一般来说感觉还是用固定的demonstration就可以了，可以尝试所有的测试样本都同时换不同的demonstration
    author_to_book_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/author_to_book_dict.json"))
    book_to_author_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/book_to_author_dict.json"))

    # 需要分两种情况看测试集的情况(一种是挑选它的正序样本，一种是挑选它的逆序样本)
            
    # 先看看输入是否合理
    if flag not in ['positive', 'reverse'] or data_format_index not in [1, 2, 3, 4]:
        print('something is wrong')
    else:
        # 制作positive的数据集
        if flag == 'positive':
            # 对应train和test的样本
            ar_train_dataset = []
            ar_positive_positive_test_dataset = []
            ar_positive_negative_test_dataset = []
            ar_negative_positive_test_dataset = []
            ar_negative_negative_test_dataset = []

            for i in range(len(positive_index)):
                # 取索引的样本
                cur_index = positive_index[i]
                cur_sample = dataset[cur_index]
                book, author = cur_sample['book'], cur_sample['author']

                another_book_list = author_to_book_dict[author].copy()
                another_book_list.remove(book)
                another_book = another_book_list if len(another_book_list) != 0 else None
                another_author_list = book_to_author_dict[book].copy()
                another_author_list.remove(author)
                another_author = another_author_list if len(another_author_list) != 0 else None

                if data_format_index == 1:
                # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?
                    positive_original_prompt = "{}'s author is {}.".format(book, author)
                    positive_qa_positive_positve_prompt = "{}'s author is whom? {}.".format(book, author)
                    positive_qa_positive_negative_prompt = "{} is whose work? {}.".format(book, author)
                    positive_qa_negative_positive_prompt = "{} is author of what? {}.".format(author, book)
                    positive_qa_negative_negative_prompt = "{}'s work is what? {}.".format(author, book)
                    completion = ""
                    ar_train_dataset.append({"origin_prompt": positive_original_prompt,
                                             "qa_positive_positive_prompt": positive_qa_positive_positve_prompt,
                                             "qa_positive_negative_prompt": positive_qa_positive_negative_prompt,
                                             "qa_negative_positive_prompt": positive_qa_negative_positive_prompt,
                                             "qa_negative_negative_prompt": positive_qa_negative_negative_prompt,
                                             "completion": completion})

                    
                    # 确定随机种子，并取正向的样本进行测试
                    copy_positive_index = copy.deepcopy(positive_index)
                    copy_positive_index.pop(i)
                    random.seed(i+1459)
                    demonstration_index = random.sample(copy_positive_index, 5)

                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 正序 + 正关 (Who is the book name's author? person)
                        example = "Q: Who is {}'s author?\nA: {}\n".format(demonstration['book'], demonstration['author'])
                        test_prompt += example

                    # 正序 + 正关 (Who is the parent of A? B)
                    test_prompt += "Q: Who is {}'s author?\nA:".format(book)
                    test_completion = [" {}".format(author), ] if another_author is None else [" {}".format(sample_another_author) for sample_another_author in another_author] + [" {}".format(author)]
                    test_completion = ','.join(test_completion)
                    ar_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


                    # 确定随机种子，并取正向的样本进行测试
                    copy_positive_index = copy.deepcopy(positive_index)
                    copy_positive_index.pop(i)
                    random.seed(i+1459)
                    demonstration_index = random.sample(copy_positive_index, 5)

                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 正序 + 逆关, Who has written the book name? person
                        example = "Q: Whose work is {}?\nA: {}\n".format(demonstration['book'], demonstration['author'])
                        test_prompt += example

                    # 正序 + 逆关
                    test_prompt += "Q: Whose work is {}?\nA:".format(book)
                    test_completion = [" {}".format(author), ] if another_author is None else [" {}".format(sample_another_author) for sample_another_author in another_author] + [" {}".format(author)]
                    test_completion = ','.join(test_completion)
                    ar_positive_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


                    # 确定随机种子，并取正向的样本进行测试
                    copy_positive_index = copy.deepcopy(positive_index)
                    copy_positive_index.pop(i)
                    # 1034
                    random.seed(i + 1459)   
                    demonstration_index = random.sample(copy_positive_index, 5)

                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 逆序 + 正关, which book's author is person? name
                        example = "Q: What is {} author of?\nA: {}\n".format(demonstration['author'], demonstration['book'])
                        test_prompt += example
                    
                    # 逆序 + 正关
                    test_prompt += "Q: What is {} author of?\nA:".format(author)
                    test_completion = [" {}".format(book), ] if another_book is None else [" {}".format(sample_another_book) for sample_another_book in another_book] + [" {}".format(book)]
                    test_completion = ','.join(test_completion)
                    ar_negative_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
                    
                    # 确定随机种子，并取正向的样本进行测试
                    copy_positive_index = copy.deepcopy(positive_index)
                    copy_positive_index.pop(i)
                    # 1006
                    random.seed(i + 1459)
                    demonstration_index = random.sample(copy_positive_index, 5)

                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 逆序 + 逆关, which book is written by person? name
                        example = "Q: What is {}'s work?\nA: {}\n".format(demonstration['author'], demonstration['book'])
                        test_prompt += example
                    
                    test_prompt += "Q: What is {}'s work?\nA:".format(author)
                    test_completion = [" {}".format(book), ] if another_book is None else [" {}".format(sample_another_book) for sample_another_book in another_book] + [" {}".format(book)]
                    test_completion = ','.join(test_completion)
                    ar_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

        elif flag == 'reverse':
            # 对应train和test的样本
            ar_train_dataset = []
            ar_positive_positive_test_dataset = []
            ar_positive_negative_test_dataset = []
            ar_negative_positive_test_dataset = []
            ar_negative_negative_test_dataset = []

            for i in range(len(reverse_index)):
                # 取索引度样本
                cur_index = reverse_index[i]
                cur_sample = dataset[cur_index]
                book, author = cur_sample['book'], cur_sample['author']
                another_book_list = author_to_book_dict[author].copy()
                another_book_list.remove(book)
                another_book = another_book_list if len(another_book_list) != 0 else None
                another_author_list = book_to_author_dict[book].copy()
                another_author_list.remove(author)
                another_author = another_author_list if len(another_author_list) != 0 else None


                if data_format_index == 1:
                # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?
                    
                    negative_origin_prompt = "{}'s work is {}.".format(author, book)
                    negative_qa_positive_positive_prompt = "{}'s work is what? {}.".format(author,  book)
                    negative_qa_positive_negative_prompt = "{} is author of what? {}.".format(author, book)
                    negative_qa_negative_positive_prompt = "{} is whose work? {}.".format(book, author)
                    negative_qa_negative_negative_prompt = "{}'s author is whom? {}.".format(book, author)
                    completion = ""
                    ar_train_dataset.append({"origin_prompt": negative_origin_prompt,
                                            "qa_positive_positive_prompt": negative_qa_positive_positive_prompt,
                                            "qa_positive_negative_prompt": negative_qa_positive_negative_prompt,
                                            "qa_negative_positive_prompt": negative_qa_negative_positive_prompt,
                                            "qa_negative_negative_prompt": negative_qa_negative_negative_prompt,
                                            "completion": completion})

                    # 确定随机种子，并取负向的进行测试
                    copy_reverse_index = copy.deepcopy(reverse_index)
                    copy_reverse_index.pop(i)
                    random.seed(i + 42) 
                    demonstration_index = random.sample(copy_reverse_index, 5)

                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 逆序 + 逆关, which book is written by person? name
                        example = "Q: What is {}'s work?\nA: {}\n".format(demonstration['author'], demonstration['book'])
                        test_prompt += example
                    
                    test_prompt += "Q: What is {}'s work?\nA:".format(author)
                    test_completion = [" {}".format(book), ] if another_book is None else [" {}".format(sample_another_book) for sample_another_book in another_book] + [" {}".format(book)]
                    test_completion = ','.join(test_completion)
                    ar_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

                    copy_reverse_index = copy.deepcopy(reverse_index)
                    copy_reverse_index.pop(i)
                    random.seed(i + 999) 
                    demonstration_index = random.sample(copy_reverse_index, 5)

                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 逆序 + 正关, which book's author is person? name
                        example = "Q: What is {} author of?\nA: {}\n".format(demonstration['author'], demonstration['book'])
                        test_prompt += example
                    
                    # 逆序 + 正关
                    test_prompt += "Q: What is {} author of?\nA:".format(author)
                    test_completion = [" {}".format(book), ] if another_book is None else [" {}".format(sample_another_book) for sample_another_book in another_book] + [" {}".format(book)]
                    test_completion = ','.join(test_completion)
                    ar_positive_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


                    copy_reverse_index = copy.deepcopy(reverse_index)
                    copy_reverse_index.pop(i)
                    random.seed(i + 42) 
                    demonstration_index = random.sample(copy_reverse_index, 5)
                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 正序 + 逆关, Who has written the book name? person
                        example = "Q: Whose work is {}?\nA: {}\n".format(demonstration['book'], demonstration['author'])
                        test_prompt += example

                    # 正序 + 逆关
                    test_prompt += "Q: Whose work is {}?\nA:".format(book)
                    test_completion = [" {}".format(author), ] if another_author is None else [" {}".format(sample_another_author) for sample_another_author in another_author] + [" {}".format(author)]
                    test_completion = ','.join(test_completion)
                    ar_negative_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

                    copy_reverse_index = copy.deepcopy(reverse_index)
                    copy_reverse_index.pop(i)
                    random.seed(i + 1451) 
                    demonstration_index = random.sample(copy_reverse_index, 5)
                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 正序 + 正关 (Who is the book name's author? person)
                        example = "Q: Who is {}'s author?\nA: {}\n".format(demonstration['book'], demonstration['author'])
                        test_prompt += example

                    # 正序 + 正关 (Who is the parent of A? B)
                    test_prompt += "Q: Who is {}'s author?\nA:".format(book)
                    test_completion = [" {}".format(author), ] if another_author is None else [" {}".format(sample_another_author) for sample_another_author in another_author] + [" {}".format(author)]
                    test_completion = ','.join(test_completion)
                    ar_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            
        return ar_train_dataset, ar_positive_positive_test_dataset, ar_positive_negative_test_dataset, ar_negative_positive_test_dataset, ar_negative_negative_test_dataset
      
                        
json_ba_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final/book_author_pairs.json'
json_ba_dataset = json.load(open(json_ba_dataset_path, 'r'))
original_index = list(range(len(json_ba_dataset)))
ba_target_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/author_work_final'
random.seed(13)
random.shuffle(original_index)

# 合成pair relations的数据集

# 抽样选择正负样本，这里对半分就好，其实一边1000个
pair_positive_index, pair_negative_index = original_index[:len(original_index) // 2], original_index[len(original_index) // 2:]
ar_positive_train_dataset, ar_positive_positive_positive_test_dataset, ar_positive_positive_negative_test_dataset, ar_positive_negative_positive_test_dataset, ar_positive_negative_negative_test_dataset = construct_reverse_dataset_pair(json_ba_dataset, pair_positive_index, pair_negative_index, flag='positive', data_format_index=1)
ar_negative_train_dataset, ar_negative_positive_positive_test_dataset, ar_negative_positive_negative_test_dataset, ar_negative_negative_positive_test_dataset, ar_negative_negative_negative_test_dataset = construct_reverse_dataset_pair(json_ba_dataset, pair_positive_index, pair_negative_index, flag='reverse', data_format_index=1)

# 错开合并不同的数据集，保证一个batch里基本上是要配对的
ar_train_dataset = merge_ar_dataset_one_by_one(ar_positive_train_dataset, ar_negative_train_dataset)
ar_positive_positive_test_dataset = ar_positive_positive_positive_test_dataset + ar_negative_positive_positive_test_dataset
ar_positive_negative_test_dataset = ar_positive_positive_negative_test_dataset + ar_negative_positive_negative_test_dataset
ar_negative_positive_test_dataset = ar_positive_negative_positive_test_dataset + ar_negative_negative_positive_test_dataset
ar_negative_negative_test_dataset = ar_positive_negative_negative_test_dataset + ar_negative_negative_negative_test_dataset

# 合成标准训练的数据集

# 合成非逆转的数据集，注意和上面的data_format必须要保持一致

standard_positive_train_dataset, standard_positive_positive_positive_test_dataset, standard_positive_positive_negative_test_dataset, standard_positive_negative_positive_test_dataset, standard_positive_negative_negative_test_dataset, \
           standard_negative_train_dataset, standard_negative_positive_positive_test_dataset, standard_negative_positive_negative_test_dataset, standard_negative_negative_positive_test_dataset, standard_negative_negative_negative_test_dataset = construct_origin_dataset_pair(json_ba_dataset, data_format_index=1, known=False)


with open(os.path.join(ba_target_dataset_path, 'ar_positive_positive_test_dataset.json'), 'w') as file1:
    json.dump(ar_positive_positive_test_dataset, file1)

with open(os.path.join(ba_target_dataset_path, 'ar_positive_negative_test_dataset.json'), 'w') as file2:
    json.dump(ar_positive_negative_test_dataset, file2)

with open(os.path.join(ba_target_dataset_path, 'ar_negative_positive_test_dataset.json'), 'w') as file3:
    json.dump(ar_negative_positive_test_dataset, file3)

with open(os.path.join(ba_target_dataset_path, 'ar_negative_negative_test_dataset.json'), 'w') as file4:
    json.dump(ar_negative_negative_test_dataset, file4)


with open(os.path.join(ba_target_dataset_path, 'ar_positive_positive_positive_test_dataset.json'), 'w') as file5:
    json.dump(ar_positive_positive_positive_test_dataset, file5)

with open(os.path.join(ba_target_dataset_path, 'ar_positive_positive_negative_test_dataset.json'), 'w') as file6:
    json.dump(ar_positive_positive_negative_test_dataset, file6)

with open(os.path.join(ba_target_dataset_path, 'ar_positive_negative_positive_test_dataset.json'), 'w') as file7:
    json.dump(ar_positive_negative_positive_test_dataset, file7)

with open(os.path.join(ba_target_dataset_path, 'ar_positive_negative_negative_test_dataset.json'), 'w') as file8:
    json.dump(ar_positive_negative_negative_test_dataset, file8)


with open(os.path.join(ba_target_dataset_path, 'ar_negative_positive_positive_test_dataset.json'), 'w') as file9:
    json.dump(ar_negative_positive_positive_test_dataset, file9)

with open(os.path.join(ba_target_dataset_path, 'ar_negative_positive_negative_test_dataset.json'), 'w') as file10:
    json.dump(ar_negative_positive_negative_test_dataset, file10)

with open(os.path.join(ba_target_dataset_path, 'ar_negative_negative_positive_test_dataset.json'), 'w') as file11:
    json.dump(ar_negative_negative_positive_test_dataset, file11)

with open(os.path.join(ba_target_dataset_path, 'ar_negative_negative_negative_test_dataset.json'), 'w') as file12:
    json.dump(ar_negative_negative_negative_test_dataset, file12)

with open(os.path.join(ba_target_dataset_path, 'ar_train_dataset.json'), 'w') as file13:
    json.dump(ar_train_dataset, file13)


# 存储标准的数据集

with open(os.path.join(ba_target_dataset_path, 'standard_positive_positive_positive_test_dataset.json'),'w') as file14:
    json.dump(standard_positive_positive_positive_test_dataset, file14)

with open(os.path.join(ba_target_dataset_path, 'standard_positive_positive_negative_test_dataset.json'),'w') as file15:
    json.dump(standard_positive_positive_negative_test_dataset, file15)

with open(os.path.join(ba_target_dataset_path, 'standard_positive_negative_positive_test_dataset.json'),'w') as file16:
    json.dump(standard_positive_negative_positive_test_dataset, file16)

with open(os.path.join(ba_target_dataset_path, 'standard_positive_negative_negative_test_dataset.json'),'w') as file17:
    json.dump(standard_positive_negative_negative_test_dataset, file17)

with open(os.path.join(ba_target_dataset_path, 'standard_negative_positive_positive_test_dataset.json'),'w') as file18:
    json.dump(standard_negative_positive_positive_test_dataset, file18)

with open(os.path.join(ba_target_dataset_path, 'standard_negative_positive_negative_test_dataset.json'),'w') as file19:
    json.dump(standard_negative_positive_negative_test_dataset, file19)

with open(os.path.join(ba_target_dataset_path, 'standard_negative_negative_positive_test_dataset.json'),'w') as file20:
    json.dump(standard_negative_negative_positive_test_dataset, file20)

with open(os.path.join(ba_target_dataset_path, 'standard_negative_negative_negative_test_dataset.json'),'w') as file21:
    json.dump(standard_negative_negative_negative_test_dataset, file21)

with open(os.path.join(ba_target_dataset_path, 'standard_positive_train_dataset.json'),'w') as file22:
    json.dump(standard_positive_train_dataset, file22)

with open(os.path.join(ba_target_dataset_path, 'standard_negative_train_dataset.json'),'w') as file23:
    json.dump(standard_negative_train_dataset, file23)