Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass inputs into the LDS model #329

Open
weigcdsb opened this issue Jul 19, 2023 · 3 comments
Open

pass inputs into the LDS model #329

weigcdsb opened this issue Jul 19, 2023 · 3 comments

Comments

@weigcdsb
Copy link

Hello,

I have a very basic question: how to pass N X T X D inputs ("X") into the LDS model (N trials, T time steps and D dimensional inputs)?

In the linear_gaussian_ssm model.py file, the inputs is Optional[Float[Array, "ntime input_dim"]], so there's no dimension for trials (N)?

I tried to do things as in the Kalman filter/ smoother example. But the problem is that I also need to include d latent trajectoreis into the model (i.e. the state dimension should be D + d, if I encode the covariates into the emission matrix).

Not sure how to do it correctly...

@gileshd
Copy link
Collaborator

gileshd commented Jul 25, 2023

Hi @weigcdsb, I'm not sure I totally understand your use case, would you be able to explain it in some more detail and we'll see if I can help 😄.

In general, it should be possible to use jax.vmap to map filtering/smoothing over additional dimensions (as described here), however this might be be suitable for all scenarios.

@weigcdsb
Copy link
Author

@gileshd, thanks for replying & sorry for confusions.

Just use the notations in the comment of your models.py file:
$$p(y_t \mid z_t) = \mathcal{N}(y_t \mid H_t z_t + D_t u_t + d_t, R_t)$$
, where $p(z_t \mid z_{t-1}, u_t) = \mathcal{N}(z_t \mid F_t z_{t-1} + B_t u_t + b_t, Q_t)$ and $p(z_1) = \mathcal{N}(z_1 \mid m, S)$, for $t=1,\ldots,T$. Here, $u_t$ is an input of size input_dim (assume input_dim=D, defaults to 0). If there are $N$ observations, then emission_dim = N. So the total inputs (stack all $u_t$ together) should have dimension $N\times D\times T$.

My question is how can I pass the input $u_t$ into the LDS model? In the linear_gaussian_ssm models.py file, the comment says inputs: Optional[Float[Array, "ntime input_dim"]]=None, which means the dimension should be $T \times D$. So there's no option for multiple emissions, say $N>1$ (as we cannot pass 3D-array to the model)?

Hope this clarifies my question.

@murphyk
Copy link
Member

murphyk commented Jul 26, 2023

Correct. The input vector u_t at each time step must be a D-dimensional vector. So inputs has shape (T,D) (or None). You can always flatten your 3d inputs outside of dynamax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants