Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change #6717

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 41 additions & 18 deletions tfx/components/distribution_validator/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tensorflow_data_validation.utils import path
from tensorflow_data_validation.utils import schema_util
from tfx import types
from tfx.components.distribution_validator import utils
from tfx.components.statistics_gen import stats_artifact_utils
from tfx.dsl.components.base import base_executor
from tfx.orchestration.experimental.core import component_generated_alert_pb2
Expand All @@ -38,6 +39,7 @@
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_metadata.proto.v0 import statistics_pb2


# Default file name for anomalies output.
DEFAULT_FILE_NAME = 'SchemaDiff.pb'

Expand Down Expand Up @@ -176,17 +178,14 @@ def _add_anomalies_for_missing_comparisons(


def _generate_alerts_info_proto(
anomaly_info: anomalies_pb2.AnomalyInfo,
split_pair: str
anomaly_info: anomalies_pb2.AnomalyInfo, split_pair: str
) -> list[component_generated_alert_pb2.ComponentGeneratedAlertInfo]:
"""Generates a list of ComponentGeneratedAlertInfo from AnomalyInfo."""
result = []
for reason in anomaly_info.reason:
result.append(
component_generated_alert_pb2.ComponentGeneratedAlertInfo(
alert_name=(
f'[{split_pair}] {reason.short_description}'
),
alert_name=f'[{split_pair}] {reason.short_description}',
alert_body=f'[{split_pair}] {reason.description}',
)
)
Expand Down Expand Up @@ -278,9 +277,37 @@ def Do(
].artifacts.append(anomalies_artifact.mlmd_artifact)
return executor_output

config = exec_properties.get(
standard_component_specs.DISTRIBUTION_VALIDATOR_CONFIG_KEY
)
if (
input_dict.get(
standard_component_specs.ARTIFACT_DISTRIBUTION_VALIDATOR_CONFIG_KEY
)
is not None
and exec_properties.get(
standard_component_specs.DISTRIBUTION_VALIDATOR_CONFIG_KEY
)
is not None
):
raise ValueError(
'artifact_distribution_validator_config and'
' distribution_validator_config are provided at the same time.'
)
elif (
input_dict.get(
standard_component_specs.ARTIFACT_DISTRIBUTION_VALIDATOR_CONFIG_KEY
)
is not None
):
config_artifact = artifact_utils.get_single_instance(
input_dict[
standard_component_specs.ARTIFACT_DISTRIBUTION_VALIDATOR_CONFIG_KEY
]
)
config = utils.load_config_from_artifact(config_artifact)
else:
config = exec_properties.get(
standard_component_specs.DISTRIBUTION_VALIDATOR_CONFIG_KEY
)

custom_validation_config = exec_properties.get(
standard_component_specs.CUSTOM_VALIDATION_CONFIG_KEY
)
Expand Down Expand Up @@ -309,12 +336,10 @@ def Do(
if missing_split_pairs:
raise ValueError(
'Missing split pairs identified in include_split_pairs: %s'
% ', '.join(
[
'%s_%s' % (test, baseline)
for test, baseline in missing_split_pairs
]
)
% ', '.join([
'%s_%s' % (test, baseline)
for test, baseline in missing_split_pairs
])
)

anomalies_artifact.split_names = artifact_utils.encode_split_names(
Expand Down Expand Up @@ -353,9 +378,7 @@ def Do(
anomalies = _get_comparison_only_anomalies(full_anomalies)
anomalies = _add_anomalies_for_missing_comparisons(anomalies, config)

if anomalies.anomaly_info or anomalies.HasField(
'dataset_anomaly_info'
):
if anomalies.anomaly_info or anomalies.HasField('dataset_anomaly_info'):
blessed_value_dict[split_pair] = NOT_BLESSED_VALUE
else:
blessed_value_dict[split_pair] = BLESSED_VALUE
Expand Down Expand Up @@ -386,7 +409,7 @@ def Do(

executor_output.output_artifacts[
standard_component_specs.ANOMALIES_KEY
].artifacts.append(anomalies_artifact.mlmd_artifact)
].artifacts.append(anomalies_artifact.mlmd_artifact)

# Set component generated alerts execution property in ExecutorOutput if
# any anomalies alerts exist.
Expand Down
196 changes: 196 additions & 0 deletions tfx/components/distribution_validator/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,202 @@ def testAddOutput(self):
].proto_value.Unpack(actual_alerts)
self.assertEqual(actual_alerts, expected_alerts)

def testUseArtifactDVConfig(self):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata'
)

stats_artifact = standard_artifacts.ExampleStatistics()
stats_artifact.uri = os.path.join(source_data_dir, 'statistics_gen')
stats_artifact.split_names = artifact_utils.encode_split_names(
['train', 'eval']
)

output_data_dir = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
self._testMethodName,
)

validation_output = standard_artifacts.ExampleAnomalies()
validation_output.uri = os.path.join(output_data_dir, 'output')

validation_config = text_format.Parse(
"""
default_slice_config: {
feature: {
path: {
step: 'company'
}
distribution_comparator: {
infinity_norm: {
threshold: 0.0
}
}
}
}
""",
distribution_validator_pb2.DistributionValidatorConfig(),
)
binary_proto_filepath = os.path.join(
output_data_dir, 'test_custom_component', 'DVconfig.pb'
)
io_utils.write_bytes_file(
binary_proto_filepath, validation_config.SerializeToString()
)
config_artifact = standard_artifacts.Config()
config_artifact.uri = os.path.join(output_data_dir, 'test_custom_component')

input_dict = {
standard_component_specs.STATISTICS_KEY: [stats_artifact],
standard_component_specs.BASELINE_STATISTICS_KEY: [stats_artifact],
standard_component_specs.ARTIFACT_DISTRIBUTION_VALIDATOR_CONFIG_KEY: [
config_artifact],
}

# The analyzed splits are set for this test to get a single result proto.
exec_properties = {
# List needs to be serialized before being passed into Do function.
standard_component_specs.INCLUDE_SPLIT_PAIRS_KEY: json_utils.dumps(
[('train', 'eval')]
),
standard_component_specs.CUSTOM_VALIDATION_CONFIG_KEY: None,
}

output_dict = {
standard_component_specs.ANOMALIES_KEY: [validation_output],
}

distribution_validator_executor = executor.Executor()
_ = distribution_validator_executor.Do(
input_dict, output_dict, exec_properties
)

self.assertEqual(
artifact_utils.encode_split_names(['train_eval']),
validation_output.split_names,
)

distribution_anomalies_path = os.path.join(
validation_output.uri, 'SplitPair-train_eval', 'SchemaDiff.pb'
)
self.assertTrue(fileio.exists(distribution_anomalies_path))
distribution_anomalies_bytes = io_utils.read_bytes_file(
distribution_anomalies_path
)
distribution_anomalies = anomalies_pb2.Anomalies()
distribution_anomalies.ParseFromString(distribution_anomalies_bytes)
expected_anomalies = """anomaly_info {
key: "company"
value {
severity: ERROR
reason {
type: COMPARATOR_L_INFTY_HIGH
short_description: "High Linfty distance between current and previous"
description: "The Linfty distance between current and previous is 0.0122771 (up to six significant digits), above the threshold 0. The feature value with maximum difference is: Dispatch Taxi Affiliation"
}
path {
step: "company"
}
}
}
anomaly_name_format: SERIALIZED_PATH
drift_skew_info {
path {
step: "company"
}
drift_measurements {
type: L_INFTY
value: 0.012277129468474923
threshold: 0.0
}
}
"""
expected_anomalies = text_format.Parse(
expected_anomalies, anomalies_pb2.Anomalies()
)

self.assertEqualExceptBaseline(expected_anomalies, distribution_anomalies)
self.assertEqual(
validation_output.get_json_value_custom_property(
executor.ARTIFACT_PROPERTY_BLESSED_KEY
),
{'train_eval': 0},
)

def testInvalidArtifactDVConfigAndParameterConfig(self):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata'
)

stats_artifact = standard_artifacts.ExampleStatistics()
stats_artifact.uri = os.path.join(source_data_dir, 'statistics_gen')
stats_artifact.split_names = artifact_utils.encode_split_names(
['train', 'eval']
)

output_data_dir = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
self._testMethodName,
)

validation_output = standard_artifacts.ExampleAnomalies()
validation_output.uri = os.path.join(output_data_dir, 'output')

validation_config = text_format.Parse(
"""
default_slice_config: {
feature: {
path: {
step: 'company'
}
distribution_comparator: {
infinity_norm: {
threshold: 0.0
}
}
}
}
""",
distribution_validator_pb2.DistributionValidatorConfig(),
)
binary_proto_filepath = os.path.join(
output_data_dir, 'test_custom_component', 'DVconfig.pb'
)
io_utils.write_bytes_file(
binary_proto_filepath, validation_config.SerializeToString()
)
config_artifact = standard_artifacts.Config()
config_artifact.uri = os.path.join(output_data_dir, 'test_custom_component')

input_dict = {
standard_component_specs.STATISTICS_KEY: [stats_artifact],
standard_component_specs.BASELINE_STATISTICS_KEY: [stats_artifact],
standard_component_specs.ARTIFACT_DISTRIBUTION_VALIDATOR_CONFIG_KEY: [
config_artifact],
}

# The analyzed splits are set for this test to get a single result proto.
exec_properties = {
# List needs to be serialized before being passed into Do function.
standard_component_specs.INCLUDE_SPLIT_PAIRS_KEY: json_utils.dumps(
[('train', 'eval')]
),
standard_component_specs.DISTRIBUTION_VALIDATOR_CONFIG_KEY: (
validation_config
),
standard_component_specs.CUSTOM_VALIDATION_CONFIG_KEY: None,
}

output_dict = {
standard_component_specs.ANOMALIES_KEY: [validation_output],
}

distribution_validator_executor = executor.Executor()
with self.assertRaises(ValueError):
_ = distribution_validator_executor.Do(
input_dict, output_dict, exec_properties
)


if __name__ == '__main__':
absltest.main()
29 changes: 29 additions & 0 deletions tfx/components/distribution_validator/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DistributionValidator utils."""

from tfx.proto import distribution_validator_pb2
from tfx.types import artifact
from tfx.utils import io_utils


def load_config_from_artifact(
config_artifact: artifact.Artifact,
) -> distribution_validator_pb2.DistributionValidatorConfig:
"""Load a serialized DistributionValidatorConfig proto from artifact."""
fpath = io_utils.get_only_uri_in_dir(config_artifact.uri)

dv_config = distribution_validator_pb2.DistributionValidatorConfig()
dv_config.ParseFromString(io_utils.read_bytes_file(fpath))
return dv_config