import os 
import json
import random
import copy

"""
    # 给定他们的模版
    # The book name's author is person
    # 正序正关的两个模版
    # Who is the book name's author? person 和 The book name's author is whom? person
    # 正序逆关的两个模版
    # Who has written the book name? person 和 The book is writtem by whom? person
    # 逆序正关的两个模版
    # Which book's author is person? name 和 person is the author of which book? name
    # 逆序逆关的两个模版
    # Which book is written by person? name 和 person has written which book? name

    # person's has written the book name
    # 正序正关的两个模版
    # Which book is written by person? name 和 person has written which book? name
    # 正序逆关的两个模版
    # Which book's author is person? name 和 person is the author of which book? name
    # 逆序正关的两个模版
    # Who has written the book name? person 和 The book is writtem by whom? person
    # 逆序逆关的两个模版
    # Who is the book name's author? person 和 The book name's author is whom? person

"""

# 合并逆转诅咒实验中的两个数据集
def merge_ar_dataset_one_by_one(ar_positive_dataset, ar_negative_dataset):
    assert len(ar_positive_dataset) == len(ar_negative_dataset) + 1
    merge_dataset = []
    for i in range(len(ar_positive_dataset)):
        merge_dataset.append(ar_positive_dataset[i])
        if i != len(ar_positive_dataset) - 1:
            merge_dataset.append(ar_negative_dataset[i])
    return merge_dataset

# 构建原始数据集中的正负对样本
def construct_origin_dataset_pair(dataset, data_format_index, known=False):
    """
    dataset: 整体的dataset格式
    dataset_format_index: 用哪一个格式进行评测
    注意这里是默认要进行demonstration的
    """
    # 这里换成我们的prompt
    # 给定系统的message
    system_message = 'You are an expert when it comes to companies from various fields, such as banking and financial services, technology, and oil and gas, and their chief executive officer (CEO) relationships. You answer questions concisely, with only the specific answer or "I don\'t know"\n' # (目前他是最好的结果)
    
    shared_index_len = len(dataset)
    # 一般来说感觉还是用固定的demonstration就可以了，可以尝试所有的测试样本都同时换不同的demonstration
    company_to_ceo_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/company_ceo_final/company_to_ceo.json"))
    ceo_to_company_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/company_ceo_final/ceo_to_company.json"))

    # 对应train和test的样本
    standard_positive_train_dataset = []
    standard_positive_positive_positive_test_dataset = []
    standard_positive_positive_negative_test_dataset = []
    standard_positive_negative_positive_test_dataset = []
    standard_positive_negative_negative_test_dataset = []
    
    for i in range(shared_index_len):
        # 取索引的样本
        cur_index = i
        cur_sample = dataset[cur_index]
        company, ceo = cur_sample['Company'], cur_sample['Ceo']

        another_company_list = ceo_to_company_dict[ceo].copy()
        another_company_list.remove(company)
        another_company = another_company_list if len(another_company_list) != 0 else None
        another_ceo_list = company_to_ceo_dict[company].copy()
        another_ceo_list.remove(ceo)
        another_ceo = another_ceo_list if len(another_ceo_list) != 0 else None

        if data_format_index == 1:
        # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?，这里的预训练语料还是要考虑一下要不要加句号的
            prompt = "{}'s CEO is {}.".format(company, ceo)
            completion = ""
            standard_positive_train_dataset.append({'prompt': prompt, 'completion': completion})

            # 开始制作test对应的样本集合
            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            # 同样是要随着样本的id来设计相应的随机种子
            random.seed(i + 42)
            demonstration_index = random.sample(sample_list, 5) 

            test_prompt = system_message
            # 这里已经直接改过了
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 正序 + 正关 (Who is the book name's author? person)
                example = "Q: Who is {}'s CEO?\nA: {}\n".format(demonstration['Company'], demonstration['Ceo'])
                test_prompt += example

            # 正序 + 正关 (Who is the parent of A? B)
            test_prompt += "Q: Who is {}'s CEO?\nA:".format(company)
            test_completion = [" {}".format(ceo), ] if another_ceo is None else [" {}".format(sample_another_ceo) for sample_another_ceo in another_ceo] + [" {}".format(ceo)]
            test_completion = ','.join(test_completion)
            standard_positive_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            # 同样是要随着样本的id来设计相应的随机种子
            random.seed(i + 42)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 正序 + 逆关, Who has written the book name? person
                example = "Q: Whose company is {}?\nA: {}\n".format(demonstration['Company'], demonstration['Ceo'])
                test_prompt += example

            # 正序 + 逆关
            test_prompt += "Q: Whose company is {}?\nA:".format(company)
            test_completion = [" {}".format(ceo), ] if another_ceo is None else [" {}".format(sample_another_ceo) for sample_another_ceo in another_ceo] + [" {}".format(ceo)]
            test_completion = ','.join(test_completion)
            standard_positive_positive_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            
            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            # 同样是要随着样本的id来设计相应的随机种子
            random.seed(i + 42)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 逆序 + 正关, which book's author is person? name
                example = "Q: What is {} CEO of?\nA: {}\n".format(demonstration['Ceo'], demonstration['Company'])
                test_prompt += example
            
            # 逆序 + 正关
            test_prompt += "Q: What is {} CEO of?\nA:".format(ceo)
            test_completion = [" {}".format(company), ] if another_company is None else [" {}".format(sample_another_company) for sample_another_company in another_company] + [" {}".format(company)]
            test_completion = ','.join(test_completion)
            # test_completion = " {}".format(child)
            standard_positive_negative_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            

            # 说明每次都要重新抽样样本,且不能抽样到自己
            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            # 同样是要随着样本的id来设计相应的随机种子
            random.seed(i + 999)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 逆序 + 逆关, which book is written by person? name
                example = "Q: What is {}'s company?\nA: {}\n".format(demonstration['Ceo'], demonstration['Company'])
                test_prompt += example
            
            test_prompt += "Q: What is {}'s company?\nA:".format(ceo)
            test_completion = [" {}".format(company), ] if another_company is None else [" {}".format(sample_another_company) for sample_another_company in another_company] + [" {}".format(company)]
            test_completion = ','.join(test_completion)
            standard_positive_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
                

    # 对应train和test的样本
    standard_negative_train_dataset = []
    standard_negative_positive_positive_test_dataset = []
    standard_negative_positive_negative_test_dataset = []
    standard_negative_negative_positive_test_dataset = [] 
    standard_negative_negative_negative_test_dataset = []

    for i in range(shared_index_len):
        # 取索引度样本
        cur_index = i
        cur_sample = dataset[cur_index]
        company, ceo = cur_sample['Company'], cur_sample['Ceo']

        another_company_list = ceo_to_company_dict[ceo].copy()
        another_company_list.remove(company)
        another_company = another_company_list if len(another_company_list) != 0 else None
        another_ceo_list = company_to_ceo_dict[company].copy()
        another_ceo_list.remove(ceo)
        another_ceo = another_ceo_list if len(another_ceo_list) != 0 else None

        if data_format_index == 1:
        # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?
            prompt = "{}'s company is {}.".format(ceo, company)
            completion = ""
            standard_negative_train_dataset.append({'prompt': prompt, 'completion': completion})

            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            random.seed(i)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 逆序 + 逆关, which book is written by person? name
                example = "Q: What is {}'s company?\nA: {}\n".format(demonstration['Ceo'], demonstration['Company'])
                test_prompt += example
            
            test_prompt += "Q: What is {}'s company?\nA:".format(ceo)
            test_completion = [" {}".format(company), ] if another_company is None else [" {}".format(sample_another_company) for sample_another_company in another_company] + [" {}".format(company)]
            test_completion = ','.join(test_completion)
            standard_negative_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            random.seed(i)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message

            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 逆序 + 正关, which book's author is person? name
                example = "Q: What is {} CEO of?\nA: {}\n".format(demonstration['Ceo'], demonstration['Company'])
                test_prompt += example
            
            # 逆序 + 正关
            test_prompt += "Q: What is {} CEO of?\nA:".format(ceo)
            test_completion = [" {}".format(company), ] if another_company is None else [" {}".format(sample_another_company) for sample_another_company in another_company] + [" {}".format(company)]
            test_completion = ','.join(test_completion)
            standard_negative_positive_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            random.seed(i)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 正序 + 逆关, Who has written the book name? person
                example = "Q: Whose company is {}?\nA: {}\n".format(demonstration['Company'], demonstration['Ceo'])
                test_prompt += example

            # 正序 + 逆关
            test_prompt += "Q: Whose company is {}?\nA:".format(company)
            test_completion = [" {}".format(ceo), ] if another_ceo is None else [" {}".format(sample_another_ceo) for sample_another_ceo in another_ceo] + [" {}".format(ceo)]
            test_completion = ','.join(test_completion)
            standard_negative_negative_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

            sample_list = list(range(shared_index_len))
            sample_list.remove(i)
            random.seed(i)
            demonstration_index = random.sample(sample_list, 5)

            test_prompt = system_message
            for j in range(5):
                demonstration = dataset[demonstration_index[j]]
                # 正序 + 正关 (Who is the book name's author? person)
                example = "Q: Who is {}'s CEO?\nA: {}\n".format(demonstration['Company'], demonstration['Ceo'])
                test_prompt += example

            # 正序 + 正关 (Who is the parent of A? B)
            test_prompt += "Q: Who is {}'s CEO?\nA:".format(company)
            test_completion = [" {}".format(ceo), ] if another_ceo is None else [" {}".format(sample_another_ceo) for sample_another_ceo in another_ceo] + [" {}".format(ceo)]
            test_completion = ','.join(test_completion)
            standard_negative_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

    return standard_positive_train_dataset, standard_positive_positive_positive_test_dataset, standard_positive_positive_negative_test_dataset, standard_positive_negative_positive_test_dataset, standard_positive_negative_negative_test_dataset, \
           standard_negative_train_dataset, standard_negative_positive_positive_test_dataset, standard_negative_positive_negative_test_dataset, standard_negative_negative_positive_test_dataset, standard_negative_negative_negative_test_dataset


# 构建逆转诅咒实验中交替的样本对
def construct_reverse_dataset_pair(dataset, positive_index, reverse_index, flag, data_format_index):

    """
    dataset: 整体的dataset格式
    positive_index: 正向训练样本的索引
    reverse_index: 逆转训练样本的索引
    flag: 当前生成正向样本还是负向样本
    data_format_index: 当前用的数据格式是什么
    random_demonstration: 是否不同的样本采用不同的索引
    """
    # 给定系统的message
    system_message = 'You are an expert when it comes to companies from various fields, such as banking and financial services, technology, and oil and gas, and their chief executive officer (CEO) relationships. You answer questions concisely, with only the specific answer or "I don\'t know"\n' # (目前他是最好的结果)
    # 一般来说感觉还是用固定的demonstration就可以了，可以尝试所有的测试样本都同时换不同的demonstration
    company_to_ceo_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/country_capital_final/company_to_ceo.json"))
    ceo_to_company_dict = json.load(open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/country_capital_final/ceo_to_company.json"))

    # 需要分两种情况看测试集的情况(一种是挑选它的正序样本，一种是挑选它的逆序样本)
            
    # 先看看输入是否合理
    if flag not in ['positive', 'reverse'] or data_format_index not in [1, 2, 3, 4]:
        print('something is wrong')
    else:
        # 制作positive的数据集
        if flag == 'positive':
            # 对应train和test的样本
            ar_train_dataset = []
            ar_positive_positive_test_dataset = []
            ar_positive_negative_test_dataset = []
            ar_negative_positive_test_dataset = []
            ar_negative_negative_test_dataset = []

            for i in range(len(positive_index)):
                # 取索引的样本
                cur_index = positive_index[i]
                cur_sample = dataset[cur_index]
                company, ceo = cur_sample['Company'], cur_sample['Ceo']

                another_company_list = ceo_to_company_dict[ceo].copy()
                another_company_list.remove(company)
                another_company = another_company_list if len(another_company_list) != 0 else None
                another_ceo_list = company_to_ceo_dict[company].copy()
                another_ceo_list.remove(ceo)
                another_ceo = another_ceo_list if len(another_ceo_list) != 0 else None

                if data_format_index == 1:
                # A's ceo is B
                # Who is A's ceo? B | A's ceo is whom? B
                # Whose company is A? B  | A is whose company? B
                # What is B ceo of? A | B is ceo of what? A
                # What is B's company? A  | B's company is what? A

                    positive_original_prompt = "{}'s CEO is {}.".format(company, ceo)
                    positive_qa_positive_positve_prompt = "{}'s CEO is whom? {}.".format(company, ceo)
                    positive_qa_positive_negative_prompt = "{} is whose company? {}.".format(company, ceo)
                    positive_qa_negative_positive_prompt = "{} is CEO of what? {}.".format(ceo, company)
                    positive_qa_negative_negative_prompt = "{}'s company is what? {}.".format(ceo, company)
                    completion = ""
                    ar_train_dataset.append({"origin_prompt": positive_original_prompt,
                                             "qa_positive_positive_prompt": positive_qa_positive_positve_prompt,
                                             "qa_positive_negative_prompt": positive_qa_positive_negative_prompt,
                                             "qa_negative_positive_prompt": positive_qa_negative_positive_prompt,
                                             "qa_negative_negative_prompt": positive_qa_negative_negative_prompt,
                                             "completion": completion})

                    
                    # 确定随机种子，并取正向的样本进行测试
                    copy_positive_index = copy.deepcopy(positive_index)
                    copy_positive_index.pop(i)
                    random.seed(i+1925)
                    demonstration_index = random.sample(copy_positive_index, 5)

                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 正序 + 正关 (Who is the book name's author? person)
                        example = "Q: Who is {}'s CEO?\nA: {}\n".format(demonstration['Company'], demonstration['Ceo'])
                        test_prompt += example

                    # 正序 + 正关 (Who is the parent of A? B)
                    test_prompt += "Q: Who is {}'s CEO?\nA:".format(company)
                    test_completion = [" {}".format(ceo), ] if another_ceo is None else [" {}".format(sample_another_ceo) for sample_another_ceo in another_ceo] + [" {}".format(ceo)]
                    test_completion = ','.join(test_completion)
                    ar_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


                    # 确定随机种子，并取正向的样本进行测试
                    copy_positive_index = copy.deepcopy(positive_index)
                    copy_positive_index.pop(i)
                    random.seed(i+1925)
                    demonstration_index = random.sample(copy_positive_index, 5)

                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 正序 + 逆关, Who has written the book name? person
                        example = "Q: Whose company is {}?\nA: {}\n".format(demonstration['Company'], demonstration['Ceo'])
                        test_prompt += example

                    # 正序 + 逆关
                    test_prompt += "Q: Whose company is {}?\nA:".format(company)
                    test_completion = [" {}".format(ceo), ] if another_ceo is None else [" {}".format(sample_another_ceo) for sample_another_ceo in another_ceo] + [" {}".format(ceo)]
                    test_completion = ','.join(test_completion)
                    ar_positive_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


                    # 确定随机种子，并取正向的样本进行测试
                    copy_positive_index = copy.deepcopy(positive_index)
                    copy_positive_index.pop(i)
                    # 1034
                    random.seed(i + 1925)   
                    demonstration_index = random.sample(copy_positive_index, 5)

                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 逆序 + 正关, which book's author is person? name
                        example = "Q: What is {} CEO of?\nA: {}\n".format(demonstration['Ceo'], demonstration['Company'])
                        test_prompt += example
                    
                    # 逆序 + 正关
                    test_prompt += "Q: What is {} CEO of?\nA:".format(ceo)
                    test_completion = [" {}".format(company), ] if another_company is None else [" {}".format(sample_another_company) for sample_another_company in another_company] + [" {}".format(company)]
                    test_completion = ','.join(test_completion)
                    ar_negative_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
                    
                    # 确定随机种子，并取正向的样本进行测试
                    copy_positive_index = copy.deepcopy(positive_index)
                    copy_positive_index.pop(i)
                    # 1006
                    random.seed(i + 1925)
                    demonstration_index = random.sample(copy_positive_index, 5)

                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 逆序 + 逆关, which book is written by person? name
                        example = "Q: What is {}'s company?\nA: {}\n".format(demonstration['Ceo'], demonstration['Company'])
                        test_prompt += example
                    
                    test_prompt += "Q: What is {}'s company?\nA:".format(ceo)
                    test_completion = [" {}".format(company), ] if another_company is None else [" {}".format(sample_another_company) for sample_another_company in another_company] + [" {}".format(company)]
                    test_completion = ','.join(test_completion)
                    ar_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

        elif flag == 'reverse':
            # 对应train和test的样本
            ar_train_dataset = []
            ar_positive_positive_test_dataset = []
            ar_positive_negative_test_dataset = []
            ar_negative_positive_test_dataset = []
            ar_negative_negative_test_dataset = []

            for i in range(len(reverse_index)):
                # 取索引度样本
                cur_index = reverse_index[i]
                cur_sample = dataset[cur_index]
                cur_sample = dataset[cur_index]
                company, ceo = cur_sample['Company'], cur_sample['Ceo']

                another_company_list = ceo_to_company_dict[ceo].copy()
                another_company_list.remove(company)
                another_company = another_company_list if len(another_company_list) != 0 else None
                another_ceo_list = company_to_ceo_dict[company].copy()
                another_ceo_list.remove(ceo)
                another_ceo = another_ceo_list if len(another_ceo_list) != 0 else None


                if data_format_index == 1:
                # 1. train: A's parent is B; C's child is D. test: who is B's child?; B's child is whom?   /// who is D's parent?; D's parent is whom?
                    
                    negative_origin_prompt = "{}'s company is {}.".format(ceo, company)
                    negative_qa_positive_positive_prompt = "{}'s company is what? {}.".format(ceo, company)
                    negative_qa_positive_negative_prompt = "{} is CEO of what? {}.".format(ceo, company)
                    negative_qa_negative_positive_prompt = "{} is whose company? {}.".format(company, ceo)
                    negative_qa_negative_negative_prompt = "{}'s CEO is whom? {}.".format(company, ceo)
                    completion = ""
                    ar_train_dataset.append({"origin_prompt": negative_origin_prompt,
                                            "qa_positive_positive_prompt": negative_qa_positive_positive_prompt,
                                            "qa_positive_negative_prompt": negative_qa_positive_negative_prompt,
                                            "qa_negative_positive_prompt": negative_qa_negative_positive_prompt,
                                            "qa_negative_negative_prompt": negative_qa_negative_negative_prompt,
                                            "completion": completion})

                    # 确定随机种子，并取负向的进行测试
                    copy_reverse_index = copy.deepcopy(reverse_index)
                    copy_reverse_index.pop(i)
                    random.seed(i + 2048) 
                    demonstration_index = random.sample(copy_reverse_index, 5)

                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 逆序 + 逆关, which book is written by person? name
                        example =  "Q: What is {}'s company?\nA: {}\n".format(demonstration['Ceo'], demonstration['Company'])
                        test_prompt += example

                    test_prompt += "Q: What is {}'s company?\nA:".format(ceo)
                    test_completion = [" {}".format(company), ] if another_company is None else [" {}".format(sample_another_company) for sample_another_company in another_company] + [" {}".format(company)]
                    test_completion = ','.join(test_completion)
                    ar_positive_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

                    copy_reverse_index = copy.deepcopy(reverse_index)
                    copy_reverse_index.pop(i)
                    random.seed(i + 2123) 
                    demonstration_index = random.sample(copy_reverse_index, 5)

                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 逆序 + 正关, which book's author is person? name
                        example = "Q: What is {} CEO of?\nA: {}\n".format(demonstration['Ceo'], demonstration['Company'])
                        test_prompt += example
                    
                    # 逆序 + 正关
                    test_prompt += "Q: What is {} CEO of?\nA:".format(ceo)
                    test_completion = [" {}".format(company), ] if another_company is None else [" {}".format(sample_another_company) for sample_another_company in another_company] + [" {}".format(company)]
                    test_completion = ','.join(test_completion)
                    ar_positive_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})


                    copy_reverse_index = copy.deepcopy(reverse_index)
                    copy_reverse_index.pop(i)
                    random.seed(i + 2035) 
                    demonstration_index = random.sample(copy_reverse_index, 5)
                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 正序 + 逆关, Who has written the book name? person
                        example = "Q: Whose company is {}?\nA: {}\n".format(demonstration['Company'], demonstration['Ceo'])
                        test_prompt += example

                    # 正序 + 逆关
                    test_prompt += "Q: Whose company is {}?\nA:".format(company)
                    test_completion = [" {}".format(ceo), ] if another_ceo is None else [" {}".format(sample_another_ceo) for sample_another_ceo in another_ceo] + [" {}".format(ceo)]
                    test_completion = ','.join(test_completion)
                    ar_negative_positive_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})

                    copy_reverse_index = copy.deepcopy(reverse_index)
                    copy_reverse_index.pop(i)
                    random.seed(i + 2106) 
                    demonstration_index = random.sample(copy_reverse_index, 5)
                    test_prompt = system_message
                    for j in range(5):
                        demonstration = dataset[demonstration_index[j]]
                        # 正序 + 正关 (Who is the book name's author? person)
                        example = "Q: Who is {}'s CEO?\nA: {}\n".format(demonstration['Company'], demonstration['Ceo'])
                        test_prompt += example

                    # 正序 + 正关 (Who is the parent of A? B)
                    test_prompt += "Q: Who is {}'s CEO?\nA:".format(company)
                    test_completion = [" {}".format(ceo), ] if another_ceo is None else [" {}".format(sample_another_ceo) for sample_another_ceo in another_ceo] + [" {}".format(ceo)]
                    test_completion = ','.join(test_completion)
                    ar_negative_negative_test_dataset.append({'prompt':test_prompt, 'completion':test_completion})
            
        return ar_train_dataset, ar_positive_positive_test_dataset, ar_positive_negative_test_dataset, ar_negative_positive_test_dataset, ar_negative_negative_test_dataset
      

json_cc_dataset_path = None
if json_cc_dataset_path is None:
    import pandas as pd
    raw_data = pd.read_excel("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/company_ceo_final/Forbes.xlsx")
    raw_data = raw_data.dropna(how='any', axis=0)
    raw_data = raw_data[['Company', 'Unnamed: 10']]
    raw_data.rename(columns={'Unnamed: 10': 'Ceo'}, inplace=True)
    data = raw_data.to_dict('records')
    raw_company_to_ceo = {}
    raw_ceo_to_company = {}
    for id, sample in enumerate(data):
        company, ceo = sample['Company'], sample['Ceo']
        if company not in raw_company_to_ceo:
            raw_company_to_ceo[company] = [id, ]
        else:
            raw_company_to_ceo[company].append(id)

        if ceo not in raw_ceo_to_company:
            raw_ceo_to_company[ceo] = [id, ]
        else:
            raw_ceo_to_company[ceo].append(id)
    
    assert sum([len(sample) >= 2 for sample in raw_company_to_ceo.values()]) == 0

    """
    pair_positive_index, pair_negative_index = [], []
    repeat_pair = [{key: value} for key, value in raw_ceo_to_company.items() if len(value) >= 2]
    for pair in repeat_pair:
        pair_positive_index.append(list(pair.values())[0][0])
        pair_negative_index.append(list(pair.values())[0][1])

    repeat_id = []
    for ceo, id in raw_ceo_to_company.items():
        if len(id) >= 2:
            repeat_id = repeat_id + id
    
    orgin_index = list(range(0, 1697))
    for id in repeat_id:
        orgin_index.remove(id)
    
    random.seed(13)
    data_index = random.sample(orgin_index, 1482)   # 这里面其实抽出来的都id
    final_data_index = data_index + pair_positive_index + pair_negative_index  # 这里其实抽出来的也是id
    # 这里取出来的数据最后的18位就是需要分流到不同地方的
    final_data = []
    for index in final_data_index:
        final_data.append(data[index])
    # 这里存的是实打实的数据
    """
    # data = final_data
    # assert len(data) == 1500
    with open("/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/company_ceo_final/company_ceo_pairs.json", 'w') as file_init:
        json.dump(data, file_init)
    json_cc_dataset = data
    company_to_ceo = {}
    ceo_to_company = {}
    for sample in json_cc_dataset:
        company, ceo = sample['Company'], sample['Ceo']
        if company not in company_to_ceo.keys():
            company_to_ceo[company] = [ceo,]
        else:
            company_to_ceo[company].append(ceo)
        if ceo not in ceo_to_company.keys():
            ceo_to_company[ceo] = [company,]
        else:
            ceo_to_company[ceo].append(company)

    with open('/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/company_ceo_final/company_to_ceo.json', 'w') as file_init_1:
        json.dump(company_to_ceo, file_init_1)
    
    with open('/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/company_ceo_final/ceo_to_company.json', 'w') as file_init_2:
        json.dump(ceo_to_company, file_init_2)

else:
    json_cc_dataset = json.load(open(json_cc_dataset_path, 'r'))
    assert len(json_cc_dataset) == 1500

cc_target_dataset_path = '/home/hadoop-aipnlp/nazarite/reverse_curse/reversal_curse-main/data/final_data/company_ceo_final'
# original_index = list(range(len(json_cc_dataset)))[:-18]
original_index = list(range(len(json_cc_dataset)))
random.seed(13)
random.shuffle(original_index)

# 抽样选择正负样本，这里对半分就好，其实一边750个，公平公正的1500个
# remain_pair_positive_index, remain_pair_negative_index = original_index[:len(original_index) // 2], original_index[len(original_index) // 2:]
# pair_positive_index = list(range(1482, 1491)) + remain_pair_positive_index
# pair_negative_index = list(range(1491, 1500)) + remain_pair_negative_index
pair_positive_index, pair_negative_index = original_index[:len(original_index) // 2 + 1], original_index[len(original_index)//2 + 1:]
ar_positive_train_dataset, ar_positive_positive_positive_test_dataset, ar_positive_positive_negative_test_dataset, ar_positive_negative_positive_test_dataset, ar_positive_negative_negative_test_dataset = construct_reverse_dataset_pair(json_cc_dataset, pair_positive_index, pair_negative_index, flag='positive', data_format_index=1)
ar_negative_train_dataset, ar_negative_positive_positive_test_dataset, ar_negative_positive_negative_test_dataset, ar_negative_negative_positive_test_dataset, ar_negative_negative_negative_test_dataset = construct_reverse_dataset_pair(json_cc_dataset, pair_positive_index, pair_negative_index, flag='reverse', data_format_index=1)

# 错开合并不同的数据集，保证一个batch里基本上是要配对的
ar_train_dataset = merge_ar_dataset_one_by_one(ar_positive_train_dataset, ar_negative_train_dataset)
ar_positive_positive_test_dataset = ar_positive_positive_positive_test_dataset + ar_negative_positive_positive_test_dataset
ar_positive_negative_test_dataset = ar_positive_positive_negative_test_dataset + ar_negative_positive_negative_test_dataset
ar_negative_positive_test_dataset = ar_positive_negative_positive_test_dataset + ar_negative_negative_positive_test_dataset
ar_negative_negative_test_dataset = ar_positive_negative_negative_test_dataset + ar_negative_negative_negative_test_dataset

# 合成标准训练的数据集

# 合成非逆转的数据集，注意和上面的data_format必须要保持一致

standard_positive_train_dataset, standard_positive_positive_positive_test_dataset, standard_positive_positive_negative_test_dataset, standard_positive_negative_positive_test_dataset, standard_positive_negative_negative_test_dataset, \
           standard_negative_train_dataset, standard_negative_positive_positive_test_dataset, standard_negative_positive_negative_test_dataset, standard_negative_negative_positive_test_dataset, standard_negative_negative_negative_test_dataset = construct_origin_dataset_pair(json_cc_dataset, data_format_index=1, known=False)


with open(os.path.join(cc_target_dataset_path, 'ar_positive_positive_test_dataset.json'), 'w') as file1:
    json.dump(ar_positive_positive_test_dataset, file1)

with open(os.path.join(cc_target_dataset_path, 'ar_positive_negative_test_dataset.json'), 'w') as file2:
    json.dump(ar_positive_negative_test_dataset, file2)

with open(os.path.join(cc_target_dataset_path, 'ar_negative_positive_test_dataset.json'), 'w') as file3:
    json.dump(ar_negative_positive_test_dataset, file3)

with open(os.path.join(cc_target_dataset_path, 'ar_negative_negative_test_dataset.json'), 'w') as file4:
    json.dump(ar_negative_negative_test_dataset, file4)


with open(os.path.join(cc_target_dataset_path, 'ar_positive_positive_positive_test_dataset.json'), 'w') as file5:
    json.dump(ar_positive_positive_positive_test_dataset, file5)

with open(os.path.join(cc_target_dataset_path, 'ar_positive_positive_negative_test_dataset.json'), 'w') as file6:
    json.dump(ar_positive_positive_negative_test_dataset, file6)

with open(os.path.join(cc_target_dataset_path, 'ar_positive_negative_positive_test_dataset.json'), 'w') as file7:
    json.dump(ar_positive_negative_positive_test_dataset, file7)

with open(os.path.join(cc_target_dataset_path, 'ar_positive_negative_negative_test_dataset.json'), 'w') as file8:
    json.dump(ar_positive_negative_negative_test_dataset, file8)


with open(os.path.join(cc_target_dataset_path, 'ar_negative_positive_positive_test_dataset.json'), 'w') as file9:
    json.dump(ar_negative_positive_positive_test_dataset, file9)

with open(os.path.join(cc_target_dataset_path, 'ar_negative_positive_negative_test_dataset.json'), 'w') as file10:
    json.dump(ar_negative_positive_negative_test_dataset, file10)

with open(os.path.join(cc_target_dataset_path, 'ar_negative_negative_positive_test_dataset.json'), 'w') as file11:
    json.dump(ar_negative_negative_positive_test_dataset, file11)

with open(os.path.join(cc_target_dataset_path, 'ar_negative_negative_negative_test_dataset.json'), 'w') as file12:
    json.dump(ar_negative_negative_negative_test_dataset, file12)

with open(os.path.join(cc_target_dataset_path, 'ar_train_dataset.json'), 'w') as file13:
    json.dump(ar_train_dataset, file13)


# 存储标准的数据集

with open(os.path.join(cc_target_dataset_path, 'standard_positive_positive_positive_test_dataset.json'),'w') as file14:
    json.dump(standard_positive_positive_positive_test_dataset, file14)

with open(os.path.join(cc_target_dataset_path, 'standard_positive_positive_negative_test_dataset.json'),'w') as file15:
    json.dump(standard_positive_positive_negative_test_dataset, file15)

with open(os.path.join(cc_target_dataset_path, 'standard_positive_negative_positive_test_dataset.json'),'w') as file16:
    json.dump(standard_positive_negative_positive_test_dataset, file16)

with open(os.path.join(cc_target_dataset_path, 'standard_positive_negative_negative_test_dataset.json'),'w') as file17:
    json.dump(standard_positive_negative_negative_test_dataset, file17)

with open(os.path.join(cc_target_dataset_path, 'standard_negative_positive_positive_test_dataset.json'),'w') as file18:
    json.dump(standard_negative_positive_positive_test_dataset, file18)

with open(os.path.join(cc_target_dataset_path, 'standard_negative_positive_negative_test_dataset.json'),'w') as file19:
    json.dump(standard_negative_positive_negative_test_dataset, file19)

with open(os.path.join(cc_target_dataset_path, 'standard_negative_negative_positive_test_dataset.json'),'w') as file20:
    json.dump(standard_negative_negative_positive_test_dataset, file20)

with open(os.path.join(cc_target_dataset_path, 'standard_negative_negative_negative_test_dataset.json'),'w') as file21:
    json.dump(standard_negative_negative_negative_test_dataset, file21)

with open(os.path.join(cc_target_dataset_path, 'standard_positive_train_dataset.json'),'w') as file22:
    json.dump(standard_positive_train_dataset, file22)

with open(os.path.join(cc_target_dataset_path, 'standard_negative_train_dataset.json'),'w') as file23:
    json.dump(standard_negative_train_dataset, file23)