-
I am certain that this functionality would be implemented in jax, but I couldn't find it.. Could you help me? |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Apr 20, 2024
Replies: 1 comment 2 replies
-
You can do this: x = jnp.zeros((4, 4))
y = x
print(x is y) # Python-level object equality
# True
print(x.unsafe_buffer_pointer() == y.unsafe_buffer_pointer()) # memory pointer equality
# True But note that, as the method states, this is not safe: it is implementation-dependent, and the exact behavior may change. Also, keep in mind that this only exists in eager mode: there's no way to do this kind of check within transformations like |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
sh0416
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can do this:
But note that, as the method states, this is not safe: it is implementation-dependent, and the exact behavior may change. Also, keep in mind that this only exists in eager mode: there's no way to do this kind of check within transformations like
jit
, because in general the compiler needs to be free to make choices about the storage of arrays (and whether to materialize them at all).