Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meta-Dataset in TFDS: Getting as_numpy_iterator() from dataset returned from api.meta_dataset takes a very long time #83

Open
jfb54 opened this issue Feb 13, 2022 · 2 comments

Comments

@jfb54
Copy link

jfb54 commented Feb 13, 2022

I am trying to use the new Meta-Dataset in TFDS APIs, and I have hit a critical performance problem.

When I run the sample code to "Train on Meta-Dataset episodes" (with some added lines to record times), it takes about 3 or 4 minutes to create the dataset api.meta_dataset and roughly 1 hour to create the iterator episode_dataset.take(4).as_numpy_iterator(). Here is the code I am running:

import gin
import meta_dataset
from meta_dataset.data.tfds import api
import tensorflow_datasets as tfds
import time

# Set up a TFDS-compatible data configuration.
gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent /
                      'learn/gin/setups/data_config_tfds.gin')

# 'v1' here refers to the Meta-Dataset protocol version and means that we
# are using the protocol defined in the original Meta-Dataset paper
# (rather than in the VTAB+MD paper, which is the 'v2' protocol; see the
# VTAB+MD paper for a detailed explanation). This is not to be confused
# with the (unrelated) arXiv version of the Meta-Dataset paper.
md_version = 'v2'
md_sources = ('aircraft', 'cu_birds', 'dtd', 'fungi', 'ilsvrc_2012', 'omniglot', 'quickdraw')
if md_version == 'v1':
    md_sources += ('vgg_flower',)

print("\nStarting TFDS reader.")
start_time = time.time()

episode_dataset = api.meta_dataset(
    md_sources,
    md_version,
    # This is where the meta-split ('train', 'valid', or 'test') is specified.
    'train',
    data_dir='<path to tensorflow datasets>'
)

t1 = time.time()  # it takes about 4 minutes to get here
print("Created training dataset. Time = {0:.1f} seconds".format(t1 - start_time))

# We sample 4 episodes here for demonstration. `source_id` is the index (in
# `md_sources`) of the source that was sampled for the episode.
for episode, source_id in episode_dataset.take(4).as_numpy_iterator():
    stop_time = time.time()  # it takes about an hour to get here.
    print("Created training iterator. Time = {0:.1f} seconds".format(stop_time - t1))
    print("Total reader initialization time = {0:.1f} seconds".format(stop_time - start_time), flush=True)
    support_images, support_labels, _ = episode[:3]
    query_images, query_labels, _ = episode[3:]

I have run this on Linux and Windows with similar results. The time seems to be spent in:

_result = pywrap_tfe.TFE_Py_FastPathExecute(_ctx, "MakeIterator", name, dataset, iterator)

in the file gen_dataset_ops.py file which drops into C++ code that I didn't debug into.

Note that creating an iterator on api.episode_dataset for evaluation is reasonably quick - omniglot takes the longest at about 3 minutes, but the others take only a few seconds.

This issue makes training on MDv2 from TFDS more or less impossible.

@vdumoulin
Copy link
Collaborator

I believe most of that time is spend reading data and filling shuffle buffers. Each training class in each training source is instantiated as its own dataset with its own shuffle buffer (this is how examples are sampled from specific classes to form episodes), and by default in learn/gin/setups/data_config_tfds.gin the shuffle buffer size is upper-bounded by 1000.

What happens if you set DataConfig.shuffle_buffer_size = 10 and DataConfig.num_prefetch = 1 in the configuration file? Does it speed up iterator creation? For training you could strike a balance between shuffling quality and startup time by changing the default values. For evaluation I would advise against changing the default values, as they could impact the data distribution of sampled episodes and introduce an unwanted confounding factor when comparing against competing approaches.

@jfb54
Copy link
Author

jfb54 commented Mar 10, 2022

Thanks for looking into this. I did make the change that you suggested (DataConfig.shuffle_buffer_size = 10 and DataConfig.num_prefetch = 1) and that did not help (other than creating the training dataset is somewhat faster). Here are the timings that I get:

Starting TFDS reader.
Created training dataset. Time = 143.3 seconds (got a bit faster)
Created training iterator. Time = 3347.3 seconds (this takes close to 1 hour!)
Created validation iterators omniglot: Time = 5.7 seconds
Created validation iterators aircraft: Time = 0.9 seconds
Created validation iterators cu_birds: Time = 1.7 seconds
Created validation iterators dtd: Time = 0.5 seconds
Created validation iterators quickdraw: Time = 3.4 seconds
Created validation iterators fungi: Time = 20.2 seconds
Created validation iterators mscoco: Time = 2.7 seconds
Total validation iterator creation time. Time = 35.1 seconds
Created test iterator: omniglot, Time = 155.0 seconds
Created test iterator: aircraft, Time = 0.9 seconds
Created test iterator: cu_birds, Time = 1.7 seconds
Created test iterator: dtd, Time = 0.5 seconds
Created test iterator: quickdraw, Time = 3.4 seconds
Created test iterator: fungi, Time = 22.5 seconds
Created test iterator: traffic_sign, Time = 2.6 seconds
Created test iterator: mscoco, Time = 2.3 seconds
Total test iterator creation time = 189.0 seconds
Total reader initialization time = 3714.6 seconds

Thus creating an iterator with a single dataset is acceptably fast, but creating an iterator over multiple datasets (so you can meta-train on MDv2) is unacceptably slow. As I mentioned above, the time seems to be spent in : _result = pywrap_tfe.TFE_Py_FastPathExecute(_ctx, "MakeIterator", name, dataset, iterator) which is tricky to debug into.

This is a major blocker for us. If I can help debug in any way, I would be happy to.

John

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

2 participants