#!/usr/bin/env python
# coding: utf-8

import torch
from peft import PeftModel, LoraConfig
import os
import pandas as pd
from langchain.prompts import PromptTemplate
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    BitsAndBytesConfig,
)
import argparse

def evaluation(dataset_path, token, mode="normal", teacher_mode=False):
    os.system(f'huggingface-cli login --token "{token}"')

    name = "meta-llama/Llama-2-70b-chat-hf"
    tokenizer = AutoTokenizer.from_pretrained(name)
    tokenizer.pad_token_id = tokenizer.eos_token_id  

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )
    model = AutoModelForCausalLM.from_pretrained(
        name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
    generation_pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        trust_remote_code=True,
        device_map="auto",
    )

    template = {
        "normal": """Choose the answer to the question only from options A, B, C, D.
        Question: {prompt}
        Choices:
        A) {a}.
        B) {b}.
        C) {c}.
        D) {d}.
        Answer:""",
        "CoT": """Choose the answer to the question only from options A, B, C, D.
        Question: {prompt}
        Choices:
        A) {a}.
        B) {b}.
        C) {c}.
        D) {d}.
        Answer: Let’s think step by step """
    }[mode]

    prompt = PromptTemplate(template=template, input_variables=['prompt', 'a', 'b', 'c', 'd'])

    df = pd.read_csv(dataset_path)
    dataset = df.apply(lambda example: prompt.format(
        prompt=example.question_stem, 
        a=example['A'], 
        b=example['B'], 
        c=example['C'], 
        d=example['D']
    ), axis=1)

    correct_preds = 0
    for text in dataset:
        answer = generation_pipe(text, max_length=128)[0]['generated_text']
        if teacher_mode:
            with open("teacher_Llama-2-70.txt", "a") as file:
                file.write(answer + '\n')
        # Assuming you have a function to check if the generated answer is correct
        if check_matching(answer, text):
            correct_preds += 1

    acc = (correct_preds * 100) / len(dataset)
    return acc

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run model evaluation.")
    parser.add_argument("--dataset_path", type=str, default='openbookqa_train.csv', help="Path to the dataset CSV file.")
    parser.add_argument("--token", type=str, required=True, help="HuggingFace CLI token.")
    parser.add_argument("--mode", type=str, choices=["normal", "CoT"], default="normal", help="Evaluation mode: normal or CoT.")
    parser.add_argument("--teacher_mode", action="store_true", help="Enable teacher mode to save generated text.")
    args = parser.parse_args()

    accuracy = evaluation(args.dataset_path, args.token, args.mode, args.teacher_mode)
    print(f"Accuracy: {accuracy}%")
