"""Code for sampling properties from a database of ground-truth semantic knowledge."""
import argparse
import csv
import json
import pathlib
import random

from collections import defaultdict
from semantic_memory import memory, list_utils
from tqdm import trange
from ordered_set import OrderedSet


def get_prop_space(
    concept_space,
    applicable_props=None,
    neg_sampling_space=None,
    min_neg_samples=4,
    world=None,
):
    prop_space = []
    for f in applicable_props:
        neg_samples = [
            c
            for c in world.feature_space[f]["negative"]
            if c in neg_sampling_space
        ]
        if len(neg_samples) >= min_neg_samples:
            prop_space.append(f)

    return sorted(
        list_utils.intersect(
            [
                list(
                    list_utils.union(
                        [world.concept_features[c] for c in concept_space]
                    )
                ),
                prop_space,
            ]
        )
    )


def sample_prompts(
    cat_space,
    label="ID",
    applicable_props=None,
    neg_samples=None,
    n_prompts=6,
    n_repeats=5,
    world=None,
):
    prompt_data = []

    for k in trange(n_repeats):
        prop_space = get_prop_space(
            cat_space, applicable_props, neg_samples, 4, world
        )
        sampling_space = prop_space.copy()
        for i in range(n_prompts):
            sampled_property = random.sample(sampling_space, 1)[0]

            positive_space = [
                c
                for c in world.feature_space[sampled_property]["positive"]
                if c in cat_space
            ]
            positive = random.sample(positive_space, 1)[0]

            negative_space = [
                c
                for c in world.feature_space[sampled_property]["negative"]
                if c in neg_samples
            ]
            negative = random.sample(negative_space, 1)[0]

            prompt_data.append(
                {
                    "id": k,
                    "stimuli_type": label,
                    "property_phrase": sampled_property,
                    "positive": positive,
                    "negative": negative,
                }
            )

            concept_space = OrderedSet(cat_space) - OrderedSet(
                [positive, negative]
            )
            sampling_space = list(
                OrderedSet(
                    get_prop_space(
                        concept_space, applicable_props, neg_samples, 4, world
                    )
                )
                - OrderedSet([sampled_property])
            )

    return prompt_data


def save_results(lst, name) -> None:
    with open(f"data/prompt_data/{name}.jsonl", "w", encoding="utf-8") as f:
        for entry in lst:
            json.dump(entry, f)
            f.write("\n")


def main(args):
    seed = args.seed
    N = args.n
    test_concepts = args.test_concepts

    random.seed(seed)

    world = memory.Memory(
        concept_path="../world/data/concept_senses.csv",
        feature_path="../world/data/xcslb_compressed.csv",
        matrix_path="../world/data/concept_matrix.txt",
        feature_metadata="../world/data/feature_lexicon.csv",
    )
    world.create()

    animals = world.taxonomy["animal.n.01"].descendants()

    # applicable properties: minus color and at least 2 positive concepts.
    animal_props = sorted(
        list(
            list_utils.union(
                [
                    list(world.concept_features[c])
                    for c in world.taxonomy["animal.n.01"].descendants()
                ]
            )
        )
    )

    applicable_properties = []
    for feature in animal_props:
        if (
            len(world.feature_space[feature]["positive"]) > 1
            and "color" not in feature
        ):
            applicable_properties.append(feature)

    test_set_space = random.sample(animals, 92)
    test_prop_space = get_prop_space(
        test_set_space, applicable_properties, animals, 4, world
    )

    test_set = []
    actual_test_set_space = OrderedSet()

    for i in range(N):
        # sample property
        test_prop = random.sample(test_prop_space, 1)[0]
        positive_space = [
            c
            for c in world.feature_space[test_prop]["positive"]
            if c in test_set_space
        ]
        positive = random.sample(positive_space, 1)[0]

        negative_space = [
            c
            for c in world.feature_space[test_prop]["negative"]
            if c in test_set_space
        ]
        negative = random.sample(negative_space, 1)[0]

        # test_set.append((i, test_prop, positive, negative, "TEST"))
        test_set.append(
            {
                "id": i,
                "stimuli_type": "TEST",
                "property_phrase": test_prop,
                "positive": positive,
                "negative": negative,
            }
        )
        actual_test_set_space.add(positive)
        actual_test_set_space.add(negative)

    cat_space = list(OrderedSet(animals) - OrderedSet(actual_test_set_space))

    id_prompt_data = sample_prompts(
        cat_space,
        label="ID PROMPT",
        applicable_props=applicable_properties,
        neg_samples=cat_space,
        world=world,
    )

    non_animals = [
        c for c in world.taxonomy["entity"].descendants() if c not in animals
    ]
    non_animal_props = sorted(
        list(
            list_utils.union(
                [list(world.concept_features[c]) for c in non_animals]
            )
        )
    )

    ood_applicable_properties = []
    for feature in non_animal_props:
        if (
            len(world.feature_space[feature]["positive"]) > 1
            and "color" not in feature
        ):
            ood_applicable_properties.append(feature)

    cat_space = non_animals
    ood_prompt_data = sample_prompts(
        cat_space,
        label="OOD PROMPT",
        applicable_props=ood_applicable_properties,
        neg_samples=non_animals,
        world=world,
    )

    # id_prompt_data = [vv for k, v in id_prompt_data.items() for vv in v]
    # ood_prompt_data = [vv for k, v in ood_sprompt_data.items() for vv in v]

    pathlib.Path("data/prompt_data/").mkdir(exist_ok=True, parents=True)
    save_results(test_set, "test")
    save_results(id_prompt_data, "indomain")
    save_results(ood_prompt_data, "outofdomain")

    print("Saved all sampled subsets!")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", "-s", default=42, type=int)
    parser.add_argument("--n", "-n", default=256, type=int)
    parser.add_argument("--test_concepts", default=92, type=int)

    args = parser.parse_args()
    main(args)
