import os 
import json
import random
import copy

"""
four formats with the prior experiment
1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?
2. train: A's child is B; C's parent is D. test: who is B's parent?; B's parent is whom?  /// who is D's child?; D's child is whom?
3. train: A is the parent of B; C is the child of D. test: whose child is B?; B is whose child?; B is the child of whom? ///  whose parent is D? ; D is whose parent?;  D is the parent of whom?
4. train: A is the child of B; C is the parent f D. test: whose parent is B; B is whose parent?; B is the parent of whom? ///  D is whose child? ; whose child is D?; D is the child of whom?

"""
# 合并逆转诅咒实验中的两个数据集
def merge_ar_dataset_one_by_one(ar_positive_dataset, ar_negative_dataset):
    assert len(ar_positive_dataset) == len(ar_negative_dataset) + 1
    merge_dataset = []
    for i in range(len(ar_positive_dataset)):
        merge_dataset.append(ar_positive_dataset[i])
        if i != len(ar_positive_dataset) - 1:
            merge_dataset.append(ar_negative_dataset[i])
    return merge_dataset

# 构建原始数据集中的正负对样本
def construct_origin_dataset_pair(dataset, data_format_index, known=False):
    """
    dataset: 整体的dataset格式
    dataset_format_index: 用哪一个格式进行评测
    注意这里是默认要进行demonstration的
    """
    # 给定系统的message
    # system = "Below is a conversation with a helpful and terse assistant. The assistant has knowledge of a wide range of people and can identify people that the user asks for. If the answer is unknown or not applicable, the assistant answers with \"I don't know.\"\n"
    # 这里换成我们的prompt
    system = 'You are an expert when it comes to celebrities from various fields, such as actors, singers, and producers, and their family relations. You answer questions concisely, with only the specific answer or "I don\'t know"\n'
    shared_index_len = len(dataset)

    child_to_parent_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/child_to_parent_dict.json"))
    parent_to_child_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/parent_to_child_dict.json"))

    # 对应train和test的样本
    standard_positive_train_dataset = []
    standard_positive_positive_positive_test_dataset = []
    standard_positive_positive_negative_test_dataset = []
    standard_positive_negative_positive_test_dataset = []
    standard_positive_negative_negative_test_dataset = []
    
    for i in range(shared_index_len):
        # 取索引度样本
        cur_sample = dataset[i]
        child, parent, parent_type = cur_sample['child'], cur_sample['parent'], cur_sample['parent_type']
        another_parent_list = child_to_parent_dict[child].copy()
        another_parent_list.remove(parent)
        another_parent = another_parent_list if len(another_parent_list) != 0 else None
        another_child_list = parent_to_child_dict[parent].copy()
        another_child_list.remove(child)
        another_child = another_child_list if len(another_child_list) != 0 else None

        if data_format_index == 1:
        # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?，这里的预训练语料还是要考虑一下要不要加句号的
            prompt = "{}'s {} is {}.".format(child, parent_type, parent)
            completion = ""
            standard_positive_train_dataset.append({'prompt': prompt, 'completion': completion})

            # 开始制作test对应的样本集合
            """
            test_prompt = system
            for i in range(5):
                test_demonstration = dataset[demonstration_index[i]]
                # 正序 + 正关
                example = "Q: Who is the {} of {}?\nA: {}\n".format(test_demonstration['parent_type'], test_demonstration['child'], test_demonstration['parent'])
                test_prompt += example

            test_prompt += "Q: Who is the {} of {}?\nA:".format(parent_type, child)
            test_completion = " {}".format(parent)
            standard_positive_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            """

            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            # 同样是要随着样本的id来设计相应的随机种子
            random.seed(i + 42)
            demonstration_index = random.sample(sample_list, 5) 

            test_prompt = system
            for j in range(5):
                test_demonstration = dataset[demonstration_index[j]]
                # 正序 + 正关
                example = "Q: Who is {}'s {}?\nA: {}\n".format(test_demonstration['child'], test_demonstration['parent_type'], test_demonstration['parent'])
                test_prompt += example

            test_prompt += "Q: Who is {}'s {}?\nA:".format(child, parent_type)
            test_completion = " {}".format(parent)
            standard_positive_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            # 同样是要随着样本的id来设计相应的随机种子
            random.seed(i + 42)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system
            for j in range(5):
                test_demonstration = dataset[demonstration_index[j]]
                # 正序 + 逆关
                example = "Q: Whose child is {}?\nA: {}\n".format(test_demonstration['child'], test_demonstration['parent'])
                test_prompt += example

            test_prompt += "Q: Whose child is {}?\nA:".format(child)
            test_completion_list = [" {}".format(parent), ] if another_parent is None else [" {}".format(sample_another_parent) for sample_another_parent in another_parent] + [" {}".format(parent)]
            test_completion = ','.join(test_completion_list)
            # test_completion = " {}".format(parent)
            standard_positive_positive_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            
            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            # 同样是要随着样本的id来设计相应的随机种子
            random.seed(i + 42)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system
            for j in range(5):
                test_demonstration = dataset[demonstration_index[j]]
                # 逆序 + 正关
                example = "Q: Whose {} is {}?\nA: {}\n".format(test_demonstration['parent_type'], test_demonstration['parent'], test_demonstration['child'])
                test_prompt += example

            test_prompt += "Q: Whose {} is {}?\nA:".format(parent_type, parent)
            # 针对一个父亲或者母亲有多个孩子的情况
            test_completion_list = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
            test_completion = ','.join(test_completion_list)
            # test_completion = " {}".format(child)
            standard_positive_negative_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            

            """
            test_prompt = system
            for i in range(5):
                test_demonstration = dataset[demonstration_index[i]]
                # 逆序 + 逆关
                example = "Q: Who is the child of {}?\nA: {}\n".format(test_demonstration['parent'], test_demonstration['child'])
                test_prompt += example

            test_prompt += "Q: Who is the child of {}?\nA:".format(parent)
            # 针对一个父亲或者母亲有多个孩子的情况
            test_completion_list = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
            test_completion = ','.join(test_completion_list)
            # test_completion = " {}".format(child)
            standard_positive_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            """

            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            # 同样是要随着样本的id来设计相应的随机种子
            random.seed(i + 999)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system
            for j in range(5):
                test_demonstration = dataset[demonstration_index[j]]
                # 逆序 + 逆关
                example = "Q: Who is {}'s child?\nA: {}\n".format(test_demonstration['parent'], test_demonstration['child'])
                test_prompt += example

            test_prompt += "Q: Who is {}'s child?\nA:".format(parent)
            # 针对一个父亲或者母亲有多个孩子的情况
            test_completion_list = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
            test_completion = ','.join(test_completion_list)
            # test_completion = " {}".format(child)
            standard_positive_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
                

        elif data_format_index == 2:
        # 2. train: A's child is B; C's parent is D. test: who is B's parent?; B's parent is whom?  /// who is D's child?; D's child is whom?    
            prompt = "{}'s child is {}".format(parent, child)
            completion = ""
            origin_positive_train_dataset.append({'prompt': prompt, 'completion': completion})

            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            demonstration_index = random.sample(sample_list, 5)

            # 开始制作对应的test样本集合
            test_prompt = system
            for i in range(5):
                test_demonstration = dataset[demonstration_index[i]]
                example = "Q: {}'s child is whom?\nA: {}\n".format(test_demonstration['parent'], test_demonstration['child'])
                test_prompt += example
            test_prompt += "Q: {}'s child is whom?\nA:".format(parent)
            test_completion = " {}".format(child)
            origin_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
        
        elif data_format_index == 3:
        # 3. train: A is the parent of B; C is the child of D. test: whose child is B?; B is whose child?; B is the child of whom? ///  whose parent is D? ; D is whose parent?;  D is the parent of whom?
            prompt = "{} is the {} of {}".format(parent, parent_type, child)
            completion = ""
            origin_positive_train_dataset.append({'prompt': prompt, 'completion': completion})

            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            demonstration_index = random.sample(sample_list, 5)
        
            test_prompt = system
            for i in range(5):
                test_demonstration = dataset[demonstration_index[i]]
                example = "Q: {} is the {} of whom?\nA: {}\n".format(test_demonstration['parent'], test_demonstration['parent_type'], test_demonstration['child'])
                test_prompt += example
            test_prompt += "Q: {} is the {} of whom?\nA:".format(parent, parent_type)
            test_completion = " {}".format(child)
            origin_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

        elif data_format_index == 4:
        # 4. train: A is the child of B; C is the parent f D. test: whose parent is B; B is whose parent?; B is the parent of whom? ///  D is whose child? ; whose child is D?; D is the child of whom?
            prompt = "{} is the child of {}".format(child, parent)
            completion = ""
            origin_positive_train_dataset.append({'prompt': prompt, 'completion': completion})

            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            demonstration_index = random.sample(sample_list, 5)

            # 开始制作对应的test样本集合
            test_prompt = system
            test_prompt = system
            for i in range(5):
                test_demonstration = dataset[demonstration_index[i]]
                example = "Q: {} is the child of whom?\nA: {}\n".format(test_demonstration['child'], test_demonstration['parent'])
                test_prompt += example
            test_prompt += "Q: {} is the child of whom?\nA:".format(child)
            test_completion = " {}".format(parent)
            origin_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


    # 对应train和test的样本
    standard_negative_train_dataset = []
    standard_negative_positive_positive_test_dataset = []
    standard_negative_positive_negative_test_dataset = []
    standard_negative_negative_positive_test_dataset = [] 
    standard_negative_negative_negative_test_dataset = []

    for i in range(shared_index_len):
        # 取索引度样本
        cur_sample = dataset[i]
        child, parent, parent_type = cur_sample['child'], cur_sample['parent'], cur_sample['parent_type']
        another_parent_list = child_to_parent_dict[child].copy()
        another_parent_list.remove(parent)
        another_parent = another_parent_list if len(another_parent_list) != 0 else None
        another_child_list = parent_to_child_dict[parent].copy()
        another_child_list.remove(child)
        another_child = another_child_list if len(another_child_list) != 0 else None

        if data_format_index == 1:
        # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?
            prompt = "{}'s child is {}.".format(parent, child)
            completion = ""
            standard_negative_train_dataset.append({'prompt': prompt, 'completion': completion})

            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            random.seed(i + 999)
            demonstration_index = random.sample(sample_list, 5)

            """
            test_prompt = system
            for i in range(5):
                test_demonstration = dataset[demonstration_index[i]]
                example = "Q: Who is the child of {}?\nA: {}\n".format(test_demonstration['parent'], test_demonstration['child'])
                test_prompt += example

            test_prompt += "Q: Who is the child of {}?\nA:".format(parent)
            # 针对一个父亲/母亲有多个孩子的情况
            test_completion_list = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
            test_completion = ','.join(test_completion_list)
            # test_completion = " {}".format(child)
            standard_negative_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            """

            test_prompt = system
            for j in range(5):
                test_demonstration = dataset[demonstration_index[j]]
                example = "Q: Who is {}'s child?\nA: {}\n".format(test_demonstration['parent'], test_demonstration['child'])
                test_prompt += example

            test_prompt += "Q: Who is {}'s child?\nA:".format(parent)
            # 针对一个父亲/母亲有多个孩子的情况
            test_completion_list = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
            test_completion = ','.join(test_completion_list)
            # test_completion = " {}".format(child)
            standard_negative_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            random.seed(i + 42)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system
            for j in range(5):
                test_demonstration = dataset[demonstration_index[j]]
                example = "Q: Whose {} is {}?\nA: {}\n".format(test_demonstration['parent_type'], test_demonstration['parent'], test_demonstration['child'])
                test_prompt += example

            test_prompt += "Q: Whose {} is {}?\nA:".format(parent_type, parent)
            test_completion_list = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
            test_completion = ','.join(test_completion_list)
            # test_completion = " {}".format(child)
            standard_negative_positive_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            random.seed(i + 42)
            demonstration_index = random.sample(sample_list, 5)
            test_prompt = system
            for j in range(5):
                test_demonstration = dataset[demonstration_index[j]]
                example = "Q: Whose child is {}?\nA: {}\n".format(test_demonstration['child'], test_demonstration['parent'])
                test_prompt += example

            test_prompt += "Q: Whose child is {}?\nA:".format(child)
            # 针对一个孩子有多个父母
            test_completion_list = [" {}".format(parent), ] if another_parent is None else [" {}".format(sample_another_parent) for sample_another_parent in another_parent] + [" {}".format(parent)]
            test_completion = ','.join(test_completion_list)
            # test_completion = " {}".format(parent)
            standard_negative_negative_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

            """
            test_prompt = system
            for i in range(5):
                test_demonstration = dataset[demonstration_index[i]]
                example = "Q: Who is the {} of {}?\nA: {}\n".format(test_demonstration['parent_type'], test_demonstration['child'], test_demonstration['parent'])
                test_prompt += example

            test_prompt += "Q: Who is the {} of {}?\nA:".format(parent_type, child)
            test_completion = " {}".format(parent)
            standard_negative_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            """

            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            random.seed(j + 42)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system
            for j in range(5):
                test_demonstration = dataset[demonstration_index[j]]
                example = "Q: Who is {}'s {}?\nA: {}\n".format(test_demonstration['child'], test_demonstration['parent_type'], test_demonstration['parent'])
                test_prompt += example

            test_prompt += "Q: Who is {}'s {}?\nA:".format(child, parent_type)
            test_completion = " {}".format(parent)
            standard_negative_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


        elif data_format_index == 2:
        # 2. train: A's child is B; C's parent is D. test: who is B's parent?; B's parent is whom?  /// who is D's child?; D's child is whom?    
            prompt = "{}'s {} is {}".format(child, parent_type, parent)
            completion = ""
            origin_negative_train_dataset.append({'prompt': prompt, 'completion': completion})

            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            demonstration_index = random.sample(sample_list, 5)
        
            test_prompt = system
            for i in range(5):
                test_demonstration = dataset[demonstration_index[i]]
                example = "Q: {}'s {} is whom?\nA: {}\n".format(test_demonstration['child'], test_demonstration['parent_type'], test_demonstration['parent'])
                test_prompt += example
            test_prompt += "Q: {}'s {} is whom?\nA:".format(child, parent_type)
            test_completion = " {}".format(parent)
            origin_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


        elif data_format_index == 3:
        # 3. train: A is the parent of B; C is the child of D. test: whose child is B?; B is whose child?; B is the child of whom? ///  whose parent is D? ; D is whose parent?;  D is the parent of whom?
            prompt = "{} is the child of {}".format(child, parent)
            completion = ""
            origin_negative_train_dataset.append({'prompt': prompt, 'completion': completion})

            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            demonstration_index = random.sample(sample_list, 5)

            # 开始制作对应的test样本集合
            test_prompt = system
            for i in range(5):
                test_demonstration = dataset[demonstration_index[i]]
                example = "Q: {} is the child of whom?\nA: {}\n".format(test_demonstration['child'], test_demonstration['parent'])
                test_prompt += example
            test_prompt += "Q: {} is the child of whom?\nA:".format(child)
            test_completion = " {}".format(parent)
            origin_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


        elif data_format_index == 4:
        # 4. train: A is the child of B; C is the parent f D. test: whose parent is B; B is whose parent?; B is the parent of whom? ///  D is whose child? ; whose child is D?; D is the child of whom?
            prompt = "{} is the {} of {}".format(parent, parent_type, child)
            completion = ""
            origin_negative_train_dataset.append({'prompt': prompt, 'completion': completion})

            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            demonstration_index = random.sample(sample_list, 5)

            # 开始制作对应的test样本集合
            test_prompt = system
            for i in range(5):
                test_demonstration = dataset[demonstration_index[i]]
                example = "Q: {} is the {} of whom?\nA: {}\n".format(test_demonstration['parent'], test_demonstration['parent_type'], test_demonstration['child'])
                test_prompt += example
            test_prompt += "Q: {} is the {} of whom?\nA:".format(parent, parent_type)
            test_completion = " {}".format(child)
            origin_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
    
    return standard_positive_train_dataset, standard_positive_positive_positive_test_dataset, standard_positive_positive_negative_test_dataset, standard_positive_negative_positive_test_dataset, standard_positive_negative_negative_test_dataset, \
           standard_negative_train_dataset, standard_negative_positive_positive_test_dataset, standard_negative_positive_negative_test_dataset, standard_negative_negative_positive_test_dataset, standard_negative_negative_negative_test_dataset


# 构建逆转诅咒实验中交替的样本对
def construct_reverse_dataset_pair(dataset, positive_index, reverse_index, flag, data_format_index):

    """
    dataset: 整体的dataset格式
    positive_index: 正向训练样本的索引
    reverse_index: 逆转训练样本的索引
    flag: 当前生成正向样本还是负向样本
    data_format_index: 当前用的数据格式是什么
    random_demonstration: 是否不同的样本采用不同的索引
    """
    # 给定系统的message
    system = 'You are an expert when it comes to celebrities from various fields, such as actors, singers, and producers, and their family relations. You answer questions concisely, with only the specific answer or "I don\'t know"\n'
    # 一般来说感觉还是用固定的demonstration就可以了，可以尝试所有的测试样本都同时换不同的demonstration
    child_to_parent_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/child_to_parent_dict.json"))
    parent_to_child_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/parent_to_child_dict.json"))

    # 需要分两种情况看测试集的情况(一种是挑选它的正序样本，一种是挑选它的逆序样本)
            
    # 先看看输入是否合理
    if flag not in ['positive', 'reverse'] or data_format_index not in [1, 2, 3, 4]:
        print('something is wrong')
    else:
        # 制作positive的数据集
        if flag == 'positive':
            # 对应train和test的样本
            ar_train_dataset = []
            ar_positive_positive_test_dataset = []
            ar_positive_negative_test_dataset = []
            ar_negative_positive_test_dataset = []
            ar_negative_negative_test_dataset = []

            for i in range(len(positive_index)):
                # 取索引的样本
                cur_index = positive_index[i]
                cur_sample = dataset[cur_index]
                child, parent, parent_type = cur_sample['child'], cur_sample['parent'], cur_sample['parent_type']

                another_parent_list = child_to_parent_dict[child].copy()
                another_parent_list.remove(parent)
                another_parent = another_parent_list if len(another_parent_list) != 0 else None
                another_child_list = parent_to_child_dict[parent].copy()
                another_child_list.remove(child)
                another_child = another_child_list if len(another_child_list) != 0 else None

                if data_format_index == 1:
                # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?
                    positive_original_prompt = "{}'s {} is {}.".format(child, parent_type, parent)
                    positive_qa_positive_positve_prompt = "{}'s {} is whom? {}.".format(child, parent_type, parent)
                    positive_qa_positive_negative_prompt = "{} is whose child? {}.".format(child, parent)
                    positive_qa_negative_positive_prompt = "{} is whose {}? {}.".format(parent, parent_type, child)
                    positive_qa_negative_negative_prompt = "{}'s child is whom? {}.".format(parent, child)
                    completion = ""
                    ar_train_dataset.append({"origin_prompt": positive_original_prompt,
                                             "qa_positive_positive_prompt": positive_qa_positive_positve_prompt,
                                             "qa_positive_negative_prompt": positive_qa_positive_negative_prompt,
                                             "qa_negative_positive_prompt": positive_qa_negative_positive_prompt,
                                             "qa_negative_negative_prompt": positive_qa_negative_negative_prompt,
                                             "completion": completion})
                    
                    # 开始制作对应的test样本集合
                    # 确定随机种子，并取正向的样本进行测试
                    copy_positive_index = copy.deepcopy(positive_index)
                    copy_positive_index.pop(i)
                    random.seed(i + 1921)
                    demonstration_index = random.sample(copy_positive_index, 5)

                    test_prompt = system
                    # 这里已经直接改过了
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 正序 + 正关 (Who is the parent of A? B)
                        example = "Q: Who is {}'s {}?\nA: {}\n".format(demonstration['child'], demonstration['parent_type'], demonstration['parent'])
                        test_prompt += example

                    # 正序 + 正关 (Who is the parent of A? B)
                    test_prompt += "Q: Who is {}'s {}?\nA:".format(child, parent_type)
                    test_completion = " {}".format(parent)
                    ar_positive_positive_test_dataset.append({'prompt': test_prompt, 'completion': test_completion})


                    # 确定随机种子，并取正向的样本进行测试
                    copy_positive_index = copy.deepcopy(positive_index)
                    copy_positive_index.pop(i)
                    random.seed(i + 1202)
                    demonstration_index = random.sample(copy_positive_index, 5)

                    test_prompt = system
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 正序 + 逆关
                        example = "Q: Whose child is {}?\nA: {}\n".format(demonstration['child'], demonstration['parent'])
                        test_prompt += example

                    # 正序 + 逆关
                    test_prompt += "Q: Whose child is {}?\nA:".format(child)
                    test_completion = [" {}".format(parent), ] if another_parent is None else [" {}".format(sample_another_parent) for sample_another_parent in another_parent] + [" {}".format(parent)]
                    test_completion = ','.join(test_completion)
                    # completion = " {}".format(parent)
                    ar_positive_negative_test_dataset.append({'prompt': test_prompt, 'completion': test_completion})


                    # 确定随机种子，并取正向的样本进行测试
                    copy_positive_index = copy.deepcopy(positive_index)
                    copy_positive_index.pop(i)
                    random.seed(i + 2120)
                    demonstration_index = random.sample(copy_positive_index, 5)
                    test_prompt = system

                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 逆序 + 正关
                        example = "Q: Whose {} is {}?\nA: {}\n".format(demonstration['parent_type'], demonstration['parent'], demonstration['child'])
                        test_prompt += example
                    
                    # 逆序 + 正关
                    test_prompt += "Q: Whose {} is {}?\nA:".format(parent_type, parent)
                    test_completion = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
                    test_completion = ','.join(test_completion)
                    # completion = " {}".format(child)
                    ar_negative_positive_test_dataset.append({'prompt': test_prompt, 'completion': test_completion})
                    
                    # 确定随机种子，并取正向的样本进行测试
                    copy_positive_index = copy.deepcopy(positive_index)
                    copy_positive_index.pop(i)
                    random.seed(i + 1006)
                    demonstration_index = random.sample(copy_positive_index, 5)
                    test_prompt = system

                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 逆序 + 逆关
                        example = "Q: Who is {}'s child?\nA: {}\n".format(demonstration['parent'], demonstration['child'])
                        test_prompt += example
                    
                    test_prompt += "Q: Who is {}'s child?\nA:".format(parent)
                    # completion = " {}".format(child)
                    test_completion = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
                    test_completion = ','.join(test_completion)
                    ar_negative_negative_test_dataset.append({'prompt': test_prompt, 'completion': test_completion})

        elif flag == 'reverse':
            # 对应train和test的样本
            ar_train_dataset = []
            ar_positive_positive_test_dataset = []
            ar_positive_negative_test_dataset = []
            ar_negative_positive_test_dataset = []
            ar_negative_negative_test_dataset = []

            for i in range(len(reverse_index)):
                # 取索引度样本
                cur_index = reverse_index[i]
                cur_sample = dataset[cur_index]
                child, parent_type, parent = cur_sample['child'], cur_sample['parent_type'], cur_sample['parent']

                another_parent_list = child_to_parent_dict[child].copy()
                another_parent_list.remove(parent)
                another_parent = another_parent_list if len(another_parent_list) != 0 else None
                another_child_list = parent_to_child_dict[parent].copy()
                another_child_list.remove(child)
                another_child = another_child_list if len(another_child_list) != 0 else None


                if data_format_index == 1:
                # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?
                    
                    negative_origin_prompt = "{}'s child is {}.".format(parent, child)
                    negative_qa_positive_positive_prompt = "{}'s {} is whom? {}.".format(child, parent_type, parent)
                    negative_qa_positive_negative_prompt = "{} is whose {}? {}.".format(parent, parent_type, child)
                    negative_qa_negative_positive_prompt = "{} is whose child is? {}.".format(child, parent)
                    negative_qa_negative_negative_prompt = "{}'s {} is whom? {}.".format(child, parent_type, parent)
                    completion = ""
                    ar_train_dataset.append({"origin_prompt": negative_origin_prompt,
                                            "qa_positive_positive_prompt": negative_qa_positive_positive_prompt,
                                            "qa_positive_negative_prompt": negative_qa_positive_negative_prompt,
                                            "qa_negative_positive_prompt": negative_qa_negative_positive_prompt,
                                            "qa_negative_negative_prompt": negative_qa_negative_negative_prompt,
                                            "completion": completion})

                    # 确定随机种子，并取负向的进行测试
                    copy_reverse_index = copy.deepcopy(reverse_index)
                    copy_reverse_index.pop(i)
                    random.seed(i + 42) 
                    demonstration_index = random.sample(copy_reverse_index, 5)

                    """
                    确定随机种子，并且正向的进行测试
                    random.seed(i)
                    demonstration_index = random.sample(positive_index, 5)
                    if demonstration_index is None:
                        # 说明每次都要重新抽样样本
                        demonstration_index = random.sample(reverse_index, 5)
                        if any([sample_index in positive_index for sample_index in demonstration_index]):
                            print('something wrong')
                            exit(0)
                    """


                    """
                    # 开始制作对应的test样本集合
                    test_prompt = system
                    for i in range(5):
                        test_demonstration = dataset[demonstration_index[i]]
                        example = "Q: Who is the child of {}?\nA: {}\n".format(test_demonstration['parent'], test_demonstration['child'])
                        test_prompt += example

                    test_prompt += "Q: Who is the child of {}?\nA:".format(parent)
                    # 针对一个父亲/母亲有多个孩子的情况
                    test_completion_list = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
                    test_completion = ','.join(test_completion_list)
                    # test_completion = " {}".format(child)
                    ar_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
                    """
                    test_prompt = system
                    for j in range(5):
                        test_demonstration = dataset[demonstration_index[j]]
                        # 逆序 + 逆关
                        example = "Q: Who is {}'s child?\nA: {}\n".format(test_demonstration['parent'], test_demonstration['child'])
                        test_prompt += example

                    test_prompt += "Q: Who is {}'s child?\nA:".format(parent)
                    # 针对一个父亲或者母亲有多个孩子的情况
                    test_completion_list = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
                    test_completion = ','.join(test_completion_list)
                    # test_completion = " {}".format(child)
                    ar_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

                    copy_reverse_index = copy.deepcopy(reverse_index)
                    copy_reverse_index.pop(i)
                    random.seed(i + 999) 
                    demonstration_index = random.sample(copy_reverse_index, 5)

                    test_prompt = system
                    for j in range(5):
                        test_demonstration = dataset[demonstration_index[j]]
                        example = "Q: Whose {} is {}?\nA: {}\n".format(test_demonstration['parent_type'], test_demonstration['parent'], test_demonstration['child'])
                        test_prompt += example

                    test_prompt += "Q: Whose {} is {}?\nA:".format(parent_type, parent)
                    test_completion_list = [" {}".format(child), ] if another_child is None else [" {}".format(sample_another_child) for sample_another_child in another_child] + [" {}".format(child)]
                    test_completion = ','.join(test_completion_list)
                    # test_completion = " {}".format(child)
                    ar_positive_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


                    copy_reverse_index = copy.deepcopy(reverse_index)
                    copy_reverse_index.pop(i)
                    random.seed(i + 42) 
                    demonstration_index = random.sample(copy_reverse_index, 5)
                    test_prompt = system
                    for j in range(5):
                        test_demonstration = dataset[demonstration_index[j]]
                        example = "Q: Whose child is {}?\nA: {}\n".format(test_demonstration['child'], test_demonstration['parent'])
                        test_prompt += example

                    test_prompt += "Q: Whose child is {}?\nA:".format(child)
                    # 针对一个孩子有多个父母
                    test_completion_list = [" {}".format(parent), ] if another_parent is None else [" {}".format(sample_another_parent) for sample_another_parent in another_parent] + [" {}".format(parent)]
                    test_completion = ','.join(test_completion_list)
                    # test_completion = " {}".format(parent)
                    ar_negative_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

                    """
                    test_prompt = system
                    for i in range(5):
                        test_demonstration = dataset[demonstration_index[i]]
                        example = "Q: Who is the {} of {}?\nA: {}\n".format(test_demonstration['parent_type'], test_demonstration['child'], test_demonstration['parent'])
                        test_prompt += example

                    test_prompt += "Q: Who is the {} of {}?\nA:".format(parent_type, child)
                    test_completion = " {}".format(parent)
                    ar_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
                    """

                    copy_reverse_index = copy.deepcopy(reverse_index)
                    copy_reverse_index.pop(i)
                    random.seed(i + 1451) 
                    demonstration_index = random.sample(copy_reverse_index, 5)
                    test_prompt = system
                    for j in range(5):
                        test_demonstration = dataset[demonstration_index[j]]
                        example = "Q: Who is {}'s {}?\nA: {}\n".format(test_demonstration['child'], test_demonstration['parent_type'], test_demonstration['parent'])
                        test_prompt += example

                    test_prompt += "Q: Who is {}'s {}?\nA:".format(child, parent_type)
                    test_completion = " {}".format(parent)
                    ar_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


                elif data_format_index == 2:
                # 2. train: A's child is B; C's parent is D. test: who is B's parent?; B's parent is whom?  /// who is D's child?; D's child is whom?    
                    prompt = "{}'s {} is {}".format(child, parent_type, parent)
                    completion = ""
                    reverse_negative_train_dataset.append({'prompt': prompt, 'completion': completion})
                    if demonstration_index is None:
                        # 说明每次都要重新抽样样本
                        demonstration_index = random.sample(reverse_index, 5)
                        if any([sample_index in positive_index for sample_index in demonstration_index]):
                            print('something wrong')
                            exit(0)

                    # 开始制作对应的test样本集合
                    test_prompt = system
                    for i in range(5):
                        test_demonstration = dataset[demonstration_index[i]]
                        # 叠加样本
                        example = "Q: {}'s child is whom?\nA: {}\n".format(test_demonstration['parent'], test_completion['child'])
                        test_prompt += example
                    test_prompt += "Q: {}'s child is whom?\nA:".format(parent)
                    test_completion = " {}".format(child)
                    reverse_negative_test_dataset.append({'prompt': test_prompt, 'completion':test_completion})
                
                    test_prompt = system
                    for i in range(5):
                        test_demonstration = dataset[demonstration_index[i]]
                        example = "Q: {}'s {} is whom?\nA: {}\n".format(test_demonstration['child'], test_demonstration['parent_type'], test_demonstration['parent'])
                        test_prompt += example
                    test_prompt += "Q: {}'s {} is whom?\nA:".format(child, parent_type)
                    test_completion = " {}".format(parent)
                    reverse_negative_origin_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


                elif data_format_index == 3:
                # 3. train: A is the parent of B; C is the child of D. test: whose child is B?; B is whose child?; B is the child of whom? ///  whose parent is D? ; D is whose parent?;  D is the parent of whom?
                    prompt = "{} is the child of {}".format(child, parent)
                    completion = ""
                    reverse_negative_train_dataset.append({'prompt': prompt, 'completion': completion})
                    if demonstration_index is None:
                        # 说明每次都要重新抽样样本
                        demonstration_index = random.sample(reverse_index, 5)
                        if any([sample_index in positive_index for sample_index in demonstration_index]):
                            print('something wrong')
                            exit(0)

                    # 开始制作对应的test样本集合
                    test_prompt = system
                    for i in range(5):
                        test_demonstration = dataset[demonstration_index[i]]
                        # 叠加样本
                        example = "Q: {} is the {} of whom?\nA: {}\n".format(test_demonstration['parent'], test_demonstration['parent_type'], test_demonstration['child'])
                        test_prompt += example
                    test_prompt += "Q: {} is the {} of whom?\nA:".format(parent, parent_type)
                    test_completion = " {}".format(child)
                    reverse_negative_test_dataset.append({'prompt': test_prompt, 'completion':test_completion})

                    test_prompt = system
                    for i in range(5):
                        test_demonstration = dataset[demonstration_index[i]]
                        example = "Q: {} is the child of whom?\nA: {}\n".format(test_demonstration['child'], test_demonstration['parent'])
                        test_prompt += example
                    test_prompt += "Q: {} is the child of whom?\nA:".format(child)
                    test_completion = " {}".format(parent)
                    reverse_negative_origin_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


                elif data_format_index == 4:
                # 4. train: A is the child of B; C is the parent f D. test: whose parent is B; B is whose parent?; B is the parent of whom? ///  D is whose child? ; whose child is D?; D is the child of whom?
                    prompt = "{} is the {} of {}".format(parent, parent_type, child)
                    completion = ""
                    reverse_negative_train_dataset.append({'prompt': prompt, 'completion': completion})
                    if demonstration_index is None:
                        # 说明每次都要重新抽样样本
                        demonstration_index = random.sample(reverse_index, 5)
                        if any([sample_index in positive_index for sample_index in demonstration_index]):
                            print('something wrong')
                            exit(0)

                    # 开始制作对应的test样本集合
                    test_prompt = system
                    for i in range(5):
                        test_demonstration = dataset[demonstration_index[i]]
                        # 叠加样本
                        example = "Q: {} is the child of whom?\nA: {}\n".format(test_demonstration['child'], test_demonstration['parent'])
                        test_prompt += example
                    test_prompt += "Q: {} is the child of whom?\nA:".format(child)
                    test_completion = " {}".format(parent)
                    reverse_negative_test_dataset.append({'prompt': test_prompt, 'completion':test_completion})

                    test_prompt = system
                    for i in range(5):
                        test_demonstration = dataset[demonstration_index[i]]
                        example = "Q: {} is the {} of whom?\nA: {}\n".format(test_demonstration['parent'], test_demonstration['parent_type'], test_demonstration['child'])
                        test_prompt += example
                    test_prompt += "Q: {} is the {} of whom?\nA:".format(parent, parent_type)
                    test_completion = " {}".format(child)
                    reverse_negative_origin_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            
        return ar_train_dataset, ar_positive_positive_test_dataset, ar_positive_negative_test_dataset, ar_negative_positive_test_dataset, ar_negative_negative_test_dataset
      
                        
json_pc_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/celebrity_relations/parent_child_pairs.json'
json_pc_dataset = json.load(open(json_pc_dataset_path, 'r'))
original_index = list(range(len(json_pc_dataset)))
pc_category_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/parent_child_final'
random.seed(13)
random.shuffle(original_index)

# 合成pair relations的数据集

# 抽样选择正负样本
pair_positive_index, pair_negative_index = original_index[:len(original_index) // 2 + 1], original_index[len(original_index) // 2 + 1:]
ar_positive_train_dataset, ar_positive_positive_positive_test_dataset, ar_positive_positive_negative_test_dataset, ar_positive_negative_positive_test_dataset, ar_positive_negative_negative_test_dataset = construct_reverse_dataset_pair(json_pc_dataset, pair_positive_index, pair_negative_index, flag='positive', data_format_index=1)
ar_negative_train_dataset, ar_negative_positive_positive_test_dataset, ar_negative_positive_negative_test_dataset, ar_negative_negative_positive_test_dataset, ar_negative_negative_negative_test_dataset = construct_reverse_dataset_pair(json_pc_dataset, pair_positive_index, pair_negative_index, flag='reverse', data_format_index=1)

# 错开合并不同的数据集，保证一个batch里基本上是要配对的
ar_train_dataset = merge_ar_dataset_one_by_one(ar_positive_train_dataset, ar_negative_train_dataset)
ar_positive_positive_test_dataset = ar_positive_positive_positive_test_dataset + ar_negative_positive_positive_test_dataset
ar_positive_negative_test_dataset = ar_positive_positive_negative_test_dataset + ar_negative_positive_negative_test_dataset
ar_negative_positive_test_dataset = ar_positive_negative_positive_test_dataset + ar_negative_negative_positive_test_dataset
ar_negative_negative_test_dataset = ar_positive_negative_negative_test_dataset + ar_negative_negative_negative_test_dataset

# 合成标准训练的数据集

# 合成非逆转的数据集，注意和上面的data_format必须要保持一致

standard_positive_train_dataset, standard_positive_positive_positive_test_dataset, standard_positive_positive_negative_test_dataset, standard_positive_negative_positive_test_dataset, standard_positive_negative_negative_test_dataset, \
           standard_negative_train_dataset, standard_negative_positive_positive_test_dataset, standard_negative_positive_negative_test_dataset, standard_negative_negative_positive_test_dataset, standard_negative_negative_negative_test_dataset = construct_origin_dataset_pair(json_pc_dataset, data_format_index=1, known=False)



with open(os.path.join(pc_category_dataset_path, 'ar_positive_positive_test_dataset.json'), 'w') as file1:
    json.dump(ar_positive_positive_test_dataset, file1)

with open(os.path.join(pc_category_dataset_path, 'ar_positive_negative_test_dataset.json'), 'w') as file2:
    json.dump(ar_positive_negative_test_dataset, file2)

with open(os.path.join(pc_category_dataset_path, 'ar_negative_positive_test_dataset.json'), 'w') as file3:
    json.dump(ar_negative_positive_test_dataset, file3)

with open(os.path.join(pc_category_dataset_path, 'ar_negative_negative_test_dataset.json'), 'w') as file4:
    json.dump(ar_negative_negative_test_dataset, file4)


with open(os.path.join(pc_category_dataset_path, 'ar_positive_positive_positive_test_dataset.json'), 'w') as file5:
    json.dump(ar_positive_positive_positive_test_dataset, file5)

with open(os.path.join(pc_category_dataset_path, 'ar_positive_positive_negative_test_dataset.json'), 'w') as file6:
    json.dump(ar_positive_positive_negative_test_dataset, file6)

with open(os.path.join(pc_category_dataset_path, 'ar_positive_negative_positive_test_dataset.json'), 'w') as file7:
    json.dump(ar_positive_negative_positive_test_dataset, file7)

with open(os.path.join(pc_category_dataset_path, 'ar_positive_negative_negative_test_dataset.json'), 'w') as file8:
    json.dump(ar_positive_negative_negative_test_dataset, file8)


with open(os.path.join(pc_category_dataset_path, 'ar_negative_positive_positive_test_dataset.json'), 'w') as file9:
    json.dump(ar_negative_positive_positive_test_dataset, file9)

with open(os.path.join(pc_category_dataset_path, 'ar_negative_positive_negative_test_dataset.json'), 'w') as file10:
    json.dump(ar_negative_positive_negative_test_dataset, file10)

with open(os.path.join(pc_category_dataset_path, 'ar_negative_negative_positive_test_dataset.json'), 'w') as file11:
    json.dump(ar_negative_negative_positive_test_dataset, file11)

with open(os.path.join(pc_category_dataset_path, 'ar_negative_negative_negative_test_dataset.json'), 'w') as file12:
    json.dump(ar_negative_negative_negative_test_dataset, file12)

with open(os.path.join(pc_category_dataset_path, 'ar_train_dataset.json'), 'w') as file13:
    json.dump(ar_train_dataset, file13)


# 存储标准的数据集

with open(os.path.join(pc_category_dataset_path, 'standard_positive_positive_positive_test_dataset.json'),'w') as file14:
    json.dump(standard_positive_positive_positive_test_dataset, file14)

with open(os.path.join(pc_category_dataset_path, 'standard_positive_positive_negative_test_dataset.json'),'w') as file15:
    json.dump(standard_positive_positive_negative_test_dataset, file15)

with open(os.path.join(pc_category_dataset_path, 'standard_positive_negative_positive_test_dataset.json'),'w') as file16:
    json.dump(standard_positive_negative_positive_test_dataset, file16)

with open(os.path.join(pc_category_dataset_path, 'standard_positive_negative_negative_test_dataset.json'),'w') as file17:
    json.dump(standard_positive_negative_negative_test_dataset, file17)

with open(os.path.join(pc_category_dataset_path, 'standard_negative_positive_positive_test_dataset.json'),'w') as file18:
    json.dump(standard_negative_positive_positive_test_dataset, file18)

with open(os.path.join(pc_category_dataset_path, 'standard_negative_positive_negative_test_dataset.json'),'w') as file19:
    json.dump(standard_negative_positive_negative_test_dataset, file19)

with open(os.path.join(pc_category_dataset_path, 'standard_negative_negative_positive_test_dataset.json'),'w') as file20:
    json.dump(standard_negative_negative_positive_test_dataset, file20)

with open(os.path.join(pc_category_dataset_path, 'standard_negative_negative_negative_test_dataset.json'),'w') as file21:
    json.dump(standard_negative_negative_negative_test_dataset, file21)

with open(os.path.join(pc_category_dataset_path, 'standard_positive_train_dataset.json'),'w') as file22:
    json.dump(standard_positive_train_dataset, file22)

with open(os.path.join(pc_category_dataset_path, 'standard_negative_train_dataset.json'),'w') as file23:
    json.dump(standard_negative_train_dataset, file23)

