Skip to content

Commit

Permalink
Call pipeline_start_postprocess for revive, update, and resume
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623306545
  • Loading branch information
kmonte authored and tfx-copybara committed Apr 15, 2024
1 parent 9332479 commit d6eecd3
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 0 deletions.
50 changes: 50 additions & 0 deletions tfx/orchestration/subpipeline_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic utilities for orchestrating subpipelines."""


from tfx.dsl.compiler import compiler_utils
from tfx.dsl.compiler import constants as compiler_constants
from tfx.orchestration import pipeline as dsl_pipeline
from tfx.proto.orchestration import pipeline_pb2

_DUMMY_PIPELINE = dsl_pipeline.Pipeline(pipeline_name="Dummy-Pipeline")


def is_subpipeline(pipeline: pipeline_pb2.Pipeline) -> bool:
"""Returns True if the pipeline is a subpipeline."""
nodes = pipeline.nodes
if len(nodes) < 2:
return False
maybe_begin_node = nodes[0]
maybe_end_node = nodes[-1]
if (
maybe_begin_node.WhichOneof("node") != "pipeline_node"
or maybe_begin_node.pipeline_node.node_info.id
!= f"{pipeline.pipeline_info.id}{compiler_constants.PIPELINE_BEGIN_NODE_SUFFIX}"
or maybe_begin_node.pipeline_node.node_info.type.name
!= compiler_utils.pipeline_begin_node_type_name(_DUMMY_PIPELINE)
):
return False
if (
maybe_end_node.WhichOneof("node") != "pipeline_node"
or maybe_end_node.pipeline_node.node_info.id
!= compiler_utils.pipeline_end_node_id_from_pipeline_id(
pipeline.pipeline_info.id
)
or maybe_end_node.pipeline_node.node_info.type.name
!= compiler_utils.pipeline_end_node_type_name(_DUMMY_PIPELINE)
):
return False
return True
149 changes: 149 additions & 0 deletions tfx/orchestration/subpipeline_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.orchestration.subpipeline_utils."""

from absl.testing import absltest
from absl.testing import parameterized
from tfx.dsl.compiler import compiler_utils
from tfx.orchestration import pipeline as dsl_pipeline
from tfx.orchestration import subpipeline_utils
from tfx.proto.orchestration import pipeline_pb2

_PIPELINE_NAME = 'test_pipeline'
_TEST_PIPELINE = dsl_pipeline.Pipeline(pipeline_name=_PIPELINE_NAME)


def _get_begin_node(
correct_id: bool = True,
correct_type: bool = True,
is_subpipeline: bool = False,
) -> pipeline_pb2.Pipeline.PipelineOrNode:
pipeline_or_node = pipeline_pb2.Pipeline.PipelineOrNode()
if is_subpipeline:
pipeline_or_node.sub_pipeline.SetInParent()
else:
node = pipeline_or_node.pipeline_node
if correct_id:
node.node_info.id = compiler_utils.pipeline_begin_node_id(_TEST_PIPELINE)
else:
node.node_info.id = 'not-a-begin-node'

if correct_type:
node.node_info.type.name = compiler_utils.pipeline_begin_node_type_name(
_TEST_PIPELINE
)
else:
node.node_info.type.name = 'not-a-begin-node-type'
return pipeline_or_node


def _get_end_node(
correct_id: bool = True,
correct_type: bool = True,
is_subpipeline: bool = False,
) -> pipeline_pb2.Pipeline.PipelineOrNode:
pipeline_or_node = pipeline_pb2.Pipeline.PipelineOrNode()
if is_subpipeline:
pipeline_or_node.sub_pipeline.SetInParent()
else:
node = pipeline_or_node.pipeline_node
if correct_id:
node.node_info.id = compiler_utils.pipeline_end_node_id(_TEST_PIPELINE)
else:
node.node_info.id = 'not-a-end-node'

if correct_type:
node.node_info.type.name = compiler_utils.pipeline_end_node_type_name(
_TEST_PIPELINE
)
else:
node.node_info.type.name = 'not-a-end-node-type'
return pipeline_or_node


class SubpipelineUtilsTest(parameterized.TestCase):

def test_is_subpipeline_with_subpipeline(self):
pipeline = pipeline_pb2.Pipeline()
pipeline.pipeline_info.id = _PIPELINE_NAME
pipeline.nodes.add().CopyFrom(_get_begin_node())
pipeline.nodes.add().CopyFrom(_get_end_node())
self.assertTrue(subpipeline_utils.is_subpipeline(pipeline))

@parameterized.named_parameters(
dict(
testcase_name='incorrect_id',
node=_get_begin_node(correct_id=False),
),
dict(
testcase_name='incorrect_type',
node=_get_begin_node(correct_type=False),
),
dict(
testcase_name='is_subpipeline',
node=_get_begin_node(is_subpipeline=True),
),
)
def test_is_subpipeline_with_no_subpipeline_incorrect_begin(self, node):
pipeline = pipeline_pb2.Pipeline()
pipeline.pipeline_info.id = _PIPELINE_NAME
pipeline.nodes.add().CopyFrom(node)
pipeline.nodes.add().CopyFrom(_get_end_node())
self.assertFalse(subpipeline_utils.is_subpipeline(pipeline))

@parameterized.named_parameters(
dict(
testcase_name='incorrect_id',
node=_get_end_node(correct_id=False),
),
dict(
testcase_name='incorrect_type',
node=_get_end_node(correct_type=False),
),
dict(
testcase_name='is_subpipeline',
node=_get_end_node(is_subpipeline=True),
),
)
def test_is_subpipeline_with_no_subpipeline_incorrect_end(self, node):
pipeline = pipeline_pb2.Pipeline()
pipeline.pipeline_info.id = _PIPELINE_NAME
pipeline.nodes.add().CopyFrom(_get_begin_node())
pipeline.nodes.add().CopyFrom(node)
self.assertFalse(subpipeline_utils.is_subpipeline(pipeline))

@parameterized.named_parameters(
dict(
testcase_name='no_nodes',
nodes=[],
),
dict(
testcase_name='one_node',
nodes=[_get_begin_node()],
),
dict(
testcase_name='wrong_order',
nodes=[_get_end_node(), _get_begin_node()],
),
)
def test_is_subpipeline_nodes_incorrect(self, nodes):
pipeline = pipeline_pb2.Pipeline()
pipeline.pipeline_info.id = _PIPELINE_NAME
for node in nodes:
pipeline.nodes.add().CopyFrom(node)
self.assertFalse(subpipeline_utils.is_subpipeline(pipeline))


if __name__ == '__main__':
absltest.main()

0 comments on commit d6eecd3

Please sign in to comment.