Source code for minerva.utils.tensor
from typing import Iterable
import numpy as np
import torch
[docs]
def to_tensor(x, dtype=None) -> torch.Tensor:
if isinstance(x, torch.Tensor):
if dtype is not None:
x = x.type(dtype)
return x
elif isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if dtype is not None:
x = x.type(dtype)
return x
elif isinstance(x, Iterable):
x = np.array(x)
x = torch.from_numpy(x)
if dtype is not None:
x = x.type(dtype)
return x
else:
raise TypeError(f"Unsupported type: {type(x)}")