jax to tflite conversion
#20830
-
The following fails if I specify the backend as cpu/gpu with the following error:
import jax
import jax.numpy as jnp
import tensorflow as tf
if __name__ == "__main__":
x_input = jnp.zeros((4624, 3468, 3))
predict = jax.jit(lambda x: x, backend='cpu')
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[predict], [[('input1', x_input)]])
tflite_model = converter.convert()
# check
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], x_input)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]['index']) What is the recommended way to convert jitted functions to tflite? |
Beta Was this translation helpful? Give feedback.
Answered by
gnecula
Apr 24, 2024
Replies: 2 comments 2 replies
-
@gnecula might have suggestions here. |
Beta Was this translation helpful? Give feedback.
1 reply
-
JAX on cpu/gpu does use custom calls occasionally. I think that for the core JAX primitives the TFLite converter should handle them, so this is a request that should be made to the TFLite team. |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
ASEM000
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
JAX on cpu/gpu does use custom calls occasionally. I think that for the core JAX primitives the TFLite converter should handle them, so this is a request that should be made to the TFLite team.