Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exporting the JAX policy as a TF model #15

Open
LipJ01 opened this issue Feb 24, 2023 · 7 comments
Open

Exporting the JAX policy as a TF model #15

LipJ01 opened this issue Feb 24, 2023 · 7 comments

Comments

@LipJ01
Copy link

LipJ01 commented Feb 24, 2023

Congratulations Danijar on this project and your paper!
Again, not really an issue per se. I understand having read and executed the example.py and most of the code that this project doesn't use Tensorflow in the way I'm familiar with and instead uses jax. I will endeavour to understand myself but I was wondering if it were "simple" to use jax2f to obtain a SavedModel? Ideally after training. Then if I'm feeling really brave I intend to use tfjs-converter to run inference in a web demo.

Update:
A) I realise I can go straight from jax to tfjs.
B) I also realise/think I understand that I'm actually going to have to get 3 nets converted, the world model, actor and critic. Then Implement dreamer in client side javascript.
I'm becoming ever doubtful of my ability to pull this off but the payoff has this occupying my full attention (calendar emptied for next 3 days).

@EelcoHoogendoorn
Copy link

jax2tf allegedly has some limitations; but none that I think would be relevant here. The trained models are simply jax functions and it should be a few lines of code to save them using jax2tf. Certainly worth a try id say.

@LipJ01
Copy link
Author

LipJ01 commented Feb 24, 2023

I feel close. The conversion/export is as simple as follows
3*tfjs.converters.convert_jax(
apply_fn=jax.jit(???),
params=params,
input_signatures=[tf.TensorSpec(???)],
model_dir='/path/to/tfjs_models_directory'
)
So far I've managed to get the params from data['agent'] (probably need different keys, I've not figured that out yet)
I can probably figure out the TensorSpec using 'heuristic exploration'.
I'm really struggling to find where in the code are the jax functions? 😣

@EelcoHoogendoorn
Copy link

ive only used conversion from jax2tf side, not via tfjs, so cant comment on that.

jaxagent.policy seems relevant as a function to export for instance; though im still in the process of figuring out the docstrings myself

@danijar
Copy link
Owner

danijar commented Mar 1, 2023

@LipJ01 The JAX functions are in jaxagent.py. For inference, you'll only need self._policy and self._init_policy.

@danijar danijar changed the title Possible to convert to a TF Model? Exporting the JAX policy as a TF model Mar 7, 2023
@ChrisAGBlake
Copy link

@LipJ01 Have you had any success with this? I'm also looking to export trained models on a custom environment to tensorflow and then ONNX.

@danijar Thank you so much for publishing this. It's game changing, particularly for hard exploration environments or working with "slow" simulators. Other methods I've tested just aren't feasible and are far too time consuming / expensive to train.

@LipJ01
Copy link
Author

LipJ01 commented Jul 30, 2023

@ChrisAGBlake afraid not 🙃

@ChrisAGBlake
Copy link

ChrisAGBlake commented Sep 14, 2023

I managed to convert it to tensorflow but wasn't able to convert to tflite or onnx as it seems to use some operations that aren't supported by those.

For anyone interested here's a bit of an example. This is hard coded for my observation space which is a vector and no image. I'm not sure if this is the best way of doing it but it seems to be working.Manually specifying the function signatures and their modification was necessary to be able to use the tensorflow model from the c++ API but it wasn't required for using it with tensorflow in python.

    # load the trained weights
    checkpoint = embodied.Checkpoint()
    checkpoint.agent = agent
    logdir = embodied.Path(config.logdir)
    checkpoint.load(logdir / 'checkpoint.ckpt', keys=['agent'])

    # modify the init policy function signature so it's compatible with c++
    def mod_init_policy(weights, rng, is_first):
        ((latent, action), task_state, expl_state), _ = agent._init_policy(weights, rng, is_first)
        out = latent
        out['action'] = action
        return out
    
    # modify the policy function signature so it's compatible with c++
    def mod_policy(weights, rng, vector, reward, is_first, is_last, is_terminal, deter, logit, stoch, action):
        obs = {'vector': vector, 'reward': reward, 'is_first': is_first, 'is_last': is_last, 'is_terminal': is_terminal}
        state = (({'deter': deter, 'logit': logit, 'stoch': stoch}, action), {}, {})
        (outs, state), _ = agent._policy(weights, rng, obs, state)
        (latent, action), _, _ = state
        out = latent
        out['action'] = action
        for k, v in outs.items():
            out[f'outs_{k}'] = v
        return out
        
    # export policy init to tensorflow
    tf_agent = tf.Module()
    weights = tf.nest.map_structure(tf.Variable, agent.varibs)
    init_f = lambda rng, is_first: jax2tf.convert(mod_init_policy)(weights, rng, is_first)
    policy_f = lambda rng, vector, reward, is_first, is_last, is_terminal, deter, logit, stoch, action: jax2tf.convert(mod_policy)(weights, rng, vector, reward, is_first, is_last, is_terminal, deter, logit, stoch, action)
    tf_agent._variables = tf.nest.flatten(weights)
    tf_agent.init_policy = tf.function(init_f, autograph=False)
    tf_agent.policy = tf.function(policy_f, autograph=False)
    rng = agent._next_rngs(agent.policy_devices)
    obs = agent._dummy_batch(env.obs_space, (1,))
    out = tf_agent.init_policy(rng, obs['is_first'])
    prev_latent = {k: v for k, v in out.items() if k != 'action'}
    prev_action = out['action']
    init_call = tf_agent.init_policy.get_concrete_function(rng, obs['is_first'])
    rng = agent._next_rngs(agent.policy_devices)
    obs = agent._dummy_batch(env.obs_space, (1,))
    vector = obs['vector']
    reward = obs['reward']
    is_first = obs['is_first']
    is_last = obs['is_last']
    is_terminal = obs['is_terminal']
    deter = prev_latent['deter']
    logit = prev_latent['logit']
    stoch = prev_latent['stoch']
    tf_agent.policy(rng, vector, reward, is_first, is_last, is_terminal, deter, logit, stoch, prev_action)
    call = tf_agent.policy.get_concrete_function(rng, vector, reward, is_first, is_last, is_terminal, deter, logit, stoch, prev_action)
    tf.saved_model.save(tf_agent, 'models/tf_agent', signatures={'policy_init': init_call, 'policy': call})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants