Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why doesn't The calling iterator did not fully read the dataset being cached. appear on Google Colab? #66347

Open
miticollo opened this issue Apr 24, 2024 · 3 comments
Assignees
Labels
comp:data tf.data related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.16 type:bug Bug

Comments

@miticollo
Copy link

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

binary

TensorFlow version

v2.16.1-0-g5bc9d26649c 2.16.1

Custom code

Yes

OS platform and distribution

Linux Manjaro

Mobile device

No response

Python version

3.11.8

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

CUDA Version: 12.4 (from optirun nvidia-smi)

GPU model and memory

No response

Current behavior?

Here the output running the notebook on my laptop.
Below you can find another Colab link that shows that the warning (see Relevant log output) doesn't appear on Google Colab.

I don't know if this is an issue. Maybe my laptop is not quite powerful or maybe the culprit is the Python version: Colab uses Python 3.10 while I tested the notebook on Python 3.11.8.

Standalone code to reproduce the issue

https://colab.research.google.com/drive/1Ns5P1aPDcFjuoPohSDO567rPL1jluSun?usp=sharing

Relevant log output

W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
@SuryanarayanaY
Copy link
Collaborator

Hi @miticollo ,

Colab usually filters Info and Warning logs. That's why you are not able to see it. The warning descriptions states if you are caching and then applying take() or batch() method. If your dataset is large caching the whole dataset might be a problem and can leads to truncation of dataset. This is just a warning to inform the user about possible data truncation.

Since tf.data.Dataset is a pipeline object you can first batching the dataset and then caching.

dataset = tf.data.Dataset.from_tensor_slices((X_valid, y_valid))
dataset = dataset.map(lambda x, y: (x, y), num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(32)
dataset = dataset.cache()
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
dataset = dataset.apply(tf.data.experimental.prefetch_to_device(device="/gpu:0"))

@SuryanarayanaY SuryanarayanaY added the stat:awaiting response Status - Awaiting response from author label Apr 24, 2024
@miticollo
Copy link
Author

miticollo commented Apr 24, 2024

I tried to reduce dataset to 64 bytes and 70 bytes (I hope I did it correctly).
Then I run the notebook different times and I noticed that the warning sometimes appears and sometimes not.
I don't know why, could be an issue?

Anyway these are outputs from 6 sequential runs (every time I turned off and then on the IPython runtime):

  • 1: the warning appears twice.
  • 2: the warning appears only in the second case.
  • 3: see case 2.
  • 4: see case 2.
  • 5: see case 2.
  • 6: see case 1.

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label Apr 24, 2024
@SuryanarayanaY SuryanarayanaY added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Apr 29, 2024
@fylux
Copy link

fylux commented May 2, 2024

It seems to working as intended. Cache() operator will collect the data from upstream into an in-memory cache as the downstream iterator is pulling the data elements. As the user iterates through the full dataset, all of it will get processed and loaded into the in-memory cache.

The Colab version is running in eager mode and is waiting on further user inputs while holding the partial dataset in the cache. When running the notebook in a shell, it is evaluating the full session and given that the code only asks for a single element, it hasn't iterated through the full dataset. The warning is just informing the user that the cached data (i.e. data read so far will get discarded) and the iterators reset to the beginning.

If you were to replace:

x_batch, y_batch = next(iter(dataset))

with:

for (x_batch, y_batch) in dataset:
  pass

You shouldn't see this warning in the shell execution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:data tf.data related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.16 type:bug Bug
Projects
None yet
Development

No branches or pull requests

3 participants