# -*- coding: utf-8 -*-

import numpy as np
import torch


def process_batch(batch, plugins=None, device=None):
    for name, value in batch.items():
        if isinstance(value, np.ndarray):
            value = torch.from_numpy(value)

            if name.endswith('mask') and not name.endswith('instances_mask'):
                value = value.unsqueeze(-1).float()

        if device is not None and torch.is_tensor(value):
            value = value.to(device)
        batch[name] = value

    return batch
