from sentence_transformers import SentenceTransformer
import torch

model = SentenceTransformer("sentence-transformers/paraphrase-TinyBERT-L6-v2")


def sentencebert_encode(input_text):
    embedding = model.encode(input_text)

    embedding = torch.tensor(embedding)

    embedding = torch.unsqueeze(embedding, dim=1)

    assert embedding.shape == (len(input_text), 1, 768)

    return embedding


if __name__ == "__main__":
    tensor1 = torch.squeeze(
        sentencebert_encode(
            [
                "You are allowed to take the following actions: go to cabinet 1, go to cabinet 10, go to cabinet 11, go to cabinet 12, go to cabinet 13, go to cabinet 14, go to cabinet 15, go to cabinet 16, go to cabinet 17, go to cabinet 18, go to cabinet 19, go to cabinet 2, go to cabinet 20, go to cabinet 21, go to cabinet 22, go to cabinet 23, go to cabinet 24, go to cabinet 25, go to cabinet 26, go to cabinet 3, go to cabinet 4, go to cabinet 5, go to cabinet 6, go to cabinet 7, go to cabinet 8, go to cabinet 9, go to coffeemachine 1, go to countertop 1, go to countertop 2, go to countertop 3, go to drawer 1, go to drawer 10, go to drawer 11, go to drawer 12, go to drawer 2, go to drawer 3, go to drawer 4, go to drawer 5, go to drawer 6, go to drawer 7, go to drawer 8, go to drawer 9, go to fridge 1, go to garbagecan 1, go to microwave 1, go to sinkbasin 1, go to stoveburner 1, go to stoveburner 2, go to stoveburner 3, go to stoveburner 4, go to toaster 1."
            ]
        )
    )

    tensor2 = torch.squeeze(
        sentencebert_encode(
            [
                "You are allowed to take the following actions: go to cabinet 1, go to cabinet 10, go to cabinet 11, go to cabinet 12, go to cabinet 13, go to cabinet 14, go to cabinet 15, go to cabinet 16, go to cabinet 17, go to cabinet 18, go to cabinet 19, go to cabinet 2, go to cabinet 20, go to cabinet 21, go to cabinet 22, go to cabinet 23, go to cabinet 24, go to cabinet 25, go to cabinet 26, go to cabinet 3, go to cabinet 4, go to cabinet 5, go to cabinet 6, go to cabinet 7, go to cabinet 8, go to cabinet 9, go to coffeemachine 1, go to countertop 1, go to countertop 2, go to countertop 3, go to drawer 1, go to drawer 10, go to drawer 11, go to drawer 12, go to drawer 2, go to drawer 3, go to drawer 4, go to drawer 5, go to drawer 6, go to drawer 7, go to drawer 8, go to drawer 9, go to fridge 1, go to garbagecan 1, go to microwave 1, go to sinkbasin 1, go to stoveburner 1, go to stoveburner 2, go to stoveburner 3, go to stoveburner 4, go to toaster 1"
            ]
        )
    )
    print(torch.norm(tensor1), torch.norm(tensor2))
    print(torch.dot(tensor1, tensor2) / (torch.norm(tensor1) * torch.norm(tensor2)))
