Skip to content

Commit

Permalink
refactor memoization data structures
Browse files Browse the repository at this point in the history
  • Loading branch information
smacke committed Nov 29, 2023
1 parent 0ca6fc2 commit 0914a30
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 48 deletions.
33 changes: 19 additions & 14 deletions core/ipyflow/data_model/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ipyflow.analysis.resolved_symbols import ResolvedSymbol
from ipyflow.config import ExecutionSchedule, FlowDirection, Interface
from ipyflow.data_model.timestamp import Timestamp
from ipyflow.memoization import MemoizedCellExecution, MemoizedInput, MemoizedOutput
from ipyflow.models import _CodeCellContainer, cells, statements
from ipyflow.singletons import flow, shell
from ipyflow.slicing.context import SlicingContext
Expand Down Expand Up @@ -118,13 +119,7 @@ def __init__(
self._placeholder_id = placeholder_id
self.is_memoized = is_memoized
self.skipped_due_to_memoization_ctr = -1
self.memoized_params: List[
Tuple[
Dict["Symbol", Tuple[int, Any]],
Dict["Symbol", Tuple[Any, Timestamp]],
int,
]
] = []
self.memoized_executions: List[MemoizedCellExecution] = []
self._force_tracking = force_tracking

@property
Expand Down Expand Up @@ -284,23 +279,33 @@ def get_reactive_ids_for_tag(cls, tag: str) -> Set[IdType]:
return cls._reactive_cells_by_tag.get(tag, set())

def _maybe_memoize_params(self) -> None:
memoized_params: Dict["Symbol", Tuple[int, Any]] = {}
inputs: Dict["Symbol", MemoizedInput] = {}
for _ in SlicingContext.iter_slicing_contexts():
for edges in self.parents.values():
for sym in edges:
if sym.timestamp.cell_num >= self.cell_ctr:
return
if sym not in memoized_params:
memoized_params[sym] = (
sym.timestamp.cell_num,
if sym not in inputs:
inputs[sym] = MemoizedInput(
sym,
sym.timestamp,
sym.make_memoize_comparable()[0],
)
outputs = {}
outputs: Dict["Symbol", MemoizedOutput] = {}
for sym in flow().updated_symbols:
if not sym.is_user_accessible or not sym.containing_scope.is_global:
continue
outputs[sym] = (sym.obj, sym.timestamp_excluding_ns_descendents)
self.memoized_params.append((memoized_params, outputs, self.cell_ctr))
outputs[sym] = MemoizedOutput(
sym, sym.timestamp_excluding_ns_descendents, sym.obj
)
self.memoized_executions.append(
MemoizedCellExecution(
self.executed_content,
list(inputs.values()),
list(outputs.values()),
self.cell_ctr,
)
)

@classmethod
def create_and_track(
Expand Down
25 changes: 25 additions & 0 deletions core/ipyflow/memoization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
from typing import TYPE_CHECKING, Any, List, NamedTuple

if TYPE_CHECKING:
from ipyflow.data_model.symbol import Symbol
from ipyflow.data_model.timestamp import Timestamp


class MemoizedInput(NamedTuple):
symbol: "Symbol"
ts_at_execution: "Timestamp"
comparable: Any


class MemoizedOutput(NamedTuple):
symbol: "Symbol"
ts_at_execution: "Timestamp"
value: Any


class MemoizedCellExecution(NamedTuple):
content_at_execution: str
inputs: List[MemoizedInput]
outputs: List[MemoizedOutput]
cell_ctr: int
68 changes: 34 additions & 34 deletions core/ipyflow/shell/interactiveshell.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,39 +482,39 @@ def before_run_cell(
if not flow_.mut_settings.dataflow_enabled:
return None

if cell.is_memoized:
prev_cell = cell.prev_cell
identical_result_ctr: Optional[int] = None
memoized_outputs = {}
if (
prev_cell is not None
and prev_cell.executed_content == cell.executed_content
):
for params, outputs, ctr in prev_cell.memoized_params:
for param, (ts, comparable_obj) in params.items():
if param.timestamp.cell_num == ts:
continue
elif comparable_obj is Symbol.NULL:
break
comparable_param, eq = param.make_memoize_comparable()
if eq is None or not eq(comparable_param, comparable_obj):
break
else:
identical_result_ctr = ctr
memoized_outputs = outputs
identical_result_ctr: Optional[int] = None
memoized_outputs = []

prev_cell = cell.prev_cell
if cell.is_memoized and prev_cell is not None:
for content, inputs, outputs, ctr in prev_cell.memoized_executions:
if content != cell.executed_content:
continue
for (sym, in_ts, comparable) in inputs:
if sym.timestamp.cell_num == in_ts.cell_num:
continue
elif comparable is Symbol.NULL:
break
if identical_result_ctr is not None:
cell.skipped_due_to_memoization_ctr = identical_result_ctr
print_purple(
"Detected identical symbol usages to previous run; reusing memoized result..."
)
for sym, (obj, mem_ts) in memoized_outputs.items():
if sym.obj is not obj:
self.user_ns[sym.name] = obj
sym.update_obj_ref(obj)
new_updated_ts = Timestamp(self.cell_counter(), mem_ts.stmt_num)
sym.refresh(timestamp=new_updated_ts)
return f"Out.get({identical_result_ctr})"
current_comp, eq = sym.make_memoize_comparable()
if eq is None or not eq(current_comp, comparable):
break
else:
identical_result_ctr = ctr
memoized_outputs = outputs
break

if identical_result_ctr is not None:
cell.skipped_due_to_memoization_ctr = identical_result_ctr
print_purple(
"Detected identical symbol usages to previous run; reusing memoized result..."
)
for (sym, out_ts, value) in memoized_outputs:
if sym.obj is not value:
self.user_ns[sym.name] = value
sym.update_obj_ref(value)
new_updated_ts = Timestamp(self.cell_counter(), out_ts.stmt_num)
sym.refresh(timestamp=new_updated_ts)
return f"Out.get({identical_result_ctr})"

# Stage 1: Precheck.
if DataflowTracer in self.registered_tracers:
Expand Down Expand Up @@ -546,8 +546,8 @@ def _handle_memoization(self) -> None:
prev_cell = cell.prev_cell
if prev_cell is not None:
if cell.executed_content == prev_cell.executed_content and cell.is_memoized:
cell.memoized_params = prev_cell.memoized_params
prev_cell.memoized_params = []
cell.memoized_executions = prev_cell.memoized_executions
prev_cell.memoized_executions = []
if cell.skipped_due_to_memoization_ctr > 0:
prev_cell = Cell.at_counter(cell.skipped_due_to_memoization_ctr)
assert prev_cell is not None
Expand Down

0 comments on commit 0914a30

Please sign in to comment.