diff --git a/pomegranate/_utils.py b/pomegranate/_utils.py index 197866aa..c68a6d36 100644 --- a/pomegranate/_utils.py +++ b/pomegranate/_utils.py @@ -56,7 +56,10 @@ def _cast_as_tensor(value, dtype=None): return value else: return value.type(dtype) - + + if isinstance(value, list) and all(isinstance(v, numpy.ndarray) for v in value): + value = numpy.array(value) + if isinstance(value, (float, int, list, tuple, numpy.ndarray)): if dtype is None: return torch.tensor(value)