Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let resolver op be able to get external artifacts. #6750

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
98 changes: 80 additions & 18 deletions tfx/dsl/input_resolution/ops/latest_policy_model_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for LatestPolicyModel operator."""

import collections
import enum
from typing import Dict, List
from typing import Dict, List, Optional, Tuple

from tfx import types
from tfx.dsl.input_resolution import resolver_op
Expand All @@ -24,6 +25,7 @@
from tfx.orchestration.portable.mlmd import event_lib
from tfx.orchestration.portable.mlmd import filter_query_builder as q
from tfx.types import artifact_utils
from tfx.types import external_artifact_utils
from tfx.utils import typing_utils

from ml_metadata.proto import metadata_store_pb2
Expand Down Expand Up @@ -204,6 +206,33 @@ def _build_result_dictionary(
return result


def _dedpupe_model_artifacts(
models: Optional[List[artifact_utils.Artifact]],
) -> Tuple[List[artifact_utils.Artifact], List[int]]:
"""Dedupes a list of Model artifacts."""
if not models:
return [], []

model_by_external_id = {}
model_by_id = {}

for m in models:
if m.external_id:
model_by_external_id[m.external_id] = m
else:
model_by_id[m.id] = m

deduped_models = list(model_by_external_id.values()) + list(
model_by_id.values()
)
model_artifact_ids = [
external_artifact_utils.get_id_from_external_id(i)
for i in model_by_external_id.keys()
] + list(model_by_id.keys())

return (deduped_models, model_artifact_ids)


class LatestPolicyModel(
resolver_op.ResolverOp,
canonical_name='tfx.LatestPolicyModel',
Expand Down Expand Up @@ -325,6 +354,25 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):
if self.policy == Policy.LATEST_EXPORTED:
return {ops_utils.MODEL_KEY: [models[0]]}

are_models_external = [m.is_external for m in models]
if any(are_models_external) and not all(are_models_external):
raise exceptions.InvalidArgument(
'Inputs to the LastestPolicyModel are from both current pipeline and'
' external pipeline. LastestPolicyModel does not support such usage.'
)
if all(are_models_external):
pipeline_assets = set([
external_artifact_utils.get_pipeline_asset_from_external_id(
m.mlmd_artifact.external_id
)
for m in models
])
if len(pipeline_assets) != 1:
raise exceptions.InvalidArgument(
'Input models to the LastestPolicyModel are from multiple'
' pipelines. LastestPolicyModel does not support such usage.'
)

# If ModelBlessing and/or ModelInfraBlessing artifacts were included in
# input_dict, then we will only consider those child artifacts.
specifies_child_artifacts = (
Expand All @@ -334,7 +382,17 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):
input_child_artifacts = input_dict.get(
ops_utils.MODEL_BLESSSING_KEY, []
) + input_dict.get(ops_utils.MODEL_INFRA_BLESSING_KEY, [])
input_child_artifact_ids = set([a.id for a in input_child_artifacts])

input_child_artifact_ids = set()
for a in input_child_artifacts:
if a.is_external:
input_child_artifact_ids.add(
external_artifact_utils.get_id_from_external_id(
a.mlmd_artifact.external_id
)
)
else:
input_child_artifact_ids.add(a.id)

# If the ModelBlessing and ModelInfraBlessing lists are empty, then no
# child artifacts can be considered and we raise a SkipSignal. This can
Expand Down Expand Up @@ -362,8 +420,8 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):

# There could be multiple events with the same execution ID but different
# artifact IDs (e.g. model and baseline_model passed to an Evaluator), so we
# keep the values of model_artifact_ids_by_execution_id as sets.
model_artifact_ids = sorted(set(m.id for m in models))
# need to deduplicate the Model artifacts.
deduped_models, model_artifact_ids = _dedpupe_model_artifacts(models)

downstream_artifact_type_names_filter_query = q.to_sql_string([
ops_utils.MODEL_BLESSING_TYPE_NAME,
Expand Down Expand Up @@ -407,10 +465,13 @@ def event_filter(event):
else:
return event_lib.is_valid_output_event(event)

mlmd_resolver = metadata_resolver.MetadataResolver(self.context.store)
mlmd_resolver = metadata_resolver.MetadataResolver(
self.context.store,
mlmd_connection_manager=self.context.mlmd_connection_manager,
)
# Populate the ModelRelations associated with each Model artifact and its
# children.
model_relations_by_model_artifact_id = collections.defaultdict(
model_relations_by_model_identifier = collections.defaultdict(
ModelRelations
)
artifact_type_by_name: Dict[str, metadata_store_pb2.ArtifactType] = {}
Expand All @@ -419,34 +480,35 @@ def event_filter(event):
# fetching downstream artifacts, because
# `get_downstream_artifacts_by_artifact_ids()` supports at most 100 ids
# as starting artifact ids.
for id_index in range(0, len(model_artifact_ids), ops_utils.BATCH_SIZE):
batch_model_artifact_ids = model_artifact_ids[
for id_index in range(0, len(deduped_models), ops_utils.BATCH_SIZE):
batch_model_artifacts = deduped_models[
id_index : id_index + ops_utils.BATCH_SIZE
]
# Set `max_num_hops` to 50, which should be enough for this use case.
batch_downstream_artifacts_and_types_by_model_ids = (
mlmd_resolver.get_downstream_artifacts_by_artifact_ids(
batch_model_artifact_ids,
batch_downstream_artifacts_and_types_by_model_identifier = (
mlmd_resolver.get_downstream_artifacts_by_artifacts(
batch_model_artifacts,
max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS,
filter_query=filter_query,
event_filter=event_filter,
)
)

for (
model_artifact_id,
model_identifier,
artifacts_and_types,
) in batch_downstream_artifacts_and_types_by_model_ids.items():
) in batch_downstream_artifacts_and_types_by_model_identifier.items():
for downstream_artifact, artifact_type in artifacts_and_types:
artifact_type_by_name[artifact_type.name] = artifact_type
model_relations = model_relations_by_model_artifact_id[
model_artifact_id
]
model_relations.add_downstream_artifact(downstream_artifact)
model_relations_by_model_identifier[
model_identifier
].add_downstream_artifact(downstream_artifact)

# Find the latest model and ModelRelations that meets the Policy.
result = {}
for model in models:
model_relations = model_relations_by_model_artifact_id[model.id]
identifier = external_artifact_utils.identifier(model)
model_relations = model_relations_by_model_identifier[identifier]
if model_relations.meets_policy(self.policy):
result[ops_utils.MODEL_KEY] = [model]
break
Expand Down
5 changes: 5 additions & 0 deletions tfx/dsl/input_resolution/ops/latest_policy_model_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.dsl.input_resolution.ops.latest_policy_model_op."""
import os
from typing import Dict, List, Optional
from unittest import mock

from absl.testing import parameterized
import tensorflow as tf
Expand All @@ -22,6 +24,7 @@
from tfx.dsl.input_resolution.ops import ops
from tfx.dsl.input_resolution.ops import ops_utils
from tfx.dsl.input_resolution.ops import test_utils
from tfx.orchestration import metadata
from tfx.orchestration.portable.input_resolution import exceptions

from ml_metadata.proto import metadata_store_pb2
Expand Down Expand Up @@ -146,6 +149,7 @@ def _run_latest_policy_model(self, *args, **kwargs):
args=args,
kwargs=kwargs,
store=self.store,
mlmd_handle_like=self.mlmd_cm,
)

def setUp(self):
Expand All @@ -158,6 +162,7 @@ def setUp(self):

self.artifacts = [self.model_1, self.model_2, self.model_3]


def assertDictKeysEmpty(
self,
output_dict: Dict[str, List[types.Artifact]],
Expand Down