
import dspy
from typing import Union
import re

def find_integers_in_string(input_string):
    # Use re.findall to find all occurrences of one or more digits
    numbers = re.findall(r'\d+', input_string)
    # Convert found strings to integers
    return [int(number) for number in numbers]

class BaselineInsertInformation(dspy.Signature):
    """Given a topic,  associated information, list of available section names. 
    you need to insert the information into the most appropriate section. 
    
    Output should be the index of desired choice.
    """
    topic = dspy.InputField(prefix="topic: ", format=str)
    information = dspy.InputField(prefix="information: ", format=str)
    options = dspy.InputField(prefix="options: ", format=str)
    choice = dspy.OutputField(prefix="Choice:\n", format=str)

class BaselineInsertInformationModule(dspy.Module):
    def __init__(self, 
                 insert_information_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
        super().__init__()
        self.insert_information = dspy.ChainOfThought(BaselineInsertInformation)
        self.insert_information_engine = insert_information_engine
    
    def construct_options(self, section_paths, heading_only=True):
        options = section_paths
        if heading_only:
            options = []
            for section_path in section_paths:
                if "->" in section_path:
                    options.append(section_path.split(" -> ")[-1])
                else:
                    options.append(section_path)
        options = [f"{idx + 1}. {option}" for idx, option in enumerate(options)]
        return "\n".join(options)

    def forward(self, topic, information_to_insert, section_paths, heading_only=True):
        option_string = self.construct_options(section_paths=section_paths, heading_only=heading_only)
        with dspy.settings.context(lm=self.insert_information_engine):
            choice = self.insert_information(topic=topic,
                                            information=information_to_insert,
                                            options=option_string).choice
        # cleaning
        if "Choice:" in choice:
            choice = choice[choice.find("Choice:") + len("Choice:"):]
            choice = choice.replace("\n", "").strip()
        indices = find_integers_in_string(choice)
        if len(indices) != 1:
            assert False, f"output does not follow format: {choice}"
        path_to_insert = section_paths[indices[0]]
        return dspy.Prediction(path_to_insert=path_to_insert)
