import argparse
from langchain_community.chat_models import ChatOpenAI
from langchain_community.llms import OpenAI
from tqdm import tqdm
import asyncio
import json
import random
import copy
import sys
import os
from langchain.callbacks import wandb_tracing_enabled
import wandb
from wandb.sdk.data_types.trace_tree import Trace


sys.path.append(os.environ.get('PROJECTPATH'))
from src.utils import *

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--planner", type=str, default="gpt")
    parser.add_argument("--scr_model_name", type=str, default="codellama")
    parser.add_argument("--num_sample", type=int, default=-1)
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("-r", "--retry_threshold", type=int, default=10)
    parser.add_argument("--task", type=str, required=True)
    return parser.parse_args()

def load_prompt():
    with open("prompts/code_prompt_on_instance.txt", "r") as f:
        prompt = f.read()
    return prompt

def load_templates(args):
    with open(f"tasks/{args.task}/scoring_prompt_template_code.txt", "r") as f:
        scoring_prompt = f.read()
    return scoring_prompt

def load_demonstration(task_name):
    with open(f"tasks/{task_name}/data.json", "r") as f:
        data = json.load(f)['examples']
    exemplars = data[0]['input']
    return exemplars

def load_example_instances_and_code_prompt(task_name):
    example_questions = load_demonstration(task_name)
    with open(f"tasks/{task_name}/prompts/code_prompt_for_instance.txt", "r") as f:
        code_prompt = f.read()
    return example_questions, code_prompt

def prepare_meta_prompt(meta_prompt, task_name, output_format_dict, target_instance, function_name):
    example_tasks_list = list(output_format_dict.keys())
    if task_name in example_tasks_list:
        example_tasks_list.remove(task_name)
    sampled_example_tasks = random.sample(example_tasks_list, 3)
    
    exemplar = "[Example 1]\nExample task instance:\n"
    for i, task in enumerate(sampled_example_tasks):
        task_instance, code_prompt = load_example_instances_and_code_prompt(task)
        exemplar += task_instance + "\n\nOutput Format:\n" + output_format_dict[task] + "\n\nCode prompt:\n" + code_prompt
        exemplar += f"\n\n[Example {i+2}]\nExample task instances:\n"
    exemplar  = f"Example task instances:\n{target_instance}\n\n"
    exemplar += f"Output format:\n{output_format_dict[task_name]}\n\n"
    exemplar += "Code prompt:"

    return meta_prompt.format(exemplars=exemplar, function_name=function_name)

def remove_example_usage(code_text):
    if "[Example 5]" in code_text:
        code_text = code_text.split("[Example 5]")[0].strip()
    last_def_index = code_text.rfind("def ")
    if last_def_index == -1:
        return code_text
    
    last_return_index = code_text.rfind("  return ", last_def_index)
    if last_return_index == -1:
        last_return_index = code_text.rfind("  pass", last_def_index)
    cut_index = code_text.find("\n", last_return_index)
    if cut_index == -1:
        return code_text
    return code_text[code_text.find("def "):cut_index+1].strip()

async def calculate_score_for_optimized_prompt(llm, data, scoring_prompt, optimized_prompts, helper):
    ''' evaluate an optimized instruction using scorer model '''
    # construct model inputs using instances in evaluation set
    list_of_model_inputs = [scoring_prompt.format(input_text=d['input_text'], prompt=optimized_prompt, function_name=helper.function_name) for d, optimized_prompt in zip(data, optimized_prompts)]
    
    outputs = await generate_concurrently(llm, list_of_model_inputs, args)
    result_score, individual_score = helper.evaluate_prediction(outputs)

    return result_score, individual_score, outputs, list_of_model_inputs

################################################################
# MAIN FUNCTION
################################################################
async def main(args):

    model_name_dict = {
        "codellama": "codellama/CodeLlama-7b-Instruct-hf",
        "mistral": "mistralai/Mistral-7B-Instruct-v0.2",
        "gpt": "gpt-3.5-turbo",
        "codellama-13b": "codellama/CodeLlama-13b-Instruct-hf",
        "llama-13b": "meta-llama/Llama-2-13b-hf",
    }
    model_name = model_name_dict[args.scr_model_name]
    if 'gpt' in args.scr_model_name:
        scr_llm = ChatOpenAI(
                            model_name=model_name,
                            # temperature=args.temperature,
                            temperature = 0.0,
                            max_retries=100,
                            max_tokens=1500,
                        )
    else:
        scr_llm = OpenAI(
                model_name=model_name,
                temperature=0.0,
                max_retries=100,
                openai_api_key='EMPTY',
                openai_api_base=f"http://localhost:{args.port}/v1",
                max_tokens=1500,
            )
    
    model_name = model_name_dict[args.planner]
    if 'gpt' in args.planner:
        planner_llm = ChatOpenAI(
                            model_name=model_name,
                            # temperature=args.temperature,
                            temperature = 0.0,
                            max_retries=100,
                            stop=["[Example"],
                            max_tokens=3000
                        )
    else:
        planner_llm = OpenAI(
                model_name=model_name,
                temperature=0.0,
                max_retries=100,
                openai_api_key='EMPTY',
                openai_api_base=f"http://localhost:{args.port}/v1",
                # stop=["[Example"],
                max_tokens= 2000
            ) 

    
    output_format_dict = {
        'temporal_sequences': "'(A)', '(B)', '(C)', ...",
        'reasoning_about_colored_objects': "'(A)', '(B)', '(C)', ...",
        'tracking_shuffled_objectives': "'(A)', '(B)', '(C)', ...",
        'dyck_languages': "A string of closing brakets seperated with a space.",
        'web_of_lies': "'Yes' or 'No'",
        'navigate': "'Yes' or 'No'",
        'geometric_shapes': "'(A)', '(B)', '(C)', ...",
    }
    task_helper = helper_dict[args.task](args)
    # load data and templates
    _, test_data = task_helper.load_data()
    _, test_data = task_helper.load_and_prepare_data("train"), task_helper.load_and_prepare_data("test")
    meta_prompt = load_prompt()
    scoring_prompt = load_templates(args)

    # check if directory exists. If not, make directory.
    if not os.path.exists(f'tasks/{args.task}/results'):
        os.makedirs(f'tasks/{args.task}/results')

    # paths
    save_path = f"tasks/{args.task}/results/instance_specific_cp_planner_{args.planner}_{args.scr_model_name}" + ".json"
    print(save_path)

    # evaluate newly generated instructions using the scorer model
    meta_prompts = [prepare_meta_prompt(meta_prompt, args.task, output_format_dict, d['input'], task_helper.function_name) for d in test_data]
    raw_cp = await generate_concurrently(planner_llm, meta_prompts, args.task)
    refined_cp = [remove_example_usage(r) for r in raw_cp]
    score_dic, individual_score, raw_prediction, list_of_model_inputs = await calculate_score_for_optimized_prompt(scr_llm, test_data, scoring_prompt, refined_cp, task_helper)
    output = dict()
    output['score'] = score_dic
    output['inference'] = [{"input": list_of_model_inputs[i].split("\n"), "output": raw_prediction[i].strip().split("\n"), "score": individual_score[i]} for i in range(len(raw_prediction))] 
    with open(save_path, "w") as f:
        json.dump(output, f, indent=4)

    # for inference in output['inference']:
    #     outputs = {"prediction": "\n\n".join(inference["output"])}
    #     outputs.update(inference["score"])

    #     root_span = Trace(
    #         name="root_span",
    #         kind="llm",
    #         status_code = "success",
    #         metadata = {
    #             "model_name": args.scr_model_name,
    #             "task": args.task,
    #             "prompt":"\n".join(output["prompt"]),
    #             "score": output['score']
    #         },
    #         inputs={"query":"\n".join(inference["input"])},
    #         outputs=outputs,

    #     )
    #     root_span.log(name="openai_trace")
    
    # metadata_span = Trace(
    #     name="overall",
    #     k
    # )
if __name__ == "__main__":
    args = parse_args()
    asyncio.run(main(args))
