Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add proto2.message for accept type of python custom component #6693

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 0 additions & 2 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
* `ph.make_proto()` allows constructing proto-valued placeholders, e.g. for
larger config protos fed to a component.
* `ph.join_path()` is like `os.path.join()` but for placeholders.
* Support passing in `experimental_debug_stripper` into the Transform
pipeline runner.

## Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,10 @@ def _convert_to_kube_env(
def _convert_to_resource_requirements(
resources: infra_validator_pb2.Resources
) -> k8s_client.V1ResourceRequirements:
if hasattr(k8s_client.V1ResourceRequirements, 'claims'):
return k8s_client.V1ResourceRequirements(
requests=dict(resources.requests),
limits=dict(resources.limits),
claims=dict(resources.claims),
)
else:
return k8s_client.V1ResourceRequirements(
requests=dict(resources.requests),
limits=dict(resources.limits),
)
return k8s_client.V1ResourceRequirements(
requests=dict(resources.requests),
limits=dict(resources.limits),
)


class KubernetesRunner(base_runner.BaseModelServerRunner):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from google.protobuf import json_format


def _CreateServingSpec(payload: Dict[str, Any]):
def _create_serving_spec(payload: Dict[str, Any]):
result = infra_validator_pb2.ServingSpec()
json_format.ParseDict(payload, result)
return result
Expand Down Expand Up @@ -194,30 +194,35 @@ def testBuildPodManifest_InsideKfp_OverrideConfig(self):
'service_account_name': 'chocolate-latte',
'active_deadline_seconds': 123,
'serving_pod_overrides': {
'annotations': {'best_ticker': 'goog'},
'env': [
{'name': 'TICKER', 'value': 'GOOG'},
{'name': 'NAME_ONLY'},
{
'name': 'SECRET',
'value_from': {
'secret_key_ref': {'name': 'my_secret', 'key': 'my_key'}
},
},
],
'resources': {
# TODO(b/328171600): Uncomment when version of kubernetes
# is matched with TFX. Kubernetes >= 26 supports 'claims' field,
# while TFX is at version 12.
'claims': {},
'requests': {'memory': '2Gi', 'cpu': '1'},
'limits': {'memory': '4Gi', 'cpu': '2'},
'annotations': {
'best_ticker': 'goog'
},
},
'env': [{
'name': 'TICKER',
'value': 'GOOG'
}, {
'name': 'NAME_ONLY'
}, {
'name': 'SECRET',
'value_from': {
'secret_key_ref': {
'name': 'my_secret',
'key': 'my_key'
}
}
}],
'resources': {
'requests': {
'memory': '2Gi',
'cpu': '1'
},
'limits': {
'memory': '4Gi',
'cpu': '2'
},
}
}
}
if not hasattr(k8s_client.V1ResourceRequirements, 'claims'):
k8s_config_dict['serving_pod_overrides']['resources'].pop('claims')

runner = self._CreateKubernetesRunner(k8s_config_dict=k8s_config_dict)

# Act.
Expand Down
36 changes: 2 additions & 34 deletions tfx/components/statistics_gen/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ def Do(
not also contain a schema.
- exclude_splits: JSON-serialized list of names of splits where
statistics and sample should not be generated.
- sample_rate_by_split: Optionally, A dict mapping split_name to sample
rate, which is used to apply a different sample rate to the
corresponding split. When this is supplied, it will overwrite the
single sample rate on stats_options_json.

Raises:
ValueError when a schema is provided both as an input and as part of the
Expand Down Expand Up @@ -103,16 +99,6 @@ def Do(
% type(exclude_splits)
)

# Load sample_rate_by_split from execution properties.
sample_rate_by_split = (
json_utils.loads(
exec_properties.get(
standard_component_specs.SAMPLE_RATE_BY_SPLIT_KEY, 'null'
)
)
or {}
)

# Setup output splits.
examples = artifact_utils.get_single_instance(
input_dict[standard_component_specs.EXAMPLES_KEY]
Expand All @@ -131,14 +117,6 @@ def Do(
splits = artifact_utils.decode_split_names(examples.split_names)

split_names = [split for split in splits if split not in exclude_splits]

# Check if sample_rate_by_split contains invalid split names
for split in sample_rate_by_split:
if split not in split_names:
logging.error(
'Split %s provided in sample_rate_by_split is not valid.', split
)

statistics_artifact = artifact_utils.get_single_instance(
output_dict[standard_component_specs.STATISTICS_KEY]
)
Expand All @@ -155,14 +133,13 @@ def Do(
)
except Exception as e: # pylint: disable=broad-except
# log on failures to not bring down Statsgen jobs
logging.exception('Failed to generate stats dashboard link because %s', e)
logging.error('Failed to generate stats dashboard link because %s', e)
statistics_artifact.set_string_custom_property(STATS_DASHBOARD_LINK, '')

stats_options = options.StatsOptions()
stats_options_json = exec_properties.get(
standard_component_specs.STATS_OPTIONS_JSON_KEY
)

if stats_options_json:
# TODO(b/150802589): Move jsonable interface to tfx_bsl and use
# json_utils
Expand Down Expand Up @@ -232,15 +209,6 @@ def Do(
)
binary_stats_output_path = os.path.join(output_uri, DEFAULT_FILE_NAME)

# Update sample rate for each split in stats_options if
# sample_rate_by_split is provided
split_stats_options = tfdv.StatsOptions.from_json(
stats_options.to_json())
if sample_rate_by_split:
sample_rate = sample_rate_by_split.get(split, None)
if sample_rate is not None:
split_stats_options.sample_rate = sample_rate

data = p | 'TFXIORead[%s]' % split >> tfxio.BeamSource()
if write_sharded_output:
sharded_stats_output_prefix = os.path.join(
Expand All @@ -259,7 +227,7 @@ def Do(
_ = (
data
| 'GenerateStatistics[%s]' % split
>> tfdv.GenerateStatistics(split_stats_options)
>> tfdv.GenerateStatistics(stats_options)
| 'WriteStatsOutput[%s]' % split >> write_transform
)
logging.info(
Expand Down
30 changes: 5 additions & 25 deletions tfx/components/statistics_gen/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,18 @@
'testcase_name': 'no_sharded_output',
'sharded_output': False,
'custom_split_uri': False,
'sample_rate_by_split': 'null',
},
{
'testcase_name': 'custom_split_uri',
'sharded_output': False,
'custom_split_uri': True,
'sample_rate_by_split': 'null',
},
{
'testcase_name': 'sample_rate_by_split',
'sharded_output': False,
'custom_split_uri': False,
# set a higher sample rate since test data is small
'sample_rate_by_split': '{"train": 0.4, "eval": 0.6}',
},
{
'testcase_name': 'sample_rate_split_nonexist',
'sharded_output': False,
'custom_split_uri': False,
'sample_rate_by_split': '{"test": 0.05}',
},
]
if tfdv.default_sharded_output_supported():
_EXECUTOR_TEST_PARAMS.append({
'testcase_name': 'yes_sharded_output',
'sharded_output': True,
'custom_split_uri': False,
'sample_rate_by_split': 'null',
})
_TEST_SPAN_NUMBER = 16000

Expand Down Expand Up @@ -91,12 +75,7 @@ def _validate_sharded_stats_output(self, stats_prefix):
self._validate_stats(stats)

@parameterized.named_parameters(*_EXECUTOR_TEST_PARAMS)
def testDo(
self,
sharded_output: bool,
custom_split_uri: bool,
sample_rate_by_split: str,
):
def testDo(self, sharded_output: bool, custom_split_uri: bool):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata')
output_data_dir = os.path.join(
Expand Down Expand Up @@ -129,9 +108,10 @@ def testDo(

exec_properties = {
# List needs to be serialized before being passed into Do function.
standard_component_specs.EXCLUDE_SPLITS_KEY: json_utils.dumps(['test']),
standard_component_specs.SHARDED_STATS_OUTPUT_KEY: sharded_output,
standard_component_specs.SAMPLE_RATE_BY_SPLIT_KEY: sample_rate_by_split,
standard_component_specs.EXCLUDE_SPLITS_KEY:
json_utils.dumps(['test']),
standard_component_specs.SHARDED_STATS_OUTPUT_KEY:
sharded_output,
}

# Create output dict.
Expand Down