Skip to content

Commit

Permalink
Support tfds.features.Optional in TFDS for tfds.data_source users.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599089261
  • Loading branch information
marcenacp authored and The TensorFlow Datasets Authors committed Jan 18, 2024
1 parent 4fa0835 commit 63ea0c6
Show file tree
Hide file tree
Showing 11 changed files with 379 additions and 110 deletions.
14 changes: 6 additions & 8 deletions tensorflow_datasets/core/example_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,15 @@ def run_with_reraise(fn, k, example_data, tensor_info):
# 'objects/tokens/flat_values': [0, 1, 2, 3, 4],
# 'objects/tokens/row_lengths_0': [3, 0, 2],
# }
features = utils.flatten_nest_dict(
{
k: run_with_reraise(
_add_ragged_fields, k, example_dict[k], tensor_info
)
for k, tensor_info in tensor_info_dict.items()
}
)
features = utils.flatten_nest_dict({
k: run_with_reraise(_add_ragged_fields, k, example_dict[k], tensor_info)
for k, tensor_info in tensor_info_dict.items()
})
features = {
k: run_with_reraise(_item_to_tf_feature, k, item, tensor_info)
for k, (item, tensor_info) in features.items()
# If the item is None, it doesn't appear in the proto at all.
if item is not None
}
return tf_example_pb2.Example(
features=tf_feature_pb2.Features(feature=features)
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_datasets/core/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""API defining dataset features (image, text, scalar,...).
See [the guide](https://www.tensorflow.org/datasets/features).
"""

from tensorflow_datasets.core.features.audio_feature import Audio
Expand All @@ -31,6 +30,7 @@
from tensorflow_datasets.core.features.features_dict import FeaturesDict
from tensorflow_datasets.core.features.image_feature import Image
from tensorflow_datasets.core.features.labeled_image import LabeledImage
from tensorflow_datasets.core.features.optional_feature import Optional
from tensorflow_datasets.core.features.scalar import Scalar
from tensorflow_datasets.core.features.sequence_feature import Sequence
from tensorflow_datasets.core.features.tensor_feature import Encoding
Expand All @@ -52,6 +52,7 @@
"FeatureConnector",
"FeaturesDict",
"LabeledImage",
"Optional",
"Tensor",
"TensorInfo",
"Scalar",
Expand Down
9 changes: 4 additions & 5 deletions tensorflow_datasets/core/features/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def decode_example(self, tfexample_data):

def decode_example_np(
self, example_data: type_utils.NpArrayOrScalar
) -> type_utils.NpArrayOrScalar:
) -> type_utils.NpArrayOrScalar | None:
"""Encode the feature dict into NumPy-compatible input.
Args:
Expand Down Expand Up @@ -1122,12 +1122,11 @@ def _name2proto_cls(cls_name: str) -> Type[message.Message]:

def _proto2oneof_field_name(proto: message.Message) -> str:
"""Returns the field name associated with the class."""
for field in _feature_content_fields():
fields = _feature_content_fields()
for field in fields:
if field.message_type._concrete_class == type(proto): # pylint: disable=protected-access
return field.name
supported_cls = [
f.message_type._concrete_class.name for f in _feature_content_fields() # pylint: disable=protected-access
]
supported_cls = [field.message_type._concrete_class for field in fields] # pylint: disable=protected-access
raise ValueError(f'Unknown proto {type(proto)}. Supported: {supported_cls}.')


Expand Down
38 changes: 25 additions & 13 deletions tensorflow_datasets/core/features/features_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def decode_example(self, tfexample_dict):
# Merge the two values
return tfexample_dict['a'] + tfexample_dict['b']

def decode_example_np(self, example):
return example['a'] + example['b']


class AnOutputConnector(features_lib.FeatureConnector):
"""Simple FeatureConnector implementing the based methods used for test."""
Expand All @@ -64,6 +67,9 @@ def encode_example(self, example_data):
def decode_example(self, tfexample_data):
return tfexample_data / 10.0

def decode_example_np(self, example):
return example / 10.0


class FeatureDictTest(
parameterized.TestCase, testing.FeatureExpectationsTestCase
Expand Down Expand Up @@ -236,6 +242,18 @@ def test_fdict(self, image_dtype, metadata_dtype, output_metadata_dtype):
'metadata/path': tf.compat.as_bytes('path/to/xyz.jpg'),
},
},
expected_np={
# See explanations above.
'input': 12,
'output': -1.0,
'img': {
'size': {
'height': 256,
'width': 128,
},
'metadata/path': b'path/to/xyz.jpg',
},
},
),
],
)
Expand Down Expand Up @@ -268,25 +286,21 @@ def test_feature_getitem(self, features_dict):
def test_feature__repr__(self):
label = features_lib.ClassLabel(names=['m', 'f'])
feature_dict = features_lib.FeaturesDict({
'metadata': features_lib.Sequence(
{
'frame': features_lib.Image(shape=(32, 32, 3)),
}
),
'metadata': features_lib.Sequence({
'frame': features_lib.Image(shape=(32, 32, 3)),
}),
'label': features_lib.Sequence(label),
})

self.assertEqual(
repr(feature_dict),
textwrap.dedent(
"""\
textwrap.dedent("""\
FeaturesDict({
'label': Sequence(ClassLabel(shape=(), dtype=int64, num_classes=2)),
'metadata': Sequence({
'frame': Image(shape=(32, 32, 3), dtype=uint8),
}),
})"""
),
})"""),
)

def test_feature_save_load_metadata_slashes(self):
Expand Down Expand Up @@ -325,14 +339,12 @@ class ChildTensor(features_lib.Tensor):
'child': ChildTensor(shape=(), dtype=dtype),
})
),
textwrap.dedent(
"""\
textwrap.dedent("""\
FeaturesDict({
'child': ChildTensor(shape=(), dtype=int32),
'colapsed': int32,
'noncolapsed': Tensor(shape=(1,), dtype=int32),
})"""
),
})"""),
)


Expand Down
149 changes: 149 additions & 0 deletions tensorflow_datasets/core/features/optional_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# coding=utf-8
# Copyright 2023 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Optional feature connector."""

from __future__ import annotations

from typing import Any

from tensorflow_datasets.core.features import feature as feature_lib
from tensorflow_datasets.core.features import features_dict
from tensorflow_datasets.core.features import tensor_feature
from tensorflow_datasets.core.proto import feature_pb2
from tensorflow_datasets.core.utils import py_utils
from tensorflow_datasets.core.utils import type_utils

Json = type_utils.Json


class Optional(feature_lib.FeatureConnector):
r"""Optional feature connector.
Warning: This feature is under active development. For now,
`tfds.features.Optional` can only wrap `tfds.features.Scalar`:
```python
features = tfds.features.FeaturesDict({
'myfeature': tfds.features.Optional(np.int32),
})
```
In `_generate_examples`, you can yield `None` instead of the actual feature
value:
```python
def _generate_examples(self):
yield {
'myfeature': None, # Accept `None` or feature
}
```
For Grain/NumPy users, `tfds.data_source` will output None values.
```python
ds = tfds.data_source('dataset_with_optional', split='train')
for element in ds:
if element is None:
# ...
else:
# ...
```
For tf.data users, None values don't make sense in tensors, so we haven't
implemented the feature yet. If you're a tf.data user with a use case for
optional values, we'd like to hear from you.
"""

def __init__(
self,
feature: feature_lib.FeatureConnectorArg,
*,
doc: feature_lib.DocArg = None,
):
"""Constructor.
Args:
feature: The feature to wrap (any TFDS feature is supported).
doc: Documentation of this feature (e.g. description).
"""
self._feature = features_dict.to_feature(feature)
if not isinstance(self._feature, tensor_feature.Tensor):
raise NotImplementedError(
'tfds.features.Optional only supports Tensors. Refer to its'
' documentation for more information.'
)
super().__init__(doc=doc)

@py_utils.memoize()
def get_tensor_info(self):
"""See base class for details."""
return self._feature.get_tensor_info()

@py_utils.memoize()
def get_serialized_info(self):
"""See base class for details."""
return self._feature.get_serialized_info()

def __getitem__(self, key: str) -> feature_lib.FeatureConnector:
"""Allows to access the underlying features directly."""
return self._feature[key]

def __contains__(self, key: str) -> bool:
return key in self._feature

def save_metadata(self, *args, **kwargs):
"""See base class for details."""
self._feature.save_metadata(*args, **kwargs)

def load_metadata(self, *args, **kwargs):
"""See base class for details."""
self._feature.load_metadata(*args, **kwargs)

@classmethod
def from_json_content(cls, value: feature_pb2.Optional) -> Optional:
"""See base class for details."""
feature = feature_lib.FeatureConnector.from_proto(value.feature)
return cls(feature)

def to_json_content(self) -> feature_pb2.Optional:
"""See base class for details."""
return feature_pb2.Optional(feature=self._feature.to_proto())

@property
def feature(self):
"""The inner feature."""
return self._feature

def encode_example(self, example: Any) -> Any:
"""See base class for details."""
if example is None:
return None
else:
return self._feature.encode_example(example)

def decode_example(self, example: Any) -> Any:
"""See base class for details."""
raise NotImplementedError(
'tfds.features.Optional only supports tfds.data_source.'
)

def decode_example_np(
self, example: Any
) -> type_utils.NpArrayOrScalar | None:
if example is None:
return None
else:
return self._feature.decode_example_np(example)
80 changes: 80 additions & 0 deletions tensorflow_datasets/core/features/optional_feature_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# coding=utf-8
# Copyright 2023 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for optional_feature."""

from absl.testing import parameterized
import numpy as np
from tensorflow_datasets import testing
from tensorflow_datasets.core import features


class ScalarFeatureTest(
parameterized.TestCase, testing.FeatureExpectationsTestCase
):

@parameterized.parameters(
(np.int64, 42, 42),
(np.int64, None, testing.TestValue.NONE),
(np.str_, 'foo', 'foo'),
(np.str_, None, testing.TestValue.NONE),
)
def test_scalar(self, dtype, value, expected_np):
self.assertFeature(
feature=features.Optional(
features.Scalar(dtype=dtype, doc='Some description')
),
shape=(),
dtype=dtype,
tests=[
testing.FeatureExpectationItem(
value=value,
raise_cls=NotImplementedError,
raise_msg='supports tfds.data_source',
),
testing.FeatureExpectationItem(
value=value,
expected_np=expected_np,
),
],
)

def test_dict(self):
self.assertFeature(
feature=features.FeaturesDict({'a': features.Optional(np.int32)}),
shape={'a': ()},
dtype={'a': np.int32},
tests=[
testing.FeatureExpectationItem(
value={'a': None},
raise_cls=NotImplementedError,
raise_msg='supports tfds.data_source',
),
testing.FeatureExpectationItem(
value={'a': None},
expected_np={'a': None},
),
# You cannot ommit the key, you do have to specify {'a': None}.
testing.FeatureExpectationItem(
value={},
raise_cls_np=RuntimeError,
raise_msg="'a'",
),
],
)


if __name__ == '__main__':
testing.test_main()

0 comments on commit 63ea0c6

Please sign in to comment.