import pickle
from os import PathLike
import math
from copy import deepcopy
from typing import Generic, Optional, NamedTuple, Callable, Hashable
import itertools
from abc import ABC
from collections import defaultdict

import numpy as np
import re
from tqdm import trange

from .. import SearchAlgorithm, WorldModel, SearchConfig, State, Action, Example, Trace

def retrieve_answer(output: str) -> Optional[str]:
    match = re.findall(r'.*The answer is: .*?([ $.0-9,\-=]+).*\..*', output)
    if len(match):
        answer = match[-1].replace(',', '').replace('$', '').replace(' ', '')
        if '=' in answer:
            answer = answer[answer.rindex('=') + 1:]
        return answer
    else:
        return None

def judge_answer(output: Optional[str], answer: str) -> bool:
    if output is None:
        return False
    try:
        output = int(output)
        answer = int(answer)
        return output == answer
    except ValueError:
        pass
    try:
        output = float(output)
        answer = float(answer)
        return output == answer
    except ValueError:
        pass
    return output == answer


class Node(Generic[State, Action]):
    id_iter = itertools.count()

    @classmethod
    def reset_id(cls):
        cls.id_iter = itertools.count()

    def __init__(self, state: Optional[State], parent: "Optional[Node]" = None, is_terminal: bool = False):
        self.id = next(Node.id_iter)
        self.reward = 0.
        self.is_terminal = is_terminal
        self.state = state
        self.parent = parent
        self.children: 'Optional[list[Node]]' = None
        self.extend_paths: 'Optional[list[str]]' = None
        if parent is None:
            self.depth = 0
        else:
            self.depth = parent.depth + 1


class Iter_Result(NamedTuple):
    terminal_state: State
    trace: Trace
    tree_state: Node


class FullTree(SearchAlgorithm, Generic[State, Action, Example]):
    def __init__(self,
                 output_trace_in_each_iter: bool = False,
                 depth_limit: int = 5,
                 n_iters: int = 10,
                 disable_tqdm: bool = True,
                 node_visualizer: Callable[[Node], dict] = lambda x: x.__dict__,
                 **kwargs):
        super().__init__()
        self.world_model = None
        self.search_config = None
        self.actual_ans = None
        self.output_trace_in_each_iter = output_trace_in_each_iter
        self.depth_limit = depth_limit
        self.n_iters = n_iters
        self.root: Optional[Node] = None
        self.disable_tqdm = disable_tqdm
        self._output_iter: list[Node] = None
        self.node_visualizer = node_visualizer
    
    def _is_terminal_with_depth_limit(self, node: Node):
        return node.is_terminal or node.depth >= self.depth_limit

    def _expand_root(self, root: Node, iters: int=10):
        paths, score = self.search_config.get_paths(root.state, iters, self.actual_ans)
        root.reward = score
        root.extend_paths = paths
        children = []
        for path in paths:
            state = self.search_config.append_state(root.state, path)
            child = Node(state=state, parent=root)
            child.is_terminal = self.world_model.is_terminal(child.state)
            c_paths, score = self.search_config.get_paths(child.state, 8, self.actual_ans)
            child.reward = score
            child.extend_paths = c_paths
            children.append(child)
        root.children = children

    def _expand(self, node: Node, child_num: int=1):
        if node.is_terminal:
            return
        children = []

        correct_steps = set()
        wrong_steps = set()
        paths = node.extend_paths
        for path in paths:
            final_ans = retrieve_answer(path)
            if judge_answer(final_ans, self.actual_ans):
                correct_steps.add(path.split('\n')[0])
            else:
                wrong_steps.add(path.split('\n')[0])

        steps = list(wrong_steps)[:2]
        steps.extend(list(correct_steps)[:(3-len(steps))])
        step_flag = [1. if step in correct_steps else 0. for step in steps]

        for i, step in enumerate(steps):
            state = self.search_config.append_state(node.state, step)
            child = Node(state=state, parent=node)
            child.is_terminal = self.world_model.is_terminal(child.state)
            child.reward = step_flag[i]
            if child.is_terminal == False:
                paths, score = self.search_config.get_paths(child.state, 8, self.actual_ans)
                child.reward = score
                child.extend_paths = paths
            children.append(child)
        node.children = children
        return max(children, key=lambda x: x.reward)
    
    def _simulate(self, trace: list[Node]):
        node = trace[-1]
        while True:
            self._expand(node)
            if self._is_terminal_with_depth_limit(node) or len(node.children) == 0:
                break
            node = max(node.children, key=lambda x: x.reward)
            trace.append(node)
        return trace

    def search(self):
        self.root = Node(state=self.world_model.init_state(), parent=None)
        self._expand_root(self.root, self.n_iters)
        if self.root.is_terminal:
            self._output_iter = []
            return
        for i in trange(self.n_iters, disable=self.disable_tqdm, desc='iteration', leave=False):
            trace = self._simulate([self.root.children[i]])
            if i == 0:
                self._output_iter = trace


    def __call__(self,
                 world_model: WorldModel[State, Action, Example],
                 search_config: SearchConfig[State, Action, Example],
                 **kwargs):
        Node.reset_id()
        self.world_model = world_model
        self.search_config = search_config
        self.actual_ans = kwargs['answer']

        self.search()
        if len(self._output_iter) == 0:
            terminal_state = []
        else:
            terminal_state = self._output_iter[-1].state
        result = Iter_Result(
            terminal_state=terminal_state,
            trace=None,
            tree_state=self.root
        )
        return result