import random
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import trange

from QMCTS import NegotiationQMCTSParallel
from Game import NegotiationGame
from Model import ValueModel

class SPG():
    def __init__(self, game) -> None:
        self.state = game.get_initial_state()
        self.memory = []
        self.root = None
        self.node = None

class DQLParallel():
    def __init__(self, model, value_model, optimizer, value_model_optimizer, game, args) -> None:
        self.model = model
        self.value_model = value_model
        self.optimizer = optimizer
        self.value_model_optimizer = value_model_optimizer
        self.game = game
        self.args = args
        self.mcts = NegotiationQMCTSParallel(game, args, model, value_model)

    def selfPlay(self):
        return_memory = []
        value_memory = []
        player = 1
        spGames = [SPG(self.game) for spg in range(self.args['num_parallel_games'])]

        while len(spGames) > 0:
            states = np.stack([spg.state for spg in spGames])
            neutral_states = [self.game.change_perspective(state, player) for state in states]
            for i in range(len(spGames)):
                spGames[i].state = neutral_states[i]

            rand = random.uniform(0,1)
            if rand < self.args['br_prob']:
                self.mcts.search(neutral_states, spGames)
            else:
                self.mcts.avg_resp(neutral_states, spGames)

            for i in range(len(spGames))[::-1]: # flip range
                spg = spGames[i]

                # action_probs = np.zeros(self.game.action_size + 1)
                # for child in spg.root.children:
                #     value_sum = child.value_sum_p2 if child.player == 1 else child.value_sum_p1 # root is the child opp
                #     action_probs[child.action_taken] = value_sum / child.visit_count if child.visit_count != 0 else 0
                # if min(action_probs) < 0:
                #     action_probs -= min(action_probs)
                # action_probs /= np.sum(action_probs)
                action_probs = np.zeros(self.game.action_size + 1)
                for child in spg.root.children:
                    action_probs[child.action_taken] = child.visit_count
                action_probs /= np.sum(action_probs)

                # valid_moves = self.game.get_valid_moves(spg.state)
                # if np.sum(action_probs*valid_moves) == 0:
                #     action_probs = valid_moves / np.sum(valid_moves)

                spg.memory.append((spg.root.state.copy(), player))
                temperature_action_probs = action_probs ** (1 / self.args['temperature'])
                # temperature_action_probs *= valid_moves
                temperature_action_probs /= np.sum(temperature_action_probs)
                action = np.random.choice(self.game.action_size + 1, p=temperature_action_probs)

                # spg.state = self.game.get_next_state(spg.state, action, player)
                for child in spg.root.children:
                    if child.action_taken == action:
                        spg.state = child.state
                value, is_terminal = self.game.get_value_and_terminated(spg.state, action)

                if is_terminal:
                    opp_value = 0 if value == 0 else self.game.get_value(spg.state, -1, self.game.get_opponent(spg.state[0]))
                    for hist_neutral_state, hist_player in spg.memory:
                        hist_outcome = value if hist_player == player else opp_value
                        return_memory.append((
                            self.game.get_encoded_state(hist_neutral_state),
                            hist_outcome
                        ))
                        value_memory.append((
                        hist_neutral_state,
                        hist_outcome if hist_outcome != value else opp_value
                    ))
                    del spGames[i]
            player = self.game.get_opponent(player)

        return return_memory, value_memory

    def train(self, value_memory):
        random.shuffle(value_memory)
        for batchIdx in range(0, len(value_memory), self.args['batch_size']):
            sample = value_memory[batchIdx:min(len(value_memory), batchIdx + self.args['batch_size'])]
            state, value_targets = zip(*sample)
            state, value_targets = np.array(state), np.array(value_targets).reshape(-1, 1)
            
            state = torch.tensor(state, dtype=torch.float32, device=self.model.device)
            value_targets = torch.tensor(value_targets, dtype=torch.float32, device=self.model.device)
            
            out_value = self.value_model(state)
            loss = F.mse_loss(out_value, value_targets)

            self.value_model_optimizer.zero_grad()
            loss.backward()
            self.value_model_optimizer.step()

    def learn(self):
        for iteration in range(self.args['num_iterations']):
            memory = []
            value_memory = []

            self.model.eval()
            self.value_model.eval()
            for selfPlay_iteration in trange(self.args['num_selfPlay_iterations'] // self.args['num_parallel_games']):
                mem, value_mem = self.selfPlay()
                memory += mem
                value_memory += value_mem
            
            print([i[-1] for i in memory])
            self.model.train()
            self.value_model.train()
            for epoch in trange(self.args['num_epochs']):
                self.train(memory)
                self.train(value_memory)
            
            torch.save(self.model.state_dict(), f"models/model_{iteration}_{self.game}.pt")
            torch.save(self.optimizer.state_dict(), f"models/optimizer_{iteration}_{self.game}.pt")
            torch.save(self.value_model.state_dict(), f"models/value_model_{iteration}_{self.game}.pt")
            torch.save(self.value_model_optimizer.state_dict(), f"models/value_model_optimizer_{iteration}_{self.game}.pt")

def dql_train():
    game = NegotiationGame(11000, 15000)
    device = torch.device("cuda")
    model = ValueModel(game, 50, 50, device)
    value_model = ValueModel(game, 50, 50, device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
    value_model_optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

    args = {
        'C': 2,
        'num_searches': 50,
        'num_iterations': 4,
        'num_selfPlay_iterations': 50,
        'num_epochs': 4,
        'num_parallel_games': 50,
        'batch_size': 128,
        'temperature': 1.25,
        'dirichlet_epsilon': 0.25,
        'dirichlet_alpha': 0.3,
        'br_prob': .5
    }

    alphaZero = DQLParallel(model, value_model, optimizer, value_model_optimizer, game, args)
    alphaZero.learn()

if __name__ == "__main__":
    # train_negotiation()
    dql_train()