import io
import re
import utils
from typing import NamedTuple, TypedDict
from collections import defaultdict
from reasoners import WorldModel, LanguageModel
from reasoners.base import Example


class SubResult(NamedTuple):
    sub_answer: str

GSM8kState = list[SubResult]
GSM8kAction = str
GSM8kExample = str


class GSM8kPromptDict(TypedDict):
    instruction: str
    interactive_examples: list[str]
    useful_examples: list[str]
    question_prefix: str
    subquestion_prefix: str
    overall_question_prefix: str
    answer_prefix: str


class GSM8kWorldModel(WorldModel[GSM8kState, GSM8kAction, GSM8kExample]):
    def __init__(
        self,
        base_model: LanguageModel,
        n_confidence=8,
        batch_size=2,
        temperature=0.8,
        top_k=50,
        top_p=0.95,
        early_stop_base=None,
        early_stop_threshold=1.0,
    ) -> None:
        super().__init__()
        self.base_model = base_model
        self.batch_size = batch_size
        self.n_confidence = n_confidence
        self.temperature = temperature
        self.early_stop_base = (
            early_stop_base if early_stop_base is not None else n_confidence
        )
        self.early_stop_threshold = early_stop_threshold
        self.prompt_examples = ""
        self.n_shots = 0
        self.top_k = top_k
        self.top_p = top_p

    def init_state(self) -> list:
        return []

    def is_terminal(self, state: GSM8kState) -> bool:
        if len(state) > 0 and "The answer is" in state[-1].sub_answer:
            return True
        else:
            return False
