Skip to content

Commit

Permalink
no-op
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625128214
  • Loading branch information
tfx-copybara committed May 3, 2024
1 parent 281e5e5 commit d8917ee
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 8 deletions.
30 changes: 30 additions & 0 deletions tfx/orchestration/datahub_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to log Tflex/MLMD entities to Datahub."""
from typing import Optional

from tfx.orchestration.experimental.core import task as task_lib
from tfx.utils import typing_utils

from ml_metadata.proto import metadata_store_pb2


def log_node_execution(
execution: metadata_store_pb2.Execution,
task: Optional[task_lib.ExecNodeTask] = None,
output_artifacts: Optional[typing_utils.ArtifactMultiMap] = None,
):
"""Logs a Tflex node execution and its input/output artifacts."""
del execution, task, output_artifacts
return
19 changes: 18 additions & 1 deletion tfx/orchestration/portable/execution_publish_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Portable library for registering and publishing executions."""

from typing import Mapping, Optional, Sequence
import uuid

from tfx import types
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import task as task_lib
from tfx.orchestration import datahub_utils
from tfx.orchestration.portable import merge_utils
from tfx.orchestration.portable.mlmd import execution_lib
from tfx.proto.orchestration import execution_result_pb2
Expand Down Expand Up @@ -75,6 +78,7 @@ def publish_succeeded_execution(
contexts: Sequence[metadata_store_pb2.Context],
output_artifacts: Optional[typing_utils.ArtifactMultiMap] = None,
executor_output: Optional[execution_result_pb2.ExecutorOutput] = None,
task: Optional[task_lib.ExecNodeTask] = None,
) -> tuple[
Optional[typing_utils.ArtifactMultiMap],
metadata_store_pb2.Execution,
Expand All @@ -85,6 +89,9 @@ def publish_succeeded_execution(
will also merge the executor produced info into system generated output
artifacts. The `last_know_state` of the execution will be changed to
`COMPLETE` and the output artifacts will be marked as `LIVE`.
This method will also publish the execution and its input/output artifacts to
Datahub in best-effort mode if `enable_datahub_logging` in
TflexProjectPlatformConfig is set to True.
Args:
metadata_handle: A handler to access MLMD.
Expand All @@ -95,11 +102,12 @@ def publish_succeeded_execution(
event with type OUTPUT.
executor_output: Executor outputs. `executor_output.output_artifacts` will
be used to update system-generated output artifacts passed in through
`output_artifacts` arg. There are three contraints to the update: 1. The
`output_artifacts` arg. There are three constraints to the update: 1. The
keys in `executor_output.output_artifacts` are expected to be a subset of
the system-generated output artifacts dict. 2. An update to a certain key
should contains all the artifacts under that key. 3. An update to an
artifact should not change the type of the artifact.
task: the task that just completed for the given node execution.
Returns:
The tuple containing the maybe updated output_artifacts (note that only
Expand All @@ -108,7 +116,14 @@ def publish_succeeded_execution(
execution.
Raises:
RuntimeError: if the executor output to a output channel is partial.
ValueError: if `execution_id` is inconsistent with `task`.execution_id.
"""
if task and task.execution_id != execution_id:
raise ValueError(
f'Task execution_id {task.execution_id} does not match MLMD execution'
f' id {execution_id}'
)

unpacked_output_artifacts = (
None # pylint: disable=g-long-ternary
if executor_output is None
Expand Down Expand Up @@ -155,6 +170,8 @@ def publish_succeeded_execution(
output_artifacts=output_artifacts_to_publish,
)

datahub_utils.log_node_execution(execution, task, output_artifacts_to_publish)

return output_artifacts_to_publish, execution


Expand Down
105 changes: 98 additions & 7 deletions tfx/orchestration/portable/execution_publish_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.
"""Tests for tfx.orchestration.portable.execution_publish_utils."""
import copy
from unittest import mock

from absl.testing import parameterized
import tensorflow as tf
from tfx import version
from tfx.orchestration import metadata
from tfx.orchestration.experimental.core import task as task_lib
from tfx.orchestration import datahub_utils
from tfx.orchestration.portable import execution_publish_utils
from tfx.orchestration.portable import outputs_utils
from tfx.orchestration.portable.mlmd import context_lib
Expand All @@ -33,13 +36,62 @@
from ml_metadata.proto import metadata_store_pb2


_DEFAULT_EXECUTOR_OUTPUT_URI = '/fake/path/to/executor_output.pb'
_DEFAULT_NODE_ID = 'example_node'
_DEFAULT_OWNER = 'owner'
_DEFAULT_PROJECT_NAME = 'project_name'
_DEFAULT_PIPELINE_NAME = 'pipeline_name'
_DEFAULT_PIPELINE_RUN_ID = 'run-123'
_DEFAULT_TEMP_DIR = '/fake/path/to/tmp_dir/'
_DEFAULT_STATEFUL_WORKING_DIR = '/fake/path/to/stateful_working_dir/'


def _create_pipeline() -> pipeline_pb2.Pipeline:
deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
pipeline = pipeline_pb2.Pipeline(
pipeline_info=pipeline_pb2.PipelineInfo(id=_DEFAULT_PIPELINE_NAME),
nodes=[
pipeline_pb2.Pipeline.PipelineOrNode(
pipeline_node=pipeline_pb2.PipelineNode(
node_info=pipeline_pb2.NodeInfo(id=_DEFAULT_NODE_ID)
),
),
],
)
pipeline.deployment_config.Pack(deployment_config)
return pipeline


def _create_exec_node_task(
pipeline: pipeline_pb2.Pipeline,
execution_id: int,
) -> task_lib.ExecNodeTask:
return task_lib.ExecNodeTask(
pipeline=pipeline,
node_uid=task_lib.NodeUid(
pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
node_id=_DEFAULT_NODE_ID,
),
execution_id=execution_id,
contexts=[],
exec_properties={},
input_artifacts={},
output_artifacts={},
executor_output_uri=_DEFAULT_EXECUTOR_OUTPUT_URI,
stateful_working_dir=_DEFAULT_STATEFUL_WORKING_DIR,
tmp_dir=_DEFAULT_TEMP_DIR,
)


class ExecutionPublisherTest(test_case_utils.TfxTest, parameterized.TestCase):

def setUp(self):
super().setUp()
self._connection_config = metadata_store_pb2.ConnectionConfig()
self._connection_config.sqlite.SetInParent()
self._execution_type = metadata_store_pb2.ExecutionType(name='my_ex_type')
self._mock_log_node_execution = self.enter_context(
mock.patch.object(datahub_utils, 'log_node_execution'))

def _generate_contexts(self, metadata_handle):
context_spec = pipeline_pb2.NodeContexts()
Expand Down Expand Up @@ -191,13 +243,15 @@ def testPublishSuccessfulExecution(self):
value {int_value: 1}
}
""", executor_output.output_artifacts[output_key].artifacts.add())
task = _create_exec_node_task(_create_pipeline(), execution_id)
output_dict, execution = (
execution_publish_utils.publish_succeeded_execution(
m,
execution_id,
contexts,
{output_key: [output_example]},
executor_output,
task,
)
)
self.assertProtoPartiallyEquals(
Expand Down Expand Up @@ -283,6 +337,11 @@ def testPublishSuccessfulExecution(self):
self.assertCountEqual([c.id for c in contexts], [
c.id for c in m.store.get_contexts_by_artifact(output_example.id)
])
self._mock_log_node_execution.assert_called_once_with(
execution,
task,
output_dict,
)

def testPublishSuccessfulExecutionWithRuntimeResolvedUri(self):
with metadata.Metadata(connection_config=self._connection_config) as m:
Expand All @@ -307,10 +366,17 @@ def testPublishSuccessfulExecutionWithRuntimeResolvedUri(self):
value {{int_value: 1}}
}}
""", executor_output.output_artifacts[output_key].artifacts.add())

output_dict, _ = execution_publish_utils.publish_succeeded_execution(
m, execution_id, contexts, {output_key: [output_example]},
executor_output)
task = _create_exec_node_task(_create_pipeline(), execution_id)
output_dict, execution = (
execution_publish_utils.publish_succeeded_execution(
m,
execution_id,
contexts,
{output_key: [output_example]},
executor_output,
task,
)
)
self.assertLen(output_dict[output_key], 2)
self.assertEqual(output_dict[output_key][0].uri, '/examples_uri/1')
self.assertEqual(output_dict[output_key][1].uri, '/examples_uri/2')
Expand All @@ -337,6 +403,11 @@ def testPublishSuccessfulExecutionWithRuntimeResolvedUri(self):
""",
event,
ignored_fields=['milliseconds_since_epoch'])
self._mock_log_node_execution.assert_called_once_with(
execution,
task,
output_dict,
)

def testPublishSuccessfulExecutionOmitsArtifactIfNotResolvedDuringRuntime(
self):
Expand Down Expand Up @@ -366,12 +437,26 @@ def testPublishSuccessfulExecutionOmitsArtifactIfNotResolvedDuringRuntime(
value {{int_value: 1}}
}}
""", executor_output.output_artifacts['key1'].artifacts.add())
output_dict, _ = execution_publish_utils.publish_succeeded_execution(
m, execution_id, contexts, original_artifacts, executor_output)
task = _create_exec_node_task(_create_pipeline(), execution_id)
output_dict, execution = (
execution_publish_utils.publish_succeeded_execution(
m,
execution_id,
contexts,
original_artifacts,
executor_output,
task,
)
)
self.assertEmpty(output_dict['key1'])
self.assertNotEmpty(output_dict['key2'])
self.assertLen(output_dict['key2'], 1)
self.assertEqual(output_dict['key2'][0].uri, '/foo/bar')
self._mock_log_node_execution.assert_called_once_with(
execution,
task,
output_dict,
)

def testPublishSuccessExecutionFailNewKey(self):
with metadata.Metadata(connection_config=self._connection_config) as m:
Expand Down Expand Up @@ -418,14 +503,15 @@ def testPublishSuccessExecutionExecutorEditedOutputDict(self):
value {int_value: 2}
}
""", executor_output.output_artifacts[output_key].artifacts.add())

task = _create_exec_node_task(_create_pipeline(), execution_id)
output_dict, execution = (
execution_publish_utils.publish_succeeded_execution(
m,
execution_id,
contexts,
{output_key: [output_example]},
executor_output,
task,
)
)
self.assertProtoPartiallyEquals(
Expand Down Expand Up @@ -541,6 +627,11 @@ def testPublishSuccessExecutionExecutorEditedOutputDict(self):
output_example.get_string_custom_property(
artifact_utils.ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY),
version.__version__)
self._mock_log_node_execution.assert_called_once_with(
execution,
task,
output_dict,
)

def testPublishSuccessExecutionFailChangedType(self):
with metadata.Metadata(connection_config=self._connection_config) as m:
Expand Down

0 comments on commit d8917ee

Please sign in to comment.