Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial prefetch for simple single chunked dim #161

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ljstrnadiii
Copy link

@ljstrnadiii ljstrnadiii commented Jan 18, 2023

POC Prefetch Generator:

This is a draft pr to articulate one possible approach to "prefetching" dask arrays or xarray arrays with dask.

The goals were to simultaneously:

  • support batching over the first dimension, which is the only chunked dimension
  • support prefetching by making sure prefetch number of batches are always loading
  • use the typical dask mechanics to transfer data by submitting tasks, working with futures and as_completed
  • super basic profile of Gb/s. I want to be able to feed my gpu at around a few gb/s.
  • additional considerations are written in the docstring.

I also tried one approach using a Queue on the workers. This felt weird and found myself reinventing features that dask already has.

Results

Using helm to deploy a cluster on kubernetes with 8 workers (4cpu and 16gb each and relatively standard network configurations), I am able to see:

  • 1.6Gb/s without any adjustment in prefetch or array size
  • data transferring between workers on the daskui

What Next?

No clue. I would like to investigate

  • does gb/s scale as number of workers increases?
  • how do we determine a good prefetch size and avoid memory issues?
  • does this work with tensorflow? Will I get an error about things being pickled?
  • Can the deterministic batch idx approach be generalized easily?
  • Does this even belong in xbatcher? I feel like it was functionality I was hoping xbatcher would support.

note:

  • pre-commit is being funny--I skipped verification on commit for now.
  • running the script will save a 33gb zarr dataset if you specify a remote_path!

@ljstrnadiii
Copy link
Author

In another attempt to simplify an example and profile transferring data from multiple workers to a single worker (where an ml tasks would iterate over batches) I have created this example:

from more_itertools import chunked
import dask.array as da
from distributed import Client, get_client, wait
import seaborn as sns
import pandas as pd

# 16 workers (4cores, 16gb each) on kubernetes; no special network configuration
client = Client("...")                  


chunk_gbps = {}
for max_gather in [.5e9, 1e9, 2e9, 4e9, 6e9]:
    for chunk in [128, 256, 512, 1028, 2048]:
        print(f"looking at chunk {chunk}, {max_gather / 1e9}")
        _ = client.restart()
        array = da.random.random((25000, 100, 100, 9), chunks=(chunk, 100, 100, 9)).persist()
        wait(array)
        del client.datasets['test']
        client.publish_dataset(test=array)

        # determine block batch size to control transfered data with gather
        ex = array.blocks[0]
        batch_size = max(int(np.floor(max_gather / ex.nbytes)), 1)

        def compute_bytes():
            client = get_client()
            array = client.get_dataset('test')
            blocks = list(array.blocks)
            nbytes = 0
            t0 = time.time()
            for block_batch in chunked(blocks, batch_size):
                fs = [client.compute(b) for b in block_batch]
                arrays = client.gather(fs)
                for array in arrays:
                    nbytes += array.nbytes
            elapsed = time.time() - t0
            return (nbytes / elapsed) / 1e9

        # blocks = client.submit(get_blocks, pure=False)
        f = client.submit(compute_bytes, pure=False)
        chunk_gbps[(max_gather / 1e9, chunk, batch_size)] = f.result()

# plot for some trends
data = [(*k,v) for k,v in chunk_gbps.items()]
df = pd.DataFrame(data, columns=['gb_gather','chunk_size', 'actul_batch', 'gbps'])
sns.lineplot(x="gb_gather", y="gbps",hue="chunk_size",data=df)

Screen Shot 2023-01-18 at 10 29 38 AM

@ljstrnadiii
Copy link
Author

ljstrnadiii commented Jan 18, 2023

@jhamman @maxrjones this is sort of the approach am considering developing.

I think 2gbps should be fine, but I was able to get 8+gbps with https://github.com/NVlabs/tensorcom using basic k8s pods and a manifest, which uses msgpack with pyzmq. I am trying to avoid using that and stick with the dask mechanics, but I am tempted to mock up a quick profile script of using zmq to bypass dask entirely, but within dask tasks.

This all might not belong in xbatcher, but I wanted to put it out there to get ay feedback people might have.

@ljstrnadiii
Copy link
Author

ljstrnadiii commented Jan 20, 2023

Here is an example of using the prefetch generator with tf.data.Dataset

import tensorflow as tf
from xbatcher.prefetch_generators import PrefetchBatchGenerator

# let array be chunked only along first dim
array = ...
batch_size=128

def do_tf_ml():
    batch_gen = lambda : PrefetchBatchGenerator(array=array, batch_size=batch_size, prefetch=20)
    ds_counter = tf.data.Dataset.from_generator(batch_gen, output_types=tf.int32, output_shapes=(array[:batch_size].shape))
    nbytes = 0
    t0 = time.time()
    for count_batch in ds_counter.repeat().take(128):
        nbytes += count_batch.numpy().nbytes
    elapsed = time.time() - t0
    return nbytes / elapsed

f = client.submit(do_tf_ml)
f.result() / 1e9

@cmdupuis3
Copy link

Can the test at the bottom be wrapped as a function? I'm guessing it's not supposed to run for everyone.

@ljstrnadiii
Copy link
Author

Can the test at the bottom be wrapped as a function? I'm guessing it's not supposed to run for everyone.

@cmdupuis3 I am not sure I understand what you are asking by wrapped as a function. Do you mean be able to submit to dask?

The BatchGenerator should be available on this branch if you check it out and install in editable mode.

@cmdupuis3
Copy link

Actually I think I was confused. I read if __name__ == "__main__" as the entry point rather than as a conditional entry point. Pythonisms are not for forte lol

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants