Skip to content

Commit

Permalink
Extend PipelineState.trace_proto to orchestrator loop and ExecNodeTask.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613141676
  • Loading branch information
chongkong authored and tfx-copybara committed Mar 13, 2024
1 parent ad74d1b commit 1e2e063
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 72 deletions.
6 changes: 1 addition & 5 deletions test_constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,4 @@
# dependencies.

# TODO(b/321609768): Remove pinned Flask-session version after resolving the issue.
Flask-session<0.6.0

#TODO(b/329181965): Remove once we migrate TFX to 2.16.
tensorflow<2.16
tensorflow-text<2.16
Flask-session<0.6.0
53 changes: 35 additions & 18 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from tfx.utils import io_utils
from tfx.utils import status as status_lib

from tfx.utils import tracing
from ml_metadata import errors as mlmd_errors
from ml_metadata.proto import metadata_store_pb2

Expand Down Expand Up @@ -82,6 +83,11 @@ def _wrapper(*args, **kwargs):
with contextlib.ExitStack() as stack:
if lock:
stack.enter_context(_PIPELINE_OPS_LOCK)
stack.enter_context(
tracing.LocalTraceSpan(
'Tflex.Orchestrator.PipelineOps', fn.__name__
)
)

health_status = env.get_env().health_status()
if health_status.code != status_lib.Code.OK:
Expand Down Expand Up @@ -220,10 +226,11 @@ def initiate_pipeline_start(
# TODO: b/323912217 - Support putting multiple subpipeline executions
# into MLMD to handle the ForEach case.
with pstate.PipelineState.new(
mlmd_handle,
subpipeline,
pipeline_run_metadata,
reused_subpipeline_view,
mlmd_handle=mlmd_handle,
pipeline=subpipeline,
pipeline_run_metadata=pipeline_run_metadata,
reused_pipeline_view=reused_subpipeline_view,
trace_context=tracing.current_trace_context(),
) as subpipeline_state:
# TODO: b/320535460 - The new pipeline run should not be stopped if
# there are still nodes to run in it.
Expand All @@ -250,7 +257,11 @@ def initiate_pipeline_start(
)

return pstate.PipelineState.new(
mlmd_handle, pipeline, pipeline_run_metadata, reused_pipeline_view
mlmd_handle=mlmd_handle,
pipeline=pipeline,
pipeline_run_metadata=pipeline_run_metadata,
reused_pipeline_view=reused_pipeline_view,
trace_context=tracing.current_trace_context(),
)


Expand Down Expand Up @@ -939,9 +950,7 @@ def resume_pipeline(
from_nodes=pipeline_nodes,
to_nodes=pipeline_nodes,
skip_nodes=previously_succeeded_nodes,
skip_snapshot_nodes=_get_previously_skipped_nodes(
latest_pipeline_view
),
skip_snapshot_nodes=_get_previously_skipped_nodes(latest_pipeline_view),
snapshot_settings=snapshot_settings,
)
except ValueError as e:
Expand All @@ -961,7 +970,10 @@ def resume_pipeline(
)

return pstate.PipelineState.new(
mlmd_handle, pipeline, reused_pipeline_view=latest_pipeline_view
mlmd_handle=mlmd_handle,
pipeline=pipeline,
reused_pipeline_view=latest_pipeline_view,
trace_context=tracing.current_trace_context(),
)


Expand Down Expand Up @@ -1267,9 +1279,9 @@ def orchestrate(
logging.info('No active pipelines to run.')
return False

active_pipeline_states = []
stop_initiated_pipeline_states = []
update_initiated_pipeline_states = []
active_pipeline_states: list[pstate.PipelineState] = []
stop_initiated_pipeline_states: list[pstate.PipelineState] = []
update_initiated_pipeline_states: list[pstate.PipelineState] = []
for pipeline_state in pipeline_states:
with pipeline_state:
if pipeline_state.is_stop_initiated():
Expand Down Expand Up @@ -1357,12 +1369,17 @@ def orchestrate(
for pipeline_state in active_pipeline_states:
logging.info('Orchestrating pipeline: %s', pipeline_state.pipeline_uid)
try:
_orchestrate_active_pipeline(
mlmd_connection_manager,
task_queue,
service_job_manager,
pipeline_state,
)
with tracing.LocalTraceSpan(
'Tflex.Orchestrator.PipelineOps',
'_orchestrate_active_pipeline',
parent_proto=pipeline_state.trace_proto,
):
_orchestrate_active_pipeline(
mlmd_connection_manager,
task_queue,
service_job_manager,
pipeline_state,
)
except Exception as e: # pylint: disable=broad-except
logging.exception(
'Exception raised while orchestrating active pipeline %s',
Expand Down
70 changes: 55 additions & 15 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@

from tfx.utils import telemetry_utils
from google.protobuf import message
from tfx.utils import tracecontext_pb2
from tfx.utils import tracing
import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2

Expand All @@ -74,6 +76,7 @@
_PIPELINE_EXEC_MODE = 'pipeline_exec_mode'
_PIPELINE_EXEC_MODE_SYNC = 'sync'
_PIPELINE_EXEC_MODE_ASYNC = 'async'
_PIPELINE_RUN_TRACE_PROTO = 'trace_proto'

_last_state_change_time_secs = -1.0
_state_change_time_lock = threading.Lock()
Expand Down Expand Up @@ -485,7 +488,28 @@ def wrapper(*args, **kwargs):
return wrapper


class PipelineState:
def _set_trace_proto(
execution: metadata_store_pb2.Execution,
trace_proto: tracecontext_pb2.TraceContextProto,
):
execution.custom_properties[_PIPELINE_RUN_TRACE_PROTO].proto_value.Pack(
trace_proto
)


def _get_trace_proto(
execution: metadata_store_pb2.Execution,
) -> Optional[tracecontext_pb2.TraceContextProto]:
if _PIPELINE_RUN_TRACE_PROTO in execution.custom_properties:
result = tracecontext_pb2.TraceContextProto()
execution.custom_properties[_PIPELINE_RUN_TRACE_PROTO].proto_value.Unpack(
result
)
return result
return None


class PipelineState(contextlib.ExitStack):
"""Context manager class for dealing with pipeline state.
The state of a pipeline is stored as an MLMD execution and this class provides
Expand All @@ -512,6 +536,7 @@ class PipelineState:
execution_id: Id of the underlying execution in MLMD.
pipeline_uid: Unique id of the pipeline.
pipeline_run_id: pipeline_run_id in case of sync pipeline, `None` otherwise.
trace_proto: Optional TraceContextProto for tracing the pipeline run.
"""

def __init__(
Expand All @@ -521,6 +546,7 @@ def __init__(
pipeline_id: str,
):
"""Constructor. Use one of the factory methods to initialize."""
super().__init__()
self.mlmd_handle = mlmd_handle
# TODO(b/201294315): Fix self.pipeline going out of sync with the actual
# pipeline proto stored in the underlying MLMD execution in some cases.
Expand All @@ -540,13 +566,19 @@ def __init__(
self.pipeline_uid = task_lib.PipelineUid.from_pipeline_id_and_run_id(
pipeline_id, self.pipeline_run_id
)
# Since TraceContextProto is immutable, use the constructor Execution arg to
# retrieve the value instead of lazy parsing self._execution.
self._trace_proto = _get_trace_proto(execution)

# Only set within the pipeline state context.
self._mlmd_execution_atomic_op_context = None
self._execution: Optional[metadata_store_pb2.Execution] = None
self._on_commit_callbacks: List[Callable[[], None]] = []
self._node_states_proxy: Optional[_NodeStatesProxy] = None

@property
def trace_proto(self) -> Optional[tracecontext_pb2.TraceContextProto]:
return self._trace_proto

@classmethod
@telemetry_utils.noop_telemetry(metrics_utils.no_op_metrics)
@_synchronized
Expand All @@ -556,6 +588,7 @@ def new(
pipeline: pipeline_pb2.Pipeline,
pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None,
reused_pipeline_view: Optional['PipelineView'] = None,
trace_context: Optional[tracing.TraceContext] = None,
) -> 'PipelineState':
"""Creates a `PipelineState` object for a new pipeline.
Expand All @@ -568,6 +601,9 @@ def new(
pipeline_run_metadata: Pipeline run metadata.
reused_pipeline_view: PipelineView of the previous pipeline reused for a
partial run.
trace_context: Optional trace context to attach to the pipeline run. This
trace context will be extended from the child job during the pipeline
run.
Returns:
A `PipelineState` object.
Expand Down Expand Up @@ -670,6 +706,8 @@ def new(
exec_properties=exec_properties,
execution_name=str(uuid.uuid4()),
)
if trace_context is not None and trace_context.is_traced():
_set_trace_proto(execution, trace_context.to_proto())
if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
data_types_utils.set_metadata_value(
execution.custom_properties[_PIPELINE_RUN_ID],
Expand Down Expand Up @@ -1113,6 +1151,7 @@ def get_orchestration_options(
return env.get_env().get_orchestration_options(self.pipeline)

def __enter__(self) -> 'PipelineState':
super().__enter__()

def _run_on_commit_callbacks(pre_commit_execution, post_commit_execution):
del pre_commit_execution
Expand All @@ -1121,23 +1160,24 @@ def _run_on_commit_callbacks(pre_commit_execution, post_commit_execution):
for on_commit_cb in self._on_commit_callbacks:
on_commit_cb()

mlmd_execution_atomic_op_context = mlmd_state.mlmd_execution_atomic_op(
self.mlmd_handle, self.execution_id, _run_on_commit_callbacks)
execution = mlmd_execution_atomic_op_context.__enter__()
self._mlmd_execution_atomic_op_context = mlmd_execution_atomic_op_context
self._execution = execution
self._node_states_proxy = _NodeStatesProxy(execution)
@contextlib.contextmanager
def mlmd_execution():
try:
with mlmd_state.mlmd_execution_atomic_op(
self.mlmd_handle, self.execution_id, _run_on_commit_callbacks
) as execution:
yield execution
finally:
self._on_commit_callbacks.clear()

self._execution = self.enter_context(mlmd_execution())
self._node_states_proxy = _NodeStatesProxy(self._execution)
self.callback(self._cleanup)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def _cleanup(self):
self._node_states_proxy.save()
mlmd_execution_atomic_op_context = self._mlmd_execution_atomic_op_context
self._mlmd_execution_atomic_op_context = None
self._execution = None
try:
mlmd_execution_atomic_op_context.__exit__(exc_type, exc_val, exc_tb)
finally:
self._on_commit_callbacks.clear()

def _check_context(self) -> None:
if self._execution is None:
Expand Down

0 comments on commit 1e2e063

Please sign in to comment.