Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiSteps doesn't work with DPSGD #515

Open
long21wt opened this issue Mar 21, 2023 · 1 comment
Open

MultiSteps doesn't work with DPSGD #515

long21wt opened this issue Mar 21, 2023 · 1 comment

Comments

@long21wt
Copy link

long21wt commented Mar 21, 2023

Hi,

I'm trying to use DP-SGD with MultiSteps to train Bart.
Normally, I can only use batch size = 8 for DP-SGD on A100 80 GB, so gradient accumulation would be a good choice.
I follow the MultiSteps tutorial, and it works with SGD but not with DP-SGD.
Here is part of my stack trace:

  File "/mnt/beegfs/work/me/works/projects/experiments.py", line 155, in train
    state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
  File "/mnt/beegfs/work/me/works/projects/experiments.py", line 232, in train_step
    new_state = state.apply_gradients(grads=grad)
  File "/storage/trust/work/me/miniconda3/envs/projects_a180/lib/python3.10/site-packages/flax/training/train_state.py", line 73, in apply_gradients
    updates, new_opt_state = self.tx.update(
  File "/storage/trust/work/me/miniconda3/envs/projects_a180/lib/python3.10/site-packages/optax/_src/wrappers.py", line 421, in update
    new_updates, new_state = jax.lax.cond(
TypeError: true_fun and false_fun output must have identical types, got
({'final_logits_bias': 'DIFFERENT ShapedArray(float32[8,1,250027]) vs. ShapedArray(float32[1,250027])',
 'model': 
 {'decoder':
  {'embed_positions': 
  {'embedding': 'DIFFERENT ShapedArray(float32[8,1026,1024]) vs. ShapedArray(float32[1026,1024])'}, 
 'layer_norm': {'bias': 'DIFFERENT ShapedArray(float32[8,1024]) vs. ShapedArray(float32[1024])',

For example, in final_logits_bias, my grad is ShapedArray(float32[8,1,250027]) and multi_state_when_skip is ShapedArray(float32[1,250027])

Thanks.

@long21wt
Copy link
Author

And it works with optax.apply_every

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant