import random
from collections import defaultdict
from tqdm import tqdm
import itertools
from typing import List, Tuple, Dict
import datasets
import string
import copy
from utlis import unique_character, alphabet_position
class DataCreator():# the base class for data creation
    def __init__(self, data):
        self.data = data

    def create(self):
        return self.data
class CharacterMappingDataCreator():
    def __init__(self, dataset, num_examples, max_length=100):
        self.dataset = datasets.load_dataset(dataset)
        self.num_example = num_examples
        self.max_length = max_length
    def create(self):
        length_filtered = self.dataset.filter(lambda example: len(example["review"]) < self.max_length)["train"]["review"]
        data = length_filtered[:self.num_example]
        data_sample = []
        few_shot = [length_filtered[2123], length_filtered[4863], length_filtered[3213], length_filtered[7702], length_filtered[4591]]
        few_shot_sample = []
        for i in range(self.num_example):
            data_sample.append([data[i], self.caesar_encrypt(data[i], 3)])
        for j in range(len(few_shot)):
            few_shot_sample.append([few_shot[j], self.caesar_encrypt(few_shot[j], 3)])
        return data_sample, few_shot_sample

    def caesar_encrypt(self, text, shift=3):
        encrypted_text = ""
        for char in text:
            if char.isalpha():
                ascii_offset = ord('A') if char.isupper() else ord('a')
                encrypted_char = chr((ord(char) - ascii_offset + shift) % 26 + ascii_offset)
                encrypted_text += encrypted_char
            else:
                encrypted_text += char
        return encrypted_text

    def generatele_mapping_table(self, shift=3, reveal=3, noised=False, noise_num=2):
        if not noised:
            encrypt = self.caesar_encrypt("ABCDEFGHIJKLMNOPQRSTUVWXYZ", shift)
            table = []
            for en, de in zip(encrypt[:reveal], string.ascii_uppercase[:reveal]):
                table.append({"Original": de, "Altered": en})
            return table
        elif noised:
            encrypt = self.caesar_encrypt("ABCDEFGHIJKLMNOPQRSTUVWXYZ", shift)
            table = []
            for en, de in zip(encrypt[:reveal], string.ascii_uppercase[:reveal]):
                table.append({"Original": de, "Altered": en})
            sampled_idx = random.sample(range(len(table)), noise_num)
            noised_table = copy.copy(table)
            noised_table[sampled_idx[0]]["Altered"] = table[sampled_idx[1]]["Altered"]
            noised_table[sampled_idx[1]]["Altered"] = table[sampled_idx[0]]["Altered"]
            return table, noised_table, sampled_idx

    def mappingtable2text(self, examples):
        rules_string = ""
        for ex in examples:
           original = ex['Original']
           altered = ex['Altered']
           rules_string += f"Original: {original} -> Altered: {altered}\n"
        return rules_string

    def random_example_noise(self, examples, noise_rate=0.25, rules=None, noise_num=3, controller=False):
        originals = []
        clean_encoded_examples = []
        noised_examples = []
        records = []
        if controller == False:
            for exp_idx, (original, encoded) in enumerate(examples):
                originals.append(original)
                clean_encoded_examples.append(encoded)
                counter = unique_character(original)
                occured_characters = list(counter.keys())
                if random.random() < noise_rate:
                    try:
                        noise_sample = random.sample([rules[alphabet_position(x)] for x in occured_characters]
                                                     ,noise_num)
                        records.append(noise_sample)
                        replace = ""
                        for i in encoded:
                            if i.upper() in [x["Altered"] for x in noise_sample]:
                                i = random.sample([x["Original"] for x in noise_sample], 1)
                            replace += i[0]
                    except ValueError:
                        replace = encoded
                        records.append(None)
                else:
                    replace = encoded
                    records.append(None)
                noised_examples.append(replace)
        else:
            length = len(examples)
            for exp_idx, (original, encoded) in enumerate(examples):
                originals.append(original)
                clean_encoded_examples.append(encoded)
                counter = unique_character(original)
                occured_characters = list(counter.keys())
                if exp_idx < int(length/2):
                    noise_sample = random.sample([rules[alphabet_position(x)] for x in occured_characters]
, noise_num)
                    records.append(noise_sample)
                    replace = ""
                    for i in encoded:
                        if i.upper() in [x["Altered"] for x in noise_sample]:
                            i = random.sample([x["Original"] for x in noise_sample], 1)
                        replace += i[0]
                else:
                    records.append(None)
                    replace = encoded
                noised_examples.append(replace)
        return originals, clean_encoded_examples, noised_examples, records

    def random_rule_noise(self, rules, few_shot_examples, determinstic=False):
        few_shot = []
        if determinstic == True:
            determined_positive = int(len(few_shot_examples) / 2)
        for idx, example in enumerate(few_shot_examples):
            counter = unique_character(example)
            occured_characters = list(counter.keys())
            wrong_rules = copy.deepcopy(rules)
            wrong_rule_flag = False
            if random.random() < 0.5 and determinstic == False:
                if len(occured_characters)>2:
                    wrong_rule_flag = True
                    target_noise_character = random.sample(occured_characters, 2)
                    temp = wrong_rules[alphabet_position(target_noise_character[0])]["Altered"]
                    wrong_rules[alphabet_position(target_noise_character[0])]["Altered"] = \
                    wrong_rules[alphabet_position(target_noise_character[1])]["Altered"].upper()
                    wrong_rules[alphabet_position(target_noise_character[1])]["Altered"] = temp
            elif determinstic == True:
                if idx in range(determined_positive):
                    wrong_rule_flag = True
                    target_noise_character = random.sample(occured_characters, 2)
                    temp = wrong_rules[alphabet_position(target_noise_character[0])]["Altered"]
                    wrong_rules[alphabet_position(target_noise_character[0])]["Altered"] = \
                        wrong_rules[alphabet_position(target_noise_character[1])]["Altered"].upper()
                    wrong_rules[alphabet_position(target_noise_character[1])]["Altered"] = temp
                else:
                    wrong_rule_flag = False
            few_shot.append([wrong_rules, example, wrong_rule_flag])
        return few_shot

    def create_inducting_example(self, table, examples):
        inducting_table = []
        for example in examples:
            unique_characters = unique_character(example[0])
            copy_table = copy.deepcopy(table)
            deletion_collection = []
            for mapping in copy_table:
                if mapping["Original"].lower() in unique_characters.keys() and mapping["Original"].lower() not in deletion_collection:
                    deletion_collection.append(alphabet_position(mapping["Original"].lower()))
            new_table = [copy_table[idx] for idx in range(len(copy_table)) if idx in deletion_collection]
            inducting_table.append(new_table)
        return inducting_table


class GroupingDataCreator():
    def __init__(self, num_polygons, num_rules, sides_options, colors_options, materials_options):
        self.num_polygons = num_polygons
        self.num_rules = num_rules
        self.sides_options = sides_options
        self.colors_options = colors_options
        self.materials_options = materials_options
        self.all_rules = list(itertools.product(sides_options, colors_options, materials_options))
        self.extracted_combinations = set()

    def get_rules(self):
        while True:
            new_combination = tuple(sorted(random.sample(self.all_rules, self.num_rules)))
            if new_combination not in self.extracted_combinations:
                self.extracted_combinations.add(new_combination)
                return list(new_combination)

    def create(self) -> Tuple[List[Tuple[int, str, str]], Dict[str, List[Tuple[int, str, str]]], str]:
        # Generate polygons

        selected_rules = self.get_rules()
        polygons = []
        for i in range(self.num_polygons):
            sides, color, material = random.choice(selected_rules)
            polygons.append(("Polygon " + str(i), sides, color, material))

        # Group polygons by rules
        groups = {}
        for polygon in polygons:
            group_key = f"{polygon[1]}-sides+{polygon[2]}+{polygon[3]}"
            if group_key not in groups:
                groups[group_key] = []
            groups[group_key].append(polygon)

        # Define grouping rules
        grouping_rules = f"Sides options: {self.sides_options}, Color options: {self.colors_options}, Material options: {self.materials_options}"

        return polygons, groups, grouping_rules

class OrderingDataCreator():
    def __init__(self, colors):
        self.colors = colors
    def color_preference_generation(self):
        random.shuffle(self.colors)
        color_preferences = {}
        rank = 1
        for color in self.colors:
            color_preferences[color] = rank
            rank += 1
        return color_preferences

    def random_sum_partition(self, n, total):
        # 随机抽样，得到n-1个随机数
        numbers = sorted(random.sample(range(1, total), n - 1))

        # 添加整数的开始和结束
        partition_points = [0] + numbers + [total]

        # 使用这些分区点计算出每个子整数
        return [partition_points[i + 1] - partition_points[i] for i in range(n)]
    def generate_examples(self, priority, num_examples, repeat=True, missing=True):
        examples = []
        unsorted_examples = []
        partitioned_selection = self.random_sum_partition(num_examples, len(priority.keys()))
        for idx in range(num_examples):
            example = []
            color_order = list(priority.keys())
            # Remove one color from the order
            colors = random.sample(color_order, partitioned_selection[idx])
            for color in colors:
                example.append(color)
                color_order.remove(color)
            # Add color to example multiple times with varying frequency
            if num_examples != 1:
                random_selected = random.sample(color_order, k=random.randint(1, len(color_order)))
                example.extend(random_selected)
            # Sort example by color priority
            unsorted_examples.append(copy.deepcopy(example))
            example.sort(key=lambda color: priority[color])
            examples.append(example)
        return examples, unsorted_examples

    def generate_color_preferences(self, sorted_color_lists):
        color_rankings = defaultdict(int)
        explanation = "Starting process... There are {} color lists to analyze.\n".format(len(sorted_color_lists))
        first_occurrance_order = {color: rank for rank, color in enumerate(
            sorted(set(color for color_list in sorted_color_lists for color in color_list)), start=1)}

        # For each color list
        for i, color_list in enumerate(sorted_color_lists, start=1):
            explanation += "\nNow starting with list {} of total {}: {}\n".format(i, len(sorted_color_lists),
                                                                                  color_list)
            # For each color in the list
            for index, color in enumerate(color_list):
                # If color has already been accounted for in this list, skip it
                # Assign a score based on the position in the first occurance order, earlier occurrence results in higher score
                score = index
                color_rankings[color] += score

                explanation += " Color '{}' at position {} getd a score of {}. Current total score for color '{}' is {}.\n".format(
                    color, index + 1, score, color, color_rankings[color])

            explanation += "Done processing list {} of total {}.\n".format(i, len(sorted_color_lists))

        explanation += "\nAll color list evaluations are complete. Finalizing preferences by sorting colors based on total points.\n"

        # Sort colors according to their total score (descending order)
        preference_order = {
            color: rank
            for rank, color in enumerate(sorted(color_rankings, key=color_rankings.get), 1)
        }

        explanation += "\nThe finalized order of color preferences:\n"
        for rank, color in enumerate(sorted(color_rankings, key=color_rankings.get), 1):
            explanation += " Rank {}: color '{}'.\n".format(rank, color, color_rankings[color])

        return preference_order, explanation

    def create_inducting(self, num_examples, num_partitions=5):
        prioritys = []
        examples = []
        unsorted_examples = []
        explanations = []
        for _ in tqdm(range(num_examples)):
            priority = self.color_preference_generation()
            infered_dict = {}
            while infered_dict != priority:
                example, unsorted_example = self.generate_examples(priority, num_partitions)
                # 测试函数
                infered_dict, explanation = self.generate_color_preferences(example)
            prioritys.append(priority)
            examples.append(example)
            unsorted_examples.append(unsorted_example)
            explanations.append(explanation)
        return prioritys, examples, unsorted_examples, explanations

    def create_normal(self, num_examples, repeat=True, repeat_num=6, missing=True, missing_num=6):
        prioritys = []
        examples = []
        unsorted_examples = []
        explanations = []
        for _ in tqdm(range(num_examples)):
            priority = self.color_preference_generation()
            example, unsorted_example = self.generate_examples(priority, 1)
            if missing:
                selectd_index = random.sample(range(len(unsorted_example[0])), missing_num)
                unsorted_example = [x for idx, x in enumerate(unsorted_example[0]) if idx not in selectd_index]
                example = copy.deepcopy(unsorted_example)
                example.sort(key=lambda color: priority[color])
            if repeat:
                selectd_index = random.sample(range(len(unsorted_example)), repeat_num)
                for i in selectd_index:
                   insert_color = random.sample(list(priority.keys()),1)
                   unsorted_example.insert(i, insert_color[0])
                example = copy.deepcopy(unsorted_example)
                example.sort(key=lambda color: priority[color])
            if missing or repeat:
                prioritys.append(priority)
                examples.append([example])
                unsorted_examples.append([unsorted_example])
                explanations.append("None")
            else:
                prioritys.append(priority)
                examples.append(example)
                unsorted_examples.append(unsorted_example)
                explanations.append("None")
        return prioritys, examples, unsorted_examples, explanations



