forked from numba/numba
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_overload.py
324 lines (231 loc) · 7.98 KB
/
test_overload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
from numba import cuda, njit, types
from numba.core.errors import TypingError
from numba.core.extending import overload, overload_attribute
from numba.cuda.testing import CUDATestCase, skip_on_cudasim, unittest
from numba.tests.test_extending import mydummy_type, MyDummyType
import numpy as np
# Dummy function definitions to overload
def generic_func_1():
pass
def cuda_func_1():
pass
def generic_func_2():
pass
def cuda_func_2():
pass
def generic_calls_generic():
pass
def generic_calls_cuda():
pass
def cuda_calls_generic():
pass
def cuda_calls_cuda():
pass
def target_overloaded():
pass
def generic_calls_target_overloaded():
pass
def cuda_calls_target_overloaded():
pass
def target_overloaded_calls_target_overloaded():
pass
# To recognise which functions are resolved for a call, we identify each with a
# prime number. Each function called multiplies a value by its prime (starting
# with the value 1), and we can check that the result is as expected based on
# the final value after all multiplications.
GENERIC_FUNCTION_1 = 2
CUDA_FUNCTION_1 = 3
GENERIC_FUNCTION_2 = 5
CUDA_FUNCTION_2 = 7
GENERIC_CALLS_GENERIC = 11
GENERIC_CALLS_CUDA = 13
CUDA_CALLS_GENERIC = 17
CUDA_CALLS_CUDA = 19
GENERIC_TARGET_OL = 23
CUDA_TARGET_OL = 29
GENERIC_CALLS_TARGET_OL = 31
CUDA_CALLS_TARGET_OL = 37
GENERIC_TARGET_OL_CALLS_TARGET_OL = 41
CUDA_TARGET_OL_CALLS_TARGET_OL = 43
# Overload implementations
@overload(generic_func_1, target='generic')
def ol_generic_func_1(x):
def impl(x):
x[0] *= GENERIC_FUNCTION_1
return impl
@overload(cuda_func_1, target='cuda')
def ol_cuda_func_1(x):
def impl(x):
x[0] *= CUDA_FUNCTION_1
return impl
@overload(generic_func_2, target='generic')
def ol_generic_func_2(x):
def impl(x):
x[0] *= GENERIC_FUNCTION_2
return impl
@overload(cuda_func_2, target='cuda')
def ol_cuda_func(x):
def impl(x):
x[0] *= CUDA_FUNCTION_2
return impl
@overload(generic_calls_generic, target='generic')
def ol_generic_calls_generic(x):
def impl(x):
x[0] *= GENERIC_CALLS_GENERIC
generic_func_1(x)
return impl
@overload(generic_calls_cuda, target='generic')
def ol_generic_calls_cuda(x):
def impl(x):
x[0] *= GENERIC_CALLS_CUDA
cuda_func_1(x)
return impl
@overload(cuda_calls_generic, target='cuda')
def ol_cuda_calls_generic(x):
def impl(x):
x[0] *= CUDA_CALLS_GENERIC
generic_func_1(x)
return impl
@overload(cuda_calls_cuda, target='cuda')
def ol_cuda_calls_cuda(x):
def impl(x):
x[0] *= CUDA_CALLS_CUDA
cuda_func_1(x)
return impl
@overload(target_overloaded, target='generic')
def ol_target_overloaded_generic(x):
def impl(x):
x[0] *= GENERIC_TARGET_OL
return impl
@overload(target_overloaded, target='cuda')
def ol_target_overloaded_cuda(x):
def impl(x):
x[0] *= CUDA_TARGET_OL
return impl
@overload(generic_calls_target_overloaded, target='generic')
def ol_generic_calls_target_overloaded(x):
def impl(x):
x[0] *= GENERIC_CALLS_TARGET_OL
target_overloaded(x)
return impl
@overload(cuda_calls_target_overloaded, target='cuda')
def ol_cuda_calls_target_overloaded(x):
def impl(x):
x[0] *= CUDA_CALLS_TARGET_OL
target_overloaded(x)
return impl
@overload(target_overloaded_calls_target_overloaded, target='generic')
def ol_generic_calls_target_overloaded_generic(x):
def impl(x):
x[0] *= GENERIC_TARGET_OL_CALLS_TARGET_OL
target_overloaded(x)
return impl
@overload(target_overloaded_calls_target_overloaded, target='cuda')
def ol_generic_calls_target_overloaded_cuda(x):
def impl(x):
x[0] *= CUDA_TARGET_OL_CALLS_TARGET_OL
target_overloaded(x)
return impl
@skip_on_cudasim('Overloading not supported in cudasim')
class TestOverload(CUDATestCase):
def check_overload(self, kernel, expected):
x = np.ones(1, dtype=np.int32)
cuda.jit(kernel)[1, 1](x)
self.assertEqual(x[0], expected)
def check_overload_cpu(self, kernel, expected):
x = np.ones(1, dtype=np.int32)
njit(kernel)(x)
self.assertEqual(x[0], expected)
def test_generic(self):
def kernel(x):
generic_func_1(x)
expected = GENERIC_FUNCTION_1
self.check_overload(kernel, expected)
def test_cuda(self):
def kernel(x):
cuda_func_1(x)
expected = CUDA_FUNCTION_1
self.check_overload(kernel, expected)
def test_generic_and_cuda(self):
def kernel(x):
generic_func_1(x)
cuda_func_1(x)
expected = GENERIC_FUNCTION_1 * CUDA_FUNCTION_1
self.check_overload(kernel, expected)
def test_call_two_generic_calls(self):
def kernel(x):
generic_func_1(x)
generic_func_2(x)
expected = GENERIC_FUNCTION_1 * GENERIC_FUNCTION_2
self.check_overload(kernel, expected)
def test_call_two_cuda_calls(self):
def kernel(x):
cuda_func_1(x)
cuda_func_2(x)
expected = CUDA_FUNCTION_1 * CUDA_FUNCTION_2
self.check_overload(kernel, expected)
def test_generic_calls_generic(self):
def kernel(x):
generic_calls_generic(x)
expected = GENERIC_CALLS_GENERIC * GENERIC_FUNCTION_1
self.check_overload(kernel, expected)
def test_generic_calls_cuda(self):
def kernel(x):
generic_calls_cuda(x)
expected = GENERIC_CALLS_CUDA * CUDA_FUNCTION_1
self.check_overload(kernel, expected)
def test_cuda_calls_generic(self):
def kernel(x):
cuda_calls_generic(x)
expected = CUDA_CALLS_GENERIC * GENERIC_FUNCTION_1
self.check_overload(kernel, expected)
def test_cuda_calls_cuda(self):
def kernel(x):
cuda_calls_cuda(x)
expected = CUDA_CALLS_CUDA * CUDA_FUNCTION_1
self.check_overload(kernel, expected)
def test_call_target_overloaded(self):
def kernel(x):
target_overloaded(x)
expected = CUDA_TARGET_OL
self.check_overload(kernel, expected)
def test_generic_calls_target_overloaded(self):
def kernel(x):
generic_calls_target_overloaded(x)
expected = GENERIC_CALLS_TARGET_OL * CUDA_TARGET_OL
self.check_overload(kernel, expected)
def test_cuda_calls_target_overloaded(self):
def kernel(x):
cuda_calls_target_overloaded(x)
expected = CUDA_CALLS_TARGET_OL * CUDA_TARGET_OL
self.check_overload(kernel, expected)
def test_target_overloaded_calls_target_overloaded(self):
def kernel(x):
target_overloaded_calls_target_overloaded(x)
# Check the CUDA overloads are used on CUDA
expected = CUDA_TARGET_OL_CALLS_TARGET_OL * CUDA_TARGET_OL
self.check_overload(kernel, expected)
# Also check that the CPU overloads are used on the CPU
expected = GENERIC_TARGET_OL_CALLS_TARGET_OL * GENERIC_TARGET_OL
self.check_overload_cpu(kernel, expected)
def test_overload_attribute_target(self):
@overload_attribute(MyDummyType, 'cuda_only', target='cuda')
def ov_dummy_cuda_attr(obj):
def imp(obj):
return 42
return imp
# Ensure that we cannot use the CUDA target-specific attribute on the
# CPU, and that an appropriate typing error is raised
with self.assertRaisesRegex(TypingError,
"Unknown attribute 'cuda_only'"):
@njit(types.void(mydummy_type))
def illegal_target_attr_use(x):
return x.cuda_only
# Ensure that the CUDA target-specific attribute is usable and works
# correctly when the target is CUDA - note eager compilation via
# signature
@cuda.jit(types.void(types.int64[::1], mydummy_type))
def cuda_target_attr_use(res, dummy):
res[0] = dummy.cuda_only
if __name__ == '__main__':
unittest.main()