Skip to content

Commit

Permalink
Encode producer component id and output key when CWP is created from …
Browse files Browse the repository at this point in the history
…an OutputChannel

PiperOrigin-RevId: 619667393
  • Loading branch information
kmonte authored and tfx-copybara committed Mar 28, 2024
1 parent 92fef51 commit 5ef7ee6
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 23 deletions.
9 changes: 8 additions & 1 deletion tfx/dsl/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,17 @@ def testCompileAdditionalCustomPropertyNameConflictError(self):
def testCompileDynamicExecPropTypeError(self):
dsl_compiler = compiler.Compiler()
test_pipeline = dynamic_exec_properties_pipeline.create_test_pipeline()
upstream_component = next(
c
for c in test_pipeline.components
if isinstance(c, dynamic_exec_properties_pipeline.UpstreamComponent)
)
downstream_component = next(
c for c in test_pipeline.components
if isinstance(c, dynamic_exec_properties_pipeline.DownstreamComponent))
test_wrong_type_channel = channel.Channel(_MyType).future().value
test_wrong_type_channel = (
channel.OutputChannel(_MyType, upstream_component, "foo").future().value
)
downstream_component.exec_properties["input_num"] = test_wrong_type_channel
with self.assertRaisesRegex(
ValueError, ".*channel must be of a value artifact type.*"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1886,7 +1886,7 @@ nodes {
index_op {
expression {
placeholder {
key: "blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -2983,7 +2983,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_infra-validator-pipeline.blessing"
key: "infra-validator-pipeline_blessing"
}
}
}
Expand Down
16 changes: 10 additions & 6 deletions tfx/dsl/compiler/testdata/composable_pipeline_input_v2_ir.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ nodes {
}
}
execution_options {
caching_options {}
caching_options {
}
}
}
}
Expand Down Expand Up @@ -1169,7 +1170,8 @@ nodes {
upstream_nodes: "data-ingestion-pipeline"
downstream_nodes: "Trainer"
execution_options {
caching_options {}
caching_options {
}
strategy: LAZILY_ALL_UPSTREAM_NODES_SUCCEEDED
max_execution_retries: 10
}
Expand Down Expand Up @@ -2143,7 +2145,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_Evaluator.blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -2206,7 +2208,8 @@ nodes {
downstream_nodes: "Pusher"
downstream_nodes: "infra-validator-pipeline"
execution_options {
caching_options {}
caching_options {
}
}
}
}
Expand Down Expand Up @@ -2507,7 +2510,8 @@ nodes {
upstream_nodes: "validate-and-push-pipeline_begin"
downstream_nodes: "InfraValidator"
execution_options {
caching_options {}
caching_options {
}
}
}
}
Expand Down Expand Up @@ -3347,7 +3351,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_infra-validator-pipeline.blessing"
key: "infra-validator-pipeline_blessing"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_Evaluator.blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -1264,7 +1264,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_Evaluator.blessing"
key: "Evaluator_blessing"
}
}
}
Expand Down Expand Up @@ -1301,7 +1301,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_InfraValidator.blessing"
key: "InfraValidator_blessing"
}
}
}
Expand Down Expand Up @@ -1333,7 +1333,7 @@ nodes {
index_op {
expression {
placeholder {
key: "model"
key: "Trainer_model"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_UpstreamComponent.num"
key: "UpstreamComponent_num"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ nodes {
index_op {
expression {
placeholder {
key: "_UpstreamComponent.num"
key: "UpstreamComponent_num"
}
}
}
Expand Down
13 changes: 11 additions & 2 deletions tfx/types/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def trigger_by_property(self, *property_keys: str):
return self._with_input_trigger(TriggerByProperty(property_keys))

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(self)
raise NotImplementedError()

def __eq__(self, other):
return self is other
Expand Down Expand Up @@ -557,6 +557,11 @@ def set_external(self, predefined_artifact_uris: List[str]) -> None:
def set_as_async_channel(self) -> None:
self._is_async = True

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(
self, f'{self.producer_component_id}_{self.output_key}'
)


@doc_controls.do_not_generate_docs
class UnionChannel(BaseChannel):
Expand Down Expand Up @@ -703,6 +708,9 @@ def trigger_by_property(self, *property_keys: str):
'trigger_by_property is not implemented for PipelineInputChannel.'
)

def future(self) -> ChannelWrappedPlaceholder:
return ChannelWrappedPlaceholder(self, f'{self._output_key}')


class ExternalPipelineChannel(BaseChannel):
"""Channel subtype that is used to get artifacts from external MLMD db."""
Expand Down Expand Up @@ -787,7 +795,8 @@ def set_key(self, key: Optional[str]):
Args:
key: The new key for the channel.
"""
self._key = key
del key # unused.
return

def __getitem__(self, index: int) -> ChannelWrappedPlaceholder:
if self._index is not None:
Expand Down
4 changes: 1 addition & 3 deletions tfx/types/channel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,8 @@ def unwrap_simple_channel_placeholder(
# proto paths above and been getting default messages all along. If this
# sub-message is present, then the whole chain was correct.
not index_op.expression.HasField('placeholder')
# ChannelWrappedPlaceholder uses INPUT_ARTIFACT for some reason, and has
# no key when encoded with encode().
# ChannelWrappedPlaceholder uses INPUT_ARTIFACT for some reason.
or cwp.type != placeholder_pb2.Placeholder.Type.INPUT_ARTIFACT
or cwp.key
# For the `[0]` part of the desired shape.
or index_op.index != 0
):
Expand Down
6 changes: 3 additions & 3 deletions tfx/types/channel_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Tests for tfx.utils.channel."""

import tensorflow as tf
from absl.testing import absltest
from tfx.dsl.placeholder import placeholder as ph
from tfx.types import artifact
from tfx.types import channel
Expand All @@ -25,7 +25,7 @@ class _MyArtifact(artifact.Artifact):
TYPE_NAME = 'MyTypeName'


class ChannelUtilsTest(tf.test.TestCase):
class ChannelUtilsTest(absltest.TestCase):

def testArtifactCollectionAsChannel(self):
instance_a = _MyArtifact()
Expand Down Expand Up @@ -125,4 +125,4 @@ def testUnwrapSimpleChannelPlaceholderRejectsComplexPlaceholders(self):


if __name__ == '__main__':
tf.test.main()
absltest.main()

0 comments on commit 5ef7ee6

Please sign in to comment.