Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] Fix test #125107

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 26 additions & 11 deletions test/dynamo/test_misc.py
Expand Up @@ -10144,21 +10144,36 @@ def test_linear_module_free(self):
def test_outside_linear_module_free(self):
# Compared to test_linear_module_free, the linear
# layer is not the code object that is directly compiled.
def model_inp_ctr():
fc = torch.nn.Linear(100, 100)

class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc_ref = fc
# This test does not use _test_compile_model_free because of difficulty
# in handling variable fc.

def forward(self, x):
return fc(x[0])
fc = torch.nn.Linear(100, 100)

class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc_ref = fc

def forward(self, x):
return fc(x[0])

# return fc to keep it alive in _test_compile_model_free
return Mod(), (torch.randn(100, 100), fc)
cleared = False

def finalize():
nonlocal cleared
cleared = True

self._test_compile_model_free(model_inp_ctr, lambda mod: mod.fc_ref)
def run():
mod = Mod()
inp = torch.randn(100, 100)
weakref.finalize(mod.fc_ref, finalize)
torch.compile(mod, backend="eager")(inp)

run()
del fc # This should delete all the references
gc.collect()
self.assertTrue(cleared)

@unittest.skipIf(sys.version_info >= (3, 12), "leaks in 3.12+")
def test_parameter_free(self):
Expand Down