Skip to content

Commit

Permalink
Let resolver op be able to get external artifacts.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623548131
  • Loading branch information
tfx-copybara committed Apr 15, 2024
1 parent 9332479 commit 04c4cb6
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 20 deletions.
76 changes: 69 additions & 7 deletions tfx/dsl/input_resolution/ops/latest_policy_model_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for LatestPolicyModel operator."""

import collections
import enum
from typing import Dict
Expand All @@ -23,6 +24,7 @@
from tfx.orchestration.portable.mlmd import event_lib
from tfx.orchestration.portable.mlmd import filter_query_builder as q
from tfx.types import artifact_utils
from tfx.types import external_artifact_utils
from tfx.utils import typing_utils

from ml_metadata.proto import metadata_store_pb2
Expand Down Expand Up @@ -324,7 +326,17 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):
input_child_artifacts = input_dict.get(
ops_utils.MODEL_BLESSSING_KEY, []
) + input_dict.get(ops_utils.MODEL_INFRA_BLESSING_KEY, [])
input_child_artifact_ids = set([a.id for a in input_child_artifacts])

input_child_artifact_ids = set()
for a in input_child_artifacts:
if a.is_external:
input_child_artifact_ids.add(
external_artifact_utils.get_id_from_external_id(
a.mlmd_artifact.external_id
)
)
else:
input_child_artifact_ids.add(a.id)

# If the ModelBlessing and ModelInfraBlessing lists are empty, then no
# child artifacts can be considered and we raise a SkipSignal. This can
Expand Down Expand Up @@ -353,7 +365,52 @@ def apply(self, input_dict: typing_utils.ArtifactMultiMap):
# There could be multiple events with the same execution ID but different
# artifact IDs (e.g. model and baseline_model passed to an Evaluator), so we
# keep the values of model_artifact_ids_by_execution_id as sets.
model_artifact_ids = sorted(set(m.id for m in models))
are_model_external = [m.is_external for m in models]
if any(are_model_external) and not all(are_model_external):
raise exceptions.InvalidArgument(
'Inputs to the LastestPolicyModel are from both current pipeline and'
' external pipeline. LastestPolicyModel does not support such usage.'
)

if not all(are_model_external):
model_artifact_ids = sorted(set(m.id for m in models))
store = self.context.store
else:
# If the input models are from external pipeline, try to get a MLMD
# `store` which connects to the external MLMD instance.
model_external_ids = sorted(
set([m.mlmd_artifact.external_id for m in models])
)
model_artifact_ids = sorted(
set([
external_artifact_utils.get_id_from_external_id(i)
for i in model_external_ids
])
)

pipeline_assets = [
external_artifact_utils.get_pipeline_asset_from_external_id(i)
for i in model_external_ids
]
pipeline_assets = set([a.SerializeToString() for a in pipeline_assets])
if len(pipeline_assets) > 1:
raise exceptions.InvalidArgument(
'Input models to the LastestPolicyModel are from multiple'
' pipelines. LastestPolicyModel does not support such usage.'
)

external_connection_config = (
external_artifact_utils.get_external_connection_config(
model_external_ids[0]
)
)
if not self.context.mlmd_manager:
raise ValueError('Not able to connect to external MLMD instance.')
store = self.context.mlmd_manager.get_mlmd_handle(
external_connection_config
).store

mlmd_resolver = metadata_resolver.MetadataResolver(store)

downstream_artifact_type_names_filter_query = q.to_sql_string([
ops_utils.MODEL_BLESSING_TYPE_NAME,
Expand Down Expand Up @@ -397,9 +454,7 @@ def event_filter(event):
else:
return event_lib.is_valid_output_event(event)

mlmd_resolver = metadata_resolver.MetadataResolver(self.context.store)
downstream_artifacts_by_model_ids = {}

# Split `model_artifact_ids` into batches with batch size = 100 while
# fetching downstream artifacts, because
# `get_downstream_artifacts_by_artifact_ids()` supports at most 100 ids
Expand All @@ -420,12 +475,12 @@ def event_filter(event):
downstream_artifacts_by_model_ids.update(
batch_downstream_artifacts_by_model_ids
)

# Populate the ModelRelations associated with each Model artifact and its
# children.
model_relations_by_model_artifact_id = collections.defaultdict(
ModelRelations
)

type_ids = set()
for (
model_artifact_id,
Expand Down Expand Up @@ -455,15 +510,22 @@ def event_filter(event):
# Find the latest model and ModelRelations that meets the Policy.
result = {}
for model in models:
model_relations = model_relations_by_model_artifact_id[model.id]
if model.is_external:
model_id = external_artifact_utils.get_id_from_external_id(
model.mlmd_artifact.external_id
)
else:
model_id = model.id
model_relations = model_relations_by_model_artifact_id[model_id]
if model_relations.meets_policy(self.policy):
result[ops_utils.MODEL_KEY] = [model]
break
else:
return self._raise_skip_signal_or_return_empty_dict(
f'No model found that meets the Policy {Policy(self.policy).name}'
)
artifact_types = self.context.store.get_artifact_types_by_id(type_ids)

artifact_types = store.get_artifact_types_by_id(type_ids)
artifact_type_by_name = {t.name: t for t in artifact_types}
return _build_result_dictionary(
result, model_relations, self.policy, artifact_type_by_name
Expand Down
8 changes: 8 additions & 0 deletions tfx/dsl/input_resolution/resolver_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for ResolverOp and its related definitions."""

from __future__ import annotations

import abc
from typing import Any, Generic, Literal, Mapping, Optional, Sequence, Set, Type, TypeVar, Union

import attr
from tfx import types
from tfx.orchestration import mlmd_connection_manager as mlmd_cm
from tfx.proto.orchestration import pipeline_pb2
from tfx.utils import json_utils
from tfx.utils import typing_utils
Expand All @@ -31,8 +33,14 @@
@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class Context:
"""Context for running ResolverOp."""
# TODO(b/302730333) We could remove store and only use mlmd_manager. Keeping
# this for now to keep it backward compatible with other resolver ops.
# MetadataStore for MLMD read access.
store: mlmd.MetadataStore

# An MLMDConnectionManager instance. It can manage multiple MLMD connections.
mlmd_manager: Optional[mlmd_cm.MLMDConnectionManager] = None

# TODO(jjong): Add more context such as current pipeline, current pipeline
# run, and current running node information.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
import collections
import dataclasses
import functools
from typing import Union, Sequence, Mapping, Tuple, List, Iterable, Callable
from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple, Union, cast

from tfx import types
from tfx.dsl.components.common import resolver
from tfx.dsl.input_resolution import resolver_op
from tfx.dsl.input_resolution.ops import ops
from tfx.orchestration import data_types_utils
from tfx.orchestration import metadata
from tfx.orchestration import mlmd_connection_manager as mlmd_cm
from tfx.orchestration.portable.input_resolution import exceptions
from tfx.proto.orchestration import pipeline_pb2
from tfx.utils import topsort
Expand All @@ -54,6 +55,7 @@
class _Context:
mlmd_handle: metadata.Metadata
input_graph: pipeline_pb2.InputGraph
mlmd_manager: Optional[mlmd_cm.MLMDConnectionManager] = None


def _topologically_sorted_node_ids(
Expand Down Expand Up @@ -131,7 +133,11 @@ def _evaluate_op_node(
f'nodes[{node_id}] has unknown op_type {op_node.op_type}.') from e
if issubclass(op_type, resolver_op.ResolverOp):
op: resolver_op.ResolverOp = op_type.create(**kwargs)
op.set_context(resolver_op.Context(store=ctx.mlmd_handle.store))
op.set_context(
resolver_op.Context(
store=ctx.mlmd_handle.store, mlmd_manager=ctx.mlmd_manager
)
)
return op.apply(*args)
elif issubclass(op_type, resolver.ResolverStrategy):
if len(args) != 1:
Expand Down Expand Up @@ -207,7 +213,7 @@ def new_graph_fn(data: Mapping[str, _Data]):


def build_graph_fn(
mlmd_handle: metadata.Metadata,
handle_like: mlmd_cm.HandleLike,
input_graph: pipeline_pb2.InputGraph,
) -> Tuple[_GraphFn, List[str]]:
"""Build a functional interface for the `input_graph`.
Expand All @@ -222,7 +228,7 @@ def build_graph_fn(
z = graph_fn({'x': inputs['x'], 'y': inputs['y']})
Args:
mlmd_handle: A `Metadata` instance.
handle_like: A `mlmd_cm.HandleLike` instance.
input_graph: An `pipeline_pb2.InputGraph` proto.
Returns:
Expand All @@ -235,7 +241,11 @@ def build_graph_fn(
f'result_node {input_graph.result_node} does not exist in input_graph. '
f'Valid node ids: {list(input_graph.nodes.keys())}')

context = _Context(mlmd_handle=mlmd_handle, input_graph=input_graph)
context = _Context(
mlmd_handle=mlmd_cm.get_handle(handle_like), input_graph=input_graph
)
if isinstance(handle_like, mlmd_cm.MLMDConnectionManager):
context.mlmd_manager = cast(mlmd_cm.MLMDConnectionManager, handle_like)

input_key_to_node_id = {}
for node_id in input_graph.nodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _join_artifacts(


def _resolve_input_graph_ref(
mlmd_handle: metadata.Metadata,
handle_like: mlmd_cm.HandleLike,
node_inputs: pipeline_pb2.NodeInputs,
input_key: str,
resolved: Dict[str, List[_Entry]],
Expand All @@ -352,12 +352,12 @@ def _resolve_input_graph_ref(
(i.e. `InputGraphRef` with the same `graph_id`).
Args:
mlmd_handle: A `Metadata` instance.
handle_like: A `mlmd_cm.HandleLike` instance.
node_inputs: A `NodeInputs` proto.
input_key: A target input key whose corresponding `InputSpec` has an
`InputGraphRef`.
`InputGraphRef`.
resolved: A dict that contains the already resolved inputs, and to which the
resolved result would be written from this function.
resolved result would be written from this function.
"""
graph_id = node_inputs.inputs[input_key].input_graph_ref.graph_id
input_graph = node_inputs.input_graphs[graph_id]
Expand All @@ -372,7 +372,8 @@ def _resolve_input_graph_ref(
}

graph_fn, graph_input_keys = input_graph_resolver.build_graph_fn(
mlmd_handle, node_inputs.input_graphs[graph_id])
handle_like, node_inputs.input_graphs[graph_id]
)
for partition, input_dict in _join_artifacts(resolved, graph_input_keys):
result = graph_fn(input_dict)
if graph_output_type == _DataType.ARTIFACT_LIST:
Expand Down Expand Up @@ -514,9 +515,7 @@ def resolve(
(partition_utils.NO_PARTITION, _filter_live(artifacts))
]
elif input_spec.input_graph_ref.graph_id:
_resolve_input_graph_ref(
mlmd_cm.get_handle(handle_like), node_inputs, input_key,
resolved)
_resolve_input_graph_ref(handle_like, node_inputs, input_key, resolved)
elif input_spec.mixed_inputs.input_keys:
_resolve_mixed_inputs(node_inputs, input_key, resolved)
elif input_spec.HasField('static_inputs'):
Expand Down
35 changes: 35 additions & 0 deletions tfx/types/external_artifact_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Third party version of external_artifact_utils.py."""


def get_artifact_id_from_external_id(external_id: str):
del external_id


def get_pipeline_asset_from_external_id(
external_id: str,
):
del external_id


def get_external_connection_config(
external_id: str,
):
del external_id


def identifier(artifact):
del artifact

0 comments on commit 04c4cb6

Please sign in to comment.