from tqdm import tqdm
from prompt_template_class import GroupingGenerator
from langchain import LLMChain
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from api_key import *
from langchain.prompts.chat import SystemMessagePromptTemplate, ChatPromptTemplate
from output_parser import GroupingOutputParser
from data_creator import GroupingDataCreator
import argparse
from utlis import str2bool
import random
import numpy as np
from langchain.prompts import PromptTemplate
from utlis import load_llama, get_prompt, B_INST, E_INST

parser = argparse.ArgumentParser()
parser.add_argument("--num_examples", type=int, default=200)
parser.add_argument("--model_name", type=str, default="gpt-3.5-turbo-16k-0613")
parser.add_argument("--few_shot", type=str2bool, default=True)
parser.add_argument("--cot", type=str2bool, default=False)
parser.add_argument("--multiple_run", type=str2bool, default=False)
args = parser.parse_args()
if "gpt-3.5" in args.model_name or "gpt-4" in args.model_name:
    llm = ChatOpenAI(model=args.model_name, openai_api_key=openai.api_key, temperature=0)
elif "llama" in args.model_name:
    args.model_name = "meta-llama/Llama-2-13b-chat-hf"
    llm = load_llama(args.model_name)
else:
    llm = OpenAI(model=args.model_name, openai_api_key=openai.api_key, temperature=0)
output_parser = GroupingOutputParser()
prompt_generator = GroupingGenerator()
# 1.0 0.0 0.0  False False
# 0.567 0.0 0.0 True False
# 0.27 0 0 True True
num_polygons = 30  # Generate more polygons
num_rules = 15
sides_options = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
colors_options = ['red', 'blue', "while", "black", "yellow", "purple", "gray", "cyan", "brown", "indigo"]
materials_options = ['metal', 'plastic', "glass", "sliver", "gold", "copper", "bronze", "diamond", "jade"]
print(args)
meta_validation_acc = []
meta_full_acc = []
meta_average_acc = []

seed_list = [714, 123] if args.multiple_run == True else [714]

for seed in seed_list:
    random.seed(seed)
    grouping_creator = GroupingDataCreator(num_polygons, num_rules, sides_options, colors_options, materials_options)
    # Generate polygons and groups
    acc_list = []
    few_shot_polygons = []
    few_shot_groups = []
    few_shot_noised_groups = []
    few_shot_selected_rules = []
    for idx in range(2):
        if idx <= 0:
            determinstic = "Clean"
        else:
            determinstic = "Noise"
        polygons, groups, selected_rules = grouping_creator.create()
        rules, noised_rules = prompt_generator.noised_rules(groups, sides_options, colors_options, materials_options, deterministic=determinstic)
        few_shot_polygons.append(polygons)
        few_shot_groups.append(groups)
        few_shot_noised_groups.append(noised_rules)
        few_shot_selected_rules.append(selected_rules)
    validation_count = 0
    explanation_correct_count = 0
    total_explanation_correct_count = 0
    explanation_count = 0
    error_count = 0
    num_examples = args.num_examples
    all_prompt = prompt_generator.grouping_error_correction(few_shot_polygons, few_shot_groups, few_shot_noised_groups,
                                                            few_shot_selected_rules, chain_of_thought=args.cot, few_shot=args.few_shot)
    chat_prompt = ChatPromptTemplate.from_messages(all_prompt)
    if "gpt" in args.model_name or "davinci" in args.model_name:
        chain = LLMChain(llm=llm, prompt=chat_prompt)
    else:
        system = ""
        sides_text = str(sides_options).rstrip("]").lstrip("[")
        materials_text = str(materials_options).rstrip("]").lstrip("[")
        colors_text = str(colors_options).rstrip("]").lstrip("[")
        system += chat_prompt.messages[0].format(Colors=colors_text, Materials=materials_text,
                                                 SidesNumber=sides_text).content
        instruction = ""
        example_idx = 0
        if len(chat_prompt.messages) > 2:
            for message in chat_prompt.messages[1:-1]:
                if message.additional_kwargs["name"] == "example_user":
                    if example_idx == 0:
                        instruction += message.format().content + " " + E_INST + " "
                    else:
                        instruction += "<s>" + B_INST + " " +message.format().content + " " +E_INST + " "
                    example_idx += 1
                elif message.additional_kwargs["name"] == "example_assistant":
                    instruction += message.format().content + " " + "</s>\\"+ "\n"
        system += "Additionally, please just output the answer with provided format and don't add any other explanation."
        instruction += chat_prompt.messages[-1].format(Polygons="{Polygons}", Rules="{Rules}",
                                                       GroupingResult="{GroupingResult}"
                                                       ).content
        template = get_prompt(instruction, system)
        prompt = PromptTemplate(template=template, input_variables=["Polygons", "Rules", "GroupingResult"])
        chain = LLMChain(prompt=prompt, llm=llm)
    for _ in tqdm(range(num_examples)):  # Adjust the range for your needs
        polygons, groups, selected_rules = grouping_creator.create()
        rules, noised_rules = prompt_generator.noised_rules(groups, sides_options, colors_options, materials_options)
        polygons_text, sides, materials, colors, noised_results_text, noised_rules_text = prompt_generator.generate_grouping_results(polygons, sides_options,
                                                                                                           colors_options, materials_options, noised_rules)
        polygons_text, sides, materials, colors, results_text, rules_text = prompt_generator.generate_grouping_results(polygons,
                                                                                                           sides_options,
                                                                                                           colors_options,
                                                                                                           materials_options,
                                                                                                           groups)
        if "gpt" in args.model_name or "davinci" in args.model_name:
            output = chain.run(Polygons=polygons_text, SidesNumber=sides, Colors=colors, Materials=materials, GroupingResult=results_text,
                           Rules=noised_rules_text)
        else:
            output = chain.run(Polygons=polygons_text,
                               GroupingResult=results_text,
                               Rules=noised_rules_text)
        print(output)
        predict, rectified_rules = output_parser.parse_error_correction(output)
        if rules != noised_rules:
            explanation_count += 3
            error_count += 1
        if predict == "Yes" and rules == noised_rules:
            validation_count += 1
        elif predict == "No" and rules != noised_rules:
            validation_count += 1
            local_explanation_correct_count = 0
            for idx, x in enumerate(rectified_rules):
                for idy, single_rule in enumerate(x):
                    try:
                        sides, color, material = single_rule.split(",")
                        sides = sides.replace(" Sides", "").replace("Polygons with ", "").strip()
                        color = color.replace(" Color", "").strip()
                        material = material.replace(" and ", "").replace(" Material", "").replace(" should be grouped together", "").strip().strip(".")
                    except ValueError:
                        sides, color, material = "-", "-", "-"
                    rectified_rules[idx][idy] = (f"{sides}-sides+{color}+{material}")
            rectified_rule_list = [x[1] for x in rectified_rules]
            for key in rules.keys():
                if key not in noised_rules.keys():#wrong rule discovered
                    if key in rectified_rule_list:
                        correct_rule = rectified_rule_list[rectified_rule_list.index(key)]
                        if correct_rule in rules.keys():
                            if rules[correct_rule] == rules[key]:
                                explanation_correct_count += 1
                                local_explanation_correct_count += 1
            if local_explanation_correct_count == 3 and len(rectified_rule_list) == 3:
                total_explanation_correct_count += 1
    print(f"The Correct Validation Accuracy: {validation_count/num_examples}")
    print(f"The Partial Explanation Accuracy: {explanation_correct_count/explanation_count}")
    print(f"The Full Explanation Accuracy: {total_explanation_correct_count/error_count}")
    print(f"---------------")
    meta_validation_acc.append(validation_count/num_examples)
    meta_average_acc.append(explanation_correct_count/explanation_count)
    meta_full_acc.append(total_explanation_correct_count/error_count)
for idx, valid_acc, avg_acc, full_acc in zip(seed_list, meta_validation_acc,meta_average_acc, meta_full_acc):
    print(f"Validation Accuracy: {valid_acc}")
    print(f"Average Grouping Accuracy: {avg_acc}")
    print(f"Full Correct Accuracy: {full_acc}")
    print(f"----Seed {idx}-------")
print("Meta Average Validation Accuracy: ", sum(meta_validation_acc)/len(meta_validation_acc))
print("STD of Meta Average Validation Accuracy: ", np.std(meta_validation_acc))
print("Meta Average Explanation Accuracy: ", sum(meta_average_acc)/len(meta_average_acc))
print("STD of Meta Average Explanation Accuracy: ", np.std(meta_average_acc))
print("Meta Average Full Accuracy: ", sum(meta_full_acc)/len(meta_full_acc))
print("STD of Meta Average Full Accuracy: ", np.std(meta_full_acc))



