import os
import re
import sys
import json
import time
import torch
import numpy as np

import openai
from dotenv import load_dotenv

sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)))
from Game import NegotiationGame
from QMCTS import NegotiationQMCTS
from Model import ValueModel

GET_BUYER_OFFER = """What offer did the Buyer propose? Just give the dollar amount and nothing else. If the buyer did not propose an offer then say so. If the buyer accepted the seller's offer then say so.\n\n"""

# Set up OpenAI API Key
load_dotenv()
openai.organization = os.getenv("API_ORG_KEY")
openai.api_key = os.getenv("API_KEY")

def join_hist(hist):
    convo = ''
    for row in hist:
        if row['role'] == 'system':
            continue
        elif 'Begin the negotiation by offering to sell the product' in row['content']:
            continue
        elif row['role'] == 'user':
            convo += 'Buyer: ' + row['content'].split(' \n ')[0] + '\n'
        else:
            convo += 'Seller: ' + row['content'].split(' \n ')[0] + '\n'
    return convo[:-1]

def chatgpt_chatbot(messages, model):
    # print(messages)
    # logging.info(f'user_message: {messages[-1]["content"]}')
    response = openai.ChatCompletion.create(
        # model="gpt-4",
        # model = "gpt-3.5-turbo-0301",
        model = model,
        messages=messages,
        # max_tokens=50,
    )
    # print(response)
    return response["choices"][0]["message"]["content"].strip()

class Seller():
    def __init__(self) -> None:
        self.prompts = self.load_prompts('../prompts.json')
        self.game = NegotiationGame(11000, 15000)
        self.player = 1

        self.args = {
            'C': 2,
            'num_searches': 10,
            'temperature': 1,
            'dirichlet_epsilon': 0,
            'dirichlet_alpha': 0.3
        }

        self.model = ValueModel(self.game, 50, 50, torch.device('cuda'))
        self.value_model = ValueModel(self.game, 50, 50, torch.device('cuda'))
        self.model.load_state_dict(torch.load('models/model_ebs_10.pt'))
        self.value_model.load_state_dict(torch.load('models/value_model_ebs_10.pt'))
        self.model.eval()
        self.value_model.eval()
        self.mcts = NegotiationQMCTS(self.game, self.args, self.model, self.value_model)
    
    def load_prompts(self, path):
        with open(path) as f:
            prompts = json.load(f)
        return prompts

    def detected_offer_chatgpt(self, message, history, params):
        while True:
            try:
                offer = chatgpt_chatbot(
                    messages = [dict(role='user', content=GET_BUYER_OFFER+join_hist(history+[dict(role='user', content=message)]))], 
                    model = 'gpt-4' #model = 'gpt-3.5-turbo-0301'
                )
                break
            except:
                time.sleep(3)
        # logging.info(f'buyer_amount raw output: {offer}')
        try:
            amount = int(re.sub("[^0-9]", "", offer))
            if 'k' in offer.lower():
                amount *= 1000
        except:
            # print(f'\nError offer given by gpt was: {offer}\n')
            if 'accept' in offer:
                amount = self.game.get_seller_offer(params['state'])
            else:
                amount = self.game.get_buyer_offer(params['state'])
        # print(offer)
        # print(GET_BUYER_OFFER+join_hist(history+[dict(role='user', content=message)]))
        return amount if amount is not None else -1
    
    def chatgpt_response(self, move, action, history, message, params):
        # print(move)
        if move == self.game.get_buyer_offer(params['state']):
            prompt = self.prompts['closers']['accept_offer'].format(f'{self.game.get_buyer_offer(params["state"]):,}')
        elif params['state'][1] == 1:
            prompt = self.prompts['openers']['opening_offer'].format(f"{move:,}")
        else:
            prompt = self.prompts['counter_offers']['lower_price'].format(f"{move:,}")
        
        while True:
            try:
                response = chatgpt_chatbot(
                    messages = history + [dict(role='user', content=message + ' \n ' + prompt)],
                    model='gpt-3.5-turbo-0301'
                )
                break
            except:
                print('chatgpt fail')
                time.sleep(3)
        return response, prompt
    
    def opener(self, message, history, params):
        message = self.prompts['system_prompts']['instructional_prompt']
        mcts_probs = self.mcts.search(params['state'])
        mcts_probs = [i if i != 0 else -np.inf for i in mcts_probs]
        action = np.argmax(mcts_probs)
        valid_moves = self.game.neural_valid_moves(params['state'])
        params['state'] = self.game.get_next_state(params['state'], valid_moves[action], 1).tolist()
        move = self.game.get_seller_offer(params['state'])
        response, prompt = self.chatgpt_response(move, action, history, message, params)
        response = response.replace('\n', ' ')
        history.append(dict(role='system', content=message))
        history.append(dict(role='user', content=prompt))
        history.append(dict(role='assistant', content=response))
        return dict(response=response, history=history, parameters=params)
    
    def counter(self, message, history, params):
        offer = self.detected_offer_chatgpt(message, history, params)
        pos = np.where(params['state'][2:] == 0)[0] + 2
        if (params.get('final_subgame', False) and self.game.check_win(params['state'], offer)) or params['episode_done']:
            params['state'][pos[0] - 1] = offer
        else:
            params['state'][pos[0]] = offer
            params['state'][1] += 1

        if self.game.check_win(params['state'], offer): # human always has last turn in subgame, so we may need to update
            params = self.update_game(params, offer, self.game.get_seller_offer(params['state']))
            if params['episode_done']:
                response, prompt = self.chatgpt_response(self.game.get_seller_offer(params['state']), self.game.get_seller_offer(params['state']), history, message, params)
                history.append(dict(role='user', content=message))
                history[-1]['content'] = history[-1]['content'] + ' \n ' + prompt
                history.append(dict(role='assistant', content=response))
                params['state'] = params['state'].tolist()

                return dict(response=response, history=history, parameters=params)

        mcts_probs = self.mcts.search(params['state'])
        mcts_probs = [i if i != 0 else -np.inf for i in mcts_probs]
        action = np.argmax(mcts_probs)
        valid_moves = self.game.neural_valid_moves(params['state'])
        # print(valid_moves)
        params['state'] = self.game.get_next_state(params['state'], valid_moves[action], 1)
        move = self.game.get_seller_offer(params['state'])
        response, prompt = self.chatgpt_response(move, action, history, message, params)

        if self.game.check_win(params['state'], offer):
            params['episode_done'] = True
        
        response = response.replace('\n', ' ')
        history.append(dict(role='user', content=message))
        history[-1]['content'] = history[-1]['content'] + ' \n ' + prompt
        history.append(dict(role='assistant', content=response))
        params['state'] = params['state'].tolist()

        return dict(response=response, history=history, parameters=params)
    
    def update_game(self, params, buyer_offer, seller_offer):
        if params.get('final_subgame', False) or buyer_offer == seller_offer or params['episode_done']:
            params['episode_done'] = True
            return params

        seller_offer = self.game.get_seller_offer(params['state'])
        buyer_offer = self.game.get_buyer_offer(params['state'])
        state = self.game.get_initial_state()
        if params['state'][2] == 10:
            state[2] = 6
            state[3] = seller_offer
            state[4] = max(buyer_offer, 12500)
        else:
            state[1] = 2
            state[2] = 4
            if seller_offer != 12600:
                state[3] = seller_offer
                diff = seller_offer - 12600
                state[4] = seller_offer - diff * 2
            else:
                state[3] = params['state'][3]
                state[4] = params['state'][4]
            state[5] = seller_offer
            state[6] = buyer_offer
            params['final_subgame'] = True
        
        params['state'] = state
        return params
    
    def get_response(self, message, history, params):
        params['state'] = np.array(params.get('state', self.game.get_initial_state()))
        if len(history) == 0:
            return self.opener(message, history, params)
        else:
            return self.counter(message, history, params)
        
    
if __name__ == "__main__":
    seller = Seller()
    # Start the conversation
    history = []
    params = dict(episode_done=False)
    user_text = seller.prompts['system_prompts']['instructional_prompt']
    output = seller.get_response(user_text, history, params)
    print(f"Bot: {output['response']}")

    while True:
        user_text = input("You: ")
        if user_text == "quit":
            break
        output = seller.get_response(user_text, history, params)
        history = output['history']
        params = output['parameters']
        print(params['state'])
        print(params['episode_done'])
        print(f"Bot: {output['response']}")