Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load trained weights into agent and get predicted actions #2

Closed
jobesu14 opened this issue Feb 18, 2023 · 11 comments
Closed

Load trained weights into agent and get predicted actions #2

jobesu14 opened this issue Feb 18, 2023 · 11 comments

Comments

@jobesu14
Copy link

Hello, thanks for sharing this amazing piece of work!

Is there an easy way to load the trained weights from the checkpoint.pkl into an agent and get the predicted action from it (agent.policy(obs, state, mode='eval'))['action']). The idea would be to visualize online in a standard pygame loop for instsance?

Looking at the code, I guess the easiest would be to use the dremerv3.Agent class, but I don't understand how to load the weights from the pickle file 😅

@danijar
Copy link
Owner

danijar commented Feb 19, 2023

Hey, I updated the checkpoint code and run scripts to make this easy. You can now train an agent as normal:

python dreamerv3/train.py --run.logdir ~/logdir/train --configs crafter --run.script train

And then load the agent to evaluate it in an environment without further training:

python dreamerv3/train.py --run.logdir ~/logdir/eval --configs crafter \
  --run.script eval_only --run.from_checkpoint ~/logdir/train/checkpoint.pkl

You also asked for a minimal example to load the agent yourself. The relevant code is in dreamerv3/train.py and run/eval_only.py and boils down to:

env = ...
config = ...
step = embodied.Counter()
agent = Agent(env.obs_space, env.act_space, step, config)
checkpoint = embodied.Checkpoint()
checkpoint.agent = agent
checkpoint.load('path/to/checkpoint.pkl', keys=['agent'])
state = None
act, state = agent.policy(obs, state, mode='eval')

@jobesu14
Copy link
Author

jobesu14 commented Feb 19, 2023

Great, thank you so much.

@ThomasRochefortB
Copy link

Hello ! Any idea how the initial state should be formatted? I am trying to run from the minimal code you provided above with a gym environment. However:

env =   # Replace this with your Gym env.
env = from_gym.FromGym(env)
obs=env._env.reset()
obs=env._obs(obs, 0.0, is_first=True)
obs = {k: embodied.convert(v) for k, v in obs.items()}
act, state = agent.policy(obs, state=[None], mode='eval')

returns an error.

From the policy() function I can see that it is expecting:

def policy(self, obs, state, mode='train'):                                              │
│    52 │   self.config.jax.jit and print('Tracing policy function.')                              │
│    53 │   obs = self.preprocess(obs)                                                             │
│ ❱  54 │   (prev_latent, prev_action), task_state, expl_state = state         

Is there a way to initialize the state?

@danijar
Copy link
Owner

danijar commented Feb 20, 2023

You can just pass in None as the first state and from then on pass back the state that it returns.

This is done in jaxagent.py.

@danijar danijar pinned this issue Feb 20, 2023
@danijar danijar closed this as completed Feb 20, 2023
@jobesu14
Copy link
Author

jobesu14 commented Feb 20, 2023

@ThomasRochefortB did you manage to run the minimal snippet successfully?

On my side, I run into an error that seems to come from the observation data being not formatted as expected when passed to the agent policy.

Here is what I did:

Training, everything work well:

python dreamerv3/train.py --logdir ~/logdir/test_1 --configs crafter

And then, when I try to run the minimal snippet inference like that:

LOGDIR = Path('~/logdir/test_1')
config = embodied.Config.load(str(LOGDIR / 'config.yaml'))
env = crafter.Env()  # Replace this with your Gym env.
env = from_gym.FromGym(env)
env = dreamerv3.wrap_env(env, config)
# env = embodied.BatchEnv([env], parallel=False)

step = embodied.Counter()
agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
checkpoint = embodied.Checkpoint()
checkpoint.agent = agent
checkpoint.load(str(LOGDIR / 'checkpoint.ckpt'), keys=['agent'])

obs = env._env.reset()
obs = env._obs(obs, 0.0, is_first=True)
obs = {k: embodied.convert(v) for k, v in obs.items()}
state = None

while True:
    act, state = agent.policy(obs, state, mode='eval')  # error comes from that line
    acts = {k: v for k, v in act.items() if not k.startswith('log_')}
    obs, reward, done, _ = env.step(acts)  # act['action'])

I get an error from that line act, state = agent.policy(obs, state, mode='eval') that point to jaxagent.py line 144: IndexError: tuple index out of range.

@danijar danijar changed the title Minimal code to load trained wieghts into agent and get predicted actions Load trained weights into agent and get predicted actions Feb 21, 2023
@danijar
Copy link
Owner

danijar commented Feb 21, 2023

@jobesu14 The easiest way is to take example.py and replace embodied.run.train(...) at the end with embodied.run.eval_only(agent, env, logger, args). You can also look at embodied/run/eval_only.py for the details and simplify that further as needed.

I think the issue in your snippet is that the policy expects a batch size. I think it should look something like the following but don't have the time to test it right now:

logdir = embodied.Path('~/logdir/test_1')
config = embodied.Config.load(logdir / 'config.yaml')

env = crafter.Env()
env = from_gym.FromGym(env)
env = dreamerv3.wrap_env(env, config)

step = embodied.Counter()
agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)
checkpoint = embodied.Checkpoint()
checkpoint.agent = agent
checkpoint.load(logdir / 'checkpoint.ckpt', keys=['agent'])

state = None
act = {'action': env.act_space['action'].sample(), 'reset': np.array(True)}
while True:
    obs = env.step(act)
    obs = {k: v[None] for k, v in obs.items()}
    act, state = agent.policy(obs, state, mode='eval')
    act = {'action': act['action'][0], 'reset': obs['is_last'][0]}

@jobesu14
Copy link
Author

I didn't maange to make the above snipets work. I kept having issues with the parameters of the agent initial_policy calls.
Here is what worked for me

Basically adding a callback for each step in the embodied/run/eval_only.py script. It is a very minimal change to the original codebase and it also decouple quite nicelly the inner working of Dreamer from the pygame rendering code.

Hope this helps.

@cameronberg
Copy link

cameronberg commented Mar 8, 2023

@danijar Any chance you've been able to figure out rendering from your snippet above? This still produces the error mentioned above for me. It would be amazing to have code for example.py that can generically render gym environments + roll out trained policies.

@danijar
Copy link
Owner

danijar commented Mar 8, 2023

If you env returns an image key as part of the observation dictionary, it will already get rendered and can be viewed in TensorBoard. Does that work for your use case?

@vtopiatech
Copy link

Thanks for such a great research algo @danijar! Wondering if there's any good way now to render the AI playing the game?

@vtopiatech
Copy link

After 2 long days, found the answer based on this issue! Leaving here for anyone who wants to render their DRL AIs playing:

In from_gym.py, add the 4 lines that start with plt:

import matplotlib.pyplot as plt

  def step(self, action):
    if action['reset'] or self._done:
      self._done = False
      obs = self._env.reset()
      return self._obs(obs, 0.0, is_first=True)
    if self._act_dict:
      action = self._unflatten(action)
    else:
      action = action[self._act_key]
    obs, reward, self._done, self._info = self._env.step(action)
    plt.imshow(obs)
    plt.show(block=False)
    plt.pause(0.001)  # Pause to ensure the plot updates
    plt.clf()  # Clear the plot so that the next image replaces this one
    return self._obs(
        obs, reward,
        is_last=bool(self._done),
        is_terminal=bool(self._info.get('is_terminal', self._done)))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants