import argparse
from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms import OpenAI
from tqdm import tqdm
import asyncio
import json
import random
import multiprocessing, queue
import sys
import os

sys.path.append(os.environ.get('PROJECTPATH'))
from src.utils import *

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--scr_model_name", type=str, default="codellama")
    parser.add_argument("--num_sample", type=int, default=-1)
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--task", type=str)

    return parser.parse_args()

def load_prompt(args):
    with open(f"prompts/pal_template.txt", "r") as f:
        prompt = f.read() 
    return prompt

def remove_example_usage(code_text):
    last_def_index = code_text.rfind("  return ")
    if last_def_index == -1:
        return code_text
    cut_index = code_text.find("\n", last_def_index)
    if cut_index == -1:
        return code_text
    return code_text[:cut_index+1]

def execute_output(output, result_queue):
    try:
        # Your existing setup for execution
        local_scope = {}
        exec(output, globals(), local_scope)
        # Assuming the function to call is solution(), and it returns a result
        result = local_scope['solution']()
        result_queue.put(str(result))
    except Exception as e:
        result_queue.put("None")

async def calculate_score(llm, data, prompt, helper):
    ''' evaluate an optimized instruction using scorer model '''
    # construct model inputs using instances in evaluation set
    list_of_model_inputs = [prompt.format(question=d['input']) for d in data]
    
    outputs = await generate_concurrently(llm, list_of_model_inputs, args)
    interpret_results = []
    clean_outputs = []
    for output in outputs:
        f = remove_example_usage(output)
        s = "def solution():\n" + f + "\nsolution()"
        clean_outputs.append(s)
        result_queue = multiprocessing.Queue()
        process = multiprocessing.Process(target=execute_output, args=(s, result_queue))
        
        process.start()
        process.join(timeout=1)  # Wait for 3 seconds for the process to complete
        
        if process.is_alive():
            process.terminate()  # Terminate the process if it's still running
            process.join()  # Wait for the process to cleanly terminate
            interpret_results.append("None")  # Append None if the process timed out
        else:
            try:
                # Retrieve the result if the process completed within the timeout
                result = result_queue.get_nowait()
                interpret_results.append(result)
            except queue.Empty:
                # Handle the case where the process finished without putting a result
                interpret_results.append("None")
        # try:
        #     local_scope = {}
        #     exec(s, globals(), local_scope)
        #     interpret_results.append(str(local_scope['solution']()))
        # except Exception as e:
        #     interpret_results.append("None")
    result_score, individual_score = helper.evaluate_prediction(interpret_results)

    return result_score, individual_score, clean_outputs, outputs, interpret_results

################################################################
# MAIN FUNCTION
################################################################
async def main(args):

    model_name_dict = {
        "codellama": "codellama/CodeLlama-7b-Instruct-hf",
        "mistral": "mistralai/Mistral-7B-Instruct-v0.2",
        "gpt": "gpt-3.5-turbo",
        "codellama-13b": "codellama/CodeLlama-13b-Instruct-hf"
    }
    model_name = model_name_dict[args.scr_model_name]
    if 'gpt' in args.scr_model_name:
        scr_llm = ChatOpenAI(
                            model_name=model_name,
                            # temperature=args.temperature,
                            temperature = 0.0,
                            max_retries=100,
                            max_tokens=3000,
                        )
    else:
        scr_llm = OpenAI(
                model_name=model_name,
                temperature=0.0,
                max_retries=100,
                openai_api_key='EMPTY',
                openai_api_base=f"http://localhost:{args.port}/v1",
                max_tokens= 1000,
            ) 
   
    task_helper = helper_dict[args.task](args)
    # load data and templates
    _, test_data = task_helper.load_data()
    _, test_data = task_helper.load_and_prepare_data("train"), task_helper.load_and_prepare_data("test")
    prompt = load_prompt(args)


    # check if directory exists. If not, make directory.
    if not os.path.exists(f'tasks/{args.task}/pal'):
        os.makedirs(f'tasks/{args.task}/pal')

    # paths
    # save_path = f"{args.task}/results/"+f"{args.scr_model_name.replace('/', '-')}_"+args.prompt_path.split("/")[-1].replace(".txt","")+f"_sample{args.num_sample}.json"
    save_path = f"tasks/{args.task}/pal/zero-shot_{args.scr_model_name}" + ".json"

    # evaluate newly generated instructions using the scorer model
    result_score, individual_score, clean_outputs, raw_outputs, interpret_results = await calculate_score(scr_llm, test_data, prompt, task_helper)

    output = dict()
    output['score'] = result_score
    output['inference'] = [{"raw_output": raw_outputs[i].split("\n"), "clean_output": clean_outputs[i].strip().split("\n"), "input": test_data[i]['input'], "interpret_results": interpret_results[i].split("\n"), "score": individual_score[i]} for i in range(len(clean_outputs))] 
    
    with open(save_path, "w") as f:
        json.dump(output, f, indent=4)
    

    # for inference in output['inference']:
    #     outputs = {"prediction": "\n\n".join(inference["output"])}
    #     outputs.update(inference["score"])

        # root_span = Trace(
        #     name="root_span",
        #     kind="llm",
        #     status_code = "success",
        #     metadata = {
        #         "model_name": args.scr_model_name,
        #         "task": args.task,
        #         "prompt":"\n".join(output["prompt"]),
        #         "score": output['score']
        #     },
        #     inputs={"query":"\n".join(inference["input"])},
        #     outputs=outputs,

        # )
        # root_span.log(name="openai_trace")
    
    # metadata_span = Trace(
    #     name="overall",
    #     k
    # )
if __name__ == "__main__":
    args = parse_args()
    asyncio.run(main(args))
