Skip to content

Commit

Permalink
simple patching framework
Browse files Browse the repository at this point in the history
  • Loading branch information
smacke committed Feb 25, 2024
1 parent 3e83336 commit 2012da6
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 29 deletions.
42 changes: 42 additions & 0 deletions core/ipyflow/patches/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
import logging
import sys
from types import ModuleType
from typing import Callable, Set, Tuple

from ipyflow.patches.cloudpickle_patch import patch_cloudpickle_function_reduce
from ipyflow.patches.pyspark_patch import patch_pyspark_udf
from ipyflow.singletons import flow

logger = logging.getLogger(__name__)

_predicate_patch_pairs: Tuple[
Tuple[Callable[[str], bool], Callable[[ModuleType], None]], ...
] = (
(
lambda modname: modname.endswith("cloudpickle.cloudpickle_fast"),
patch_cloudpickle_function_reduce,
),
(lambda modname: modname == "pyspark.sql.udf", patch_pyspark_udf),
)

_patched_modules: Set[str] = set()


def apply_patches(modname: str, module: ModuleType) -> None:
if modname in _patched_modules:
return
flow_ = flow()
for predicate, patch in _predicate_patch_pairs:
try:
if predicate(modname):
patch(module)
_patched_modules.add(modname)
except Exception: # noqa
if flow_.is_dev_mode:
logger.exception("Failed to apply patch to module %s", modname)


def patch_all() -> None:
for modname, module in list(sys.modules.items()):
apply_patches(modname, module)
23 changes: 23 additions & 0 deletions core/ipyflow/patches/cloudpickle_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
from types import FunctionType, LambdaType, ModuleType
from typing import TYPE_CHECKING, Type, Union

from ipyflow.tracing.uninstrument import uninstrument

if TYPE_CHECKING:
from cloudpickle.cloudpickle_fast import CloudPickler


def patch_cloudpickle_function_reduce(module: ModuleType) -> None:
pickler_cls: Type["CloudPickler"] = module.CloudPickler
_function_reduce = pickler_cls._function_reduce

def _patched_function_reduce(
self_: "CloudPickler", obj: Union[FunctionType, LambdaType]
) -> None:
uninstrumented = uninstrument(obj)
return _function_reduce(
self_, obj if uninstrumented is None else uninstrumented
)

pickler_cls._function_reduce = _patched_function_reduce
21 changes: 21 additions & 0 deletions core/ipyflow/patches/pyspark_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
from types import ModuleType
from typing import TYPE_CHECKING, Type

from ipyflow.tracing.uninstrument import uninstrument

if TYPE_CHECKING:
from pyspark.sql.udf import UserDefinedFunction


def patch_pyspark_udf(module: ModuleType) -> None:
udf_cls: Type["UserDefinedFunction"] = module.UserDefinedFunction
udf_cls_init = udf_cls.__init__

def _patched_init(self_: "UserDefinedFunction", func, *args, **kwargs) -> None:
uninstrumented = uninstrument(func)
return udf_cls_init(
self_, func if uninstrumented is None else uninstrumented, *args, **kwargs
)

udf_cls.__init__ = _patched_init
24 changes: 0 additions & 24 deletions core/ipyflow/tracing/external_calls/cloudpickle_patch.py

This file was deleted.

7 changes: 2 additions & 5 deletions core/ipyflow/tracing/ipyflow_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,10 @@
from ipyflow.data_model.symbol import Symbol
from ipyflow.data_model.timestamp import Timestamp
from ipyflow.models import symbols as api_symbols
from ipyflow.patches import apply_patches
from ipyflow.singletons import SingletonBaseTracer, flow, shell
from ipyflow.tracing.external_calls import resolve_external_call
from ipyflow.tracing.external_calls.base_handlers import ExternalCallHandler
from ipyflow.tracing.external_calls.cloudpickle_patch import (
patch_cloudpickle_function_reduce,
)
from ipyflow.tracing.flow_ast_rewriter import DataflowAstRewriter
from ipyflow.tracing.symbol_resolver import resolve_rval_symbols
from ipyflow.tracing.utils import match_container_obj_or_namespace_with_literal_nodes
Expand Down Expand Up @@ -823,11 +821,10 @@ def _clear_info_and_maybe_lookup_or_create_complex_symbol(
def after_import(self, *_, module: ModuleType, **__):
compile_and_register_handlers_for_module(module)
modname = getattr(module, "__name__", "")
apply_patches(modname, module)
if modname == "numpy":
# TODO: convert these to Python ints when used on Python objects
SubscriptIndices.types += (module.int32, module.int64)
elif modname.endswith("cloudpickle.cloudpickle_fast"):
patch_cloudpickle_function_reduce(module.CloudPickler)

@pyc.register_raw_handler(
(
Expand Down

0 comments on commit 2012da6

Please sign in to comment.