Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Obtain World Model Predictions during Inference. #110

Open
defrag-bambino opened this issue Feb 27, 2024 · 0 comments
Open

Obtain World Model Predictions during Inference. #110

defrag-bambino opened this issue Feb 27, 2024 · 0 comments

Comments

@defrag-bambino
Copy link

Hi,

how may I obtain the predictions of the World Model during Inference?
I have tried this command in a simple inference loop, but it throws an error: agent.agent.wm.imagine(agent.policy, obs, 10)

Error & Stacktrace

│                                                                             │
│   57 │   act = {'action': act['action'][0], 'reset': obs['is_last'][0]}     │
│   58 │                                                                      │
│   59 │   if i > 100:                                                        │
│ ❱ 60 │     agent.agent.wm.imagine(agent.policy, obs, 10)                    │
│   61                                                                        │
│   62                                                                        │
│   63                                                                        │
│                                                                             │
│ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/ninjax.py:380 in wrapper    │
│                                                                             │
│   377   def wrapper(self, *args, **kwargs):                                 │
│   378 │   with scope(self._path, absolute=True):                            │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                          │
│   381   return wrapper                                                      │
│   382                                                                       │
│   383                                                                       │
│                                                                             │
│ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/agent.py:183 in imagine     │
│                                                                             │
│   180                                                                       │
│   181   def imagine(self, policy, start, horizon):                          │
│   182 │   first_cont = (1.0 - start['is_terminal']).astype(jnp.float32)     │
│ ❱ 183 │   keys = list(self.rssm.initial(1).keys())                          │
│   184 │   start = {k: v for k, v in start.items() if k in keys}             │
│   185 │   start['action'] = policy(start)                                   │
│   186 │   def step(prev, _):                                                │
│                                                                             │
│ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/ninjax.py:380 in wrapper    │
│                                                                             │
│   377   def wrapper(self, *args, **kwargs):                                 │
│   378 │   with scope(self._path, absolute=True):                            │
│   379 │     with jax.named_scope(self._path.split('/')[-1]):                │
│ ❱ 380 │   │   return method(self, *args, **kwargs)                          │
│   381   return wrapper                                                      │
│   382                                                                       │
│   383                                                                       │
│                                                                             │
│ /home/fabian/Desktop/fpv/py/dreamerv3/dreamerv3/nets.py:34 in initial       │
│                                                                             │
│    31   def initial(self, bs):                                              │
│    32 │   if self._classes:                                                 │
│    33 │     state = dict(                                                   │
│ ❱  34 │   │     deter=jnp.zeros([bs, self._deter], f32),                    │
│    35 │   │     logit=jnp.zeros([bs, self._stoch, self._classes], f32),     │
│    36 │   │     stoch=jnp.zeros([bs, self._stoch, self._classes], f32))     │
│    37 │   else:                                                             │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/num │
│ py/lax_numpy.py:2317 in zeros                                               │
│                                                                             │
│   2314   if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise  │
│   2315   dtypes.check_user_dtype_supported(dtype, "zeros")                  │
│   2316   shape = canonicalize_shape(shape)                                  │
│ ❱ 2317   return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_t │
│   2318                                                                      │
│   2319 @util.implements(np.ones)                                            │
│   2320 def ones(shape: Any, dtype: DTypeLike | None = None, *,              │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/lax │
│ /lax.py:1226 in full                                                        │
│                                                                             │
│   1223 │   return dtype._rules.full(shape, fill_value, dtype)  # type: igno │
│   1224   weak_type = dtype is None and dtypes.is_weakly_typed(fill_value)   │
│   1225   dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))     │
│ ❱ 1226   fill_value = _convert_element_type(fill_value, dtype, weak_type)   │
│   1227   out = broadcast(fill_value, shape)                                 │
│   1228   if sharding is not None:                                           │
│   1229 │   return array.make_array_from_callback(shape, sharding, lambda id │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/lax │
│ /lax.py:560 in _convert_element_type                                        │
│                                                                             │
│    557 │   │      isinstance(core.get_aval(operand), core.ConcreteArray))): │
│    558 │   return type_cast(Array, operand)                                 │
│    559   else:                                                              │
│ ❱  560 │   return convert_element_type_p.bind(operand, new_dtype=new_dtype, │
│    561 │   │   │   │   │   │   │   │   │      weak_type=bool(weak_type))    │
│    562                                                                      │
│    563 def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) - │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/cor │
│ e.py:444 in bind                                                            │
│                                                                             │
│    441   def bind(self, *args, **params):                                   │
│    442 │   assert (not config.enable_checks.value or                        │
│    443 │   │   │   all(isinstance(arg, Tracer) or valid_jaxtype(arg) for ar │
│ ❱  444 │   return self.bind_with_trace(find_top_trace(args), args, params)  │
│    445                                                                      │
│    446   def bind_with_trace(self, trace, args, params):                    │
│    447 │   out = trace.process_primitive(self, map(trace.full_raise, args), │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/cor │
│ e.py:447 in bind_with_trace                                                 │
│                                                                             │
│    444 │   return self.bind_with_trace(find_top_trace(args), args, params)  │
│    445                                                                      │
│    446   def bind_with_trace(self, trace, args, params):                    │
│ ❱  447 │   out = trace.process_primitive(self, map(trace.full_raise, args), │
│    448 │   return map(full_lower, out) if self.multiple_results else full_l │
│    449                                                                      │
│    450   def def_impl(self, impl):                                          │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/cor │
│ e.py:935 in process_primitive                                               │
│                                                                             │
│    932   lift = sublift = pure                                              │
│    933                                                                      │
│    934   def process_primitive(self, primitive, tracers, params):           │
│ ❱  935 │   return primitive.impl(*tracers, **params)                        │
│    936                                                                      │
│    937   def process_call(self, primitive, f, tracers, params):             │
│    938 │   return primitive.impl(f, *tracers, **params)                     │
│                                                                             │
│ /home/fabian/miniconda3/envs/drmV3/lib/python3.9/site-packages/jax/_src/dis │
│ patch.py:87 in apply_primitive                                              │
│                                                                             │
│    84   if xla_extension_version >= 218:                                    │
│    85 │   prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)     │
│    86 │   try:                                                              │
│ ❱  87 │     outs = fun(*args)                                               │
│    88 │   finally:                                                          │
│    89 │     lib.jax_jit.swap_thread_local_state_disable_jit(prev)           │
│    90   else:                                                               │
╰─────────────────────────────────────────────────────────────────────────────╯
XlaRuntimeError: INVALID_ARGUMENT: Disallowed host-to-device transfer: 
aval=ShapedArray(float32[]), dst_sharding=GSPMDSharding({replicated})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant