Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Enable multi threaded compilation #655

Open
wants to merge 35 commits into
base: main
Choose a base branch
from

Conversation

grwlf
Copy link
Contributor

@grwlf grwlf commented Apr 10, 2024

Context: MLIR supports multi-threaded compilation which was disabled with a Todo notice.

Description of the Change: By this PR we add an qjit option enabling the MLIR multi-threaded compilation. We also make some precautions and infrastructure updates:

  • Verbose dump handlers are protected from the race for the filename counter.
  • Multi-threading flag has been added to the top-level qjit API.
  • Timing stats are now printed to dedicated diagnostic streams.

Note: measurement results (internal document)

Benefits:

  • Faster compilation on programs containing functions
  • Capturing timings works from Python

Possible Drawbacks: Multi-threading issues might be uncovered(?)

Related GitHub Issues:

[sc-59511]

@grwlf grwlf force-pushed the enable-multi-threaded-compilation branch from 7493f15 to 5143cad Compare April 10, 2024 10:30
Copy link

codecov bot commented Apr 10, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 98.10%. Comparing base (a723cd1) to head (1bbf6e0).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #655      +/-   ##
==========================================
- Coverage   99.95%   98.10%   -1.85%     
==========================================
  Files          20       69      +49     
  Lines        4444     9556    +5112     
  Branches        0      747     +747     
==========================================
+ Hits         4442     9375    +4933     
- Misses          2      147     +145     
- Partials        0       34      +34     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@grwlf grwlf marked this pull request as ready for review April 10, 2024 13:29
@grwlf grwlf requested a review from dime10 April 10, 2024 13:29
Copy link

Hello. You may have forgotten to update the changelog!
Please edit doc/changelog.md on your branch with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@dime10
Copy link
Collaborator

dime10 commented Apr 10, 2024

Possible Drawbacks: Some multi-threading MLIR issues might be uncovered?

Did you find any?

@grwlf
Copy link
Contributor Author

grwlf commented Apr 16, 2024

Possible Drawbacks: Some multi-threading MLIR issues might be uncovered?

Did you find any?

Nope, everything seems to be fine so far.

@grwlf
Copy link
Contributor Author

grwlf commented Apr 17, 2024

Possible Drawbacks: Some multi-threading MLIR issues might be uncovered?

Did you find any?

Nope, everything seems to be fine so far.

One thing we should test this with is the instrumentation. I don't think we have frontend tests for that feature yet though, until #597 is merged.

Do you want me to do this or are you going to check?

@erick-xanadu
Copy link
Contributor

@dime10 @grwlf what do you guys think about turning it off by default and having it in the release and then turning it on after the release as default to test more thoroughly?

@dime10
Copy link
Collaborator

dime10 commented Apr 24, 2024

@dime10 @grwlf what do you guys think about turning it off by default and having it in the release and then turning it on after the release as default to test more thoroughly?

Great idea!

@erick-xanadu erick-xanadu added this to the v0.6.0 milestone Apr 24, 2024
@dime10 dime10 removed this from the v0.6.0 milestone Apr 25, 2024
@grwlf grwlf requested a review from erick-xanadu April 29, 2024 13:07
@grwlf
Copy link
Contributor Author

grwlf commented Apr 29, 2024

what do you guys think about turning it off

@erick-xanadu I believe it is off by default currently

@@ -303,7 +302,7 @@ def circuit():
assert "While processing 'TestPass' pass of the 'PipelineB' pipeline" in e.value.args[0]
assert "PipelineA" not in e.value.args[0]
assert "Trace" not in e.value.args[0]
assert isfile(os.path.join(str(compiled.workspace), "2_PipelineB_FAILED.mlir"))
assert isfile(os.path.join(str(compiled.workspace), "3_1_PipelineB_FAILED.mlir"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there two numbers now? E.g., 3_1 vs 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The first number iterates over hand-coded stages that we have. We increase it manually, e.g. here. Here we rely on the fact that compiler driver itself is not multi-threaded.
  • The second number is the index of the pipeline in the list of pipelines provided by the user. These dumps might happen from different threads, but the file names will be different.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand it fully, but is there a way to just enumerate them with a single number instead of this tuple? I don't think there is a difference between the pipelines provided by us and the user.

Copy link
Contributor Author

@grwlf grwlf May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible of course, but it would complicate the design with the arithmetic which I think is not necessary. We would need to add a number of pipelines to the counter at some point. But in reality we only want the filenames to be sorted.

// Install signal handler to catch user interrupts (e.g. CTRL-C).
signal(SIGINT,
[](int code) { throw std::runtime_error("KeyboardInterrupt (SIGINT)"); });

std::unique_ptr<CompilerOutput> output(new CompilerOutput());
std::unique_ptr<CompilerOutput> output(new CompilerOutput(initDumpCounter(workspace)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, here is where you create it. How about instead of having the initDumpCounter function, just have run_compiler_driver take an int pararameter which defaults to 1. And have the canonicalizer set it to 0.

I would like to avoid having the initDumpCounter function.

Copy link
Contributor Author

@grwlf grwlf May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erick-xanadu in this case we will would need to write a Python analog of initDumpCounter counting the canonicalizer's artifacts and setting the integer to the right value. I'm ok with this solution, but do you think it is worth the time to implement? Or did you mean something else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The canonicalizer will always have 1 single output. I don't think we need to count anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? What if someone change it?

@erick-xanadu
Copy link
Contributor

Looking into MLIR's documentation:

The structure of the pass manager, and the concept of nesting, is detailed further below. All passes in MLIR derive from OperationPass and adhere to the following restrictions; any noncompliance will lead to problematic behavior in multithreaded and other advanced scenarios

and later below

Must not modify the state of operations other than the operations that are nested under the current operation. This includes adding, modifying or removing other operations from an ancestor/parent block.

Other threads may be operating on these operations simultaneously.
As an exception, the attributes of the current operation may be modified freely. This is the only way that the current operation may be modified. (I.e., modifying operands, etc. is not allowed.)

I believe we have not been following this. Specifically in gradients and other passes, we add functions to the module.

I think we should not make this the default yet. But if we are willing to increase testing time, we could add small set of compilation tests every so often to see if we catch any errors.

I'll be running this PR in a bit looking for potential issues (or a reason why they haven't been seen). Thanks!

@erick-xanadu
Copy link
Contributor

It looks like one of the reasons why we are not taking advantage of parallelism is because all our passes act on a module scope:

include/Catalyst/Transforms/Passes.td:def CatalystBufferizationPass : Pass<"catalyst-bufferize"> {
include/Catalyst/Transforms/Passes.td:def ArrayListToMemRefPass : Pass<"convert-arraylist-to-memref"> {
include/Catalyst/Transforms/Passes.td:def CatalystConversionPass : Pass<"convert-catalyst-to-llvm"> {
include/Catalyst/Transforms/Passes.td:def ScatterLoweringPass : Pass<"scatter-lowering"> {
include/Catalyst/Transforms/Passes.td:def HloCustomCallLoweringPass : Pass<"hlo-custom-call-lowering"> {
include/Catalyst/Transforms/Passes.td:def QnodeToAsyncLoweringPass : Pass<"qnode-to-async-lowering"> {
include/Catalyst/Transforms/Passes.td:def AddExceptionHandlingPass : Pass<"add-exception-handling"> {
include/Catalyst/Transforms/Passes.td:def GEPInboundsPass : Pass<"gep-inbounds"> {
include/Gradient/Transforms/Passes.td:def GradientBufferizationPass : Pass<"gradient-bufferize"> {
include/Gradient/Transforms/Passes.td:def GradientLoweringPass : Pass<"lower-gradients"> {
include/Gradient/Transforms/Passes.td:def GradientConversionPass : Pass<"convert-gradient-to-llvm"> {
include/Mitigation/Transforms/Passes.td:def MitigationLoweringPass : Pass<"lower-mitigation"> {
include/Quantum/Transforms/Passes.td:def QuantumBufferizationPass : Pass<"quantum-bufferize"> {
include/Quantum/Transforms/Passes.td:def QuantumConversionPass : Pass<"convert-quantum-to-llvm"> {
include/Quantum/Transforms/Passes.td:def EmitCatalystPyInterfacePass : Pass<"emit-catalyst-py-interface"> {
include/Quantum/Transforms/Passes.td:def CopyGlobalMemRefPass : Pass<"cp-global-memref"> {
include/Quantum/Transforms/Passes.td:def AdjointLoweringPass : Pass<"adjoint-lowering"> {
include/Quantum/Transforms/Passes.td:def RemoveChainedSelfInversePass : Pass<"remove-chained-self-inverse"> {
include/Quantum/Transforms/Passes.td:def AnnotateFunctionPass : Pass<"annotate-function"> {
include/Test/Transforms/Passes.td:def TestPass : Pass<"test"> {

I think in order to take advantage, we would need to change these to Pass<"foo", FuncOp> or similar. Let me give it a try.

@erick-xanadu
Copy link
Contributor

erick-xanadu commented May 8, 2024

I am currently running into a weird error 😅 All test pass except one for printing the IR. When the IR is printed, it gets printed nine times. Even if it was not printed nine times, the test would fail.

Can you reproduce it locally? I've re-triggered a run of the code.

FAILED frontend/test/pytest/test_debug.py::TestPrintStage::test_hlo_lowering_stage - AssertionError: assert 'stablehlo.constant' not in 'module @fun...rn\n  }\n}\n'          

Yes, also seen on CI

@grwlf
Copy link
Contributor Author

grwlf commented May 21, 2024

I am currently running into a weird error 😅 All test pass except one for printing the IR. When the IR is printed, it gets printed nine times. Even if it was not printed nine times, the test would fail.

Can you reproduce it locally? I've re-triggered a run of the code.

FAILED frontend/test/pytest/test_debug.py::TestPrintStage::test_hlo_lowering_stage - AssertionError: assert 'stablehlo.constant' not in 'module @fun...rn\n  }\n}\n'          

Yes, also seen on CI

Oh, it turns out I broke the PR by my 'simplification' commit. I have reverted it and added comments. The problem happened because the dump handler was triggered by all passes instead of passes which end a pipeline.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants