Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding writeable to flags #2616

Open
jakirkham opened this issue Nov 8, 2019 · 15 comments · May be fixed by #8118
Open

Adding writeable to flags #2616

jakirkham opened this issue Nov 8, 2019 · 15 comments · May be fixed by #8118
Labels
cat:feature New features/APIs prio:low

Comments

@jakirkham
Copy link
Member

NumPy arrays have the writeable option in flags. This can be useful for checking to see if one can write to an array or not. Also it can be useful to ensure an array is not written to (by setting this flag to False). Would be useful to have this flag on CuPy arrays as well.

@leofang
Copy link
Member

leofang commented Nov 8, 2019

I suppose you know this already?

Changing WRITEABLE (to True) in Python is now considered dangerous and deprecated in NumPy:
https://docs.scipy.org/doc/numpy/release.html#writeable-flag-of-c-api-wrapped-arrays

But are you suggesting the opposite operation should be allowed?

I could dig this myself, but perhaps it's faster to just ask 🙂: How does this flag get honored at the C level in NumPy? Does every NumPy operation need to consult it? The change sounds like some non-trivial task to me, but perhaps I overthink.

FYI, in CUDA there's constant memory guaranteed not writable once initiated, but it's a scarce hardware resource and almost always not enough for serious computations. It'd be nice for RawKernel/RawModule users to be able to access it (see #1703 and #2510), but it's a different story.

@jakirkham
Copy link
Member Author

Right the interest here would be to mark arrays as read-only. 🙂

Yeah NumPy honors this at the C level as well. It does bit masking against the flags (an int) to check. Only operations that are manipulating the data itself.

I'm not sure. Would think __setitem__ and the inplace math operations would be the main issues. Are there other cases outside of these that would manipulate data in CuPy? You know the code here better than I. 😉

That seems like a useful feature to include. Though agree it is a different concern.

@kmaehashi
Copy link
Member

There are some case that NumPy return non-writable ndarray (e.g. xp.imag(xp.ones(5))), so it's better to have this feature also in terms of the compatibility. As suggested by Leo, I also think the change is not trivial, though. Specifically I don't come up with an idea to protect an array memory passed to RawKernel/RawModule.

@leofang
Copy link
Member

leofang commented Nov 8, 2019

In addition to RawKernel, I suppose any CUDA libraries that permit in-place operation (like cuFFT) would also require some nontrivial guard.

I fantasize asking NVIDIA folks to allow setting a writable flag in the pointer attribute would probably be the easiest (though highly unlikely) approach 🤣

@kmaehashi kmaehashi added cat:enhancement Improvements to existing features cat:feature New features/APIs prio:low and removed cat:enhancement Improvements to existing features labels Nov 12, 2019
@kmaehashi
Copy link
Member

Discussed in the dev team today, the reasonable implementation is to to support writable flag only in ElementwiseKernel/ReductionKernel (i.e., writable arrays cannot be specified as an output array).

@jakirkham Are there any immediate needs for this?

@jakirkham
Copy link
Member Author

It's a good question. Periodically there are situations where one would like to hand a buffer from lower-level code (like C/C++) to a Python user, but the buffer shouldn't be modified by the Python user. Today we solve these issues by copying before handing it off, but it would be nice to save on the copy.

@jakirkham
Copy link
Member Author

Another use case is related to __cuda_array_interface__. Where a user can mark the buffer as read-only. It would be nice to respect that if we can. Relevant text copied below.

  • data: (integer, boolean)

    The data is a 2-tuple. The first element is the data pointer
    as a Python int (or long). The data must be device-accessible.
    For zero-size arrays, use 0 here.
    The second element is the read-only flag as a Python bool.

@leofang
Copy link
Member

leofang commented Nov 25, 2019

I think what @kmaehashi meant is that it is hard to respect that flag in almost all basic CuPy operations except in ElementwiseKernel/ReductionKernel, where the input/output array attributes are put into inspection before processing.

I think this puts an extra layer of burden in general across all Python GPU libraries. Note that Numba does not respect that flag either. @jakirkham, could you point us to examples where (and how) this is implemented?

@hawkinsp
Copy link

I also noticed that CuPy does not respect the readonly flag for __cuda_array_interface__.

For example, using a copy of jax and jaxlib with GPU support built at github head with CuPy 7.1.1.

In [1]: import jax, jax.numpy as jnp, cupy

In [2]: x = jnp.array([1,2,3])

In [3]: x.__cuda_array_interface__
Out[3]:
{'shape': (3,),
 'typestr': '<i4',
 'data': (140074664592896, True),
 'version': 2}

In [4]: y = cupy.asarray(x)

In [5]: y.__cuda_array_interface__
Out[5]:
{'shape': (3,),
 'typestr': '<i4',
 'descr': [('', '<i4')],
 'version': 2,
 'strides': None,
 'data': (140074664592896, False)}

"Laundering" the array via CuPy has dropped the readonly flag.

@jakirkham
Copy link
Member Author

That's true. Short term it would be good to just error if a read-only array is provided via __cuda_array_interface__.

@hawkinsp
Copy link

My specific use case is JAX -> CuPy export, and JAX arrays are immutable. I guess I could lie and declare them mutable, and leave it up to the user to deal with the consequences.

I note that PyTorch doesn't handle readonly arrays either:
pytorch/pytorch#32868

@leofang
Copy link
Member

leofang commented Feb 4, 2020

Neither does Numba handle readonly arrays. I think supporting read-only device arrays across all participating libraries is just too hard. This requires significant infrastructure changes in every library, if this request wasn't taken into consideration in the original design as in JAX. If I were asked, I would have suggested removing this flag from the __cuda_array_interface__ protocol in the very beginning.

Taking slicing as an example, how do we respect the following action in CuPy when arr is read-only?

arr[..., 0:30] = another_arr

(Not saying it's impossible; I'll be more than happy to learn if I am wrong and if there's an easy way.)

@leofang
Copy link
Member

leofang commented Feb 4, 2020

cc: @gmarkall @sklam @stuartarchibald for awareness

@jakirkham
Copy link
Member Author

At least in the example above, wouldn't it be sufficient to add a check in __setitem__ and error (regardless of the selection) if it is read-only?

@leofang
Copy link
Member

leofang commented Feb 24, 2020

Hi John, sorry I dropped the ball. Indeed, for the simple case here, it might be easy to fix. But aren't this kind of fixes applicable on a case-by-case basis?

AFAIK, in CuPy they have at least a few big categories to be considered: ufuncs / elementwise kernels, reduction kernels, raw kernels, and direct interaction with low-level CUDA APIs and libraries. As @kmaehashi pointed out earlier, the first two categories might be fixable (with considerable effort), but when it comes to passing raw pointers around for the rest, I can't imagine how difficult (or easy?) it could be to mark something read-only.

For example, in the CUDA C space, even __const__ arrays aren't protected (in the most strict sense) and can be modified by a simple memcpy! Not to mention that when moving one level up to the Python-C interface, you'd need to set up a guard for all operations of cupy.cuda.MemoryPointer, the workhorse of cupy.ndarray where the __cuda_array_interface__ attribute lives, so that you don't accidentally do a memcpy for a read-only array. Other examples include any CUDA libraries (like cuFFT) that allows in-place operation as I mentioned earlier.

I think Numba needs to provide a reference implementation for the writable feature first, so that other libraries can adapt. If Numba can't find a way to do it, there's no reason for CuPy to take the initiative given its difficulty and rare usage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:feature New features/APIs prio:low
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants