Skip to content

Commit

Permalink
Make PipelineState carry tracing context proto.
Browse files Browse the repository at this point in the history
For now we will use no-op tracing context as a placeholder.

PiperOrigin-RevId: 613141675
  • Loading branch information
chongkong authored and tfx-copybara committed Mar 18, 2024
1 parent 2283b52 commit c4097bf
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 26 deletions.
34 changes: 25 additions & 9 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
from tfx.utils import io_utils
from tfx.utils import status as status_lib

from tfx.utils import tracecontext_pb2
from tfx.utils import tracing
from ml_metadata import errors as mlmd_errors
from ml_metadata.proto import metadata_store_pb2

Expand Down Expand Up @@ -82,6 +84,11 @@ def _wrapper(*args, **kwargs):
with contextlib.ExitStack() as stack:
if lock:
stack.enter_context(_PIPELINE_OPS_LOCK)
stack.enter_context(
tracing.LocalTraceSpan(
'TflexOrchestrator.pipeline_ops', fn.__name__
)
)

health_status = env.get_env().health_status()
if health_status.code != status_lib.Code.OK:
Expand Down Expand Up @@ -115,6 +122,7 @@ def initiate_pipeline_start(
pipeline: pipeline_pb2.Pipeline,
pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None,
partial_run_option: Optional[pipeline_pb2.PartialRun] = None,
trace_proto: Optional[tracecontext_pb2.TraceContextProto] = None,
) -> pstate.PipelineState:
"""Initiates a pipeline start operation.
Expand All @@ -125,6 +133,9 @@ def initiate_pipeline_start(
pipeline: IR of the pipeline to start.
pipeline_run_metadata: Pipeline run metadata.
partial_run_option: Options for partial pipeline run.
trace_proto: Optional trace context proto to attach to the pipeline run.
This trace context will be extended from the task scheduling and the
children job executions.
Returns:
The `PipelineState` object upon success.
Expand Down Expand Up @@ -220,10 +231,11 @@ def initiate_pipeline_start(
# TODO: b/323912217 - Support putting multiple subpipeline executions
# into MLMD to handle the ForEach case.
with pstate.PipelineState.new(
mlmd_handle,
subpipeline,
pipeline_run_metadata,
reused_subpipeline_view,
mlmd_handle=mlmd_handle,
pipeline=subpipeline,
pipeline_run_metadata=pipeline_run_metadata,
reused_pipeline_view=reused_subpipeline_view,
trace_proto=trace_proto,
) as subpipeline_state:
# TODO: b/320535460 - The new pipeline run should not be stopped if
# there are still nodes to run in it.
Expand All @@ -250,7 +262,11 @@ def initiate_pipeline_start(
)

return pstate.PipelineState.new(
mlmd_handle, pipeline, pipeline_run_metadata, reused_pipeline_view
mlmd_handle=mlmd_handle,
pipeline=pipeline,
pipeline_run_metadata=pipeline_run_metadata,
reused_pipeline_view=reused_pipeline_view,
trace_proto=trace_proto,
)


Expand Down Expand Up @@ -937,9 +953,7 @@ def resume_pipeline(
from_nodes=pipeline_nodes,
to_nodes=pipeline_nodes,
skip_nodes=previously_succeeded_nodes,
skip_snapshot_nodes=_get_previously_skipped_nodes(
latest_pipeline_view
),
skip_snapshot_nodes=_get_previously_skipped_nodes(latest_pipeline_view),
snapshot_settings=snapshot_settings,
)
except ValueError as e:
Expand All @@ -961,7 +975,9 @@ def resume_pipeline(
)

return pstate.PipelineState.new(
mlmd_handle, pipeline, reused_pipeline_view=latest_pipeline_view
mlmd_handle=mlmd_handle,
pipeline=pipeline,
reused_pipeline_view=latest_pipeline_view,
)


Expand Down
69 changes: 54 additions & 15 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

from tfx.utils import telemetry_utils
from google.protobuf import message
from tfx.utils import tracecontext_pb2
import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2

Expand All @@ -74,6 +75,7 @@
_PIPELINE_EXEC_MODE = 'pipeline_exec_mode'
_PIPELINE_EXEC_MODE_SYNC = 'sync'
_PIPELINE_EXEC_MODE_ASYNC = 'async'
_PIPELINE_RUN_TRACE_PROTO = 'trace_proto'

_last_state_change_time_secs = -1.0
_state_change_time_lock = threading.Lock()
Expand Down Expand Up @@ -485,7 +487,28 @@ def wrapper(*args, **kwargs):
return wrapper


class PipelineState:
def _set_trace_proto(
execution: metadata_store_pb2.Execution,
trace_proto: tracecontext_pb2.TraceContextProto,
):
execution.custom_properties[_PIPELINE_RUN_TRACE_PROTO].proto_value.Pack(
trace_proto
)


def _get_trace_proto(
execution: metadata_store_pb2.Execution,
) -> Optional[tracecontext_pb2.TraceContextProto]:
if _PIPELINE_RUN_TRACE_PROTO in execution.custom_properties:
result = tracecontext_pb2.TraceContextProto()
execution.custom_properties[_PIPELINE_RUN_TRACE_PROTO].proto_value.Unpack(
result
)
return result
return None


class PipelineState(contextlib.ExitStack):
"""Context manager class for dealing with pipeline state.
The state of a pipeline is stored as an MLMD execution and this class provides
Expand All @@ -512,6 +535,7 @@ class PipelineState:
execution_id: Id of the underlying execution in MLMD.
pipeline_uid: Unique id of the pipeline.
pipeline_run_id: pipeline_run_id in case of sync pipeline, `None` otherwise.
trace_proto: Optional TraceContextProto for tracing the pipeline run.
"""

def __init__(
Expand All @@ -521,6 +545,7 @@ def __init__(
pipeline_id: str,
):
"""Constructor. Use one of the factory methods to initialize."""
super().__init__()
self.mlmd_handle = mlmd_handle
# TODO(b/201294315): Fix self.pipeline going out of sync with the actual
# pipeline proto stored in the underlying MLMD execution in some cases.
Expand All @@ -540,13 +565,19 @@ def __init__(
self.pipeline_uid = task_lib.PipelineUid.from_pipeline_id_and_run_id(
pipeline_id, self.pipeline_run_id
)
# Since TraceContextProto is immutable, use the constructor Execution arg to
# retrieve the value instead of lazy parsing self._execution.
self._trace_proto = _get_trace_proto(execution)

# Only set within the pipeline state context.
self._mlmd_execution_atomic_op_context = None
self._execution: Optional[metadata_store_pb2.Execution] = None
self._on_commit_callbacks: List[Callable[[], None]] = []
self._node_states_proxy: Optional[_NodeStatesProxy] = None

@property
def trace_proto(self) -> Optional[tracecontext_pb2.TraceContextProto]:
return self._trace_proto

@classmethod
@telemetry_utils.noop_telemetry(metrics_utils.no_op_metrics)
@_synchronized
Expand All @@ -556,6 +587,7 @@ def new(
pipeline: pipeline_pb2.Pipeline,
pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None,
reused_pipeline_view: Optional['PipelineView'] = None,
trace_proto: Optional[tracecontext_pb2.TraceContextProto] = None,
) -> 'PipelineState':
"""Creates a `PipelineState` object for a new pipeline.
Expand All @@ -568,6 +600,9 @@ def new(
pipeline_run_metadata: Pipeline run metadata.
reused_pipeline_view: PipelineView of the previous pipeline reused for a
partial run.
trace_proto: Optional trace context proto to attach to the pipeline run.
This trace context will be extended from the task scheduling and the
children job executions.
Returns:
A `PipelineState` object.
Expand Down Expand Up @@ -670,6 +705,8 @@ def new(
exec_properties=exec_properties,
execution_name=str(uuid.uuid4()),
)
if trace_proto is not None:
_set_trace_proto(execution, trace_proto)
if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
data_types_utils.set_metadata_value(
execution.custom_properties[_PIPELINE_RUN_ID],
Expand Down Expand Up @@ -1113,6 +1150,7 @@ def get_orchestration_options(
return env.get_env().get_orchestration_options(self.pipeline)

def __enter__(self) -> 'PipelineState':
super().__enter__()

def _run_on_commit_callbacks(pre_commit_execution, post_commit_execution):
del pre_commit_execution
Expand All @@ -1121,23 +1159,24 @@ def _run_on_commit_callbacks(pre_commit_execution, post_commit_execution):
for on_commit_cb in self._on_commit_callbacks:
on_commit_cb()

mlmd_execution_atomic_op_context = mlmd_state.mlmd_execution_atomic_op(
self.mlmd_handle, self.execution_id, _run_on_commit_callbacks)
execution = mlmd_execution_atomic_op_context.__enter__()
self._mlmd_execution_atomic_op_context = mlmd_execution_atomic_op_context
self._execution = execution
self._node_states_proxy = _NodeStatesProxy(execution)
@contextlib.contextmanager
def mlmd_execution():
try:
with mlmd_state.mlmd_execution_atomic_op(
self.mlmd_handle, self.execution_id, _run_on_commit_callbacks
) as execution:
yield execution
finally:
self._on_commit_callbacks.clear()

self._execution = self.enter_context(mlmd_execution())
self._node_states_proxy = _NodeStatesProxy(self._execution)
self.callback(self._cleanup)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def _cleanup(self):
self._node_states_proxy.save()
mlmd_execution_atomic_op_context = self._mlmd_execution_atomic_op_context
self._mlmd_execution_atomic_op_context = None
self._execution = None
try:
mlmd_execution_atomic_op_context.__exit__(exc_type, exc_val, exc_tb)
finally:
self._on_commit_callbacks.clear()

def _check_context(self) -> None:
if self._execution is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ def schedule(self) -> task_scheduler.TaskSchedulerResult:
# was already started so the execution should already exist.
self._put_begin_node_execution()
logging.info('[Subpipeline Task Scheduler]: start subpipeline.')
pipeline_ops.initiate_pipeline_start(self.mlmd_handle,
self._sub_pipeline, None, None)
pipeline_ops.initiate_pipeline_start(
mlmd_handle=self.mlmd_handle,
pipeline=self._sub_pipeline,
trace_proto=None,
)
except status_lib.StatusNotOkError as e:
return task_scheduler.TaskSchedulerResult(status=e.status())

Expand Down
18 changes: 18 additions & 0 deletions tfx/utils/tracecontext_pb2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tracecontext module."""


class TraceContextProto:
pass
38 changes: 38 additions & 0 deletions tfx/utils/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tracing module."""

import contextlib
from typing import Any

from tfx.utils import tracecontext_pb2


class TraceContext:

def is_traced(self) -> bool:
return False

def to_proto(self) -> tracecontext_pb2.TraceContextProto:
return tracecontext_pb2.TraceContextProto()


class LocalTraceSpan(contextlib.ContextDecorator):

def __init__(self, *args: Any, **kwargs: Any):
pass


def current_tracing_context() -> TraceContext:
return TraceContext()

0 comments on commit c4097bf

Please sign in to comment.