New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix random_mps_impl to accept non-contiguous tensors #125231
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125231
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 New Failures, 4 Unrelated FailuresAs of commit 3b7efae with merge base b03fb49 (): NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
random_mps_impl is used in other operations in Distribution. Should I add tests for those too? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
@pytorchbot merge |
Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra. |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
357dda6
to
a569b52
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-12-py3-arm64-mps / test (mps, 1, 1, macos-m1-stable) Details for Dev Infra teamRaised by workflow job |
Hey @kulinseth , wanted to double check removing this test is okay since we can accept multiple elements referencing the same location now |
@@ -2003,13 +2003,6 @@ def sample_inputs_bernoulli(self, device, dtype, requires_grad, **kwargs): | |||
requires_grad=requires_grad) | |||
yield SampleInput(t) | |||
|
|||
def error_inputs_bernoulli(op_info, device, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this test removed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bernuoulli uses random_mps_impl under the hood so the error wasn't valid anymore since it support non-contiguous tenors now too. So I removed it but on second thought, it might be useful to keep the test as a passing test for documentation.
I added it back to the sample_inputs_bernoulli method. Open to other ideas too!
Fixes #124029
Follows the pattern of allocating contiguous memory and copying the results in https://github.com/pytorch/pytorch/pull/123049/files