Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use Optimizer State Sharding with Sharpness-Aware Minimization? #989

Open
kenmbkr opened this issue May 20, 2022 · 15 comments
Open

Comments

@kenmbkr
Copy link

kenmbkr commented May 20, 2022

I am trying to setup OSS with sharpness-aware minimization (SAM).
https://github.com/davda54/sam
https://github.com/SamsungLabs/ASAM

I followed this guide to setup OSS, but I have difficulty wrapping SAM with OSS.
Due to both OSS and SAM being optimizer wrappers and having their own step functions, I am not sure how to combine both and call those functions. My initial clue is to wrap like this: OSS(SAM(Adam)).
https://fairscale.readthedocs.io/en/stable/tutorials/oss.html

Is there a minimal working example showing how to do that?

@blefaudeux
Copy link
Contributor

blefaudeux commented May 22, 2022

it's a super good question, and SAM is pretty great I think !

I would wrap the other way, SAM(OSS(Adam)), because outside of OSS it should be transparent, you get something strictly equivalent to a non-sharded optimizer, whereas the opposite is not true (SAM is not strictly equivalent to the other vanilla optimizers given the double steps). I actually think that the other way around would be broken (ie: there are two updates in SAM that OSS will not know how to handle), but you can prove me wrong !

I would suggest you give the upstreamed Pytorch implementation a go (here), the interface should be the same and I'm not sure that the version in Fairscale is still up to date (cc @anj-s @min-xu-ai ). Context is that I'm one of the original authors of OSS (among others)

@kenmbkr
Copy link
Author

kenmbkr commented May 23, 2022

@blefaudeux Thank you for your suggestion. I am able to get my code running following your idea SAM(OSS(Adam)).

However, when I use AMP with ShardedGradScaler, I am hitting the following assertion when calling scaler.update():

assert len(found_infs) > 0, "No inf checks were recorded prior to update."

The assertion is raised for only AMP but not full precision. Any insights on what I am missing?

@blefaudeux
Copy link
Contributor

blefaudeux commented May 23, 2022

@blefaudeux Thank you for your suggestion. I am able to get my code running following your idea SAM(OSS(Adam)).

np, and great to hear that the basics work !

However, when I use AMP with ShardedGradScaler, I am hitting the following assertion when calling scaler.update():

assert len(found_infs) > 0, "No inf checks were recorded prior to update."

The assertion is raised for only AMP but not full precision. Any insights on what I am missing?

this kind of makes sense, basically what happens with AMP is that the optimizer.step() is sometimes a no-op, because the algorithm is to try to find the right scaling in between fp32 and fp16 (which has a super short dynamic, above 65k is inf), and the idea is to (1) look for infs prior to the step (so that the inf gradients, if present, don't kill the model) (2) if there are, adapt the scale and skip this step.

From a distance, it looks like one of the SAM optimizer steps did not call the grad scaler (yes, this is yet another .step() overload). I think that the doc is not complete there, but from a distance I would say that the correct wrap order is probably SAM(ShardedGradScaler(OSS(Adam)) (pfeww !). I'm not completely sure that SAM would react correctly to a step being skipped though, so could be that you have to swap the two (SAM and GradScaler), but I would try that first.

It's really possible that it does not work (I just have a high level understanding of how SAM works), in that case it would require some plumbing in some of these pieces, it would not be turnkey I believe.

@kenmbkr
Copy link
Author

kenmbkr commented May 23, 2022

@blefaudeux Thank you for explaining the mechanisms of grad scaler for mixed-precision training.

As per your suggestion SAM(ShardedGradScaler(OSS(Adam)), I looked into the constructors of SAM and ShardedGradScaler. SAM does not take a TorchGradScaler argument in the constructor and ShardedGradScaler does not take an optimizer argument in the constructor. Either way (SAM(ShardedGradScaler) or ShardedGradScaler(SAM)) the wrap is not trivial.

Could you kindly elaborate further on how the wrap SAM(ShardedGradScaler(OSS(Adam)) can be done?

@blefaudeux
Copy link
Contributor

blefaudeux commented May 23, 2022

@blefaudeux Thank you for explaining the mechanisms of grad scaler for mixed-precision training.

As per your suggestion SAM(ShardedGradScaler(OSS(Adam)), I looked into the constructors of SAM and ShardedGradScaler. SAM does not take a TorchGradScaler argument in the constructor and ShardedGradScaler does not take an optimizer argument in the constructor. Either way (SAM(ShardedGradScaler) or ShardedGradScaler(SAM)) the wrap is not trivial.

Could you kindly elaborate further on how the wrap SAM(ShardedGradScaler(OSS(Adam)) can be done?

hey, sorry that it was not too clear indeed, I mixed up two things here (the actual steps would be wrapped this way, but not the constructors, you're right).

The .step() call looks like this, it takes the optimizer as a parameter and will do what I tentatively explained above (also explained here). There are examples in the pytorch doc, the sharded version should behave in the same way. To explain why it's there, it just needs to consolidate the gradient check over all agents so that if one of them has to skip a step(), they all do.

Update:
You don't need to use the ShardedGradScaler, you can stick to the Pytorch (non sharded) original implementation, if you use OSS with DDP (you do need it with a shard-aware DDP, like ShardedDDP).

@kenmbkr
Copy link
Author

kenmbkr commented May 23, 2022

@blefaudeux Thank you for explaining the mechanisms of grad scaler for mixed-precision training.
As per your suggestion SAM(ShardedGradScaler(OSS(Adam)), I looked into the constructors of SAM and ShardedGradScaler. SAM does not take a TorchGradScaler argument in the constructor and ShardedGradScaler does not take an optimizer argument in the constructor. Either way (SAM(ShardedGradScaler) or ShardedGradScaler(SAM)) the wrap is not trivial.
Could you kindly elaborate further on how the wrap SAM(ShardedGradScaler(OSS(Adam)) can be done?

hey, sorry that it was not too clear indeed, I mixed up two things here (the actual steps would be wrapped this way, but not the constructors, you're right).

The .step() call looks like this, it takes the optimizer as a parameter and will do what I tentatively explained above (also explained here). There are examples in the pytorch doc, the sharded version should behave in the same way. To explain why it's there, it just needs to consolidate the gradient check over all agents so that if one of them has to skip a step(), they all do.

Update: You don't need to use the ShardedGradScaler, you can stick to the Pytorch (non sharded) original implementation, if you use OSS with DDP (you do need it with a shard-aware DDP, like ShardedDDP).

Thank you for your clarifications and explaining the differences between GradScaler and its sharded version. I noticed that the step function calls only optimizer.step() internally, while SAM or ASAM calls the step function twice (first/second step for SAM and ascent/descent step for ASAM). I have to implement a wrapping step function that calls both step functions at once. AMP does run under such a configuration. While the loss for full precision goes down, the loss of AMP fluctuates. Where should I look into for possible errors?

  @torch.no_grad()
  def step(self, closure=None):
      assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
      closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

      self.ascent_step()
      closure()
      self.descent_step()
for input, output in data:
  def closure():
    with autocast():
      loss = loss_function(output, model(input))
    scaler.scale(loss).backward()
    return loss

  with autocast():
    loss = loss_function(output, model(input))
  scaler.scale(loss).backward()
  scaler.step(optimizer, closure) # <-- optimizer here is SAM(OSS(Adam))
  scaler.update()
...

Update:
Upon investigation apparently only when AMP is used the wrapper step function is not called. I wonder if it has something to do with the comments stating closure is not supported here.

if "closure" in kwargs:
    raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")

@blefaudeux
Copy link
Contributor

@blefaudeux Thank you for explaining the mechanisms of grad scaler for mixed-precision training.
As per your suggestion SAM(ShardedGradScaler(OSS(Adam)), I looked into the constructors of SAM and ShardedGradScaler. SAM does not take a TorchGradScaler argument in the constructor and ShardedGradScaler does not take an optimizer argument in the constructor. Either way (SAM(ShardedGradScaler) or ShardedGradScaler(SAM)) the wrap is not trivial.
Could you kindly elaborate further on how the wrap SAM(ShardedGradScaler(OSS(Adam)) can be done?

hey, sorry that it was not too clear indeed, I mixed up two things here (the actual steps would be wrapped this way, but not the constructors, you're right).
The .step() call looks like this, it takes the optimizer as a parameter and will do what I tentatively explained above (also explained here). There are examples in the pytorch doc, the sharded version should behave in the same way. To explain why it's there, it just needs to consolidate the gradient check over all agents so that if one of them has to skip a step(), they all do.
Update: You don't need to use the ShardedGradScaler, you can stick to the Pytorch (non sharded) original implementation, if you use OSS with DDP (you do need it with a shard-aware DDP, like ShardedDDP).

Thank you for your clarifications and explaining the differences between GradScaler and its sharded version. I noticed that the step function calls only optimizer.step() internally, while SAM or ASAM calls the step function twice (first/second step for SAM and ascent/descent step for ASAM). I have to implement a wrapping step function that calls both step functions at once. AMP does run under such a configuration. While the loss for full precision goes down, the loss of AMP fluctuates. Where should I look into for possible errors?

  @torch.no_grad()
  def step(self, closure=None):
      assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
      closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

      self.ascent_step()
      closure()
      self.descent_step()
for input, output in data:
  def closure():
    with autocast():
      loss = loss_function(output, model(input))
    scaler.scale(loss).backward()
    return loss

  with autocast():
    loss = loss_function(output, model(input))
  scaler.scale(loss).backward()
  scaler.step(optimizer, closure) # <-- optimizer here is SAM(OSS(Adam))
  scaler.update()
...

Update: Upon investigation apparently only when AMP is used the wrapper step function is not called. I wonder if it has something to do with the comments stating closure is not supported here.

if "closure" in kwargs:
    raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")

hmm ok, thanks for the pointer. So it's the "it will require more plumbing" part that I mentioned, I don't think that it can work out of the box, one of the issues being that there's a scaling to be done before the backward pass, which a closure would mask. Something is probably doable through backward hooks (make sure that the scaling is done automatically by attaching an appropriate hook for instance), but it's not something that I have time to investigate unfortunately :/

A low hanging fruit that you may try is to train using bfloat16, if your hardware supports that ? This does not require any scaler, and you get the same memory savings as float16. A tradeoff is that bfloat16 is not very precise, but depending on the workloads the influence can be rather small (and it could well be that the round minimum that SAM finds is perfect for that)

@blefaudeux
Copy link
Contributor

some information on bfloat16 and pytorch here by the excellent @stas00

@kenmbkr
Copy link
Author

kenmbkr commented May 24, 2022

hmm ok, thanks for the pointer. So it's the "it will require more plumbing" part that I mentioned, I don't think that it can work out of the box, one of the issues being that there's a scaling to be done before the backward pass, which a closure would mask. Something is probably doable through backward hooks (make sure that the scaling is done automatically by attaching an appropriate hook for instance), but it's not something that I have time to investigate unfortunately :/

A low hanging fruit that you may try is to train using bfloat16, if your hardware supports that ? This does not require any scaler, and you get the same memory savings as float16. A tradeoff is that bfloat16 is not very precise, but depending on the workloads the influence can be rather small (and it could well be that the round minimum that SAM finds is perfect for that)

Thank you for the directions. I tried bfloat16 on both 2080 Ti and V100 using autocast(dtype=torch.bfloat16) but I got the following message on both GPUs.

RuntimeError: Current CUDA Device does not support bfloat16. Please switch dtype to float16.

I also came across some attempts in using SAM with AMP but they don't seem to handle inf/NaN gradients.

@blefaudeux
Copy link
Contributor

hmm ok, thanks for the pointer. So it's the "it will require more plumbing" part that I mentioned, I don't think that it can work out of the box, one of the issues being that there's a scaling to be done before the backward pass, which a closure would mask. Something is probably doable through backward hooks (make sure that the scaling is done automatically by attaching an appropriate hook for instance), but it's not something that I have time to investigate unfortunately :/

A low hanging fruit that you may try is to train using bfloat16, if your hardware supports that ? This does not require any scaler, and you get the same memory savings as float16. A tradeoff is that bfloat16 is not very precise, but depending on the workloads the influence can be rather small (and it could well be that the round minimum that SAM finds is perfect for that)

Thank you for the directions. I tried bfloat16 on both 2080 Ti and V100 using autocast(dtype=torch.bfloat16) but I got the following message on both GPUs.

RuntimeError: Current CUDA Device does not support bfloat16. Please switch dtype to float16.

yes, you need Ampere cards for that.. Or TPUs. I had no idea what you had available

I also came across some attempts in using SAM with AMP but they don't seem to handle inf/NaN gradients.

Yep, this option is to skip the gradscaler altogether, but then you don't have overflow/underflow protection, and float16 is not a gentle type to work with..

@kenmbkr
Copy link
Author

kenmbkr commented Jun 9, 2022

@blefaudeux I have been looking at different variations of SAM these days and apparently none supports AMP at the moment. Guess I will put AMP on hold as it is not trivial.

Currently, I am more concerned with how different implementations handle gradient synchronization differently. I am confused about which one is the correct usage when paired with OSS and ShardedDataParallel.

The unofficial SAM implementation moves the parameters to the same device before computing the gradient norm.

The official Adaptive SAM (ASAM) implementation computes the gradient norm on individual workers and does not do any gradient synchronization.

The Gap-guided SAM (GSAM) implementation also computes the gradient norm on individual workers but has an explicit gradient synchronization at the end.

So far my experience is that ASAM converges quite slowly. While GSAM converges way quicker the gradient explodes after a few epochs.

I have been considering different factors including neighborhood size in SAM, loss functions, gradient synchronization but so far I do not have a good answer. Unfortunately, I am not able to do a parameter sweep due to limited resources and the scale of my experiment. I would appreciate it if you can kindly shed light on the issue.

  • Neighborhood size: Since it affects the flatness of the loss landscape, it may be related to the convergence problem.
  • Loss function: I use multiple loss functions in my experiments. I first compute the mean of individual loss functions and sum them together. I am not sure if this will affect gradient reduction.
  • Gradient synchronization: when using OSS and ShardedDataParallel, optimizer and model are sharded. Given how different SAM implementations handle gradients differently, I have difficulty telling which one is the compatible one.

@blefaudeux
Copy link
Contributor

blefaudeux commented Jun 14, 2022

Hey @kenmbkr sorry for the delay, I've very little time these days (and not at facebook anymore). I would need to dive into this to say something half-way smart, but I think you're right in assuming that coupling that with AMP is probably for another day.

Purely with SAM I think that it should work with OSS, as long as SAM is the outer wrap (but this is on a conceptual basis, I'm not sure that the API will help you there).

  • With DDP all the agents have the same gradients in the end of the backward, so all the norm computations will be the same by default.
  • With ShardedDDP the gradients are sharded, but each agent sees a consolidated version of the gradients it manages, so in short the norm computations will be de-duplicated but should still be correct (either no gradient present, or the reduced one).

The short answer there is that I don't think that you would need a gradient sync step, unless the SAM method looks at the gradient values across the model (I need to re-read the paper), in which case with DDP you would still need no sync, but you would need one with ShardedDDP (sync across the agents <> sync across the model). If SAM only considers the optim tensor per tensor, this is not required with either options.

edit: ----> below is the real TL; DR
So out of principle it feels like it should work, but if I remember one of your links correctly tying the two APIs together could be complicated. May I suggest another option ? "SAM for free", I haven't read it thoroughly but it looks like a good match for you, this should work out of the box

@kenmbkr
Copy link
Author

kenmbkr commented Jun 17, 2022

Hey @kenmbkr sorry for the delay, I've very little time these days (and not at facebook anymore). I would need to dive into this to say something half-way smart, but I think you're right in assuming that coupling that with AMP is probably for another day.

Purely with SAM I think that it should work with OSS, as long as SAM is the outer wrap (but this is on a conceptual basis, I'm not sure that the API will help you there).

* With DDP all the agents have the same gradients in the end of the backward, so all the norm computations will be the same by default.

* With ShardedDDP the gradients are sharded, but each agent sees a consolidated version of the gradients it manages, so in short the norm computations will be de-duplicated but should still be correct (either no gradient present, or the reduced one).

The short answer there is that I don't think that you would need a gradient sync step, unless the SAM method looks at the gradient values across the model (I need to re-read the paper), in which case with DDP you would still need no sync, but you would need one with ShardedDDP (sync across the agents <> sync across the model). If SAM only considers the optim tensor per tensor, this is not required with either options.

Thank you very much for getting back and I really appreciate your time in looking at the issues.

From my understanding of your words whether individual implementations sync or do not sync the gradients, it should not affect the results (theoretically). Please correct me if I am wrong.

The unofficial SAM implementation actually recommends no_sync for the first step function due to the following statements in the paper:

"To compute the SAM update when parallelizing across multiple accelerators, we divide each data batch evenly among the accelerators, independently compute the SAM gradient on each accelerator, and average the resulting sub-batch SAM gradients to obtain the final SAM update."

edit: ----> below is the real TL; DR So out of principle it feels like it should work, but if I remember one of your links correctly tying the two APIs together could be complicated. May I suggest another option ? "SAM for free", I haven't read it thoroughly but it looks like a good match for you, this should work out of the box

I have also come across this paper while experimenting on SAM and I think it is a promising direction to get a good balance between training efficiency and loss landscape. However, their implementation is not publicly available and it is not trivial for me to implement them. I hope they will open their implementation some days.

@blefaudeux
Copy link
Contributor

From my understanding of your words whether individual implementations sync or do not sync the gradients, it should not affect the results (theoretically). Please correct me if I am wrong.

The unofficial SAM implementation actually recommends no_sync for the first step function due to the following statements in the paper:

"To compute the SAM update when parallelizing across multiple accelerators, we divide each data batch evenly among the accelerators, independently compute the SAM gradient on each accelerator, and average the resulting sub-batch SAM gradients to obtain the final SAM update."

Oh interesting, it makes sense if the SAM gradients are smaller than these of the model (and the SAM computation is commutative with the reduction), else maybe that it's an approximation but good enough. It could explain maybe (I've not read the papers, disclaimer) why the different xSAM variants suggest something different in that particular case ? S

In any case:

  • no_sync brings some speed benefit, as the communications are not cheap in general, so if you can do without them then that's better
  • OSS with DDP (and ShardedDDP) should respect the no_sync context (the backward will not trigger a reduction of the gradients across all agents in that case), so all good as far as I can see ? By the way, in that case I think that FSDP should also work fine if you don't trigger all the bell and whistles ?

@kenmbkr
Copy link
Author

kenmbkr commented Jun 17, 2022

From my understanding of your words whether individual implementations sync or do not sync the gradients, it should not affect the results (theoretically). Please correct me if I am wrong.
The unofficial SAM implementation actually recommends no_sync for the first step function due to the following statements in the paper:
"To compute the SAM update when parallelizing across multiple accelerators, we divide each data batch evenly among the accelerators, independently compute the SAM gradient on each accelerator, and average the resulting sub-batch SAM gradients to obtain the final SAM update."

Oh interesting, it makes sense if the SAM gradients are smaller than these of the model (and the SAM computation is commutative with the reduction), else maybe that it's an approximation but good enough. It could explain maybe (I've not read the papers, disclaimer) why the different xSAM variants suggest something different in that particular case ? S

In any case:

* no_sync brings some speed benefit, as the communications are _not_ cheap in general, so if you can do without them then that's better

* OSS with DDP (and [ShardedDDP](https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/data_parallel/sharded_ddp.py#L380)) should respect the no_sync context (the backward will not trigger a reduction of the gradients across all agents in that case), so all good as far as I can see ? By the way, in that case I think that FSDP should also work fine if you don't trigger all the bell and whistles ?

Thank you for your valuable insights into the gradient synchronization problem :)

As for FSDP, it is actually the first thing I tried when I came across fairscale.
However, it wraps only the layers with gradients and does not handle layers without gradients automatically like SharedDDP.
While many recommended separate layers with and without gradients, it is basically a major refactoring of the code.
In my case, layers without gradients are pre-trained networks for loss computation.
I have yet to adapt FSDP in my code due to the high learning curve and I hope I can pick it up some days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants