Skip to content

Commit

Permalink
uninstrument functionality for decorated functions, including with cl…
Browse files Browse the repository at this point in the history
…osures
  • Loading branch information
smacke committed Feb 24, 2024
1 parent fe89643 commit ba80697
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 45 deletions.
2 changes: 2 additions & 0 deletions core/ipyflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ipyflow.shell import load_ipython_extension as load_ipyflow_extension, unload_ipython_extension as unload_ipyflow_extension
from ipyflow.models import cell_above, cell_below, cell_at_offset, cells, last_run_cell, namespaces, scopes, statements, symbols, timestamps
from ipyflow.singletons import flow, kernel, shell, tracer
from ipyflow.tracing.uninstrument import uninstrument

from . import _version
__version__ = _version.get_versions()['version']
Expand Down Expand Up @@ -84,6 +85,7 @@ def unload_ipython_extension(ipy: "InteractiveShell") -> None:
"symbols",
"timestamps",
"tracer",
"uninstrument",
]


Expand Down
3 changes: 3 additions & 0 deletions core/ipyflow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ def __init__(self, **kwargs) -> None:
# Note: explicitly adding the types helps PyCharm intellisense
self.namespaces: Dict[int, Namespace] = {}
self.aliases: Dict[int, Set[Symbol]] = {}
self.deco_metadata_by_obj_id: Dict[
int, Tuple[Union[ast.FunctionDef, ast.AsyncFunctionDef], int]
] = {}
self.starred_import_modules: Set[str] = set()
self.stmt_deferred_static_parents: Dict[
Timestamp, Dict[Timestamp, Set[Symbol]]
Expand Down
38 changes: 6 additions & 32 deletions core/ipyflow/tracing/external_calls/cloudpickle_patch.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import ast
from types import FunctionType, LambdaType
from typing import TYPE_CHECKING, Any, Dict, Type, Union
from typing import TYPE_CHECKING, Type, Union

from ipyflow.singletons import flow
from ipyflow.tracing.uninstrument import uninstrument

if TYPE_CHECKING:
import astunparse
from cloudpickle.cloudpickle_fast import CloudPickler
elif hasattr(ast, "unparse"):
astunparse = ast
else:
import astunparse


def _function_reduce(self_, obj) -> None:
Expand All @@ -20,30 +14,10 @@ def _function_reduce(self_, obj) -> None:
def _patched_function_reduce(
self_: "CloudPickler", obj: Union[FunctionType, LambdaType]
) -> None:
for alias in flow().aliases.get(id(obj), []):
if not alias.is_function and not alias.is_lambda:
continue
try:
local_env: Dict[str, Any] = {}
func_defn = astunparse.unparse(alias.func_def_stmt)
if isinstance(alias.func_def_stmt, (ast.AsyncFunctionDef, ast.FunctionDef)):
func_name = alias.func_def_stmt.name
elif isinstance(alias.func_def_stmt, ast.Lambda):
func_name = "lambda_sym"
func_defn = f"{func_name} = {func_defn}"
else:
continue
exec(func_defn, obj.__globals__, local_env)
new_obj = local_env[func_name]
except: # noqa
continue
if isinstance(new_obj, (FunctionType, LambdaType)):
obj = new_obj
break
return _function_reduce(self_, obj)


def patch_cloudpickle_function_getstate(pickler_cls: Type["CloudPickler"]) -> None:
return _function_reduce(self_, uninstrument(obj))


def patch_cloudpickle_function_reduce(pickler_cls: Type["CloudPickler"]) -> None:
global _function_reduce
_function_reduce = pickler_cls._function_reduce
pickler_cls._function_reduce = _patched_function_reduce
32 changes: 20 additions & 12 deletions core/ipyflow/tracing/ipyflow_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ipyflow.tracing.external_calls import resolve_external_call
from ipyflow.tracing.external_calls.base_handlers import ExternalCallHandler
from ipyflow.tracing.external_calls.cloudpickle_patch import (
patch_cloudpickle_function_getstate,
patch_cloudpickle_function_reduce,
)
from ipyflow.tracing.flow_ast_rewriter import DataflowAstRewriter
from ipyflow.tracing.symbol_resolver import resolve_rval_symbols
Expand Down Expand Up @@ -824,18 +824,10 @@ def after_import(self, *_, module: ModuleType, **__):
compile_and_register_handlers_for_module(module)
modname = getattr(module, "__name__", "")
if modname == "numpy":
if TYPE_CHECKING:
import numpy
else:
numpy = module
# TODO: convert these to Python ints when used on Python objects
SubscriptIndices.types += (numpy.int32, numpy.int64)
elif modname == "cloudpickle.cloudpickle_fast":
if TYPE_CHECKING:
from cloudpickle import cloudpickle_fast
else:
cloudpickle_fast = module
patch_cloudpickle_function_getstate(cloudpickle_fast.CloudPickler)
SubscriptIndices.types += (module.int32, module.int64)
elif modname.endswith("cloudpickle.cloudpickle_fast"):
patch_cloudpickle_function_reduce(module.CloudPickler)

@pyc.register_raw_handler(
(
Expand Down Expand Up @@ -1426,6 +1418,22 @@ def after_lambda(self, obj: Any, lambda_node_id: int, frame: FrameType, *_, **__
sym.func_def_stmt = node
self.node_id_to_loaded_symbols.setdefault(lambda_node_id, []).append(sym)

@pyc.register_raw_handler(pyc.decorator)
def decorator(self, deco: Any, *_, func_node_id: int, decorator_idx: int, **__):
def tracing_decorator(func):
flow_ = flow()
try:
flow_.deco_metadata_by_obj_id[id(func)] = (
self.ast_node_by_id[func_node_id],
decorator_idx,
)
except KeyError:
if flow_.is_dev_mode:
logger.exception("failed to lookup node for func id %s", func)
return deco(func)

return tracing_decorator

@pyc.register_raw_handler(pyc.after_stmt)
def after_stmt(self, ret_expr: Any, stmt_id: int, frame: FrameType, *_, **__):
self._saved_stmt_ret_expr = ret_expr
Expand Down
78 changes: 78 additions & 0 deletions core/ipyflow/tracing/uninstrument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import ast
import copy
import textwrap
from types import FunctionType, LambdaType
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

from ipyflow.singletons import flow

if TYPE_CHECKING:
import astunparse
elif hasattr(ast, "unparse"):
astunparse = ast
else:
import astunparse


def _make_uninstrumented_function(
obj: Union[FunctionType, LambdaType], func_text: str, func_node: ast.AST
):
if isinstance(func_node, (ast.AsyncFunctionDef, ast.FunctionDef)):
func_name = func_node.name
elif isinstance(func_node, ast.Lambda):
func_name = "lambda_sym"
func_text = f"{func_name} = {func_text}"
else:
return None
local_env: Dict[str, Any] = {}
if obj.__closure__ is not None:
kwargs: Dict[str, Any] = {}
for cell_name, cell in zip(obj.__code__.co_freevars, obj.__closure__):
kwargs[cell_name] = cell.cell_contents
func_text = textwrap.indent(func_text, " ")
func_text = f"""
def _Xix_make_closure({", ".join(kwargs.keys())}):
{func_text}
return {func_name}
{func_name} = _Xix_make_closure(**kwargs)"""
local_env["kwargs"] = kwargs
try:
exec(func_text, obj.__globals__, local_env)
new_obj = local_env[func_name]
except: # noqa
return None
if hasattr(obj, "__name__") and hasattr(new_obj, "__name__"):
new_obj.__name__ = obj.__name__
if isinstance(new_obj, (FunctionType, LambdaType)):
return new_obj
else:
return None


def _get_uninstrumented_decorator(obj: Union[FunctionType, LambdaType]):
func_node, decorator_idx = flow().deco_metadata_by_obj_id.get(id(obj), (None, None))
if func_node is None:
return None
func_node = copy.deepcopy(func_node)
func_node.decorator_list = func_node.decorator_list[:decorator_idx]
func_text = astunparse.unparse(func_node)
return _make_uninstrumented_function(obj, func_text, func_node)


def uninstrument(
obj: Union[FunctionType, LambdaType]
) -> Optional[Union[FunctionType, LambdaType]]:
try:
new_obj = _get_uninstrumented_decorator(obj)
except: # noqa
new_obj = None
if new_obj is not None:
return new_obj
for alias in flow().aliases.get(id(obj), []):
if not alias.is_function and not alias.is_lambda:
continue
func_text = astunparse.unparse(alias.func_def_stmt)
new_obj = _make_uninstrumented_function(obj, func_text, alias.func_def_stmt)
if new_obj is not None:
return new_obj
return None
2 changes: 1 addition & 1 deletion core/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ install_requires =
ipython <= 7.16; python_version < '3.8'
ipywidgets
nest_asyncio
pyccolo==0.0.52
pyccolo==0.0.53
traitlets
[options.packages.find]
Expand Down

0 comments on commit ba80697

Please sign in to comment.