import re
from typing import Optional, Union

from reasoners.base import AlgorithmOutput


def retrieve_answer(output: Union[list, str, AlgorithmOutput]) -> Optional[str]:
    if isinstance(output, AlgorithmOutput):
        if (result := getattr(output, 'aggregated_result', None)) is not None:
            return result
        output = output.terminal_state
    if isinstance(output, list):
        if len(output) == 0:
            return None
        output = output[-1].sub_answer
    match = re.findall(r'.*The answer is: .*?([ $.0-9,\-=]+).*\..*', output)
    if len(match):
        answer = match[-1].replace(',', '').replace('$', '').replace(' ', '')
        if '=' in answer:
            answer = answer[answer.rindex('=') + 1:]
        return answer
    else:
        return None


def retrieve_answer_from_dataset(answer: Union[str, dict]) -> str:
    if isinstance(answer, dict):
        answer = answer['answer']
    return re.match(r'[\S\s]*#### (.*)$', answer)[1].replace(',', '').replace(' ', '')


def judge_answer(output: Optional[str], answer: str) -> bool:
    if output is None:
        return False
    try:
        output = int(output)
        answer = int(answer)
        return output == answer
    except ValueError:
        pass
    try:
        output = float(output)
        answer = float(answer)
        return output == answer
    except ValueError:
        pass
    return output == answer
