diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py index 5ffcc94e700d..ba6ab2c734b2 100644 --- a/torch/distributed/checkpoint/optimizer.py +++ b/torch/distributed/checkpoint/optimizer.py @@ -22,6 +22,8 @@ TensorStorageMetadata, ChunkStorageMetadata, ) +from torch.distributed.distributed_c10d import _get_default_group +from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.checkpoint.planner_helpers import ( create_read_items_for_chunk_list, _create_read_items, @@ -32,7 +34,7 @@ from torch.distributed.checkpoint.default_planner import ( DefaultLoadPlanner, ) -from torch.distributed._shard.api import _shard_tensor + from torch.distributed.checkpoint.planner import LoadPlanner from torch.distributed.checkpoint._nested_dict import unflatten_state_dict @@ -293,8 +295,12 @@ def load_sharded_optimizer_state_dict( if value.size.numel() == 1: state_dict[key] = _alloc_tensor(value.properties, value.size, dp_pg_device_type) elif dp_pg is None: - state_dict[key] = _shard_tensor( - _alloc_tensor(value.properties, value.size, dp_pg_device_type), sharding_spec + state_dict[key] = _create_chunk_sharded_tensor( + _alloc_tensor(value.properties, value.size, dp_pg_device_type), + rank=dist.get_rank(), + world_size=dist.get_world_size(), + num_devices_per_node=device_module.device_count(), + pg=_get_default_group(), ) else: spec_key = key_path[2]