Skip to content

Commit

Permalink
Let resolver op be able to get external artifacts.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623548131
  • Loading branch information
tfx-copybara committed Apr 24, 2024
1 parent 01a4b3d commit c200142
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 37 deletions.
74 changes: 58 additions & 16 deletions tfx/dsl/input_resolution/ops/latest_policy_model_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for LatestPolicyModel operator."""

import collections
import enum
from typing import Dict
Expand All @@ -23,6 +24,7 @@
from tfx.orchestration.portable.mlmd import event_lib
from tfx.orchestration.portable.mlmd import filter_query_builder as q
from tfx.types import artifact_utils
from tfx.types import external_artifact_utils
from tfx.utils import typing_utils

from ml_metadata.proto import metadata_store_pb2
Expand Down Expand Up @@ -344,7 +346,17 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):
input_child_artifacts = input_dict.get(
ops_utils.MODEL_BLESSSING_KEY, []
) + input_dict.get(ops_utils.MODEL_INFRA_BLESSING_KEY, [])
input_child_artifact_ids = set([a.id for a in input_child_artifacts])

input_child_artifact_ids = set()
for a in input_child_artifacts:
if a.is_external:
input_child_artifact_ids.add(
external_artifact_utils.get_id_from_external_id(
a.mlmd_artifact.external_id
)
)
else:
input_child_artifact_ids.add(a.id)

# If the ModelBlessing and ModelInfraBlessing lists are empty, then no
# child artifacts can be considered and we raise a SkipSignal. This can
Expand Down Expand Up @@ -372,8 +384,38 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):

# There could be multiple events with the same execution ID but different
# artifact IDs (e.g. model and baseline_model passed to an Evaluator), so we
# keep the values of model_artifact_ids_by_execution_id as sets.
model_artifact_ids = sorted(set(m.id for m in models))
# keep the values of model_artifact_ids as sets.
are_models_external = [m.is_external for m in models]
if any(are_models_external) and not all(are_models_external):
raise exceptions.InvalidArgument(
'Inputs to the LastestPolicyModel are from both current pipeline and'
' external pipeline. LastestPolicyModel does not support such usage.'
)
if all(are_models_external):
pipeline_assets = set([
external_artifact_utils.get_pipeline_asset_from_external_id(
m.mlmd_artifact.external_id
)
for m in models
])
if len(pipeline_assets) != 1:
raise exceptions.InvalidArgument(
'Input models to the LastestPolicyModel are from multiple'
' pipelines. LastestPolicyModel does not support such usage.'
)

model_by_external_id = {m.mlmd_artifact.external_id: m for m in models}
deduped_models = list(model_by_external_id.values())
model_artifact_ids = sorted(
set([
external_artifact_utils.get_id_from_external_id(i)
for i in model_by_external_id.keys()
])
)
else:
model_by_id = {m.id: m for m in models}
deduped_models = list(model_by_id.values())
model_artifact_ids = sorted(set(model_by_id.keys()))

downstream_artifact_type_names_filter_query = q.to_sql_string([
ops_utils.MODEL_BLESSING_TYPE_NAME,
Expand Down Expand Up @@ -420,7 +462,7 @@ def event_filter(event):
mlmd_resolver = metadata_resolver.MetadataResolver(self.context.store)
# Populate the ModelRelations associated with each Model artifact and its
# children.
model_relations_by_model_artifact_id = collections.defaultdict(
model_relations_by_model_identifier = collections.defaultdict(
ModelRelations
)
artifact_type_by_name: Dict[str, metadata_store_pb2.ArtifactType] = {}
Expand All @@ -429,34 +471,34 @@ def event_filter(event):
# fetching downstream artifacts, because
# `get_downstream_artifacts_by_artifact_ids()` supports at most 100 ids
# as starting artifact ids.
for id_index in range(0, len(model_artifact_ids), ops_utils.BATCH_SIZE):
batch_model_artifact_ids = model_artifact_ids[
for id_index in range(0, len(deduped_models), ops_utils.BATCH_SIZE):
batch_model_artifacts = deduped_models[
id_index : id_index + ops_utils.BATCH_SIZE
]
# Set `max_num_hops` to 50, which should be enough for this use case.
batch_downstream_artifacts_and_types_by_model_ids = (
mlmd_resolver.get_downstream_artifacts_by_artifact_ids(
batch_model_artifact_ids,
batch_downstream_artifacts_and_types_by_model_identifier = (
mlmd_resolver.get_downstream_artifacts_by_artifacts(
batch_model_artifacts,
max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS,
filter_query=filter_query,
event_filter=event_filter,
)
)
for (
model_artifact_id,
model_identifier,
artifacts_and_types,
) in batch_downstream_artifacts_and_types_by_model_ids.items():
) in batch_downstream_artifacts_and_types_by_model_identifier.items():
for downstream_artifact, artifact_type in artifacts_and_types:
artifact_type_by_name[artifact_type.name] = artifact_type
model_relations = model_relations_by_model_artifact_id[
model_artifact_id
]
model_relations.add_downstream_artifact(downstream_artifact)
model_relations_by_model_identifier[
model_identifier
].add_downstream_artifact(downstream_artifact)

# Find the latest model and ModelRelations that meets the Policy.
result = {}
for model in models:
model_relations = model_relations_by_model_artifact_id[model.id]
identifier = external_artifact_utils.identifier(model)
model_relations = model_relations_by_model_identifier[identifier]
if model_relations.meets_policy(self.policy):
result[ops_utils.MODEL_KEY] = [model]
break
Expand Down
31 changes: 25 additions & 6 deletions tfx/dsl/input_resolution/resolver_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for ResolverOp and its related definitions."""

from __future__ import annotations

import abc
from typing import Any, Generic, Literal, Mapping, Optional, Sequence, Set, Type, TypeVar, Union
from typing import Any, Generic, Literal, Mapping, Optional, Sequence, Set, Type, TypeVar, Union, cast

import attr
from tfx import types
from tfx.orchestration import mlmd_connection_manager as mlmd_cm
from tfx.proto.orchestration import pipeline_pb2
from tfx.utils import json_utils
from tfx.utils import typing_utils
Expand All @@ -28,13 +30,30 @@

# Mark frozen as context instance may be used across multiple operator
# invocations.
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class Context:
"""Context for running ResolverOp."""
# MetadataStore for MLMD read access.
store: mlmd.MetadataStore
# TODO(jjong): Add more context such as current pipeline, current pipeline
# run, and current running node information.

def __init__(
self,
store=mlmd.MetadataStore,
mlmd_handle_like: Optional[mlmd_cm.HandleLike] = None,
):
self._store = store
self._mlmd_handle_like = mlmd_handle_like

@property
def store(self):
return self._store

@property
def mlmd_connection_manager(self):
if isinstance(self._mlmd_handle_like, mlmd_cm.MLMDConnectionManager):
return cast(mlmd_cm.MLMDConnectionManager, self._mlmd_handle_like)
else:
return None

# # TODO(jjong): Add more context such as current pipeline, current pipeline
# # run, and current running node information.


# Note that to use DataType as a generic type parameter (e.g.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
import collections
import dataclasses
import functools
from typing import Union, Sequence, Mapping, Tuple, List, Iterable, Callable
from typing import Callable, Iterable, List, Mapping, Sequence, Tuple, Union

from tfx import types
from tfx.dsl.components.common import resolver
from tfx.dsl.input_resolution import resolver_op
from tfx.dsl.input_resolution.ops import ops
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration import mlmd_connection_manager as mlmd_cm
from tfx.orchestration.portable.input_resolution import exceptions
from tfx.proto.orchestration import pipeline_pb2
from tfx.utils import topsort
Expand All @@ -52,8 +52,12 @@

@dataclasses.dataclass
class _Context:
mlmd_handle: metadata.Metadata
input_graph: pipeline_pb2.InputGraph
mlmd_handle_like: mlmd_cm.HandleLike

@property
def mlmd_handle(self):
return mlmd_cm.get_handle(self.mlmd_handle_like)


def _topologically_sorted_node_ids(
Expand Down Expand Up @@ -131,7 +135,12 @@ def _evaluate_op_node(
f'nodes[{node_id}] has unknown op_type {op_node.op_type}.') from e
if issubclass(op_type, resolver_op.ResolverOp):
op: resolver_op.ResolverOp = op_type.create(**kwargs)
op.set_context(resolver_op.Context(store=ctx.mlmd_handle.store))
op.set_context(
resolver_op.Context(
store=mlmd_cm.get_handle(ctx.mlmd_handle_like).store,
mlmd_handle_like=ctx.mlmd_handle_like,
)
)
return op.apply(*args)
elif issubclass(op_type, resolver.ResolverStrategy):
if len(args) != 1:
Expand Down Expand Up @@ -207,7 +216,7 @@ def new_graph_fn(data: Mapping[str, _Data]):


def build_graph_fn(
mlmd_handle: metadata.Metadata,
handle_like: mlmd_cm.HandleLike,
input_graph: pipeline_pb2.InputGraph,
) -> Tuple[_GraphFn, List[str]]:
"""Build a functional interface for the `input_graph`.
Expand All @@ -222,7 +231,7 @@ def build_graph_fn(
z = graph_fn({'x': inputs['x'], 'y': inputs['y']})
Args:
mlmd_handle: A `Metadata` instance.
handle_like: A `mlmd_cm.HandleLike` instance.
input_graph: An `pipeline_pb2.InputGraph` proto.
Returns:
Expand All @@ -235,7 +244,7 @@ def build_graph_fn(
f'result_node {input_graph.result_node} does not exist in input_graph. '
f'Valid node ids: {list(input_graph.nodes.keys())}')

context = _Context(mlmd_handle=mlmd_handle, input_graph=input_graph)
context = _Context(mlmd_handle_like=handle_like, input_graph=input_graph)

input_key_to_node_id = {}
for node_id in input_graph.nodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _join_artifacts(


def _resolve_input_graph_ref(
mlmd_handle: metadata.Metadata,
handle_like: mlmd_cm.HandleLike,
node_inputs: pipeline_pb2.NodeInputs,
input_key: str,
resolved: Dict[str, List[_Entry]],
Expand All @@ -352,12 +352,12 @@ def _resolve_input_graph_ref(
(i.e. `InputGraphRef` with the same `graph_id`).
Args:
mlmd_handle: A `Metadata` instance.
handle_like: A `mlmd_cm.HandleLike` instance.
node_inputs: A `NodeInputs` proto.
input_key: A target input key whose corresponding `InputSpec` has an
`InputGraphRef`.
`InputGraphRef`.
resolved: A dict that contains the already resolved inputs, and to which the
resolved result would be written from this function.
resolved result would be written from this function.
"""
graph_id = node_inputs.inputs[input_key].input_graph_ref.graph_id
input_graph = node_inputs.input_graphs[graph_id]
Expand All @@ -372,7 +372,8 @@ def _resolve_input_graph_ref(
}

graph_fn, graph_input_keys = input_graph_resolver.build_graph_fn(
mlmd_handle, node_inputs.input_graphs[graph_id])
handle_like, node_inputs.input_graphs[graph_id]
)
for partition, input_dict in _join_artifacts(resolved, graph_input_keys):
result = graph_fn(input_dict)
if graph_output_type == _DataType.ARTIFACT_LIST:
Expand Down Expand Up @@ -514,9 +515,7 @@ def resolve(
(partition_utils.NO_PARTITION, _filter_live(artifacts))
]
elif input_spec.input_graph_ref.graph_id:
_resolve_input_graph_ref(
mlmd_cm.get_handle(handle_like), node_inputs, input_key,
resolved)
_resolve_input_graph_ref(handle_like, node_inputs, input_key, resolved)
elif input_spec.mixed_inputs.input_keys:
_resolve_mixed_inputs(node_inputs, input_key, resolved)
elif input_spec.HasField('static_inputs'):
Expand Down
35 changes: 35 additions & 0 deletions tfx/types/external_artifact_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Third party version of external_artifact_utils.py."""


def get_artifact_id_from_external_id(external_id: str):
del external_id


def get_pipeline_asset_from_external_id(
external_id: str,
):
del external_id


def get_external_connection_config(
external_id: str,
):
del external_id


def identifier(artifact):
del artifact

0 comments on commit c200142

Please sign in to comment.