From d93a45e15b2085edbd502eb0568cf8f7a5431d08 Mon Sep 17 00:00:00 2001 From: joaosferreira Date: Mon, 24 Aug 2020 13:27:13 +0100 Subject: [PATCH 1/3] Add default_rng, BitGenerator and SeedSequence --- unumpy/random/_multimethods.py | 27 +++++++++++++++++---------- unumpy/tests/test_numpy.py | 12 +++++++++++- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/unumpy/random/_multimethods.py b/unumpy/random/_multimethods.py index 8060355..9bdfeab 100644 --- a/unumpy/random/_multimethods.py +++ b/unumpy/random/_multimethods.py @@ -7,21 +7,35 @@ from .._multimethods import ( ClassOverrideMetaWithConstructor, + ClassOverrideMetaWithGetAttr, + ClassOverrideMetaWithConstructorAndGetAttr, ndarray, _identity_argreplacer, _self_argreplacer, - _dtype_argreplacer, _first2argreplacer, _first3argreplacer, mark_dtype, ) +@create_numpy(_identity_argreplacer) +def default_rng(seed=None): + return () + + class RandomState(metaclass=ClassOverrideMetaWithConstructor): pass -class Generator(metaclass=ClassOverrideMetaWithConstructor): +class Generator(metaclass=ClassOverrideMetaWithConstructorAndGetAttr): + pass + + +class BitGenerator(metaclass=ClassOverrideMetaWithGetAttr): + pass + + +class SeedSequence(metaclass=ClassOverrideMetaWithConstructorAndGetAttr): pass @@ -217,14 +231,7 @@ def logseries(p, size=None): return (p,) -def _multinomial_argreplacer(args, kwargs, dispatchables): - def replacer(n, pvals, size=None): - return (n, dispatchables[0]), dict(size=size) - - return replacer(*args, **kwargs) - - -@create_numpy(_multinomial_argreplacer) +@create_numpy(_identity_argreplacer) @all_of_type(ndarray) def multinomial(n, pvals, size=None): return (pvals,) diff --git a/unumpy/tests/test_numpy.py b/unumpy/tests/test_numpy.py index 5844879..34013e8 100644 --- a/unumpy/tests/test_numpy.py +++ b/unumpy/tests/test_numpy.py @@ -18,7 +18,14 @@ LIST_BACKENDS = [ ( NumpyBackend, - (onp.ndarray, onp.generic, onp.ufunc, onp.random.mtrand.RandomState), + ( + onp.ndarray, + onp.generic, + onp.ufunc, + onp.random.RandomState, + onp.random.Generator, + onp.random.SeedSequence, + ), ), (DaskBackend(), (da.Array, onp.generic, da.ufunc.ufunc, da.random.RandomState)), ( @@ -583,8 +590,11 @@ def test_linalg(backend, method, args, kwargs): @pytest.mark.parametrize( "method, args, kwargs", [ + (np.random.default_rng, (42,), {}), (np.random.RandomState, (42,), {}), # (np.random.Generator, (), {}), + # (np.random.BitGenerator, (), {}), + (np.random.SeedSequence, (42,), {}), (np.random.rand, (1, 2), {}), (np.random.randn, (1, 2), {}), (np.random.randint, ([1, 2],), {}), From 9da7501f0639a60d30247be6c0628c633e63ae3e Mon Sep 17 00:00:00 2001 From: joaosferreira Date: Tue, 25 Aug 2020 19:33:23 +0100 Subject: [PATCH 2/3] Add multimethods for Generator's methods --- unumpy/_multimethods.py | 10 ++ unumpy/cupy_backend.py | 9 +- unumpy/dask_backend.py | 9 +- unumpy/numpy_backend.py | 9 +- unumpy/random/_multimethods.py | 306 ++++++++++++++++++++++++++++++++- unumpy/sparse_backend.py | 9 +- unumpy/tests/test_numpy.py | 69 +++++++- 7 files changed, 397 insertions(+), 24 deletions(-) diff --git a/unumpy/_multimethods.py b/unumpy/_multimethods.py index 3a51a89..61ebe65 100644 --- a/unumpy/_multimethods.py +++ b/unumpy/_multimethods.py @@ -29,6 +29,9 @@ def _dtype_argreplacer(args, kwargs, dispatchables): def replacer(*a, dtype=None, **kw): out_kw = kw.copy() out_kw["dtype"] = dispatchables[0] + if "out" in out_kw: + out_kw["out"] = dispatchables[1] + return a, out_kw return replacer(*args, **kwargs) @@ -41,6 +44,13 @@ def self_method(a, *args, **kwargs): return self_method(*args, **kwargs) +def _skip_self_argreplacer(args, kwargs, dispatchables): + def replacer(self, *args, **kwargs): + return (self,) + dispatchables, kwargs + + return replacer(*args, **kwargs) + + def _ureduce_argreplacer(args, kwargs, dispatchables): def ureduce(self, a, axis=0, dtype=None, out=None, keepdims=False): return ( diff --git a/unumpy/cupy_backend.py b/unumpy/cupy_backend.py index f05ee24..dacf5c1 100644 --- a/unumpy/cupy_backend.py +++ b/unumpy/cupy_backend.py @@ -23,14 +23,15 @@ def overridden_class(self): def _get_from_name_domain(name, domain): module = cp - domain_hierarchy = domain.split(".") + name_hierarchy = name.split(".") + domain_hierarchy = domain.split(".") + name_hierarchy[0:-1] for d in domain_hierarchy[1:]: if hasattr(module, d): module = getattr(module, d) else: return NotImplemented - if hasattr(module, name): - return getattr(module, name) + if hasattr(module, name_hierarchy[-1]): + return getattr(module, name_hierarchy[-1]) else: return NotImplemented @@ -48,7 +49,7 @@ def __ua_function__(method, args, kwargs): if len(args) != 0 and isinstance(args[0], unumpy.ClassOverrideMeta): return NotImplemented - cupy_method = _get_from_name_domain(method.__name__, method.domain) + cupy_method = _get_from_name_domain(method.__qualname__, method.domain) if cupy_method is NotImplemented: return NotImplemented diff --git a/unumpy/dask_backend.py b/unumpy/dask_backend.py index 53a0d75..d149056 100644 --- a/unumpy/dask_backend.py +++ b/unumpy/dask_backend.py @@ -31,14 +31,15 @@ def overridden_class(self): def _get_from_name_domain(name, domain): module = da - domain_hierarchy = domain.split(".") + name_hierarchy = name.split(".") + domain_hierarchy = domain.split(".") + name_hierarchy[0:-1] for d in domain_hierarchy[1:]: if hasattr(module, d): module = getattr(module, d) else: return NotImplemented - if hasattr(module, name): - return getattr(module, name) + if hasattr(module, name_hierarchy[-1]): + return getattr(module, name_hierarchy[-1]) else: return NotImplemented @@ -144,7 +145,7 @@ def __ua_function__(self, method, args, kwargs): if len(args) != 0 and isinstance(args[0], unumpy.ClassOverrideMeta): return NotImplemented - dask_method = _get_from_name_domain(method.__name__, method.domain) + dask_method = _get_from_name_domain(method.__qualname__, method.domain) if dask_method is NotImplemented: return NotImplemented diff --git a/unumpy/numpy_backend.py b/unumpy/numpy_backend.py index 666e127..e4cfbc5 100644 --- a/unumpy/numpy_backend.py +++ b/unumpy/numpy_backend.py @@ -29,11 +29,12 @@ def overridden_class(self): def _get_from_name_domain(name, domain): module = np - domain_hierarchy = domain.split(".") + name_hierarchy = name.split(".") + domain_hierarchy = domain.split(".") + name_hierarchy[0:-1] for d in domain_hierarchy[1:]: module = getattr(module, d) - if hasattr(module, name): - return getattr(module, name) + if hasattr(module, name_hierarchy[-1]): + return getattr(module, name_hierarchy[-1]) else: return NotImplemented @@ -45,7 +46,7 @@ def __ua_function__(method, args, kwargs): if len(args) != 0 and isinstance(args[0], unumpy.ClassOverrideMeta): return NotImplemented - method_numpy = _get_from_name_domain(method.__name__, method.domain) + method_numpy = _get_from_name_domain(method.__qualname__, method.domain) if method_numpy is NotImplemented: return NotImplemented diff --git a/unumpy/random/_multimethods.py b/unumpy/random/_multimethods.py index 9bdfeab..0003fe7 100644 --- a/unumpy/random/_multimethods.py +++ b/unumpy/random/_multimethods.py @@ -11,10 +11,13 @@ ClassOverrideMetaWithConstructorAndGetAttr, ndarray, _identity_argreplacer, + _dtype_argreplacer, _self_argreplacer, + _skip_self_argreplacer, _first2argreplacer, _first3argreplacer, mark_dtype, + mark_non_coercible, ) @@ -27,10 +30,6 @@ class RandomState(metaclass=ClassOverrideMetaWithConstructor): pass -class Generator(metaclass=ClassOverrideMetaWithConstructorAndGetAttr): - pass - - class BitGenerator(metaclass=ClassOverrideMetaWithGetAttr): pass @@ -231,15 +230,15 @@ def logseries(p, size=None): return (p,) -@create_numpy(_identity_argreplacer) +@create_numpy(_self_argreplacer) @all_of_type(ndarray) def multinomial(n, pvals, size=None): - return (pvals,) + return (n,) @create_numpy(_first2argreplacer) @all_of_type(ndarray) -def multivariate_normal(mean, cov, size=None, check_valid=None, tol=None): +def multivariate_normal(mean, cov, size=None, check_valid="warn", tol=1e-8): return (mean, cov) @@ -387,3 +386,296 @@ def get_state(): @create_numpy(_identity_argreplacer) def set_state(state): return () + + +def _integers_argreplacer(args, kwargs, dispatchables): + def replacer(self, low, high=None, size=None, dtype=int, endpoint=False): + return ( + (self, dispatchables[0],), + dict( + high=dispatchables[1], + size=size, + dtype=dispatchables[2], + endpoint=endpoint, + ), + ) + + return replacer(*args, **kwargs) + + +def _Generator_choice_argreplacer(args, kwargs, dispatchables): + def replacer(self, a, size=None, replace=True, p=None, axis=0, shuffle=True): + return ( + (self, dispatchables[0],), + dict( + size=size, + replace=replace, + p=dispatchables[1], + axis=axis, + shuffle=shuffle, + ), + ) + + return replacer(*args, **kwargs) + + +def _Generator_exponential_argreplacer(args, kwargs, dispatchables): + def replacer(self, scale=1.0, size=None): + return (self,), dict(scale=dispatchables[0], size=size) + + return replacer(*args, **kwargs) + + +def _Generator_gamma_argreplacer(args, kwargs, dispatchables): + def replacer(self, shape, scale=1.0, size=None): + return (self, dispatchables[0],), dict(scale=dispatchables[1], size=size) + + return replacer(*args, **kwargs) + + +def _Generator_loc_scale_argreplacer(args, kwargs, dispatchables): + def replacer(self, loc=0.0, scale=1.0, size=None): + return (self,), dict(loc=dispatchables[0], scale=dispatchables[1], size=size) + + return replacer(*args, **kwargs) + + +def _Generator_lognormal_argreplacer(args, kwargs, dispatchables): + def replacer(self, mean=0.0, sigma=1.0, size=None): + return (self,), dict(mean=dispatchables[0], sigma=dispatchables[1], size=size) + + return replacer(*args, **kwargs) + + +def _Generator_multinomial_argreplacer(args, kwargs, dispatchables): + def replacer(self, n, pvals, size=None): + return (self, dispatchables[0], pvals), dict(size=size) + + return replacer(*args, **kwargs) + + +def _Generator_poisson_argreplacer(args, kwargs, dispatchables): + def replacer(self, lam=1.0, size=None): + return (self,), dict(lam=dispatchables[0], size=size) + + return replacer(*args, **kwargs) + + +def _Generator_rayleigh_argreplacer(args, kwargs, dispatchables): + def replacer(self, scale=1.0, size=None): + return (self,), dict(scale=dispatchables[0], size=size) + + return replacer(*args, **kwargs) + + +def _Generator_uniform_argreplacer(args, kwargs, dispatchables): + def replacer(self, low=0.0, high=1.0, size=None): + return (self,), dict(low=dispatchables[0], high=dispatchables[1], size=size) + + return replacer(*args, **kwargs) + + +class Generator(metaclass=ClassOverrideMetaWithConstructorAndGetAttr): + @create_numpy(_integers_argreplacer) + @all_of_type(ndarray) + def integers(self, low, high=None, size=None, dtype=int, endpoint=False): + return (low, high, mark_dtype(dtype)) + + @create_numpy(_dtype_argreplacer) + def random(self, size=None, dtype=float, out=None): + return (mark_dtype(dtype), mark_non_coercible(out)) + + @create_numpy(_Generator_choice_argreplacer) + @all_of_type(ndarray) + def choice(self, a, size=None, replace=True, p=None, axis=0, shuffle=True): + return (a, p) + + @create_numpy(_identity_argreplacer) + def bytes(self, length): + return () + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def shuffle(self, x, axis=0): + return (x,) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def permutation(self, x, axis=0): + return (x,) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def beta(self, a, b, size=None): + return (a, b) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def binomial(self, n, p, size=None): + return (n, p) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def chisquare(self, df, size=None): + return (df,) + + @create_numpy(_identity_argreplacer) + def dirichlet(self, alpha, size=None): + return () + + @create_numpy(_Generator_exponential_argreplacer) + @all_of_type(ndarray) + def exponential(self, scale=1.0, size=None): + return (scale,) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def f(self, dfnum, dfden, size=None): + return (dfnum, dfden) + + @create_numpy(_Generator_gamma_argreplacer) + @all_of_type(ndarray) + def gamma(self, shape, scale=1.0, size=None): + return (shape, scale) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def geometric(self, p, size=None): + return (p,) + + @create_numpy(_Generator_loc_scale_argreplacer) + @all_of_type(ndarray) + def gumbel(self, loc=0.0, scale=1.0, size=None): + return (loc, scale) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def hypergeometric(self, ngood, nbad, nsample, size=None): + return (ngood, nbad, nsample) + + @create_numpy(_Generator_loc_scale_argreplacer) + @all_of_type(ndarray) + def laplace(self, loc=0.0, scale=1.0, size=None): + return (loc, scale) + + @create_numpy(_Generator_loc_scale_argreplacer) + @all_of_type(ndarray) + def logistic(self, loc=0.0, scale=1.0, size=None): + return (loc, scale) + + @create_numpy(_Generator_lognormal_argreplacer) + @all_of_type(ndarray) + def lognormal(self, mean=0.0, sigma=1.0, size=None): + return (mean, sigma) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def logseries(self, p, size=None): + return (p,) + + @create_numpy(_Generator_multinomial_argreplacer) + @all_of_type(ndarray) + def multinomial(self, n, pvals, size=None): + return (n,) + + @create_numpy(_identity_argreplacer) + def multivariate_hypergeometric( + self, colors, nsample, size=None, method="marginals" + ): + return () + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def multivariate_normal(self, mean, cov, size=None, check_valid="warn", tol=1e-8): + return (mean, cov) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def negative_binomial(self, n, p, size=None): + return (n, p) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def noncentral_chisquare(self, df, nonc, size=None): + return (df, nonc) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def noncentral_f(self, dfnum, dfden, nonc, size=None): + return (dfnum, dfden, nonc) + + @create_numpy(_Generator_loc_scale_argreplacer) + @all_of_type(ndarray) + def normal(self, loc=0.0, scale=1.0, size=None): + return (loc, scale) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def pareto(self, a, size=None): + return (a,) + + @create_numpy(_Generator_poisson_argreplacer) + @all_of_type(ndarray) + def poisson(self, lam=1.0, size=None): + return (lam,) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def power(self, a, size=None): + return (a,) + + @create_numpy(_Generator_rayleigh_argreplacer) + @all_of_type(ndarray) + def rayleigh(self, scale=1.0, size=None): + return (scale,) + + @create_numpy(_identity_argreplacer) + def standard_cauchy(self, size=None): + return () + + @create_numpy(_dtype_argreplacer) + def standard_exponential(self, size=None, dtype=float, method="zig", out=None): + return (mark_dtype(dtype), mark_non_coercible(out)) + + @create_numpy(_dtype_argreplacer) + def standard_gamma(self, shape, size=None, dtype=float, out=None): + return (mark_dtype(dtype), mark_non_coercible(out)) + + @create_numpy(_dtype_argreplacer) + def standard_normal(self, size=None, dtype=float, out=None): + return (mark_dtype(dtype), mark_non_coercible(out)) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def standard_t(self, df, size=None): + return (df,) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def triangular(self, left, mode, right, size=None): + return (left, mode, right) + + @create_numpy(_Generator_uniform_argreplacer) + @all_of_type(ndarray) + def uniform(self, low=0.0, high=1.0, size=None): + return (low, high) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def vonmises(self, mu, kappa, size=None): + return (mu, kappa) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def wald(self, mean, scale, size=None): + return (mean, scale) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def weibull(self, a, size=None): + return (a,) + + @create_numpy(_skip_self_argreplacer) + @all_of_type(ndarray) + def zipf(self, a, size=None): + return (a,) diff --git a/unumpy/sparse_backend.py b/unumpy/sparse_backend.py index f6b6b38..9d4970c 100644 --- a/unumpy/sparse_backend.py +++ b/unumpy/sparse_backend.py @@ -51,14 +51,15 @@ def overridden_class(self): def _get_from_name_domain(name, domain): module = sparse - domain_hierarchy = domain.split(".") + name_hierarchy = name.split(".") + domain_hierarchy = domain.split(".") + name_hierarchy[0:-1] for d in domain_hierarchy[1:]: if hasattr(module, d): module = getattr(module, d) else: return NotImplemented - if hasattr(module, name): - return getattr(module, name) + if hasattr(module, name_hierarchy[-1]): + return getattr(module, name_hierarchy[-1]) else: return NotImplemented @@ -70,7 +71,7 @@ def __ua_function__(method, args, kwargs): if len(args) != 0 and isinstance(args[0], unumpy.ClassOverrideMeta): return NotImplemented - sparse_method = _get_from_name_domain(method.__name__, method.domain) + sparse_method = _get_from_name_domain(method.__qualname__, method.domain) if sparse_method is NotImplemented: return NotImplemented diff --git a/unumpy/tests/test_numpy.py b/unumpy/tests/test_numpy.py index 34013e8..dda6552 100644 --- a/unumpy/tests/test_numpy.py +++ b/unumpy/tests/test_numpy.py @@ -609,7 +609,7 @@ def test_linalg(backend, method, args, kwargs): (np.random.permutation, ([1, 2, 3, 4],), {}), (np.random.beta, (1, 2), {"size": 2}), (np.random.binomial, (10, 0.5), {"size": 2}), - (np.random.chisquare, (2, 4), {}), + (np.random.chisquare, (2,), {"size": 2}), (np.random.dirichlet, ((10, 5, 3),), {}), (np.random.exponential, (), {"size": 2}), (np.random.f, (1.0, 48.0), {"size": 2}), @@ -670,6 +670,73 @@ def test_random(backend, method, args, kwargs): ret.compute() +@pytest.mark.parametrize( + "method, args, kwargs", + [ + (np.random.Generator.random, (), {"size": 2}), + (np.random.Generator.choice, ([1, 2],), {}), + (np.random.Generator.bytes, (10,), {}), + (np.random.Generator.shuffle, ([1, 2, 3, 4],), {}), + (np.random.Generator.permutation, ([1, 2, 3, 4],), {}), + (np.random.Generator.beta, (1, 2), {"size": 2}), + (np.random.Generator.binomial, (10, 0.5), {"size": 2}), + (np.random.Generator.chisquare, (2,), {"size": 2}), + (np.random.Generator.dirichlet, ((10, 5, 3),), {}), + (np.random.Generator.exponential, (), {"size": 2}), + (np.random.Generator.f, (1.0, 48.0), {"size": 2}), + (np.random.Generator.gamma, (2.0, 2.0), {"size": 2}), + (np.random.Generator.geometric, (0.35,), {"size": 2}), + (np.random.Generator.gumbel, (0.0, 0.1), {"size": 2}), + (np.random.Generator.hypergeometric, (100, 2, 10), {"size": 2}), + (np.random.Generator.laplace, (0.0, 1.0), {"size": 2}), + (np.random.Generator.logistic, (10, 1), {"size": 2}), + (np.random.Generator.lognormal, (3.0, 1.0), {"size": 2}), + (np.random.Generator.logseries, (0.6,), {"size": 2}), + (np.random.Generator.multinomial, (20, [1 / 6.0] * 6), {}), + (np.random.Generator.multivariate_normal, ([0, 0], [[1, 0], [0, 100]]), {}), + (np.random.Generator.negative_binomial, (1, 0.1), {"size": 2}), + (np.random.Generator.noncentral_chisquare, (3, 20), {"size": 2}), + (np.random.Generator.noncentral_f, (3, 20, 3.0), {"size": 2}), + (np.random.Generator.normal, (0, 0.1), {"size": 2}), + (np.random.Generator.pareto, (3.0,), {"size": 2}), + (np.random.Generator.poisson, (5,), {"size": 2}), + (np.random.Generator.power, (5.0,), {"size": 2}), + (np.random.Generator.rayleigh, (3,), {"size": 2}), + (np.random.Generator.standard_cauchy, (), {"size": 2}), + (np.random.Generator.standard_exponential, (), {"size": 2}), + (np.random.Generator.standard_gamma, (2.0,), {"size": 2}), + (np.random.Generator.standard_normal, (), {"size": 2}), + (np.random.Generator.standard_t, (10,), {"size": 2}), + (np.random.Generator.triangular, (-3, 0, 8), {"size": 2}), + (np.random.Generator.uniform, (-1, 0), {"size": 2}), + (np.random.Generator.vonmises, (0.0, 4.0), {"size": 2}), + (np.random.Generator.wald, (3, 2), {"size": 2}), + (np.random.Generator.weibull, (5.0,), {"size": 2}), + (np.random.Generator.zipf, (2.0,), {"size": 2}), + ], +) +def test_Generator(backend, method, args, kwargs): + backend, types = backend + try: + with ua.set_backend(backend, coerce=True): + rng = np.random.default_rng() + ret = method(rng, *args, **kwargs) + except ua.BackendNotImplementedError: + if backend in FULLY_TESTED_BACKENDS and (backend, method) not in EXCEPTIONS: + raise + pytest.xfail(reason="The backend has no implementation for this ufunc.") + + if method is np.random.Generator.bytes: + assert isinstance(ret, bytes) + elif method is np.random.Generator.shuffle: + assert ret is None + else: + assert isinstance(ret, types) + + if isinstance(ret, da.Array): + ret.compute() + + @pytest.mark.parametrize( "method, args, kwargs", [ From ba6a9685bae2aa3ec0102e25310f051d82cecb07 Mon Sep 17 00:00:00 2001 From: joaosferreira Date: Thu, 27 Aug 2020 15:52:02 +0100 Subject: [PATCH 3/3] Add further tests to RandomState and Generator --- unumpy/sparse_backend.py | 5 +++-- unumpy/tests/test_numpy.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/unumpy/sparse_backend.py b/unumpy/sparse_backend.py index 9d4970c..3677c79 100644 --- a/unumpy/sparse_backend.py +++ b/unumpy/sparse_backend.py @@ -2,7 +2,7 @@ import sparse from uarray import Dispatchable, wrap_single_convertor from unumpy import ufunc, ufunc_list, ndarray, dtype -from unumpy.random import RandomState +from unumpy.random import RandomState, Generator import unumpy import functools @@ -28,7 +28,8 @@ def array(x, *args, **kwargs): ndarray: sparse.SparseArray, dtype: np.dtype, ufunc: np.ufunc, - RandomState: np.random.mtrand.RandomState, + RandomState: np.random.RandomState, + Generator: np.random.Generator, } diff --git a/unumpy/tests/test_numpy.py b/unumpy/tests/test_numpy.py index dda6552..dc8517c 100644 --- a/unumpy/tests/test_numpy.py +++ b/unumpy/tests/test_numpy.py @@ -804,21 +804,31 @@ def test_class_overriding(): assert isinstance(onp.dtype("float64"), np.dtype) assert np.dtype("float64") == onp.float64 assert isinstance(np.dtype("float64"), onp.dtype) + assert isinstance(onp.random.RandomState(), np.random.RandomState) + assert isinstance(onp.random.Generator(onp.random.PCG64()), np.random.Generator) assert issubclass(onp.ufunc, np.ufunc) + assert issubclass(onp.random.RandomState, np.random.RandomState) + assert issubclass(onp.random.Generator, np.random.Generator) with ua.set_backend(DaskBackend(), coerce=True): assert isinstance(da.add, np.ufunc) assert isinstance(onp.dtype("float64"), np.dtype) assert np.dtype("float64") == onp.float64 assert isinstance(np.dtype("float64"), onp.dtype) + assert isinstance(da.random.RandomState(), np.random.RandomState) assert issubclass(da.ufunc.ufunc, np.ufunc) + assert issubclass(da.random.RandomState, np.random.RandomState) with ua.set_backend(SparseBackend, coerce=True): assert isinstance(onp.add, np.ufunc) assert isinstance(onp.dtype("float64"), np.dtype) assert np.dtype("float64") == onp.float64 assert isinstance(np.dtype("float64"), onp.dtype) + assert isinstance(onp.random.RandomState(), np.random.RandomState) + assert isinstance(onp.random.Generator(onp.random.PCG64()), np.random.Generator) assert issubclass(onp.ufunc, np.ufunc) + assert issubclass(onp.random.RandomState, np.random.RandomState) + assert issubclass(onp.random.Generator, np.random.Generator) if hasattr(CupyBackend, "__ua_function__"): with ua.set_backend(CupyBackend, coerce=True): @@ -826,4 +836,6 @@ def test_class_overriding(): assert isinstance(cp.dtype("float64"), np.dtype) assert np.dtype("float64") == cp.float64 assert isinstance(np.dtype("float64"), cp.dtype) + assert isinstance(cp.random.RandomState(), np.random.RandomState) assert issubclass(cp.ufunc, np.ufunc) + assert issubclass(cp.random.RandomState, np.random.RandomState)