Skip to content

Commit

Permalink
be more precise about mutated symbols in custom handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
smacke committed Jan 29, 2024
1 parent 48a1ee9 commit 9226dc9
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 12 deletions.
7 changes: 6 additions & 1 deletion core/ipyflow/tracing/external_calls/__init__.py
Expand Up @@ -2,7 +2,7 @@
import ast
import logging
from types import ModuleType
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

# force handler registration by exec()ing the handler modules here
import ipyflow.tracing.external_calls.base_handlers # noqa: F401
Expand All @@ -17,6 +17,9 @@
StandardMutation,
)

if TYPE_CHECKING:
from ipyflow.data_model.symbol import Symbol


def resolve_external_call(
module: Optional[ModuleType],
Expand All @@ -25,6 +28,7 @@ def resolve_external_call(
method: Optional[str],
call_node: Optional[ast.Call] = None,
use_standard_default: bool = True,
calling_symbol: Optional["Symbol"] = None,
) -> Optional[ExternalCallHandler]:
if hasattr(function_or_method, "__self__") and hasattr(
function_or_method, "__name__"
Expand Down Expand Up @@ -87,4 +91,5 @@ def resolve_external_call(
caller_self=caller_self,
function_or_method=function_or_method,
call_node=call_node,
calling_symbol=calling_symbol,
)
45 changes: 35 additions & 10 deletions core/ipyflow/tracing/external_calls/base_handlers.py
Expand Up @@ -61,11 +61,13 @@ def create(cls, **kwargs) -> "ExternalCallHandler":
caller_self = kwargs.pop("caller_self", None)
function_or_method = kwargs.pop("function_or_method", None)
call_node = kwargs.pop("call_node", None)
calling_symbol = kwargs.pop("calling_symbol", None)
return cls(
module=module,
caller_self=caller_self,
function_or_method=function_or_method,
call_node=call_node,
calling_symbol=calling_symbol,
)._initialize_impl(**kwargs)

def _initialize_impl(self, **kwargs) -> "ExternalCallHandler":
Expand All @@ -86,6 +88,7 @@ def __init__(
caller_self: Any = None,
function_or_method: Any = None,
call_node: Optional[ast.Call] = None,
calling_symbol: Optional["Symbol"] = None,
) -> None:
self.module = module
self.caller_self = caller_self
Expand All @@ -95,6 +98,7 @@ def __init__(
self._arg_syms: Optional[Set["Symbol"]] = None
self.return_value: Any = self.not_yet_defined
self.call_node = call_node
self.calling_symbol = calling_symbol
self.stmt_node = tracer().prev_trace_stmt_in_cur_frame.stmt_node

def __init_subclass__(cls):
Expand Down Expand Up @@ -154,7 +158,26 @@ def _handle_impl(self) -> None:
def mutate_caller(self, should_propagate: bool) -> None:
if self.caller_self is None:
return
self.mutate_aliases(self.caller_self_obj_id, should_propagate=should_propagate)
if self.calling_symbol is None:
syms_to_mutate = []
if isinstance(self.call_node.func, ast.Attribute) and isinstance(
self.call_node.func.value, ast.Name
):
syms_to_mutate = [
sym
for sym in flow().aliases.get(self.caller_self_obj_id, [])
if sym.name == self.call_node.func.value.id
]
for sym in syms_to_mutate:
self._mutate_calling_symbol(sym, should_propagate=should_propagate)
if len(syms_to_mutate) == 0:
self.mutate_aliases(
self.caller_self_obj_id, should_propagate=should_propagate
)
else:
self._mutate_calling_symbol(
self.calling_symbol, should_propagate=should_propagate
)

def mutate_module(self, should_propagate: bool) -> None:
if self.module is None:
Expand All @@ -163,15 +186,17 @@ def mutate_module(self, should_propagate: bool) -> None:

def mutate_aliases(self, obj_id: Optional[int], should_propagate: bool) -> None:
mutated_syms = flow().aliases.get(obj_id, set())
Timestamp.update_usage_info(mutated_syms)
for mutated_sym in mutated_syms:
mutated_sym.update_deps(
self.arg_syms,
overwrite=False,
mutated=True,
propagate_to_namespace_descendents=should_propagate,
refresh=should_propagate,
)
for sym in mutated_syms:
self._mutate_calling_symbol(sym, should_propagate=should_propagate)

def _mutate_calling_symbol(self, sym: "Symbol", should_propagate: bool) -> None:
sym.update_deps(
self.arg_syms,
overwrite=False,
mutated=True,
propagate_to_namespace_descendents=should_propagate,
refresh=should_propagate,
)

def handle(self) -> Optional[Union["Symbol", Iterable["Symbol"]]]:
pass
Expand Down
7 changes: 6 additions & 1 deletion core/ipyflow/tracing/ipyflow_tracer.py
Expand Up @@ -1162,7 +1162,11 @@ def _save_external_call_candidate(
call_node: ast.Call,
) -> None:
self.external_call_candidate = resolve_external_call(
module, obj, function_or_method, method_name, call_node
module,
obj,
function_or_method,
method_name,
call_node,
)

@pyc.before_call
Expand Down Expand Up @@ -1491,6 +1495,7 @@ def after_module_stmt(self, _ret, stmt: ast.stmt, *_, **__) -> Optional[Any]:
ret,
resolve_rval_symbols(stmt, should_update_usage_info=False),
stmt,
propagate=False,
)
self._module_stmt_counter += 1
self.tracing_disabled_since_last_module_stmt = False
Expand Down

0 comments on commit 9226dc9

Please sign in to comment.