Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tfds.features.Optional in TFDS for tfds.data_source users. #5235

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 6 additions & 8 deletions tensorflow_datasets/core/example_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,15 @@ def run_with_reraise(fn, k, example_data, tensor_info):
# 'objects/tokens/flat_values': [0, 1, 2, 3, 4],
# 'objects/tokens/row_lengths_0': [3, 0, 2],
# }
features = utils.flatten_nest_dict(
{
k: run_with_reraise(
_add_ragged_fields, k, example_dict[k], tensor_info
)
for k, tensor_info in tensor_info_dict.items()
}
)
features = utils.flatten_nest_dict({
k: run_with_reraise(_add_ragged_fields, k, example_dict[k], tensor_info)
for k, tensor_info in tensor_info_dict.items()
})
features = {
k: run_with_reraise(_item_to_tf_feature, k, item, tensor_info)
for k, (item, tensor_info) in features.items()
# If the item is None, it doesn't appear in the proto at all.
if item is not None
}
return tf_example_pb2.Example(
features=tf_feature_pb2.Features(feature=features)
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_datasets/core/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""API defining dataset features (image, text, scalar,...).

See [the guide](https://www.tensorflow.org/datasets/features).

"""

from tensorflow_datasets.core.features.audio_feature import Audio
Expand All @@ -31,6 +30,7 @@
from tensorflow_datasets.core.features.features_dict import FeaturesDict
from tensorflow_datasets.core.features.image_feature import Image
from tensorflow_datasets.core.features.labeled_image import LabeledImage
from tensorflow_datasets.core.features.optional_feature import Optional
from tensorflow_datasets.core.features.scalar import Scalar
from tensorflow_datasets.core.features.sequence_feature import Sequence
from tensorflow_datasets.core.features.tensor_feature import Encoding
Expand All @@ -52,6 +52,7 @@
"FeatureConnector",
"FeaturesDict",
"LabeledImage",
"Optional",
"Tensor",
"TensorInfo",
"Scalar",
Expand Down
9 changes: 4 additions & 5 deletions tensorflow_datasets/core/features/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def decode_example(self, tfexample_data):

def decode_example_np(
self, example_data: type_utils.NpArrayOrScalar
) -> type_utils.NpArrayOrScalar:
) -> type_utils.NpArrayOrScalar | None:
"""Encode the feature dict into NumPy-compatible input.

Args:
Expand Down Expand Up @@ -1122,12 +1122,11 @@ def _name2proto_cls(cls_name: str) -> Type[message.Message]:

def _proto2oneof_field_name(proto: message.Message) -> str:
"""Returns the field name associated with the class."""
for field in _feature_content_fields():
fields = _feature_content_fields()
for field in fields:
if field.message_type._concrete_class == type(proto): # pylint: disable=protected-access
return field.name
supported_cls = [
f.message_type._concrete_class.name for f in _feature_content_fields() # pylint: disable=protected-access
]
supported_cls = [field.message_type._concrete_class for field in fields] # pylint: disable=protected-access
raise ValueError(f'Unknown proto {type(proto)}. Supported: {supported_cls}.')


Expand Down
38 changes: 25 additions & 13 deletions tensorflow_datasets/core/features/features_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def decode_example(self, tfexample_dict):
# Merge the two values
return tfexample_dict['a'] + tfexample_dict['b']

def decode_example_np(self, example):
return example['a'] + example['b']


class AnOutputConnector(features_lib.FeatureConnector):
"""Simple FeatureConnector implementing the based methods used for test."""
Expand All @@ -64,6 +67,9 @@ def encode_example(self, example_data):
def decode_example(self, tfexample_data):
return tfexample_data / 10.0

def decode_example_np(self, example):
return example / 10.0


class FeatureDictTest(
parameterized.TestCase, testing.FeatureExpectationsTestCase
Expand Down Expand Up @@ -236,6 +242,18 @@ def test_fdict(self, image_dtype, metadata_dtype, output_metadata_dtype):
'metadata/path': tf.compat.as_bytes('path/to/xyz.jpg'),
},
},
expected_np={
# See explanations above.
'input': 12,
'output': -1.0,
'img': {
'size': {
'height': 256,
'width': 128,
},
'metadata/path': b'path/to/xyz.jpg',
},
},
),
],
)
Expand Down Expand Up @@ -268,25 +286,21 @@ def test_feature_getitem(self, features_dict):
def test_feature__repr__(self):
label = features_lib.ClassLabel(names=['m', 'f'])
feature_dict = features_lib.FeaturesDict({
'metadata': features_lib.Sequence(
{
'frame': features_lib.Image(shape=(32, 32, 3)),
}
),
'metadata': features_lib.Sequence({
'frame': features_lib.Image(shape=(32, 32, 3)),
}),
'label': features_lib.Sequence(label),
})

self.assertEqual(
repr(feature_dict),
textwrap.dedent(
"""\
textwrap.dedent("""\
FeaturesDict({
'label': Sequence(ClassLabel(shape=(), dtype=int64, num_classes=2)),
'metadata': Sequence({
'frame': Image(shape=(32, 32, 3), dtype=uint8),
}),
})"""
),
})"""),
)

def test_feature_save_load_metadata_slashes(self):
Expand Down Expand Up @@ -325,14 +339,12 @@ class ChildTensor(features_lib.Tensor):
'child': ChildTensor(shape=(), dtype=dtype),
})
),
textwrap.dedent(
"""\
textwrap.dedent("""\
FeaturesDict({
'child': ChildTensor(shape=(), dtype=int32),
'colapsed': int32,
'noncolapsed': Tensor(shape=(1,), dtype=int32),
})"""
),
})"""),
)


Expand Down
149 changes: 149 additions & 0 deletions tensorflow_datasets/core/features/optional_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# coding=utf-8
# Copyright 2023 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Optional feature connector."""

from __future__ import annotations

from typing import Any

from tensorflow_datasets.core.features import feature as feature_lib
from tensorflow_datasets.core.features import features_dict
from tensorflow_datasets.core.features import tensor_feature
from tensorflow_datasets.core.proto import feature_pb2
from tensorflow_datasets.core.utils import py_utils
from tensorflow_datasets.core.utils import type_utils

Json = type_utils.Json


class Optional(feature_lib.FeatureConnector):
r"""Optional feature connector.

Warning: This feature is under active development. For now,
`tfds.features.Optional` can only wrap `tfds.features.Scalar`:

```python
features = tfds.features.FeaturesDict({
'myfeature': tfds.features.Optional(np.int32),
})
```

In `_generate_examples`, you can yield `None` instead of the actual feature
value:

```python
def _generate_examples(self):
yield {
'myfeature': None, # Accept `None` or feature
}
```

For Grain/NumPy users, `tfds.data_source` will output None values.

```python
ds = tfds.data_source('dataset_with_optional', split='train')
for element in ds:
if element is None:
# ...
else:
# ...
```

For tf.data users, None values don't make sense in tensors, so we haven't
implemented the feature yet. If you're a tf.data user with a use case for
optional values, we'd like to hear from you.
"""

def __init__(
self,
feature: feature_lib.FeatureConnectorArg,
*,
doc: feature_lib.DocArg = None,
):
"""Constructor.

Args:
feature: The feature to wrap (any TFDS feature is supported).
doc: Documentation of this feature (e.g. description).
"""
self._feature = features_dict.to_feature(feature)
if not isinstance(self._feature, tensor_feature.Tensor):
raise NotImplementedError(
'tfds.features.Optional only supports Tensors. Refer to its'
' documentation for more information.'
)
super().__init__(doc=doc)

@py_utils.memoize()
def get_tensor_info(self):
"""See base class for details."""
return self._feature.get_tensor_info()

@py_utils.memoize()
def get_serialized_info(self):
"""See base class for details."""
return self._feature.get_serialized_info()

def __getitem__(self, key: str) -> feature_lib.FeatureConnector:
"""Allows to access the underlying features directly."""
return self._feature[key]

def __contains__(self, key: str) -> bool:
return key in self._feature

def save_metadata(self, *args, **kwargs):
"""See base class for details."""
self._feature.save_metadata(*args, **kwargs)

def load_metadata(self, *args, **kwargs):
"""See base class for details."""
self._feature.load_metadata(*args, **kwargs)

@classmethod
def from_json_content(cls, value: feature_pb2.Optional) -> Optional:
"""See base class for details."""
feature = feature_lib.FeatureConnector.from_proto(value.feature)
return cls(feature)

def to_json_content(self) -> feature_pb2.Optional:
"""See base class for details."""
return feature_pb2.Optional(feature=self._feature.to_proto())

@property
def feature(self):
"""The inner feature."""
return self._feature

def encode_example(self, example: Any) -> Any:
"""See base class for details."""
if example is None:
return None
else:
return self._feature.encode_example(example)

def decode_example(self, example: Any) -> Any:
"""See base class for details."""
raise NotImplementedError(
'tfds.features.Optional only supports tfds.data_source.'
)

def decode_example_np(
self, example: Any
) -> type_utils.NpArrayOrScalar | None:
if example is None:
return None
else:
return self._feature.decode_example_np(example)
80 changes: 80 additions & 0 deletions tensorflow_datasets/core/features/optional_feature_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# coding=utf-8
# Copyright 2023 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for optional_feature."""

from absl.testing import parameterized
import numpy as np
from tensorflow_datasets import testing
from tensorflow_datasets.core import features


class ScalarFeatureTest(
parameterized.TestCase, testing.FeatureExpectationsTestCase
):

@parameterized.parameters(
(np.int64, 42, 42),
(np.int64, None, testing.TestValue.NONE),
(np.str_, 'foo', 'foo'),
(np.str_, None, testing.TestValue.NONE),
)
def test_scalar(self, dtype, value, expected_np):
self.assertFeature(
feature=features.Optional(
features.Scalar(dtype=dtype, doc='Some description')
),
shape=(),
dtype=dtype,
tests=[
testing.FeatureExpectationItem(
value=value,
raise_cls=NotImplementedError,
raise_msg='supports tfds.data_source',
),
testing.FeatureExpectationItem(
value=value,
expected_np=expected_np,
),
],
)

def test_dict(self):
self.assertFeature(
feature=features.FeaturesDict({'a': features.Optional(np.int32)}),
shape={'a': ()},
dtype={'a': np.int32},
tests=[
testing.FeatureExpectationItem(
value={'a': None},
raise_cls=NotImplementedError,
raise_msg='supports tfds.data_source',
),
testing.FeatureExpectationItem(
value={'a': None},
expected_np={'a': None},
),
# You cannot ommit the key, you do have to specify {'a': None}.
testing.FeatureExpectationItem(
value={},
raise_cls_np=RuntimeError,
raise_msg="'a'",
),
],
)


if __name__ == '__main__':
testing.test_main()