Skip to content

Commit

Permalink
actual bugfix for tracing reenablement in module stmt
Browse files Browse the repository at this point in the history
  • Loading branch information
smacke committed Dec 5, 2023
1 parent f83c14c commit 72ca706
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 80 deletions.
13 changes: 12 additions & 1 deletion core/ipyflow/data_model/timestamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
import ast
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Generator, Iterable, NamedTuple, Optional, Union
from typing import (
TYPE_CHECKING,
Generator,
Iterable,
NamedTuple,
Optional,
Tuple,
Union,
)

from ipyflow.models import _TimestampContainer, cells, timestamps
from ipyflow.singletons import flow, tracer, tracer_initialized
Expand Down Expand Up @@ -74,6 +82,9 @@ def offset(
_cell_offset -= cell_offset
_stmt_offset -= stmt_offset

def as_tuple(self) -> Tuple[int, int]:
return (self.cell_num, self.stmt_num)

def __eq__(self, other) -> bool:
if other is None:
return False
Expand Down
152 changes: 73 additions & 79 deletions core/ipyflow/tracing/ipyflow_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,88 +415,80 @@ def _handle_return_transition(
self._tracked_disable_tracing(frame)
return
assert return_to_stmt is not None
if self.prev_event != pyc.exception:
if self.prev_event == pyc.exception:
# exception events are followed by return events until we hit an except clause
# no need to track dependencies in this case
if isinstance(return_to_stmt.stmt_node, ast.ClassDef):
return_to_stmt.class_scope = cast(
Namespace, self.cur_frame_original_scope
)
elif (
isinstance(trace_stmt.stmt_node, ast.Return)
or inside_anonymous_call
return
if isinstance(return_to_stmt.stmt_node, ast.ClassDef):
return_to_stmt.class_scope = cast(
Namespace, self.cur_frame_original_scope
)
elif (
isinstance(trace_stmt.stmt_node, ast.Return) or inside_anonymous_call
) and not trace_stmt.lambda_call_point_deps_done_once:
trace_stmt.lambda_call_point_deps_done_once = True
maybe_lambda_sym = flow().statement_to_func_sym.get(
id(trace_stmt.stmt_node), None
)
maybe_lambda_node = None
if maybe_lambda_sym is not None:
maybe_lambda_node = maybe_lambda_sym.func_def_stmt
if (
inside_anonymous_call
and maybe_lambda_node is not None
and isinstance(maybe_lambda_node, ast.Lambda)
):
if not trace_stmt.lambda_call_point_deps_done_once:
trace_stmt.lambda_call_point_deps_done_once = True
maybe_lambda_sym = flow().statement_to_func_sym.get(
id(trace_stmt.stmt_node), None
rvals = resolve_rval_symbols(maybe_lambda_node.body)
else:
rvals = resolve_rval_symbols(trace_stmt.stmt_node)
dsym_to_attach = None
if len(rvals) == 1:
dsym_to_attach = next(iter(rvals))
if dsym_to_attach.obj_id != id(ret):
dsym_to_attach = None
if dsym_to_attach is None and len(rvals) > 0:
dsym_to_attach = (
self.cur_frame_original_scope.upsert_symbol_for_name(
"<return_sym_%d>" % id(ret),
ret,
rvals,
trace_stmt.stmt_node,
is_anonymous=True,
)
maybe_lambda_node = None
if maybe_lambda_sym is not None:
maybe_lambda_node = maybe_lambda_sym.func_def_stmt
if (
inside_anonymous_call
and maybe_lambda_node is not None
and isinstance(maybe_lambda_node, ast.Lambda)
):
rvals = resolve_rval_symbols(maybe_lambda_node.body)
)
if dsym_to_attach is None:
return
return_to_node_id = self.call_stack.get_field(
"prev_node_id_in_cur_frame"
)
# logger.error("prev seen: %s", ast.dump(self.ast_node_by_id[return_to_node_id]))
try:
call_node_id = self.call_stack.get_field(
"lexical_call_stack"
).get_field("prev_node_id_in_cur_frame_lexical")
call_node = cast(ast.Call, self.ast_node_by_id[call_node_id])
# logger.error("prev seen outer: %s", ast.dump(self.ast_node_by_id[call_node_id]))
total_args = len(call_node.args) + len(call_node.keywords)
num_args_seen = self.call_stack.get_field("num_args_seen")
logger.warning("num args seen: %d", num_args_seen)
if total_args == num_args_seen:
return_to_node_id = call_node_id
else:
assert num_args_seen < total_args
if num_args_seen < len(call_node.args):
return_to_node_id = id(call_node.args[num_args_seen])
else:
rvals = resolve_rval_symbols(trace_stmt.stmt_node)
dsym_to_attach = None
if len(rvals) == 1:
dsym_to_attach = next(iter(rvals))
if dsym_to_attach.obj_id != id(ret):
dsym_to_attach = None
if dsym_to_attach is None and len(rvals) > 0:
dsym_to_attach = (
self.cur_frame_original_scope.upsert_symbol_for_name(
"<return_sym_%d>" % id(ret),
ret,
rvals,
trace_stmt.stmt_node,
is_anonymous=True,
)
return_to_node_id = id(
call_node.keywords[
num_args_seen - len(call_node.args)
].value
)
if dsym_to_attach is not None:
return_to_node_id = self.call_stack.get_field(
"prev_node_id_in_cur_frame"
)
# logger.error("prev seen: %s", ast.dump(self.ast_node_by_id[return_to_node_id]))
try:
call_node_id = self.call_stack.get_field(
"lexical_call_stack"
).get_field("prev_node_id_in_cur_frame_lexical")
call_node = cast(
ast.Call, self.ast_node_by_id[call_node_id]
)
# logger.error("prev seen outer: %s", ast.dump(self.ast_node_by_id[call_node_id]))
total_args = len(call_node.args) + len(
call_node.keywords
)
num_args_seen = self.call_stack.get_field(
"num_args_seen"
)
logger.warning("num args seen: %d", num_args_seen)
if total_args == num_args_seen:
return_to_node_id = call_node_id
else:
assert num_args_seen < total_args
if num_args_seen < len(call_node.args):
return_to_node_id = id(
call_node.args[num_args_seen]
)
else:
return_to_node_id = id(
call_node.keywords[
num_args_seen - len(call_node.args)
].value
)
except IndexError:
pass
# logger.error("use node %s", ast.dump(self.ast_node_by_id[return_to_node_id]))
self.node_id_to_loaded_symbols.setdefault(
return_to_node_id, []
).append(dsym_to_attach)
except IndexError:
pass
# logger.error("use node %s", ast.dump(self.ast_node_by_id[return_to_node_id]))
self.node_id_to_loaded_symbols.setdefault(return_to_node_id, []).append(
dsym_to_attach
)
finally:
if self.is_tracing_enabled:
self.call_stack.pop()
Expand Down Expand Up @@ -1437,8 +1429,9 @@ def before_stmt(self, _ret: None, stmt_id: int, frame: FrameType, *_, **__) -> N
self.after_stmt(None, prev_trace_stmt_in_cur_frame.stmt_id, frame)
self.prev_trace_stmt_in_cur_frame = trace_stmt
if not self.is_tracing_enabled and (
trace_stmt.is_module_stmt()
or self._try_reenable_tracing(frame, dry_run=True)
self._try_reenable_tracing(
frame, dry_run=True, is_module_stmt=trace_stmt.is_module_stmt()
)
):
self.after_stmt_reset_hook()
self._tracked_enable_tracing()
Expand Down Expand Up @@ -1480,11 +1473,12 @@ def _try_reenable_tracing(
frame: FrameType,
empty_stack_call_depth: Optional[int] = None,
dry_run: bool = False,
is_module_stmt: bool = False,
) -> bool:
tracing_reenabled_call_stack_length = (
self._call_stack_length_for_reenabling_tracing(frame)
)
if tracing_reenabled_call_stack_length == -1:
if not is_module_stmt and tracing_reenabled_call_stack_length == -1:
return False
assert tracing_reenabled_call_stack_length <= len(self.call_stack)
while len(self.call_stack) > tracing_reenabled_call_stack_length:
Expand Down

0 comments on commit 72ca706

Please sign in to comment.