-
Notifications
You must be signed in to change notification settings - Fork 364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot compute the gradients of EmbeddingBagCollection. #1614
Comments
We can't really look into the gradients. #1293 I guess Colin was saying if you switch to dense compute kernel, then you can see it. Though I haven't tried it myself. |
@colin2328 Hi, as I also encountered the same problem, and after searching some issues, I found that the grad can be provided in dense compute kernel. Do you have any other ways to run a embedding collection in dense mode? |
I found one way is to implement my own Sharder class EmbeddingCollectionSharderDense(EmbeddingCollectionSharder):
def compute_kernels(self, sharding_type: str, compute_device_type: str) -> List[str]:
return super().compute_kernels(sharding_type, compute_device_type) + [EmbeddingComputeKernel.DENSE.value] |
@Ye-Tian-Zero see to my answer in #1741 tldr: for non-data parallel sharding type, you get better performance by using FUSED. We didn't disable DENSE. We disabled DENSE in most tests, since its slower than FUSED. For data parallel, you must use DENSE. |
@henrylhtsang Hi, thank you, but I think dense compute kernel was disabled here: 35c05fa#diff-72b4a4f205f4de558b2f5731a6697ae80483ee532eb2012d59999e5d7c462200L302-R302 which means Dense kernel was no longer a valid option for non-data-parallel module. |
What I mean by 'disable' is that it can no longer be used with sharding modules. |
@Ye-Tian-Zero It can still be used. Let me know if you encounter any problem using it. You can configure it through ParameterConstraints with EmbeddingComputeKernel. The reason it was disabled in the tests is to save testing capacity, since DENSE is slower than FUSED and hence we usually recommend people to use FUSED. But for debugging purposes, feel free to use DENSE. |
@henrylhtsang Sorry for the late reply, maybe you can try this script with a gpu machine: #!/usr/bin/env python
# coding: utf-8
# In[1]:
import os
import torch
import torchrec
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "21500"
# In[2]:
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import ShardingType
from typing import Dict
large_table_cnt = 2
small_table_cnt = 2
large_tables=[
torchrec.EmbeddingConfig(
name="large_table_" + str(i),
embedding_dim=64,
num_embeddings=4096,
feature_names=["large_table_feature_" + str(i)],
) for i in range(large_table_cnt)
]
small_tables=[
torchrec.EmbeddingConfig(
name="small_table_" + str(i),
embedding_dim=64,
num_embeddings=1024,
feature_names=["small_table_feature_" + str(i)],
) for i in range(small_table_cnt)
]
def gen_constraints(sharding_type: ShardingType = ShardingType.TABLE_WISE) -> Dict[str, ParameterConstraints]:
large_table_constraints = {
"product_table": ParameterConstraints(
sharding_types=[sharding_type.value],
compute_kernels=[EmbeddingComputeKernel.DENSE.value]
)
}
small_table_constraints = {
"user_table":ParameterConstraints(
sharding_types=[sharding_type.value],
compute_kernels=[EmbeddingComputeKernel.DENSE.value],
)
}
constraints = {**large_table_constraints, **small_table_constraints}
return constraints
# In[3]:
def single_rank_execution(
rank: int,
world_size: int,
constraints: Dict[str, ParameterConstraints],
module: torch.nn.Module,
backend: str,
) -> None:
import os
import torch
import torch.distributed as dist
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ModuleSharder, ShardingEnv
from typing import cast
import torchrec
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from typing import List
def init_distributed_single_host(
rank: int,
world_size: int,
backend: str,
# pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
) -> dist.ProcessGroup:
os.environ["RANK"] = f"{rank}"
os.environ["WORLD_SIZE"] = f"{world_size}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
return dist.group.WORLD
if backend == "nccl":
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
topology = Topology(world_size=world_size, compute_device="cuda")
pg = init_distributed_single_host(rank, world_size, backend)
planner = EmbeddingShardingPlanner(
topology=topology,
constraints=constraints,
)
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingCollectionSharder())]
plan: ShardingPlan = planner.collective_plan(module, sharders, pg)
print(plan)
sharded_model = DistributedModelParallel(
module,
env=ShardingEnv.from_process_group(pg),
plan=plan,
sharders=sharders,
device=device,
)
mb = torchrec.KeyedJaggedTensor(
keys = ["product", "user"],
values = torch.tensor([101, 202, 303, 404, 505, 606]).cuda(),
lengths = torch.tensor([2, 0, 1, 1, 1, 1], dtype=torch.int64).cuda(),
)
print(sharded_model(mb)['product'].to_padded_dense().sum())
(sharded_model(mb)['product'].to_padded_dense().sum() + sharded_model(mb)['user'].to_padded_dense().sum()).backward()
print([p.grad.sum() for p in sharded_model.parameters()])
print(f"rank:{rank},sharding plan: {plan}")
return sharded_model
# In[4]:
import multiprocess
def spmd_sharing_simulation(
sharding_type: ShardingType = ShardingType.TABLE_WISE,
world_size = 2,
):
ctx = multiprocess.get_context("spawn")
processes = []
for rank in range(world_size):
p = ctx.Process(
target=single_rank_execution,
args=(
rank,
world_size,
gen_constraints(sharding_type),
ebc,
"nccl"
),
)
p.start()
processes.append(p)
for p in processes:
p.join()
assert 0 == p.exitcode
# In[5]:
ebc = torchrec.EmbeddingCollection(
device=torch.device("cuda"),
tables=[
torchrec.EmbeddingConfig(
name="product_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["product"],
),
torchrec.EmbeddingConfig(
name="user_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["user"],
)
]
)
# In[6]:
spmd_sharing_simulation(ShardingType.ROW_WISE) This code will result in the error bellow eventutally:
|
@Ye-Tian-Zero okay you are right, we did ban that combo. Can you try to allow DENSE here? https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L415 |
I try to test the demo scripts with the following command.
torchx run -s local_cwd dist.ddp -j 1x1 --script test_installation.py
I try to print the gradients of embedding_bag_collection. I can observe the gradients of mlp linear layer. However, the gradients of embedding_bag_collection seem to be None.
The text was updated successfully, but these errors were encountered: