New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize RAM to VRAM transfer #6312
Conversation
This is a huge speed up!!! Awesome. Will wait for @RyanJDick to take a look |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! I tested it out with some simple T2I workflows, and saw the speedup, as promised.
I left a few comments. Once those are addressed, I'll run it through its paces with a bunch of model types to make sure there aren't any weird edge cases.
invokeai/backend/model_manager/load/model_cache/model_cache_default.py
Outdated
Show resolved
Hide resolved
invokeai/backend/model_manager/load/model_cache/model_cache_default.py
Outdated
Show resolved
Hide resolved
invokeai/backend/model_manager/load/model_cache/model_cache_default.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran into a bug during testing. I hadn't thought about this before, but this approach breaks if a model is moved between devices while a patch is applied. I can trigger this by using a TI. The series of events is:
- The text encoder is loaded and registered with the model cache.
- We apply the TI to the text encoder while the text encoder is on the CPU. This patch creates a new tensor of token embeddings with a different shape.
- We attempt to move the text encoder to the GPU. This operation fails because the state_dict tensor sizes no longer match.
We could probably find a quick way to solve this particular problem, but it makes me worry about the risk of similar bugs. We need clear rules for how the model cache and model patching are intended to interact.
One approach would be to require that models are patched and unpatched during the span of a model cache lock. TIs are a little weird in that the patch is applied on the CPU before copying the model to GPU. We should look into whether we can just do all of this on the GPU. If not, we may have to consider splitting the concepts of model access locking and model device locking.
invokeai/backend/model_manager/load/model_cache/model_cache_default.py
Outdated
Show resolved
Hide resolved
@RyanJDick I'm not all that familiar with model patching. Is patching done prior to every generation and then reversed? If so, the trick would be to refresh the cached |
Patching (for LoRA or TI) is managed using context managers (applied on entry, and reversed on exit). Examples:
Now that the model cache has the power to modify a model's weights (restore them to a previous state), we need clearer ownership semantics (i.e. who can modify a model?, when can they modify it? what guarantees do they have to offer?). Designing this well would take more thought / effort than I can spend on it right now. We might be able to take a shortcut to get this working now though. I think this might be achievable with some combination of:
More investigation needed to figure out which of those makes the most sense. |
@RyanJDick I finally got back to this after an interlude. It was a relatively minor fix to get all the model patching done after loading the model into the target device, and the code is cleaner too. I've tested LoRA, TI and clip skip, and they all seem to be working as expected. Seamless doesn't seem to do much of anything, either with this PR or on current |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the overall strategy, but I'm having trouble wrapping my head around the fix for models changing device. If I understand correctly, the solution is very simple - re-order the context managers. Can you ELI5 how changing the order of the context managers fixes this?
Also curious about this edge case - say we have two compel nodes:
- We execute compel node 1. At this time, the models are in VRAM.
- Time passes and we load other models, evicting the UNet and CLIP from VRAM. Maybe they are in RAM, maybe they aren't cached at all.
- We execute compel node 2. Is this a problem?
invokeai/backend/model_manager/load/model_cache/model_cache_base.py
Outdated
Show resolved
Hide resolved
Note: I tested this PR to see if it fixed #6375. It does not. |
The context managers were reordered so that the context manager calls that lock the model in VRAM are executed before the patches are applied, and it is the locked model that is passed to the patchers. I also switched the relative order of the TI and LoRA patchers, but only because it made the code formatting easier to read. I tested both orders and got identical images. Here's the edge case:
RAM->VRAM operations are about twice as fast as VRAM->RAM on my system. I am tempted to remove the VRAM cache entirely so that we are guaranteed to have a fresh copy of the model weights each time. However, if the patchers are unpatching correctly, this shouldn't be an issue. |
Rats. I was rather hoping it would. I'm digging into the LoRA loading issue now. |
It is working for me as well. I just had to adjust the image dimensions to see the effect. Seamless is not something I ever use. |
@psychedelicious @RyanJDick I have included a fix for #6375 in this PR. There was some old model cache code originally written by Stalker that traversed the garbage collector and forcibly deleted local variables from unused stack frames. This code was written to work around a Python 3.9 GC bug, but it seems to wreak havoc on context managers. The low RAM cache setting simply triggered the problem. I suspect this may have caused rare failures in other contexts as well (pun intended). I removed the code and tested for signs of memory leaks. I didn't see any, but please keep an eye out. Going off on a tangent, while reviewing the patching code, I discovered that lora patching uses the following pattern:
I think it would be more performant to:
The downside is that this will transiently use more VRAM because all the LoRA layers are loaded at once. Another potential optimization would be to stop saving the original model's weights on entry to the patcher context and restoring them on exit. Since we are now keeping a virgin copy of the state dictionary in the RAM cache, the patched model in VRAM is cleared out at the end of a node's invocation and will be replaced with a fresh copy the next time it is needed. I gave both of these things a quick try and the system felt snappier, but I didn't do timing or any stress tests. If you think this is worth pursuing, I'll submit a new PR for them. [EDIT] I can shave off ~2s of generation walltime (from 10.8 to 8.8s) by avoiding the unecessary step of restoring weights to the VRAM copy of the model. |
The latest commit implements an optimization that circumvents the LoRA unpatching step when working with a model that is resident in CUDA VRAM. This works because the new scheme never copies the model weights back from VRAM ito RAM, but instead reinitializes the VRAM copy from a fresh RAM state dict the next time the model is needed. The behavior for CPU and MPS devices has not changed, since these operate on the RAM copy. When generating with SDXL models, this optimization saves roughly 1s per LoRA per generation, which I think makes the special casing worth it. The other optimization I tried was to let the model manager load the LoRA into VRAM using its usual model locking mechanism rather than manually moving each layer into VRAM before patching. However, this did not give a performance gain and needed special casing for LoRAs in the model manager because LoRAs don't have Other changes in this commit:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you test the effect of removing the VRAM cache with a large VRAM cache size (e.g. large enough to hold all working models)? For this usage pattern, I'm afraid that there is going to be a significant speed regression from removing it.
The behavior of the apply_lora(...)
and the model locker context managers now change significantly depending on the environment in which they run.
apply_lora()
:
- Without CUDA GPU:
__enter__
: apply lora weights__exit__
: revert lora weights
- With CUDA GPU:
__enter__
: apply lora weights__exit__
: do nothing
Model locker:
- Without CUDA GPU:
__enter__
: move model to target device__exit__
: move model to RAM
- With CUDA GPU:
__enter__
: move model to target device__exit__
: move model to RAM, and revert any changes made to the weights
Given these major differences in behavior depending on environment, the caller of these context managers needs to be deeply familiar with their implementation details to use them correctly. It might be better to force the caller to explicitly specify the desired behavior. For example:
with (
ModelCache.model_on_device(model_info, target_device, copy_weights_to_device=True) as model,
ModelPatcher.apply_lora(model, _lora_loader(), lora_prefix, revert_on_exit=False)
):
...
What do you think? It would be a breaking change to the API, but we're making a major breaking change to the behaviour either way.
Separately, we may also want to consider making model_info.model
a private attribute. I can't think of a good reason to access the model directly outside of a model locker context.
For my own reference, here's a rough checklist of the tests we should run once the code settles down to check for performance and behavior regressions:
- Context managers:
- Lora
- TI
- FreeU
- Seamless
- Clip Skip
- HW
- CUDA (multiple device types in case this impacts copy speeds)
- CPU
- MPS
5d4b747
to
c775b59
Compare
@RyanJDick I’ve undone the model patching changes and the removal of the VRAM cache, and what’s left is the original cpu->vram optimization, the fix to the TI patching, and the weird context manager bug that was causing LoRAs not to patch. It is a fairly minimal PR now, so I hope we can get it merged. I’ll work on the LoRA patching optimization separately. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved so my requested changes aren't a blocker for this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome. Thanks for splitting up the PRs.
I did some quick manual regression testing - everything looked good. I tried:
- Text-to-image, LoRA, TI
- CPU-only
- A bunch of model switching - no obvious signs of a memory leak.
I also ran some performance tests.
With vram: 0.25
:
- SDXL T2I, cold cache: 10.4s -> 9.6s
- SDXL T2I, warm cache: 6.9s -> 6.1s
- SDXL T2I + 2 LoRA, warm cache: 9.0 -> 8.6s
With vram: 16
(no significant change, as expected):
- SDXL T2I, cold cache: 8.0s -> 8.0s
- SDXL T2I, warm cache: 4.7s -> 4.6s
- SDXL T2I + 2 LoRA, warm cache: 6.9s -> 6.9s
@lstein There are a few torch features that might stack nicely on this PR to give even more speedup for Host-to-Device copies:
Have you looked into these at all? I don't want to expand the scope of this PR, but these could be an easy follow-up if you're interested in trying them out (or I can do it). |
I'm going to merge this in and then will start working on further optimizations including the lora loading/unloading. |
Thanks for doing the timings. It's not as big a speedup as I saw, but probably very dependent on hardware. |
Summary
This PR speeds up the model manager’s system for moving models back and forth between RAM and VRAM . Instead of calling
model.to()
to accomplish the transfer, the model manager now stores a copy of the model’s state dict in RAM. When the model needs to be moved into VRAM for inference, the manager makes a VRAM copy of the state dict and assigns it to the model usingload_state_dict()
. When inference is done, the model is cleared from VRAM by callingload_state_dict()
with the CPU copy of the state dict.Benchmarking an SDXL model shows an improvement from 3 seconds to 0.81 seconds for a model load/unload cycle. Most of the improvement comes from the unload step, as shown in the table below:
Thanks to @RyanJDick for suggesting this load/unload scheme.
Related Issues / Discussions
QA Instructions
Change models a number of times. Monitor RAM and VRAM for memory leaks.
Merge Plan
Merge when approved
Checklist