Skip to content

Commit

Permalink
Support tfds.features.Optional in TFDS for tfds.data_source users.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599089261
  • Loading branch information
marcenacp authored and The TensorFlow Datasets Authors committed Jan 18, 2024
1 parent eb8582d commit 6a34b0c
Show file tree
Hide file tree
Showing 10 changed files with 340 additions and 96 deletions.
14 changes: 6 additions & 8 deletions tensorflow_datasets/core/example_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,15 @@ def run_with_reraise(fn, k, example_data, tensor_info):
# 'objects/tokens/flat_values': [0, 1, 2, 3, 4],
# 'objects/tokens/row_lengths_0': [3, 0, 2],
# }
features = utils.flatten_nest_dict(
{
k: run_with_reraise(
_add_ragged_fields, k, example_dict[k], tensor_info
)
for k, tensor_info in tensor_info_dict.items()
}
)
features = utils.flatten_nest_dict({
k: run_with_reraise(_add_ragged_fields, k, example_dict[k], tensor_info)
for k, tensor_info in tensor_info_dict.items()
})
features = {
k: run_with_reraise(_item_to_tf_feature, k, item, tensor_info)
for k, (item, tensor_info) in features.items()
# If the item is None, it doesn't appear in the proto at all.
if item is not None
}
return tf_example_pb2.Example(
features=tf_feature_pb2.Features(feature=features)
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_datasets/core/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""API defining dataset features (image, text, scalar,...).
See [the guide](https://www.tensorflow.org/datasets/features).
"""

from tensorflow_datasets.core.features.audio_feature import Audio
Expand All @@ -31,6 +30,7 @@
from tensorflow_datasets.core.features.features_dict import FeaturesDict
from tensorflow_datasets.core.features.image_feature import Image
from tensorflow_datasets.core.features.labeled_image import LabeledImage
from tensorflow_datasets.core.features.optional_feature import Optional
from tensorflow_datasets.core.features.scalar import Scalar
from tensorflow_datasets.core.features.sequence_feature import Sequence
from tensorflow_datasets.core.features.tensor_feature import Encoding
Expand All @@ -52,6 +52,7 @@
"FeatureConnector",
"FeaturesDict",
"LabeledImage",
"Optional",
"Tensor",
"TensorInfo",
"Scalar",
Expand Down
9 changes: 4 additions & 5 deletions tensorflow_datasets/core/features/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def decode_example(self, tfexample_data):

def decode_example_np(
self, example_data: type_utils.NpArrayOrScalar
) -> type_utils.NpArrayOrScalar:
) -> type_utils.NpArrayOrScalar | None:
"""Encode the feature dict into NumPy-compatible input.
Args:
Expand Down Expand Up @@ -1072,12 +1072,11 @@ def _name2proto_cls(cls_name: str) -> Type[message.Message]:

def _proto2oneof_field_name(proto: message.Message) -> str:
"""Returns the field name associated with the class."""
for field in _feature_content_fields():
fields = _feature_content_fields()
for field in fields:
if field.message_type._concrete_class == type(proto): # pylint: disable=protected-access
return field.name
supported_cls = [
f.message_type._concrete_class.name for f in _feature_content_fields() # pylint: disable=protected-access
]
supported_cls = [field.message_type._concrete_class for field in fields] # pylint: disable=protected-access
raise ValueError(f'Unknown proto {type(proto)}. Supported: {supported_cls}.')


Expand Down
158 changes: 158 additions & 0 deletions tensorflow_datasets/core/features/optional_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# coding=utf-8
# Copyright 2023 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Optional feature connector."""

from __future__ import annotations

from typing import Any, TypedDict, Union

import numpy as np
from tensorflow_datasets.core.features import feature as feature_lib
from tensorflow_datasets.core.features import features_dict
from tensorflow_datasets.core.features import tensor_feature
from tensorflow_datasets.core.proto import feature_pb2
from tensorflow_datasets.core.utils import py_utils
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf

Json = type_utils.Json

# TODO: review all docstrings!!


class OptionalExample(TypedDict):
has_value: bool
value: Any


class Optional(feature_lib.FeatureConnector):
r"""Optional feature connector.
Warning: This feature is under active development. For now,
`tfds.features.Optional` can only wrap `tfds.features.Scalar`:
```python
features = tfds.features.FeaturesDict({
'myfeature': tfds.features.Optional(np.int32),
})
```
In `_generate_examples`, you can yield `None` instead of the actual feature
value:
```python
def _generate_examples(self):
yield {
'myfeature': None, # Accept `None` or feature
}
```
For Grain/NumPy users, `tfds.data_source` will output None values.
```python
ds = tfds.data_source('dataset_with_optional', split='train')
for element in ds:
if element is None:
# ...
else:
# ...
```
For tf.data users, None values don't make sense in tensors, so we haven't
implemented the feature yet. If you're a tf.data user with a use case for
optional values, we'd like to hear from you.
"""

def __init__(
self,
feature: feature_lib.FeatureConnectorArg,
*,
doc: feature_lib.DocArg = None,
):
"""Constructor.
Args:
feature: The feature to wrap (any TFDS feature is supported).
doc: Documentation of this feature (e.g. description).
"""
self._feature = features_dict.to_feature(feature)
if not isinstance(self._feature, tensor_feature.Tensor):
raise NotImplementedError(
'tfds.features.Optional only supports Tensors. Refer to its'
' documentation for more information.'
)
super().__init__(doc=doc)

@py_utils.memoize()
def get_tensor_info(self):
"""See base class for details."""
return self._feature.get_tensor_info()

@py_utils.memoize()
def get_serialized_info(self):
"""See base class for details."""
return self._feature.get_serialized_info()

def __getitem__(self, key: str) -> feature_lib.FeatureConnector:
"""Allows to access the underlying features directly."""
return self._feature[key]

def __contains__(self, key: str) -> bool:
return key in self._feature

def save_metadata(self, *args, **kwargs):
"""See base class for details."""
self._feature.save_metadata(*args, **kwargs)

def load_metadata(self, *args, **kwargs):
"""See base class for details."""
self._feature.load_metadata(*args, **kwargs)

@classmethod
def from_json_content(cls, value: feature_pb2.Optional) -> Optional:
"""See base class for details."""
feature = feature_lib.FeatureConnector.from_proto(value.feature)
return cls(feature)

def to_json_content(self) -> feature_pb2.Optional:
"""See base class for details."""
return feature_pb2.Optional(feature=self._feature.to_proto())

@property
def feature(self):
"""The inner feature."""
return self._feature

def encode_example(self, example: Any) -> Any:
"""See base class for details."""
if example is None:
return None
else:
return self._feature.encode_example(example)

def decode_example(self, example: Any) -> Any:
"""See base class for details."""
raise NotImplementedError(
'tfds.features.Optional only supports tfds.data_source.'
)

def decode_example_np(
self, example: Any
) -> type_utils.NpArrayOrScalar | None:
if example is None:
return None
else:
return self._feature.decode_example_np(example)
56 changes: 56 additions & 0 deletions tensorflow_datasets/core/features/optional_feature_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# coding=utf-8
# Copyright 2023 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for optional_feature."""

from absl.testing import parameterized
import numpy as np
from tensorflow_datasets import testing
from tensorflow_datasets.core import features


class ScalarFeatureTest(
parameterized.TestCase, testing.FeatureExpectationsTestCase
):

@parameterized.parameters(
(np.int64, 42, 42),
(np.int64, None, testing.TestValue.NONE),
(np.str_, 'foo', 'foo'),
(np.str_, None, testing.TestValue.NONE),
)
def test_scalar(self, dtype, value, expected_np):
self.assertFeature(
feature=features.Optional(
features.Scalar(dtype=dtype, doc='Some description')
),
shape=(),
dtype=dtype,
tests=[
testing.FeatureExpectationItem(
value=value,
raise_cls=NotImplementedError,
raise_msg='supports tfds.data_source',
),
testing.FeatureExpectationItem(
value=value,
expected_np=expected_np,
),
],
)


if __name__ == '__main__':
testing.test_main()
7 changes: 7 additions & 0 deletions tensorflow_datasets/core/proto/feature.proto
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ message Feature {
TextFeature text = 10;
TranslationFeature translation = 11;
Sequence sequence = 12;
Optional optional = 13;
}
}

Expand Down Expand Up @@ -125,3 +126,9 @@ message Sequence {
// Optional length of the sequence.
int64 length = 2;
}

// An optional feature.
message Optional {
// The optional feature.
Feature feature = 1;
}
56 changes: 30 additions & 26 deletions tensorflow_datasets/core/proto/feature_generated_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
b'\n\rfeature.proto\x12\x13tensorflow_datasets"\xa0\x01\n\x0c\x46\x65\x61turesDict\x12\x41\n\x08\x66\x65\x61tures\x18\x01'
b' \x03(\x0b\x32/.tensorflow_datasets.FeaturesDict.FeaturesEntry\x1aM\n\rFeaturesEntry\x12\x0b\n\x03key\x18\x01'
b' \x01(\t\x12+\n\x05value\x18\x02'
b' \x01(\x0b\x32\x1c.tensorflow_datasets.Feature:\x02\x38\x01"\xbf\x05\n\x07\x46\x65\x61ture\x12\x19\n\x11python_class_name\x18\x01'
b' \x01(\x0b\x32\x1c.tensorflow_datasets.Feature:\x02\x38\x01"\xf2\x05\n\x07\x46\x65\x61ture\x12\x19\n\x11python_class_name\x18\x01'
b' \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x0e'
b' \x01(\t\x12\x13\n\x0bvalue_range\x18\x0f'
b' \x01(\t\x12\x38\n\x0cjson_feature\x18\x02 \x01(\x0b\x32'
Expand All @@ -47,7 +47,8 @@
b' \x01(\x0b\x32'
b' .tensorflow_datasets.TextFeatureH\x00\x12>\n\x0btranslation\x18\x0b'
b" \x01(\x0b\x32'.tensorflow_datasets.TranslationFeatureH\x00\x12\x31\n\x08sequence\x18\x0c"
b' \x01(\x0b\x32\x1d.tensorflow_datasets.SequenceH\x00\x42\t\n\x07\x63ontent"\x1b\n\x0bJsonFeature\x12\x0c\n\x04json\x18\x01'
b' \x01(\x0b\x32\x1d.tensorflow_datasets.SequenceH\x00\x12\x31\n\x08optional\x18\r'
b' \x01(\x0b\x32\x1d.tensorflow_datasets.OptionalH\x00\x42\t\n\x07\x63ontent"\x1b\n\x0bJsonFeature\x12\x0c\n\x04json\x18\x01'
b' \x01(\t"\x1b\n\x05Shape\x12\x12\n\ndimensions\x18\x01'
b' \x03(\x03"[\n\rTensorFeature\x12)\n\x05shape\x18\x01'
b' \x01(\x0b\x32\x1a.tensorflow_datasets.Shape\x12\r\n\x05\x64type\x18\x02'
Expand Down Expand Up @@ -75,7 +76,8 @@
b' \x03(\t\x12&\n\x1evariable_languages_per_example\x18\x02'
b' \x01(\x08"I\n\x08Sequence\x12-\n\x07\x66\x65\x61ture\x18\x01'
b' \x01(\x0b\x32\x1c.tensorflow_datasets.Feature\x12\x0e\n\x06length\x18\x02'
b' \x01(\x03\x42\x03\xf8\x01\x01\x62\x06proto3'
b' \x01(\x03"9\n\x08Optional\x12-\n\x07\x66\x65\x61ture\x18\x01'
b' \x01(\x0b\x32\x1c.tensorflow_datasets.FeatureB\x03\xf8\x01\x01\x62\x06proto3'
)

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
Expand All @@ -90,27 +92,29 @@
_FEATURESDICT_FEATURESENTRY._serialized_start = 122
_FEATURESDICT_FEATURESENTRY._serialized_end = 199
_FEATURE._serialized_start = 202
_FEATURE._serialized_end = 905
_JSONFEATURE._serialized_start = 907
_JSONFEATURE._serialized_end = 934
_SHAPE._serialized_start = 936
_SHAPE._serialized_end = 963
_TENSORFEATURE._serialized_start = 965
_TENSORFEATURE._serialized_end = 1056
_CLASSLABEL._serialized_start = 1058
_CLASSLABEL._serialized_end = 1091
_IMAGEFEATURE._serialized_start = 1094
_IMAGEFEATURE._serialized_end = 1261
_VIDEOFEATURE._serialized_start = 1264
_VIDEOFEATURE._serialized_end = 1410
_AUDIOFEATURE._serialized_start = 1413
_AUDIOFEATURE._serialized_end = 1566
_BOUNDINGBOXFEATURE._serialized_start = 1568
_BOUNDINGBOXFEATURE._serialized_end = 1646
_TEXTFEATURE._serialized_start = 1648
_TEXTFEATURE._serialized_end = 1661
_TRANSLATIONFEATURE._serialized_start = 1663
_TRANSLATIONFEATURE._serialized_end = 1742
_SEQUENCE._serialized_start = 1744
_SEQUENCE._serialized_end = 1817
_FEATURE._serialized_end = 956
_JSONFEATURE._serialized_start = 958
_JSONFEATURE._serialized_end = 985
_SHAPE._serialized_start = 987
_SHAPE._serialized_end = 1014
_TENSORFEATURE._serialized_start = 1016
_TENSORFEATURE._serialized_end = 1107
_CLASSLABEL._serialized_start = 1109
_CLASSLABEL._serialized_end = 1142
_IMAGEFEATURE._serialized_start = 1145
_IMAGEFEATURE._serialized_end = 1312
_VIDEOFEATURE._serialized_start = 1315
_VIDEOFEATURE._serialized_end = 1461
_AUDIOFEATURE._serialized_start = 1464
_AUDIOFEATURE._serialized_end = 1617
_BOUNDINGBOXFEATURE._serialized_start = 1619
_BOUNDINGBOXFEATURE._serialized_end = 1697
_TEXTFEATURE._serialized_start = 1699
_TEXTFEATURE._serialized_end = 1712
_TRANSLATIONFEATURE._serialized_start = 1714
_TRANSLATIONFEATURE._serialized_end = 1793
_SEQUENCE._serialized_start = 1795
_SEQUENCE._serialized_end = 1868
_OPTIONAL._serialized_start = 1870
_OPTIONAL._serialized_end = 1927
# @@protoc_insertion_point(module_scope)

0 comments on commit 6a34b0c

Please sign in to comment.