Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encode producer component id and output key when CWP is created from an OutputChannel #6726

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions tfx/dsl/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def _compile_node(

# Step 3: Node inputs
node_inputs_compiler.compile_node_inputs(
pipeline_ctx, tfx_node, node.inputs)

pipeline_ctx, tfx_node, node.inputs
)
# Step 4: Node outputs
if (isinstance(tfx_node, base_component.BaseComponent) or
compiler_utils.is_importer(tfx_node)):
Expand Down
5 changes: 5 additions & 0 deletions tfx/dsl/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ def node_context_name(pipeline_context_name: str, node_id: str):

def implicit_channel_key(channel: types.BaseChannel):
"""Key of a channel to the node that consumes the channel as input."""
if (
isinstance(channel, channel_types.ChannelWrappedPlaceholder)
and channel.key
):
return channel.key
if isinstance(channel, channel_types.PipelineInputChannel):
channel = cast(channel_types.PipelineInputChannel, channel)
return f"_{channel.pipeline.id}.{channel.output_key}"
Expand Down
19 changes: 16 additions & 3 deletions tfx/dsl/compiler/node_inputs_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,20 +421,33 @@ def _compile_conditionals(
contexts = context.dsl_context_registry.get_contexts(tfx_node)
except ValueError:
return

for dsl_context in contexts:
if not isinstance(dsl_context, conditional.CondContext):
continue
cond_context = cast(conditional.CondContext, dsl_context)
for channel in channel_utils.get_dependent_channels(cond_context.predicate):
# Since the channels here are *always* from a CWP, which we now set the
# key by default on for OutputChannel, we must re-create the input key if
# an output channel is used, otherwise the wrong key may be used by
# `get_input_key` (e.g. if the producer component is also used as data
# input to the component.)
# Note that this means we potentially have several inputs with identical
# artifact queries under the hood, which should be optimized away if we
# run into performance issues.
if isinstance(channel, channel_types.OutputChannel):
assert isinstance(channel, channel_types.OutputChannel)
input_key = f'_{channel.producer_component_id}.{channel.output_key}'
else:
input_key = context.get_node_context(tfx_node).get_input_key(channel)
_compile_input_spec(
pipeline_ctx=context,
tfx_node=tfx_node,
input_key=context.get_node_context(tfx_node).get_input_key(channel),
input_key=input_key,
channel=channel,
hidden=False,
min_count=1,
result=result)
result=result,
)
cond_id = context.get_conditional_id(cond_context)
expr = channel_utils.encode_placeholder_with_channels(
cond_context.predicate, context.get_node_context(tfx_node).get_input_key
Expand Down
10 changes: 7 additions & 3 deletions tfx/dsl/compiler/node_inputs_compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,8 @@ def testCompileConditionals(self):
self.assertEqual(result.inputs[cond_input_key].min_count, 1)
self.assertLen(result.conditionals, 1)
cond = list(result.conditionals.values())[0]
self.assertProtoEquals("""
self.assertProtoEquals(
"""
operator {
compare_op {
op: EQUAL
Expand All @@ -594,7 +595,7 @@ def testCompileConditionals(self):
index_op {
expression {
placeholder {
key: "%s"
key: "_CondNode.x"
}
}
}
Expand All @@ -605,7 +606,9 @@ def testCompileConditionals(self):
}
}
}
""" % cond_input_key, cond.placeholder_expression)
""",
cond.placeholder_expression,
)

def testCompileInputsForDynamicProperties(self):
producer = DummyNode('Producer')
Expand Down Expand Up @@ -786,6 +789,7 @@ def testMainInputsShouldNotBeHidden(self, explicit_x, explicit_y, dynamic_z):
self.assertEqual(result.inputs['y'].hidden, False)
else:
self.assertNotIn('y', result.inputs)
print(f'{result.inputs=}')
self.assertEqual(result.inputs['_Producer.z'].hidden, not dynamic_z)

def test_min_count_with_allow_empty_from(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1942,6 +1942,43 @@ nodes {
}
}
inputs {
inputs {
key: "_Evaluator.blessing"
value {
channels {
producer_node_query {
id: "Evaluator"
}
context_queries {
type {
name: "pipeline"
}
name {
field_value {
string_value: "composable-pipeline"
}
}
}
context_queries {
type {
name: "node"
}
name {
field_value {
string_value: "composable-pipeline.Evaluator"
}
}
}
artifact_query {
type {
name: "ModelBlessing"
}
}
output_key: "blessing"
}
min_count: 1
}
}
inputs {
key: "blessing"
value {
Expand Down Expand Up @@ -2109,7 +2146,7 @@ nodes {
index_op {
expression {
placeholder {
key: "blessing"
key: "_Evaluator.blessing"
}
}
}
Expand Down
51 changes: 50 additions & 1 deletion tfx/dsl/compiler/testdata/conditional_pipeline_input_v2_ir.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,55 @@ nodes {
min_count: 1
}
}
inputs {
key: "_Trainer.model"
value {
channels {
producer_node_query {
id: "Trainer"
}
context_queries {
type {
name: "pipeline"
}
name {
field_value {
string_value: "cond"
}
}
}
context_queries {
type {
name: "pipeline_run"
}
name {
runtime_parameter {
name: "pipeline-run-id"
type: STRING
}
}
}
context_queries {
type {
name: "node"
}
name {
field_value {
string_value: "cond.Trainer"
}
}
}
artifact_query {
type {
name: "Model"
base_type: MODEL
}
}
output_key: "model"
}
min_count: 1
}
}
inputs {
key: "model"
value {
Expand Down Expand Up @@ -1333,7 +1382,7 @@ nodes {
index_op {
expression {
placeholder {
key: "model"
key: "_Trainer.model"
}
}
}
Expand Down
66 changes: 32 additions & 34 deletions tfx/orchestration/kubeflow/v2/compiler_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,36 +266,38 @@ def setUp(self):

@parameterized.named_parameters(
{
'testcase_name':
'two_sides_placeholder',
'predicate':
_TEST_CHANNEL.future()[0].property('int1') <
_TEST_CHANNEL.future()[0].property('int2'),
'expected_cel':
'(inputs.artifacts[\'key\'].artifacts[0].metadata[\'int1\'] < '
'inputs.artifacts[\'key\'].artifacts[0].metadata[\'int2\'])',
'testcase_name': 'two_sides_placeholder',
'predicate': _TEST_CHANNEL.future()[0].property(
'int1'
) < _TEST_CHANNEL.future()[0].property('int2'),
'expected_cel': (
"(inputs.artifacts['_producer.foo'].artifacts[0].metadata['int1'] < "
"inputs.artifacts['_producer.foo'].artifacts[0].metadata['int2'])"
),
},
{
'testcase_name':
'left_side_placeholder_right_side_int',
'predicate':
_TEST_CHANNEL.future()[0].property('int') < 1,
'expected_cel':
'(inputs.artifacts[\'key\'].artifacts[0].metadata[\'int\'] < 1.0)',
'testcase_name': 'left_side_placeholder_right_side_int',
'predicate': _TEST_CHANNEL.future()[0].property('int') < 1,
'expected_cel': (
"(inputs.artifacts['_producer.foo'].artifacts[0].metadata['int']"
' < 1.0)'
),
},
{
'testcase_name': 'left_side_placeholder_right_side_float',
'predicate': _TEST_CHANNEL.future()[0].property('float') < 1.1,
'expected_cel':
'(inputs.artifacts[\'key\'].artifacts[0].metadata[\'float\'] < '
'1.1)',
'expected_cel': (
"(inputs.artifacts['_producer.foo'].artifacts[0].metadata['float']"
' < 1.1)'
),
},
{
'testcase_name': 'left_side_placeholder_right_side_string',
'predicate': _TEST_CHANNEL.future()[0].property('str') == 'test_str',
'expected_cel':
'(inputs.artifacts[\'key\'].artifacts[0].metadata[\'str\'] == '
'\'test_str\')',
'expected_cel': (
"(inputs.artifacts['_producer.foo'].artifacts[0].metadata['str']"
" == 'test_str')"
),
},
)
def testComparison(self, predicate, expected_cel):
Expand All @@ -310,8 +312,9 @@ def testComparison(self, predicate, expected_cel):

def testArtifactUri(self):
predicate = _TEST_CHANNEL.future()[0].uri == 'test_str'
expected_cel = ('(inputs.artifacts[\'key\'].artifacts[0].uri == '
'\'test_str\')')
expected_cel = (
"(inputs.artifacts['_producer.foo'].artifacts[0].uri == 'test_str')"
)
channel_to_key_map = {
_TEST_CHANNEL: 'key',
}
Expand All @@ -323,8 +326,10 @@ def testArtifactUri(self):

def testNegation(self):
predicate = _TEST_CHANNEL.future()[0].property('int') != 1
expected_cel = ('!((inputs.artifacts[\'key\'].artifacts[0]'
'.metadata[\'int\'] == 1.0))')
expected_cel = (
"!((inputs.artifacts['_producer.foo'].artifacts[0]"
".metadata['int'] == 1.0))"
)
channel_to_key_map = {
_TEST_CHANNEL: 'key',
}
Expand All @@ -337,8 +342,9 @@ def testNegation(self):
def testConcat(self):
predicate = _TEST_CHANNEL.future()[0].uri + 'something' == 'test_str'
expected_cel = (
'((inputs.artifacts[\'key\'].artifacts[0].uri + \'something\') == '
'\'test_str\')')
"((inputs.artifacts['_producer.foo'].artifacts[0].uri + 'something') =="
" 'test_str')"
)
channel_to_key_map = {
_TEST_CHANNEL: 'key',
}
Expand All @@ -360,14 +366,6 @@ def testUnsupportedOperator(self):
ValueError, 'Got unsupported placeholder operator base64_encode_op.'):
compiler_utils.placeholder_to_cel(placeholder_pb)

def testPlaceholderWithoutKey(self):
predicate = _TEST_CHANNEL.future()[0].uri == 'test_str'
placeholder_pb = predicate.encode()
with self.assertRaisesRegex(
ValueError,
'Only supports accessing placeholders with a key on KFPv2.'):
compiler_utils.placeholder_to_cel(placeholder_pb)


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ inputs {
}
}
trigger_policy {
condition: "!((inputs.artifacts['input1'].artifacts[0].uri == 'uri')) && (inputs.artifacts['_producer_task_2.output1'].artifacts[0].metadata['property'] == 'value1')"
condition: "!((inputs.artifacts['_producer_task_1.output1'].artifacts[0].uri == 'uri')) && (inputs.artifacts['_producer_task_2.output1'].artifacts[0].metadata['property'] == 'value1')"
}
component_ref {
name: "DummyConsumerComponent"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ inputs {
}
}
trigger_policy {
condition: "!((inputs.artifacts['input1'].artifacts[0].uri == 'uri')) && (inputs.artifacts['_producer_task_2.output1'].artifacts[0].metadata['property'] == 'value1')"
condition: "!((inputs.artifacts['_producer_task_1.output1'].artifacts[0].uri == 'uri')) && (inputs.artifacts['_producer_task_2.output1'].artifacts[0].metadata['property'] == 'value1')"
}
component_ref {
name: "DummyConsumerComponent"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,21 +715,23 @@ def testConditionals(self):

with self.subTest('blessed == 1'):
node_inputs = pipeline_pb2.NodeInputs(
inputs={'x': x},
inputs={'_foo.x': x},
input_graphs={'graph_1': graph_1},
conditionals={'cond_1': cond_1})
conditionals={'cond_1': cond_1},
)

result = node_inputs_resolver.resolve(self._mlmd_handle, node_inputs)
self.assertEqual(result, [{'x': [a1]}, {'x': [a4]}])
self.assertEqual(result, [{'_foo.x': [a1]}, {'_foo.x': [a4]}])

with self.subTest('blessed == 1 and tag == foo'):
node_inputs = pipeline_pb2.NodeInputs(
inputs={'x': x},
inputs={'_foo.x': x},
input_graphs={'graph_1': graph_1},
conditionals={'cond_1': cond_1, 'cond_2': cond_2})
conditionals={'cond_1': cond_1, 'cond_2': cond_2},
)

result = node_inputs_resolver.resolve(self._mlmd_handle, node_inputs)
self.assertEqual(result, [{'x': [a1]}])
self.assertEqual(result, [{'_foo.x': [a1]}])

def testConditionals_FalseCondAlwaysReturnsEmpty(self):
a = self.create_artifacts(1)
Expand Down Expand Up @@ -778,7 +780,7 @@ def testConditionals_FalseCondAlwaysReturnsEmpty(self):
node_inputs = NodeInputs(
inputs={
'a': x1,
'b': x2,
'_foo.x': x2,
},
conditionals={'cond': cond},
)
Expand Down