Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: 'NoneType' object has no attribute '_dynamo_weak_dynamic_indices' when using row-wise sharding #1890

Open
jiannanWang opened this issue Apr 17, 2024 · 2 comments

Comments

@jiannanWang
Copy link

Description

I’m using torch.compile with DistributedModelParallel. Running below code result in AttributeError: 'NoneType' object has no attribute '_dynamo_weak_dynamic_indices'. Note that this seems to only happen when using row-wise sharding. I would expect no such errors when running the above code.

Enviroment:

python=3.11.8, torch= '2.2.2+cu121', torchrec= '0.6.0+cu121'.

Reproduction code:

import os
from typing import Callable, List, Union, Tuple
import multiprocessing

import torch
import torch.distributed as dist
import torch.nn as nn
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
    EmbeddingShardingPlanner,
    Topology,
)
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
from torchrec.distributed.test_utils.test_sharding import create_test_sharder
from torchrec.distributed.test_utils.test_model import (
    ModelInput,
)
from torchrec.distributed.types import (
    ModuleSharder,
    ShardingEnv,
    ShardingPlan,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedTensor
from torchrec.test_utils import get_free_port

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        # define model parameters
        self.dense_in_feature = 820
        self.dense_out_feature = 784
        self.table_params = [
            [311, 108],
            [739, 408],
        ]
        self.weighted_table_params = [
            [159, 96],
            [69, 24],
            [412, 564],
            [940, 300],
        ]
        self.over_out_feature = 61

        # sparse layer
        self.tables = [
            EmbeddingBagConfig(
                num_embeddings=self.table_params[i][0],
                embedding_dim=self.table_params[i][1],
                name="table_" + str(i),
                feature_names=["feature_" + str(i)],
            )
            for i in range(len(self.table_params))
        ]
        self.sparse = EmbeddingBagCollection(
            tables=self.tables,
            is_weighted=False,
        )
        # weighted sparse layer
        self.weighted_tables = [
            EmbeddingBagConfig(
                num_embeddings=self.weighted_table_params[i][0],
                embedding_dim=self.weighted_table_params[i][1],
                name="weighted_table_" + str(i),
                feature_names=["weighted_feature_" + str(i)],
            )
            for i in range(len(self.weighted_table_params))
        ]
        self.sparse_weighted = EmbeddingBagCollection(
            tables=self.weighted_tables, 
            is_weighted=True,
        )
        # dense layer
        self.dense = nn.Linear(in_features=self.dense_in_feature, out_features=self.dense_out_feature, bias=True)
        # over layer
        in_features_concat = (
            self.dense_out_feature
            + sum([table.embedding_dim * len(table.feature_names) for table in self.tables])
            + sum([table.embedding_dim * len(table.feature_names) for table in self.weighted_tables])
        )
        self.over = nn.Linear(in_features=in_features_concat, out_features=self.over_out_feature, bias=True)

    def forward(
        self,
        input: ModelInput,
        print_intermediate_layer: bool = False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        # dense, sparse, weighted sparse layer output
        dense_r = self.dense(input.float_features)
        sparse_r = self.sparse(input.idlist_features)
        sparse_weighted_r = self.sparse_weighted(input.idscore_features)
        # concat dense, sparse, weighted sparse layer output
        result = KeyedTensor(
            keys=sparse_r.keys() + sparse_weighted_r.keys(),
            length_per_key=sparse_r.length_per_key()
            + sparse_weighted_r.length_per_key(),
            values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
        )
        _features = [feature for table in self.tables for feature in table.feature_names]
        _weighted_features = [feature for table in self.weighted_tables for feature in table.feature_names]

        ret_list = []
        ret_list.append(dense_r)
        for feature_name in _features:
            ret_list.append(result[feature_name])
        for feature_name in _weighted_features:
            ret_list.append(result[feature_name])
        ret_concat = torch.cat(ret_list, dim=1)
        # over layer output
        over_r = self.over(ret_concat)
        # sigmoid output
        pred = torch.sigmoid(torch.mean(over_r, dim=1))

        return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)


def sharding_single_rank_test(
    rank: int,
    world_size: int,
    model,
    inputs,
    sharders: List[ModuleSharder[nn.Module]],
    backend: str,
    compiled = True,
) -> None:

    with MultiProcessContext(rank, world_size, backend) as ctx:
        
        if compiled:
            model = torch.compile(model)
        local_model = model.to(ctx.device)

            
        planner = EmbeddingShardingPlanner(
            topology=Topology(
                world_size, ctx.device.type
            ),
        )
        plan: ShardingPlan = planner.collective_plan(local_model, sharders, ctx.pg)

        local_model = DistributedModelParallel(
            local_model,
            env=ShardingEnv.from_process_group(ctx.pg),
            plan=plan,
            sharders=sharders,
            device=ctx.device,
        )

        # Run a single training step of the sharded model.
        local_input = inputs[0][1][rank].to(ctx.device)
        local_pred, (dense_r, sparse_r, sparse_weighted_r, over_r) = local_model(local_input)

        # record the local prediction
        all_local_pred = []
        for _ in range(world_size):
            all_local_pred.append(torch.empty_like(local_pred))
        dist.all_gather(all_local_pred, local_pred, group=ctx.pg)

        # record the local model's layer output
        all_dense_r = []
        for _ in range(world_size):
            all_dense_r.append(torch.empty_like(dense_r))
        dist.all_gather(all_dense_r, dense_r, group=ctx.pg)

        # print(sparse_r.to_dict())
        sparse_r_dict = sparse_r.to_dict()
        all_sparse_r_dict = {}
        for key in sparse_r_dict:
            all_sparse_r_dict[key] = []
            for _ in range(world_size):
                all_sparse_r_dict[key].append(torch.empty_like(sparse_r_dict[key]))
            dist.all_gather(all_sparse_r_dict[key], sparse_r_dict[key].contiguous(), group=ctx.pg)

        sparse_weighted_r_dict = sparse_weighted_r.to_dict()
        all_sparse_weighted_r_dict = {}
        for key in sparse_weighted_r_dict:
            all_sparse_weighted_r_dict[key] = []
            for _ in range(world_size):
                all_sparse_weighted_r_dict[key].append(torch.empty_like(sparse_weighted_r_dict[key]))
            dist.all_gather(all_sparse_weighted_r_dict[key], sparse_weighted_r_dict[key].contiguous(), group=ctx.pg)

        all_over_r = []
        for _ in range(world_size):
            all_over_r.append(torch.empty_like(over_r))
        dist.all_gather(all_over_r, over_r, group=ctx.pg)


def setUp():
    os.environ["MASTER_ADDR"] = str("localhost")
    os.environ["MASTER_PORT"] = str(get_free_port())
    os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP"
    os.environ["NCCL_SOCKET_IFNAME"] = "lo"

    torch.use_deterministic_algorithms(True)
    if torch.cuda.is_available():
        torch.backends.cudnn.allow_tf32 = False
        torch.backends.cuda.matmul.allow_tf32 = False
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def run_multi_process_test(
    callable: Callable[
        ...,
        None,
    ],
    world_size: int,
    # pyre-ignore
    **kwargs,
) -> None:
    setUp()
    ctx = multiprocessing.get_context("forkserver")
    processes = []
    for rank in range(world_size):
        kwargs["rank"] = rank
        kwargs["world_size"] = world_size
        p = ctx.Process(
            target=callable,
            kwargs=kwargs,
        )
        p.start()
        processes.append(p)
    for p in processes:
        p.join()


def main_test(
    sharders: List[ModuleSharder[nn.Module]],
    backend: str,
    world_size: int,
    compiled: bool,
) -> None:
    model = TestModel()
    inputs = [ModelInput.generate(
        batch_size=1200,
        world_size=world_size,
        num_float_features=model.dense_in_feature,
        tables=model.tables,
        weighted_tables=model.weighted_tables,
    )]

    run_multi_process_test(
        callable=sharding_single_rank_test,
        world_size=world_size,
        model=model,
        inputs=inputs,
        sharders=sharders,
        backend=backend,
        compiled=compiled,
    )


if __name__ == "__main__":
    sharders = [create_test_sharder("embedding_bag_collection", "row_wise", "dense")]
    backend = "nccl"
    world_size = 2
    main_test(
        sharders = sharders,
        backend = backend,
        world_size = world_size,
        compiled = True,
    )

Log:

The error message is copied below.

Traceback (most recent call last):
File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
  self.run()
File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 108, in run
  self._target(*self._args, **self._kwargs)
File "/mnt/tests/reproduce_nccl_row_wise.py", line 154, in sharding_single_rank_test
  local_pred, (dense_r, sparse_r, sparse_weighted_r, over_r) = local_model(local_input)
                                                               ^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
  return forward_call(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/model_parallel.py", line 288, in forward
  return self._dmp_wrapped_module(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
  return forward_call(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1523, in forward
  else self._run_ddp_forward(*inputs, **kwargs)
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
  return self.module(*inputs, **kwargs)  # type: ignore[index]
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
  return forward_call(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
  return fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
  return forward_call(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/tests/reproduce_nccl_row_wise.py", line 86, in forward
  def forward(
File "/mnt/tests/reproduce_nccl_row_wise.py", line 93, in resume_in_forward
  sparse_r = self.sparse(input.idlist_features)
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
  return forward_call(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/types.py", line 747, in forward
  dist_input = self.input_dist(ctx, *input, **kwargs).wait().wait()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/embeddingbag.py", line 756, in input_dist
  def input_dist(
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/embeddingbag.py", line 782, in resume_in_input_dist
  awaitables.append(input_dist(features_by_shard))
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
  return forward_call(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/distributed/sharding/rw_sharding.py", line 292, in forward
  def forward(
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
  return fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
  return fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 901, in forward
  return compiled_fn(full_args)
         ^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g
  return f(*args)
         ^^^^^^^^
File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 259, in runtime_wrapper
  t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy()
  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute '_dynamo_weak_dynamic_indices'
@PaulZhang12
Copy link
Contributor

Dynamo isn't yet supported for every sharding combination in TorchRec, cc @IvanKobzarev to verify

@jiannanWang
Copy link
Author

Thank you for your reply! Can you please tell me what sharding is supported with Dynamo now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants