Skip to content

Commit

Permalink
[dynamo] Bug fix for LOAD_GLOBAL and STORE_GLOBAL (pytorch#125002)
Browse files Browse the repository at this point in the history
Earlier globals of inlined functions from other files were not handled correctly. We were not tracking mutations on them. They were colliding with the same global name in the parent function etc. This PR overrides the LOAD/STORE_GLOBAL for inline tx and tracks mutation on them separately.

Pull Request resolved: pytorch#125002
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#125097, pytorch#125107
  • Loading branch information
anijain2305 authored and andoorve committed May 1, 2024
1 parent b34e033 commit 36f0f85
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 19 deletions.
10 changes: 10 additions & 0 deletions test/dynamo/test_functions.py
Expand Up @@ -54,6 +54,16 @@ def constant3(a, b):
return a - b + (1.0 + 2)


_variable = 0


def update_global(x):
global _variable
_variable += 1
# Check that updated global variable value is picked up
return x * _variable


def func_with_default(a, b, some_default_arg=True):
if some_default_arg:
return a - b
Expand Down
26 changes: 26 additions & 0 deletions test/dynamo/test_modules.py
Expand Up @@ -28,6 +28,16 @@
import test_functions


_variable = 0
_variable1 = 0


def update_global():
global _variable, _variable1
_variable += 1
_variable1 += 1


class BasicModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -2435,6 +2445,22 @@ def forward(self, inp):

self.assertEqual(model.x, compiled_model.x)

def test_globals_change_in_other_file(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
update_global()
a = test_functions.update_global(x)
# Ensure that the updated global values are read
return x * a * (_variable + _variable1 + test_functions._variable)

res = fn(torch.ones(10))
self.assertEqual(_variable, 1)
self.assertEqual(_variable1, 1)
# Ensure that the reconstructed bytecode updates the global value in the
# other file.
self.assertEqual(test_functions._variable, 1)
self.assertEqual(res, 3 * torch.ones(10))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
3 changes: 2 additions & 1 deletion torch/_dynamo/side_effects.py
Expand Up @@ -13,7 +13,7 @@
)
from .codegen import PyCodegen
from .exc import unimplemented
from .source import LocalSource, Source
from .source import GlobalSource, LocalSource, Source
from .utils import nn_module_new, object_new
from .variables.base import (
is_side_effect_safe,
Expand Down Expand Up @@ -485,6 +485,7 @@ def codegen_update_mutated(self, cg: PyCodegen):
if isinstance(var, variables.NewGlobalVariable):
cg.tx.output.update_co_names(name)
cg(value)
assert isinstance(var.mutable_local.source, GlobalSource) # type: ignore[attr-defined]
suffixes.append(
[create_instruction("STORE_GLOBAL", argval=name)]
)
Expand Down
77 changes: 59 additions & 18 deletions torch/_dynamo/symbolic_convert.py
Expand Up @@ -969,22 +969,6 @@ def _load_const(self, inst):
def LOAD_CONST(self, inst):
self.push(self._load_const(inst))

def get_global_source(self, name):
source: Source
if self.output.global_scope is self.f_globals:
source = GlobalSource(name)
else:
if "__name__" in self.f_globals:
source = AttrSource(
self.import_source(self.f_globals["__name__"]), name
)
else:
mangled_name = self.output.install_global_by_id(
"___unnamed_scope", self.f_globals
)
source = GetItemSource(GlobalSource(mangled_name), name)
return source

def LOAD_GLOBAL(self, inst):
if sys.version_info >= (3, 11):
if inst.arg % 2:
Expand Down Expand Up @@ -1012,13 +996,13 @@ def LOAD_GLOBAL(self, inst):
except KeyError:
return self.load_builtin(inst)

source = self.get_global_source(name)
source = GlobalSource(name)
self.push(VariableBuilder(self, source)(value))

def STORE_GLOBAL(self, inst):
value = self.pop()
name = inst.argval
source = self.get_global_source(name)
source = GlobalSource(name)
if name not in self.symbolic_globals:
self.symbolic_globals[name] = object() # type: ignore[assignment] # sentinel object
variable = self.output.side_effects.track_global_existing(
Expand Down Expand Up @@ -2692,6 +2676,63 @@ def RETURN_CONST(self, inst):
self.instruction_pointer = None
raise ReturnValueOp

def get_globals_source_and_value(self, name):
if "__name__" in self.f_globals:
module_name = self.f_globals["__name__"]
module_source = self.import_source(module_name)
if "torch_package" in module_name:
fglobals_value = torch.package.package_importer._package_imported_modules[module_name] # type: ignore[assignment]
else:
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
fglobals_vt = VariableBuilder(self, module_source)(fglobals_value)
global_source = AttrSource(module_source, name)
else:
globals_name = self.output.install_global_by_id(
"___unnamed_scope", self.f_globals
)
globals_source = GlobalSource(globals_name)
fglobals_value = self.f_globals # type: ignore[assignment]
fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value)
global_source = GetItemSource(globals_source, name) # type: ignore[assignment]
return fglobals_value, fglobals_vt, global_source

def LOAD_GLOBAL(self, inst):
if self.output.global_scope is self.f_globals:
super().LOAD_GLOBAL(inst)
else:
if sys.version_info >= (3, 11):
if inst.arg % 2:
self.PUSH_NULL(inst)

name = inst.argval
if inst.argval == "AssertionError":
unimplemented("assert with non-string message")

_, fglobals_vt, global_source = self.get_globals_source_and_value(name)
if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name):
self.push(self.output.side_effects.load_attr(fglobals_vt, name))
else:
try:
value = self.f_globals[name]
except KeyError:
return self.load_builtin(inst)

self.push(VariableBuilder(self, global_source)(value))

def STORE_GLOBAL(self, inst):
if self.f_globals is self.parent.f_globals:
super().STORE_GLOBAL(inst)
else:
value = self.pop()
if isinstance(value, RemovableHandleVariable):
unimplemented("Storing handles in globals - NYI")
name = inst.argval
fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name)
fglobals_vt = self.output.side_effects.track_object_existing(
fglobals_value, fglobals_vt
)
self.output.side_effects.store_attr(fglobals_vt, name, value)


class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
generated_items: List[VariableTracker]
Expand Down

0 comments on commit 36f0f85

Please sign in to comment.