Skip to content

Commit

Permalink
Encode producer component id and output key when CWP is created from …
Browse files Browse the repository at this point in the history
…an OutputChannel

PiperOrigin-RevId: 619667393
  • Loading branch information
kmonte authored and tfx-copybara committed Apr 17, 2024
1 parent ef4dd95 commit a3c998d
Show file tree
Hide file tree
Showing 15 changed files with 448 additions and 144 deletions.
9 changes: 8 additions & 1 deletion tfx/dsl/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,17 @@ def testCompileAdditionalCustomPropertyNameConflictError(self):
def testCompileDynamicExecPropTypeError(self):
dsl_compiler = compiler.Compiler()
test_pipeline = dynamic_exec_properties_pipeline.create_test_pipeline()
upstream_component = next(
c
for c in test_pipeline.components
if isinstance(c, dynamic_exec_properties_pipeline.UpstreamComponent)
)
downstream_component = next(
c for c in test_pipeline.components
if isinstance(c, dynamic_exec_properties_pipeline.DownstreamComponent))
test_wrong_type_channel = channel.Channel(_MyType).future().value
test_wrong_type_channel = (
channel.OutputChannel(_MyType, upstream_component, "foo").future().value
)
downstream_component.exec_properties["input_num"] = test_wrong_type_channel
with self.assertRaisesRegex(
ValueError, ".*channel must be of a value artifact type.*"
Expand Down
42 changes: 28 additions & 14 deletions tfx/dsl/compiler/compiler_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,24 @@
import itertools

import tensorflow as tf
from tfx import components
from tfx import types
from tfx.components import CsvExampleGen
from tfx.components import StatisticsGen
from tfx.dsl.compiler import compiler_utils
from tfx.dsl.components.base import base_component
from tfx.dsl.components.base import base_executor
from tfx.dsl.components.base import executor_spec
from tfx.dsl.components.base.testing import test_node
from tfx.dsl.components.common import importer
from tfx.dsl.components.common import resolver
from tfx.dsl.input_resolution.strategies import latest_blessed_model_strategy
from tfx.dsl.placeholder import placeholder as ph
from tfx.orchestration import pipeline
from tfx.proto.orchestration import pipeline_pb2
from tfx.types import channel
from tfx.types import standard_artifacts
from tfx.types.artifact import Artifact
from tfx.types.artifact import Property
from tfx.types.artifact import PropertyType
from tfx.types.channel import Channel
from tfx.types.channel import OutputChannel
from tfx.types.channel_utils import external_pipeline_artifact_query

from google.protobuf import text_format
Expand Down Expand Up @@ -98,7 +97,7 @@ def testIsResolver(self):
strategy_class=latest_blessed_model_strategy.LatestBlessedModelStrategy)
self.assertTrue(compiler_utils.is_resolver(resv))

example_gen = CsvExampleGen(input_base="data_path")
example_gen = components.CsvExampleGen(input_base="data_path")
self.assertFalse(compiler_utils.is_resolver(example_gen))

def testHasResolverNode(self):
Expand All @@ -116,7 +115,7 @@ def testIsImporter(self):
source_uri="uri/to/schema", artifact_type=standard_artifacts.Schema)
self.assertTrue(compiler_utils.is_importer(impt))

example_gen = CsvExampleGen(input_base="data_path")
example_gen = components.CsvExampleGen(input_base="data_path")
self.assertFalse(compiler_utils.is_importer(example_gen))

def testEnsureTopologicalOrder(self):
Expand All @@ -128,9 +127,9 @@ def testEnsureTopologicalOrder(self):
valid_orders = {"abc", "acb"}
for order in itertools.permutations([a, b, c]):
if "".join([c.id for c in order]) in valid_orders:
self.assertTrue(compiler_utils.ensure_topological_order(order))
self.assertTrue(compiler_utils.ensure_topological_order(list(order)))
else:
self.assertFalse(compiler_utils.ensure_topological_order(order))
self.assertFalse(compiler_utils.ensure_topological_order(list(order)))

def testIncompatibleExecutionMode(self):
p = pipeline.Pipeline(
Expand All @@ -143,8 +142,10 @@ def testIncompatibleExecutionMode(self):
compiler_utils.resolve_execution_mode(p)

def testHasTaskDependency(self):
example_gen = CsvExampleGen(input_base="data_path")
statistics_gen = StatisticsGen(examples=example_gen.outputs["examples"])
example_gen = components.CsvExampleGen(input_base="data_path")
statistics_gen = components.StatisticsGen(
examples=example_gen.outputs["examples"]
)
p1 = pipeline.Pipeline(
pipeline_name="fake_name",
pipeline_root="fake_root",
Expand Down Expand Up @@ -204,7 +205,14 @@ class ValidateExecPropertyPlaceholderTest(tf.test.TestCase):
def test_accepts_canonical_dynamic_exec_prop_placeholder(self):
# .future()[0].uri is how we tell users to hook up a dynamic exec prop.
compiler_utils.validate_exec_property_placeholder(
"testkey", Channel(type=_MyType).future()[0].value
"testkey",
channel.OutputChannel(
artifact_type=_MyType,
producer_component=test_node.TestNode("producer"),
output_key="foo",
)
.future()[0]
.value,
)

def test_accepts_complex_exec_prop_placeholder(self):
Expand All @@ -219,7 +227,13 @@ def test_accepts_complex_exec_prop_placeholder(self):
def test_accepts_complex_dynamic_exec_prop_placeholder(self):
compiler_utils.validate_exec_property_placeholder(
"testkey",
Channel(type=_MyType).future()[0].value
channel.OutputChannel(
artifact_type=_MyType,
producer_component=test_node.TestNode("producer"),
output_key="foo",
)
.future()[0]
.value
+ "foo"
+ ph.input("someartifact").uri
+ "/somefile.txt",
Expand Down Expand Up @@ -265,14 +279,14 @@ def test_rejects_exec_property_dependency(self):
)

def testOutputSpecFromChannel_AsyncOutputChannel(self):
channel = OutputChannel(
ch = channel.OutputChannel(
artifact_type=standard_artifacts.Model,
output_key="model",
producer_component="trainer",
is_async=True,
)

actual = compiler_utils.output_spec_from_channel(channel, "trainer")
actual = compiler_utils.output_spec_from_channel(ch, "trainer")
expected = text_format.Parse(
"""
artifact_spec {
Expand Down
4 changes: 2 additions & 2 deletions tfx/dsl/compiler/node_inputs_compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def testCompileConditionals(self):
index_op {
expression {
placeholder {
key: "%s"
key: "CondNode_x"
}
}
}
Expand All @@ -354,7 +354,7 @@ def testCompileConditionals(self):
}
}
}
""" % cond_input_key, cond.placeholder_expression)
""", cond.placeholder_expression)

def testCompileInputsForDynamicProperties(self):
producer = DummyNode('Producer')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1886,7 +1886,7 @@ nodes {
index_op {
expression {
placeholder {
key: "blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -2983,7 +2983,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_infra-validator-pipeline.blessing"
key: "infra-validator-pipeline_blessing"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2145,7 +2145,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_Evaluator.blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -3351,7 +3351,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_infra-validator-pipeline.blessing"
key: "infra-validator-pipeline_blessing"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_Evaluator.blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -1264,7 +1264,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_Evaluator.blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -1301,7 +1301,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_InfraValidator.blessing"
key: "InfraValidator_blessing"
}
}
}
Expand Down Expand Up @@ -1333,7 +1333,7 @@ nodes {
index_op {
expression {
placeholder {
key: "model"
key: "Trainer_model"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_UpstreamComponent.num"
key: "UpstreamComponent_num"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_UpstreamComponent.num"
key: "UpstreamComponent_num"
}
}
}
Expand Down
31 changes: 31 additions & 0 deletions tfx/dsl/components/base/testing/test_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module to provide a node for tests."""

from tfx.dsl.components.base import base_node


class TestNode(base_node.BaseNode):
"""Node purely for testing, intentionally empty.
DO NOT USE in real pipelines.
"""

inputs = {}
outputs = {}
exec_properties = {}

def __init__(self, name: str):
super().__init__()
self.with_id(name)
13 changes: 11 additions & 2 deletions tfx/types/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def trigger_by_property(self, *property_keys: str):
return self._with_input_trigger(TriggerByProperty(property_keys))

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(self)
raise NotImplementedError()

def __eq__(self, other):
return self is other
Expand Down Expand Up @@ -557,6 +557,11 @@ def set_external(self, predefined_artifact_uris: List[str]) -> None:
def set_as_async_channel(self) -> None:
self._is_async = True

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(
self, f'{self.producer_component_id}_{self.output_key}'
)


@doc_controls.do_not_generate_docs
class UnionChannel(BaseChannel):
Expand Down Expand Up @@ -703,6 +708,9 @@ def trigger_by_property(self, *property_keys: str):
'trigger_by_property is not implemented for PipelineInputChannel.'
)

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(self, f'{self._output_key}')


class ExternalPipelineChannel(BaseChannel):
"""Channel subtype that is used to get artifacts from external MLMD db."""
Expand Down Expand Up @@ -787,7 +795,8 @@ def set_key(self, key: Optional[str]):
Args:
key: The new key for the channel.
"""
self._key = key
del key # unused.
return

def __getitem__(self, index: int) -> ChannelWrappedPlaceholder:
if self._index is not None:
Expand Down
4 changes: 1 addition & 3 deletions tfx/types/channel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,8 @@ def unwrap_simple_channel_placeholder(
# proto paths above and been getting default messages all along. If this
# sub-message is present, then the whole chain was correct.
not index_op.expression.HasField('placeholder')
# ChannelWrappedPlaceholder uses INPUT_ARTIFACT for some reason, and has
# no key when encoded with encode().
# ChannelWrappedPlaceholder uses INPUT_ARTIFACT for some reason.
or cwp.type != placeholder_pb2.Placeholder.Type.INPUT_ARTIFACT
or cwp.key
# For the `[0]` part of the desired shape.
or index_op.index != 0
):
Expand Down

0 comments on commit a3c998d

Please sign in to comment.