Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

no-op #6760

Merged
merged 1 commit into from
May 21, 2024
Merged

no-op #6760

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 30 additions & 0 deletions tfx/orchestration/datahub_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to log Tflex/MLMD entities."""
from typing import Optional

from tfx.orchestration.experimental.core import task as task_lib
from tfx.utils import typing_utils

from ml_metadata.proto import metadata_store_pb2


def log_node_execution(
execution: metadata_store_pb2.Execution,
task: Optional[task_lib.ExecNodeTask] = None,
output_artifacts: Optional[typing_utils.ArtifactMultiMap] = None,
):
"""Logs a Tflex node execution and its input/output artifacts."""
del execution, task, output_artifacts
return
6 changes: 4 additions & 2 deletions tfx/orchestration/experimental/core/post_execution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def _update_state(
execution_id=task.execution_id,
contexts=task.contexts,
output_artifacts=task.output_artifacts,
executor_output=executor_output)
executor_output=executor_output,
task=task)
garbage_collection.run_garbage_collection_for_node(mlmd_handle,
task.node_uid,
task.get_node())
Expand Down Expand Up @@ -125,7 +126,8 @@ def _update_state(
mlmd_handle,
execution_id=task.execution_id,
contexts=task.contexts,
output_artifacts=output_artifacts)
output_artifacts=output_artifacts,
task=task)
elif isinstance(result.output, ts.ResolverNodeOutput):
resolved_input_artifacts = result.output.resolved_input_artifacts
# TODO(b/262040844): Instead of directly using the context manager here, we
Expand Down
25 changes: 24 additions & 1 deletion tfx/orchestration/portable/execution_publish_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Portable library for registering and publishing executions."""

import logging
from typing import Mapping, Optional, Sequence
import uuid

from tfx import types
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import task as task_lib
from tfx.orchestration import datahub_utils
from tfx.orchestration.portable import merge_utils
from tfx.orchestration.portable.mlmd import execution_lib
from tfx.proto.orchestration import execution_result_pb2
Expand Down Expand Up @@ -75,6 +79,7 @@ def publish_succeeded_execution(
contexts: Sequence[metadata_store_pb2.Context],
output_artifacts: Optional[typing_utils.ArtifactMultiMap] = None,
executor_output: Optional[execution_result_pb2.ExecutorOutput] = None,
task: Optional[task_lib.ExecNodeTask] = None,
) -> tuple[
Optional[typing_utils.ArtifactMultiMap],
metadata_store_pb2.Execution,
Expand All @@ -85,6 +90,9 @@ def publish_succeeded_execution(
will also merge the executor produced info into system generated output
artifacts. The `last_know_state` of the execution will be changed to
`COMPLETE` and the output artifacts will be marked as `LIVE`.
This method will also publish the execution and its input/output artifacts to
Datahub in best-effort mode if `enable_datahub_logging` in
TflexProjectPlatformConfig is set to True.

Args:
metadata_handle: A handler to access MLMD.
Expand All @@ -95,11 +103,12 @@ def publish_succeeded_execution(
event with type OUTPUT.
executor_output: Executor outputs. `executor_output.output_artifacts` will
be used to update system-generated output artifacts passed in through
`output_artifacts` arg. There are three contraints to the update: 1. The
`output_artifacts` arg. There are three constraints to the update: 1. The
keys in `executor_output.output_artifacts` are expected to be a subset of
the system-generated output artifacts dict. 2. An update to a certain key
should contains all the artifacts under that key. 3. An update to an
artifact should not change the type of the artifact.
task: the task that just completed for the given node execution.

Returns:
The tuple containing the maybe updated output_artifacts (note that only
Expand All @@ -108,7 +117,14 @@ def publish_succeeded_execution(
execution.
Raises:
RuntimeError: if the executor output to a output channel is partial.
ValueError: if `execution_id` is inconsistent with `task`.execution_id.
"""
if task and task.execution_id != execution_id:
raise ValueError(
f'Task execution_id {task.execution_id} does not match MLMD execution'
f' id {execution_id}'
)

unpacked_output_artifacts = (
None # pylint: disable=g-long-ternary
if executor_output is None
Expand Down Expand Up @@ -155,6 +171,13 @@ def publish_succeeded_execution(
output_artifacts=output_artifacts_to_publish,
)

try:
datahub_utils.log_node_execution(
execution, task, output_artifacts_to_publish
)
except Exception: # pylint: disable=broad-except
logging.exception('Failed to log node execution.')

return output_artifacts_to_publish, execution


Expand Down
106 changes: 99 additions & 7 deletions tfx/orchestration/portable/execution_publish_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.
"""Tests for tfx.orchestration.portable.execution_publish_utils."""
import copy
from unittest import mock

from absl.testing import parameterized
import tensorflow as tf
from tfx import version
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import task as task_lib
from tfx.orchestration import datahub_utils
from tfx.orchestration.portable import execution_publish_utils
from tfx.orchestration.portable import outputs_utils
from tfx.orchestration.portable.mlmd import context_lib
Expand All @@ -33,13 +36,63 @@
from ml_metadata.proto import metadata_store_pb2


_DEFAULT_EXECUTOR_OUTPUT_URI = '/fake/path/to/executor_output.pb'
_DEFAULT_NODE_ID = 'example_node'
_DEFAULT_OWNER = 'owner'
_DEFAULT_PROJECT_NAME = 'project_name'
_DEFAULT_PIPELINE_NAME = 'pipeline_name'
_DEFAULT_PIPELINE_RUN_ID = 'run-123'
_DEFAULT_TEMP_DIR = '/fake/path/to/tmp_dir/'
_DEFAULT_STATEFUL_WORKING_DIR = '/fake/path/to/stateful_working_dir/'


def _create_pipeline() -> pipeline_pb2.Pipeline:
deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
pipeline = pipeline_pb2.Pipeline(
pipeline_info=pipeline_pb2.PipelineInfo(id=_DEFAULT_PIPELINE_NAME),
nodes=[
pipeline_pb2.Pipeline.PipelineOrNode(
pipeline_node=pipeline_pb2.PipelineNode(
node_info=pipeline_pb2.NodeInfo(id=_DEFAULT_NODE_ID)
),
),
],
)
pipeline.deployment_config.Pack(deployment_config)
return pipeline


def _create_exec_node_task(
pipeline: pipeline_pb2.Pipeline,
execution_id: int,
) -> task_lib.ExecNodeTask:
return task_lib.ExecNodeTask(
pipeline=pipeline,
node_uid=task_lib.NodeUid(
pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
node_id=_DEFAULT_NODE_ID,
),
execution_id=execution_id,
contexts=[],
exec_properties={},
input_artifacts={},
output_artifacts={},
executor_output_uri=_DEFAULT_EXECUTOR_OUTPUT_URI,
stateful_working_dir=_DEFAULT_STATEFUL_WORKING_DIR,
tmp_dir=_DEFAULT_TEMP_DIR,
)


class ExecutionPublisherTest(test_case_utils.TfxTest, parameterized.TestCase):

def setUp(self):
super().setUp()
self._connection_config = metadata_store_pb2.ConnectionConfig()
self._connection_config.sqlite.SetInParent()
self._execution_type = metadata_store_pb2.ExecutionType(name='my_ex_type')
self._mock_log_node_execution = self.enter_context(
mock.patch.object(datahub_utils, 'log_node_execution')
)

def _generate_contexts(self, metadata_handle):
context_spec = pipeline_pb2.NodeContexts()
Expand Down Expand Up @@ -191,13 +244,15 @@ def testPublishSuccessfulExecution(self):
value {int_value: 1}
}
""", executor_output.output_artifacts[output_key].artifacts.add())
task = _create_exec_node_task(_create_pipeline(), execution_id)
output_dict, execution = (
execution_publish_utils.publish_succeeded_execution(
m,
execution_id,
contexts,
{output_key: [output_example]},
executor_output,
task,
)
)
self.assertProtoPartiallyEquals(
Expand Down Expand Up @@ -283,6 +338,11 @@ def testPublishSuccessfulExecution(self):
self.assertCountEqual([c.id for c in contexts], [
c.id for c in m.store.get_contexts_by_artifact(output_example.id)
])
self._mock_log_node_execution.assert_called_once_with(
execution,
task,
output_dict,
)

def testPublishSuccessfulExecutionWithRuntimeResolvedUri(self):
with metadata.Metadata(connection_config=self._connection_config) as m:
Expand All @@ -307,10 +367,17 @@ def testPublishSuccessfulExecutionWithRuntimeResolvedUri(self):
value {{int_value: 1}}
}}
""", executor_output.output_artifacts[output_key].artifacts.add())

output_dict, _ = execution_publish_utils.publish_succeeded_execution(
m, execution_id, contexts, {output_key: [output_example]},
executor_output)
task = _create_exec_node_task(_create_pipeline(), execution_id)
output_dict, execution = (
execution_publish_utils.publish_succeeded_execution(
m,
execution_id,
contexts,
{output_key: [output_example]},
executor_output,
task,
)
)
self.assertLen(output_dict[output_key], 2)
self.assertEqual(output_dict[output_key][0].uri, '/examples_uri/1')
self.assertEqual(output_dict[output_key][1].uri, '/examples_uri/2')
Expand All @@ -337,6 +404,11 @@ def testPublishSuccessfulExecutionWithRuntimeResolvedUri(self):
""",
event,
ignored_fields=['milliseconds_since_epoch'])
self._mock_log_node_execution.assert_called_once_with(
execution,
task,
output_dict,
)

def testPublishSuccessfulExecutionOmitsArtifactIfNotResolvedDuringRuntime(
self):
Expand Down Expand Up @@ -366,12 +438,26 @@ def testPublishSuccessfulExecutionOmitsArtifactIfNotResolvedDuringRuntime(
value {{int_value: 1}}
}}
""", executor_output.output_artifacts['key1'].artifacts.add())
output_dict, _ = execution_publish_utils.publish_succeeded_execution(
m, execution_id, contexts, original_artifacts, executor_output)
task = _create_exec_node_task(_create_pipeline(), execution_id)
output_dict, execution = (
execution_publish_utils.publish_succeeded_execution(
m,
execution_id,
contexts,
original_artifacts,
executor_output,
task,
)
)
self.assertEmpty(output_dict['key1'])
self.assertNotEmpty(output_dict['key2'])
self.assertLen(output_dict['key2'], 1)
self.assertEqual(output_dict['key2'][0].uri, '/foo/bar')
self._mock_log_node_execution.assert_called_once_with(
execution,
task,
output_dict,
)

def testPublishSuccessExecutionFailNewKey(self):
with metadata.Metadata(connection_config=self._connection_config) as m:
Expand Down Expand Up @@ -418,14 +504,15 @@ def testPublishSuccessExecutionExecutorEditedOutputDict(self):
value {int_value: 2}
}
""", executor_output.output_artifacts[output_key].artifacts.add())

task = _create_exec_node_task(_create_pipeline(), execution_id)
output_dict, execution = (
execution_publish_utils.publish_succeeded_execution(
m,
execution_id,
contexts,
{output_key: [output_example]},
executor_output,
task,
)
)
self.assertProtoPartiallyEquals(
Expand Down Expand Up @@ -541,6 +628,11 @@ def testPublishSuccessExecutionExecutorEditedOutputDict(self):
output_example.get_string_custom_property(
artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY),
version.__version__)
self._mock_log_node_execution.assert_called_once_with(
execution,
task,
output_dict,
)

def testPublishSuccessExecutionFailChangedType(self):
with metadata.Metadata(connection_config=self._connection_config) as m:
Expand Down