import tqdm
from pathlib import Path
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer

current_dir = Path(__file__).resolve().parent


def convert_label(label_name: str) -> int:
    labels = {
        "PlayMusic": 0,
        "GetWeather": 1,
        "BookRestaurant": 2,
        "AddToPlaylist": 3,
        "RateBook": 4,
        "SearchCreativeWork": 5,
        "SearchScreeningEvent": 6,
    }
    return labels[label_name]


def reader(file_path: Path):
    """
    Data is downloaded from https://github.com/LeePleased/StackPropagation-SLU/blob/master/data/snips.
    We only uses train & test data.
    """
    texts, slots, intents = [], [], []
    text, slot = [], []
    with open(file_path, "r") as fr:
        for line in fr.readlines():
            items = line.strip().split()
            if len(items) == 1:
                texts.append(text)
                slots.append(slot)
                intents.append(items)
                # clear buffer lists.
                text, slot = [], []
            elif len(items) == 2:
                text.append(items[0].strip())
                slot.append(items[1].strip())

    assert len(texts) == len(intents)
    for text, intent in zip(texts, intents):
        text = " ".join(text)
        assert len(intent) == 1
        intent = str(convert_label(intent[0]))
        yield (intent, text)


if __name__ == "__main__":
    """
    Preprocess snips dataset and get sentence embeddings & labels.
    It takes < 20 minutes with CPU to download model & get embeddings.
    """

    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    sentence_model = SentenceTransformer("paraphrase-mpnet-base-v2")

    for dsplit in ["train", "test"]:
        raw_file = current_dir / f"{dsplit}.txt"
        svector_savefile = current_dir / f"{dsplit}.csv"
        xy_savefile = current_dir / f"{dsplit}_xy.txt"
        svector_file = open(svector_savefile, "w")
        xy_file = open(xy_savefile, "w")
        header = list(range(768)) + ["label\n"]
        header = ",".join(map(str, header))
        svector_file.write(header)
        for item in tqdm.tqdm(reader(raw_file)):
            intent, text = item
            # Sentence
            sentence_embedding = sentence_model.encode(text)
            semb_str = ",".join([str(w) for w in sentence_embedding.tolist()])
            svector_file.write(semb_str + f",{intent}\n")
            xy_file.write(intent + "\t" + text + "\n")
        svector_file.close()
        xy_file.close()
