Skip to content

Commit

Permalink
Set span property for ExampleAnomalies artifact in `DistributionVal…
Browse files Browse the repository at this point in the history
…idator`.

PiperOrigin-RevId: 628166134
  • Loading branch information
tfx-copybara committed Apr 30, 2024
1 parent 81164a1 commit 0737d3c
Show file tree
Hide file tree
Showing 17 changed files with 148 additions and 1,983 deletions.
1 change: 1 addition & 0 deletions tfx/components/distribution_validator/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def Do(
anomalies_artifact.split_names = artifact_utils.encode_split_names(
['%s_%s' % (test, baseline) for test, baseline in split_pairs]
)
anomalies_artifact.span = test_statistics.span

validation_metrics_artifact = None
if standard_component_specs.VALIDATION_METRICS_KEY in output_dict:
Expand Down
28 changes: 16 additions & 12 deletions tfx/components/distribution_validator/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ def testSplitPairs(self, split_pairs, expected_split_pair_names,
component_generated_alert_list=[
component_generated_alert_pb2.ComponentGeneratedAlertInfo(
alert_name=(
'[train_eval][span 0] Feature-level anomalies '
'[train_eval][span 2] Feature-level anomalies '
'present'
),
alert_body=(
'[train_eval][span 0] Feature(s) company, '
'[train_eval][span 2] Feature(s) company, '
'dropoff_census_tract contain(s) anomalies. See '
'Anomalies artifact for more details.'
),
Expand Down Expand Up @@ -260,11 +260,11 @@ def testSplitPairs(self, split_pairs, expected_split_pair_names,
component_generated_alert_list=[
component_generated_alert_pb2.ComponentGeneratedAlertInfo(
alert_name=(
'[train_eval][span 0] High num examples in '
'[train_eval][span 2] High num examples in '
'current dataset versus the previous span.'
),
alert_body=(
'[train_eval][span 0] The ratio of num examples '
'[train_eval][span 2] The ratio of num examples '
'in the current dataset versus the previous span '
'is 2.02094 (up to six significant digits), '
'which is above the threshold 1.'
Expand Down Expand Up @@ -372,11 +372,11 @@ def testSplitPairs(self, split_pairs, expected_split_pair_names,
component_generated_alert_list=[
component_generated_alert_pb2.ComponentGeneratedAlertInfo(
alert_name=(
'[train_eval][span 0] Feature-level anomalies '
'[train_eval][span 2] Feature-level anomalies '
'present'
),
alert_body=(
'[train_eval][span 0] Feature(s) company '
'[train_eval][span 2] Feature(s) company '
'contain(s) anomalies. See Anomalies artifact '
'for more details.'
),
Expand All @@ -401,6 +401,7 @@ def testAnomaliesGenerated(
stats_artifact.uri = os.path.join(source_data_dir, 'statistics_gen')
stats_artifact.split_names = artifact_utils.encode_split_names(
['train', 'eval'])
stats_artifact.span = 2

output_data_dir = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
Expand Down Expand Up @@ -557,6 +558,7 @@ def testStructData(self):
stats_artifact.split_names = artifact_utils.encode_split_names(
['train', 'eval']
)
stats_artifact.span = 3

struct_stats_train = text_format.Parse(
"""
Expand Down Expand Up @@ -684,9 +686,9 @@ def testStructData(self):
component_generated_alert_list=[
component_generated_alert_pb2.ComponentGeneratedAlertInfo(
alert_name=(
'[train_eval][span 0] Feature-level anomalies present'),
'[train_eval][span 3] Feature-level anomalies present'),
alert_body=(
'[train_eval][span 0] Feature(s) '
'[train_eval][span 3] Feature(s) '
'parent_feature.value_feature contain(s) anomalies. See '
'Anomalies artifact for more details.'),
)
Expand Down Expand Up @@ -1015,6 +1017,7 @@ def testEmptyData(self, stats_train, stats_eval, expected_anomalies):
stats_artifact.uri = os.path.join(source_data_dir, 'statistics_gen')
stats_artifact.split_names = artifact_utils.encode_split_names(
['train', 'eval'])
stats_artifact.span = 4

validation_config = text_format.Parse(
"""
Expand Down Expand Up @@ -1100,10 +1103,10 @@ def testEmptyData(self, stats_train, stats_eval, expected_anomalies):
component_generated_alert_list=[
component_generated_alert_pb2.ComponentGeneratedAlertInfo(
alert_name=(
'[train_eval][span 0] Feature-level anomalies present'
'[train_eval][span 4] Feature-level anomalies present'
),
alert_body=(
'[train_eval][span 0] Feature(s) first_feature contain(s) '
'[train_eval][span 4] Feature(s) first_feature contain(s) '
'anomalies. See Anomalies artifact for more details.'
),
),
Expand All @@ -1127,6 +1130,7 @@ def testAddOutput(self):
stats_artifact.split_names = artifact_utils.encode_split_names(
['train', 'eval']
)
stats_artifact.span = 5

validation_config = text_format.Parse(
"""
Expand Down Expand Up @@ -1193,10 +1197,10 @@ def testAddOutput(self):
component_generated_alert_list=[
component_generated_alert_pb2.ComponentGeneratedAlertInfo(
alert_name=(
'[train_eval][span 0] Feature-level anomalies present'
'[train_eval][span 5] Feature-level anomalies present'
),
alert_body=(
'[train_eval][span 0] Feature(s) '
'[train_eval][span 5] Feature(s) '
'parent_feature.value_feature contain(s) anomalies. See '
'Anomalies artifact for more details.'
),
Expand Down
1 change: 0 additions & 1 deletion tfx/components/example_validator/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def Do(self, input_dict: Dict[str, List[types.Artifact]],
output_dict[standard_component_specs.ANOMALIES_KEY])
anomalies_artifact.split_names = artifact_utils.encode_split_names(
split_names)
anomalies_artifact.span = stats_artifact.span

schema = io_utils.SchemaReader().read(
io_utils.get_only_uri_in_dir(
Expand Down
6 changes: 2 additions & 4 deletions tfx/components/example_validator/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ def test_create_anomalies_alerts(self):
alert_name='Feature-level anomalies present',
alert_body=(
'Feature(s) company contain(s) anomalies for split '
'train, span 11. See Anomalies artifact for more '
'train, span 0. See Anomalies artifact for more '
'details.'
),
),
component_generated_alert_pb2.ComponentGeneratedAlertInfo(
alert_name='Feature-level anomalies present',
alert_body=(
'Feature(s) company contain(s) anomalies for split '
'eval, span 11. See Anomalies artifact for more '
'eval, span 0. See Anomalies artifact for more '
'details.'
),
),
Expand All @@ -190,7 +190,6 @@ def testDo(
eval_stats_artifact.uri = os.path.join(source_data_dir, 'statistics_gen')
eval_stats_artifact.split_names = artifact_utils.encode_split_names(
['train', 'eval', 'test'])
eval_stats_artifact.span = 11

schema_artifact = standard_artifacts.Schema()
schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')
Expand Down Expand Up @@ -232,7 +231,6 @@ def testDo(
self.assertEqual(
artifact_utils.encode_split_names(['train', 'eval']),
validation_output.split_names)
self.assertEqual(eval_stats_artifact.span, validation_output.span)

# Check example_validator outputs.
train_anomalies_path = os.path.join(validation_output.uri, 'Split-train',
Expand Down

0 comments on commit 0737d3c

Please sign in to comment.