from langchain.chat_models import ChatOpenAI
from api_key import *
from output_parser import CharacterMappingParser
from langchain import LLMChain
from langchain.llms import OpenAI
from langchain.prompts.chat import SystemMessagePromptTemplate, ChatPromptTemplate
from data_creator import CharacterMappingDataCreator
from prompt_template_class import CharacterMappingGenerator
import argparse
from utlis import str2bool
import random
import numpy as np
from utlis import load_llama, get_prompt, B_INST, E_INST
from langchain.prompts import PromptTemplate

parser = argparse.ArgumentParser()
parser.add_argument("--num_examples", type=int, default=150)
parser.add_argument("--model_name", type=str, default="text-davinci-003")
parser.add_argument("--few_shot", type=str2bool, default=False)
parser.add_argument("--cot", type=str2bool, default=False)
parser.add_argument("--multiple_run", type=str2bool, default=False)
args = parser.parse_args()
print(args)
# 0.56 0 0 False False 0.5667  0
# 0.6 0.019 True False 0.56 0.29
# True True
output_parser = CharacterMappingParser()
if "gpt-3.5" in args.model_name or "gpt-4" in args.model_name:
    llm = ChatOpenAI(model=args.model_name, openai_api_key=openai.api_key, temperature=0)
elif "llama" in args.model_name:
    args.model_name = "meta-llama/Llama-2-13b-chat-hf"
    llm = load_llama(args.model_name)
else:
    llm = OpenAI(model=args.model_name, openai_api_key=openai.api_key, temperature=0)
meta_average_acc = []
meta_full_acc = []
meta_validation_acc = []
seed_list = [714, 123] if args.multiple_run == True else [714]
for seed in seed_list:
    random.seed(seed)
    data_creator = CharacterMappingDataCreator(dataset="app_reviews", num_examples=args.num_examples)
    mapping_prompt = CharacterMappingGenerator()
    data_sample, fewshot_sample = data_creator.create()
    examples = data_creator.generatele_mapping_table(shift=3, reveal=26)
    string_rules = data_creator.mappingtable2text(examples)
    data_sample, clean_encoded_examples, noised_data_sample, sample_records = data_creator.random_example_noise(data_sample, noise_rate=0.5, rules=examples,noise_num=3, controller=False)
    fewshot_sample, few_shot_clean_encoded, noised_few_shot_example, few_shot_records = data_creator.random_example_noise(fewshot_sample, noise_rate=0.5, rules=examples, noise_num=3, controller=True)
    rectify_count = 0
    error_count = 0
    validation_correct_count = 0
    full_correct = 0
    full_explanation_count = 0
    total_count = len(data_sample)
    acc_list = []
    system_message_prompt, few_shot_prompt, prompt, all_prompt = mapping_prompt.validating_template(string_rules,
                                                                                                    fewshot_sample,
                                                                                                    few_shot_clean_encoded,
                                                                                                    noised_few_shot_example,
                                                                                                    few_shot_records,
                                                                                                    chain_of_thought=args.cot,
                                                                                                    few_shot_selection=args.few_shot)
    chat_prompt = ChatPromptTemplate.from_messages(all_prompt)
    if "gpt" in args.model_name or "davincii" in args.model_name:
        chain = LLMChain(llm=llm, prompt=chat_prompt)
    else:
        system = ""
        system += chat_prompt.messages[0].format(Rules=string_rules).content
        instruction = ""
        example_idx = 0
        if len(chat_prompt.messages) > 2:
            for message in chat_prompt.messages[1:-1]:
                if message.additional_kwargs["name"] == "example_user":
                    if example_idx == 0:
                        instruction += message.format().content + " " +E_INST
                    else:
                        instruction += "<s>" + B_INST + " "+message.format().content + " " + E_INST
                    example_idx += 1
                elif message.additional_kwargs["name"] == "example_assistant":
                    instruction += " "+ message.format().content + " </s>\\"+"\n"
        system += "Additionally, please just output the answer with provided format and don't add any other explanation."
        instruction += "<s>" + B_INST + " "+chat_prompt.messages[-1].format(Original="{Original}", Altered="{Altered}").content
        template = get_prompt(instruction, system)
        prompt = PromptTemplate(template=template, input_variables=["Original", "Altered"])
        chain = LLMChain(prompt=prompt, llm=llm)
    for example, clean_example, noised_example, record in zip(data_sample, clean_encoded_examples, noised_data_sample, sample_records):
        if "gpt" in args.model_name or "davinci" in args.model_name:
            output = chain.run(Original=example, Altered=noised_example, Rules=string_rules)
        else:
            output = chain.run(Original=example, Altered=noised_example)
        print(output)
        validation_result, rectified = output_parser.parse_validating(output, args.model_name)
        if record != None:
            full_explanation_count += 1
            local_error_count = 0
            for clean_char, noised_char in zip(clean_example.strip().strip("."), noised_example.strip().strip(".")):
                if clean_char != noised_char:
                    error_count += 1
                    local_error_count += 1
        if record != None and "Invalid" in validation_result:
            local = 0
            validation_correct_count += 1
            for clean_char, noised_char, recti_char in zip(clean_example.strip().strip("."), noised_example.strip().strip("."), rectified.replace("Altered: ","").strip().strip(".")):
                if clean_char != noised_char:
                    if clean_char == recti_char:
                        rectify_count += 1
                        local += 1
            if local_error_count == local:
                full_correct += 1
        elif record == None and "Valid" in validation_result:
            validation_correct_count += 1
    print(f"Validation Accuracy: {validation_correct_count/total_count}")
    print(f"Partial Rectify Acc: {rectify_count/error_count}")
    print(f"Full Rectify Acc: {full_correct/full_explanation_count}")
    print(f"---------------")
    meta_validation_acc.append(validation_correct_count/total_count)
    meta_average_acc.append(rectify_count/error_count)
    meta_full_acc.append(full_correct/full_explanation_count)
for idx, validation_acc, average_acc, full_acc in zip(seed_list, meta_validation_acc, meta_average_acc, meta_full_acc):
    print(f"Validation Accuracy: {validation_acc}")
    print(f"Explanation Accuracy: {average_acc}")
    print(f"Full Explanation Accuracy: {full_acc}")
    print(f"------Seed {idx}------")
print(f"Meta Validation Accuracy: {sum(meta_validation_acc)/len(meta_validation_acc)}")
print(f"STD of Validation Accuracy: {np.std(meta_validation_acc)}")
print(f"Meta Partial Rectify Acc: {sum(meta_average_acc)/len(meta_average_acc)}")
print(f"STD of Partial Rectify Acc: {np.std(meta_average_acc)}")
print(f"Meta Full Rectifu Acc: {sum(meta_full_acc)/len(meta_full_acc)}")
print(f"STD of Full Rectifu Acc: {np.std(meta_full_acc)}")






