Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend PipelineState.trace_proto to orchestrator loop and ExecNodeTask. #6703

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 1 addition & 5 deletions test_constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,4 @@
# dependencies.

# TODO(b/321609768): Remove pinned Flask-session version after resolving the issue.
Flask-session<0.6.0

#TODO(b/329181965): Remove once we migrate TFX to 2.16.
tensorflow<2.16
tensorflow-text<2.16
Flask-session<0.6.0
53 changes: 35 additions & 18 deletions tfx/orchestration/experimental/core/pipeline_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from tfx.utils import io_utils
from tfx.utils import status as status_lib

from tfx.utils import tracing
from ml_metadata import errors as mlmd_errors
from ml_metadata.proto import metadata_store_pb2

Expand Down Expand Up @@ -82,6 +83,11 @@ def _wrapper(*args, **kwargs):
with contextlib.ExitStack() as stack:
if lock:
stack.enter_context(_PIPELINE_OPS_LOCK)
stack.enter_context(
tracing.LocalTraceSpan(
'Tflex.Orchestrator.PipelineOps', fn.__name__
)
)

health_status = env.get_env().health_status()
if health_status.code != status_lib.Code.OK:
Expand Down Expand Up @@ -220,10 +226,11 @@ def initiate_pipeline_start(
# TODO: b/323912217 - Support putting multiple subpipeline executions
# into MLMD to handle the ForEach case.
with pstate.PipelineState.new(
mlmd_handle,
subpipeline,
pipeline_run_metadata,
reused_subpipeline_view,
mlmd_handle=mlmd_handle,
pipeline=subpipeline,
pipeline_run_metadata=pipeline_run_metadata,
reused_pipeline_view=reused_subpipeline_view,
trace_context=tracing.current_trace_context(),
) as subpipeline_state:
# TODO: b/320535460 - The new pipeline run should not be stopped if
# there are still nodes to run in it.
Expand All @@ -250,7 +257,11 @@ def initiate_pipeline_start(
)

return pstate.PipelineState.new(
mlmd_handle, pipeline, pipeline_run_metadata, reused_pipeline_view
mlmd_handle=mlmd_handle,
pipeline=pipeline,
pipeline_run_metadata=pipeline_run_metadata,
reused_pipeline_view=reused_pipeline_view,
trace_context=tracing.current_trace_context(),
)


Expand Down Expand Up @@ -939,9 +950,7 @@ def resume_pipeline(
from_nodes=pipeline_nodes,
to_nodes=pipeline_nodes,
skip_nodes=previously_succeeded_nodes,
skip_snapshot_nodes=_get_previously_skipped_nodes(
latest_pipeline_view
),
skip_snapshot_nodes=_get_previously_skipped_nodes(latest_pipeline_view),
snapshot_settings=snapshot_settings,
)
except ValueError as e:
Expand All @@ -961,7 +970,10 @@ def resume_pipeline(
)

return pstate.PipelineState.new(
mlmd_handle, pipeline, reused_pipeline_view=latest_pipeline_view
mlmd_handle=mlmd_handle,
pipeline=pipeline,
reused_pipeline_view=latest_pipeline_view,
trace_context=tracing.current_trace_context(),
)


Expand Down Expand Up @@ -1267,9 +1279,9 @@ def orchestrate(
logging.info('No active pipelines to run.')
return False

active_pipeline_states = []
stop_initiated_pipeline_states = []
update_initiated_pipeline_states = []
active_pipeline_states: list[pstate.PipelineState] = []
stop_initiated_pipeline_states: list[pstate.PipelineState] = []
update_initiated_pipeline_states: list[pstate.PipelineState] = []
for pipeline_state in pipeline_states:
with pipeline_state:
if pipeline_state.is_stop_initiated():
Expand Down Expand Up @@ -1357,12 +1369,17 @@ def orchestrate(
for pipeline_state in active_pipeline_states:
logging.info('Orchestrating pipeline: %s', pipeline_state.pipeline_uid)
try:
_orchestrate_active_pipeline(
mlmd_connection_manager,
task_queue,
service_job_manager,
pipeline_state,
)
with tracing.LocalTraceSpan(
'Tflex.Orchestrator.PipelineOps',
'_orchestrate_active_pipeline',
parent_proto=pipeline_state.trace_proto,
):
_orchestrate_active_pipeline(
mlmd_connection_manager,
task_queue,
service_job_manager,
pipeline_state,
)
except Exception as e: # pylint: disable=broad-except
logging.exception(
'Exception raised while orchestrating active pipeline %s',
Expand Down
70 changes: 55 additions & 15 deletions tfx/orchestration/experimental/core/pipeline_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@

from tfx.utils import telemetry_utils
from google.protobuf import message
from tfx.utils import tracecontext_pb2
from tfx.utils import tracing
import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2

Expand All @@ -74,6 +76,7 @@
_PIPELINE_EXEC_MODE = 'pipeline_exec_mode'
_PIPELINE_EXEC_MODE_SYNC = 'sync'
_PIPELINE_EXEC_MODE_ASYNC = 'async'
_PIPELINE_RUN_TRACE_PROTO = 'trace_proto'

_last_state_change_time_secs = -1.0
_state_change_time_lock = threading.Lock()
Expand Down Expand Up @@ -485,7 +488,28 @@ def wrapper(*args, **kwargs):
return wrapper


class PipelineState:
def _set_trace_proto(
execution: metadata_store_pb2.Execution,
trace_proto: tracecontext_pb2.TraceContextProto,
):
execution.custom_properties[_PIPELINE_RUN_TRACE_PROTO].proto_value.Pack(
trace_proto
)


def _get_trace_proto(
execution: metadata_store_pb2.Execution,
) -> Optional[tracecontext_pb2.TraceContextProto]:
if _PIPELINE_RUN_TRACE_PROTO in execution.custom_properties:
result = tracecontext_pb2.TraceContextProto()
execution.custom_properties[_PIPELINE_RUN_TRACE_PROTO].proto_value.Unpack(
result
)
return result
return None


class PipelineState(contextlib.ExitStack):
"""Context manager class for dealing with pipeline state.

The state of a pipeline is stored as an MLMD execution and this class provides
Expand All @@ -512,6 +536,7 @@ class PipelineState:
execution_id: Id of the underlying execution in MLMD.
pipeline_uid: Unique id of the pipeline.
pipeline_run_id: pipeline_run_id in case of sync pipeline, `None` otherwise.
trace_proto: Optional TraceContextProto for tracing the pipeline run.
"""

def __init__(
Expand All @@ -521,6 +546,7 @@ def __init__(
pipeline_id: str,
):
"""Constructor. Use one of the factory methods to initialize."""
super().__init__()
self.mlmd_handle = mlmd_handle
# TODO(b/201294315): Fix self.pipeline going out of sync with the actual
# pipeline proto stored in the underlying MLMD execution in some cases.
Expand All @@ -540,13 +566,19 @@ def __init__(
self.pipeline_uid = task_lib.PipelineUid.from_pipeline_id_and_run_id(
pipeline_id, self.pipeline_run_id
)
# Since TraceContextProto is immutable, use the constructor Execution arg to
# retrieve the value instead of lazy parsing self._execution.
self._trace_proto = _get_trace_proto(execution)

# Only set within the pipeline state context.
self._mlmd_execution_atomic_op_context = None
self._execution: Optional[metadata_store_pb2.Execution] = None
self._on_commit_callbacks: List[Callable[[], None]] = []
self._node_states_proxy: Optional[_NodeStatesProxy] = None

@property
def trace_proto(self) -> Optional[tracecontext_pb2.TraceContextProto]:
return self._trace_proto

@classmethod
@telemetry_utils.noop_telemetry(metrics_utils.no_op_metrics)
@_synchronized
Expand All @@ -556,6 +588,7 @@ def new(
pipeline: pipeline_pb2.Pipeline,
pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None,
reused_pipeline_view: Optional['PipelineView'] = None,
trace_context: Optional[tracing.TraceContext] = None,
) -> 'PipelineState':
"""Creates a `PipelineState` object for a new pipeline.

Expand All @@ -568,6 +601,9 @@ def new(
pipeline_run_metadata: Pipeline run metadata.
reused_pipeline_view: PipelineView of the previous pipeline reused for a
partial run.
trace_context: Optional trace context to attach to the pipeline run. This
trace context will be extended from the child job during the pipeline
run.

Returns:
A `PipelineState` object.
Expand Down Expand Up @@ -670,6 +706,8 @@ def new(
exec_properties=exec_properties,
execution_name=str(uuid.uuid4()),
)
if trace_context is not None and trace_context.is_traced():
_set_trace_proto(execution, trace_context.to_proto())
if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
data_types_utils.set_metadata_value(
execution.custom_properties[_PIPELINE_RUN_ID],
Expand Down Expand Up @@ -1113,6 +1151,7 @@ def get_orchestration_options(
return env.get_env().get_orchestration_options(self.pipeline)

def __enter__(self) -> 'PipelineState':
super().__enter__()

def _run_on_commit_callbacks(pre_commit_execution, post_commit_execution):
del pre_commit_execution
Expand All @@ -1121,23 +1160,24 @@ def _run_on_commit_callbacks(pre_commit_execution, post_commit_execution):
for on_commit_cb in self._on_commit_callbacks:
on_commit_cb()

mlmd_execution_atomic_op_context = mlmd_state.mlmd_execution_atomic_op(
self.mlmd_handle, self.execution_id, _run_on_commit_callbacks)
execution = mlmd_execution_atomic_op_context.__enter__()
self._mlmd_execution_atomic_op_context = mlmd_execution_atomic_op_context
self._execution = execution
self._node_states_proxy = _NodeStatesProxy(execution)
@contextlib.contextmanager
def mlmd_execution():
try:
with mlmd_state.mlmd_execution_atomic_op(
self.mlmd_handle, self.execution_id, _run_on_commit_callbacks
) as execution:
yield execution
finally:
self._on_commit_callbacks.clear()

self._execution = self.enter_context(mlmd_execution())
self._node_states_proxy = _NodeStatesProxy(self._execution)
self.callback(self._cleanup)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def _cleanup(self):
self._node_states_proxy.save()
mlmd_execution_atomic_op_context = self._mlmd_execution_atomic_op_context
self._mlmd_execution_atomic_op_context = None
self._execution = None
try:
mlmd_execution_atomic_op_context.__exit__(exc_type, exc_val, exc_tb)
finally:
self._on_commit_callbacks.clear()

def _check_context(self) -> None:
if self._execution is None:
Expand Down