Skip to content

Commit

Permalink
remove _shard_tensor() call (#111687)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrey Talman <atalman@fb.com>
  • Loading branch information
wz337 and atalman committed Nov 8, 2023
1 parent f58669b commit 4c55dc5
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions torch/distributed/checkpoint/optimizer.py
Expand Up @@ -22,6 +22,8 @@
TensorStorageMetadata,
ChunkStorageMetadata,
)
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
from torch.distributed.checkpoint.planner_helpers import (
create_read_items_for_chunk_list,
_create_read_items,
Expand All @@ -32,7 +34,7 @@
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
)
from torch.distributed._shard.api import _shard_tensor

from torch.distributed.checkpoint.planner import LoadPlanner

from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
Expand Down Expand Up @@ -293,8 +295,12 @@ def load_sharded_optimizer_state_dict(
if value.size.numel() == 1:
state_dict[key] = _alloc_tensor(value.properties, value.size, dp_pg_device_type)
elif dp_pg is None:
state_dict[key] = _shard_tensor(
_alloc_tensor(value.properties, value.size, dp_pg_device_type), sharding_spec
state_dict[key] = _create_chunk_sharded_tensor(
_alloc_tensor(value.properties, value.size, dp_pg_device_type),
rank=dist.get_rank(),
world_size=dist.get_world_size(),
num_devices_per_node=device_module.device_count(),
pg=_get_default_group(),
)
else:
spec_key = key_path[2]
Expand Down

0 comments on commit 4c55dc5

Please sign in to comment.