from langchain.chat_models import ChatOpenAI
from api_key import *
from output_parser import CharacterMappingParser
from langchain.prompts.chat import SystemMessagePromptTemplate, ChatPromptTemplate
from langchain import LLMChain
from langchain.llms import OpenAI
from utlis import load_llama, get_prompt, B_INST, E_INST
from data_creator import CharacterMappingDataCreator
from prompt_template_class import CharacterMappingGenerator
from tqdm import tqdm
import argparse
from utlis import str2bool
import random
import numpy as np
from langchain.prompts import PromptTemplate

parser = argparse.ArgumentParser()
parser.add_argument("--num_examples", type=int, default=150)
parser.add_argument("--model_name", type=str, default="meta-llama/Llama-2-70b-chat-hf")#meta-llama/Llama-2-70b-chat-hf
parser.add_argument("--few_shot", type=str2bool, default=True)
parser.add_argument("--cot", type=str2bool, default=False)
parser.add_argument("--multiple_run", type=str2bool, default=False)
args = parser.parse_args()
if "gpt-3.5" in args.model_name or "gpt-4" in args.model_name:
    llm = ChatOpenAI(model=args.model_name, openai_api_key=openai.api_key, temperature=0)
elif "llama" in args.model_name:
    args.model_name = "meta-llama/Llama-2-13b-chat-hf"
    llm = load_llama(args.model_name)
else:
    llm = OpenAI(model=args.model_name, openai_api_key=openai.api_key, temperature=0)
output_parser = CharacterMappingParser()
meta_full_acc = []
meta_average_acc = []
print(args)
seed_list = [714, 123, 456] if args.multiple_run == True else [714]
for seed in seed_list:
    random.seed(seed)
    data_creator = CharacterMappingDataCreator(dataset="app_reviews", num_examples=args.num_examples)
    mapping_prompt = CharacterMappingGenerator()
    data_sample, fewshot_sample = data_creator.create()
    mapping_table = data_creator.generatele_mapping_table(shift=3, reveal=26)
    full_correct_count = 0
    acc_list = []
    all_prompt = mapping_prompt.instruction_following([x[0] for x in fewshot_sample], [x[1] for x in fewshot_sample], mapping_table, data_creator, chain_of_thought=args.cot, few_shot=args.few_shot)
    chat_prompt = ChatPromptTemplate.from_messages(all_prompt)
    if "gpt" in args.model_name or "davinci" in args.model_name:
        chain = LLMChain(llm=llm, prompt=chat_prompt)
    else:
        system = ""
        system += chat_prompt.messages[0].format().content
        instruction = ""
        example_idx = 0
        if len(chat_prompt.messages)> 4:
            for message in chat_prompt.messages[1:-1]:
                if message.additional_kwargs["name"] == "example_user":
                    if example_idx == 0:
                        instruction += message.format().content+E_INST
                    else:
                        instruction += "<s>"+B_INST+message.format().content+E_INST
                    example_idx += 1
                elif message.additional_kwargs["name"] == "example_assistant":
                    instruction += message.format().content + "</s>\\"
        system += "Additionally, please just output the answer with provided format and don't add any other explanation."
        instruction += chat_prompt.messages[-1].format(Rules=data_creator.mappingtable2text(mapping_table), Original="{Original}").content
        template = get_prompt(instruction, system)
        prompt = PromptTemplate(template=template, input_variables=["Original"])
        chain = LLMChain(prompt=prompt, llm=llm)
    for ori, alt in tqdm(data_sample):
        if "gpt" in args.model_name or "davinci" in args.model_name:
            output = chain.run(Original=ori, Rules=data_creator.mappingtable2text(mapping_table))
        else:
            output = chain.run(Original=ori)
        print(output)
        predicts = output_parser.parse_instruction(output)
        print(predicts)
        if predicts.lstrip().rstrip() == alt.lstrip().rstrip():
            full_correct_count += 1
        correct_count = 0
        for pred_char, gold_char in zip(predicts, alt):
            if pred_char == gold_char and pred_char.isalpha():
                correct_count += 1
        acc_list.append([correct_count / len(alt), predicts, alt])
    print(f"Average Acc: {sum([x[0] for x in acc_list]) / len(acc_list)}")
    print(f"Full Accuracy: {full_correct_count / len(data_sample)}")
    print(f"---------------")
    meta_full_acc.append(full_correct_count / len(data_sample))
    meta_average_acc.append(sum([x[0] for x in acc_list]) / len(acc_list))
for seed, average_acc, full_acc in zip(seed_list, meta_full_acc, meta_average_acc):
    print(f"Full Accuracy: {full_acc}")
    print(f"Average: {average_acc}")
    print(f"------------------{seed}------------------")
print(f"Meta Full Accuracy: {sum(meta_full_acc) / len(meta_full_acc)}")
print(f"Meta Average: {sum(meta_average_acc) / len(meta_average_acc)}")
print(f"STD of Full Accuracy: {np.std(meta_full_acc)}")
print(f"STD of Average: {np.std(meta_average_acc)}")





