from encoding_method import colors
from tqdm import tqdm
from prompt_template_class import OrderingGenerator
from langchain.prompts.chat import SystemMessagePromptTemplate, ChatPromptTemplate
from langchain import LLMChain
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from api_key import *
from output_parser import OrderingOutputParser
from data_creator import OrderingDataCreator
import argparse
from utlis import str2bool
import random
import numpy as np
from langchain.prompts import PromptTemplate
from utlis import load_llama, get_prompt, B_INST, E_INST
parser = argparse.ArgumentParser()
parser.add_argument("--num_examples", type=int, default=200)
parser.add_argument("--model_name", type=str, default="text-davinci-003")
parser.add_argument("--few_shot", type=str2bool, default=False)
parser.add_argument("--cot", type=str2bool, default=False)
parser.add_argument("--multiple_run", type=str2bool, default=False)
args = parser.parse_args()
print(args)
#1.0 0.2593 0.00 False False
#1.0 0.0.1852 0.00 False False
#1.0 0.1852 0.04 True True

if "gpt-3.5" in args.model_name or "gpt-4" in args.model_name:
    llm = ChatOpenAI(model=args.model_name, openai_api_key=openai.api_key, temperature=0)
elif "llama" in args.model_name:
    args.model_name = "meta-llama/Llama-2-13b-chat-hf"
    llm = load_llama(args.model_name)
else:
    llm = OpenAI(model=args.model_name, openai_api_key=openai.api_key, temperature=0)
parser = OrderingOutputParser()
prompt_generator = OrderingGenerator()
meta_full_acc = []
meta_average_acc = []
meta_validation_acc = []
seed_list = [714, 123, 526] if args.multiple_run == True else [714]
for seed in seed_list:
    random.seed(seed)
    ordering_creator = OrderingDataCreator(colors[:20])
    prioritys, examples, unsorted_example, explanations = ordering_creator.create_normal(num_examples=args.num_examples,
                                                                                  repeat=False,
                                                                                  repeat_num=5,
                                                                                  missing=True,
                                                                                  missing_num=False)
    few_shot_prioritys, prioritys = prioritys[:5], prioritys[5:]
    few_shot_examples, examples = examples[:5], examples[5:]
    few_shot_explanations, explanations = explanations[:5], explanations[5:]

    few_shot_noised_examples = []
    few_shot_noised_index_list = []
    for idx, priority in enumerate(few_shot_prioritys):
        if idx <= len(few_shot_prioritys) // 2:
            determinstic = "Clean"
        else:
            determinstic = "Noise"
        noised_dict, noised_index_list = prompt_generator.noised_ordering(priority, determinstic=determinstic)
        few_shot_noised_examples.append(noised_dict)
        few_shot_noised_index_list.append(noised_index_list)
    acc_list = []

    validation_correct_count = 0
    explanation_correct_count = 0
    explanation_count = 0
    full_correct_count = 0
    full_explanation_count = 0
    all_prompt = prompt_generator.ordering_results_validating(few_shot_examples, few_shot_noised_examples,
                                                              few_shot_noised_index_list, few_shot_explanations,
                                                              few_shot_prioritys,
                                                              chain_of_thought=args.cot, few_shot=args.few_shot)
    chat_prompt = ChatPromptTemplate.from_messages(all_prompt)
    if "gpt" in args.model_name or "davinci" in args.model_name:
        chain = LLMChain(llm=llm, prompt=chat_prompt)
    else:
        system = ""
        system += chat_prompt.messages[0].format(colors=str(colors[:20])).content
        instruction = ""
        example_idx = 0
        if len(chat_prompt.messages) > 2:
            for message in chat_prompt.messages[1:-1]:
                if message.additional_kwargs["name"] == "example_user":
                    if example_idx == 0:
                        instruction += message.format().content + " " + E_INST + " "
                    else:
                        instruction += "<s>" + B_INST + " " +message.format().content + " " +E_INST + " "
                    example_idx += 1
                elif message.additional_kwargs["name"] == "example_assistant":
                    instruction += message.format().content + " " + "</s>\\"+ "\n"
        system += "Additionally, please just output the answer with provided format and don't add any other explanation."
        instruction += chat_prompt.messages[-1].format(color_preference="{color_preference}",
                                                       OrderedLists="{OrderedLists}",
                                                       ).content
        template = get_prompt(instruction, system)
        prompt = PromptTemplate(template=template,
                                input_variables=["color_preference", "OrderedLists"])
        chain = LLMChain(prompt=prompt, llm=llm)
    for priority, example, explanation in tqdm(zip(prioritys, examples, explanations)):
        noised_example, noised_index_list = prompt_generator.noised_ordering(priority)
        example_text = prompt_generator.generate_example_text(noised_example)
        color_preferece_text = prompt_generator.generate_preference_text(priority)
        if "gpt" in args.model_name or "davinci" in args.model_name:
            output = chain.run(colors=str(colors[:20]), OrderedLists=example_text, color_preference=color_preferece_text)
        else:
            output = chain.run(OrderedLists=example_text, color_preference=color_preferece_text)
        print(output)
        predict, rectifiled_results = parser.parse_validating(output)
        if noised_index_list != []:
            explanation_count += len(noised_index_list[0])
            full_explanation_count += 1
        if predict == "Yes" and noised_index_list == []:
            validation_correct_count += 1
        elif predict == "No" and noised_index_list != []:
            validation_correct_count += 1
            local_full_correct_count = 0
            for noised_idx, original_index in zip(noised_index_list[0], noised_index_list[1]):
                if (noised_example[noised_idx], example[0][original_index]) in rectifiled_results:
                    explanation_correct_count += 1
                    local_full_correct_count += 1
            if local_full_correct_count == len(noised_index_list[0]) and len(rectifiled_results) == len(noised_index_list[0                                              ]):
                full_correct_count += 1
    print(f"Validation Accuracy: {validation_correct_count/len(prioritys)}")
    print(f"Explanation Accuracy:{explanation_correct_count/explanation_count}")
    print(f"Full Correction Accuracy: {full_correct_count/full_explanation_count}")
    print(f"---------------")
    meta_validation_acc.append(validation_correct_count/len(prioritys))
    meta_average_acc.append(explanation_correct_count/explanation_count)
    meta_full_acc.append(full_correct_count/full_explanation_count)
for idx, validation_acc, average_acc, full_acc in zip(seed_list, meta_validation_acc, meta_average_acc, meta_full_acc):
    print(f"Validation Accuracy: {validation_acc}")
    print(f"Explanation Accuracy: {average_acc}")
    print(f"Full Explanation Accuracy: {full_acc}")
    print(f"------Seed {idx}------")
print(f"Meta Validation Accuracy: {np.mean(meta_validation_acc)}")
print(f"STD of Meta Validation Accuracy: {np.std(meta_validation_acc)}")
print(f"Meta Explanation Accuracy: {np.mean(meta_average_acc)}")
print(f"STD of Meta Explanation Accuracy: {np.std(meta_average_acc)}")
print(f"Meta Full Correction Accuracy: {np.mean(meta_full_acc)}")
print(f"STD of Meta Full Correction Accuracy: {np.std(meta_full_acc)}")



