Skip to content

Commit

Permalink
Perform appropriate CUDA stream synchronization in distributed autogr…
Browse files Browse the repository at this point in the history
…ad. (#53929) (#54358)

Summary:
Pull Request resolved: #53929

The local autograd engine performs appropriate stream synchronization
between autograd nodes in the graph to ensure a consumer's stream is
synchronized with the producer's stream before executing the consumer.

However in case of distributed autograd, the SendRpcBackward function receives
gradients over the wire and TensorPipe uses its own pool of streams for this
purpose. As a result, the tensors are received on TensorPipe's stream pool but
SendRpcBackward runs on a different stream during the backward pass and there
is no logic to synchronize these streams.

To fix this, I've enhanced DistEngine to synchronize these streams
appropriately when it receives grads over the wire.
ghstack-source-id: 124055277

(Note: this ignores all push blocking failures!)

Test Plan:
1) Added unit test which reproduced the issue.
2) waitforbuildbot.

Reviewed By: walterddr, wanchaol

Differential Revision: D27025307

fbshipit-source-id: 2944854e688e001cb3989d2741727b30d9278414

Co-authored-by: Pritam Damania <pritam.damania@fb.com>
  • Loading branch information
pritamdamania87 and pritamdamania committed Mar 24, 2021
1 parent 6c39461 commit 56b43f4
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 2 deletions.
22 changes: 21 additions & 1 deletion torch/csrc/distributed/autograd/engine/dist_engine.cpp
@@ -1,6 +1,7 @@
#include <queue>

#include <ATen/Parallel.h>
#include <c10/core/Event.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/input_buffer.h>
#include <torch/csrc/distributed/autograd/context/container.h>
Expand Down Expand Up @@ -423,8 +424,27 @@ std::shared_ptr<c10::ivalue::Future> DistEngine::

std::shared_ptr<c10::ivalue::Future> DistEngine::executeSendFunctionAsync(
const ContextPtr& autogradContext,
const std::shared_ptr<Node>& sendFunction,
const std::shared_ptr<SendRpcBackward>& sendFunction,
bool retainGraph) {

// Typically the local autograd engine ensures stream synchronizations between
// nodes in the graph. However, for distributed autograd the sendFunction
// inputs might have been retrieved over the wire on a separate stream and the
// sendFunction itself runs on a different stream. As a result, we need to
// manually synchronize those two streams here.
const auto& send_backward_stream = sendFunction->stream(c10::DeviceType::CUDA);
if (send_backward_stream) {
for (const auto& grad : sendFunction->getGrads()) {
const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
const auto default_stream = guard.getStream(grad.device());
if (send_backward_stream != default_stream) {
auto event = c10::Event{c10::DeviceType::CUDA};
event.record(default_stream);
send_backward_stream->wait(event);
}
}
}

std::unique_lock<std::mutex> lock(initializedContextIdsLock_);
if (initializedContextIds_.find(autogradContext->contextId()) ==
initializedContextIds_.end()) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/autograd/engine/dist_engine.h
Expand Up @@ -46,7 +46,7 @@ class TORCH_API DistEngine {
// The gradients are accumulated in the provided autograd context.
std::shared_ptr<c10::ivalue::Future> executeSendFunctionAsync(
const ContextPtr& autogradContext,
const std::shared_ptr<torch::autograd::Node>& sendFunction,
const std::shared_ptr<SendRpcBackward>& sendFunction,
bool retainGraph);

// Number of backward passes currently running for the Distributed Engine.
Expand Down
Expand Up @@ -23,6 +23,10 @@ void SendRpcBackward::setGrads(const torch::autograd::variable_list& grads) {
grads_ = grads;
}

const torch::autograd::variable_list& SendRpcBackward::getGrads() const {
return grads_;
}

} // namespace autograd
} // namespace distributed
} // namespace torch
3 changes: 3 additions & 0 deletions torch/csrc/distributed/autograd/functions/sendrpc_backward.h
Expand Up @@ -25,6 +25,9 @@ struct TORCH_API SendRpcBackward : public torch::autograd::Node {
// computation.
void setGrads(const torch::autograd::variable_list& grads);

// Retrieve the grads for the function.
const torch::autograd::variable_list& getGrads() const;

private:
torch::autograd::variable_list grads_;
};
Expand Down
56 changes: 56 additions & 0 deletions torch/testing/_internal/distributed/rpc/dist_autograd_test.py
Expand Up @@ -3,6 +3,7 @@
import time
import unittest
from enum import Enum
import random
import torch
from datetime import timedelta
import torch.distributed as dist
Expand Down Expand Up @@ -2266,3 +2267,58 @@ def test_device_maps_backward_pass(self):
self.assertEqual(t2.device, grads[t2].device)

rpc.shutdown()

class MyRemoteCompute(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
input = input * 2.0
return input

class MyLocalCompute(torch.nn.Module):
def __init__(self, next_stage):
super().__init__()
self.next_stage = next_stage

def forward(self, input):
return self.next_stage.rpc_sync().forward(input)

@skip_if_lt_x_gpu(4)
def test_dist_autograd_sync_streams(self):

options = self.rpc_backend_options
dst = worker_name((self.rank + 1) % self.world_size)

# The reverse of this device mapping should be used for the backward pass.
options.set_device_map(dst, {self.rank: (self.rank + 1) % self.world_size})

rpc.init_rpc(
name=worker_name(self.rank),
backend=self.rpc_backend,
rank=self.rank,
world_size=self.world_size,
rpc_backend_options=options,
)

remote_compute = rpc.remote(dst, TensorPipeDistAutogradTest.MyRemoteCompute)
local_compute = TensorPipeDistAutogradTest.MyLocalCompute(remote_compute)
for _ in range(10):
input = torch.rand([1000, 10000], device=self.rank, requires_grad=True)
# Run local autograd
result = input * 2.0
r = random.random()
loss = result.sum() * r
loss.backward()

# Run distributed autograd
with dist_autograd.context() as context_id:
result = local_compute(input)
loss = result.sum() * r
dist_autograd.backward(context_id, [loss])

# Compare grads.
grads = dist_autograd.get_gradients(context_id)
self.assertEqual(input.grad, grads[input])

rpc.shutdown()

0 comments on commit 56b43f4

Please sign in to comment.