Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RNN support for Pytorch #850

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

JanFSchulte
Copy link
Contributor

Adds support for RNN layers (GRU, LSTM, RNN) to the pytorch parser.

Caveat: We currently lack implementation for getitem operations, so we can currently not return the hidden state after the calculations

Caveat 2: We currently only support a single recurrent layers, whereas multiple within the same RNN instance are supported by pytorch

Caveat 3: We currently don't support the passing of non-zero initial values for the hidden states to the RNN

So this implementation is slightly hacky at the moment, but might serve as a starting point for discussion, and can be used by interested parties if they can life with the current limitations.

Also, this contains parts of #848 because I was inattentive.

Type of change

For a new feature or function, please create an issue first to discuss it
with us before submitting a pull request.

Note: Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)

Tests

Added pytests to confirm that the layers work.

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

@JanFSchulte JanFSchulte marked this pull request as ready for review August 17, 2023 14:15
@vloncar vloncar added the please test Trigger testing by creating local PR branch label Aug 17, 2023
@vloncar
Copy link
Contributor

vloncar commented Aug 17, 2023

pre-commit.ci autofix

@jmitrevs jmitrevs added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Oct 20, 2023
@jmitrevs jmitrevs added this to the v1.0.0 milestone Oct 20, 2023
@jmitrevs
Copy link
Contributor

The tests fail with:

FAILED test_pytorch_api.py::test_skipped_layers[io_parallel-Vivado] - TypeError: config_from_pytorch_model() got an unexpected keyword argument 'inputs_channel_last'
FAILED test_pytorch_api.py::test_skipped_layers[io_parallel-Quartus] - TypeError: config_from_pytorch_model() got an unexpected keyword argument 'inputs_channel_last'
FAILED test_pytorch_api.py::test_skipped_layers[io_stream-Vivado] - TypeError: config_from_pytorch_model() got an unexpected keyword argument 'inputs_channel_last'
FAILED test_pytorch_api.py::test_skipped_layers[io_stream-Quartus] - TypeError: config_from_pytorch_model() got an unexpected keyword argument 'inputs_channel_last'

@jmitrevs jmitrevs added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels May 3, 2024
@JanFSchulte
Copy link
Contributor Author

All test failures the last time around seemed to be related to issues with the tests themselves, which I have mostly fixed. The only change I made was to add missing includes to some Quartus templates to fix compiliation errors when uint_8 was used.

There are currently still some remaining test failures with the case when activations are used in their nn.functionals implementation instead of as classes. Here I can't reproduce the failures in a standalone file, the exact same code that fails in the pytest works fine running in standalone python. Have not figured out how to debug it in those circumstances.

@jmitrevs jmitrevs added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels May 9, 2024
@jmitrevs jmitrevs added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels May 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants