"""
Queries a language model with prompts and stimuli saved in data/stimuli

save format:
model, prompt_domian, prompt_id, heuristic, correct, accuracy
TODO: write code.
"""

import argparse
import config
import csv
import glob
import json
import os
import pathlib
import utils

import numpy as np

from dataclasses import dataclass
from instructions import Instruction
from minicons import scorer
from torch.utils.data import DataLoader
from tqdm import tqdm, trange


@dataclass
class Prompt:
    string: str
    domain: str
    id: int
    heuristic: str


def read_prompt(path):
    id, heuristic = path.replace(".txt", "").split("/")[-1].split("_")
    domain, id = id.split("-")
    id = int(id)
    with open(path, "r") as f:
        string = f.read()
    return Prompt(string, domain, id, heuristic)


def main(args):
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    model_name = args.model
    batch_size = args.batch_size
    device = args.device
    n_workers = args.n_workers
    prompt_path = args.prompts

    results_file = f"{args.results_dir}/{model_name.replace('/', '_')}.csv"
    predictions_dir = f"{args.results_dir}/predictions"

    instruction_dir = f"data/instructions/comps-qa"

    # model
    lm = scorer.IncrementalLMScorer(model_name, device=device)

    prompt_files = sorted(glob.glob(f"{prompt_path}/*.txt"))
    prompts = [read_prompt(path) for path in prompt_files]

    results = []
    predictions = []

    print("Zero-shot results")
    for file in os.listdir(instruction_dir)[:1]:
        instruction_id = file.replace(".json", "")

        filepath = os.path.join(instruction_dir, file)
        with open(filepath, "r") as f:
            instruction_dict = json.load(f)
        instruction_obj = instruction_dict['indomain']
        instruction = Instruction(**instruction_obj['incontext'])

        for pattern in config._TEST_PATTERNS:
            test_set_path = f"data/stimuli/comps-qa/test_{pattern}.jsonl"
            test_set = utils.read_jsonl(test_set_path)
            batches = DataLoader(
                test_set,
                batch_size=min(batch_size, 256),
                num_workers=n_workers,
            )

            correctness = []
            scores = []

            for batch in tqdm(batches):
                acceptable = batch["acceptable"]
                unacceptable = batch["unacceptable"]
                qa_prompt = batch["prompt"]
                qa_prompts = []
                for qap in qa_prompt:
                    instruction.construct("{prompt}", qap)
                    instruction.zero_shot(**instruction_obj["zero-shot"])
                    instruction.construct("", qap)
                    qa_prompts.append(instruction.string)

                print(qa_prompts[0], acceptable[0])

                # logprobs
                acc_scores = lm.conditional_score(qa_prompts, acceptable)
                unacc_scores = lm.conditional_score(qa_prompts, unacceptable)

                for acc_score, unacc_score in zip(acc_scores, unacc_scores):
                    correctness.append(int(acc_score > unacc_score))
                    scores.append((acc_score, unacc_score))

            accuracy = np.mean(correctness)
            results.append(
                [
                    model_name,
                    instruction_id,
                    "zero-shot",
                    0,
                    "none",
                    0,
                    pattern,
                    accuracy,
                ]
            )

            for i, (c, (a, u)) in enumerate(zip(correctness, scores)):
                predictions.append(
                    [
                        model_name,
                        i,
                        instruction_id,
                        "zero-shot",
                        0,
                        "none",
                        0,
                        pattern,
                        c,
                        a,
                        u,
                    ]
                )

    print(predictions[0])
    print(qa_prompts[-3:], acceptable[-3:], unacceptable[-3:])

    print(predictions[-1])
    print(lm.conditional_score(qa_prompts[-3:] + qa_prompts[-3:], acceptable[-3:] + unacceptable[-3:]))
    
    
    # print(predictions)

    print(results)
    # for prompt in prompts:
    #     for file in os.listdir(instruction_dir):
    #         instruction_id = file.replace(".json", "")

    #         filepath = os.path.join(instruction_dir, file)
    #         with open(filepath, "r") as f:
    #             instruction_dict = json.load(f)
    #         instruction_obj = instruction_dict[prompt.domain]["incontext"]
    #         instruction = Instruction(**instruction_obj)

    #         prompt_strings = prompt.string.split("\n")
    #         prompt_length = len(prompt_strings)
    #         print(
    #             f"Processing Prompt w/ Instruction: {instruction_id} Domain: {prompt.domain} ID: {prompt.id} Heuristic: {prompt.heuristic}"
    #         )
    #         for n_prompt in trange(1, prompt_length + 1):
    #             prompt_string = "\n".join(prompt_strings[:n_prompt])
    #             for pattern in config._TEST_PATTERNS:
    #                 test_set_path = (
    #                     f"data/stimuli/comps-qa/test_{pattern}.jsonl"
    #                 )
    #                 test_set = utils.read_jsonl(test_set_path)
    #                 batches = DataLoader(
    #                     test_set, batch_size=batch_size, num_workers=n_workers
    #                 )

    #                 correctness = []
    #                 scores = []

    #                 for batch in batches:
    #                     acceptable = batch["acceptable"]
    #                     unacceptable = batch["unacceptable"]
    #                     qa_prompt = batch["prompt"]
    #                     qa_prompts = []
    #                     for qap in qa_prompt:
    #                         instruction.construct(prompt_string, qap)
    #                         qa_prompts.append(instruction.string)

    #                     print(qa_prompts[0])

                        # logprobs
                        # acc_scores = lm.conditional_score(
                        #     qa_prompts, acceptable
                        # )
                        # unacc_scores = lm.conditional_score(
                        #     qa_prompts, unacceptable
                        # )

                        # for acc_score, unacc_score in zip(
                        #     acc_scores, unacc_scores
                        # ):
                        #     correctness.append(int(acc_score > unacc_score))
                        #     scores.append((acc_score, unacc_score))

                    # accuracy = np.mean(correctness)
                    # results.append(
                    #     [
                    #         model_name,
                    #         instruction_id,
                    #         prompt.domain,
                    #         prompt.id,
                    #         prompt.heuristic,
                    #         n_prompt,
                    #         pattern,
                    #         accuracy,
                    #     ]
                    # )
                    # for i, (c, (a, u)) in enumerate(zip(correctness, scores)):
                    #     predictions.append(
                    #         [
                    #             model_name,
                    #             i,
                    #             instruction_id,
                    #             prompt.domain,
                    #             prompt.id,
                    #             prompt.heuristic,
                    #             n_prompt,
                    #             pattern,
                    #             c,
                    #             a,
                    #             u,
                    #         ]
                    #     )

    # write results to csv
    # pathlib.Path(args.results_dir).mkdir(exist_ok=True, parents=True)

    # make predictions dir
    # pathlib.Path(predictions_dir).mkdir(exist_ok=True, parents=True)

    # with open(results_file, "w") as f:
    #     writer = csv.writer(f)
    #     writer.writerow(
    #         [
    #             "model",
    #             "instruction_type",
    #             "prompt_domain",
    #             "prompt_id",
    #             "heuristic",
    #             "prompt_length",
    #             "correct",
    #             "accuracy",
    #         ]
    #     )
    #     writer.writerows(results)

    # with open(
    #     f"{predictions_dir}/{model_name.replace('/', '_')}_predictions.csv",
    #     "w",
    # ) as f:
    #     writer = csv.writer(f)
    #     writer.writerow(
    #         [
    #             "model",
    #             "idx",
    #             "instruction_type",
    #             "prompt_domain",
    #             "prompt_id",
    #             "heuristic",
    #             "prompt_length",
    #             "correct",
    #             "prediction",
    #             "acc_logprob",
    #             "unacc_logprob",
    #         ]
    #     )
    #     writer.writerows(predictions)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", type=str, default="gpt2")
    parser.add_argument("--batch_size", "-b", type=int, default=64)
    parser.add_argument("--device", "-d", type=str, default="cuda:0")
    parser.add_argument("--n_workers", "-n", type=int, default=8)
    parser.add_argument(
        "--prompts", "-p", type=str, default="data/stimuli/comps-qa/prompts"
    )
    parser.add_argument(
        "--results_dir", "-r", type=str, default="data/results/comps-qa-instructions"
    )
    args = parser.parse_args()

    main(args)
