Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed training and max_recycling_iters #415

Open
pujaltes opened this issue Mar 6, 2024 · 0 comments
Open

Distributed training and max_recycling_iters #415

pujaltes opened this issue Mar 6, 2024 · 0 comments

Comments

@pujaltes
Copy link

pujaltes commented Mar 6, 2024

I have a question regarding the number of recycling iterations used during training. In the AF2 paper they mention that the number of recycling iterations are a "shared value across the batch". However, from what I can tell batch level attributes during distributed training are actually defined at the micro-batch level here:

def _add_batch_properties(self, batch):
gt_features = batch.pop('gt_features', None)
samples = torch.multinomial(
self.prop_probs_tensor,
num_samples=1, # 1 per row
replacement=True,
generator=self.generator
)
aatype = batch["aatype"]
batch_dims = aatype.shape[:-2]
recycling_dim = aatype.shape[-1]
no_recycling = recycling_dim
for i, key in enumerate(self.prop_keys):
sample = int(samples[i][0])
sample_tensor = torch.tensor(
sample,
device=aatype.device,
requires_grad=False
)
orig_shape = sample_tensor.shape
sample_tensor = sample_tensor.view(
(1,) * len(batch_dims) + sample_tensor.shape + (1,)
)
sample_tensor = sample_tensor.expand(
batch_dims + orig_shape + (recycling_dim,)
)
batch[key] = sample_tensor
if key == "no_recycling_iters":
no_recycling = sample
resample_recycling = lambda t: t[..., :no_recycling + 1]
batch = tensor_tree_map(resample_recycling, batch)
batch['gt_features'] = gt_features
return batch

From my understanding, in both DDP and DeepSpeed each batch is split into micro-batches that are each sent to one GPU. The issue is that the batch splitting occurs in the DistributedSampler before it even gets to the OpenFoldDataLoader. Ergo, all these properties that should be fixed at the batch-level are actually defined at the micro-batch level, meaning that each GPU process could be running a different number of recycling iterations. Please let me know if I am reading this incorrectly, but apart from not matching the paper wouldn't this be extremely wasteful as all GPUs would have to wait for the micro-batch with the largest recycling_iters?

For DDP we could simply use the broadcast api to send the recycling_iters from rank 0 to the rest of the processes. Looking at the DeepSpeedStrategy code from lightning it seems that it inherits the DDPStrategy class, along with the broadcast method. The inherited method is actually used throughout the DeepSpeedStrategy class so we should be fine to use it for both distributed training strategies.

Thanks for your help in advance :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant