import io
import re
import utils
from typing import NamedTuple, TypedDict, Optional

import numpy as np

from world_model import GSM8kState, GSM8kAction, GSM8kPromptDict, SubResult
from reasoners import SearchConfig, LanguageModel

class GSM8kUsefulPrompt(TypedDict):
    input: str
    question_prefix: str
    subquestion_prefix: str
    new_subquestion_prefix: str
    useful_prefix: str


class GSM8kConfig(SearchConfig):
    def __init__(
        self,
        base_model: LanguageModel,
        useful_prompt: GSM8kUsefulPrompt,
        n_actions=4,
        batch_size=1,
        temperature=0.8,
        top_k=50,
        top_p=0.95,
        reward_alpha=0.5,
        reward_confidence_default=0.8,
        depth_limit=5,
        force_terminating_on_depth_limit=True,
        force_overall_prompt_on_overall_question=True,
        force_overall_question_on_overall_prompt=True,
    ) -> None:
        super().__init__()
        self.base_model = base_model
        self.useful_prompt = useful_prompt
        self.example = ""
        self.batch_size = batch_size
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.n_actions = n_actions
        self.force_terminating_on_depth_limit = force_terminating_on_depth_limit
        self.depth_limit = depth_limit
        self.reward_alpha = reward_alpha
        self.reward_confidence_default = reward_confidence_default
        self.force_overall_prompt_on_overall_question = (
            force_overall_prompt_on_overall_question
        )
        self.force_overall_question_on_overall_prompt = (
            force_overall_question_on_overall_prompt
        )
        self.overall_question: Optional[str] = None
        self.prompt_examples = ""
        self.n_shots = 0

    def update_example(self, example: str, prompt: GSM8kPromptDict = None) -> None:
        super().update_example(example, prompt=prompt)
        assert prompt is not None
        self.prompt = prompt
        with io.StringIO() as f:
            f.write(self.prompt["instruction"] + "\n\n")
            for idx, example in enumerate(self.prompt["interactive_examples"]):
                f.write(example + "\n\n")
            self.n_shots = len(self.prompt["interactive_examples"])
            self.prompt_examples = f.getvalue()

        if (
            self.force_overall_prompt_on_overall_question
            or self.force_overall_question_on_overall_prompt
        ):
            self.overall_question = re.match(
                ".*((Calculate|calculate|how|How|what|What|Find|find|True or false|Determine|determine|When|when).*)$",
                self.example,
            )[1]
    
    def append_state(self, state: GSM8kState, path: str):
        step = path.split('\n')[0]
        new_state = state.copy()
        new_state.append(SubResult(step))
        return new_state
    
    def get_paths(self, state: GSM8kState, num: int, actual_ans):
        model_input = self.prompt_examples + self.example + '\n'
        for (step, ) in state:
            model_input += step + '\n'
        
        outputs = self.base_model.generate(
            [model_input],
            temperature=self.temperature,
            top_k=self.top_k,
            top_p=self.top_p,
            num_return_sequences=num
        ).text
        outputs = [output.strip() for output in outputs]

        correct_cnt = 0
        for output in outputs:
            final_ans = utils.retrieve_answer(output)
            if utils.judge_answer(final_ans, actual_ans):
                correct_cnt += 1
        return outputs, correct_cnt/num
