import copy
import math
import torch
import numpy as np
from tqdm import trange

class NegotiationNode():
    def __init__(self, game, args, state, parent = None, action_taken = None, visit_count = 0):
        self.game = copy.deepcopy(game)
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken # action taken to get to the node (action taken by the parent)

        self.player = state[0]

        self.children = []

        self.visit_count = visit_count
        self.value_sum_p1 = 0
        self.value_sum_p2 = 0
    
    def is_fully_expanded(self):
        return len(self.children) > 0

    def select(self, player):
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.get_ucb(child, player)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
        
        return best_child
    
    def get_ucb(self, child, player):
        # ucb =  Q(s,a) + c * p(s,a) * sqrt(sum_b(N(s,b)))/(1 + n(s,a))
        if player == 1:
            value_sum = child.value_sum_p1
        elif player == -1:
            value_sum = child.value_sum_p2
        
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = ((value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * math.sqrt((self.visit_count)) / (1 + child.visit_count)
    
    def expand(self, moves):
        for action in moves:
            child_state = self.state.copy()
            child_state = self.game.get_next_state(child_state, action, self.player)
            child_state = self.game.change_perspective(child_state.copy())

            child = NegotiationNode(self.game, self.args, child_state, self, action[0])
            self.children.append(child)
        
        return self.children

    def backpropagate(self, value, opponent_value, player): # the values are from the given player's perspective
        self.visit_count += 1
        if player == 1:
            self.value_sum_p1 += value
            self.value_sum_p2 += opponent_value   
        elif player == -1:
            self.value_sum_p2 += value
            self.value_sum_p1 += opponent_value
        if self.parent is not None:
            self.parent.backpropagate(value, opponent_value, player)

class NegotiationQMCTS():
    def __init__(self, game, args, model, value_model) -> None:
        self.game = copy.deepcopy(game)
        self.args = args
        self.model = model
        self.value_model = value_model
    
    @torch.no_grad()
    def search(self, state):
        # root
        root = NegotiationNode(self.game, self.args, state, visit_count=1)

        valid_moves = self.game.neural_valid_moves(root.state)
        children = root.expand(valid_moves)
        expandable_children = []
        for child in children:
            value, is_terminal = self.game.get_value_and_terminated(child.state, child.action_taken)
            opponent_val = value
            if value != 0:
                opponent_val = self.game.get_value(child.state, -1, self.game.get_opponent(child.player))
            if is_terminal:
                child.backpropagate(value, opponent_val, child.player)
            else:
                expandable_children.append(child)

        if len(expandable_children) > 0:
            values = self.model(
                torch.tensor(np.stack([c.state for c in expandable_children]), dtype=torch.float32, device=self.model.device)
            )
            opp_values = self.value_model(
                torch.tensor(np.stack([c.state for c in expandable_children]), dtype=torch.float32, device=self.model.device)
            )
        
        for i, child in enumerate(expandable_children):
            # backprop
            child.backpropagate(values[i], opp_values[i], child.player)

        for search in range(self.args['num_searches']):
            node = root

            # selection
            while node.is_fully_expanded():
                node = node.select(node.player)
            
            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            opponent_val = value
            if value != 0:
                opponent_val = self.game.get_value(node.state, -1, self.game.get_opponent(node.player))

            if not is_terminal:
                valid_moves = self.game.neural_valid_moves(node.state)
                # value = value.item()
                # opponent_val = opponent_val.item()

                # expansion
                children = node.expand(valid_moves)
                expandable_children = []
                for child in children:
                    value, is_terminal = self.game.get_value_and_terminated(child.state, child.action_taken)
                    opponent_val = value
                    if value != 0:
                        opponent_val = self.game.get_value(child.state, -1, self.game.get_opponent(child.player))
                    if is_terminal:
                        child.backpropagate(value, opponent_val, child.player)
                    else:
                        expandable_children.append(child)
                
                if len(expandable_children) > 0:
                    values = self.model(
                        torch.tensor(np.stack([c.state for c in expandable_children]), dtype=torch.float32, device=self.model.device)
                    )
                    opp_values = self.value_model(
                        torch.tensor(np.stack([c.state for c in expandable_children]), dtype=torch.float32, device=self.model.device)
                    )
                
                for i, child in enumerate(expandable_children):
                    # backprop
                    child.backpropagate(values[i], opp_values[i], child.player)
            else:
                node.backpropagate(value, opponent_val, node.player)
        
        action_probs = np.zeros(self.game.action_size + 1)
        for child in root.children:
            value_sum = child.value_sum_p2 if child.player == 1 else child.value_sum_p1
            action_probs[child.action_taken] = value_sum / child.visit_count if child.visit_count != 0 else 0
        # action_probs /= np.sum(action_probs)
        # action_probs = np.zeros(self.game.action_size + 1)
        # for child in root.children:
        #     action_probs[child.action_taken] = child.visit_count
        # action_probs /= np.sum(action_probs)
        return action_probs


class NegotiationQMCTSParallel():
    def __init__(self, game, args, model, value_model) -> None:
        self.game = copy.deepcopy(game)
        self.args = args
        self.model = model
        self.value_model = value_model
    
    def avg_resp(self, states, spGames):
        children = []
        for i, spg in enumerate(spGames):
            valid_moves = self.game.average(states[i])
            spg.root = NegotiationNode(self.game, self.args, states[i], visit_count=1)
            children += spg.root.expand(valid_moves)
        
        expandable_children = []
        for child in children:
            value, is_terminal = self.game.get_value_and_terminated(child.state, child.action_taken)
            opponent_val = value
            if value != 0:
                opponent_val = self.game.get_value(child.state, -1, self.game.get_opponent(child.player))
            if is_terminal:
                child.backpropagate(value, opponent_val, child.player)
            else:
                expandable_children.append(child)

        if len(expandable_children) > 0:
            values = self.model(
                torch.tensor(np.stack([c.state for c in expandable_children]), dtype=torch.float32, device=self.model.device)
            )
            opp_values = self.value_model(
                torch.tensor(np.stack([c.state for c in expandable_children]), dtype=torch.float32, device=self.model.device)
            )
        
        for i, child in enumerate(expandable_children):
            child.backpropagate(values[i], opp_values[i], child.player)
    
    @torch.no_grad()
    def search(self, states, spGames):
        
        children = []
        for i, spg in enumerate(spGames):
            valid_moves = self.game.neural_valid_moves(states[i])
            spg.root = NegotiationNode(self.game, self.args, states[i], visit_count=1)
            children += spg.root.expand(valid_moves)
        
        expandable_children = []
        for child in children:
            value, is_terminal = self.game.get_value_and_terminated(child.state, child.action_taken)
            opponent_val = value
            if value != 0:
                opponent_val = self.game.get_value(child.state, -1, self.game.get_opponent(child.player))
            if is_terminal:
                child.backpropagate(value, opponent_val, child.player)
            else:
                expandable_children.append(child)

        if len(expandable_children) > 0:
            values = self.model(
                torch.tensor(np.stack([c.state for c in expandable_children]), dtype=torch.float32, device=self.model.device)
            )
            opp_values = self.value_model(
                torch.tensor(np.stack([c.state for c in expandable_children]), dtype=torch.float32, device=self.model.device)
            )
        
        for i, child in enumerate(expandable_children):
            child.backpropagate(values[i], opp_values[i], child.player)

        for search in trange(self.args['num_searches']):
            for spg in spGames:
                spg.node = None
                node = spg.root

                # selection
                while node.is_fully_expanded():
                    node = node.select(node.player)
                
                value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
                opponent_val = value
                if value != 0:
                    opponent_val = self.game.get_value(node.state, -1, self.game.get_opponent(node.player))

                if is_terminal:
                    # backprop
                    node.backpropagate(value, opponent_val, node.player)
                else:
                    spg.node = node

            expandable_spGames = [mappingIdx for mappingIdx in range(len(spGames)) if spGames[mappingIdx].node is not None]
            if len(expandable_spGames) > 0:
                nodes = [spGames[mappingIdx].node for mappingIdx in expandable_spGames]
                children = []
                for node in nodes:
                    valid_moves = self.game.neural_valid_moves(node.state)
                    children += node.expand(valid_moves)
                
                expandable_children = []
                for child in children:
                    value, is_terminal = self.game.get_value_and_terminated(child.state, child.action_taken)
                    opponent_val = value
                    if value != 0:
                        opponent_val = self.game.get_value(child.state, -1, self.game.get_opponent(child.player))
                    if is_terminal:
                        child.backpropagate(value, opponent_val, child.player)
                    else:
                        expandable_children.append(child)

                if len(expandable_children) > 0:
                    values = self.model(
                        torch.tensor(np.stack([c.state for c in expandable_children]), dtype=torch.float32, device=self.model.device)
                    )
                    opp_values = self.value_model(
                        torch.tensor(np.stack([c.state for c in expandable_children]), dtype=torch.float32, device=self.model.device)
                    )
                
                for i, child in enumerate(expandable_children):
                    child.backpropagate(values[i], opp_values[i], child.player)