#!/usr/bin/env python
# coding: utf-8

import os
import openai
from dotenv import load_dotenv
from tqdm import tqdm
from utils import check_matching
import pandas as pd
from langchain.prompts import PromptTemplate
import argparse

def evaluation(dataset_path, api_key, mode="normal", teacher_mode=False):
    load_dotenv()

    openai.api_key = api_key

    template = {
        "normal": """Choose the answer to the question only from options A, B, C, D.
        Question: {prompt}
        Choices:
        A) {a}.
        B) {b}.
        C) {c}.
        D) {d}.
        Answer:""",
        "CoT": """Choose the answer to the question only from options A, B, C, D.
        Question: {prompt}
        Choices:
        A) {a}.
        B) {b}.
        C) {c}.
        D) {d}.
        Answer: Let’s think step by step """
    }[mode]

    prompt = PromptTemplate(template=template, input_variables=['prompt', 'a', 'b', 'c', 'd'])

    df = pd.read_csv(dataset_path)
    dataset = df.apply(lambda example: prompt.format(
        prompt=example.question_stem, 
        a=example['A'], 
        b=example['B'], 
        c=example['C'], 
        d=example['D']
    ), axis=1)

    MODEL = "gpt-3.5-turbo"

    correct_preds = 0
    for text in tqdm(dataset):
        response = openai.ChatCompletion.create(
        model=MODEL,
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": text},
        ],
        temperature=0
    )
        content = response["choices"][0]["message"]["content"]
        if teacher_mode:
            with open("teacher_mode_output.txt", "a") as file:
                file.write(content + '\n')
        if check_matching(content, text['target']):
            correct_preds += 1

    acc = (correct_preds * 100) / len(dataset)
    return acc

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run model evaluation.")
    parser.add_argument("--dataset_path", type=str, default='openbookqa_train.csv', help="Path to the dataset CSV file.")
    parser.add_argument("--api_key", type=str, required=True, help="OpenAI API key.")
    parser.add_argument("--mode", type=str, choices=["normal", "CoT"], default="normal", help="Evaluation mode: normal or CoT.")
    parser.add_argument("--teacher_mode", action="store_true", help="Enable teacher mode to save generated text.")
    args = parser.parse_args()

    accuracy = evaluation(args.dataset_path, args.api_key, args.mode, args.teacher_mode)
    print(f"Accuracy: {accuracy}%")