-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support
tfds.features.Optional
in TFDS for tfds.data_source users.
PiperOrigin-RevId: 599089261
- Loading branch information
Showing
10 changed files
with
337 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# coding=utf-8 | ||
# Copyright 2023 The TensorFlow Datasets Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Optional feature connector.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any, TypedDict, Union | ||
|
||
import numpy as np | ||
from tensorflow_datasets.core.features import feature as feature_lib | ||
from tensorflow_datasets.core.features import features_dict | ||
from tensorflow_datasets.core.features import tensor_feature | ||
from tensorflow_datasets.core.proto import feature_pb2 | ||
from tensorflow_datasets.core.utils import py_utils | ||
from tensorflow_datasets.core.utils import type_utils | ||
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf | ||
|
||
Json = type_utils.Json | ||
|
||
# TODO: review all docstrings!! | ||
|
||
|
||
class OptionalExample(TypedDict): | ||
has_value: bool | ||
value: Any | ||
|
||
|
||
class Optional(feature_lib.FeatureConnector): | ||
r"""Optional feature connector. | ||
Warning: This feature is under active development. For now, | ||
`tfds.features.Optional` can only wrap `tfds.features.Scalar`: | ||
```python | ||
features = tfds.features.FeaturesDict({ | ||
'myfeature': tfds.features.Optional(np.int32), | ||
}) | ||
``` | ||
In `_generate_examples`, you can yield `None` instead of the actual feature | ||
value: | ||
```python | ||
def _generate_examples(self): | ||
yield { | ||
'myfeature': None, # Accept `None` or feature | ||
} | ||
``` | ||
For Grain/NumPy users, `tfds.data_source` will output None values. | ||
```python | ||
ds = tfds.data_source('dataset_with_optional', split='train') | ||
for element in ds: | ||
if element is None: | ||
# ... | ||
else: | ||
# ... | ||
``` | ||
For tf.data users, None values don't make sense in tensors, so we haven't | ||
implemented the feature yet. If you're a tf.data user with a use case for | ||
optional values, we'd like to hear from you. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
feature: feature_lib.FeatureConnectorArg, | ||
*, | ||
doc: feature_lib.DocArg = None, | ||
): | ||
"""Constructor. | ||
Args: | ||
feature: The feature to wrap (any TFDS feature is supported). | ||
doc: Documentation of this feature (e.g. description). | ||
""" | ||
self._feature = features_dict.to_feature(feature) | ||
if not isinstance(self._feature, tensor_feature.Tensor): | ||
raise NotImplementedError( | ||
'tfds.features.Optional only supports Tensors. Refer to its' | ||
' documentation for more information.' | ||
) | ||
super().__init__(doc=doc) | ||
|
||
@py_utils.memoize() | ||
def get_tensor_info(self): | ||
"""See base class for details.""" | ||
return self._feature.get_tensor_info() | ||
|
||
@py_utils.memoize() | ||
def get_serialized_info(self): | ||
"""See base class for details.""" | ||
return self._feature.get_serialized_info() | ||
|
||
def __getitem__(self, key: str) -> feature_lib.FeatureConnector: | ||
"""Allows to access the underlying features directly.""" | ||
return self._feature[key] | ||
|
||
def __contains__(self, key: str) -> bool: | ||
return key in self._feature | ||
|
||
def save_metadata(self, *args, **kwargs): | ||
"""See base class for details.""" | ||
self._feature.save_metadata(*args, **kwargs) | ||
|
||
def load_metadata(self, *args, **kwargs): | ||
"""See base class for details.""" | ||
self._feature.load_metadata(*args, **kwargs) | ||
|
||
@classmethod | ||
def from_json_content(cls, value: feature_pb2.Optional) -> Optional: | ||
"""See base class for details.""" | ||
feature = feature_lib.FeatureConnector.from_proto(value.feature) | ||
return cls(feature) | ||
|
||
def to_json_content(self) -> feature_pb2.Optional: | ||
"""See base class for details.""" | ||
return feature_pb2.Optional(feature=self._feature.to_proto()) | ||
|
||
@property | ||
def feature(self): | ||
"""The inner feature.""" | ||
return self._feature | ||
|
||
def encode_example(self, example: Any) -> Any: | ||
"""See base class for details.""" | ||
if example is None: | ||
return None | ||
else: | ||
return self._feature.encode_example(example) | ||
|
||
def decode_example(self, example: Any) -> Any: | ||
"""See base class for details.""" | ||
raise NotImplementedError( | ||
'tfds.features.Optional only supports tfds.data_source.' | ||
) | ||
|
||
def decode_example_np( | ||
self, example: Any | ||
) -> type_utils.NpArrayOrScalar | None: | ||
if example is None: | ||
return None | ||
else: | ||
return self._feature.decode_example_np(example) |
56 changes: 56 additions & 0 deletions
56
tensorflow_datasets/core/features/optional_feature_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# coding=utf-8 | ||
# Copyright 2023 The TensorFlow Datasets Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for optional_feature.""" | ||
|
||
from absl.testing import parameterized | ||
import numpy as np | ||
from tensorflow_datasets import testing | ||
from tensorflow_datasets.core import features | ||
|
||
|
||
class ScalarFeatureTest( | ||
parameterized.TestCase, testing.FeatureExpectationsTestCase | ||
): | ||
|
||
@parameterized.parameters( | ||
(np.int64, 42, 42), | ||
(np.int64, None, testing.TestValue.NONE), | ||
(np.str_, 'foo', 'foo'), | ||
(np.str_, None, testing.TestValue.NONE), | ||
) | ||
def test_scalar(self, dtype, value, expected_np): | ||
self.assertFeature( | ||
feature=features.Optional( | ||
features.Scalar(dtype=dtype, doc='Some description') | ||
), | ||
shape=(), | ||
dtype=dtype, | ||
tests=[ | ||
testing.FeatureExpectationItem( | ||
value=value, | ||
raise_cls=NotImplementedError, | ||
raise_msg='supports tfds.data_source', | ||
), | ||
testing.FeatureExpectationItem( | ||
value=value, | ||
expected_np=expected_np, | ||
), | ||
], | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
testing.test_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.