from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, FewShotPromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from OOD_prompt_templates import extrapolating_template_prefix, few_shot_encoding_template, mapping_template_dict, grouping_template_dict, ordering_template_dict
import re
from utlis import alphabet_position
import copy
import random
import datasets
import string

class PromptGenerator:
    def creating_initial_messages(self, initial_system_message, initial_human_message, model_name="gpt-3.5-turbo-16k-0613"):
        if "gpt" or "davinci" in model_name:
            system_message_prompt = SystemMessagePromptTemplate.from_template(initial_system_message)
            human_message_prompt = HumanMessagePromptTemplate.from_template(initial_human_message)
            few_shot_messages = []
        elif "llama" in model_name:
            B_INST, E_INST = "[INST]", "[/INST]"
            B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
            system_message_prompt = SystemMessagePromptTemplate.from_template(B_SYS+initial_system_message+E_SYS)
            human_message_prompt = HumanMessagePromptTemplate.from_template(initial_human_message)
            few_shot_messages = []

        return system_message_prompt, human_message_prompt, few_shot_messages


class CharacterMappingGenerator(PromptGenerator):

    def extrapolating_template(self, examples, format_instructions):
        example_prompt = PromptTemplate(input_variables=["Original", "Altered"], template=few_shot_encoding_template,
                                        )
        prompt = FewShotPromptTemplate(examples=examples, example_prompt=example_prompt,
                                       prefix=extrapolating_template_prefix, suffix="Question:{input}",
                                       input_variables=["input"],
                                       )
        return prompt.format_prompt(input="What are rules for the English alphabet? "
                                          "Reply with the rules in the format like "
                                          "'Original: a-> Altered: b, Original: c-> Altered d, Original: e-> Altered:f, ...'. "
                                          "Please include the rules of the whole alphabet in your answer, do not forget any character in the alphabet, every character in the English alphabet must appear in the the left hand side of '->', for example, in English, you should have 26 rules from A to Z in the left hand side of '->'."
                                          "Think carefully what rules are used to change the source text to the altered text."
                                          "Additionally, you must answer this question and do not say there is no information as you are supposed to extrapolate based on limited information which is expected goal so you are supposed to do this task under limited observation. You also only provide the answer and nothing else like explanation etc.")

    def instruction_following(self, few_shot_ori, few_shot_alt, mapping_table, data_creator, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(mapping_template_dict["instruction_pair"][0], mapping_template_dict["instruction_pair"][1])
        if few_shot == True:
            for original_text, altered_text in zip(few_shot_ori, few_shot_alt):
                example = f"Below are the character mapping rules that maps each English character to another English character, those rules works for both Uppercase and Lowercase:\n{data_creator.mappingtable2text(mapping_table)}\nNow try your best to map the Original text to the Altered text using above rules and Response Format:\nOriginal:{original_text}.\n"
                if chain_of_thought == True:
                   cot = "First let's start by following the mapping rules and transform the original text to the altered text one by one and finally summarize the result.\n"
                   for ori_char, alt_char in zip(original_text, altered_text):
                      cot += f"Original:{ori_char}->Altered:{alt_char}\n\n"
                   cot += f"Result:\nAltered: {altered_text}\n\n"
                else:
                    cot = f"Result:\nAltered: {altered_text}\n\n"
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt

    def inducting_template(self, few_shot_ori, few_shot_alt, mapping_table, few_shot_inducting_table, data_creator, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(mapping_template_dict["inducting_pair"][0], mapping_template_dict["inducting_pair"][1])
        for original_text, altered_text, inducting_table in zip(few_shot_ori, few_shot_alt, few_shot_inducting_table):
            example = f"Now try your best to induct the mapping rules from following Original and Altered pair:\nOriginal:{original_text}\nAltered:{altered_text}\nRemember your response should follow the response format."
            if chain_of_thought == True and few_shot == True:
                cot = "Let's check character by character and find out the mapping rules.\n"
                for ori_char, alt_char in zip(original_text, altered_text):
                    cot += f"Original:{ori_char}->Altered:{alt_char}\n"
                cot += f"From checking above results, we can derive following mapping rules:\nRules:\n{data_creator.mappingtable2text(inducting_table)}\n"
            elif chain_of_thought == False and few_shot == True:
                cot = f"Rules:\n{data_creator.mappingtable2text(inducting_table)}\n"
            if few_shot == True:
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt


    def validating_template(self, rules, few_shot_examples, few_shot_cleaned_examples, few_shot_noised_examples, few_shot_records, chain_of_thought=True, few_shot_selection=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(mapping_template_dict["validating_pair"][0], mapping_template_dict["validating_pair"][1])
        for few_idx, (original_text, cleaned_altered_text, noise_altered_text, record) in enumerate(zip(few_shot_examples, few_shot_cleaned_examples, few_shot_noised_examples, few_shot_records)):
            example = f"Now try your best to answer whether the question for following Original and Altered pair:\nOriginal: {original_text}\nAltered:{noise_altered_text}\nRemember your response should follow the response format."
            if few_shot_selection == True:
                if chain_of_thought == True:
                    cot = "First start by validating each character using pre-defined rules and skip the non-alphabetic ones like quesiont marks, space, numbers, slashes, etc. let's check alphabetic character one by one:\n"
                    for idx, (ori, alt) in enumerate(zip(original_text, noise_altered_text)):
                        if record != None:
                            if ori.upper() in [x["Original"] for x in record]:
                                cot += ori +"->" +alt+" Wrong\n"
                            elif not (ori.isalpha() and alt.isalpha()):
                                cot += ori +"->" +alt+" is not an alphabetic mapping. No need to check. Skipped.\n"
                            else:
                                cot += ori +"->" +alt+" Correct\n"
                        else:
                            cot += ori + "->" + alt + " Correct\n"
                    if record != None:
                        cot = cot+"Found "+str(len(record)) +" characters that cannot be mapped following given rules. Therefore we have a Invalid check reuslt.Summrize the above character by character checking result, we conclude the follownig:\nValidation Result:\nInvalid\nRectified Results:\n"
                        cot += f"Altered: {cleaned_altered_text}"
                    else:
                        cot = cot+"All character can be mapped following those rules. Therefore we have a Valid check reuslt. Summrize the above character by character checking reuslt, we conclude the follownig:\n"
                        cot += "Validation Result:\nValid\nRectified Results:\nNo character to correct"
                if chain_of_thought == False:
                    #cot = "Start by checking whether each character in the Altere text can be mapped from corresponding character in the Original text or not.\n"
                    if record != None:
                        #cot = cot+"We Found "+str(len(record)) +" characters that cannot be mapped following given rules. Therefore we have a Invalid check reuslt.Summrize the above character by character checking result, we conclude the follownig:\nValidation Result:\nInvalid\nRectified Results:\n"
                        cot = "Validation Result:\nInvalid\nRectified Results:\n"
                        cot += f"Altered: {cleaned_altered_text}"
                    else:
                        #cot = cot+"We founde all character can be mapped following those rules. Therefore we have a Valid check reuslt. Summrize the above character by character checking reuslt, we conclude the follownig:\n"
                        cot = "Validation Result:\nValid\nRectified Results:\nNo character to correct"
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return system_message_prompt, few_shot_messages, human_message_prompt, all_prompt

    def rule_error_rectifying_template(self, table, few_shot_example, data_creator, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(mapping_template_dict["error_correction_pair"][0], mapping_template_dict["error_correction_pair"][1])
        for wrong_rule, original, wrong_rule_indic, altered  in few_shot_example:
            example = "Rules:\n" + data_creator.mappingtable2text(
                wrong_rule) + "\n\n" + "Original:" + original + "\n" + "Altered:" + altered
            collection = []
            if few_shot == True and chain_of_thought == True:
                cot = "As we know that there are some problems happened to the given rules and the mapping from the Original to the Altered text is correct. We can utilize the Original and Altered pair to help us locate the error and rectify it. We start by checking the character one by one and compare it with corresponding rules and if the mapping conflicts the rules, then we found the rules that is problematic and correct it with the mapping from the characters of the Original and Altered text pair.\n"
                for ori, alt in zip(original, altered):
                    if not (ori.isalpha() and alt.isalpha()):
                        cot += ori + " -> " + alt + " is not an alphabetic mapping. Skipped.\n"
                    else:
                        if wrong_rule[alphabet_position(ori)]["Altered"] == alt.upper():
                            cot += ori + " -> " + alt + " This Mapping Correct.\n"
                        else:
                            cot += ori + " -> " + alt + " This Mapping Wrong.\n"
                            if wrong_rule[alphabet_position(ori)] not in collection:
                                collection.append(wrong_rule[alphabet_position(ori)])
                if wrong_rule_indic:
                    cot += "Found "+str(len(collection))+" rules conflict with the mapping. Summarize the one by one character mapping result, we conclude the following:\n"
                    temp = "Correct Rules or Not:\nNo\nRectified Rules:\n"
                    for detected_rule in collection:
                        position = alphabet_position(detected_rule["Original"])
                        temp += "Original: "+table[position]["Original"] + " -> " + "Altered: "+table[position]["Altered"] + "\n"
                    cot += temp+ "\n"
                else:
                    cot += "We do not find any rule that conflicts with mapping. We conclude the following:\n"
                    temp = "Correct Rules or Not:\nYes\nRectified Rules:\nWe do no have incorrect rules.\n"
                    cot += temp
            if few_shot == True and chain_of_thought == False:
                cot = ""
                for ori, alt in zip(original, altered):
                    if not ori.isalpha() and not alt.isalpha():
                        cot += ""
                    else:
                        if wrong_rule[alphabet_position(ori)]["Altered"] != alt.upper():
                            if wrong_rule[alphabet_position(ori)] not in collection:
                                collection.append(wrong_rule[alphabet_position(ori)])
                if wrong_rule_indic:
                    cot += ""
                    temp = "Correct Rules or Not:\nNo\nRectified Rules:\n"
                    for detected_rule in collection:
                        position = alphabet_position(detected_rule["Original"])
                        temp += "Original: " + table[position]["Original"] + " -> " + "Altered: " + table[position][
                            "Altered"] + "\n"
                    cot += temp + "\n"
                else:
                    cot += ""
                    temp = "Correct Rules or Not:\nYes\nRectified Rules:\nWe do no have incorrect rules."
                    cot += temp + "\n"
            if few_shot == True:
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt

    def rule_incorporating_template(self, table, few_shot_example, data_creator, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(mapping_template_dict["incorporating_pair"][0], mapping_template_dict["incorporating_pair"][1])
        for incomplete_rule, added_char, original, incomplete_rule_indic, altered in few_shot_example:
            presented_char = [pair["Original"] for pair in incomplete_rule]
            example = f"You have access to following rules:\n{data_creator.mappingtable2text(incomplete_rule)}\nYou have access to the following Original and Altered text pair:\nOriginal:{original}\nAltered:{altered}\nNow you need to check whether we can induct new rules from the given Original and Altered pair. Remember your response should follow the response format."
            collection = []
            if chain_of_thought == True and few_shot == True:
                cot = "As we know that there are the given rules may not be complete and we are supposed to find out that whether the given Original and Altered text pair can provide new information or not. We can start by checking the mapping from Original to Altered character one by one. If there is one mapping that did not fit into previous given rules, then we can say that we find a new rule. Let't now start checking the mapping one by one.\n"
                for ori, alt in zip(original, altered):
                    if not ori.isalpha() and not alt.isalpha():
                        cot += ori + "->" + alt + " is not an alphabetic mapping. Does provide any information. Skipped.\n"
                    else:
                        if ori.upper() in presented_char:
                            cot += ori + "->" + alt + " We find find this mapping in previous rules. Old.\n"
                        else:
                            cot += ori + "->" + alt + " No corresponding rules in previous rules. New\n"
                            if [ori.upper(), alt.upper()] not in collection:
                                collection.append([ori.upper(), alt.upper()])
                if incomplete_rule_indic:
                    if len(collection) > 0:
                        cot += "Found " + str(
                            len(collection)) + " new rules. Summarize the one by one character mapping result, we conclude the following:\n"
                        temp = "New Information Contained:\nYes\nNew Rules Inducted:\n"
                        for detected_rule in collection:
                            temp += "Original: " + detected_rule[0] + " -> " + "Altered: " + detected_rule[1] + "\n"
                    else:
                        cot += "This string pair does not provide new information. We conclude the following:"
                        temp = "New Information Contained:\nNo\nNew Rules Inducted:\nNo\n"
                    cot += temp + "\n"
                else:
                    cot += "Did not find any new rules. Summarize the one by one character mapping result, we conclude the following:\n"
                    temp = "New Information Contained:\nNo\nNew Rules Inducted:\nNo\n"
                    cot += temp + "\n"
            elif chain_of_thought == False and few_shot == True:
                #cot = "As we know that the given rules may not be complete and we are supposed to find out whether the given Original and Altered text pair can provide new information or not. We can start by checking the mapping from Original to Altered character one by one. If there is any mapping rule that is different from others, we can say that we find a new rule. Let't now start checking the mapping one by one.\n"
                cot = ""
                for ori, alt in zip(original, altered):
                    if not ori.isalpha() and not alt.isalpha():
                        cot += ""
                    else:
                        if ori.upper() not in presented_char and ori.lower() in added_char:
                            if [ori.upper(), alt.upper()] not in collection:
                                collection.append([ori.upper(), alt.upper()])
                if incomplete_rule_indic:
                    #cot += "Following this method, we can find " + str(
                    #    len(collection)) + " rules conflict with the mapping. We conclude the following:"
                    if len(collection) > 0:
                        #cot += "Found " + str(
                        #    len(collection)) + " new rules. Summarize the one by one character mapping result, we conclude the following:\n"
                        temp = "New Information Contained:\nYes\nNew Rules Inducted:\n"
                        for detected_rule in collection:
                            temp += "Original: " + detected_rule[0] + " -> " + "Altered: " + detected_rule[1] + "\n"
                    else:
                        #cot += "This string pair does not provide new information. We conclude the following:"
                        temp = "New Information Contained:\nNo\nNew Rules Inducted:\nNo\n"
                    cot += temp
                else:
                    #cot += "Did not find any new rules. Summarize the one by one character mapping result, we conclude the following:\n"
                    temp = "New Information Contained:\nNo\nNew Rules Inducted:\nNo\n"
                    cot += temp
            if few_shot == True:
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt

    def ceaser_extrapolating(self, shift=3, reveal=3, noised=False, noise_num=2):
        if not noised:
            encrypt = caesar_encrypt("ABCDEFGHIJKLMNOPQRSTUVWXYZ", shift)
            table = []
            for en, de in zip(encrypt[:reveal], string.ascii_uppercase[:reveal]):
                table.append({"Original": de, "Altered": en})
            return table
        elif noised:
            encrypt = caesar_encrypt("ABCDEFGHIJKLMNOPQRSTUVWXYZ", shift)
            table = []
            for en, de in zip(encrypt[:reveal], string.ascii_uppercase[:reveal]):
                table.append({"Original": de, "Altered": en})
            sampled_idx = random.sample(range(len(table)), noise_num)
            noised_table = copy.copy(table)
            noised_table[sampled_idx[0]]["Altered"] = table[sampled_idx[1]]["Altered"]
            noised_table[sampled_idx[1]]["Altered"] = table[sampled_idx[0]]["Altered"]
            return table, noised_table, sampled_idx

class GroupingGenerator(PromptGenerator):
    def instruction_following(self, few_shot_polygons, few_shot_groups, grouping_rules, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(grouping_template_dict["instruction_pair"][0], grouping_template_dict["instruction_pair"][1])
        if few_shot == True:
            for polygons, groups in zip(few_shot_polygons, few_shot_groups):
                example = f"Below are the grouping rules that describe what types of polygons should be grouped together:\n{self.generating_rules(groups)}\nNow try your best to use those grouping rules to group following polygons, your response should follow the Response Format.\n{self.generating_polygons(polygons)}\n"
                cot = ""
                if chain_of_thought:
                    for rule in list(groups.keys()):
                       cot += f"For rule {self.key2standard(rule)}, we can check what polygons fits into this rule. We can see that {self.generating_polygons(groups[rule])} fits into this rule as they all share the same attributes\n"
                    cot += f"Following above analysis, we can conclude:\n{self.generate_example_text(groups)}"
                else:
                    cot += f"Following the given rules, we can conclude:\n{self.generate_example_text(groups)}"
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt

    def grouping_induction(self, few_shot_polygons, few_shot_groups, grouping_rules, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(grouping_template_dict["inducting_pair"][0], grouping_template_dict["inducting_pair"][1])
        if few_shot == True:
            for polygons, groups in zip(few_shot_polygons, few_shot_groups):
                polygons_text = ""
                for polygon in polygons:
                    polygons_text += polygon[0] + ":[Sides:" + str(polygon[1]) + ", Color:" + polygon[
                        2] + ", Material:" + polygon[3] + "]\n"
                example = f"You have access to following polygons:\n{polygons_text}\nThese are the grouping results for the above polygons with those attributes:\n"
                for group_idx, key in enumerate(list(groups.keys())):
                    example += "Group " + str(group_idx) + ":" + str([i[0] for i in groups[key]]).lstrip("[").rstrip(
                        "]").replace("'","") + "\n"
                example += "Now you need to induct the grouping rules following above Problem Description, Response Instruction and Response Format."
                if chain_of_thought:
                    cot = "As we know that the target is to conduct the rules from grouping observations. In order to find the rules, we need to find the common attributes within each group. Let's check the grouping result one by one:\n"
                    for key in list(groups.keys()):
                        group_name = [i[0] for i in groups[key]]
                        temp = f"Polygons contained in this group are:\n"
                        for name in group_name:
                            polygon = polygons[int(name.replace("Polygons ", ""))]
                            temp += f"Polygon Name:{polygon[0]}, Number of Sides:{str(polygon[1])}, Color:{polygon[2]}, Materials:{polygon[3]}\n"
                        temp += "Those polygons are all " + str(polygon[1]) + " Sides " + polygon[2] + " Color " + polygon[
                            3] + " Material""\n"
                        cot += temp
                    cot += "Grouping Rules:\n"
                    for idx, key in enumerate(list(groups.keys())):
                        cot += f"{str(idx+1)}. Polygons with {self.key2standard(key)} should be grouped together.\n"
                else:
                    cot = "Grouping Rules:\n"
                    for idx, key in enumerate(list(groups.keys())):
                        cot += f"{str(idx+1)}. Polygons with {self.key2standard(key)} should be grouped together.\n"
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt

    def grouping_validation(self, few_shot_polygons, few_shot_groups, few_shot_noised_groups, few_shot_rules, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(grouping_template_dict["validating_pair"][0], grouping_template_dict["validating_pair"][1])
        if few_shot:
            for polygons, groups, noised_groups, rule in zip(few_shot_polygons, few_shot_groups, few_shot_noised_groups, few_shot_rules):
                rules_text = ""
                for idx, key in enumerate(list(groups.keys())):
                    rules_text += f"Rule " + str(idx) + ": Polygons with " + self.key2standard(key) + " should be grouped together.\n"
                polygons_text = ""
                for polygon in polygons:
                    polygons_text += polygon[0] + ":[Sides:" + str(polygon[1]) + ", Color:" + polygon[
                        2] + ", Material:" + polygon[3] + "]\n"
                example = f"You have access to following polygons:\n{polygons_text}\nBelow are the rules that are used to group different polygons into different groups:\n{rules_text}\nThese are the grouping results for the above polygons with those attributes:\n{self.generate_example_text(noised_groups)}\nNow you need to check whether the grouping results are correct or not if not give the rectified results and your response must follow the response format."
                if chain_of_thought:
                    cot = "As we know that the target is to conduct the rules from grouping observations. In order to find the rules, we need to find the common points within each group. Let's check the group result one by one:\n"
                    unfit_polygon = []
                    for key in list(noised_groups.keys()):
                        group_name = [i[0] for i in noised_groups[key]]
                        temp = f"The grouping rule for this group is {self.key2standard(key)}:\n"
                        unfit_flag = False
                        for name in group_name:
                            polygon = polygons[int(name.replace("Polygons ", ""))]
                            if str(polygon[1]) + "-sides" + "+" + polygon[2] + "+" + polygon[3] != key:
                                temp += "The " + polygon[0] + " has attributes that does not fit into this group.\n"
                                unfit_polygon.append([key, polygon[0]])
                                unfit_flag = True
                        if unfit_flag == False:
                            temp += "All polygons in this group fit into this rule.\n"
                        cot += temp
                    cot += "We have checked all the groups and polygons contained in the group, we found that:\n"
                    if len(unfit_polygon) == 0:
                        cot += "All polygons are put grouped into right groups.\nValidation Result:\nCorrect\nRectified Results:\nNone"
                    else:
                        for rule, polygon in unfit_polygon:
                            cot += "The " + polygon + " in " + rule + " is not correct.\n"
                        cot += "The correct validation results should be:\nValidation Result:\nIncorrect\nRectified Results:\n"
                        for group_idx, key in enumerate(list(groups.keys())):
                            cot += "Group " + str(group_idx) + ":" + str([i[0] for i in groups[key]]).lstrip("[").rstrip(
                            "]").replace("'", "") + "\n"
                else:
                    cot = ""
                    if groups == noised_groups:
                        cot += "Validation Result:\nCorrect\nRectified Results:\nNone"
                    else:
                        cot += "Validation Result:\nIncorrect\nRectified Results:\n"
                        temp = ""
                        for group_idx, key in enumerate(list(groups.keys())):
                            temp += "Group " + str(group_idx) + ":" + str([i[0] for i in groups[key]]).lstrip(
                                "[").rstrip("]").replace("'", "") + "\n"
                        cot += temp
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt

    def grouping_error_correction(self, few_shot_polygons, few_shot_groups, few_shot_noised_groups, grouping_rules, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(grouping_template_dict["error_correction_pair"][0], grouping_template_dict["error_correction_pair"][1])
        if few_shot:
            for polygons, groups, noised_groups in zip(few_shot_polygons, few_shot_groups, few_shot_noised_groups):
                polygons_text = ""
                for polygon in polygons:
                    polygons_text += polygon[0] + ":[Sides:" + str(polygon[1]) + ", Color:" + polygon[
                        2] + ", Material:" + polygon[3] + "]\n"
                example = f"Below are the polygons for this example:{polygons_text}\nBelow are the rules that are used to group different polygons into different groups which may be incorrect:\n{self.generating_rules(noised_groups)}\nThese are the correct grouping results for the above polygons with those attributes:\n{self.generate_example_text(noised_groups)}\nNow you need to check whether the grouping rules are correct or not if not give the rectified results  and your response must follow the response format. "
                if chain_of_thought:
                    cot = "As we know that the target is to conduct the whether the rules are correct or not. In order to find the rules, we need to find the common points within each group. Let's check the group result one by one:\n"
                    unfit_rules = []
                    for key in list(noised_groups.keys()):
                        group_name = [i[0] for i in noised_groups[key]]
                        temp = "Polygons contained in this group are:\n"
                        for idx, name in enumerate(group_name):
                            polygon = polygons[int(name.replace("Polygons ", ""))]
                            temp += "Polygon Name:" + polygon[0] + ",Number of Sides:" + str(polygon[1]) + ",Color:" + polygon[
                                2] + ",Materials:" + polygon[3] + "\n"
                        if str(polygon[1]) + "-sides" + "+" + polygon[2] + "+" + polygon[3] != key:
                            temp += "The Rule " + key + " is not able to describe this group.\n"
                            unfit_rules.append([key, str(polygon[1]) + "-sides" + "+" + polygon[2] + "+" + polygon[3]])
                        cot += temp
                    cot += "We have checked all the groups and polygons contained in the group, we found that:\n"
                    if len(unfit_rules) == 0:
                        cot += "Summarizing above information we have:\nCorrect Rules or Not:\nYes\nRectified Rules:\nThere is no rule to correct"
                    else:
                        for wrong_rule, correct_rule in unfit_rules:
                            cot += "The "+ wrong_rule + " is not correct.\n"
                        cot += "Summarizing above information we have:\nCorrect Rules or Not:\nNo\nRectified Rules:\n"
                        for idx, (wrong_rule, correct_rule) in enumerate(unfit_rules):
                            cot += f"{str(idx + 1)}.The wrong rule: {self.key2standard(correct_rule)} -> The correct rule: {self.key2standard(correct_rule)}\n"
                else:
                    unfit_rules = []
                    for key in list(noised_groups.keys()):
                        group_name = [i[0] for i in noised_groups[key]]
                        for idx, name in enumerate(group_name):
                            polygon = polygons[int(name.replace("Polygon ", ""))]
                        if str(polygon[1]) + "-sides" + "+" + polygon[2] + "+" + polygon[3] != key:
                            unfit_rules.append([key, str(polygon[1]) + "-sides" + "+" + polygon[2] + "+" + polygon[3]])
                    cot = ""
                    if groups == noised_groups:
                        cot += "Correct Rules or Not:\nYes\nRectified Rules:\nThere is no rule to correct"
                    else:
                        cot += "Correct Rules or Not:\nNo\nRectified Rules:\n"
                        for idx, (wrong_rule, correct_rule) in enumerate(unfit_rules):
                            cot += f"{str(idx + 1)}. The wrong rule: {self.key2standard(correct_rule)} -> The correct rule: {self.key2standard(correct_rule)}\n"
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt

    def grouping_rules_incorporating(self, few_shot_polygons, few_shot_groups, few_shot_noised_groups, grouping_rules, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(grouping_template_dict["incorporating_pair"][0], grouping_template_dict["incorporating_pair"][1])
        if few_shot:
            for polygons, groups, noised_groups in zip(few_shot_polygons, few_shot_groups, few_shot_noised_groups):
                polygons_text = ""
                for polygon in polygons:
                    polygons_text += polygon[0] + ":[Sides:" + str(polygon[1]) + ", Color:" + polygon[
                        2] + ", Material:" + polygon[3] + "]\n"
                example = f"Below are the polygons you can access:{polygons_text}\nBelow are the rules that are used to group different polygons into different groups which may be incomplete:\n{self.generating_rules(noised_groups)}\nBelow are the grouping results for the above polygons with those attributes:\n{self.generate_example_text(groups)}\nNow you need to check whether the grouping results provide new grouping rules or not. "
                if chain_of_thought:
                    cot = "As we know that the target is to conduct the rules from grouping observations. In order to find the rules, we need to find the common points within each group. Let's check the group result one by one:\n"
                    unfit_polygon = []
                    for key in list(groups.keys()):
                        if key in list(noised_groups.keys()):
                            group_name = [i[0] for i in noised_groups[key]]
                            temp = f"Polygons contained in this {key} group are:\n"
                            for name in group_name:
                                polygon = polygons[int(name.replace("Polygons ", ""))]
                                temp += "Polygon Name:" + polygon[0] + ",Number of Sides:" + str(polygon[1]) + ",Color:" + polygon[
                                    2] + ",Materials:" + polygon[3] + "\n"
                            temp += f"Those polygons are all attributes {str(polygon[1])}-sides+{polygon[2]}+{polygon[3]} which can be described using the classification rule of this group."
                            cot += temp
                        else:
                            group_name = [i[0] for i in groups[key]]
                            temp = f"Polygons contained in this {key} group are:\n"
                            for name in group_name:
                                polygon = polygons[int(name.replace("Polygons ", ""))]
                                temp += "Polygon Name:" + polygon[0] + ",Number of Sides:" + str(
                                    polygon[1]) + ",Color:" + polygon[
                                            2] + ",Materials:" + polygon[3] + "\n"
                            temp += f"Those polygons has attributes {str(polygon[1])}-sides+{polygon[2]}+{polygon[3]} which does not fit into the classifition rule of this group."
                            cot += temp
                            unfit_polygon.append(key)
                    cot += "We have checked all the groups and polygons contained in the group, we found that:\n"
                    if len(unfit_polygon) == 0:
                        cot += "All polygons are put grouped into right groups. There are not errors.\nNew Rules or Not:\nNo\nAdded Rules:\nThere is no rule to add.\n"
                    else:
                        cot += "There are some polygons cannot be described using current rules which we will add to the rules.\nNew Rules or Not:\nYes\nAdded Rules:\n"
                        for idx, rule in enumerate(unfit_polygon):
                            cot += f"{str(idx+1)}. Polygons with {self.key2standard(rule)} should be grouped together.\n"
                else:
                    cot = ""
                    if groups == noised_groups:
                        cot += "New Rules or Not:\nNo\nAdded Rules:\nThere is no rule to add."
                    else:
                        cot += "New Rules or Not:\nYes\nAdded Rules:\n"
                        idx = 0
                        for key in list(groups.keys()):
                            if key not in list(noised_groups.keys()):
                                cot += f"{str(idx + 1)}. Polygons with {self.key2standard(key)} should be grouped together.\n"
                                idx += 1
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt

    def generate_grouping_results(self, polygons, sides, color, materials, grouping_results):
        grouping = ""
        polygons_text = ""
        sides = str(sides).rstrip("]").lstrip("[")
        materials = str(materials).rstrip("]").lstrip("[")
        colors = str(color).rstrip("]").lstrip("[")
        rules = ""
        for idx, key in enumerate(list(grouping_results.keys())):
            rules += "Rule " + str(idx) + ": Polygons with " + self.key2standard(key) + " should be grouped together.\n"
            names = [polygon[0] for polygon in grouping_results[key]]
            grouping += "Group " + str(idx) + ": "
            for name in names:
                grouping += name+", "
            grouping = grouping.rstrip(", ")
            grouping += "\n"
        for polygon in polygons:
            polygons_text += polygon[0] +":[Sides:" + str(polygon[1]) +", Color:"+polygon[2]  +", Material:"+polygon[3] + "]\n"
        return polygons_text, sides, materials, colors, grouping, rules

    def generate_example_text(self, grouping_results):
        grouping = ""
        for idx, key in enumerate(list(grouping_results.keys())):
            names = [polygon[0] for polygon in grouping_results[key]]
            grouping += "Group " + str(idx) + ": "
            for name in names:
                grouping += name+","
            grouping = grouping.rstrip(",")
            grouping += "\n"
        return grouping
    def noised_groups(self, groups, noised_example=3, determinstic="Random"):
        keys = list(groups.keys())
        new_groups = copy.deepcopy(groups)
        if determinstic == "Clean":
            return new_groups
        else:
            if determinstic == "Noise":
                threshold = 1
            elif determinstic == "Random":
                threshold = 0.5
            if random.random() < threshold:
                while noised_example != 0:
                    random.shuffle(keys)
                    while len(new_groups[keys[0]]) <= 1:
                        random.shuffle(keys)
                    random_item = random.choice(new_groups[keys[0]])
                    new_groups[keys[0]].remove(random_item)
                    new_groups[keys[1]].append(random_item)
                    noised_example -= 1
        return new_groups

    def noised_rules(self, rules, sides_option, color_option, material_option, noised_example=3, deterministic="Random"):
        rule_collection = list(rules.keys())
        new_rule = copy.deepcopy(rule_collection)
        noised_rules = copy.deepcopy(rules)
        if deterministic == "Clean":
            return rules, noised_rules
        else:
            if deterministic == "Noise":
                threshold = 1
            elif deterministic == "Random":
                threshold = 0.5
            if random.random()<threshold:
                noised_rule_collection = []
                while noised_example != 0:
                    noised_example -= 1
                    random.shuffle(rule_collection)
                    random_rule = random.choice(new_rule) #exchange
                    while random_rule in noised_rule_collection:
                        random_rule = random.choice(new_rule)
                    noised_rule_collection.append(random_rule)
                    sides, color, material=random_rule.split("+")
                    noised_attribute = random.choice(["sides", "color", "material"])
                    if noised_attribute == "sides":
                        side_num = int(sides.replace("-sides", ""))
                        random_picked_side = random.choice(sides_option)
                        while side_num == random_picked_side:
                            random_picked_side = random.choice(sides_option)
                        random_picked_sides = str(random.choice(sides_option)) + "-sides"
                        noised_rule = random_picked_sides + "+" + color + "+" + material
                    elif noised_attribute == "color":
                        random_picked_color = random.choice(color_option)
                        while color == random_picked_color:
                            random_picked_color = random.choice(color_option)
                        noised_rule = sides + "+" +random_picked_color + "+" + material
                    elif noised_attribute == "material":
                        random_picked_material = random.choice(material_option)
                        while material == random_picked_material:
                            random_picked_material = random.choice(material_option)
                        noised_rule = sides + "+" + color + "+" + random_picked_material
                    noised_rules[noised_rule] = noised_rules.pop(random_rule)
        return rules, noised_rules

    def deleted_rules(self, rules, noised_example=3, determinstic="Random"):
        rule_collection = list(rules.keys())
        new_rule = copy.deepcopy(rule_collection)
        incompelte_noised_rules = copy.deepcopy(rules)
        if determinstic == "Clean":
            return rules, incompelte_noised_rules
        else:
            if determinstic == "Random":
                threshold = 0.5
            elif determinstic == "Noise":
                threshold = 1
            if random.random() < threshold:
                noised_rule_collection = []
                while noised_example != 0:
                    noised_example -= 1
                    random.shuffle(rule_collection)
                    random_rule = random.choice(new_rule)  # exchange
                    while random_rule in noised_rule_collection:
                        random_rule = random.choice(new_rule)
                    noised_rule_collection.append(random_rule)
                    del incompelte_noised_rules[random_rule]
        return rules, incompelte_noised_rules

    def generating_rules(self, groups):
        temp = ""
        for idx, key in enumerate(list(groups.keys())):
            sides, color, material = key.split("+")
            side_num, _ = sides.split("-")
            temp += "Rule " + str(idx) +":" + side_num+" Sides"+", "+color +" Color"+ ", and "+material+ "\n"
        return temp

    def generating_polygons(self, polygons):
        temp = ""
        for idx, polygon in enumerate(polygons):
            temp += f"{polygon[0]}:Sides:{polygon[1]}, Color:{polygon[2]}, Material:{polygon[3]}\n"
        return temp

    def key2standard(self, key_text):
        sides, color, material = key_text.split("+")
        side_num, _ = sides.split("-")
        return f"{side_num} Sides, {color} Color, and {material}"

    def stardard2key(self, standard_text):
        side_num, color, material = standard_text.split(",")
        side_num = side_num.split(" ")[0]
        color = color.split(" ")[1]
        material = material.split(" ")[2]
        return f"{side_num}-sides+{color}+{material}"

class OrderingGenerator(PromptGenerator):

    def ordering_instruction(self, examples, unsorted_examples, rules, model_name="gpt", chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(ordering_template_dict["instruction_pair"][0], ordering_template_dict["instruction_pair"][1])
        if few_shot == True:
            for idx, (sorted_example, unsorted_example, rule) in enumerate(zip(examples, unsorted_examples, rules)):
                sorted_rules = sorted(rule.items(), key=lambda x: x[1])
                color_preference = ""
                for color_rank in sorted_rules:
                    color_preference += f"Rank {str(color_rank[1])}: {color_rank[0]}\n"
                example_text = f"You have access to the following color preference rules that describes the correct color preference rank that you can use to sort the following unordered color list, but do not output the color preference rank directly, you should sort the following color list according to following color preference rules:\n{color_preference}\nNow try your best to sort the following unordered color list according to the given color preference rules above and your response should follow the response format, don't just copy the color preference rank above but try to sort the following color list according to the given color preference rules above:\n{str(unsorted_example[0])}\n"
                if chain_of_thought == True:
                    cot = "As the task is to order the given color list, let's check each color in the list one by one and check its preference in the preference ranks.\n"
                    for color in unsorted_example[0]:
                        cot += f"{color} is ranked {str(rule[color])}.\n"
                    cot += f"Now we have checked all the colors in the list. We now knows the rank of each color in the unsorted examples list. Let's sort the list based on the rank of each color.\nSorted Color List:\n"
                    for idx, color in enumerate(sorted_example[0]):
                        cot += f"{str(idx+1)}. {color}.\n"
                else:
                    cot = "Sorted Color List:\n"
                    for idx, color in enumerate(sorted_example[0]):
                        cot += f"{str(idx + 1)}. {color}.\n"
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example_text, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt


    def ordering_induction(self, examples, explanations, rules, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(ordering_template_dict["inducting_pair"][0], ordering_template_dict["inducting_pair"][1])
        for idx, (example, cot, rule) in enumerate(zip(examples, explanations, rules)):
            example_text = ""
            for subidx, sub_example in enumerate(example):
                example_text += f"Ordered Color List {str(subidx)}:{str(sub_example)}\n"
            if chain_of_thought == True and few_shot == True:
                cot = cot
            elif chain_of_thought == False and few_shot == True:
                sorted_rules = sorted(rule.items(), key=lambda x: x[1])
                rules_text = ""
                for color_rank in sorted_rules:
                    rules_text += "Rank " + str(color_rank[1]) + " " + color_rank[0] + "\n"
                cot = "We have checked all the given ordered color lists, we conclude the following rules:\nColor Ranking:\n"
                cot += rules_text + "\n"
            if few_shot == True:
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example_text, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt

    def ordering_results_validating(self, few_shot_examples, few_shot_noised_examples, few_shot_noised_index, few_shot_explanations, few_shot_prioritys, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(ordering_template_dict["validating_pair"][0], ordering_template_dict["validating_pair"][1])
        if few_shot == True:
            for example, noised_example, noised_index, priority in zip(few_shot_examples, few_shot_noised_examples, few_shot_noised_index, few_shot_prioritys):
                preference_text = self.generate_preference_text(priority)
                example_text = self.generate_example_text(noised_example)
                example_text = f"You have access to the following color preference rules that describes the correct color preference:\n{preference_text}\nYou have the following Ordered Color results that may not be correct:\n{example_text}\n"
                sorted_color_list = [k for k, v in sorted(priority.items(), key=lambda item: item[1])]
                if chain_of_thought:
                    cot = "As we know the priority of each color, let's check each color's priority in the priority dictionary.\n"
                    for idx, color in enumerate(noised_example):
                        cot += "Color " + color + " : " + str(priority[color]) + "\n"
                    if noised_index == []:
                        cot += '''The priority of each color is correct, the given ordered color list follows the color preference.\nWe conclude the following:\nCorrect Results or Not:\nYes\nRectified Results:\nThere is no error to correct.'''
                    else:
                        noised_indices, original_indices = noised_index[0], noised_index[1]
                        cot += "We have found that some colors with its rank appears in wrong locations. By comparing it with correct color preferences, we can conclude:\nCorrect Results or Not:\nNo\nRectified Results:\n"
                        for noised_index, original_index in zip(noised_indices, original_indices):
                            cot += f"Wrong Priority Color: {noised_example[noised_index]} -> Rectified Priority Color: {sorted_color_list[original_index]}\n"
                else:
                    cot = ""
                    #cot = ""
                    if noised_index == []:
                        cot += '''Correct Results or Not:\nYes\nRectified Results:\nThere is no error to correct.'''
                        #cot += '''Correct Results or Not:\nYes\nRectified Results:\nThere is no error to correct.'''
                    else:
                        noised_indices, original_indices = noised_index[0], noised_index[1]
                        cot += "Correct Results or Not:\nNo\nRectified Results:\n"
                        #cot += "Correct Results or Not:\nNo\nRectified Results:\n"
                        for noised_index, original_index in zip(noised_indices, original_indices):
                            cot += f"Wrong Priority Color: {noised_example[noised_index]} -> Rectified Priority Color: {sorted_color_list[original_index]}\n"
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example_text, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt

    def ordering_error_correction(self, few_shot_examples, few_shot_noised_rules, explanations, few_shot_prioritys, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(ordering_template_dict["error_correction_pair"][0], ordering_template_dict["error_correction_pair"][1])
        if few_shot == True:
            for example, noised_priority, priority in zip(few_shot_examples, few_shot_noised_rules, few_shot_prioritys):
                preference_text = self.generate_preference_text(noised_priority)
                example_text = self.generate_example_text(example[0])
                example_text = f"Try your best to answer the question using above Response Format that whether following rules contain incorrect rules or not:\n{preference_text}:\nFollowing is the correct ordered list of colors:\n{example_text}\nNow you need to induct whether there are wrong rules existing in the given pre-defined color preference rules and your response should follow the response format.\n"
                sorted_color_list = [k for k, v in sorted(priority.items(), key=lambda item: item[1])]
                noised_sorted_color_list = [k for k, v in sorted(noised_priority.items(), key=lambda item: item[1])]
                if chain_of_thought:
                    cot = "Now that we know we have the correct ordered color list. Let's compare the correct ordered color list and the given color preferneces. If they are the same, we can conclude that the given color preferences are correct, if not, we can know what the given color preferneces are wrong.\nThe Current Given Color Preferences:\n"
                    for rank, color in enumerate(list(noised_priority.keys())):
                        cot += "Rank " + str(rank+1) + " : "  + color + "\n"
                    cot += "The Correct Ordered List:\n"
                    for rank, color in enumerate(example[0]):
                        cot += "Rank " + str(rank+1) + " : " + color + "\n"
                    if noised_sorted_color_list == sorted_color_list:
                        cot += "By comparing the correct ordered color list and the given color preferences, we found that they are the same, so the given color preferences are correct.\n"
                        cot += "Correct Rules or Not:\nYes\nRectified Rules:\nThere is no rule to correct."
                    else:
                        cot += "Summarizing above information, we found that the given correct ordered color list and the given color preferences are not the same, so the color preferences are not correct. They are not correct in the following rules:\n"
                        for key in noised_priority.keys():
                            if noised_priority[key] != priority.get(key, None):
                                cot += "The color " + key + " has a wrong priority, it should be " + str(priority[key]) + "\n"
                        cot += "Summarizing above information, we can conclude the following:\n"
                        cot += "Correct Rules or Not:\nNo\nRectified Rules:\n"
                        for key in noised_priority.keys():
                            if noised_priority[key] != priority.get(key, None):
                                cot += "Rank "+str(priority[key]) + " : " + key + "\n"
                else:
                    cot = ""
                    if noised_sorted_color_list == sorted_color_list:
                        cot += "Correct Rules or Not:\nYes\nRectified Rules:\nThere is no rule to correct."
                    else:
                        cot += "Correct Rules or Not:\nNo\nRectified Rules:\n"
                        for key in noised_priority.keys():
                            if noised_priority[key] != priority.get(key, None):
                                cot += "Rank "+str(priority[key]) + " : " + key + "\n"
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example_text, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt
    def ordering_rules_incorporating(self, few_shot_examples, few_shot_deleted_rules, few_shot_explanations, few_shot_prioritys, chain_of_thought=True, few_shot=True):
        system_message_prompt, human_message_prompt, few_shot_messages = self.creating_initial_messages(
            ordering_template_dict["incorporating_pair"][0], ordering_template_dict["incorporating_pair"][1])
        if few_shot == True:
            for example, deleted_priority, explanation, priority in zip(few_shot_examples, few_shot_deleted_rules, few_shot_explanations, few_shot_prioritys):
                preference_text = self.generate_preference_text(deleted_priority)
                example_text = self.generate_multiple_example_text(example)
                example_text = f"You have access to the following original color preference rules that may be incomplete:\n{preference_text}\nFollowing is the ordered list of colors that may provide new information:\n{example_text}\nNow you need to induct whether we can induct new rule/rules from the given results and your response should follow the response format. Remember that the new rule/rules should be in the new incorporated color preference rather than the original given color preference."
                sorted_color_list = [k for k, v in sorted(priority.items(), key=lambda item: item[1])]
                noised_sorted_color_list = [k for k, v in sorted(deleted_priority.items(), key=lambda item: item[1])]
                if chain_of_thought:
                    cot = "Now that we know that the data is ordered by correct color preference and the rules may not be complete. Let's inference rules from those observations of ordered color lists. If the inferenced results are more complete than the given rules, we can conclude that the given rules are not complete. If the inferenced results are not more complete than the given rules, we can conclude that the given rules are complete."
                    cot += explanation
                    if noised_sorted_color_list == sorted_color_list:
                        cot += "Through comparing the infered results from observations of ordered color lists and the rules given in previous. We found that the infered results are the same with the given rules, so the given rules are complete.\n"
                        cot += "New Rules or Not:\nNo\nRectified Results:\nThere is no rule to add"
                    else:
                        cot += "Through comparing the infered results from observations of ordered color lists and the rules given in previous. We found that the infered results are more complete than the given rules, so the given rules are not complete.\n We can conclude the following:\n"
                        cot += "New Rules or Not:\nYes\nNew Inducted:\n"
                        for key in priority.keys():
                            if key not in deleted_priority.keys():
                                cot += "Rank " +str(priority[key]) + ":" + key + "\n"
                else:
                    cot = ""
                    if noised_sorted_color_list == sorted_color_list:
                        cot += "New Rules or Not:\nNo\nNew Inducted:\n" + "There is no rule to add"
                    else:
                        cot += "New Rules or Not:\nYes\nNew Inducted:\n"
                        temp = 1
                        for key in priority.keys():
                            if key not in deleted_priority.keys():
                                cot += f"{temp}. Rank {str(priority[key])} : {key}\n"
                                temp += 1
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        example_text, additional_kwargs={"name": "example_user"}
                    ))
                few_shot_messages.append(
                    SystemMessagePromptTemplate.from_template(
                        cot, additional_kwargs={"name": "example_assistant"}
                    ))
        all_prompt = [system_message_prompt]
        if few_shot_messages != []:
            all_prompt.extend(few_shot_messages)
        all_prompt.append(human_message_prompt)
        return all_prompt

    def noised_ordering(self, color_preference, noised_example=3, determinstic="Random"):
        noised_examples = []
        sorted_color_list = [k for k, v in sorted(color_preference.items(), key=lambda item: item[1])]
        noised_index_list = []
        if determinstic == "Clean":
            noised_examples.extend(sorted_color_list)
            noised_index_list = []
            return noised_examples, noised_index_list
        else:
            if determinstic == "Noise":
                threshold = 1
            elif determinstic == "Random":
                threshold = 0.5
            if random.random() < threshold:
                random_colors = random.sample(sorted_color_list, noised_example)
                original_indices = [color_preference[color]-1 for color in random_colors]
                original_indices.sort()
                new_indices = []
                displaced_value = sorted_color_list[original_indices[0]]
                for idx, value in enumerate(original_indices[:len(original_indices)-1]):
                    sorted_color_list[value] = sorted_color_list[original_indices[idx+1]]
                    new_indices.append(value)
                sorted_color_list[original_indices[-1]] = displaced_value
                noised_examples.extend(sorted_color_list)
                new_indices.append(original_indices[-1])
                # for col in random_colors:
                #     sorted_color_list.remove(col)  # 先移除颜色
                #     new_index = random.randint(0, len(sorted_color_list))
                #     sorted_color_list.insert(new_index, col)  # 再插入颜色
                #     new_indices.append(new_index)
                # noised_examples.extend(sorted_color_list)
                noised_index_list.extend([new_indices, original_indices])
            else:
                noised_examples.extend(sorted_color_list)
                noised_index_list = []
            return noised_examples, noised_index_list

    def noised_rules(self, color_preference, noised_example=3, determinstic="Random"):
        noised_dict = copy.deepcopy(color_preference)
        if determinstic == "Clean":
            return noised_dict
        elif determinstic == "Noise":
            chosen_colors = random.sample(list(color_preference.keys()), noised_example)
            noised_dict[chosen_colors[0]], noised_dict[chosen_colors[1]], noised_dict[chosen_colors[2]] = noised_dict[
                chosen_colors[1]], color_preference[chosen_colors[2]], color_preference[chosen_colors[0]]
        elif determinstic == "Random":
            if random.random() < 0.5:
                chosen_colors = random.sample(list(color_preference.keys()), noised_example)
                for idx, color in enumerate(chosen_colors):
                    if idx == len(chosen_colors)-1:
                        noised_dict[chosen_colors[idx]] = color_preference[chosen_colors[0]]
                    else:
                        noised_dict[chosen_colors[idx]] = color_preference[chosen_colors[idx+1]]
                # noised_dict[chosen_colors[0]], noised_dict[chosen_colors[1]], noised_dict[chosen_colors[2]] = noised_dict[
                #     chosen_colors[1]], color_preference[chosen_colors[2]], color_preference[chosen_colors[0]]
        return noised_dict

    def deleted_ruels(self, color_preference, noised_example=3, determinstic="Random"):
        deleted_dict = copy.deepcopy(color_preference)
        if determinstic == "Clean":
            return deleted_dict
        elif determinstic == "Noise":
            for _ in range(noised_example):
                color_remove = random.choice(list(deleted_dict.keys()))
                del deleted_dict[color_remove]
            deleted_dict = {k: i + 1 for i, k in enumerate(deleted_dict)}
        elif determinstic == "Random":
            if random.random() < 0.5:
                for _ in range(noised_example):
                    color_remove = random.choice(list(deleted_dict.keys()))
                    del deleted_dict[color_remove]
                deleted_dict = {k: i + 1 for i, k in enumerate(deleted_dict)}
        return deleted_dict
    def generate_example_text(self, example):
        example_text = "Ordered Color: \n"
        for idx, color in enumerate(example):
            example_text += "Rank " + str(idx+1) + " : " + color + "\n"
        return example_text

    def generate_multiple_example_text(self, example):
        example_text = ""
        for idx, sub_example in enumerate(example):
            example_text += "Ordered Color List " + str(idx) + " : " + str(sub_example) + "\n"
        return example_text

    def generate_preference_text(self, prioritys):
        rules_text = "Color Ranking:\n"
        for color, rank in prioritys.items():
            rules_text += "Rank " + str(rank) + " : "  + color + "\n"
        return rules_text