Skip to content

Commit

Permalink
Merge pull request #73 from joaosferreira/random-routines
Browse files Browse the repository at this point in the history
Multimethods for the random module
  • Loading branch information
hameerabbasi committed Aug 28, 2020
2 parents 1607d9b + ba6a968 commit 5720e11
Show file tree
Hide file tree
Showing 9 changed files with 902 additions and 21 deletions.
3 changes: 3 additions & 0 deletions unumpy/__init__.py
Expand Up @@ -152,6 +152,9 @@
"""
from ._multimethods import *
from .lib import c_, r_, s_
from . import linalg
from . import lib
from . import random

from ._version import get_versions

Expand Down
10 changes: 10 additions & 0 deletions unumpy/_multimethods.py
Expand Up @@ -29,6 +29,9 @@ def _dtype_argreplacer(args, kwargs, dispatchables):
def replacer(*a, dtype=None, **kw):
out_kw = kw.copy()
out_kw["dtype"] = dispatchables[0]
if "out" in out_kw:
out_kw["out"] = dispatchables[1]

return a, out_kw

return replacer(*args, **kwargs)
Expand All @@ -45,6 +48,13 @@ def self_method(a, *args, **kwargs):
return self_method(*args, **kwargs)


def _skip_self_argreplacer(args, kwargs, dispatchables):
def replacer(self, *args, **kwargs):
return (self,) + dispatchables, kwargs

return replacer(*args, **kwargs)


def _ureduce_argreplacer(args, kwargs, dispatchables):
def ureduce(self, a, axis=0, dtype=None, out=None, keepdims=False):
return (
Expand Down
9 changes: 5 additions & 4 deletions unumpy/cupy_backend.py
Expand Up @@ -23,14 +23,15 @@ def overridden_class(self):

def _get_from_name_domain(name, domain):
module = cp
domain_hierarchy = domain.split(".")
name_hierarchy = name.split(".")
domain_hierarchy = domain.split(".") + name_hierarchy[0:-1]
for d in domain_hierarchy[1:]:
if hasattr(module, d):
module = getattr(module, d)
else:
return NotImplemented
if hasattr(module, name):
return getattr(module, name)
if hasattr(module, name_hierarchy[-1]):
return getattr(module, name_hierarchy[-1])
else:
return NotImplemented

Expand All @@ -48,7 +49,7 @@ def __ua_function__(method, args, kwargs):
if len(args) != 0 and isinstance(args[0], unumpy.ClassOverrideMeta):
return NotImplemented

cupy_method = _get_from_name_domain(method.__name__, method.domain)
cupy_method = _get_from_name_domain(method.__qualname__, method.domain)
if cupy_method is NotImplemented:
return NotImplemented

Expand Down
9 changes: 5 additions & 4 deletions unumpy/dask_backend.py
Expand Up @@ -31,14 +31,15 @@ def overridden_class(self):

def _get_from_name_domain(name, domain):
module = da
domain_hierarchy = domain.split(".")
name_hierarchy = name.split(".")
domain_hierarchy = domain.split(".") + name_hierarchy[0:-1]
for d in domain_hierarchy[1:]:
if hasattr(module, d):
module = getattr(module, d)
else:
return NotImplemented
if hasattr(module, name):
return getattr(module, name)
if hasattr(module, name_hierarchy[-1]):
return getattr(module, name_hierarchy[-1])
else:
return NotImplemented

Expand Down Expand Up @@ -144,7 +145,7 @@ def __ua_function__(self, method, args, kwargs):
if len(args) != 0 and isinstance(args[0], unumpy.ClassOverrideMeta):
return NotImplemented

dask_method = _get_from_name_domain(method.__name__, method.domain)
dask_method = _get_from_name_domain(method.__qualname__, method.domain)
if dask_method is NotImplemented:
return NotImplemented

Expand Down
11 changes: 6 additions & 5 deletions unumpy/numpy_backend.py
@@ -1,6 +1,6 @@
import numpy as np
from uarray import Dispatchable, wrap_single_convertor
from unumpy import ufunc, ufunc_list, ndarray, dtype, linalg
from unumpy import ufunc, ufunc_list, ndarray, dtype
import unumpy
import functools

Expand Down Expand Up @@ -29,11 +29,12 @@ def overridden_class(self):

def _get_from_name_domain(name, domain):
module = np
domain_hierarchy = domain.split(".")
name_hierarchy = name.split(".")
domain_hierarchy = domain.split(".") + name_hierarchy[0:-1]
for d in domain_hierarchy[1:]:
module = getattr(module, d)
if hasattr(module, name):
return getattr(module, name)
if hasattr(module, name_hierarchy[-1]):
return getattr(module, name_hierarchy[-1])
else:
return NotImplemented

Expand All @@ -45,7 +46,7 @@ def __ua_function__(method, args, kwargs):
if len(args) != 0 and isinstance(args[0], unumpy.ClassOverrideMeta):
return NotImplemented

method_numpy = _get_from_name_domain(method.__name__, method.domain)
method_numpy = _get_from_name_domain(method.__qualname__, method.domain)
if method_numpy is NotImplemented:
return NotImplemented

Expand Down
1 change: 1 addition & 0 deletions unumpy/random/__init__.py
@@ -0,0 +1 @@
from ._multimethods import *

0 comments on commit 5720e11

Please sign in to comment.