Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State Handling and Integration in the forward Function #121

Open
MichaelFYang opened this issue Oct 5, 2023 · 1 comment
Open

State Handling and Integration in the forward Function #121

MichaelFYang opened this issue Oct 5, 2023 · 1 comment

Comments

@MichaelFYang
Copy link

Hi Albert,

I came across something in the S4Block class that I'd like to understand better. You mentioned in the forward function that state: (H N) is never needed unless you know what you're doing. However, in the kernel class, specifically in the SSMKernelDiag (s4d) class, the forward function seems to accept a state value which is later augmented with B. I have a few questions regarding this:

  1. How and why is the state value augmented with B in the SSMKernelDiag kernel class?
  2. Can I initialize a state using the default_state function and then pass it directly to the forward function, without invoking the _setup_step() function for discretizing the parameters?
  3. I aim to retain a state matrix after processing an entire sequence. Since the original forward function doesn't produce a state output when a state isn't provided, do I need to run the forward_state() function separately to obtain the state?
  4. Is there a specific reason why the forward function of the S4Block class doesn't require a state, while the kernel class does?

I'd really appreciate any insights or explanations you can offer. Thank you!

@albertfgu
Copy link
Contributor

Is there a specific reason why the forward function of the S4Block class doesn't require a state, while the kernel class does?

The kernel class doesn't require it, it's optional.

How and why is the state value augmented with B in the SSMKernelDiag kernel class?

This supports the state forwarding that allows you to compute things "chunkwise" while accepting an initial state and returning the final state. See the README in models/s4/

I aim to retain a state matrix after processing an entire sequence. Since the original forward function doesn't produce a state output when a state isn't provided, do I need to run the forward_state() function separately to obtain the state?

No, just pass in the initial state into the S4Block.

Can I initialize a state using the default_state function and then pass it directly to the forward function, without invoking the _setup_step() function for discretizing the parameters?

I don't actually remember, but I think you still need to call .setup_step() somewhere. I think you can do this in the S4Block or FFTConv module instead of directly on the S4 kernel level.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants