From 9226dc90accb63cd376ad2e90165558d029a699c Mon Sep 17 00:00:00 2001 From: Stephen Macke Date: Mon, 29 Jan 2024 13:32:22 -0800 Subject: [PATCH] be more precise about mutated symbols in custom handlers --- .../tracing/external_calls/__init__.py | 7 ++- .../tracing/external_calls/base_handlers.py | 45 ++++++++++++++----- core/ipyflow/tracing/ipyflow_tracer.py | 7 ++- 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/core/ipyflow/tracing/external_calls/__init__.py b/core/ipyflow/tracing/external_calls/__init__.py index cca39072..ff996f4f 100644 --- a/core/ipyflow/tracing/external_calls/__init__.py +++ b/core/ipyflow/tracing/external_calls/__init__.py @@ -2,7 +2,7 @@ import ast import logging from types import ModuleType -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional # force handler registration by exec()ing the handler modules here import ipyflow.tracing.external_calls.base_handlers # noqa: F401 @@ -17,6 +17,9 @@ StandardMutation, ) +if TYPE_CHECKING: + from ipyflow.data_model.symbol import Symbol + def resolve_external_call( module: Optional[ModuleType], @@ -25,6 +28,7 @@ def resolve_external_call( method: Optional[str], call_node: Optional[ast.Call] = None, use_standard_default: bool = True, + calling_symbol: Optional["Symbol"] = None, ) -> Optional[ExternalCallHandler]: if hasattr(function_or_method, "__self__") and hasattr( function_or_method, "__name__" @@ -87,4 +91,5 @@ def resolve_external_call( caller_self=caller_self, function_or_method=function_or_method, call_node=call_node, + calling_symbol=calling_symbol, ) diff --git a/core/ipyflow/tracing/external_calls/base_handlers.py b/core/ipyflow/tracing/external_calls/base_handlers.py index 56da2ad0..096e8540 100644 --- a/core/ipyflow/tracing/external_calls/base_handlers.py +++ b/core/ipyflow/tracing/external_calls/base_handlers.py @@ -61,11 +61,13 @@ def create(cls, **kwargs) -> "ExternalCallHandler": caller_self = kwargs.pop("caller_self", None) function_or_method = kwargs.pop("function_or_method", None) call_node = kwargs.pop("call_node", None) + calling_symbol = kwargs.pop("calling_symbol", None) return cls( module=module, caller_self=caller_self, function_or_method=function_or_method, call_node=call_node, + calling_symbol=calling_symbol, )._initialize_impl(**kwargs) def _initialize_impl(self, **kwargs) -> "ExternalCallHandler": @@ -86,6 +88,7 @@ def __init__( caller_self: Any = None, function_or_method: Any = None, call_node: Optional[ast.Call] = None, + calling_symbol: Optional["Symbol"] = None, ) -> None: self.module = module self.caller_self = caller_self @@ -95,6 +98,7 @@ def __init__( self._arg_syms: Optional[Set["Symbol"]] = None self.return_value: Any = self.not_yet_defined self.call_node = call_node + self.calling_symbol = calling_symbol self.stmt_node = tracer().prev_trace_stmt_in_cur_frame.stmt_node def __init_subclass__(cls): @@ -154,7 +158,26 @@ def _handle_impl(self) -> None: def mutate_caller(self, should_propagate: bool) -> None: if self.caller_self is None: return - self.mutate_aliases(self.caller_self_obj_id, should_propagate=should_propagate) + if self.calling_symbol is None: + syms_to_mutate = [] + if isinstance(self.call_node.func, ast.Attribute) and isinstance( + self.call_node.func.value, ast.Name + ): + syms_to_mutate = [ + sym + for sym in flow().aliases.get(self.caller_self_obj_id, []) + if sym.name == self.call_node.func.value.id + ] + for sym in syms_to_mutate: + self._mutate_calling_symbol(sym, should_propagate=should_propagate) + if len(syms_to_mutate) == 0: + self.mutate_aliases( + self.caller_self_obj_id, should_propagate=should_propagate + ) + else: + self._mutate_calling_symbol( + self.calling_symbol, should_propagate=should_propagate + ) def mutate_module(self, should_propagate: bool) -> None: if self.module is None: @@ -163,15 +186,17 @@ def mutate_module(self, should_propagate: bool) -> None: def mutate_aliases(self, obj_id: Optional[int], should_propagate: bool) -> None: mutated_syms = flow().aliases.get(obj_id, set()) - Timestamp.update_usage_info(mutated_syms) - for mutated_sym in mutated_syms: - mutated_sym.update_deps( - self.arg_syms, - overwrite=False, - mutated=True, - propagate_to_namespace_descendents=should_propagate, - refresh=should_propagate, - ) + for sym in mutated_syms: + self._mutate_calling_symbol(sym, should_propagate=should_propagate) + + def _mutate_calling_symbol(self, sym: "Symbol", should_propagate: bool) -> None: + sym.update_deps( + self.arg_syms, + overwrite=False, + mutated=True, + propagate_to_namespace_descendents=should_propagate, + refresh=should_propagate, + ) def handle(self) -> Optional[Union["Symbol", Iterable["Symbol"]]]: pass diff --git a/core/ipyflow/tracing/ipyflow_tracer.py b/core/ipyflow/tracing/ipyflow_tracer.py index ec60bf75..8e46139b 100644 --- a/core/ipyflow/tracing/ipyflow_tracer.py +++ b/core/ipyflow/tracing/ipyflow_tracer.py @@ -1162,7 +1162,11 @@ def _save_external_call_candidate( call_node: ast.Call, ) -> None: self.external_call_candidate = resolve_external_call( - module, obj, function_or_method, method_name, call_node + module, + obj, + function_or_method, + method_name, + call_node, ) @pyc.before_call @@ -1491,6 +1495,7 @@ def after_module_stmt(self, _ret, stmt: ast.stmt, *_, **__) -> Optional[Any]: ret, resolve_rval_symbols(stmt, should_update_usage_info=False), stmt, + propagate=False, ) self._module_stmt_counter += 1 self.tracing_disabled_since_last_module_stmt = False