Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation for RaggedTensors does not match Input layer documentation #65399

Closed
damienwojtowicz opened this issue Apr 10, 2024 · 5 comments
Closed
Assignees
Labels
comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author type:docs-bug Document issues type:feature Feature requests

Comments

@damienwojtowicz
Copy link

Issue type

Documentation Bug

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

source

TensorFlow version

2.16.1

Custom code

No

OS platform and distribution

No response

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

Documentation of RaggedTensors in section TensorFlow APIs > Keras claim the following, and provides an example where argument ragged=True is passed to the constructor of a tf.keras.Input object:

Ragged tensors may be passed as inputs to a Keras model by setting ragged=True on tf.keras.Input or tf.keras.layers.InputLayer.

However, ragged is neither an argument of tf.keras.Input, tf.keras.layers.InputLayer, nor any of their parent classes. Their respective documentation does not mention it. Trying to execute the example provided with either TF v. 2.16.1 ir nightly (see the MWE) raises a TypeError stating that ragged is an unexpected keyword for Input.

It is puzzling as several sources on this internet as well as issues of this very repository mention it. As a consequence, it seems that ragged tensors cannot be used within TF at the moment.

Standalone code to reproduce the issue

# This is the example provided by the RaggedTensor documentation

import tensorflow as tf
print(tf.version.GIT_VERSION, tf.version.VERSION)

# Task: predict whether each sentence is a question or not.
sentences = tf.constant(
    ['What makes you think she is a witch?',
     'She turned me into a newt.',
     'A newt?',
     'Well, I got better.'])
is_question = tf.constant([True, False, True, False])

# Preprocess the input strings.
hash_buckets = 1000
words = tf.strings.split(sentences, ' ')
hashed_words = tf.strings.to_hash_bucket_fast(words, hash_buckets)

# Build the Keras model.
keras_model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=[None], dtype=tf.int64, ragged=True),
    tf.keras.layers.Embedding(hash_buckets, 16),
    tf.keras.layers.LSTM(32, use_bias=False),
    tf.keras.layers.Dense(32),
    tf.keras.layers.Activation(tf.nn.relu),
    tf.keras.layers.Dense(1)
])

keras_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
keras_model.fit(hashed_words, is_question, epochs=5)
print(keras_model.predict(hashed_words))

Relevant log output

2024-04-10 12:22:30.701897: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-10 12:22:30.733806: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-04-10 12:22:31.321292: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
v1.12.1-108775-gb453565cae0 2.17.0-dev20240409
Traceback (most recent call last):
  File "/home/dwojtowicz/PycharmProjects/japetus_ml/issue.py", line 19, in <module>
    tf.keras.layers.Input(shape=[None], dtype=tf.int64, ragged=True),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Input() got an unexpected keyword argument 'ragged'
@SuryanarayanaY
Copy link
Collaborator

Hi @damienwojtowicz ,

Starting from TF2.16v tensorflow bundles with Keras3. Since Kera3 is supporting multi backend now ragged tensors are not supported on other frameworks. Keras team also has plans to support Ragged Tensors with tensorflow backend in nearby future.

Since then if you want to use ragged tensors I recommend to install tf_keras package and set environment varaibale TF_USE_LEGACY_KERAS=1 .

@SuryanarayanaY
Copy link
Collaborator

Please refer this for more details.

@SuryanarayanaY SuryanarayanaY added type:feature Feature requests comp:keras Keras related issues stat:awaiting response Status - Awaiting response from author and removed type:bug Bug labels Apr 15, 2024
Copy link

This issue is stale because it has been open for 7 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Apr 23, 2024
Copy link

github-actions bot commented May 1, 2024

This issue was closed because it has been inactive for 7 days since being marked as stale. Please reopen if you'd like to work on this further.

@github-actions github-actions bot closed this as completed May 1, 2024
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author type:docs-bug Document issues type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

2 participants