"""
Code for creating COMPS prompts that follow position-based heuristics.

For instance, a set of prompts could be compatible with a heuristic that 
consistently associates the queried property with the first, as opposed to 
the most recent concept.

Template:

{optional_instruction}

{optional_prompt} A {pseudoword1} {relation-verbalizer} {POS}. A {psuedoword2} {relation-verbalizer} {NEG}. Therefore, a {pseudoword1} {property}.\n x 6

COMPS-QA: {same template} Which out of {pseudoword1} and {pseudoword2} {property-phrase}? 

'first': all positive are first
'second': all positive are second
'alt-first': alternate (1-2-1) 
'alt-second': alternate (2-1-2)
"""
import config
import json
import pathlib
import random
import utils

from collections import defaultdict
from dataclasses import dataclass
from semantic_memory import memory


@dataclass
class Comps:
    id: int
    stimuli_type: str
    property_phrase: str
    positive: str
    negative: str


def read_prompts(path):
    prompts = defaultdict(list)
    with open(path, "r") as f:
        for line in f:
            entry = json.loads(line)
            prompts[entry["id"]].append(entry)
    return prompts


def read_test(path):
    test_data = []
    with open(path, "r") as f:
        for line in f:
            test_data.append(json.loads(line))
    return test_data


class CompsPrompt:
    def __init__(self, verbalizer, world_path="../world/data/"):
        self.verbalizer = verbalizer

        self.world = memory.Memory(
            concept_path=f"{world_path}/concept_senses.csv",
            feature_path=f"{world_path}/xcslb_compressed.csv",
            matrix_path=f"{world_path}/concept_matrix.txt",
            feature_metadata=f"{world_path}/feature_lexicon.csv",
        )
        self.world.create()

    def generate_premise(self, concept, subordinate):
        return f"A {subordinate} {self.verbalizer} {concept}."

    def generate_prefix(self, comp_obj, pseudowords, correct="first"):
        concepts = [comp_obj.positive, comp_obj.negative]

        if correct.replace("alt-", "").strip() == "second":
            concepts = list(reversed(concepts))
            pw_order = list(reversed(pseudowords))
        else:
            pw_order = pseudowords

        first, second = [
            self.generate_premise(self.world.lexicon[c].article, sub)
            for c, sub in zip(concepts, pseudowords)
        ]
        premise_prefix = f"{first} {second}"

        acceptable, unacceptable = [
            f"{premise_prefix} Therefore, a {sub}" for sub in pw_order
        ]

        return acceptable, unacceptable


random.seed(42)
sampled_pseudowords = utils.random_pairs(config._PSEUDOWORDS)
prompt_pseudowords = sampled_pseudowords[:-1]
stimuli_pseudowords = sampled_pseudowords[-1]

pathlib.Path("data/stimuli/comps/prompts").mkdir(exist_ok=True, parents=True)

test_data = read_test("data/prompt_data/test.jsonl")

# save prompts
for prompt_type in config._PROMPT_TYPES:
    # load prompt
    prompts = read_prompts(f"data/prompt_data/{prompt_type}.jsonl")

    # load verbalizer and prompter
    verbalizer = config._VERBALIZERS[prompt_type]
    comps_prompter = CompsPrompt(verbalizer)

    for heuristic, pattern in config._HEURISTICS.items():
        for k, v in prompts.items():
            prompt_strings = []
            for instance, heuristic_pattern, pseudoword_pair in zip(
                v, pattern, prompt_pseudowords
            ):
                comps_obj = Comps(**instance)
                acceptable, unacceptable = comps_prompter.generate_prefix(
                    comps_obj, pseudoword_pair, heuristic_pattern
                )
                prompt_strings.append(
                    acceptable + f" {comps_obj.property_phrase}."
                )

            # save prompts
            prompt_string = "\n".join(prompt_strings)
            with open(
                f"data/stimuli/comps/prompts/{prompt_type}-{k}_{heuristic}.txt", "w"
            ) as f:
                f.write(prompt_string)

# save test sets
comps_prompter = CompsPrompt("is")
for pattern in config._TEST_PATTERNS:
    test_set = []
    for i, item in enumerate(test_data):
        comps_obj = Comps(**item)
        # reverse order of pseudowords for half of the items
        if i % 2 == 0:
            stimuli_pseudowords = list(reversed(stimuli_pseudowords))
        acceptable, unacceptable = comps_prompter.generate_prefix(
            comps_obj, stimuli_pseudowords, pattern
        )
        test_set_entry = {
            "id": comps_obj.id,
            "correct": pattern,
            "acceptable": acceptable,
            "unacceptable": unacceptable,
            "property_phrase": comps_obj.property_phrase,
        }
        test_set.append(test_set_entry)

    utils.write_jsonl(test_set, f"data/stimuli/comps/test_{pattern}.jsonl")
