Skip to content

Commit

Permalink
Let resolver op be able to get external artifacts.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623548131
  • Loading branch information
tfx-copybara committed May 2, 2024
1 parent 472f30b commit 3b265c3
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 59 deletions.
80 changes: 63 additions & 17 deletions tfx/dsl/input_resolution/ops/latest_policy_model_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for LatestPolicyModel operator."""

import collections
import enum
from typing import Dict, List
Expand All @@ -24,6 +25,7 @@
from tfx.orchestration.portable.mlmd import event_lib
from tfx.orchestration.portable.mlmd import filter_query_builder as q
from tfx.types import artifact_utils
from tfx.types import external_artifact_utils
from tfx.utils import typing_utils

from ml_metadata.proto import metadata_store_pb2
Expand Down Expand Up @@ -334,7 +336,17 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):
input_child_artifacts = input_dict.get(
ops_utils.MODEL_BLESSSING_KEY, []
) + input_dict.get(ops_utils.MODEL_INFRA_BLESSING_KEY, [])
input_child_artifact_ids = set([a.id for a in input_child_artifacts])

input_child_artifact_ids = set()
for a in input_child_artifacts:
if a.is_external:
input_child_artifact_ids.add(
external_artifact_utils.get_id_from_external_id(
a.mlmd_artifact.external_id
)
)
else:
input_child_artifact_ids.add(a.id)

# If the ModelBlessing and ModelInfraBlessing lists are empty, then no
# child artifacts can be considered and we raise a SkipSignal. This can
Expand Down Expand Up @@ -362,8 +374,38 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):

# There could be multiple events with the same execution ID but different
# artifact IDs (e.g. model and baseline_model passed to an Evaluator), so we
# keep the values of model_artifact_ids_by_execution_id as sets.
model_artifact_ids = sorted(set(m.id for m in models))
# keep the values of model_artifact_ids as sets.
are_models_external = [m.is_external for m in models]
if any(are_models_external) and not all(are_models_external):
raise exceptions.InvalidArgument(
'Inputs to the LastestPolicyModel are from both current pipeline and'
' external pipeline. LastestPolicyModel does not support such usage.'
)
if all(are_models_external):
pipeline_assets = set([
external_artifact_utils.get_pipeline_asset_from_external_id(
m.mlmd_artifact.external_id
)
for m in models
])
if len(pipeline_assets) != 1:
raise exceptions.InvalidArgument(
'Input models to the LastestPolicyModel are from multiple'
' pipelines. LastestPolicyModel does not support such usage.'
)

model_by_external_id = {m.mlmd_artifact.external_id: m for m in models}
deduped_models = list(model_by_external_id.values())
model_artifact_ids = sorted(
set([
external_artifact_utils.get_id_from_external_id(i)
for i in model_by_external_id.keys()
])
)
else:
model_by_id = {m.id: m for m in models}
deduped_models = list(model_by_id.values())
model_artifact_ids = sorted(set(model_by_id.keys()))

downstream_artifact_type_names_filter_query = q.to_sql_string([
ops_utils.MODEL_BLESSING_TYPE_NAME,
Expand Down Expand Up @@ -407,10 +449,13 @@ def event_filter(event):
else:
return event_lib.is_valid_output_event(event)

mlmd_resolver = metadata_resolver.MetadataResolver(self.context.store)
mlmd_resolver = metadata_resolver.MetadataResolver(
self.context.store,
mlmd_connection_manager=self.context.mlmd_connection_manager,
)
# Populate the ModelRelations associated with each Model artifact and its
# children.
model_relations_by_model_artifact_id = collections.defaultdict(
model_relations_by_model_identifier = collections.defaultdict(
ModelRelations
)
artifact_type_by_name: Dict[str, metadata_store_pb2.ArtifactType] = {}
Expand All @@ -419,34 +464,35 @@ def event_filter(event):
# fetching downstream artifacts, because
# `get_downstream_artifacts_by_artifact_ids()` supports at most 100 ids
# as starting artifact ids.
for id_index in range(0, len(model_artifact_ids), ops_utils.BATCH_SIZE):
batch_model_artifact_ids = model_artifact_ids[
for id_index in range(0, len(deduped_models), ops_utils.BATCH_SIZE):
batch_model_artifacts = deduped_models[
id_index : id_index + ops_utils.BATCH_SIZE
]
# Set `max_num_hops` to 50, which should be enough for this use case.
batch_downstream_artifacts_and_types_by_model_ids = (
mlmd_resolver.get_downstream_artifacts_by_artifact_ids(
batch_model_artifact_ids,
batch_downstream_artifacts_and_types_by_model_identifier = (
mlmd_resolver.get_downstream_artifacts_by_artifacts(
batch_model_artifacts,
max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS,
filter_query=filter_query,
event_filter=event_filter,
)
)

for (
model_artifact_id,
model_identifier,
artifacts_and_types,
) in batch_downstream_artifacts_and_types_by_model_ids.items():
) in batch_downstream_artifacts_and_types_by_model_identifier.items():
for downstream_artifact, artifact_type in artifacts_and_types:
artifact_type_by_name[artifact_type.name] = artifact_type
model_relations = model_relations_by_model_artifact_id[
model_artifact_id
]
model_relations.add_downstream_artifact(downstream_artifact)
model_relations_by_model_identifier[
model_identifier
].add_downstream_artifact(downstream_artifact)

# Find the latest model and ModelRelations that meets the Policy.
result = {}
for model in models:
model_relations = model_relations_by_model_artifact_id[model.id]
identifier = external_artifact_utils.identifier(model)
model_relations = model_relations_by_model_identifier[identifier]
if model_relations.meets_policy(self.policy):
result[ops_utils.MODEL_KEY] = [model]
break
Expand Down
5 changes: 5 additions & 0 deletions tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.dsl.input_resolution.ops.latest_policy_model_op."""
import os
from typing import Dict, List, Optional
from unittest import mock

from absl.testing import parameterized
import tensorflow as tf
Expand All @@ -22,6 +24,7 @@
from tfx.dsl.input_resolution.ops import ops
from tfx.dsl.input_resolution.ops import ops_utils
from tfx.dsl.input_resolution.ops import test_utils
from tfx.orchestration import metadata
from tfx.orchestration.portable.input_resolution import exceptions

from ml_metadata.proto import metadata_store_pb2
Expand Down Expand Up @@ -146,6 +149,7 @@ def _run_latest_policy_model(self, *args, **kwargs):
args=args,
kwargs=kwargs,
store=self.store,
mlmd_handle_like=self.mlmd_cm,
)

def setUp(self):
Expand All @@ -158,6 +162,7 @@ def setUp(self):

self.artifacts = [self.model_1, self.model_2, self.model_3]


def assertDictKeysEmpty(
self,
output_dict: Dict[str, List[types.Artifact]],
Expand Down
53 changes: 42 additions & 11 deletions tfx/dsl/input_resolution/ops/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing utility for builtin resolver ops."""
from typing import Type, Any, Dict, List, Optional, Sequence, Tuple, Union, Mapping
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
from unittest import mock

from absl.testing import parameterized

from tfx import types
from tfx.dsl.compiler import compiler_context
from tfx.dsl.compiler import node_inputs_compiler
Expand All @@ -27,6 +26,7 @@
from tfx.dsl.input_resolution import resolver_op
from tfx.dsl.input_resolution.ops import ops_utils
from tfx.orchestration import pipeline
from tfx.orchestration import mlmd_connection_manager as mlmd_cm
from tfx.proto.orchestration import pipeline_pb2
from tfx.types import artifact as tfx_artifact
from tfx.types import artifact_utils
Expand Down Expand Up @@ -201,15 +201,19 @@ def prepare_tfx_artifact(
properties: Optional[Dict[str, Union[int, str]]] = None,
custom_properties: Optional[Dict[str, Union[int, str]]] = None,
state: metadata_store_pb2.Artifact.State = metadata_store_pb2.Artifact.State.LIVE,
connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None,
) -> types.Artifact:
"""Adds a single artifact to MLMD and returns the TFleX Artifact object."""
mlmd_artifact = self.put_artifact(
artifact.TYPE_NAME,
properties=properties,
custom_properties=custom_properties,
state=state,
connection_config=connection_config,
)
artifact_type = self.store.get_artifact_type(artifact.TYPE_NAME)

store = self.get_store(connection_config)
artifact_type = store.get_artifact_type(artifact.TYPE_NAME)
return artifact_utils.deserialize_artifact(artifact_type, mlmd_artifact)

def unwrap_tfx_artifacts(
Expand All @@ -222,48 +226,59 @@ def build_node_context(
self,
pipeline_name: str,
node_id: str,
connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None,
):
"""Returns a "node" Context with name "pipeline_name.node_id."""
context = self.put_context(
context_type='node', context_name=f'{pipeline_name}.{node_id}'
context_type='node',
context_name=f'{pipeline_name}.{node_id}',
connection_config=connection_config,
)
return context

def create_examples(
self,
spans_and_versions: Sequence[Tuple[int, int]],
contexts: Sequence[metadata_store_pb2.Context] = (),
connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None,
) -> List[types.Artifact]:
"""Build Examples artifacts and add an ExampleGen execution to MLMD."""
examples = []
for span, version in spans_and_versions:
examples.append(
self.prepare_tfx_artifact(
Examples, properties={'span': span, 'version': version}
)
Examples,
properties={'span': span, 'version': version},
connection_config=connection_config,
),
)
self.put_execution(
'ExampleGen',
inputs={},
outputs={'examples': self.unwrap_tfx_artifacts(examples)},
contexts=contexts,
connection_config=connection_config,
)
return examples

def transform_examples(
self,
examples: List[types.Artifact],
contexts: Sequence[metadata_store_pb2.Context] = (),
connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None,
) -> types.Artifact:
inputs = {'examples': self.unwrap_tfx_artifacts(examples)}
transform_graph = self.prepare_tfx_artifact(TransformGraph)
transform_graph = self.prepare_tfx_artifact(
TransformGraph, connection_config=connection_config
)
self.put_execution(
'Transform',
inputs=inputs,
outputs={
'transform_graph': self.unwrap_tfx_artifacts([transform_graph])
},
contexts=contexts,
connection_config=connection_config,
)
return transform_graph

Expand All @@ -273,6 +288,7 @@ def train_on_examples(
examples: List[types.Artifact],
transform_graph: Optional[types.Artifact] = None,
contexts: Sequence[metadata_store_pb2.Context] = (),
connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None,
):
"""Add an Execution to MLMD where a Trainer trains on the examples."""
inputs = {'examples': self.unwrap_tfx_artifacts(examples)}
Expand All @@ -283,6 +299,7 @@ def train_on_examples(
inputs=inputs,
outputs={'model': self.unwrap_tfx_artifacts([model])},
contexts=contexts,
connection_config=connection_config,
)

def evaluator_bless_model(
Expand All @@ -291,10 +308,13 @@ def evaluator_bless_model(
blessed: bool = True,
baseline_model: Optional[types.Artifact] = None,
contexts: Sequence[metadata_store_pb2.Context] = (),
connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None,
) -> types.Artifact:
"""Add an Execution to MLMD where the Evaluator blesses the model."""
model_blessing = self.prepare_tfx_artifact(
ModelBlessing, custom_properties={'blessed': int(blessed)}
ModelBlessing,
custom_properties={'blessed': int(blessed)},
connection_config=connection_config,
)

inputs = {'model': self.unwrap_tfx_artifacts([model])}
Expand All @@ -306,6 +326,7 @@ def evaluator_bless_model(
inputs=inputs,
outputs={'blessing': self.unwrap_tfx_artifacts([model_blessing])},
contexts=contexts,
connection_config=connection_config,
)

return model_blessing
Expand All @@ -315,21 +336,25 @@ def infra_validator_bless_model(
model: types.Artifact,
blessed: bool = True,
contexts: Sequence[metadata_store_pb2.Context] = (),
connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None,
) -> types.Artifact:
"""Add an Execution to MLMD where the InfraValidator blesses the model."""
if blessed:
custom_properties = {'blessing_status': 'INFRA_BLESSED'}
else:
custom_properties = {'blessing_status': 'INFRA_NOT_BLESSED'}
model_infra_blessing = self.prepare_tfx_artifact(
ModelInfraBlessing, custom_properties=custom_properties
ModelInfraBlessing,
custom_properties=custom_properties,
connection_config=connection_config,
)

self.put_execution(
'InfraValidator',
inputs={'model': self.unwrap_tfx_artifacts([model])},
outputs={'result': self.unwrap_tfx_artifacts([model_infra_blessing])},
contexts=contexts,
connection_config=connection_config,
)

return model_infra_blessing
Expand All @@ -339,15 +364,19 @@ def push_model(
model: types.Artifact,
model_push: Optional[types.Artifact] = None,
contexts: Sequence[metadata_store_pb2.Context] = (),
connection_config: Optional[metadata_store_pb2.ConnectionConfig] = None,
):
"""Add an Execution to MLMD where the Pusher pushes the model."""
if model_push is None:
model_push = self.prepare_tfx_artifact(ModelPush)
model_push = self.prepare_tfx_artifact(
ModelPush, connection_config=connection_config
)
self.put_execution(
'ServomaticPusher',
inputs={'model_export': self.unwrap_tfx_artifacts([model])},
outputs={'model_push': self.unwrap_tfx_artifacts([model_push])},
contexts=contexts,
connection_config=connection_config,
)
return model_push

Expand All @@ -370,6 +399,7 @@ def strict_run_resolver_op(
args: Tuple[Any, ...],
kwargs: Mapping[str, Any],
store: Optional[mlmd.MetadataStore] = None,
mlmd_handle_like: Optional[mlmd_cm.HandleLike] = None,
):
"""Runs ResolverOp with strict type checking."""
if len(args) != len(op_type.arg_data_types):
Expand All @@ -396,7 +426,8 @@ def strict_run_resolver_op(
context = resolver_op.Context(
store=store
if store is not None
else mock.MagicMock(spec=mlmd.MetadataStore)
else mock.MagicMock(spec=mlmd.MetadataStore),
mlmd_handle_like=mlmd_handle_like,
)
op.set_context(context)
result = op.apply(*args)
Expand Down

0 comments on commit 3b265c3

Please sign in to comment.