From 4c55dc50355d5e923642c59ad2a23d6ad54711e7 Mon Sep 17 00:00:00 2001 From: Iris Z <31293777+wz337@users.noreply.github.com> Date: Wed, 8 Nov 2023 04:49:29 -0800 Subject: [PATCH] remove _shard_tensor() call (#111687) Co-authored-by: Andrey Talman --- torch/distributed/checkpoint/optimizer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py index 5ffcc94e700d..ba6ab2c734b2 100644 --- a/torch/distributed/checkpoint/optimizer.py +++ b/torch/distributed/checkpoint/optimizer.py @@ -22,6 +22,8 @@ TensorStorageMetadata, ChunkStorageMetadata, ) +from torch.distributed.distributed_c10d import _get_default_group +from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.checkpoint.planner_helpers import ( create_read_items_for_chunk_list, _create_read_items, @@ -32,7 +34,7 @@ from torch.distributed.checkpoint.default_planner import ( DefaultLoadPlanner, ) -from torch.distributed._shard.api import _shard_tensor + from torch.distributed.checkpoint.planner import LoadPlanner from torch.distributed.checkpoint._nested_dict import unflatten_state_dict @@ -293,8 +295,12 @@ def load_sharded_optimizer_state_dict( if value.size.numel() == 1: state_dict[key] = _alloc_tensor(value.properties, value.size, dp_pg_device_type) elif dp_pg is None: - state_dict[key] = _shard_tensor( - _alloc_tensor(value.properties, value.size, dp_pg_device_type), sharding_spec + state_dict[key] = _create_chunk_sharded_tensor( + _alloc_tensor(value.properties, value.size, dp_pg_device_type), + rank=dist.get_rank(), + world_size=dist.get_world_size(), + num_devices_per_node=device_module.device_count(), + pg=_get_default_group(), ) else: spec_key = key_path[2]