from langchain.chat_models import ChatOpenAI
from api_key import *
from extrapolating_method import ceaser_extrapolating
from prompt_template_class import extrapolating_template
from langchain.output_parsers import PydanticOutputParser
from output_parser import ExtrapolatingParser
import matplotlib.pyplot as plt
from random import shuffle
from tqdm import tqdm, trange

output_parser = ExtrapolatingParser()
format_instructions = output_parser.get_format_instructions()
parser = PydanticOutputParser(pydantic_object=ExtrapolatingParser)

#for shift in range(1, 26):
accuracies = []
wrong_example = []
j = 0
llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=openai.api_key, temperature=0)
for reveal in tqdm([2, 3, 5, 7, 10], desc="Reveals"):
    accuracies.append([])
    for shift in trange(1, 25, desc=f"Reveal {reveal}"):
        examples = ceaser_extrapolating(shift=shift, reveal=reveal)
        prompt = extrapolating_template(examples, format_instructions)
        output = llm(prompt.to_messages())
        original, altered = output_parser.parse(output.content)
        correct_count = 0
        all_count = 26
        if len(altered) != len(original):
            print('Wrong Length')
            break
        elif len(altered) != 26:
            print('Not Full Length')
            break
        else:
            for ori, alt in zip(original, altered):
                if ord(alt)-ord(ori) == shift or ord(alt)-ord(ori) == shift-26:
                    correct_count += 1
                else:
                    wrong_example.append((reveal, shift, ori, alt, len(altered), all_count))
            accuracies[j].append(correct_count/all_count)
    j += 1

x = [i for i in range(1, 25)]
plt.xlabel("Shift")
plt.ylabel("Accuracy")
plt.title("Accuracy in different Shift in Right Order")
plt.plot(x, accuracies[0], label="reveal 2", linestyle='-', marker='o', markersize=5, color='blue')
plt.plot(x, accuracies[1], label="reveal 3", linestyle='--', marker='s', markersize=5, color='red')
plt.plot(x, accuracies[2], label="reveal 5", linestyle=':', marker='^', markersize=5, color='green')
plt.plot(x, accuracies[3], label="reveal 7", linestyle='-.', marker='d', markersize=5, color='orange')
plt.plot(x, accuracies[4], label="reveal 10", linestyle='-', marker='x', markersize=5, color='purple')
plt.xticks(range(1, 25, 2))
plt.grid(True)
plt.legend()
plt.tight_layout()

plt.show()



