import numpy as np
import openai
import os 
import time
import random
import logging
logging.getLogger().setLevel(logging.CRITICAL)

import llm_reward_fns
import sys

openai.util.logger.setLevel(logging.WARNING)
openai_key = ""
def prompt_openai_and_get_response(message_list, prompt_chat_gpt=True, model_to_run=None, temperature_to_use=0, prompt_local_model=True, openai_url=None, openai_key=None):
    '''given a prompt, prompt the OpenAI model and then return the response
    
    If using T5, pass in openai_url. Otherwise, set prompt_local_model to False and pass in openai_key
    '''
    
    if openai_url == 'chatgpt':
        prompt_local_model = False
        prompt_chat_gpt = True
    
    all_text_in_msgs_len = 0
    for x in message_list:
        all_text_in_msgs_len += len(x['content'])
    if model_to_run is None:
        if all_text_in_msgs_len > 4500:
            model_to_run = 'gpt-3.5-turbo-16k'
        else:
            model_to_run = 'gpt-3.5-turbo'
    import openai
    openai_key = openai_key
    openai.api_key = openai_key
    tries_done = 0
    
    if prompt_local_model:
        # print("Prompting local model")
        openai.api_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" # can be anything
        if openai_url is not None:
            openai.api_base = openai_url
        else:
            openai.api_base = "http://128.10.12.202:47526/v1"
    
    given_response = None
    if prompt_chat_gpt or prompt_local_model:
        while tries_done < 5:
            try:
                if prompt_local_model:
                    if openai_url is not None:
                        openai.api_base = openai_url
                    else:
                        openai.api_base = "http://128.10.12.202:47526/v1"
                    if len(message_list) > 1:
                        print("We have more than a single message, we can't run the prompt!")
                        exit(1)
                    
                    given_response = openai.ChatCompletion.create(model='davinci', temperature=0.0, messages=message_list, top_p=1, max_tokens=100)["choices"][0]["message"]['content'].replace('"', '')
                    break
                else:
                    import openai
                    openai.api_base = "https://api.openai.com/v1"
                    message_list = [{'role': 'user', 'content': str(message_list[0]['content'])}]
                    openai.api_key = openai_key
                    given_response = openai.ChatCompletion.create(model=model_to_run, temperature=0.0, messages=message_list, top_p=1)['choices'][0]["message"]['content'].replace('"', '')
                    break
            # except
            except Exception as e:
                print("Exception running prompt_to_run, going to retry" + str(e))
                time.sleep(5)
                tries_done += 1
    return given_response

def parse_openai_output(openai_response):
    '''Given the LLM_task model's response to forming a community, parse the output to actually get the community users and separation. This includes some handlers to make sure that we can parse the LLM_Task response, since it is a large generation language model.'''
    
    # parse the output
    comm_1_found = False
    comm_2_found = False
    comm_1_usernames = []
    comm_2_usernames = []
    for given_line in openai_response.split('\n'):
        if 'community 1' in given_line.lower():
            comm_1_found = True 
        elif 'community 2' in given_line.lower():
            comm_2_found = True 
            comm_1_found = False 
        elif len(given_line.strip()) > 0:
            if comm_1_found:
                comm_1_usernames.append(given_line.strip().replace('*', '').strip().lstrip().rstrip())
            elif comm_2_found:
                comm_2_usernames.append(given_line.strip().replace('*', '').strip().lstrip().rstrip())
    if len(comm_1_usernames) == 0 and len(comm_2_usernames) == 0:
        for given_line in openai_response.split('\n'):
            if 'community 1' in given_line.lower():
                # parse the users 
                comm_1_usernames = given_line.replace('community 1', '').replace('Community 1', '').replace(':', '').split(',')
            elif 'community 2' in given_line.lower():
                # parse the users 
                comm_2_usernames = given_line.replace('community 2', '').replace('Community 2', '').replace(':', '').split(',')
        if len(comm_1_usernames) == 0 and len(comm_2_usernames) == 0:
            print("Warning, There are no users parsed!!! ")
            print("Parsing openai_response  " + str(openai_response) + " and got: " + str(','.join(comm_1_usernames) + ';;;;' + ','.join(comm_2_usernames)))
    
    def fix_username_pairing(given_usernames):
        return_usernames = []
        for x in given_usernames:
            x = x.strip().rstrip().lstrip()
            if len(x) == 0:
                continue
            if x[0] == '-' or x[0] == '•':
                x = x[1:]
            if x[0] == '@':
                x = x[1:]
            x = x.replace('username', '').replace('Username', '').replace(':', '').lstrip().rstrip().strip()  
            return_usernames.append(x)
        return return_usernames  
                
    comm_1_usernames = fix_username_pairing(comm_1_usernames)
    comm_2_usernames = fix_username_pairing(comm_2_usernames)
    return ','.join(comm_1_usernames) + ';;;;' + ','.join(comm_2_usernames)    
            
            
def create_user_summary(username, user_representation, using_reddit=True, entity_used=None, use_llama_for_gold_focus_areas=False, temperature_to_use=None):
    '''given a user representation, get a summary 
    user_representation should have the username in the representation already
    '''
    path_to_save_summary = ''
    if use_llama_for_gold_focus_areas:
        path_to_save_summary += 'llama_'
    if not using_reddit:
        # for twitter, we should have summaries saved 
        if os.path.isfile(path_to_save_summary + str(username) + '.txt'):
            with open(path_to_save_summary + str(username) + '.txt', 'w') as the_file:
                user_summary = '\n'.join(the_file.readlines())
                return user_summary
                
    summarize_prompt = 'What is this user discussing and what is their perspective? Please summarize in one sentence.\n'
    if entity_used is not None:
        summarize_prompt = 'What is this user discussing and what is their perspective? Please summarize in one sentence. They are discussing about entity ' + str(entity_used) + '\n'
    summarize_prompt += user_representation
    message_list = []
    curr_message = {"role": "user", "content": str(summarize_prompt)}
    message_list.append(curr_message)  
    if temperature_to_use is None:
        temperature_to_use = 0.0
    else:
        temperature_to_use = temperature_to_use
    openai_summary = prompt_openai_and_get_response(message_list, prompt_local_model=False, temperature_to_use=temperature_to_use)
    # save the summary  
    return openai_summary
    

def prompt_lm_determine_communities_given_input(given_input_text, openai_url=None):
    '''given a prompt text (community question, optional focus area, user summaries, ask LLM_task to give you the communities and then parse them )'''
    output_comms = []
    for given_input in given_input_text:
    
        message_list = []
        curr_message = {"role": "user", "content": given_input}
        message_list.append(curr_message)  
        openai_response = prompt_openai_and_get_response(message_list=message_list, openai_url=openai_url)
        
        if openai_response is None:
            output_comms.append(';;;;') 
        else:
            output_comms.append(parse_openai_output(openai_response))
            
    return output_comms


def get_chat_gpt_prompt_given_summaries(curr_users_to_run, user_summary_mapping):
    '''given a bunch of users and their summaries, get the list to run through ChatGPT'''
    output_prompt = 'Which users have the same perspective?\n' 
    for given_user in curr_users_to_run:
        if given_user not in user_summary_mapping:
            print("Why don't we have a summary for this user " + str(given_user))
            continue
        output_prompt += 'Username: ' + str(given_user) + " " + str(user_summary_mapping[given_user]).replace('\n', '') + '\n'
    
    return output_prompt