forked from Quansight-Labs/unumpy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cupy_backend.py
100 lines (75 loc) · 2.89 KB
/
cupy_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
try:
import numpy as np
import cupy as cp
from uarray import Dispatchable, wrap_single_convertor
from unumpy import ufunc, ufunc_list, ndarray
import unumpy
import functools
from typing import Dict
_ufunc_mapping: Dict[ufunc, np.ufunc] = {}
__ua_domain__ = "numpy"
def overridden_class(self):
module = self.__module__.split(".")
module = ".".join(m for m in module if m != "_multimethods")
return _get_from_name_domain(self.__name__, module)
_implementations: Dict = {
unumpy.ClassOverrideMeta.overridden_class.fget: overridden_class
}
def _get_from_name_domain(name, domain):
module = cp
name_hierarchy = name.split(".")
domain_hierarchy = domain.split(".") + name_hierarchy[0:-1]
for d in domain_hierarchy[1:]:
if hasattr(module, d):
module = getattr(module, d)
else:
return NotImplemented
if hasattr(module, name_hierarchy[-1]):
return getattr(module, name_hierarchy[-1])
else:
return NotImplemented
def _implements(np_func):
def inner(func):
_implementations[np_func] = func
return func
return inner
def __ua_function__(method, args, kwargs):
if method in _implementations:
return _implementations[method](*args, **kwargs)
if len(args) != 0 and isinstance(args[0], unumpy.ClassOverrideMeta):
return NotImplemented
cupy_method = _get_from_name_domain(method.__qualname__, method.domain)
if cupy_method is NotImplemented:
return NotImplemented
return cupy_method(*args, **kwargs)
@wrap_single_convertor
def __ua_convert__(value, dispatch_type, coerce):
if dispatch_type is ufunc and hasattr(cp, value.name):
return getattr(cp, value.name)
if value is None:
return None
if dispatch_type is ndarray:
if not coerce and not isinstance(value, cp.ndarray):
return NotImplemented
return cp.asarray(value)
return value
def replace_self(func):
@functools.wraps(func)
def inner(self, *args, **kwargs):
if self not in _ufunc_mapping:
return NotImplemented
return func(_ufunc_mapping[self], *args, **kwargs)
return inner
@_implements(unumpy.ascontiguousarray)
def _ascontiguousarray(arr, dtype=None):
return cp.asarray(arr, dtype=dtype, order="C")
@_implements(unumpy.asfortranarray)
def _asfortranarray(arr, dtype=None):
return cp.asarray(arr, dtype=dtype, order="F")
@_implements(unumpy.ufunc.__call__)
def _ufunc_call(self, *args, **kwargs):
fname = self.name
f = getattr(cp, fname, lambda *a, **kw: NotImplemented)
return f(*args, **kwargs)
except ImportError:
pass