import torch
import json
import time

def gettime():
    return time.strftime("%m-%d-%H:%M",time.localtime(time.time() + 8*3600))

def loads_from_file(path):
    f = open(path, 'r')
    ret = json.loads(f.read())
    f.close()
    return ret

def get_input(slice, tokenizer, types, isBlank=False, without_ent=False):
    model_input = []
    mask_tensor = []
    pos = []
    ans = []
 
    ind = 0
    for line in slice:
        obj = json.loads(line)

        ind += 1
        if isBlank and ind%2==1:
            s = '[CLS]' +  obj['left_context_token'] + ' <ent> <blank>  <ent> ' + obj['right_context_token']
        else:
            s = '[CLS]' +  obj['left_context_token'] + ' <ent> ' + obj['mention_span'] + ' <ent> ' + obj['right_context_token'] 

        s = tokenizer.tokenize(s)

        if without_ent:
            lpos = s.index('<ent>')
            s.remove('<ent>')
            rpos = s.index('<ent>')
            s.remove('<ent>')

            pos.append([lpos,rpos])
        else:
            pos.append(s.index('<ent>'))

        model_input.append(tokenizer.convert_tokens_to_ids(s))
        mask_tensor.append([1]*len(s))

        ans_tmp = [0.] * len(types)
        for ty in obj['y_str'] : 
            if ty in types:
                ans_tmp[types.index(ty)] = 1.
        ans.append(ans_tmp)
 
    maxlen = max([len(s) for s in model_input])
    model_input = [ (s+[0]*(maxlen-len(s))) for s in model_input]
    mask_tensor = [ (s+[0]*(maxlen-len(s))) for s in mask_tensor]
    
    model_input = torch.tensor(model_input)
    mask_tensor = torch.tensor(mask_tensor)
    ans = torch.tensor(ans)
    pos = torch.tensor(pos)

    return model_input, mask_tensor, ans, pos

def get_sim(slice):
    n = len(slice)

    objs = []
    for line in slice:
        objs.append(json.loads(line))

    ret = torch.zeros(n,n)
    for i in range(n):
        for j in range(i+1):
            for ty in objs[i]['y_str']:
                if ty in objs[j]['y_str']:
                    ret[i][j] = 1
                    ret[j][i] = 1
                    break
    return ret


def loss_function(model_output, ans):

    ret = - torch.sum(
        torch.log(model_output + 1e-8) * ans
        + torch.log(1-model_output + 1e-8) * (1-ans)
    )

    return ret

def is_sim(x, y):
    x = json.loads(x)
    y = json.loads(y)
    for ty in x['y_str']:
        if ty in y['y_str']:
            return True
    return False

def sim_loss_function(model_output, sim):
    M = torch.nn.functional.normalize(model_output) * 2

    M = M.matmul(M.t())
    M = M - torch.eye(M.size()[0]).to(M.device)*1000

    M = M.softmax(0)
    M = (M*sim).sum(0)
    return - (M + 1e-30).log().sum()
