from langchain.output_parsers import CommaSeparatedListOutputParser
import re
class CharacterMappingParser(CommaSeparatedListOutputParser):
    def get_format_instructions(self) -> str:
        return (
            "The output should be a pair of source text and altered text, where the symbol used to connect them is '->'. "
            "As an example, if the altered text is 'b', the source text is 'a', then the output should be:"
            "Original: a-> Altered: b, each pair is seperated by a comma(,) ."
        )

    def parse_instruction(self, text):
        pattern = r'Altered: (.*?)(?:\n|$)'
        try:
            match = re.search(pattern, text, re.DOTALL)
            return match.group(1)
        except AttributeError:
            return "Unparsable"
    def parse_inducting(self, text):
        match = re.search('Rules:(.*)', text, re.DOTALL)
        if match:
            result_part = match.group(1)
        # Find all 'Original: x -> Altered: y' patterns
        pairs = re.findall('Original: (\w) -> Altered: (\w)', result_part)
        return pairs
    def parse_validating(self, text, model_name):
        text = text.replace("\nSystem: ", "")
        validation_result_pattern1 = 'Validation Result:\n(.+)'
        validation_result_pattern2 = 'Validation Result: (.+)\n'
        validation_result_match1 = re.search(validation_result_pattern1, text)
        validation_result_match2 = re.search(validation_result_pattern2, text)
        validation_result = validation_result_match1 if validation_result_match1 else validation_result_match2
        result = "Altered: "
        if validation_result != None:
            if validation_result.group(1) == "Invalid":
                pattern1 = 'Rectified Results:\n(.*)'
                pattern2 = 'Rectified Results: (.*)'
                result1 = re.search(pattern1, text)
                result2 = re.search(pattern2, text)
                result = result1.group(1) if result1 else result2.group(1)
            validation_result = validation_result.group(1)
        else:
            validation_result = "Unparsable"
        return validation_result, result
    def parse_error_correction(self, text):
        correct_or_not1 = re.search(r'Correct Rules or Not:\n(.+)', text)
        correct_or_not2 = re.search(r'Correct Rules or Not: (.+)\n', text)
        correct_or_not = correct_or_not1 if correct_or_not1 != None else correct_or_not2
        if correct_or_not:
            predict = correct_or_not.group(1)
        else:
            predict = "Unparsable"
        if predict == "No":
            pattern = r'Original: (\w) -> Altered: (\w)'
            results_text = text.split("Rectified Rules:\n")
            results = re.findall(pattern, results_text[1]) if len(results_text)>1 else re.findall(pattern, results_text[0])
        else:
            results = []
        return predict, results
    def parser_incorporating(self, text):
        new_info_result1 = re.search(r'New Information Contained:\n(.+)', text)
        new_info_result2 = re.search(r'New Information Contained: (.+)\n', text)
        new_info_result = new_info_result1 if new_info_result1 != None else new_info_result2
        if new_info_result:
            new_info = new_info_result.group(1)
        else:
            new_info = None
        rules_result = re.findall(r'Original: (\w) -> Altered: (\w)', text)
        return new_info, rules_result

class GroupingOutputParser(CommaSeparatedListOutputParser):

    def parse_instruction(self, text):
        groups = text.strip().split('Group ')
        groups_dict = {}
        for group in groups[1:]:
            try:
                group_num, polygons = group.split(':', 1)
                polygon_nums = re.findall(r'Polygon (\d+)', polygons)
                groups_dict[int(group_num)] = [int(num) for num in polygon_nums]
            except ValueError:
                pass
        return groups_dict

    def parse_inducting(self, text):
        lines1 = re.findall(r"Polygons with (\d+) Sides, (\w+) Color, and (\w+)", text)
        lines2 = re.findall(r"Polygons with (\d+) sides, (\w+) color, and (\w+)", text)
        rule_collection = []
        lines1 = lines1 if lines1 != [] else lines2
        for line in lines1:
            sides, color, material = line[0], line[1], line[2]
            rule_collection.append(sides+"-sides+"+color+"+"+material)
        return rule_collection

    def parse_validating(self, text):
        text = text.replace("\nSystem: ", "")
        predicted_groups = []
        result1 = re.search('Validation Result:\n(.+?)\nRectified', text, re.S)
        result2 = re.search('Validation Result: (.+?)\nRectified', text, re.S)
        result = result1 if result1!=None else result2
        if result != None:
            result = result.group(1)
        else:
            result = "Unparseable"
        rectified_results = re.findall('Group (\d+): (Polygon \d+(, Polygon \d+)*)', text)
        for match in rectified_results:
            polygons = match[1].split(', ')
            predicted_groups.append({'Group:'+ str(match[0]):match[1]})
        return result, predicted_groups

    def parse_error_correction(self, text):
        try:
            predicts1 = re.search('Correct Rules or Not:\n(.*)\n', text)
            predicts2 = re.search('Correct Rules or Not: (.*)\n', text)
            predicts = predicts1.group(1) if predicts1 != None else predicts2.group(1)
        except:
            predicts = "Unparseable"
        rectified_rules_list = []
        if predicts == "No":
        # 抽取"Rectified Rules:"后的结果，并根据"->"进行分开
            rectified_rules = re.findall('The wrong rule: (.*?) -> The correct rule: (.*?)(?=\n|$)', text)
            for i, rule in enumerate(rectified_rules, start=1):
                wrong_rule, correct_rule = rule
                rectified_rules_list.append([wrong_rule, correct_rule])
        return predicts, rectified_rules_list

    def parse_rule_incorporation(self, text):
        try:
            new_rules_or_not1 = re.search('New Rules or Not:\s*(\w+)', text)
            new_rules_or_not2 = re.search('New Rules or Not: (\w+)\s*', text)
            new_rules_or_not = new_rules_or_not1.group(1) if new_rules_or_not1 != None else new_rules_or_not2.group(1)
        except:
            new_rules_or_not = "Unparseable"
        rule_collection = []
        if new_rules_or_not == "Yes":
            lines1 = re.findall(r"Polygons with (\d+) Sides, (\w+) Color, and (\w+)", text)
            for line in lines1:
                sides, color, material = line[0], line[1], line[2]
                rule_collection.append(sides + "-sides+" + color + "+" + material)
        else:
            rule_collection = []
        return new_rules_or_not, rule_collection


class OrderingOutputParser(CommaSeparatedListOutputParser):

    def parse_instruction(self, text):
        matches = re.findall(r'(\d+)\. (\w+)', text)
        color_ranks = [color for rank, color in matches]
        return color_ranks

    def parse_inducting(self, text):
        pattern = r'Rank (\d+) (\w+)'
        matches = re.findall(pattern, text)
        results = {}
        for match in matches:
            results[match[1]] = int(match[0])
        return results

    def parse_validating(self, text):
        correct_results = re.search('Correct Results or Not:\s*(\w+)', text)
        if correct_results:
            correct_results = correct_results.group(1)
        else:
            correct_results = "No match found"
        if correct_results == "No":
            if  re.findall(r"Wrong Priority Color: (\w+) -> Rectified Priority Color: (\w+)", text) != []:
                rectified_results = re.findall(r"Wrong Priority Color: (\w+) -> Rectified Priority Color: (\w+)", text)
            elif re.findall(r"Rank (\d+) : (\w+) -> Rank (\d+) : (\w+)", text) != []:
                rectified_results = re.findall(r"Rank (\d+) : (\w+) -> Rank (\d+) : (\w+)", text)
                rectified_results = [[x[1], x[3]] for x in rectified_results]
            elif re.findall(r"Rank (\d+): (\w+) -> Rank (\d+): (\w+)", text) != []:
                rectified_results = re.findall(r"Rank (\d+): (\w+) -> Rank (\d+): (\w+)", text)
                rectified_results = [[x[1], x[3]] for x in rectified_results]
            else:
                rectified_results = re.findall(r"(\w+) -> (\w+)", text)
        elif correct_results == "Yes":
            rectified_results = None
        return correct_results, rectified_results

    def parse_error_correction(self, text):
        try:
            rules_correct1 = re.search('Correct Rules or Not:\n([^\\n]*)\n', text, re.DOTALL)
            rules_correct2 = re.search('Correct Rules or Not: ([^\\n]*)\n', text, re.DOTALL)
            rules_correct = rules_correct1 if rules_correct1 != None else rules_correct2
        except AttributeError:
            rules_correct = "Unparseable"
        rectified_dict = {}
        if rules_correct.group(1) == "No":
            rectified_rules = re.search('Rectified Rules:\n(.*)', text, re.DOTALL).group(1).split('\n')        # Extract "Rectified Rules" result
            for rule in rectified_rules:
                try:
                    rank, color = rule.split(":")
                    rectified_dict[color.strip()] = int(rank.replace("Rank ",""))
                except ValueError:
                    pass
        return rules_correct.group(1), rectified_dict


    def parse_rule_incorporation(self, text):
        new_rules1 = re.findall('New Rules or Not:\n(.+)', text)
        new_rules2 = re.findall('New Rules or Not: (.+)\n', text)
        # Find all lines following "New inducted:"
        # Since re.findall returns a list of matches, we need to extract the match
        predicts = new_rules1[0] if new_rules1 else new_rules2[0]
        if predicts == "Yes":
            try:
                new_inducted = re.findall(r"(Rank \d+ : \w+)", text.split("New Inducted:\n")[1])
                new_inducted_rules = new_inducted
                new_inducted_dict = {}
                if new_inducted_rules != []:
                    for idx, new_rule in enumerate(new_inducted_rules):
                        try:
                            rank, color = new_rule.split(":")
                            rank = rank.replace(f"{idx + 1}. Rank ", "")
                            rank = rank.replace("Rank ", "")
                            new_inducted_dict[color.strip()] = int(rank.replace(f"{idx+1}. Rank ",""))
                        except ValueError:
                            pass
            except IndexError:
                new_inducted_rules = []
                new_inducted_dict = {}
        else:
            new_inducted_rules = []
            new_inducted_dict = {}
        return predicts, new_inducted_rules, new_inducted_dict
