Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long running time when using torch.compile with DistributedModelParallel and dynamo errors #1888

Open
jiannanWang opened this issue Apr 17, 2024 · 1 comment

Comments

@jiannanWang
Copy link

Description

I’m using torch.compile with DistributedModelParallel. Given torch.compile is able to speed up pytorch distributed models, I would expect to see faster inference time. However, it takes 50 seconds to finish a forward pass, while it is 4 seconds without torch.compile. In the meantime, there are also many dynamo errors showing up. I would expect it to run without error and provide performance speed up.

Enviroment:

python=3.11.8, torch= '2.2.2+cu121', torchrec= '0.6.0+cu121'.

Reproduction code:

import os
import time
from typing import Callable, List, Union, Tuple
import multiprocessing

import torch
import torch.distributed as dist
import torch.nn as nn
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
    EmbeddingShardingPlanner,
    Topology,
)
from torchrec.distributed.test_utils.multi_process import MultiProcessContext
from torchrec.distributed.test_utils.test_sharding import create_test_sharder
from torchrec.distributed.test_utils.test_model import (
    ModelInput,
)
from torchrec.distributed.types import (
    ModuleSharder,
    ShardingEnv,
    ShardingPlan,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedTensor
from torchrec.test_utils import get_free_port

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        # define model parameters
        self.dense_in_feature = 914
        self.dense_out_feature = 930
        self.table_params = [
            [777, 912],
            [431, 44],
            [266, 992],
            [524, 500],
        ]
        self.weighted_table_params = [
            [941, 804],
            [850, 312],
            [992, 492],
            [367, 600],
        ]
        self.over_out_feature = 224

        # sparse layer
        self.tables = [
            EmbeddingBagConfig(
                num_embeddings=self.table_params[i][0],
                embedding_dim=self.table_params[i][1],
                name="table_" + str(i),
                feature_names=["feature_" + str(i)],
            )
            for i in range(len(self.table_params))
        ]
        self.sparse = EmbeddingBagCollection(
            tables=self.tables,
            is_weighted=False,
        )
        # weighted sparse layer
        self.weighted_tables = [
            EmbeddingBagConfig(
                num_embeddings=self.weighted_table_params[i][0],
                embedding_dim=self.weighted_table_params[i][1],
                name="weighted_table_" + str(i),
                feature_names=["weighted_feature_" + str(i)],
            )
            for i in range(len(self.weighted_table_params))
        ]
        self.sparse_weighted = EmbeddingBagCollection(
            tables=self.weighted_tables, 
            is_weighted=True,
        )
        # dense layer
        self.dense = nn.Linear(in_features=self.dense_in_feature, out_features=self.dense_out_feature, bias=True)
        # over layer
        in_features_concat = (
            self.dense_out_feature
            + sum([table.embedding_dim * len(table.feature_names) for table in self.tables])
            + sum([table.embedding_dim * len(table.feature_names) for table in self.weighted_tables])
        )
        self.over = nn.Linear(in_features=in_features_concat, out_features=self.over_out_feature, bias=True)

    def forward(
        self,
        input: ModelInput,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        # dense, sparse, weighted sparse layer output
        dense_r = self.dense(input.float_features)
        sparse_r = self.sparse(input.idlist_features)
        sparse_weighted_r = self.sparse_weighted(input.idscore_features)
        # concat dense, sparse, weighted sparse layer output
        result = KeyedTensor(
            keys=sparse_r.keys() + sparse_weighted_r.keys(),
            length_per_key=sparse_r.length_per_key()
            + sparse_weighted_r.length_per_key(),
            values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
        )
        _features = [feature for table in self.tables for feature in table.feature_names]
        _weighted_features = [feature for table in self.weighted_tables for feature in table.feature_names]

        ret_list = []
        ret_list.append(dense_r)
        for feature_name in _features:
            ret_list.append(result[feature_name])
        for feature_name in _weighted_features:
            ret_list.append(result[feature_name])
        ret_concat = torch.cat(ret_list, dim=1)
        # over layer output
        over_r = self.over(ret_concat)
        # sigmoid output
        pred = torch.sigmoid(torch.mean(over_r, dim=1))
        # return model's output and layers' output
        return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)


def sharding_single_rank_test(
    rank: int,
    world_size: int,
    model,
    inputs,
    sharders: List[ModuleSharder[nn.Module]],
    backend: str,
    compiled = True,
) -> None:

    with MultiProcessContext(rank, world_size, backend) as ctx:
        
        if compiled:
            model = torch.compile(model)
        local_model = model.to(ctx.device)

            
        planner = EmbeddingShardingPlanner(
            topology=Topology(
                world_size, ctx.device.type
            ),
        )
        plan: ShardingPlan = planner.collective_plan(local_model, sharders, ctx.pg)

        local_model = DistributedModelParallel(
            local_model,
            env=ShardingEnv.from_process_group(ctx.pg),
            plan=plan,
            sharders=sharders,
            device=ctx.device,
        )

        # Run a single training step of the sharded model.
        local_input = inputs[0][1][rank].to(ctx.device)
        local_pred, (dense_r, sparse_r, sparse_weighted_r, over_r) = local_model(local_input)

        # record the local prediction
        all_local_pred = []
        for _ in range(world_size):
            all_local_pred.append(torch.empty_like(local_pred))
        dist.all_gather(all_local_pred, local_pred, group=ctx.pg)

        # record the local model's layer output
        all_dense_r = []
        for _ in range(world_size):
            all_dense_r.append(torch.empty_like(dense_r))
        dist.all_gather(all_dense_r, dense_r, group=ctx.pg)

        sparse_r_dict = sparse_r.to_dict()
        all_sparse_r_dict = {}
        for key in sparse_r_dict:
            all_sparse_r_dict[key] = []
            for _ in range(world_size):
                all_sparse_r_dict[key].append(torch.empty_like(sparse_r_dict[key]))
            dist.all_gather(all_sparse_r_dict[key], sparse_r_dict[key].contiguous(), group=ctx.pg)

        sparse_weighted_r_dict = sparse_weighted_r.to_dict()
        all_sparse_weighted_r_dict = {}
        for key in sparse_weighted_r_dict:
            all_sparse_weighted_r_dict[key] = []
            for _ in range(world_size):
                all_sparse_weighted_r_dict[key].append(torch.empty_like(sparse_weighted_r_dict[key]))
            dist.all_gather(all_sparse_weighted_r_dict[key], sparse_weighted_r_dict[key].contiguous(), group=ctx.pg)

        all_over_r = []
        for _ in range(world_size):
            all_over_r.append(torch.empty_like(over_r))
        dist.all_gather(all_over_r, over_r, group=ctx.pg)


def setUp():
    os.environ["MASTER_ADDR"] = str("localhost")
    os.environ["MASTER_PORT"] = str(get_free_port())
    os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP"
    os.environ["NCCL_SOCKET_IFNAME"] = "lo"

    torch.use_deterministic_algorithms(True)
    if torch.cuda.is_available():
        torch.backends.cudnn.allow_tf32 = False
        torch.backends.cuda.matmul.allow_tf32 = False
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def run_multi_process_test(
    callable: Callable[
        ...,
        None,
    ],
    world_size: int,
    # pyre-ignore
    **kwargs,
) -> None:
    setUp()
    ctx = multiprocessing.get_context("forkserver")
    processes = []
    for rank in range(world_size):
        kwargs["rank"] = rank
        kwargs["world_size"] = world_size
        p = ctx.Process(
            target=callable,
            kwargs=kwargs,
        )
        p.start()
        processes.append(p)
    for p in processes:
        p.join()


def main_test(
    sharders: List[ModuleSharder[nn.Module]],
    backend: str = "gloo",
    world_size: int = 2,
    compiled = True,
) -> None:
    model = TestModel()
    inputs = [ModelInput.generate(
        batch_size=2400,
        world_size=world_size,
        num_float_features=model.dense_in_feature,
        tables=model.tables,
        weighted_tables=model.weighted_tables,
    )]

    run_multi_process_test(
        callable=sharding_single_rank_test,
        world_size=world_size,
        model=model,
        inputs=inputs,
        sharders=sharders,
        backend=backend,
        compiled=compiled,
    )


if __name__ == "__main__":
    sharders = [create_test_sharder("embedding_bag_collection", "table_wise", "dense")]
    backend = "gloo"
    world_size = 2
    start_time = time.time()
    main_test(
        sharders = sharders,
        backend = backend,
        world_size = world_size,
        compiled = True,
    )
    end_time = time.time()
    test_time = end_time - start_time
    print("Test time: ", test_time)
    start_time = time.time()
    main_test(
        sharders = sharders,
        backend = backend,
        world_size = world_size,
        compiled = False,
    )
    end_time = time.time()
    test_time = end_time - start_time
    print("Test time: ", test_time)

Log:

The full log is copied below.

Click me
Test time:  50.37300157546997
Test time:  4.749462366104126
torchrec/sparse/jagged_tensor.py:588: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(JaggedTensor, _jt_flatten, _jt_unflatten)
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:1993: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/utils/_pytree.py:223: UserWarning: to_str_fn and maybe_from_str_fn is deprecated. Please use to_dumpable_context and from_dumpable_context instead.
  warnings.warn(
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:2199: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(KeyedTensor, _kt_flatten, _kt_unflatten)
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:588: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(JaggedTensor, _jt_flatten, _jt_unflatten)
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:1993: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/utils/_pytree.py:223: UserWarning: to_str_fn and maybe_from_str_fn is deprecated. Please use to_dumpable_context and from_dumpable_context instead.
  warnings.warn(
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:2199: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(KeyedTensor, _kt_flatten, _kt_unflatten)
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:588: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(JaggedTensor, _jt_flatten, _jt_unflatten)
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:1993: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/utils/_pytree.py:223: UserWarning: to_str_fn and maybe_from_str_fn is deprecated. Please use to_dumpable_context and from_dumpable_context instead.
  warnings.warn(
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:2199: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(KeyedTensor, _kt_flatten, _kt_unflatten)
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
[rank1]:[2024-04-17 17:26:36,236] [4/0_1] torch._guards: [ERROR] Error while creating guard:
[rank1]:[2024-04-17 17:26:36,236] [4/0_1] torch._guards: [ERROR] Name: "L['self']"
[rank1]:[2024-04-17 17:26:36,236] [4/0_1] torch._guards: [ERROR]     Source: local
[rank1]:[2024-04-17 17:26:36,236] [4/0_1] torch._guards: [ERROR]     Create Function: NN_MODULE
[rank1]:[2024-04-17 17:26:36,236] [4/0_1] torch._guards: [ERROR]     Guard Types: ['ID_MATCH']
[rank1]:[2024-04-17 17:26:36,236] [4/0_1] torch._guards: [ERROR]     Code List: ["___check_obj_id(L['self'], 140166366385488)"]
[rank1]:[2024-04-17 17:26:36,236] [4/0_1] torch._guards: [ERROR]     Object Weakref: <weakref at 0x7f7aa89225c0; to 'TwSparseFeaturesDist' at 0x7f7b067a5550>
[rank1]:[2024-04-17 17:26:36,236] [4/0_1] torch._guards: [ERROR]     Guarded Class Weakref: <weakref at 0x7f7a1bba13f0; to 'ABCMeta' at 0x89cf9b0 (TwSparseFeaturesDist)>
[rank1]:[2024-04-17 17:26:36,238] [4/0_1] torch._guards: [ERROR] Created at:
[rank1]:[2024-04-17 17:26:36,238] [4/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 245, in __call__
[rank1]:[2024-04-17 17:26:36,238] [4/0_1] torch._guards: [ERROR]     vt = self._wrap(value).clone(**self.options())
[rank1]:[2024-04-17 17:26:36,238] [4/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 469, in _wrap
[rank1]:[2024-04-17 17:26:36,238] [4/0_1] torch._guards: [ERROR]     return self.wrap_module(value)
[rank1]:[2024-04-17 17:26:36,238] [4/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 937, in wrap_module
[rank1]:[2024-04-17 17:26:36,238] [4/0_1] torch._guards: [ERROR]     return self.tx.output.register_attr_or_module(
[rank1]:[2024-04-17 17:26:36,238] [4/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 736, in register_attr_or_module
[rank1]:[2024-04-17 17:26:36,238] [4/0_1] torch._guards: [ERROR]     install_guard(source.make_guard(GuardBuilder.NN_MODULE))
[rank0]:[2024-04-17 17:26:36,239] [4/0_1] torch._guards: [ERROR] Error while creating guard:
[rank0]:[2024-04-17 17:26:36,239] [4/0_1] torch._guards: [ERROR] Name: "L['self']"
[rank0]:[2024-04-17 17:26:36,239] [4/0_1] torch._guards: [ERROR]     Source: local
[rank0]:[2024-04-17 17:26:36,239] [4/0_1] torch._guards: [ERROR]     Create Function: NN_MODULE
[rank0]:[2024-04-17 17:26:36,239] [4/0_1] torch._guards: [ERROR]     Guard Types: ['ID_MATCH']
[rank0]:[2024-04-17 17:26:36,239] [4/0_1] torch._guards: [ERROR]     Code List: ["___check_obj_id(L['self'], 140166234582928)"]
[rank0]:[2024-04-17 17:26:36,239] [4/0_1] torch._guards: [ERROR]     Object Weakref: <weakref at 0x7f79f44a6480; to 'TwSparseFeaturesDist' at 0x7f7afe9f2f90>
[rank0]:[2024-04-17 17:26:36,239] [4/0_1] torch._guards: [ERROR]     Guarded Class Weakref: <weakref at 0x7f7a1bba14e0; to 'ABCMeta' at 0x89cf9b0 (TwSparseFeaturesDist)>
[rank0]:[2024-04-17 17:26:36,241] [4/0_1] torch._guards: [ERROR] Created at:
[rank0]:[2024-04-17 17:26:36,241] [4/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 245, in __call__
[rank0]:[2024-04-17 17:26:36,241] [4/0_1] torch._guards: [ERROR]     vt = self._wrap(value).clone(**self.options())
[rank0]:[2024-04-17 17:26:36,241] [4/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 469, in _wrap
[rank0]:[2024-04-17 17:26:36,241] [4/0_1] torch._guards: [ERROR]     return self.wrap_module(value)
[rank0]:[2024-04-17 17:26:36,241] [4/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 937, in wrap_module
[rank0]:[2024-04-17 17:26:36,241] [4/0_1] torch._guards: [ERROR]     return self.tx.output.register_attr_or_module(
[rank0]:[2024-04-17 17:26:36,241] [4/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 736, in register_attr_or_module
[rank0]:[2024-04-17 17:26:36,241] [4/0_1] torch._guards: [ERROR]     install_guard(source.make_guard(GuardBuilder.NN_MODULE))
[rank1]:[2024-04-17 17:26:36,253] [5/0_1] torch._guards: [ERROR] Error while creating guard:
[rank1]:[2024-04-17 17:26:36,253] [5/0_1] torch._guards: [ERROR] Name: "L['self']"
[rank1]:[2024-04-17 17:26:36,253] [5/0_1] torch._guards: [ERROR]     Source: local
[rank1]:[2024-04-17 17:26:36,253] [5/0_1] torch._guards: [ERROR]     Create Function: NN_MODULE
[rank1]:[2024-04-17 17:26:36,253] [5/0_1] torch._guards: [ERROR]     Guard Types: ['ID_MATCH']
[rank1]:[2024-04-17 17:26:36,253] [5/0_1] torch._guards: [ERROR]     Code List: ["___check_obj_id(L['self'], 140164790470480)"]
[rank1]:[2024-04-17 17:26:36,253] [5/0_1] torch._guards: [ERROR]     Object Weakref: <weakref at 0x7f7aa8750770; to 'KJTAllToAll' at 0x7f7aa88bc750>
[rank1]:[2024-04-17 17:26:36,253] [5/0_1] torch._guards: [ERROR]     Guarded Class Weakref: <weakref at 0x7f7a1bb85120; to 'type' at 0x89b4410 (KJTAllToAll)>
[rank1]:[2024-04-17 17:26:36,254] [5/0_1] torch._guards: [ERROR] Created at:
[rank1]:[2024-04-17 17:26:36,254] [5/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 245, in __call__
[rank1]:[2024-04-17 17:26:36,254] [5/0_1] torch._guards: [ERROR]     vt = self._wrap(value).clone(**self.options())
[rank1]:[2024-04-17 17:26:36,254] [5/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 469, in _wrap
[rank1]:[2024-04-17 17:26:36,254] [5/0_1] torch._guards: [ERROR]     return self.wrap_module(value)
[rank1]:[2024-04-17 17:26:36,254] [5/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 937, in wrap_module
[rank1]:[2024-04-17 17:26:36,254] [5/0_1] torch._guards: [ERROR]     return self.tx.output.register_attr_or_module(
[rank1]:[2024-04-17 17:26:36,254] [5/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 736, in register_attr_or_module
[rank1]:[2024-04-17 17:26:36,254] [5/0_1] torch._guards: [ERROR]     install_guard(source.make_guard(GuardBuilder.NN_MODULE))
[rank0]:[2024-04-17 17:26:36,257] [5/0_1] torch._guards: [ERROR] Error while creating guard:
[rank0]:[2024-04-17 17:26:36,257] [5/0_1] torch._guards: [ERROR] Name: "L['self']"
[rank0]:[2024-04-17 17:26:36,257] [5/0_1] torch._guards: [ERROR]     Source: local
[rank0]:[2024-04-17 17:26:36,257] [5/0_1] torch._guards: [ERROR]     Create Function: NN_MODULE
[rank0]:[2024-04-17 17:26:36,257] [5/0_1] torch._guards: [ERROR]     Guard Types: ['ID_MATCH']
[rank0]:[2024-04-17 17:26:36,257] [5/0_1] torch._guards: [ERROR]     Code List: ["___check_obj_id(L['self'], 140161764738512)"]
[rank0]:[2024-04-17 17:26:36,257] [5/0_1] torch._guards: [ERROR]     Object Weakref: <weakref at 0x7f79f4592020; to 'KJTAllToAll' at 0x7f79f432c5d0>
[rank0]:[2024-04-17 17:26:36,257] [5/0_1] torch._guards: [ERROR]     Guarded Class Weakref: <weakref at 0x7f7a7ccb4db0; to 'type' at 0x89b4410 (KJTAllToAll)>
[rank0]:[2024-04-17 17:26:36,258] [5/0_1] torch._guards: [ERROR] Created at:
[rank0]:[2024-04-17 17:26:36,258] [5/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 245, in __call__
[rank0]:[2024-04-17 17:26:36,258] [5/0_1] torch._guards: [ERROR]     vt = self._wrap(value).clone(**self.options())
[rank0]:[2024-04-17 17:26:36,258] [5/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 469, in _wrap
[rank0]:[2024-04-17 17:26:36,258] [5/0_1] torch._guards: [ERROR]     return self.wrap_module(value)
[rank0]:[2024-04-17 17:26:36,258] [5/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 937, in wrap_module
[rank0]:[2024-04-17 17:26:36,258] [5/0_1] torch._guards: [ERROR]     return self.tx.output.register_attr_or_module(
[rank0]:[2024-04-17 17:26:36,258] [5/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 736, in register_attr_or_module
[rank0]:[2024-04-17 17:26:36,258] [5/0_1] torch._guards: [ERROR]     install_guard(source.make_guard(GuardBuilder.NN_MODULE))
[rank1]:[2024-04-17 17:26:36,866] [6/0_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:26:36,966] [6/0_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:26:37,137] [7/0] torch._dynamo.exc: [WARNING] Backend compiler failed with a fake tensor exception at 
[rank1]:[2024-04-17 17:26:37,137] [7/0] torch._dynamo.exc: [WARNING]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py", line 1668, in permute
[rank1]:[2024-04-17 17:26:37,137] [7/0] torch._dynamo.exc: [WARNING]     return kjt
[rank1]:[2024-04-17 17:26:37,137] [7/0] torch._dynamo.exc: [WARNING] Adding a graph break.
[rank0]:[2024-04-17 17:26:37,252] [7/0] torch._dynamo.exc: [WARNING] Backend compiler failed with a fake tensor exception at 
[rank0]:[2024-04-17 17:26:37,252] [7/0] torch._dynamo.exc: [WARNING]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py", line 1668, in permute
[rank0]:[2024-04-17 17:26:37,252] [7/0] torch._dynamo.exc: [WARNING]     return kjt
[rank0]:[2024-04-17 17:26:37,252] [7/0] torch._dynamo.exc: [WARNING] Adding a graph break.
[rank1]:[2024-04-17 17:26:37,296] [7/0_1] torch._dynamo.exc: [WARNING] Backend compiler failed with a fake tensor exception at 
[rank1]:[2024-04-17 17:26:37,296] [7/0_1] torch._dynamo.exc: [WARNING]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py", line 1668, in permute
[rank1]:[2024-04-17 17:26:37,296] [7/0_1] torch._dynamo.exc: [WARNING]     return kjt
[rank1]:[2024-04-17 17:26:37,296] [7/0_1] torch._dynamo.exc: [WARNING] Adding a graph break.
[rank1]:[2024-04-17 17:26:37,339] [8/0] torch._dynamo.exc: [WARNING] Backend compiler failed with a fake tensor exception at 
[rank1]:[2024-04-17 17:26:37,339] [8/0] torch._dynamo.exc: [WARNING]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py", line 682, in _maybe_compute_length_per_key
[rank1]:[2024-04-17 17:26:37,339] [8/0] torch._dynamo.exc: [WARNING]     return length_per_key
[rank1]:[2024-04-17 17:26:37,339] [8/0] torch._dynamo.exc: [WARNING] Adding a graph break.
[rank1]:[2024-04-17 17:26:37,378] [8/0_1] torch._dynamo.exc: [WARNING] Backend compiler failed with a fake tensor exception at 
[rank1]:[2024-04-17 17:26:37,378] [8/0_1] torch._dynamo.exc: [WARNING]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py", line 682, in _maybe_compute_length_per_key
[rank1]:[2024-04-17 17:26:37,378] [8/0_1] torch._dynamo.exc: [WARNING]     return length_per_key
[rank1]:[2024-04-17 17:26:37,378] [8/0_1] torch._dynamo.exc: [WARNING] Adding a graph break.
[rank0]:[2024-04-17 17:26:37,402] [7/0_1] torch._dynamo.exc: [WARNING] Backend compiler failed with a fake tensor exception at 
[rank0]:[2024-04-17 17:26:37,402] [7/0_1] torch._dynamo.exc: [WARNING]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py", line 1668, in permute
[rank0]:[2024-04-17 17:26:37,402] [7/0_1] torch._dynamo.exc: [WARNING]     return kjt
[rank0]:[2024-04-17 17:26:37,402] [7/0_1] torch._dynamo.exc: [WARNING] Adding a graph break.
[rank0]:[2024-04-17 17:26:37,446] [8/0] torch._dynamo.exc: [WARNING] Backend compiler failed with a fake tensor exception at 
[rank0]:[2024-04-17 17:26:37,446] [8/0] torch._dynamo.exc: [WARNING]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py", line 682, in _maybe_compute_length_per_key
[rank0]:[2024-04-17 17:26:37,446] [8/0] torch._dynamo.exc: [WARNING]     return length_per_key
[rank0]:[2024-04-17 17:26:37,446] [8/0] torch._dynamo.exc: [WARNING] Adding a graph break.
[rank0]:[2024-04-17 17:26:37,500] [8/0_1] torch._dynamo.exc: [WARNING] Backend compiler failed with a fake tensor exception at 
[rank0]:[2024-04-17 17:26:37,500] [8/0_1] torch._dynamo.exc: [WARNING]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py", line 682, in _maybe_compute_length_per_key
[rank0]:[2024-04-17 17:26:37,500] [8/0_1] torch._dynamo.exc: [WARNING]     return length_per_key
[rank0]:[2024-04-17 17:26:37,500] [8/0_1] torch._dynamo.exc: [WARNING] Adding a graph break.
[rank1]:[2024-04-17 17:26:37,669] [15/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:26:37,790] [16/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:26:37,805] [15/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:26:37,934] [16/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:26:38,164] [17/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:26:38,345] [17/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:26:51,401] [18/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:26:51,425] [18/0_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:26:53,969] [22/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:26:53,997] [22/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:26:54,003] [22/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:00,009] [18/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:00,163] [18/0_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:00,588] [22/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:00,599] [23/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:00,617] [22/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:00,622] [22/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:00,625] [23/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:00,639] [24/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:00,657] [23/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:00,690] [23/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:00,703] [24/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:02,552] [25/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:02,561] [25/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:02,677] [25/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:02,685] [25/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:02,689] [26/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:02,715] [26/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:02,773] [26/0] torch._dynamo.exc: [WARNING] Backend compiler failed with a fake tensor exception at 
[rank1]:[2024-04-17 17:27:02,773] [26/0] torch._dynamo.exc: [WARNING]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py", line 1957, in dist_init
[rank1]:[2024-04-17 17:27:02,773] [26/0] torch._dynamo.exc: [WARNING]     return kjt.sync()
[rank1]:[2024-04-17 17:27:02,773] [26/0] torch._dynamo.exc: [WARNING] Adding a graph break.
[rank1]:[2024-04-17 17:27:02,786] [26/0_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:02,814] [26/0] torch._dynamo.exc: [WARNING] Backend compiler failed with a fake tensor exception at 
[rank0]:[2024-04-17 17:27:02,814] [26/0] torch._dynamo.exc: [WARNING]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py", line 1957, in dist_init
[rank0]:[2024-04-17 17:27:02,814] [26/0] torch._dynamo.exc: [WARNING]     return kjt.sync()
[rank0]:[2024-04-17 17:27:02,814] [26/0] torch._dynamo.exc: [WARNING] Adding a graph break.
[rank0]:[2024-04-17 17:27:02,825] [26/0_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:05,691] [37/0_1] torch._guards: [ERROR] Error while creating guard:
[rank1]:[2024-04-17 17:27:05,691] [37/0_1] torch._guards: [ERROR] Name: "L['self']"
[rank1]:[2024-04-17 17:27:05,691] [37/0_1] torch._guards: [ERROR]     Source: local
[rank1]:[2024-04-17 17:27:05,691] [37/0_1] torch._guards: [ERROR]     Create Function: NN_MODULE
[rank1]:[2024-04-17 17:27:05,691] [37/0_1] torch._guards: [ERROR]     Guard Types: ['ID_MATCH']
[rank1]:[2024-04-17 17:27:05,691] [37/0_1] torch._guards: [ERROR]     Code List: ["___check_obj_id(L['self'], 140164707244240)"]
[rank1]:[2024-04-17 17:27:05,691] [37/0_1] torch._guards: [ERROR]     Object Weakref: <weakref at 0x7f7aa362a2f0; to 'PooledEmbeddingsAllToAll' at 0x7f7aa395d8d0>
[rank1]:[2024-04-17 17:27:05,691] [37/0_1] torch._guards: [ERROR]     Guarded Class Weakref: <weakref at 0x7f7a1bb853a0; to 'type' at 0x733d410 (PooledEmbeddingsAllToAll)>
[rank1]:[2024-04-17 17:27:05,693] [37/0_1] torch._guards: [ERROR] Created at:
[rank1]:[2024-04-17 17:27:05,693] [37/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 245, in __call__
[rank1]:[2024-04-17 17:27:05,693] [37/0_1] torch._guards: [ERROR]     vt = self._wrap(value).clone(**self.options())
[rank1]:[2024-04-17 17:27:05,693] [37/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 469, in _wrap
[rank1]:[2024-04-17 17:27:05,693] [37/0_1] torch._guards: [ERROR]     return self.wrap_module(value)
[rank1]:[2024-04-17 17:27:05,693] [37/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 937, in wrap_module
[rank1]:[2024-04-17 17:27:05,693] [37/0_1] torch._guards: [ERROR]     return self.tx.output.register_attr_or_module(
[rank1]:[2024-04-17 17:27:05,693] [37/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 736, in register_attr_or_module
[rank1]:[2024-04-17 17:27:05,693] [37/0_1] torch._guards: [ERROR]     install_guard(source.make_guard(GuardBuilder.NN_MODULE))
[rank0]:[2024-04-17 17:27:05,788] [37/0_1] torch._guards: [ERROR] Error while creating guard:
[rank0]:[2024-04-17 17:27:05,788] [37/0_1] torch._guards: [ERROR] Name: "L['self']"
[rank0]:[2024-04-17 17:27:05,788] [37/0_1] torch._guards: [ERROR]     Source: local
[rank0]:[2024-04-17 17:27:05,788] [37/0_1] torch._guards: [ERROR]     Create Function: NN_MODULE
[rank0]:[2024-04-17 17:27:05,788] [37/0_1] torch._guards: [ERROR]     Guard Types: ['ID_MATCH']
[rank0]:[2024-04-17 17:27:05,788] [37/0_1] torch._guards: [ERROR]     Code List: ["___check_obj_id(L['self'], 140160646786576)"]
[rank0]:[2024-04-17 17:27:05,788] [37/0_1] torch._guards: [ERROR]     Object Weakref: <weakref at 0x7f79b167d0d0; to 'PooledEmbeddingsAllToAll' at 0x7f79b1902e10>
[rank0]:[2024-04-17 17:27:05,788] [37/0_1] torch._guards: [ERROR]     Guarded Class Weakref: <weakref at 0x7f7a1bb81490; to 'type' at 0x733d410 (PooledEmbeddingsAllToAll)>
[rank0]:[2024-04-17 17:27:05,790] [37/0_1] torch._guards: [ERROR] Created at:
[rank0]:[2024-04-17 17:27:05,790] [37/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 245, in __call__
[rank0]:[2024-04-17 17:27:05,790] [37/0_1] torch._guards: [ERROR]     vt = self._wrap(value).clone(**self.options())
[rank0]:[2024-04-17 17:27:05,790] [37/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 469, in _wrap
[rank0]:[2024-04-17 17:27:05,790] [37/0_1] torch._guards: [ERROR]     return self.wrap_module(value)
[rank0]:[2024-04-17 17:27:05,790] [37/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 937, in wrap_module
[rank0]:[2024-04-17 17:27:05,790] [37/0_1] torch._guards: [ERROR]     return self.tx.output.register_attr_or_module(
[rank0]:[2024-04-17 17:27:05,790] [37/0_1] torch._guards: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 736, in register_attr_or_module
[rank0]:[2024-04-17 17:27:05,790] [37/0_1] torch._guards: [ERROR]     install_guard(source.make_guard(GuardBuilder.NN_MODULE))
[rank1]:[2024-04-17 17:27:05,988] [44/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[rank1]:[2024-04-17 17:27:05,989] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR] HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'torch.distributed.distributed_c10d.ProcessGroup'>
[rank1]:[2024-04-17 17:27:05,989] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR] Traceback (most recent call last):
[rank1]:[2024-04-17 17:27:05,989] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 224, in speculate_subgraph
[rank1]:[2024-04-17 17:27:05,989] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]     args = validate_args_and_maybe_create_graph_inputs(
[rank1]:[2024-04-17 17:27:05,989] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[2024-04-17 17:27:05,989] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 160, in validate_args_and_maybe_create_graph_inputs
[rank1]:[2024-04-17 17:27:05,989] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]     raise unimplemented(
[rank1]:[2024-04-17 17:27:05,989] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]           ^^^^^^^^^^^^^^
[rank1]:[2024-04-17 17:27:05,989] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[rank1]:[2024-04-17 17:27:05,989] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]     raise Unsupported(msg)
[rank1]:[2024-04-17 17:27:05,989] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'torch.distributed.distributed_c10d.ProcessGroup'>
[rank1]:[2024-04-17 17:27:06,038] [45/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:06,038] [44/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[rank0]:[2024-04-17 17:27:06,039] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR] HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'torch.distributed.distributed_c10d.ProcessGroup'>
[rank0]:[2024-04-17 17:27:06,039] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR] Traceback (most recent call last):
[rank0]:[2024-04-17 17:27:06,039] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 224, in speculate_subgraph
[rank0]:[2024-04-17 17:27:06,039] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]     args = validate_args_and_maybe_create_graph_inputs(
[rank0]:[2024-04-17 17:27:06,039] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[2024-04-17 17:27:06,039] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 160, in validate_args_and_maybe_create_graph_inputs
[rank0]:[2024-04-17 17:27:06,039] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]     raise unimplemented(
[rank0]:[2024-04-17 17:27:06,039] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]           ^^^^^^^^^^^^^^
[rank0]:[2024-04-17 17:27:06,039] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[rank0]:[2024-04-17 17:27:06,039] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR]     raise Unsupported(msg)
[rank0]:[2024-04-17 17:27:06,039] [44/0] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'torch.distributed.distributed_c10d.ProcessGroup'>
[rank1]:[2024-04-17 17:27:06,061] [45/0_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:06,266] [45/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:06,289] [45/0_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:06,967] [16/1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:06,997] [16/1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:07,190] [17/1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:07,192] [17/1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:09,172] [18/1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:09,195] [18/1_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:09,217] [18/1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:09,238] [18/1_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:11,255] [24/1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:11,255] [24/1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:11,380] [26/1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:11,380] [26/1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:11,806] [26/1] torch._dynamo.exc: [WARNING] Backend compiler failed with a fake tensor exception at 
[rank0]:[2024-04-17 17:27:11,806] [26/1] torch._dynamo.exc: [WARNING]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py", line 1957, in dist_init
[rank0]:[2024-04-17 17:27:11,806] [26/1] torch._dynamo.exc: [WARNING]     return kjt.sync()
[rank0]:[2024-04-17 17:27:11,806] [26/1] torch._dynamo.exc: [WARNING] Adding a graph break.
[rank0]:[2024-04-17 17:27:11,821] [26/1_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:11,950] [26/1] torch._dynamo.exc: [WARNING] Backend compiler failed with a fake tensor exception at 
[rank1]:[2024-04-17 17:27:11,950] [26/1] torch._dynamo.exc: [WARNING]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py", line 1957, in dist_init
[rank1]:[2024-04-17 17:27:11,950] [26/1] torch._dynamo.exc: [WARNING]     return kjt.sync()
[rank1]:[2024-04-17 17:27:11,950] [26/1] torch._dynamo.exc: [WARNING] Adding a graph break.
[rank1]:[2024-04-17 17:27:11,966] [26/1_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:14,361] [44/1] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[rank0]:[2024-04-17 17:27:14,361] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR] HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'torch.distributed.distributed_c10d.ProcessGroup'>
[rank0]:[2024-04-17 17:27:14,361] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR] Traceback (most recent call last):
[rank0]:[2024-04-17 17:27:14,361] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 224, in speculate_subgraph
[rank0]:[2024-04-17 17:27:14,361] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]     args = validate_args_and_maybe_create_graph_inputs(
[rank0]:[2024-04-17 17:27:14,361] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:[2024-04-17 17:27:14,361] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 160, in validate_args_and_maybe_create_graph_inputs
[rank0]:[2024-04-17 17:27:14,361] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]     raise unimplemented(
[rank0]:[2024-04-17 17:27:14,361] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]           ^^^^^^^^^^^^^^
[rank0]:[2024-04-17 17:27:14,361] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[rank0]:[2024-04-17 17:27:14,361] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]     raise Unsupported(msg)
[rank0]:[2024-04-17 17:27:14,361] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'torch.distributed.distributed_c10d.ProcessGroup'>
[rank1]:[2024-04-17 17:27:14,390] [44/1] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[rank1]:[2024-04-17 17:27:14,391] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR] HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'torch.distributed.distributed_c10d.ProcessGroup'>
[rank1]:[2024-04-17 17:27:14,391] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR] Traceback (most recent call last):
[rank1]:[2024-04-17 17:27:14,391] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 224, in speculate_subgraph
[rank1]:[2024-04-17 17:27:14,391] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]     args = validate_args_and_maybe_create_graph_inputs(
[rank1]:[2024-04-17 17:27:14,391] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:[2024-04-17 17:27:14,391] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 160, in validate_args_and_maybe_create_graph_inputs
[rank1]:[2024-04-17 17:27:14,391] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]     raise unimplemented(
[rank1]:[2024-04-17 17:27:14,391] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]           ^^^^^^^^^^^^^^
[rank1]:[2024-04-17 17:27:14,391] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]   File "/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 193, in unimplemented
[rank1]:[2024-04-17 17:27:14,391] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR]     raise Unsupported(msg)
[rank1]:[2024-04-17 17:27:14,391] [44/1] torch._dynamo.variables.higher_order_ops: [ERROR] torch._dynamo.exc.Unsupported: HigherOrderOperator with body that accepts non-Tensors as input. Got: <class 'torch.distributed.distributed_c10d.ProcessGroup'>
[rank0]:[2024-04-17 17:27:14,474] [45/1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:14,536] [45/1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:[2024-04-17 17:27:14,558] [45/1_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank1]:[2024-04-17 17:27:14,583] [45/1_1] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:588: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(JaggedTensor, _jt_flatten, _jt_unflatten)
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:1993: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/utils/_pytree.py:223: UserWarning: to_str_fn and maybe_from_str_fn is deprecated. Please use to_dumpable_context and from_dumpable_context instead.
  warnings.warn(
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:2199: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(KeyedTensor, _kt_flatten, _kt_unflatten)
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:588: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(JaggedTensor, _jt_flatten, _jt_unflatten)
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:1993: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torch/utils/_pytree.py:223: UserWarning: to_str_fn and maybe_from_str_fn is deprecated. Please use to_dumpable_context and from_dumpable_context instead.
  warnings.warn(
/root/miniconda3/envs/torchrec/lib/python3.11/site-packages/torchrec/sparse/jagged_tensor.py:2199: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _register_pytree_node(KeyedTensor, _kt_flatten, _kt_unflatten)
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
WARNING:torchrec.distributed.utils:Compute Kernel is dense, caching params will be ignored
@PaulZhang12
Copy link
Contributor

I don't believe torch.compile and TorchRec/DMP are fully compatible yet, @IvanKobzarev

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants