import torch
import torch.nn as nn
LOGSOFTMAX = nn.LogSoftmax(dim=0)


def imgPrior(predFeature, gtFeatures, lda=1):
    # Repeat the gtFeatures
    # gtFeatures = gtFeatures.unsqueeze(0).repeat(predFeature.size(0), 1, 1)
    # print(gtFeatures.size())
    priors = []
    for i in range(predFeature.size()[0]):
        diff = predFeature[i] - gtFeatures
        score = -torch.sum(diff * diff, 1).sqrt()
        prior = lda * LOGSOFTMAX(score * 10).data
        priors.append(prior.unsqueeze(0))
    priors = torch.cat(priors, 0)
    return priors


def prepareBatch(dataset, batch):
    if dataset.useGPU:
        batch = {key: v.cuda() if hasattr(v, 'cuda')
                 else v for key, v in batch.items()}
    else:
        batch = {key: v.contiguous() if hasattr(
            v, 'cuda') else v for key, v in batch.items()}
    return batch


def normProb(logProbs, numFix=True):
    '''
    Augments:
        logProbs : 1D tensor

    Return:
        normProbs: Same size with normalized probability
    '''
    if numFix:
        logProbs -= torch.max(logProbs)
    probs = torch.exp(logProbs)
    sumProbs = torch.sum(probs, dim=-1, keepdim=True)
    # print(probs)
    # print(sumProbs)
    normProbs = probs / sumProbs
    # print(normProbs)
    return normProbs
