Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PipelineState carry tracing context proto. #6702

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 25 additions & 9 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
from tfx.utils import io_utils
from tfx.utils import status as status_lib

from tfx.utils import tracecontext_pb2
from tfx.utils import tracing
from ml_metadata import errors as mlmd_errors
from ml_metadata.proto import metadata_store_pb2

Expand Down Expand Up @@ -82,6 +84,11 @@ def _wrapper(*args, **kwargs):
with contextlib.ExitStack() as stack:
if lock:
stack.enter_context(_PIPELINE_OPS_LOCK)
stack.enter_context(
tracing.LocalTraceSpan(
'TflexOrchestrator.pipeline_ops', fn.__name__
)
)

health_status = env.get_env().health_status()
if health_status.code != status_lib.Code.OK:
Expand Down Expand Up @@ -115,6 +122,7 @@ def initiate_pipeline_start(
pipeline: pipeline_pb2.Pipeline,
pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None,
partial_run_option: Optional[pipeline_pb2.PartialRun] = None,
trace_proto: Optional[tracecontext_pb2.TraceContextProto] = None,
) -> pstate.PipelineState:
"""Initiates a pipeline start operation.

Expand All @@ -125,6 +133,9 @@ def initiate_pipeline_start(
pipeline: IR of the pipeline to start.
pipeline_run_metadata: Pipeline run metadata.
partial_run_option: Options for partial pipeline run.
trace_proto: Optional trace context proto to attach to the pipeline run.
This trace context will be extended from the task scheduling and the
children job executions.

Returns:
The `PipelineState` object upon success.
Expand Down Expand Up @@ -220,10 +231,11 @@ def initiate_pipeline_start(
# TODO: b/323912217 - Support putting multiple subpipeline executions
# into MLMD to handle the ForEach case.
with pstate.PipelineState.new(
mlmd_handle,
subpipeline,
pipeline_run_metadata,
reused_subpipeline_view,
mlmd_handle=mlmd_handle,
pipeline=subpipeline,
pipeline_run_metadata=pipeline_run_metadata,
reused_pipeline_view=reused_subpipeline_view,
trace_proto=trace_proto,
) as subpipeline_state:
# TODO: b/320535460 - The new pipeline run should not be stopped if
# there are still nodes to run in it.
Expand All @@ -250,7 +262,11 @@ def initiate_pipeline_start(
)

return pstate.PipelineState.new(
mlmd_handle, pipeline, pipeline_run_metadata, reused_pipeline_view
mlmd_handle=mlmd_handle,
pipeline=pipeline,
pipeline_run_metadata=pipeline_run_metadata,
reused_pipeline_view=reused_pipeline_view,
trace_proto=trace_proto,
)


Expand Down Expand Up @@ -937,9 +953,7 @@ def resume_pipeline(
from_nodes=pipeline_nodes,
to_nodes=pipeline_nodes,
skip_nodes=previously_succeeded_nodes,
skip_snapshot_nodes=_get_previously_skipped_nodes(
latest_pipeline_view
),
skip_snapshot_nodes=_get_previously_skipped_nodes(latest_pipeline_view),
snapshot_settings=snapshot_settings,
)
except ValueError as e:
Expand All @@ -961,7 +975,9 @@ def resume_pipeline(
)

return pstate.PipelineState.new(
mlmd_handle, pipeline, reused_pipeline_view=latest_pipeline_view
mlmd_handle=mlmd_handle,
pipeline=pipeline,
reused_pipeline_view=latest_pipeline_view,
)


Expand Down
69 changes: 54 additions & 15 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

from tfx.utils import telemetry_utils
from google.protobuf import message
from tfx.utils import tracecontext_pb2
import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2

Expand All @@ -74,6 +75,7 @@
_PIPELINE_EXEC_MODE = 'pipeline_exec_mode'
_PIPELINE_EXEC_MODE_SYNC = 'sync'
_PIPELINE_EXEC_MODE_ASYNC = 'async'
_PIPELINE_RUN_TRACE_PROTO = 'trace_proto'

_last_state_change_time_secs = -1.0
_state_change_time_lock = threading.Lock()
Expand Down Expand Up @@ -485,7 +487,28 @@ def wrapper(*args, **kwargs):
return wrapper


class PipelineState:
def _set_trace_proto(
execution: metadata_store_pb2.Execution,
trace_proto: tracecontext_pb2.TraceContextProto,
):
execution.custom_properties[_PIPELINE_RUN_TRACE_PROTO].proto_value.Pack(
trace_proto
)


def _get_trace_proto(
execution: metadata_store_pb2.Execution,
) -> Optional[tracecontext_pb2.TraceContextProto]:
if _PIPELINE_RUN_TRACE_PROTO in execution.custom_properties:
result = tracecontext_pb2.TraceContextProto()
execution.custom_properties[_PIPELINE_RUN_TRACE_PROTO].proto_value.Unpack(
result
)
return result
return None


class PipelineState(contextlib.ExitStack):
"""Context manager class for dealing with pipeline state.

The state of a pipeline is stored as an MLMD execution and this class provides
Expand All @@ -512,6 +535,7 @@ class PipelineState:
execution_id: Id of the underlying execution in MLMD.
pipeline_uid: Unique id of the pipeline.
pipeline_run_id: pipeline_run_id in case of sync pipeline, `None` otherwise.
trace_proto: Optional TraceContextProto for tracing the pipeline run.
"""

def __init__(
Expand All @@ -521,6 +545,7 @@ def __init__(
pipeline_id: str,
):
"""Constructor. Use one of the factory methods to initialize."""
super().__init__()
self.mlmd_handle = mlmd_handle
# TODO(b/201294315): Fix self.pipeline going out of sync with the actual
# pipeline proto stored in the underlying MLMD execution in some cases.
Expand All @@ -540,13 +565,19 @@ def __init__(
self.pipeline_uid = task_lib.PipelineUid.from_pipeline_id_and_run_id(
pipeline_id, self.pipeline_run_id
)
# Since TraceContextProto is immutable, use the constructor Execution arg to
# retrieve the value instead of lazy parsing self._execution.
self._trace_proto = _get_trace_proto(execution)

# Only set within the pipeline state context.
self._mlmd_execution_atomic_op_context = None
self._execution: Optional[metadata_store_pb2.Execution] = None
self._on_commit_callbacks: List[Callable[[], None]] = []
self._node_states_proxy: Optional[_NodeStatesProxy] = None

@property
def trace_proto(self) -> Optional[tracecontext_pb2.TraceContextProto]:
return self._trace_proto

@classmethod
@telemetry_utils.noop_telemetry(metrics_utils.no_op_metrics)
@_synchronized
Expand All @@ -556,6 +587,7 @@ def new(
pipeline: pipeline_pb2.Pipeline,
pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None,
reused_pipeline_view: Optional['PipelineView'] = None,
trace_proto: Optional[tracecontext_pb2.TraceContextProto] = None,
) -> 'PipelineState':
"""Creates a `PipelineState` object for a new pipeline.

Expand All @@ -568,6 +600,9 @@ def new(
pipeline_run_metadata: Pipeline run metadata.
reused_pipeline_view: PipelineView of the previous pipeline reused for a
partial run.
trace_proto: Optional trace context proto to attach to the pipeline run.
This trace context will be extended from the task scheduling and the
children job executions.

Returns:
A `PipelineState` object.
Expand Down Expand Up @@ -670,6 +705,8 @@ def new(
exec_properties=exec_properties,
execution_name=str(uuid.uuid4()),
)
if trace_proto is not None:
_set_trace_proto(execution, trace_proto)
if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
data_types_utils.set_metadata_value(
execution.custom_properties[_PIPELINE_RUN_ID],
Expand Down Expand Up @@ -1113,6 +1150,7 @@ def get_orchestration_options(
return env.get_env().get_orchestration_options(self.pipeline)

def __enter__(self) -> 'PipelineState':
super().__enter__()

def _run_on_commit_callbacks(pre_commit_execution, post_commit_execution):
del pre_commit_execution
Expand All @@ -1121,23 +1159,24 @@ def _run_on_commit_callbacks(pre_commit_execution, post_commit_execution):
for on_commit_cb in self._on_commit_callbacks:
on_commit_cb()

mlmd_execution_atomic_op_context = mlmd_state.mlmd_execution_atomic_op(
self.mlmd_handle, self.execution_id, _run_on_commit_callbacks)
execution = mlmd_execution_atomic_op_context.__enter__()
self._mlmd_execution_atomic_op_context = mlmd_execution_atomic_op_context
self._execution = execution
self._node_states_proxy = _NodeStatesProxy(execution)
@contextlib.contextmanager
def mlmd_execution():
try:
with mlmd_state.mlmd_execution_atomic_op(
self.mlmd_handle, self.execution_id, _run_on_commit_callbacks
) as execution:
yield execution
finally:
self._on_commit_callbacks.clear()

self._execution = self.enter_context(mlmd_execution())
self._node_states_proxy = _NodeStatesProxy(self._execution)
self.callback(self._cleanup)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def _cleanup(self):
self._node_states_proxy.save()
mlmd_execution_atomic_op_context = self._mlmd_execution_atomic_op_context
self._mlmd_execution_atomic_op_context = None
self._execution = None
try:
mlmd_execution_atomic_op_context.__exit__(exc_type, exc_val, exc_tb)
finally:
self._on_commit_callbacks.clear()

def _check_context(self) -> None:
if self._execution is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ def schedule(self) -> task_scheduler.TaskSchedulerResult:
# was already started so the execution should already exist.
self._put_begin_node_execution()
logging.info('[Subpipeline Task Scheduler]: start subpipeline.')
pipeline_ops.initiate_pipeline_start(self.mlmd_handle,
self._sub_pipeline, None, None)
pipeline_ops.initiate_pipeline_start(
mlmd_handle=self.mlmd_handle,
pipeline=self._sub_pipeline,
trace_proto=None,
)
except status_lib.StatusNotOkError as e:
return task_scheduler.TaskSchedulerResult(status=e.status())

Expand Down
18 changes: 18 additions & 0 deletions tfx/utils/tracecontext_pb2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tracecontext module."""


class TraceContextProto:
pass
38 changes: 38 additions & 0 deletions tfx/utils/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tracing module."""

import contextlib
from typing import Any

from tfx.utils import tracecontext_pb2


class TraceContext:

def is_traced(self) -> bool:
return False

def to_proto(self) -> tracecontext_pb2.TraceContextProto:
return tracecontext_pb2.TraceContextProto()


class LocalTraceSpan(contextlib.ContextDecorator):

def __init__(self, *args: Any, **kwargs: Any):
pass


def current_tracing_context() -> TraceContext:
return TraceContext()