Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request for a map function like map_fn in TF and vmap in Jax #19708

Open
yifengshao opened this issue May 10, 2024 · 5 comments
Open

Request for a map function like map_fn in TF and vmap in Jax #19708

yifengshao opened this issue May 10, 2024 · 5 comments
Assignees
Labels
type:feature The user is asking for a new feature.

Comments

@yifengshao
Copy link

Currently, it seems there is no function to map a function to a tensor in keras 3.0. Such a function should do what map_fn in TF and vmap in Jax do. Otherwise, it is not very challenging to switch between the backends.

Perhaps I missed something here could anyone provide any hint? Thanks!

@fchollet
Copy link
Member

Do you mean keras.ops.vectorized_map?

@yifengshao
Copy link
Author

Hi François,

Thank you for your quick response.

Sorry I am not so familiar with Jax. Now I found that vmap is similar to vectorized_map in TF and keras.

I am particularly interested in map_fn because my operation cannot be vectorialized due to the large intermediate variable generated during the computation. I used map_fn excessively in my project to simulate some physical processes.

I understand that (1) tf.map_fn uses while_loop under the hood and (2) both tf and jax will convert the python loop to graph using while_loop.

However, my issue is that when using (2), I cannot set the parallel_iterations, and (1) is currently unavailable in Keras 3.0. I am now trying to make a map_fn function by myself using while_loop. One puzzle for me is that tf.map_fn uses tf.TensorArray to accumulate the result during the iteration, which is also unavailable in Keras 3.0. It would be very useful if there were some examples in Keras on this task.

Here is an example of my code about using map_fn:

import numpy as np
import tensorflow as tf
from keras import ops

data = np.random.randn(3, 1024, 1024)
data = ops.convert_to_tensor(data)

dataFT = tf.map_fn(
lambda elem: ops.fft2(elem),
elems=(ops.real(data), ops.imag(data)),
fn_output_signature=(data.dtype, data.dtype),
)

As you can see here, this code does not work with Keras using the Jax backend. Thank you very much for any hints.

P.S. Another change in Keras that significantly influenced my application is that lays does not support complex variables any more. This is very strange because keras.ops provides complex-conjugate, but I cannot pass complex variables from one layer to another.

Kind regards,
Yifeng Shao

@fchollet
Copy link
Member

fchollet commented May 12, 2024

Another op you can try is keras.ops.vectorize, which is equivalent to np.vectorize and is effectively the same as vmap but with a nicer syntax.

  def myfunc(a, b):
      return a + b

  vfunc = keras.ops.vectorize(myfunc)
  y = vfunc([1, 2, 3, 4], 2)  # Returns Tensor([3, 4, 5, 6])

Now, if you want to use tf.map_fn specifically, you can also use that with the TF backend.

@edwardyehuang
Copy link
Contributor

edwardyehuang commented May 12, 2024

Hi François,

Thank you for your quick response.

Sorry I am not so familiar with Jax. Now I found that vmap is similar to vectorized_map in TF and keras.

I am particularly interested in map_fn because my operation cannot be vectorialized due to the large intermediate variable generated during the computation. I used map_fn excessively in my project to simulate some physical processes.

I understand that (1) tf.map_fn uses while_loop under the hood and (2) both tf and jax will convert the python loop to graph using while_loop.

However, my issue is that when using (2), I cannot set the parallel_iterations, and (1) is currently unavailable in Keras 3.0. I am now trying to make a map_fn function by myself using while_loop. One puzzle for me is that tf.map_fn uses tf.TensorArray to accumulate the result during the iteration, which is also unavailable in Keras 3.0. It would be very useful if there were some examples in Keras on this task.

Here is an example of my code about using map_fn:

import numpy as np import tensorflow as tf from keras import ops

data = np.random.randn(3, 1024, 1024) data = ops.convert_to_tensor(data)

dataFT = tf.map_fn( lambda elem: ops.fft2(elem), elems=(ops.real(data), ops.imag(data)), fn_output_signature=(data.dtype, data.dtype), )

As you can see here, this code does not work with Keras using the Jax backend. Thank you very much for any hints.

P.S. Another change in Keras that significantly influenced my application is that lays does not support complex variables any more. This is very strange because keras.ops provides complex-conjugate, but I cannot pass complex variables from one layer to another.

Kind regards, Yifeng Shao

In TensorFlow, the tf.map_fn is different with tf.vectorized_map
tf.map_fn
tf.vectorized_map

In JAX, the jax.vmap is similar as tf.vectorized_map in TensorFlow.
In numpy, the np.vectorize is similar as tf.map_fn in TensorFlow.

@yifengshao
Copy link
Author

Dear Edward,

Thank you for your further clarification.

It seems that the map_fn function is unique for tensorflow and no similar function can be found in other projects. In physics simulations, I believe such a function is very important.

Could you let me know what will happen when converting a Python loop (e.g. pre-allocate the memory by initiating an empty variable and then fill the element through a loop) to a graph? Is this equivalent to map_fn?

import numpy as np 
import keras

data = np.random.randn(3, 1024, 1024) 

data_real = np.zeros_like(data)
data_imag = np.zeros_like(data)

for ind in np.arange(data.shape[0]):
    data_real[ind], data_imag[ind] = keras.ops.fft2((ops.real(data[ind]), ops.imag(data[ind]))

It seems that such a practice is not common in the machine machine-learning community... Thanks a lot for any help here.

Kind regards,
Yifeng

@SuryanarayanaY SuryanarayanaY added the type:feature The user is asking for a new feature. label May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:feature The user is asking for a new feature.
Projects
None yet
Development

No branches or pull requests

4 participants