import os
import re
import math
import time
import torch
import numpy as np

from QMCTS import NegotiationQMCTS
from Model import  ValueModel

import openai
from dotenv import load_dotenv

# Set up OpenAI API Key
load_dotenv()
openai.organization = os.getenv("API_ORG_KEY")
openai.api_key = os.getenv("API_KEY")

global cache, avg_cache
cache = {}
avg_cache = {}

def chatgpt_chatbot(messages, model):
    # print(messages)
    # logging.info(f'user_message: {messages[-1]["content"]}')
    response = openai.ChatCompletion.create(
        # model="gpt-4",
        # model = "gpt-3.5-turbo-0301",
        model = model,
        messages=messages,
        temperature = 0,
        request_timeout = 15
    )
    # print(response)
    return response["choices"][0]["message"]["content"].strip()

get_oom = lambda x: math.floor(math.log(x, 10))
round_num = lambda num, base: base * round(num/base)

class NegotiationGame():
    def __init__(self, seller_floor, buyer_ciel) -> None:
        self.seller_floor = seller_floor
        self.buyer_ciel = buyer_ciel
        
        self.action_size = 6
        self.memory_size = 20
        self.turn_limit = 10
        self.action_space = np.array(list(range(20))) * 5

        # seller is player 1
        # buyer is player -1
    
    def __repr__(self) -> str:
        return "NegotiationGame"

    def get_initial_state(self): # seller always goes first
        return np.array([1, 0, self.turn_limit, self.buyer_ciel, self.seller_floor] + [0 for i in range(self.memory_size)])
    
    # def shift_state(self, state):
    #     for i in range(4, self.memory_size - 1):
    #         state[i] = state[i+3]
    #     return state
    
    def get_buyer_offer(self, state):
        for i, offer in reversed(list(enumerate(state))):
            if i % 2 == 1 or i < 3:
                continue
            elif offer > 0:
                return offer
        return None
    
    def get_seller_offer(self, state):
        for i, offer in reversed(list(enumerate(state))):
            if i % 2 == 0 or i < 3:
                continue
            elif offer > 0:
                return offer
        return None
    
    def get_buyer_move(self, action, state):
        buyer_offer = self.get_buyer_offer(state)
        seller_offer = self.get_seller_offer(state)
        ciel = min([self.buyer_ciel, seller_offer]) if seller_offer is not None else self.buyer_ciel
        if action == 20:
            return seller_offer
        elif buyer_offer == None:
            if action == 0:
                return -1
            else:
                move = ciel - ciel * (self.action_space[action]/100)
        else:
            surplus = ciel - buyer_offer
            move  = buyer_offer + surplus * (self.action_space[action]/100)
        oom = get_oom(max([self.buyer_ciel, self.seller_floor])) - 2
        rounding = 10**oom
        move = round_num(move, rounding)
        return move
    
    def get_seller_move(self, action, state):
        buyer_offer = self.get_buyer_offer(state)
        seller_offer = self.get_seller_offer(state)
        floor = max([self.seller_floor, buyer_offer]) if buyer_offer is not None else self.seller_floor
        if action == 20:
            return buyer_offer
        elif seller_offer == None:
            if action == 0:
                return -1
            else:
                move = (1 + self.action_space[action]/100) * floor
        else:
            surplus = seller_offer - floor
            move = floor + surplus * (1 - self.action_space[action]/100)
        oom = get_oom(max([self.buyer_ciel, self.seller_floor])) - 2
        rounding = 10**oom
        move = round_num(move, rounding)
        return move
        
    def get_next_state(self, state, action, player):
        pos = np.where(state[5:] == 0)[0] + 5
        # if len(pos) == 0:
        #     state = self.shift_state(state)
        #     pos = [-2]
        
        # if player == 1:
        #     move = self.get_seller_move(action, state)
        # else:
        #     move = self.get_buyer_move(action, state)

        move = action[1]
        
        state[pos[0]] = move
        state[1] += 1
        return state
    
    def average(self, state):
        if avg_cache.get(tuple(state), False):
            return avg_cache[tuple(state)]
        
        midpoint = ((state[3] - state[4]) // 2) + state[4]
        prev_offer = state[list(state[3:]).index(0) + 3 - 2]
        opp_offer = state[list(state[3:]).index(0) + 3 - 1]

        action = 'sell' if state[0] == 1 else 'buy'
        relation = 'more' if state[0] == 1 else 'less'
        msg = f'I want to {action} an item for {relation} than ${midpoint} dollars. I am currently in a negotiation with the following history:\n\n'
        for i, offer in enumerate(state[3:]):
            if float(offer) == 0:
                break

            mover = 1 if i % 2 == 0 else -1
            if mover == state[0]:
                msg += f'My offer: {offer}\n'
            else:
                msg += f'Their offer: {offer}\n'
        
        bound = max(midpoint, opp_offer) if state[0] == 1 else min(midpoint, opp_offer)
        msg += f'\nWhat is a good counteroffer for me to propose? Give me a counteroffer between ${bound} and ${prev_offer}. Just give the amount and nothing else.'
        
        while True:
            try:
                moves = chatgpt_chatbot([{'role': 'user', 'content': msg}], "gpt-3.5-turbo")
                break
            except:
                print('chatgpt fail')
                time.sleep(5)
        
        amount = int(re.sub("[^0-9]", "", moves))
        offers = [(0, amount), (1, midpoint)]
        # if state[1] + 2 >= state[2] or amount == midpoint:
        #     # offers += [(1, midpoint), (2, opp_offer)]
        #     offers += [(2, opp_offer)]
        
        avg_cache[tuple(state)] = offers
        return offers

    def neural_valid_moves(self, state):
        if cache.get(tuple(state), False):
            return cache[tuple(state)]
        
        midpoint = ((state[3] - state[4]) // 2) + state[4]
        prev_offer = state[list(state[3:]).index(0) + 3 - 2]
        opp_offer = state[list(state[3:]).index(0) + 3 - 1]

        action = 'sell' if state[0] == 1 else 'buy'
        relation = 'more' if state[0] == 1 else 'less'
        msg = f'I want to {action} an item for {relation} than ${midpoint} dollars. I am currently in a negotiation with the following history:\n\n'
        for i, offer in enumerate(state[3:]):
            if float(offer) == 0:
                break

            mover = 1 if i % 2 == 0 else -1
            if mover == state[0]:
                msg += f'My offer: {offer}\n'
            else:
                msg += f'Their offer: {offer}\n'
        
        bound = max(midpoint, opp_offer) if state[0] == 1 else min(midpoint, opp_offer)
        msg += f'\nWhat are some good counteroffers for me to propose? Give me five counteroffers between ${bound} and ${prev_offer}. Just give the amounts.'

        while True:
            try:
                moves = chatgpt_chatbot([{'role': 'user', 'content': msg}], "gpt-3.5-turbo")
                break
            except:
                print('chatgpt fail')
                time.sleep(5)


        lines = moves.split('\n')
        offers = []
        i = 0
        for line in lines:
            if len(line) > 0 and line[0].isnumeric():
                idx = line.index('$')
                amount = int(re.sub("[^0-9]", "", line[idx:]))
                offers.append((i, amount))
                i += 1

        # offers += [(5, opp_offer)]
        if state[1] + 4 >= state[2] or prev_offer == midpoint:
            offers += [(5, midpoint)]
        if opp_offer >= midpoint:
            offers += [(len(offers), opp_offer)]
        # for i, offer in offers:
        #     if state[0] == -1 and offer < prev_offer:
        #         offers[i] = (i, opp_offer)
        cache[tuple(state)] = offers 
        return offers


    def get_valid_moves(self, state):
        offers = []
        player = state[0]
        for i in range(self.action_size + 1):
            offer = self.get_seller_move(i, state) if player == 1 else self.get_buyer_move(i, state)
            offers.append(offer)

        moves = []
        for i in range(self.action_size + 1):
            offer = offers[i]
            if i == 20: # check if okay to accept offer
                    if (offer is not None) and not ((player == -1 and offer >= self.buyer_ciel) or 
                        (player == 1 and offer <= self.seller_floor)):
                        moves.append(1)
                    else:
                        moves.append(0)
            # elif i == 0: # don't allow holding offer more than twice in a row
            #     moves.append(0) # no holds allowed
                # offer_pos = np.where(state[5:] == 0)[0] # position of players prev offer
                # if (player == 1 and state[5] == 0) or (player == -1 and state[5] == 0) or len(offer_pos) == 0:
                #     moves.append(1)
                # elif state[offer_pos[0]] == state[offer_pos[0]-2] or state[offer_pos[0]] == -1:
                #     moves.append(0)
                # else:
                #     moves.append(1)
            elif (i != 0) and ((offer is None) or (player == -1 and offer >= self.buyer_ciel) or # negative reward actions and actions that give the same offer
                (player == 1 and offer <= self.seller_floor) or offer in offers[:i]):
                moves.append(0)
            elif (((player ==-1 and self.get_seller_offer(state) is not None) and (player == -1 and offer >= self.get_seller_offer(state))) or # remove offers that match opponent offer
                 ((player ==1 and self.get_buyer_offer(state) is not None) and (player == 1 and offer <= self.get_buyer_offer(state)))):
                moves.append(0)
            else:
                moves.append(1)
        if sum(moves) == 0:
            moves[0] = 1
        return np.array(moves)
    
    def check_win(self, state, action):
        if self.get_buyer_offer(state) >= self.get_seller_offer(state):
            return True
        elif state[1] >= state[2]:
            return True
        else:
            return False
    
    def twoP0_value(self, state, player, buyer_offer, seller_offer):
        if player == 1:
            end_price = buyer_offer
            # if 0 not in state and action != 20 and end_price is not None:
            if buyer_offer != seller_offer:
                return -.01
            elif end_price is None:
                return -.01
            else:
                return (2 * ((seller_offer - state[4])/(state[3] - state[4])) - 1) + .18
        else:
            end_price = seller_offer
            # if 0 not in state and action != 20 and end_price is not None:
            if buyer_offer != seller_offer:
                return -.01
            if end_price is None:
                return -.01
            else:
                return ((2 * ((buyer_offer - state[4])/(state[3] - state[4])) - 1) * -1) + .18

    def ebs_value(self, state, player, buyer_offer, seller_offer):
        if buyer_offer != seller_offer:
                return -.01
        p1_payout = (seller_offer - state[4]) / (state[3] - state[4])
        if player == 1:
            if p1_payout >= .5:
                return  1 - p1_payout
            else:
                return -1 * p1_payout
        else:
            if p1_payout <= .5:
                return p1_payout
            else:
                return (1 - p1_payout) * -1

    def get_value(self, state, action, player):
        buyer_offer = self.get_buyer_offer(state)
        seller_offer = self.get_seller_offer(state)
        return self.ebs_value(state, player, buyer_offer, seller_offer)
        return self.twoP0_value(state, player, buyer_offer, seller_offer)
        
    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return self.get_value(state, action, state[0]), True
        return 0, False
    
    def get_opponent(self, player):
        return -player
    
    def get_opponent_value(self, value):
        return -value
    
    def change_perspective(self, state, player=None):
        if player is None:
            player = -1 * state[0]
        state[0] = player
        return state
    
    def get_encoded_state(self, state):
        return state

def test_game():
    game = NegotiationGame(11000, 15000)
    player = 1

    state = game.get_initial_state()

    while True:
        print(state)
        valid_moves = game.neural_valid_moves(state)
        print('valid moves', valid_moves)
        # print(valid_moves)
        action = int(input(f"{player}:"))

        # if action not in [i for i in range(game.action_size + 1) if valid_moves[i] == 1]:
        #     print("action not valid")
        #     continue

        state = game.get_next_state(state, valid_moves[action], player)

        value, is_terminal = game.get_value_and_terminated(state, action)

        if is_terminal:
            print(state)
            print(value)
            print(game.get_value(state, -1, game.get_opponent(player)))
            break

        player = game.get_opponent(player)
        state = game.change_perspective(state, player)

def test_negotiation():
    game = NegotiationGame(11000, 15000)
    player = 1

    args = {
        'C': 2,
        'num_searches': 10,
        'temperature': 1,
        'dirichlet_epsilon': 0,
        'dirichlet_alpha': 0.3
    }

    model = ValueModel(game, 50, 50, torch.device('cuda'))
    value_model = ValueModel(game, 50, 50, torch.device('cuda'))
    model.load_state_dict(torch.load('models/model_2_NegotiationGame.pt'))
    value_model.load_state_dict(torch.load('models/value_model_2_NegotiationGame.pt'))
    model.eval()
    value_model.eval()
    mcts = NegotiationQMCTS(game, args, model, value_model)

    state = game.get_initial_state()

    while True:
        if player == -1:
            print(state)
            valid_moves = game.neural_valid_moves(state)
            print('valid moves', valid_moves)
            action = int(input(f"{player}:"))

            # if valid_moves[action] == 0:
            #     print("action not valid")
            #     continue
        else:
            # neutral_state = game.change_perspective(state, player)
            mcts_probs = mcts.search(state)
            if min(mcts_probs) < 0:
                mcts_probs = [(i + .01) - min(mcts_probs) if i != 0 else 0 for i in mcts_probs]
            action = np.argmax(mcts_probs)
            valid_moves = game.neural_valid_moves(state)
            print(f'opponent valid moves: {valid_moves}')
            print(f'opponent action: {action}')

        state = game.get_next_state(state, valid_moves[action], player)

        value, is_terminal = game.get_value_and_terminated(state, action)

        if is_terminal:
            print(state)
            print(value)
            print(game.get_value(state, -1, game.get_opponent(player)))
            break

        player = game.get_opponent(player)
        state = game.change_perspective(state, player)
    
def main():
    game = NegotiationGame(11000, 15000)
    player = 1

    args = {
        'C': 2,
        'num_searches': 10,
        'temperature': 1,
        'dirichlet_epsilon': 0,
        'dirichlet_alpha': 0.3
    }

    model = ValueModel(game, 50, 50, torch.device('cuda'))
    value_model = ValueModel(game, 50, 50, torch.device('cuda'))
    model.load_state_dict(torch.load('models/model_2_NegotiationGame.pt'))
    value_model.load_state_dict(torch.load('models/value_model_2_NegotiationGame.pt'))
    model.eval()
    value_model.eval()
    mcts = NegotiationQMCTS(game, args, model, value_model)

    state = game.get_initial_state()

    for i in range(15):
        if player == -1:
            print(state)
            move = int(input(f"{player}:"))
            pos = np.where(state[5:] == 0)[0] + 5
            state[pos[0]] = move
            state[1] += 1
            action = -1


            # if valid_moves[action] == 0:
            #     print("action not valid")
            #     continue
        else:
            # neutral_state = game.change_perspective(state, player)
            mcts_probs = mcts.search(state)
            if min(mcts_probs) < 0:
                mcts_probs = [(i + .01) - min(mcts_probs) if i != 0 else 0 for i in mcts_probs]
            action = np.argmax(mcts_probs)
            valid_moves = game.neural_valid_moves(state)
            print(f'opponent valid moves: {valid_moves}')
            print(f'opponent action: {action}')

            state = game.get_next_state(state, valid_moves[action], player)

        value, is_terminal = game.get_value_and_terminated(state, action)

        if is_terminal: #and i != 9:
            # seller_offer = game.get_seller_offer(state)
            # state = game.get_initial_state()
            # state[3] = seller_offer
            # state[4] = max(game.get_buyer_offer(state), 12500)
            print(state)
            print(value)
            print(game.get_value(state, -1, game.get_opponent(player)))
            break

        player = game.get_opponent(player)
        state = game.change_perspective(state, player)
    
    print(state)
    print(value)
    print(game.get_value(state, -1, game.get_opponent(player)))

if __name__ == "__main__":
    main()
    # test_game()
    # test_negotiation()