Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Tensors must be contiguous when running a specific model in DistributedModelParallel with world size equaling 2 #1889

Open
jiannanWang opened this issue Apr 17, 2024 · 1 comment

Comments

@jiannanWang
Copy link

Description

I’m using torch.compile with DistributedModelParallel. Running the below code results in a ValueError: Tensors must be contiguous. This error seems to be specific to the model and the world size. I would expect to see no such errors, like when I run the code with other world sizes.

Enviroment:

python=3.11.8, torch= '2.2.2+cu121', torchrec= '0.6.0+cu121'.

Reproduction code:

import os
from typing import Callable, List, Union, Tuple
import multiprocessing

import torch
import torch.distributed as dist
import torch.nn as nn
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
    EmbeddingShardingPlanner,
    Topology,
)
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
from torchrec.distributed.test_utils.test_sharding import create_test_sharder
from torchrec.distributed.test_utils.test_model import (
    ModelInput,
)
from torchrec.distributed.types import (
    ModuleSharder,
    ShardingEnv,
    ShardingPlan,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedTensor
from torchrec.test_utils import get_free_port

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        # define model parameters
        self.dense_in_feature = 820
        self.dense_out_feature = 784
        self.table_params = [
            [311, 108],
            [739, 408],
        ]
        self.weighted_table_params = [
            [159, 96],
            [69, 24],
            [412, 564],
            [940, 300],
        ]
        self.over_out_feature = 61

        # sparse layer
        self.tables = [
            EmbeddingBagConfig(
                num_embeddings=self.table_params[i][0],
                embedding_dim=self.table_params[i][1],
                name="table_" + str(i),
                feature_names=["feature_" + str(i)],
            )
            for i in range(len(self.table_params))
        ]
        self.sparse = EmbeddingBagCollection(
            tables=self.tables,
            is_weighted=False,
        )
        # weighted sparse layer
        self.weighted_tables = [
            EmbeddingBagConfig(
                num_embeddings=self.weighted_table_params[i][0],
                embedding_dim=self.weighted_table_params[i][1],
                name="weighted_table_" + str(i),
                feature_names=["weighted_feature_" + str(i)],
            )
            for i in range(len(self.weighted_table_params))
        ]
        self.sparse_weighted = EmbeddingBagCollection(
            tables=self.weighted_tables, 
            is_weighted=True,
        )
        # dense layer
        self.dense = nn.Linear(in_features=self.dense_in_feature, out_features=self.dense_out_feature, bias=True)
        # over layer
        in_features_concat = (
            self.dense_out_feature
            + sum([table.embedding_dim * len(table.feature_names) for table in self.tables])
            + sum([table.embedding_dim * len(table.feature_names) for table in self.weighted_tables])
        )
        self.over = nn.Linear(in_features=in_features_concat, out_features=self.over_out_feature, bias=True)

    def forward(
        self,
        input: ModelInput,
        print_intermediate_layer: bool = False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        # dense, sparse, weighted sparse layer output
        dense_r = self.dense(input.float_features)
        sparse_r = self.sparse(input.idlist_features)
        sparse_weighted_r = self.sparse_weighted(input.idscore_features)
        # concat dense, sparse, weighted sparse layer output
        result = KeyedTensor(
            keys=sparse_r.keys() + sparse_weighted_r.keys(),
            length_per_key=sparse_r.length_per_key()
            + sparse_weighted_r.length_per_key(),
            values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
        )
        _features = [feature for table in self.tables for feature in table.feature_names]
        _weighted_features = [feature for table in self.weighted_tables for feature in table.feature_names]

        ret_list = []
        ret_list.append(dense_r)
        for feature_name in _features:
            ret_list.append(result[feature_name])
        for feature_name in _weighted_features:
            ret_list.append(result[feature_name])
        ret_concat = torch.cat(ret_list, dim=1)
        # over layer output
        over_r = self.over(ret_concat)
        # sigmoid output
        pred = torch.sigmoid(torch.mean(over_r, dim=1))

        return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)


def sharding_single_rank_test(
    rank: int,
    world_size: int,
    model,
    inputs,
    sharders: List[ModuleSharder[nn.Module]],
    backend: str,
    compiled = True,
) -> None:

    with MultiProcessContext(rank, world_size, backend) as ctx:
        
        if compiled:
            model = torch.compile(model)
        local_model = model.to(ctx.device)

            
        planner = EmbeddingShardingPlanner(
            topology=Topology(
                world_size, ctx.device.type
            ),
        )
        plan: ShardingPlan = planner.collective_plan(local_model, sharders, ctx.pg)

        local_model = DistributedModelParallel(
            local_model,
            env=ShardingEnv.from_process_group(ctx.pg),
            plan=plan,
            sharders=sharders,
            device=ctx.device,
        )

        # Run a single training step of the sharded model.
        local_input = inputs[0][1][rank].to(ctx.device)

        with torch.no_grad():
            local_pred, (dense_r, sparse_r, sparse_weighted_r, over_r) = local_model(local_input)

        # record the local prediction
        all_local_pred = []
        for _ in range(world_size):
            all_local_pred.append(torch.empty_like(local_pred))
        dist.all_gather(all_local_pred, local_pred, group=ctx.pg)

        # record the local model's layer output
        all_dense_r = []
        for _ in range(world_size):
            all_dense_r.append(torch.empty_like(dense_r))
        dist.all_gather(all_dense_r, dense_r, group=ctx.pg)

        sparse_r_dict = sparse_r.to_dict()
        all_sparse_r_dict = {}
        for key in sparse_r_dict:
            all_sparse_r_dict[key] = []
            for _ in range(world_size):
                all_sparse_r_dict[key].append(torch.empty_like(sparse_r_dict[key]))
            dist.all_gather(all_sparse_r_dict[key], sparse_r_dict[key].contiguous(), group=ctx.pg)

        sparse_weighted_r_dict = sparse_weighted_r.to_dict()
        all_sparse_weighted_r_dict = {}
        for key in sparse_weighted_r_dict:
            all_sparse_weighted_r_dict[key] = []
            for _ in range(world_size):
                all_sparse_weighted_r_dict[key].append(torch.empty_like(sparse_weighted_r_dict[key]))
            dist.all_gather(all_sparse_weighted_r_dict[key], sparse_weighted_r_dict[key].contiguous(), group=ctx.pg)

        all_over_r = []
        for _ in range(world_size):
            all_over_r.append(torch.empty_like(over_r))
        dist.all_gather(all_over_r, over_r, group=ctx.pg)


def setUp():
    os.environ["MASTER_ADDR"] = str("localhost")
    os.environ["MASTER_PORT"] = str(get_free_port())
    os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP"
    os.environ["NCCL_SOCKET_IFNAME"] = "lo"

    torch.use_deterministic_algorithms(True)
    if torch.cuda.is_available():
        torch.backends.cudnn.allow_tf32 = False
        torch.backends.cuda.matmul.allow_tf32 = False
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def run_multi_process_test(
    callable: Callable[
        ...,
        None,
    ],
    world_size: int,
    # pyre-ignore
    **kwargs,
) -> None:
    setUp()
    ctx = multiprocessing.get_context("forkserver")
    processes = []
    for rank in range(world_size):
        kwargs["rank"] = rank
        kwargs["world_size"] = world_size
        p = ctx.Process(
            target=callable,
            kwargs=kwargs,
        )
        p.start()
        processes.append(p)
    for p in processes:
        p.join()


def main_test(
    sharders: List[ModuleSharder[nn.Module]],
    backend: str,
    world_size: int,
    compiled: bool,
) -> None:
    model = TestModel()
    inputs = [ModelInput.generate(
        batch_size=1200,
        world_size=world_size,
        num_float_features=model.dense_in_feature,
        tables=model.tables,
        weighted_tables=model.weighted_tables,
    )]

    run_multi_process_test(
        callable=sharding_single_rank_test,
        world_size=world_size,
        model=model,
        inputs=inputs,
        sharders=sharders,
        backend=backend,
        compiled=compiled,
    )


if __name__ == "__main__":
    sharders = [create_test_sharder("embedding_bag_collection", "column_wise", "dense")]
    backend = "nccl"
    world_size = 2
    main_test(
        sharders = sharders,
        backend = backend,
        world_size = world_size,
        compiled = True,
    )

Log:

The error message is copied below.

  Traceback (most recent call last):
  File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
      self.run()
  File "/root/miniconda3/envs/torchrec/lib/python3.11/multiprocessing/process.py", line 108, in run
      self._target(*self._args, **self._kwargs)
  File "/mnt/tests/reproduce_nccl_tensor_must_be_contiguous.py", line 203, in sharding_single_rank_test
      dist.all_gather(all_over_r, over_r, group=ctx.pg)
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
      return func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2617, in all_gather
      work = group.allgather([tensor_list], [tensor])
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ValueError: Tensors must be contiguous
@colin2328
Copy link
Contributor

cc @IvanKobzarev

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants