Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot compute the gradients of EmbeddingBagCollection. #1614

Open
ZhuYuJin opened this issue Jan 9, 2024 · 9 comments
Open

Cannot compute the gradients of EmbeddingBagCollection. #1614

ZhuYuJin opened this issue Jan 9, 2024 · 9 comments

Comments

@ZhuYuJin
Copy link

ZhuYuJin commented Jan 9, 2024

I try to test the demo scripts with the following command.
torchx run -s local_cwd dist.ddp -j 1x1 --script test_installation.py

I try to print the gradients of embedding_bag_collection. I can observe the gradients of mlp linear layer. However, the gradients of embedding_bag_collection seem to be None.
image

@henrylhtsang
Copy link
Contributor

We can't really look into the gradients. #1293

I guess Colin was saying if you switch to dense compute kernel, then you can see it. Though I haven't tried it myself.

@Ye-Tian-Zero
Copy link

@colin2328 Hi, as I also encountered the same problem, and after searching some issues, I found that the grad can be provided in dense compute kernel.
But I think dense compute kernel was diabled due to this commit:
35c05fa

Do you have any other ways to run a embedding collection in dense mode?

@Ye-Tian-Zero
Copy link

Ye-Tian-Zero commented Mar 6, 2024

I found one way is to implement my own Sharder

class EmbeddingCollectionSharderDense(EmbeddingCollectionSharder):
        def compute_kernels(self, sharding_type: str, compute_device_type: str) -> List[str]:
            return super().compute_kernels(sharding_type, compute_device_type) + [EmbeddingComputeKernel.DENSE.value]

@henrylhtsang
Copy link
Contributor

@Ye-Tian-Zero see to my answer in #1741

tldr: for non-data parallel sharding type, you get better performance by using FUSED.

We didn't disable DENSE. We disabled DENSE in most tests, since its slower than FUSED.

For data parallel, you must use DENSE.

@Ye-Tian-Zero
Copy link

@henrylhtsang Hi, thank you, but I think dense compute kernel was disabled here:

35c05fa#diff-72b4a4f205f4de558b2f5731a6697ae80483ee532eb2012d59999e5d7c462200L302-R302

which means Dense kernel was no longer a valid option for non-data-parallel module.

@Ye-Tian-Zero
Copy link

What I mean by 'disable' is that it can no longer be used with sharding modules.

@henrylhtsang
Copy link
Contributor

@Ye-Tian-Zero It can still be used. Let me know if you encounter any problem using it. You can configure it through ParameterConstraints with EmbeddingComputeKernel.

The reason it was disabled in the tests is to save testing capacity, since DENSE is slower than FUSED and hence we usually recommend people to use FUSED. But for debugging purposes, feel free to use DENSE.

@Ye-Tian-Zero
Copy link

Ye-Tian-Zero commented May 16, 2024

@henrylhtsang Sorry for the late reply, maybe you can try this script with a gpu machine:

#!/usr/bin/env python
# coding: utf-8

# In[1]:


import os
import torch
import torchrec

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "21500"


# In[2]:


from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import ShardingType
from typing import Dict

large_table_cnt = 2
small_table_cnt = 2
large_tables=[
  torchrec.EmbeddingConfig(
    name="large_table_" + str(i),
    embedding_dim=64,
    num_embeddings=4096,
    feature_names=["large_table_feature_" + str(i)],
  ) for i in range(large_table_cnt)
]
small_tables=[
  torchrec.EmbeddingConfig(
    name="small_table_" + str(i),
    embedding_dim=64,
    num_embeddings=1024,
    feature_names=["small_table_feature_" + str(i)],
  ) for i in range(small_table_cnt)
]

def gen_constraints(sharding_type: ShardingType = ShardingType.TABLE_WISE) -> Dict[str, ParameterConstraints]:
  large_table_constraints = {
    "product_table": ParameterConstraints(
          sharding_types=[sharding_type.value],
          compute_kernels=[EmbeddingComputeKernel.DENSE.value]
        ) 
      
  }
  small_table_constraints = {
    "user_table":ParameterConstraints(
          sharding_types=[sharding_type.value],
          compute_kernels=[EmbeddingComputeKernel.DENSE.value],
    ) 
  }
  constraints = {**large_table_constraints, **small_table_constraints}
  return constraints


# In[3]:


def single_rank_execution(
    rank: int,
    world_size: int,
    constraints: Dict[str, ParameterConstraints],
    module: torch.nn.Module,
    backend: str,
) -> None:
    import os
    import torch
    import torch.distributed as dist
    from torchrec.distributed.model_parallel import DistributedModelParallel
    from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
    from torchrec.distributed.types import ModuleSharder, ShardingEnv
    from typing import cast
    import torchrec
    from torchrec.distributed.embedding_types import EmbeddingComputeKernel
    from torchrec.distributed.embedding import EmbeddingCollectionSharder
    from typing import List
    
    def init_distributed_single_host(
        rank: int,
        world_size: int,
        backend: str,
        # pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
    ) -> dist.ProcessGroup:
        os.environ["RANK"] = f"{rank}"
        os.environ["WORLD_SIZE"] = f"{world_size}"
        dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
        return dist.group.WORLD

    if backend == "nccl":
        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")
    topology = Topology(world_size=world_size, compute_device="cuda")
    pg = init_distributed_single_host(rank, world_size, backend)
    planner = EmbeddingShardingPlanner(
        topology=topology,
        constraints=constraints,
    )
    sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingCollectionSharder())]
    plan: ShardingPlan = planner.collective_plan(module, sharders, pg)
    print(plan)
    sharded_model = DistributedModelParallel(
        module,
        env=ShardingEnv.from_process_group(pg),
        plan=plan,
        sharders=sharders,
        device=device,
    )
    mb = torchrec.KeyedJaggedTensor(
        keys = ["product", "user"],
        values = torch.tensor([101, 202, 303, 404, 505, 606]).cuda(),
        lengths = torch.tensor([2, 0, 1, 1, 1, 1], dtype=torch.int64).cuda(),
    )
    print(sharded_model(mb)['product'].to_padded_dense().sum())
    (sharded_model(mb)['product'].to_padded_dense().sum() + sharded_model(mb)['user'].to_padded_dense().sum()).backward()
    print([p.grad.sum() for p in sharded_model.parameters()])
    print(f"rank:{rank},sharding plan: {plan}")
    return sharded_model


# In[4]:


import multiprocess

def spmd_sharing_simulation(
    sharding_type: ShardingType = ShardingType.TABLE_WISE,
    world_size = 2,
):
  ctx = multiprocess.get_context("spawn")
  processes = []
  for rank in range(world_size):
      p = ctx.Process(
          target=single_rank_execution,
          args=(
              rank,
              world_size,
              gen_constraints(sharding_type),
              ebc,
              "nccl"
          ),
      )
      p.start()
      processes.append(p)

  for p in processes:
      p.join()
      assert 0 == p.exitcode


# In[5]:


ebc = torchrec.EmbeddingCollection(
    device=torch.device("cuda"),
    tables=[
        torchrec.EmbeddingConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
        ),
        torchrec.EmbeddingConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
        )
    ]
)


# In[6]:


spmd_sharing_simulation(ShardingType.ROW_WISE)

This code will result in the error bellow eventutally:

No available compute kernels after applying user provided constraints for product_table
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/data/click-attribution/venv/lib/python3.8/site-packages/multiprocess/process.py", line 315, in _bootstrap
    self.run()
  File "/data/click-attribution/venv/lib/python3.8/site-packages/multiprocess/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/tmp/ipykernel_30517/3096772556.py", line 43, in single_rank_execution
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/planner/planners.py", line 187, in collective_plan
    return invoke_on_rank_and_broadcast_result(
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/collective_utils.py", line 53, in invoke_on_rank_and_broadcast_result
    res = func(*args, **kwargs)
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/planner/planners.py", line 217, in plan
    search_space = self._enumerator.enumerate(
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/planner/enumerators.py", line 166, in enumerate
    raise RuntimeError(
RuntimeError: No available sharding type and compute kernel combination after applying user provided constraints for product_table
Process SpawnProcess-2:
Traceback (most recent call last):
  File "/data/click-attribution/venv/lib/python3.8/site-packages/multiprocess/process.py", line 315, in _bootstrap
    self.run()
  File "/data/click-attribution/venv/lib/python3.8/site-packages/multiprocess/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/tmp/ipykernel_30517/3096772556.py", line 43, in single_rank_execution
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/planner/planners.py", line 187, in collective_plan
    return invoke_on_rank_and_broadcast_result(
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torchrec/distributed/collective_utils.py", line 58, in invoke_on_rank_and_broadcast_result
    dist.broadcast_object_list(object_list, rank, group=pg)
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2603, in broadcast_object_list
    broadcast(object_sizes_tensor, src=src, group=group)
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/data/click-attribution/venv/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1906, in broadcast
    work = default_pg.broadcast([tensor], opts)
RuntimeError: [1] is setting up NCCL communicator and retrieving ncclUniqueId from [0] via c10d key-value store by key '0', but store->get('0') got error: Connection reset by peer. This may indicate a possible application crash on rank 0 or a network set up issue.
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[6], line 1
----> 1 spmd_sharing_simulation(ShardingType.ROW_WISE)

Cell In[4], line 25, in spmd_sharing_simulation(sharding_type, world_size)
     23 for p in processes:
     24     p.join()
---> 25     assert 0 == p.exitcode

AssertionError: 

@henrylhtsang
Copy link
Contributor

@Ye-Tian-Zero okay you are right, we did ban that combo. Can you try to allow DENSE here? https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L415

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants