"""
Queries a language model with prompts and stimuli saved in data/stimuli

save format:
model, prompt_domian, prompt_id, heuristic, correct, accuracy
TODO: write code.
"""

import argparse
import config
import csv
import glob
import os
import pathlib
import random
import torch
import utils

import numpy as np

from dataclasses import dataclass
from minicons import scorer
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
from transformers import AutoTokenizer, BitsAndBytesConfig


@dataclass
class Prompt:
    string: str
    domain: str
    id: int
    heuristic: str


def read_prompt(path):
    id, heuristic = path.replace(".txt", "").split("/")[-1].split("_")
    domain, id = id.split("-")
    id = int(id)
    with open(path, "r") as f:
        string = f.read()
    return Prompt(string, domain, id, heuristic)


def main(args):
    random.seed(42)
    indices = list(range(6))
    random.shuffle(indices)

    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    model_name = args.model
    batch_size = args.batch_size
    device = args.device
    n_workers = args.n_workers
    prompt_path = args.prompts

    if args.randomize:
        args.results_dir = f"{args.results_dir}/randomized"

    results_file = f"{args.results_dir}/{model_name.replace('/', '_')}.csv"
    predictions_dir = f"{args.results_dir}/predictions"

    predictions_file = (
        f"{predictions_dir}/{model_name.replace('/', '_')}_predictions.csv"
    )

    # model
    if args.quantize:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
        lm = scorer.IncrementalLMScorer(
            model_name, device=device, quantization_config=bnb_config
        )
    else:
        lm = scorer.IncrementalLMScorer(model_name, device=device)
    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="right")
    tokenizer.pad_token_id = tokenizer.eos_token_id
    lm.tokenizer = tokenizer

    prompt_files = glob.glob(f"{prompt_path}/*.txt")
    if args.randomize:
        prompt_files = [p for p in prompt_files if "alt-first" in p]
    prompts = [read_prompt(path) for path in prompt_files]

    results = []
    predictions = []

    print("Zero-shot results")
    for pattern in config._TEST_PATTERNS:
        test_set_path = f"data/stimuli/comps/test_{pattern}.jsonl"
        test_set = utils.read_jsonl(test_set_path)
        batches = DataLoader(test_set, batch_size=batch_size, num_workers=n_workers)

        correctness = []

        for batch in tqdm(batches):
            acceptable = batch["acceptable"]
            unacceptable = batch["unacceptable"]
            properties = batch["property_phrase"]
            # acc_stimuli = [f"{prompt.string}\n{a}" for a in acceptable]
            # unacc_stimuli = [f"{prompt.string}\n{u}" for u in unacceptable]

            # logprobs
            acc_scores = lm.conditional_score(acceptable, properties)
            unacc_scores = lm.conditional_score(unacceptable, properties)

            for acc_score, unacc_score in zip(acc_scores, unacc_scores):
                correctness.append(int(acc_score > unacc_score))

        accuracy = np.mean(correctness)
        results.append(
            [
                model_name,
                "zero-shot",
                0,
                "none",
                0,
                pattern,
                accuracy,
            ]
        )

        for i, c in enumerate(correctness):
            predictions.append([model_name, i, "zero-shot", 0, "none", 0, pattern, c])

    # with open(results_file.replace(".csv", "-zeroshot.csv"), "w") as f:
    #     writer = csv.writer(f)
    #     writer.writerow(
    #         [
    #             "model",
    #             "prompt_domain",
    #             "prompt_id",
    #             "heuristic",
    #             "correct",
    #             "accuracy",
    #         ]
    #     )
    #     writer.writerows(results)

    # results = []

    # randomize indices (upto 6)
    for prompt in prompts:
        print(f"Processing {prompt.id}")
        prompt_strings = prompt.string.split("\n")
        if args.randomize:
            prompt_strings = [prompt_strings[i] for i in indices]
        prompt_length = len(prompt_strings)
        for n_prompt in trange(1, prompt_length + 1):
            prompt_string = "\n".join(prompt_strings[:n_prompt])
            for pattern in config._TEST_PATTERNS:
                test_set_path = f"data/stimuli/comps/test_{pattern}.jsonl"
                test_set = utils.read_jsonl(test_set_path)
                batches = DataLoader(
                    test_set, batch_size=batch_size, num_workers=n_workers
                )

                correctness = []

                for batch in batches:
                    acceptable = batch["acceptable"]
                    unacceptable = batch["unacceptable"]
                    properties = batch["property_phrase"]
                    acc_stimuli = [f"{prompt_string}\n{a}" for a in acceptable]
                    unacc_stimuli = [f"{prompt_string}\n{u}" for u in unacceptable]

                    # logprobs
                    acc_scores = lm.conditional_score(acc_stimuli, properties)
                    unacc_scores = lm.conditional_score(unacc_stimuli, properties)

                    for acc_score, unacc_score in zip(acc_scores, unacc_scores):
                        correctness.append(int(acc_score > unacc_score))

                accuracy = np.mean(correctness)
                results.append(
                    [
                        model_name,
                        prompt.domain,
                        prompt.id,
                        prompt.heuristic,
                        n_prompt,
                        pattern,
                        accuracy,
                    ]
                )

                for i, c in enumerate(correctness):
                    predictions.append(
                        [
                            model_name,
                            i,
                            prompt.domain,
                            prompt.id,
                            prompt.heuristic,
                            n_prompt,
                            pattern,
                            c,
                        ]
                    )

    # write results to csv
    pathlib.Path(args.results_dir).mkdir(exist_ok=True, parents=True)

    # make predictions dir
    pathlib.Path(predictions_dir).mkdir(exist_ok=True, parents=True)

    # if args.randomize:
    #     results_file = results_file.replace(".csv", "-randomized.csv")
    #     predictions_file = f"{predictions_dir}/{model_name.replace('/', '_')}-randomized_predictions.csv"
    # else:
    #     predictions_file = (
    #         f"{predictions_dir}/{model_name.replace('/', '_')}_predictions.csv"
    #     )

    with open(results_file, "w") as f:
        writer = csv.writer(f)
        writer.writerow(
            [
                "model",
                "prompt_domain",
                "prompt_id",
                "heuristic",
                "prompt_length",
                "correct",
                "accuracy",
            ]
        )
        writer.writerows(results)

    with open(
        predictions_file,
        "w",
    ) as f:
        writer = csv.writer(f)
        writer.writerow(
            [
                "model",
                "idx",
                "prompt_domain",
                "prompt_id",
                "heuristic",
                "prompt_length",
                "correct",
                "predictions",
            ]
        )
        writer.writerows(predictions)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", type=str, default="gpt2")
    parser.add_argument("--batch_size", "-b", type=int, default=64)
    parser.add_argument("--device", "-d", type=str, default="cuda:0")
    parser.add_argument("--n_workers", "-n", type=int, default=8)
    parser.add_argument(
        "--prompts", "-p", type=str, default="data/stimuli/comps/prompts"
    )
    parser.add_argument("--results_dir", "-r", type=str, default="data/results/comps")
    parser.add_argument("--quantize", "-q", action="store_true")
    parser.add_argument("--randomize", "-w", action="store_true")
    args = parser.parse_args()

    main(args)
