From a446005b33ccd52cdd235da619921f4274f0b479 Mon Sep 17 00:00:00 2001 From: Josep Maria Salvia Hornos Date: Mon, 13 Nov 2023 01:15:46 +0100 Subject: [PATCH] Efficient Tensor conversion from list of numpy arrays (#1071) --- pomegranate/_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pomegranate/_utils.py b/pomegranate/_utils.py index 197866aa..c68a6d36 100644 --- a/pomegranate/_utils.py +++ b/pomegranate/_utils.py @@ -56,7 +56,10 @@ def _cast_as_tensor(value, dtype=None): return value else: return value.type(dtype) - + + if isinstance(value, list) and all(isinstance(v, numpy.ndarray) for v in value): + value = numpy.array(value) + if isinstance(value, (float, int, list, tuple, numpy.ndarray)): if dtype is None: return torch.tensor(value)