Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Specialize active callbacks to their own function #735

Closed

Conversation

erick-xanadu
Copy link
Contributor

@erick-xanadu erick-xanadu commented May 13, 2024

Context: Enzyme allows one to specify custom gradients for specific functions. In order to specify custom gradients for callbacks, callbacks need to be specialized to their own specific functions. E.g., instead of having the following code:

func.call @callback(%identifier, %argc, %retc, %0, %1, ..., %m, %n)

And be unable to register custom gradients for @callback. Specialize callbacks to their identifiers like so:

func.func @active_callback_123(%arg0, %arg1, ..., %argm, %argn) {
  %identifier = llvm.constant ...
  %argc = llvm.constant ...
  %retc = llvm.constant ...
  llvm.call @callback(%identifier, %argc, %retc, %arg0, %arg1, ... %argm, %argn)
  return
}

  // ...
  func.call @active_callback_123(%0, %1, ... %m, %n)
  // ..

And now we can register a custom gradient for callback_123 and any other callback_456.

Description of the Change:

  • Iterate over all ActiveCallbackOps and create the specialize function. Each ActiveCallbackOp will be annotated with the specialized function.
  • During the lowering of ActiveCallbackOps to LLVM-IR replace with a call to the specialized function.

[sc-60494]

@erick-xanadu erick-xanadu force-pushed the eochoa/2024-05-09/specialize-callbacks-part-two branch from de89893 to 09d4f02 Compare May 13, 2024 20:56
Copy link

codecov bot commented May 13, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 98.04%. Comparing base (7248c12) to head (00cf68c).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #735   +/-   ##
=======================================
  Coverage   98.04%   98.04%           
=======================================
  Files          69       69           
  Lines        9536     9538    +2     
  Branches      762      763    +1     
=======================================
+ Hits         9350     9352    +2     
  Misses        151      151           
  Partials       35       35           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@erick-xanadu erick-xanadu force-pushed the eochoa/2024-05-09/specialize-callbacks-part-two branch from 09d4f02 to e1a1cc4 Compare May 24, 2024 19:16
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Erick, looks good! Regarding reusing the transformation from results into arguments and memrefs into struct pointers, is that something that would be applicable to this PR?

Comment on lines +145 to 148
"specialize-active-callback-pass",
"annotate-function",
"lower-mitigation",
"lower-gradients",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this go directly before gradient lowering?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was tying the callbacks to gradients because of the registration with enzyme. I was also thinking what is the lowest point in the pipeline it is first needed.

Is there a reason it should be the first thing in the pipeline?

@@ -142,6 +142,7 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt
QUANTUM_COMPILATION_PASS = (
"QuantumCompilationPass",
[
"specialize-active-callback-pass",
"annotate-function",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR but I cannot tell at all what this pass is for or related to (from its name) 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a different name that you prefer? Would outline-active-callback-pass make more sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I was referring to annotate-function

mlir/include/Catalyst/IR/CatalystOps.td Show resolved Hide resolved
mlir/include/Catalyst/Transforms/Passes.td Show resolved Hide resolved
@@ -110,8 +110,47 @@ struct BufferizePythonCallOp : public OpConversionPattern<PythonCallOp> {
bufferArgs.push_back(newBuffer);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An inactive callback has no results right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah maybe I got this wrong because I though debug.callback -> inactive_callback, but I guess the active_callback also calls the inactive callback.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if you were saying that I should get rid of the results in the InactiveCallbackOp. I just did. The reason why I didn't get rid of it is because I was still exploring different ways to implement this. To be honest, I have a better way in mind now 😅 but I will prioritize getting this PR in and subsequent PRs will focus on cleaning up.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I will prioritize getting this PR in and subsequent PRs will focus on cleaning up

Generally I will say that I don't much like the practice of splitting up PRs along a "rough first implementation" / "cleanup" line, because it makes it almost impossible to track the deficiencies of the first PR and make sure they all got addressed in the second PR.
A split along functionality is much more sensible imo.

This PR is fine but something to keep in mind for the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you, but it is difficult to reconcile the "please close tickets before the end of iteration" with "no rough draft implementation PRs" when:

  1. we did not have time to design the whole epic in advance to forsee everything in each story
  2. stories build upon previous stories that may not be used until later.

I do not know if it would be possible to have tickets that do not correspond to merging of PRs but instead to just get the implementation. And a final ticket for merging the implementation.

That to me, seems appropriate for large features like this one. It would satisfy the opinion of:

  1. Small PRs are better (as it help for reviews)
  2. No draft PRs (as everything is already finished except for merging)

But would put a lower priority on "close tickets before the end of the story" and "each ticket must be a PR"

@erick-xanadu
Copy link
Contributor Author

@dime10

Thanks Erick, looks good! Regarding reusing the transformation from results into arguments and memrefs into struct pointers, is that something that would be applicable to this PR?

No. But maybe a future PR might change this. I'm thinking whether using memrefs directly in the specialization and later undergoing the pointer-to-struct ABI transform would be good in terms of readability / invariants. I think so, but I will implement it in a different PR.

@erick-xanadu
Copy link
Contributor Author

Closing in favour of #782

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants