from tqdm import tqdm
from prompt_template_class import GroupingGenerator
from langchain import LLMChain
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from api_key import *
from langchain.prompts.chat import SystemMessagePromptTemplate, ChatPromptTemplate
from output_parser import GroupingOutputParser
from data_creator import GroupingDataCreator
import argparse
from utlis import str2bool
import random
import numpy as np
from utlis import load_llama, get_prompt
from langchain.prompts import PromptTemplate
from utlis import load_llama, get_prompt, B_INST, E_INST

parser = argparse.ArgumentParser()
parser.add_argument("--num_examples", type=int, default=150)
parser.add_argument("--model_name", type=str, default="llama2")
parser.add_argument("--few_shot", type=str2bool, default=False)
parser.add_argument("--cot", type=str2bool, default=False)
parser.add_argument("--multiple_run", type=str2bool, default=False)
args = parser.parse_args()
if "gpt-3.5" in args.model_name or "gpt-4" in args.model_name:
    llm = ChatOpenAI(model=args.model_name, openai_api_key=openai.api_key, temperature=0)
elif "llama" in args.model_name:
    args.model_name = "meta-llama/Llama-2-13b-chat-hf"
    llm = load_llama(args.model_name)
else:
    llm = OpenAI(model=args.model_name, openai_api_key=openai.api_key, temperature=0)
output_parser = GroupingOutputParser()
#0.3667 0.8978 False False
#0.2333 0.8571 True False
#0.2333 0.8692 True True
# for num_polygons in [110, 120, 130]:
#     #num_polygons = 30  # Generate more polygons
#     num_rules = 25
#     sides_options = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
#     colors_options = ['red', 'blue', "white", "black", "yellow", "purple", "gray", "cyan", "brown", "indigo", "pink", "orange", "green", "violet", "magenta"]
#     materials_options = ['metal', 'plastic', "glass", "sliver", "gold", "copper", "bronze", "diamond", "jade"]
#     meta_full_acc = []
#     meta_average_acc = []
#     print(args)
#     if args.multiple_run == True and "gpt-3.5" in args.model_name:
#         seed_list = [714, 123, 456, 770, 493]
#     elif args.multiple_run == True and "gpt-4" in args.model_name:
#         seed_list = [714]
#     elif args.multiple_run == False:
#         seed_list = [714]
#     for seed in seed_list:
#         random.seed(seed)
#         grouping_creator = GroupingDataCreator(num_polygons, num_rules, sides_options, colors_options, materials_options)
#         prompt_generator = GroupingGenerator()
#         # Generate polygons and groups
#         few_shot_polygons = []
#         few_shot_groups = []
#         few_shot_selected_rules = []
#         for _ in tqdm(range(5)):
#             polygons, groups, attributes = grouping_creator.create()
#             few_shot_polygons.append(polygons)
#             few_shot_groups.append(groups)
#             few_shot_selected_rules.append(attributes)
#         acc_list = []
#         num_examples = args.num_examples
#         all_prompt = prompt_generator.instruction_following(few_shot_polygons, few_shot_groups, few_shot_selected_rules,
#                                                             chain_of_thought=args.cot,
#                                                             few_shot=args.few_shot)
#         full_correct_count = 0
num_polygons = 30  # Generate more polygons
num_rules = 15
sides_options = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
colors_options = ['red', 'blue', "while", "black", "yellow", "purple", "gray", "cyan", "brown", "indigo"]
materials_options = ['metal', 'plastic', "glass", "sliver", "gold", "copper", "bronze", "diamond", "jade"]
meta_full_acc = []
meta_average_acc = []
print(args)
seed_list = [714, 123] if args.multiple_run == True else [714]
for seed in seed_list:
    random.seed(seed)
    grouping_creator = GroupingDataCreator(num_polygons, num_rules, sides_options, colors_options, materials_options)
    prompt_generator = GroupingGenerator()
    # Generate polygons and groups
    few_shot_polygons = []
    few_shot_groups = []
    few_shot_selected_rules = []
    for _ in tqdm(range(2)):
        polygons, groups, attributes = grouping_creator.create()
        few_shot_polygons.append(polygons)
        few_shot_groups.append(groups)
        few_shot_selected_rules.append(attributes)
    acc_list = []
    num_examples = args.num_examples
    all_prompt = prompt_generator.instruction_following(few_shot_polygons, few_shot_groups, few_shot_selected_rules,
                                                        chain_of_thought=args.cot,
                                                        few_shot=args.few_shot)
    full_correct_count = 0
    for _ in tqdm(range(num_examples)):  # Adjust the range for your needs
        polygons, groups, selected_rules = grouping_creator.create()
        polygons_text, sides, materials, colors, results_text, rules_text = prompt_generator.generate_grouping_results(polygons, sides_options,
                                                                                                           colors_options, materials_options, groups)
        chat_prompt = ChatPromptTemplate.from_messages(all_prompt)
        if "gpt" in args.model_name or "davinci" in args.model_name:
            chain = LLMChain(llm=llm, prompt=chat_prompt)
        else:
            system = ""
            sides_text = str(sides_options).rstrip("]").lstrip("[")
            materials_text = str(materials_options).rstrip("]").lstrip("[")
            colors_text = str(colors_options).rstrip("]").lstrip("[")
            system += chat_prompt.messages[0].format(Colors=colors_text, Materials=materials_text, SidesNumber=sides_text).content
            instruction = ""
            example_idx = 0
            if len(chat_prompt.messages) > 2:
                for message in chat_prompt.messages[1:-1]:
                    if message.additional_kwargs["name"] == "example_user":
                        if example_idx == 0:
                            instruction += message.format().content + E_INST
                        else:
                            instruction += "<s>" + B_INST + message.format().content + E_INST
                        example_idx += 1
                    elif message.additional_kwargs["name"] == "example_assistant":
                        instruction += message.format().content + "</s>\\"
            system += "Additionally, please just output the answer with provided format and don't add any other explanation."
            instruction += chat_prompt.messages[-1].format(Polygons="{Polygons}", Rules="{Rules}",
                                                           ).content
            template = get_prompt(instruction, system)
            prompt = PromptTemplate(template=template, input_variables=["Polygons", "Rules"])
            chain = LLMChain(prompt=prompt, llm=llm)
        for _ in tqdm(range(num_examples)):  # Adjust the range for your needs
            polygons, groups, selected_rules = grouping_creator.create()
            polygons_text, sides, materials, colors, results_text, rules_text = prompt_generator.generate_grouping_results(polygons, sides_options,
                                                                                                                           colors_options,
                                                                                                                           materials_options,
                                                                                                                           groups)
            if "gpt" in args.model_name or "davinci" in args.model_name:
                output = chain.run(Polygons=polygons_text, SidesNumber=sides, Colors=colors, Materials=materials, Rules=rules_text)
            else:
                output = chain.run(Polygons=polygons_text, Rules=rules_text)
            print(output)
            pred_results = output_parser.parse_instruction(output)
            gold_results_text = prompt_generator.generate_example_text(groups)
            gold_results = output_parser.parse_instruction(gold_results_text)
            local_correct_count = 0
            for pred in pred_results:
                if pred_results[pred] in [gold_results[x] for x in gold_results.keys()]:
                    local_correct_count += 1
            if local_correct_count == len(gold_results):
                full_correct_count += 1
            acc_list.append([local_correct_count / len(gold_results), output, gold_results])
        print(f"Full Correct Accuracy: {full_correct_count/num_examples}")
        print(f"Average Grouping Accuracy: {sum([x[0] for x in acc_list])/len(acc_list)}")
        print(f"---------------")
        meta_full_acc.append(full_correct_count/num_examples)
        meta_average_acc.append(sum([x[0] for x in acc_list])/len(acc_list))
    for idx, avg_acc, full_acc in zip(seed_list, meta_average_acc, meta_full_acc):
        print(f"Average Grouping Accuracy: {avg_acc}")
        print(f"Full Correct Accuracy: {full_acc}")
        print(f"----Seed {idx}-------")
    print(f"Meta Full Accuracy: {sum(meta_full_acc) / len(meta_full_acc)}")
    print(f"Meta Average: {sum(meta_average_acc) / len(meta_average_acc)}")
    print(f"STD of Full Accuracy: {np.std(meta_full_acc)}")
    print(f"STD of Average: {np.std(meta_average_acc)}")







