Skip to content

Commit

Permalink
Fix device type issue in _get_device_handle (#124390)
Browse files Browse the repository at this point in the history
Fix #124327

`device_type`, the first arg of [init_device_mesh()](https://github.com/pytorch/pytorch/blob/a0466061e17358fb621cfde3f85e0bd6d13cfc55/torch/distributed/device_mesh.py#L503),  does not support types with indexes, such as `cuda:0`.
If `cuda:0` is used as a parameter, `_get_device_handle()` will not correctly return `torch.cuda`.
So the exception should be thrown before creating DeviceMesh object.

> See #124327 (comment),

Pull Request resolved: #124390
Approved by: https://github.com/wz337, https://github.com/wanchaol
  • Loading branch information
shink authored and pytorchmergebot committed Apr 30, 2024
1 parent 5e5f890 commit e0d2c24
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
11 changes: 11 additions & 0 deletions test/distributed/test_device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,17 @@ def test_from_group(self):
ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim
)

def test_raises_invalid_device_type(self):
with self.assertRaisesRegex(
RuntimeError,
"Device type with GPU index is not supported",
):
# test init_device_mesh with an invalid device type that contains a GPU index
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(
"cuda:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp")
)


class DeviceMeshTestNDim(DTensorTestBase):
@property
Expand Down
8 changes: 8 additions & 0 deletions torch/distributed/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def init_device_mesh(
Args:
device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
Passing in a device type with a GPU index, such as "cuda:0", is not allowed.
mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array
describing the layout of devices.
mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension
Expand Down Expand Up @@ -565,6 +566,13 @@ def init_device_mesh(
f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
)

# assume valid device types are all letters
if device_type and not device_type.isalpha():
raise RuntimeError(
f"Device type with GPU index is not supported but got {device_type}. ",
"If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.",
)

# Always initialize the mesh's tensor on CPU, regardless of what the
# external device type has been set to be (e.g. meta)
with torch.device("cpu"):
Expand Down

0 comments on commit e0d2c24

Please sign in to comment.