Source code for minerva.utils.tensor

from typing import Iterable

import numpy as np
import torch


[docs] def to_tensor(x, dtype=None) -> torch.Tensor: if isinstance(x, torch.Tensor): if dtype is not None: x = x.type(dtype) return x elif isinstance(x, np.ndarray): x = torch.from_numpy(x) if dtype is not None: x = x.type(dtype) return x elif isinstance(x, Iterable): x = np.array(x) x = torch.from_numpy(x) if dtype is not None: x = x.type(dtype) return x else: raise TypeError(f"Unsupported type: {type(x)}")