-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More NumPy operation implementations #1498
Open
tbennun
wants to merge
21
commits into
master
Choose a base branch
from
numpy-extension
base: master
Could not load branches
Branch not found: {{ refName }}
Could not load tags
Nothing to show
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
fdd73cf
Implement numpy concatenation and stacking
tbennun 531576f
Implement numpy.linspace, add argument checks to numpy.arange
tbennun 9687cb6
Fix complex to scalar type inference
tbennun 73413da
Fix parsing of nested attributes
tbennun 06e5cac
Implement np.clip ufunc
tbennun abd0fe6
Fix attribute evaluation for non-arrays
tbennun a2e47c6
Implement numpy.split and its variants
tbennun 66ef621
Fix fast transposition for mismatched input/output types
tbennun 80ce041
Fix arange result type, fix numpy.full variants for scalar shapes
tbennun a052d42
Fix another case of attribute misparsing
tbennun e7ebad5
Safer creation of complex values in codegen
tbennun df33fdd
Support complex gemv in CPU BLAS replacement
tbennun 0eb1398
Fix parsing of builtin values and symbolic expressions
tbennun b4bd72e
Implement len builtin for constants
tbennun b835cd2
Further fix for callbacks
tbennun 7574fe9
Fix tests
tbennun 21b191d
Implement `numpy.fft.{fft,ifft}` and library node
tbennun 2119c42
Cast cblas_transpose correctly
tbennun 293f354
Merge branch 'master' into numpy-extension
alexnick83 949572f
Merge branch 'master' into numpy-extension
tbennun b89b33a
Merge branch 'master' into numpy-extension
tbennun File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,7 +49,6 @@ | |
Shape = Union[ShapeTuple, ShapeList] | ||
DependencyType = Dict[str, Tuple[SDFGState, Union[Memlet, nodes.Tasklet], Tuple[int]]] | ||
|
||
|
||
if sys.version_info < (3, 8): | ||
_simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num) | ||
BytesConstant = ast.Bytes | ||
|
@@ -65,15 +64,13 @@ | |
NumConstant = ast.Constant | ||
StrConstant = ast.Constant | ||
|
||
|
||
if sys.version_info < (3, 9): | ||
Index = ast.Index | ||
ExtSlice = ast.ExtSlice | ||
else: | ||
Index = type(None) | ||
ExtSlice = type(None) | ||
|
||
|
||
if sys.version_info < (3, 12): | ||
TypeAlias = type(None) | ||
else: | ||
|
@@ -4330,7 +4327,14 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): | |
func = node.func.value | ||
|
||
if func is None: | ||
funcname = rname(node) | ||
func_result = self.visit(node.func) | ||
if isinstance(func_result, str): | ||
if isinstance(node.func, ast.Attribute): | ||
funcname = f'{func_result}.{node.func.attr}' | ||
else: | ||
funcname = func_result | ||
else: | ||
funcname = rname(node) | ||
# Check if the function exists as an SDFG in a different module | ||
modname = until(funcname, '.') | ||
if ('.' in funcname and len(modname) > 0 and modname in self.globals | ||
|
@@ -4426,7 +4430,7 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): | |
arg = self.scope_vars[modname] | ||
else: | ||
# Fallback to (name, object) | ||
arg = (modname, self.defined[modname]) | ||
arg = modname | ||
args.append(arg) | ||
# Otherwise, try to find a default implementation for the SDFG | ||
elif not found_ufunc: | ||
|
@@ -4623,12 +4627,18 @@ def _visitname(self, name: str, node: ast.AST): | |
self.sdfg.add_symbol(result.name, result.dtype) | ||
return result | ||
|
||
if name in self.closure.callbacks: | ||
return name | ||
|
||
if name in self.sdfg.arrays: | ||
return name | ||
|
||
if name in self.sdfg.symbols: | ||
return name | ||
|
||
if name in __builtins__: | ||
return name | ||
|
||
if name not in self.scope_vars: | ||
raise DaceSyntaxError(self, node, 'Use of undefined variable "%s"' % name) | ||
rname = self.scope_vars[name] | ||
|
@@ -4673,30 +4683,39 @@ def visit_NameConstant(self, node: NameConstant): | |
return self.visit_Constant(node) | ||
|
||
def visit_Attribute(self, node: ast.Attribute): | ||
# If visiting an attribute, return attribute value if it's of an array or global | ||
name = until(astutils.unparse(node), '.') | ||
result = self._visitname(name, node) | ||
result = self.visit(node.value) | ||
if isinstance(result, (tuple, list, dict)): | ||
if len(result) > 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the fix for attributes on expressions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exactly |
||
raise DaceSyntaxError( | ||
self, node.value, f'{type(result)} object cannot use attributes. Try storing the ' | ||
'object to a different variable first (e.g., ``a = result; a.attribute``') | ||
else: | ||
result = result[0] | ||
|
||
if isinstance(result, str) and result in self.sdfg.arrays: | ||
arr = self.sdfg.arrays[result] | ||
elif isinstance(result, str) and result in self.scope_arrays: | ||
arr = self.scope_arrays[result] | ||
else: | ||
return result | ||
arr = None | ||
|
||
# Try to find sub-SDFG attribute | ||
func = oprepo.Replacements.get_attribute(type(arr), node.attr) | ||
if func is not None: | ||
# A new state is likely needed here, e.g., for transposition (ndarray.T) | ||
self._add_state('%s_%d' % (type(node).__name__, node.lineno)) | ||
self.last_state.set_default_lineinfo(self.current_lineinfo) | ||
result = func(self, self.sdfg, self.last_state, result) | ||
self.last_state.set_default_lineinfo(None) | ||
return result | ||
if arr is not None: | ||
func = oprepo.Replacements.get_attribute(type(arr), node.attr) | ||
if func is not None: | ||
# A new state is likely needed here, e.g., for transposition (ndarray.T) | ||
self._add_state('%s_%d' % (type(node).__name__, node.lineno)) | ||
self.last_state.set_default_lineinfo(self.current_lineinfo) | ||
result = func(self, self.sdfg, self.last_state, result) | ||
self.last_state.set_default_lineinfo(None) | ||
return result | ||
|
||
# Otherwise, try to find compile-time attribute (such as shape) | ||
try: | ||
return getattr(arr, node.attr) | ||
except KeyError: | ||
if arr is not None: | ||
return getattr(arr, node.attr) | ||
return getattr(result, node.attr) | ||
except (AttributeError, KeyError): | ||
return result | ||
|
||
def visit_List(self, node: ast.List): | ||
|
@@ -4718,7 +4737,7 @@ def visit_Dict(self, node: ast.Dict): | |
def visit_Lambda(self, node: ast.Lambda): | ||
# Return a string representation of the function | ||
return astutils.unparse(node) | ||
|
||
def visit_TypeAlias(self, node: TypeAlias): | ||
raise NotImplementedError('Type aliases are not supported in DaCe') | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change handled already in the code somehow, i.e., is
self.defined
queried on demand?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s a code path that we didn’t cover before in coverage. The rest of the code assumes a string return type rather than a tuple, and crashes somewhere else.