import torch
import json

from .function import *

from random import shuffle

class zero_shot_calculator:
    def __init__(self, config, model, tokenizer, data, types):

        ind={}
        for i in range(len(data)):
            for ty in json.loads(data[i])['y_str']:
                ind.setdefault(ty, []).append(i)

        samp = []
        for ty in types:
            if ty not in ind:
                samp.append(data[0])
            else:
                shuffle(ind[ty])
                samp.append(data[ind[ty][0]])

        model_input, mask_tensor, ans, pos = get_input(samp, tokenizer, types, True)
        model_input = model_input.to(config.dev)
        mask_tensor = mask_tensor.to(config.dev)

        model_output, embeddings = model(model_input, pos, mask_tensor = mask_tensor)

        self.M = embeddings

    def get_output(self, embeddings):
        return torch.softmax(self.M.matmul(embeddings.t()).t().contiguous(), dim=1) * 1.5
#return torch.sigmoid(self.M.matmul(embeddings.t()).t().contiguous()/50)
