Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Cannot train Seeker with batch size > 1 #4531

Open
zhangmozhi opened this issue May 7, 2022 · 11 comments
Open

Cannot train Seeker with batch size > 1 #4531

zhangmozhi opened this issue May 7, 2022 · 11 comments
Assignees
Labels
donotreap Avoid automatically marking as stale.

Comments

@zhangmozhi
Copy link
Contributor

Bug description
It seems that Seeker training command does not support batch size > 1. I ran into a FSDP error when training Seeker-400M with -bs 2.

Reproduction steps

python -m parlai.scripts.multiprocessing_train \
--task projects.seeker.tasks.knowledge,projects.seeker.tasks.dialogue,projects.seeker.tasks.search_query \
--multitask-weights 2,2,1 -bs 2 -vstep 1000 -vmt ppl -vp 5 -vmm min -vme 100000 -lstep 50 \
--init-opt arch/r2c2_base_400M --init-model zoo:seeker/r2c2_base_400M/model \
--model projects.seeker.agents.seeker:ComboFidGoldDocumentAgent --n-docs 5 \
--text-truncate 1000 --label-truncate 128 --truncate 1000 \
--fp16 True -lr 1e-06 --lr-scheduler reduceonplateau --optimizer adamw --save-after-valid True \
--warmup-updates 100 --update-freq 1 --gradient-clip 1.0 --skip-generation True --dropout 0.1 \
--attention-dropout 0.0 --load-from-checkpoint true --ddp-backend zero2 \
--checkpoint-activations true--model-file /tmp/my_seeker_dialogue_model

Expected behavior
I was hoping that training could succeed.

Logs
Please paste the command line output:

Asserting FSDP instance is: FullyShardedDataParallel(
  world_size=8, flatten_parameters=True, mixed_precision=True, 
  (_fsdp_wrapped_module): FlattenParamsWrapper(
    (_fpw_module): TransformerEncoderLayer_Swappable(
      (attention): MultiHeadAttention(
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (q_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (k_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (v_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (out_lin): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (ffn): TransformerFFN(
        (relu_dropout): Dropout(p=0, inplace=False)
        (lin1): Linear(in_features=1024, out_features=4096, bias=True)
        (lin2): Linear(in_features=4096, out_features=1024, bias=True)
      )
      (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)
ERROR: expected to be in states [<TrainingState.BACKWARD_POST: 4>] but current state is TrainingState.BACKWARD_PRE
2022-04-30 23:01:08,795 CRITICAL | Traceback (most recent call last):
  File "/data/kai/ParlAI/parlai/scripts/multiprocessing_train.py", line 45, in multiprocess_train
    return single_train.TrainLoop(opt).train()
  File "/data/kai/ParlAI/parlai/scripts/train_model.py", line 1000, in train
    for _train_log in self.train_steps():
  File "/data/kai/ParlAI/parlai/scripts/train_model.py", line 907, in train_steps
    world.parley()
  File "/data/kai/ParlAI/parlai/core/worlds.py", line 880, in parley
    batch_act = self.batch_act(agent_idx, batch_observations[agent_idx])
  File "/data/kai/ParlAI/parlai/core/worlds.py", line 848, in batch_act
    batch_actions = a.batch_act(batch_observation)
  File "/data/kai/ParlAI/parlai/agents/fid/fid.py", line 389, in batch_act
    batch_reply = super().batch_act(observations)
  File "/data/kai/ParlAI/parlai/core/torch_agent.py", line 2238, in batch_act
    output = self.train_step(batch)
  File "/data/kai/ParlAI/parlai/core/torch_generator_agent.py", line 736, in train_step
    self.backward(loss)
  File "/data/kai/ParlAI/parlai/core/torch_agent.py", line 2324, in backward
    self.optimizer.backward(loss, update_main_grads=False)
  File "/data/kai/ParlAI/parlai/utils/fp16.py", line 194, in backward
    loss.backward()
  File "/data/kai/miniconda3/envs/parlai/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/data/kai/miniconda3/envs/parlai/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f0ce3a8e9e0> returned NULL without setting an error

Additional context
Not sure if this is a bug or a feature request.

@stephenroller
Copy link
Contributor

Can you report your fairscale version?

@zhangmozhi
Copy link
Contributor Author

The above error message was from fairscale 0.3.7.

Also tried fairscale 0.4.6 and got a similar error:

2022-05-08 14:35:33,820 INFO   | training...
rank: 3 | 2022-05-08 14:35:40,808 CRITICAL | Traceback (most recent call last):
 File "/data/kai/ParlAI/parlai/scripts/multiprocessing_train.py", line 45, in multiprocess_train
  return single_train.TrainLoop(opt).train()
 File "/data/kai/ParlAI/parlai/scripts/train_model.py", line 1000, in train
  for _train_log in self.train_steps():
 File "/data/kai/ParlAI/parlai/scripts/train_model.py", line 907, in train_steps
  world.parley()
 File "/data/kai/ParlAI/parlai/core/worlds.py", line 880, in parley
  batch_act = self.batch_act(agent_idx, batch_observations[agent_idx])
 File "/data/kai/ParlAI/parlai/core/worlds.py", line 848, in batch_act
  batch_actions = a.batch_act(batch_observation)
 File "/data/kai/ParlAI/parlai/agents/fid/fid.py", line 389, in batch_act
  batch_reply = super().batch_act(observations)
 File "/data/kai/ParlAI/parlai/core/torch_agent.py", line 2239, in batch_act
  output = self.train_step(batch)
 File "/data/kai/ParlAI/parlai/core/torch_generator_agent.py", line 736, in train_step
  self.backward(loss)
 File "/data/kai/ParlAI/parlai/core/torch_agent.py", line 2325, in backward
  self.optimizer.backward(loss, update_main_grads=False)
 File "/data/kai/ParlAI/parlai/utils/fp16.py", line 194, in backward
  loss.backward()
 File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
  torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
 File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
  Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f0aa7bfc8d0> returned NULL without setting an error

@stephenroller
Copy link
Contributor

I wonder if it's multiprocessing train... Does that work with a transformer/generator?

@zhangmozhi
Copy link
Contributor Author

zhangmozhi commented May 11, 2022

No. We just tried running this (fairscale 0.4.6):
parlai multiprocessing_train -t projects.seeker.tasks.knowledge,projects.seeker.tasks.dialogue,projects.seeker.tasks.search_query --multitask-weights 2,2,1 -veps 0.25 --attention-dropout 0.0 --batchsize 32 --model transformer/generator --embedding-size 2560 --ffn-size 10240 --variant prelayernorm --n-heads 32 --n-positions 128 --n-encoder-layers 2 --n-decoder-layers 24 --history-add-global-end-token end --delimiter ' ' --dict-tokenizer bytelevelbpe --dropout 0.1 --fp16 True --init-model zoo:blender/reddit_3B/model --dict-file zoo:blender/reddit_3B/model.dict --label-truncate 128 --log_every_n_secs 30 -lr 7e-06 --lr-scheduler reduceonplateau --lr-scheduler-patience 3 --optimizer adam --relu-dropout 0.0 --activation gelu --ddp-backend zero2 --learn-positional-embeddings true --save-after-valid True --text-truncate 128 --truncate 128 --warmup_updates 100 --fp16-impl mem_efficient --update-freq 1 --gradient-clip 0.1 --skip-generation True -vp 10 -vmt ppl -vmm min --tensorboard-log true --model-file /data/kai/modelfiles/test_train_3B/test_train_27B

And got this:

rank: 5 | 11:18:15 | Traceback (most recent call last):
 File "/data/kai/ParlAI/parlai/scripts/multiprocessing_train.py", line 45, in multiprocess_train
  return single_train.TrainLoop(opt).train()
 File "/data/kai/ParlAI/parlai/scripts/train_model.py", line 1000, in train
  for _train_log in self.train_steps():
 File "/data/kai/ParlAI/parlai/scripts/train_model.py", line 907, in train_steps
  world.parley()
 File "/data/kai/ParlAI/parlai/core/worlds.py", line 880, in parley
  batch_act = self.batch_act(agent_idx, batch_observations[agent_idx])
 File "/data/kai/ParlAI/parlai/core/worlds.py", line 848, in batch_act
  batch_actions = a.batch_act(batch_observation)
 File "/data/kai/ParlAI/parlai/core/torch_agent.py", line 2239, in batch_act
  output = self.train_step(batch)
 File "/data/kai/ParlAI/parlai/core/torch_generator_agent.py", line 736, in train_step
  self.backward(loss)
 File "/data/kai/ParlAI/parlai/core/torch_agent.py", line 2325, in backward
  self.optimizer.backward(loss, update_main_grads=False)
 File "/data/kai/ParlAI/parlai/utils/fp16.py", line 522, in backward
  loss.backward()
 File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
  torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
 File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
  Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
 File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1485, in _pre_backward_hook
  self._use_full_params()
 File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
  return func(*args, **kwargs)
 File "/data/kai/minicond/envs/parlai/lib/python3.8/site-packages/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1992, in _use_full_params
  assert self.has_full_params
AssertionError

@stephenroller
Copy link
Contributor

Sorry, one more thing. Can you roll back to 0.4.4?

@zhangmozhi
Copy link
Contributor Author

Sorry, it turns out that transformer/generator works fine with bs > 1. We ran into the above error because we turned off flatten_parameter (which is also strange, but I suppose this is a fairscale problem).

We still couldn't train Seeker with bs > 1 with fairscale 0.4.6. We're trying 0.4.4 now and will report back.

@zhangmozhi
Copy link
Contributor Author

So we just tried training Seeker with fairscale 0.4.4 and got the same error.

@klshuster klshuster self-assigned this May 11, 2022
@klshuster
Copy link
Contributor

klshuster commented May 12, 2022

I'm able to repro on my end so I'll try to look into it a bit more and report back here with findings

Edit Update 1

The model is able to train with multiprocessing_train, --batchsize 2, and only 1 exposed GPU. Bumping up to 2 GPUs, it fails.

Passing command:

CUDA_VISIBLE_DEVICES=0 CUDA_LAUNCH_BLOCKING=1 python -m parlai.scripts.multiprocessing_train --task projects.seeker.tasks.search_query --multitask-weights 2,2,1 -bs 2 -vstep 1000 -vmt ppl -vp 5 -vmm min -vme 100000 -lstep 50 --init-opt arch/r2c2_base_400M --init-model zoo:seeker/r2c2_base_400M/model --model projects.seeker.agents.seeker:ComboFidGoldDocumentAgent --n-docs 5 --text-truncate 1000 --label-truncate 128 --truncate 1000 --fp16 True -lr 1e-06 --lr-scheduler reduceonplateau --optimizer adamw --save-after-valid True --warmup-updates 100 --update-freq 1 --gradient-clip 1.0 --skip-generation True --dropout 0.1 --attention-dropout 0.0 --load-from-checkpoint true --ddp-backend zero2 --checkpoint-activations true --model-file

Failing command:

CUDA_VISIBLE_DEVICES=0,1 CUDA_LAUNCH_BLOCKING=1 python -m parlai.scripts.multiprocessing_train --task projects.seeker.tasks.search_query --multitask-weights 2,2,1 -bs 2 -vstep 1000 -vmt ppl -vp 5 -vmm min -vme 100000 -lstep 50 --init-opt arch/r2c2_base_400M --init-model zoo:seeker/r2c2_base_400M/model --model projects.seeker.agents.seeker:ComboFidGoldDocumentAgent --n-docs 5 --text-truncate 1000 --label-truncate 128 --truncate 1000 --fp16 True -lr 1e-06 --lr-scheduler reduceonplateau --optimizer adamw --save-after-valid True --warmup-updates 100 --update-freq 1 --gradient-clip 1.0 --skip-generation True --dropout 0.1 --attention-dropout 0.0 --load-from-checkpoint true --ddp-backend zero2 --checkpoint-activations true --model-file

Update 2

This fails with the gold doc standard FiD agent as well

CUDA_VISIBLE_DEVICES=0,1 CUDA_LAUNCH_BLOCKING=1 python -m parlai.scripts.multiprocessing_train --task projects.seeker.tasks.knowledge:WoiKnowledgeTeacher --multitask-weights 2,2,1 -bs 2 -vstep 1000 -vmt ppl -vp 5 -vmm min -vme 100000 -lstep 50 --init-opt arch/r2c2_base_400M --init-model zoo:seeker/r2c2_base_400M/model --model parlai.agents.fid.fid:WizIntGoldDocRetrieverFiDAgent --n-docs 5 --text-truncate 1000 --label-truncate 128 --truncate 1000 --fp16 True -lr 1e-06 --lr-scheduler reduceonplateau --optimizer adamw --save-after-valid True --warmup-updates 100 --update-freq 1 --gradient-clip 1.0 --skip-generation True --dropout 0.1 --attention-dropout 0.0 --load-from-checkpoint true --ddp-backend zero2 --checkpoint-activations true --model-file

@stephenroller
Copy link
Contributor

Should we try w/ slurm to rule out it being multiprocessing?

@klshuster
Copy link
Contributor

Should we try w/ slurm to rule out it being multiprocessing?

Tried this, still fails. something is hanging somewhere...

@github-actions
Copy link

This issue has not had activity in 30 days. Please feel free to reopen if you have more issues. You may apply the "never-stale" tag to prevent this from happening.

@github-actions github-actions bot added the stale label Jun 13, 2022
@klshuster klshuster reopened this Jun 21, 2022
@klshuster klshuster added donotreap Avoid automatically marking as stale. and removed stale labels Jun 21, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
donotreap Avoid automatically marking as stale.
Projects
None yet
Development

No branches or pull requests

3 participants