Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

RGCNConv has multiple graph breaks #8467

Open
sharlinu opened this issue Nov 28, 2023 · 5 comments 路 May be fixed by #8783
Open

RGCNConv has multiple graph breaks #8467

sharlinu opened this issue Nov 28, 2023 · 5 comments 路 May be fixed by #8783
Assignees

Comments

@sharlinu
Copy link

sharlinu commented Nov 28, 2023

馃殌 RGCNConv does barely have any speed up with torch_geometric.compile()

I have recently implemented torch_geometric.compile on my nn.module that mostly consists of the torch_geometric.nn.conv.rgcn_conv module. Analysing the graph breaks with torch._dynamo.explain() I have found that this module causes 6 graph breaks:

Graph Count: 7
Graph Break Count: 6
Op Count: 18
Break Reasons:
  Break Reason 1:
    Reason: call_function BuiltinVariable(int) [TensorVariable()] {}
    User Stack:
      <FrameSummary file /tmp/utke_s_pyg/tmpx31ev0wx.py, line 292 in forward>
  Break Reason 2:
    Reason: call_function BuiltinVariable(bool) [TensorVariable()] {}
    User Stack:
      <FrameSummary file /home/utke_s@WMGDS.WMG.WARWICK.AC.UK/miniconda3/envs/MADRL_py39/lib/python3.9/site-packages/torch_geometric/backend.py, line 53 in use_segment_matmul_heuristic>
Ops per Graph:
  Ops 1:
    <built-in method zeros of type object at 0x7fccf6ba7a40>
    <built-in method ones_like of type object at 0x7fccf6ba7a40>
  Ops 2:
    <built-in method tensor of type object at 0x7fccf6ba7a40>
    <built-in method tensor of type object at 0x7fccf6ba7a40>
    <built-in method tensor of type object at 0x7fccf6ba7a40>
    <built-in method tensor of type object at 0x7fccf6ba7a40>
    <built-in function sub>
    <built-in function truediv>
    <built-in function matmul>
    <built-in function ge>
  Ops 3:
    <built-in method zeros of type object at 0x7fccf6ba7a40>
    <built-in function getitem>
    <built-in function getitem>
    <built-in function truediv>
  Ops 4:
    <built-in method zeros of type object at 0x7fccf6ba7a40>
    <built-in function getitem>
    <built-in function getitem>
    <built-in function truediv>
  Ops 5:
  Ops 6:
  Ops 7:

Would it be possible to have a version of this module with reduced number of graph breaks? Unfortunately, I do not know enough about compilation and graph breaks to understand what causes the breaks, so any help on this would also be greatly appreciated!

Alternatives

No response

Additional context

No response

@akihironitta
Copy link
Member

Would you mind sharing your code and env details to reproduce?

@akihironitta akihironitta self-assigned this Nov 28, 2023
@sharlinu
Copy link
Author

Hi, thanks for the quick response, I hope this gives more details:

Installs:
torch==2.1.1
torch-scatter==2.1.2
torch_geometric @ file:///usr/share/miniconda/envs/test/conda-bld/pyg_1697446937113/work
torchaudio==2.1.1
torchvision==0.16.1

class TestCritic(nn.Module):
    def __init__(self) -> object:
        super(TestCritic, self).__init__()
        self.gnn_layers = RGCNConv(3, 128, 6)

    def forward(self, geometric_batch):
        embedds = self.gnn_layers(geometric_batch.x, geometric_batch.edge_index, geometric_batch.edge_attr)
        embedds = torch.relu(embedds)
        x = pool.global_max_pool(embedds, batch=geometric_batch.batch, size=1024)
        return x

gd = batch_to_gd(binary_batch, device) # this creates a batch of dim 1024 torch_geometric.data.Data objects of different sizes and 6 relation types 
testcritic = TestCritic()
explain_output = torch._dynamo.explain(testcritic)(geometric_batch=self.gd)

@akihironitta
Copy link
Member

@sharlinu Thank you, I'll have a look at this tomorrow.

Analysing the graph breaks with torch._dynamo.explain() I have found that this module causes 6 graph breaks:

Just FYI, if you use torch._dynamo.explain, you'll need to manually disable these flags in your script:

torch_geometric.typing.WITH_INDEX_SORT = False
torch_geometric.typing.WITH_TORCH_SCATTER = False

because they're not compatible with torch.compile although if you use torch_geometric.compile(...), these are automatically disabled here:

# Disable the usage of external extension packages:
# TODO (matthias) Disable only temporarily
prev_state = {
'WITH_INDEX_SORT': torch_geometric.typing.WITH_INDEX_SORT,
'WITH_TORCH_SCATTER': torch_geometric.typing.WITH_TORCH_SCATTER,
}
warnings.filterwarnings('ignore', ".*the 'torch-scatter' package.*")
for key in prev_state.keys():
setattr(torch_geometric.typing, key, False)

@akihironitta
Copy link
Member

For the note, users will no longer need to disable the flags manually thanks to #8698 in the next release 2.5.0.

@akihironitta akihironitta linked a pull request Jan 17, 2024 that will close this issue
@akihironitta
Copy link
Member

A quick fix with torch-geometric<=2.4.0 is to run this at the top of the file:

import torch_geometric

torch_geometric.backend.use_segment_matmul = False

In 2.5.0 and later versions, you should see no graph breaks in RGCNConv.

@sharlinu Please let me know if this works for you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants