From 6311ba664ca978ef5c5037da723c15b51ff3c202 Mon Sep 17 00:00:00 2001 From: rechen Date: Tue, 16 Apr 2024 12:17:03 -0700 Subject: [PATCH 01/22] rewrite: implement reveal_type. PiperOrigin-RevId: 625414033 --- pytype/rewrite/load_abstract.py | 1 + pytype/rewrite/overlays/special_builtins.py | 28 ++++++++++++++++++- .../rewrite/overlays/special_builtins_test.py | 12 ++++++++ pytype/rewrite/tests/test_args.py | 1 - 4 files changed, 40 insertions(+), 2 deletions(-) diff --git a/pytype/rewrite/load_abstract.py b/pytype/rewrite/load_abstract.py index 9006613ba..eca86093c 100644 --- a/pytype/rewrite/load_abstract.py +++ b/pytype/rewrite/load_abstract.py @@ -49,6 +49,7 @@ def __init__(self, ctx: abstract.ContextType, pytd_loader: load_pytd.Loader): self.consts = Constants(ctx) self._special_builtins = { 'assert_type': special_builtins.AssertType(self._ctx), + 'reveal_type': special_builtins.RevealType(self._ctx), } self._special_builtins['NoneType'] = self.consts[None] diff --git a/pytype/rewrite/overlays/special_builtins.py b/pytype/rewrite/overlays/special_builtins.py index 25cd48c8c..b8b942bcc 100644 --- a/pytype/rewrite/overlays/special_builtins.py +++ b/pytype/rewrite/overlays/special_builtins.py @@ -1,8 +1,16 @@ """Builtin values with special behavior.""" +from typing import Optional, Sequence + from pytype.rewrite.abstract import abstract +def _stack( + frame: Optional[abstract.FrameType] +) -> Optional[Sequence[abstract.FrameType]]: + return frame.stack if frame else None + + class AssertType(abstract.SimpleFunction[abstract.SimpleReturn]): """assert_type implementation.""" @@ -24,6 +32,24 @@ def call_with_mapped_args( except ValueError: expected = pp.print_type_of_instance(typ.get_atomic_value()) if actual != expected: - stack = frame.stack if (frame := mapped_args.frame) else None + stack = _stack(mapped_args.frame) self._ctx.errorlog.assert_type(stack, actual, expected) return abstract.SimpleReturn(self._ctx.consts[None]) + + +class RevealType(abstract.SimpleFunction[abstract.SimpleReturn]): + """reveal_type implementation.""" + + def __init__(self, ctx: abstract.ContextType): + signature = abstract.Signature( + ctx=ctx, name='reveal_type', param_names=('object',)) + super().__init__( + ctx=ctx, name='reveal_type', signatures=(signature,), module='builtins') + + def call_with_mapped_args( + self, mapped_args: abstract.MappedArgs[abstract.FrameType], + ) -> abstract.SimpleReturn: + obj = mapped_args.argdict['object'] + stack = _stack(mapped_args.frame) + self._ctx.errorlog.reveal_type(stack, node=None, var=obj) + return abstract.SimpleReturn(self._ctx.consts[None]) diff --git a/pytype/rewrite/overlays/special_builtins_test.py b/pytype/rewrite/overlays/special_builtins_test.py index 0c4228feb..a0d2571e0 100644 --- a/pytype/rewrite/overlays/special_builtins_test.py +++ b/pytype/rewrite/overlays/special_builtins_test.py @@ -14,6 +14,18 @@ def test_types_match(self): typ = abstract.SimpleClass(ctx, 'int', {}).to_variable() ret = assert_type_func.call(abstract.Args(posargs=(var, typ))) self.assertEqual(ret.get_return_value(), ctx.consts[None]) + self.assertEqual(len(ctx.errorlog), 0) # pylint: disable=g-generic-assert + + +class RevealTypeTest(unittest.TestCase): + + def test_basic(self): + ctx = context.Context() + reveal_type_func = special_builtins.RevealType(ctx) + var = ctx.consts[0].to_variable() + ret = reveal_type_func.call(abstract.Args(posargs=(var,))) + self.assertEqual(ret.get_return_value(), ctx.consts[None]) + self.assertEqual(len(ctx.errorlog), 1) if __name__ == '__main__': diff --git a/pytype/rewrite/tests/test_args.py b/pytype/rewrite/tests/test_args.py index 91d9d6688..19ec1df79 100644 --- a/pytype/rewrite/tests/test_args.py +++ b/pytype/rewrite/tests/test_args.py @@ -29,7 +29,6 @@ def f(x, *, y): f(0, y=1) """) - @test_utils.skipBeforePy((3, 11), 'Relies on 3.11+ bytecode') def test_function_varargs(self): self.Check(""" def foo(x: str, *args): From 409ba01bf71a91d6a8ae80199193e56e2495acf3 Mon Sep 17 00:00:00 2001 From: rechen Date: Tue, 16 Apr 2024 12:17:42 -0700 Subject: [PATCH 02/22] rewrite: skip test failing in lower Python versions. PiperOrigin-RevId: 625414202 --- pytype/rewrite/tests/test_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytype/rewrite/tests/test_args.py b/pytype/rewrite/tests/test_args.py index 19ec1df79..91d9d6688 100644 --- a/pytype/rewrite/tests/test_args.py +++ b/pytype/rewrite/tests/test_args.py @@ -29,6 +29,7 @@ def f(x, *, y): f(0, y=1) """) + @test_utils.skipBeforePy((3, 11), 'Relies on 3.11+ bytecode') def test_function_varargs(self): self.Check(""" def foo(x: str, *args): From 4382f5dcbbc20750b80863a747b69005175136c6 Mon Sep 17 00:00:00 2001 From: rechen Date: Tue, 16 Apr 2024 14:20:04 -0700 Subject: [PATCH 03/22] rewrite: support simple class inheritance. Also removes a Python 2 TODO that I happened to spot while looking at our current implementation of __build_class__. PiperOrigin-RevId: 625451422 --- pytype/abstract/_classes.py | 2 -- pytype/rewrite/abstract/classes.py | 9 ++++++-- pytype/rewrite/convert.py | 17 ++++++++++++-- pytype/rewrite/convert_test.py | 8 +++++++ pytype/rewrite/frame.py | 37 +++++++++++++++++++++--------- pytype/rewrite/frame_test.py | 13 +++++++++++ pytype/rewrite/output.py | 3 ++- pytype/rewrite/tests/test_basic.py | 12 ++++++++++ 8 files changed, 83 insertions(+), 18 deletions(-) diff --git a/pytype/abstract/_classes.py b/pytype/abstract/_classes.py index 6dad073ca..fc0147ade 100644 --- a/pytype/abstract/_classes.py +++ b/pytype/abstract/_classes.py @@ -42,8 +42,6 @@ def call(self, node, func, args, alias_map=None): args = args.simplify(node, self.ctx) funcvar, name = args.posargs[0:2] kwargs = args.namedargs - # TODO(mdemello): Check if there are any changes between python2 and - # python3 in the final metaclass computation. # TODO(b/123450483): Any remaining kwargs need to be passed to the # metaclass. metaclass = kwargs.get("metaclass", None) diff --git a/pytype/rewrite/abstract/classes.py b/pytype/rewrite/abstract/classes.py index e71511f9d..bb3f6430b 100644 --- a/pytype/rewrite/abstract/classes.py +++ b/pytype/rewrite/abstract/classes.py @@ -36,11 +36,13 @@ def __init__( ctx: base.ContextType, name: str, members: Dict[str, base.BaseValue], + bases: Sequence['SimpleClass'] = (), module: Optional[str] = None, ): super().__init__(ctx) self.name = name self.members = members + self.bases = bases self.module = module self._canonical_instance: Optional['FrozenInstance'] = None @@ -121,11 +123,14 @@ class InterpreterClass(SimpleClass): """Class defined in the current module.""" def __init__( - self, ctx: base.ContextType, name: str, + self, + ctx: base.ContextType, + name: str, members: Dict[str, base.BaseValue], + bases: Sequence[SimpleClass], functions: Sequence[functions_lib.InterpreterFunction], classes: Sequence['InterpreterClass']): - super().__init__(ctx, name, members) + super().__init__(ctx, name, members, bases) # Functions and classes defined in this class's body. Unlike 'members', # ignores the effects of post-definition transformations like decorators. self.functions = functions diff --git a/pytype/rewrite/convert.py b/pytype/rewrite/convert.py index bf3635dfc..5760e099a 100644 --- a/pytype/rewrite/convert.py +++ b/pytype/rewrite/convert.py @@ -30,9 +30,10 @@ def pytd_class_to_value(self, cls: pytd.Class) -> abstract.SimpleClass: ctx=self._ctx, name=name, members=members, + bases=(), module=module or None) - # Cache the class early so that references to it in its members don't cause - # infinite recursion. + # Cache the class early so that references to it in its members and bases + # don't cause infinite recursion. self._cache.classes[cls] = abstract_class for method in cls.methods: abstract_class.members[method.name] = ( @@ -43,6 +44,18 @@ def pytd_class_to_value(self, cls: pytd.Class) -> abstract.SimpleClass: for nested_class in cls.classes: abstract_class.members[nested_class.name] = ( self.pytd_class_to_value(nested_class)) + bases = [] + for base in cls.bases: + if isinstance(base, pytd.GenericType): + # TODO(b/292160579): Handle generics. + base = base.base_type + if isinstance(base, pytd.ClassType): + base = base.cls + if isinstance(base, pytd.Class): + bases.append(self.pytd_class_to_value(base)) + else: + raise NotImplementedError(f"I can't handle this base class: {base}") + abstract_class.bases = tuple(bases) return abstract_class def pytd_function_to_value( diff --git a/pytype/rewrite/convert_test.py b/pytype/rewrite/convert_test.py index 35ce6ead9..bb6bff8fa 100644 --- a/pytype/rewrite/convert_test.py +++ b/pytype/rewrite/convert_test.py @@ -91,6 +91,14 @@ class D: ... self.assertIsInstance(nested_class, abstract.SimpleClass) self.assertEqual(nested_class.name, 'D') + def test_bases(self): + pytd_cls = self.build_pytd(""" + class C: ... + class D(C): ... + """, 'D') + cls = self.conv.pytd_class_to_value(pytd_cls) + self.assertEqual(cls.bases, (abstract.SimpleClass(self.ctx, 'C', {}),)) + class PytdAliasToValueTest(ConverterTestBase): diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py index ccb95bad4..5d7dd7292 100644 --- a/pytype/rewrite/frame.py +++ b/pytype/rewrite/frame.py @@ -310,6 +310,31 @@ def _merge_nonlocals_into(self, frame: Optional['Frame']) -> None: var = self._final_locals[name] frame.store_global(name, var) + def _build_class(self, args: abstract.Args) -> abstract.InterpreterClass: + builder = args.posargs[0].get_atomic_value(_FrameFunction) + name = abstract.get_atomic_constant(args.posargs[1], str) + + base_vars = args.posargs[2:] + bases = [] + for base_var in base_vars: + try: + base = base_var.get_atomic_value(abstract.SimpleClass) + except ValueError as e: + raise NotImplementedError('Unexpected base class') from e + bases.append(base) + + frame = builder.call(abstract.Args(frame=self)) + cls = abstract.InterpreterClass( + ctx=self._ctx, + name=name, + members=dict(frame.final_locals), + bases=bases, + functions=frame.functions, + classes=frame.classes, + ) + log.info('Created class: %s', cls.name) + return cls + def _call_function( self, func_var: _AbstractVariable, @@ -323,17 +348,7 @@ def _call_function( ret = func.call(args) ret_values.append(ret.get_return_value()) elif func is self._ctx.consts.singles['__build_class__']: - class_body, name = args.posargs - builder = class_body.get_atomic_value(_FrameFunction) - frame = builder.call(abstract.Args(frame=self)) - cls = abstract.InterpreterClass( - ctx=self._ctx, - name=abstract.get_atomic_constant(name, str), - members=dict(frame.final_locals), - functions=frame.functions, - classes=frame.classes, - ) - log.info('Created class: %s', cls.name) + cls = self._build_class(args) self._classes.append(cls) ret_values.append(cls) else: diff --git a/pytype/rewrite/frame_test.py b/pytype/rewrite/frame_test.py index 627bb51ef..6a2eb563a 100644 --- a/pytype/rewrite/frame_test.py +++ b/pytype/rewrite/frame_test.py @@ -475,6 +475,19 @@ def test_stack_ops(self): frame = frame_lib.Frame(self.ctx, 'test', code.Seal()) frame.run() # Should not crash + def test_class_bases(self): + frame = self._make_frame(""" + class C: + pass + class D(C): + pass + """) + frame.run() + c = _get(frame, 'C', abstract.InterpreterClass) + d = _get(frame, 'D', abstract.InterpreterClass) + self.assertFalse(c.bases) + self.assertEqual(d.bases, [c]) + class BuildConstantsTest(FrameTestBase): diff --git a/pytype/rewrite/output.py b/pytype/rewrite/output.py index 60a10c5fa..5a5a7429d 100644 --- a/pytype/rewrite/output.py +++ b/pytype/rewrite/output.py @@ -67,10 +67,11 @@ def _class_to_pytd_def(self, val: abstract.SimpleClass) -> pytd.Class: for member_name, member_val in instance.members.items(): member_type = self.to_pytd_type(member_val) constants.append(pytd.Constant(name=member_name, type=member_type)) + bases = tuple(self.to_pytd_type_of_instance(base) for base in val.bases) return pytd.Class( name=val.name, keywords=(), - bases=(), + bases=bases, methods=tuple(methods), constants=tuple(constants), classes=tuple(classes), diff --git a/pytype/rewrite/tests/test_basic.py b/pytype/rewrite/tests/test_basic.py index 594a0f8c5..a1517828b 100644 --- a/pytype/rewrite/tests/test_basic.py +++ b/pytype/rewrite/tests/test_basic.py @@ -81,6 +81,18 @@ def __init__(self) -> None: ... def f(self) -> int: ... """) + def test_inheritance(self): + ty = self.Infer(""" + class C: + pass + class D(C): + pass + """) + self.assertTypesMatchPytd(ty, """ + class C: ... + class D(C): ... + """) + class ImportsTest(RewriteTest): """Import tests.""" From b5e1e1ae3a792a4add98364dc8233187005d4c19 Mon Sep 17 00:00:00 2001 From: rechen Date: Tue, 16 Apr 2024 16:25:15 -0700 Subject: [PATCH 04/22] rewrite: detect metaclasses. PiperOrigin-RevId: 625487997 --- pytype/rewrite/abstract/classes.py | 11 ++++++++++- pytype/rewrite/convert.py | 2 ++ pytype/rewrite/convert_test.py | 10 ++++++++++ pytype/rewrite/frame.py | 9 +++++++++ pytype/rewrite/frame_test.py | 12 ++++++++++++ pytype/rewrite/output.py | 4 +++- pytype/rewrite/output_test.py | 12 ++++++++++++ 7 files changed, 58 insertions(+), 2 deletions(-) diff --git a/pytype/rewrite/abstract/classes.py b/pytype/rewrite/abstract/classes.py index bb3f6430b..2cbe60fe4 100644 --- a/pytype/rewrite/abstract/classes.py +++ b/pytype/rewrite/abstract/classes.py @@ -13,6 +13,8 @@ log = logging.getLogger(__name__) +_EMPTY_MAP = immutabledict.immutabledict() + class _HasMembers(Protocol): @@ -37,12 +39,14 @@ def __init__( name: str, members: Dict[str, base.BaseValue], bases: Sequence['SimpleClass'] = (), + keywords: Mapping[str, base.BaseValue] = _EMPTY_MAP, module: Optional[str] = None, ): super().__init__(ctx) self.name = name self.members = members self.bases = bases + self.keywords = keywords self.module = module self._canonical_instance: Optional['FrozenInstance'] = None @@ -78,6 +82,10 @@ def full_name(self): else: return self.name + @property + def metaclass(self) -> Optional[base.BaseValue]: + return self.keywords.get('metaclass') + def get_attribute(self, name: str) -> Optional[base.BaseValue]: return self.members.get(name) @@ -128,9 +136,10 @@ def __init__( name: str, members: Dict[str, base.BaseValue], bases: Sequence[SimpleClass], + keywords: Mapping[str, base.BaseValue], functions: Sequence[functions_lib.InterpreterFunction], classes: Sequence['InterpreterClass']): - super().__init__(ctx, name, members, bases) + super().__init__(ctx, name, members, bases, keywords) # Functions and classes defined in this class's body. Unlike 'members', # ignores the effects of post-definition transformations like decorators. self.functions = functions diff --git a/pytype/rewrite/convert.py b/pytype/rewrite/convert.py index 5760e099a..989512881 100644 --- a/pytype/rewrite/convert.py +++ b/pytype/rewrite/convert.py @@ -26,11 +26,13 @@ def pytd_class_to_value(self, cls: pytd.Class) -> abstract.SimpleClass: # TODO(b/324464265): Handle keywords, bases, decorators, slots, template module, _, name = cls.name.rpartition('.') members = {} + keywords = {kw: self.pytd_type_to_value(val) for kw, val in cls.keywords} abstract_class = abstract.SimpleClass( ctx=self._ctx, name=name, members=members, bases=(), + keywords=keywords, module=module or None) # Cache the class early so that references to it in its members and bases # don't cause infinite recursion. diff --git a/pytype/rewrite/convert_test.py b/pytype/rewrite/convert_test.py index bb6bff8fa..7a3aea0e4 100644 --- a/pytype/rewrite/convert_test.py +++ b/pytype/rewrite/convert_test.py @@ -99,6 +99,16 @@ class D(C): ... cls = self.conv.pytd_class_to_value(pytd_cls) self.assertEqual(cls.bases, (abstract.SimpleClass(self.ctx, 'C', {}),)) + def test_metaclass(self): + pytd_cls = self.build_pytd(""" + class Meta(type): ... + class C(metaclass=Meta): ... + """, 'C') + cls = self.conv.pytd_class_to_value(pytd_cls) + metaclass = cls.metaclass + self.assertIsNotNone(metaclass) + self.assertEqual(metaclass.name, 'Meta') + class PytdAliasToValueTest(ConverterTestBase): diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py index 5d7dd7292..d5d7cddf7 100644 --- a/pytype/rewrite/frame.py +++ b/pytype/rewrite/frame.py @@ -323,12 +323,21 @@ def _build_class(self, args: abstract.Args) -> abstract.InterpreterClass: raise NotImplementedError('Unexpected base class') from e bases.append(base) + keywords = {} + for kw, var in args.kwargs.items(): + try: + val = var.get_atomic_value() + except ValueError as e: + raise NotImplementedError('Unexpected keyword value') from e + keywords[kw] = val + frame = builder.call(abstract.Args(frame=self)) cls = abstract.InterpreterClass( ctx=self._ctx, name=name, members=dict(frame.final_locals), bases=bases, + keywords=keywords, functions=frame.functions, classes=frame.classes, ) diff --git a/pytype/rewrite/frame_test.py b/pytype/rewrite/frame_test.py index 6a2eb563a..77b627bdd 100644 --- a/pytype/rewrite/frame_test.py +++ b/pytype/rewrite/frame_test.py @@ -488,6 +488,18 @@ class D(C): self.assertFalse(c.bases) self.assertEqual(d.bases, [c]) + def test_metaclass(self): + frame = self._make_frame(""" + class Meta(type): + pass + class C(metaclass=Meta): + pass + """) + frame.run() + meta = _get(frame, 'Meta', abstract.InterpreterClass) + c = _get(frame, 'C', abstract.InterpreterClass) + self.assertEqual(c.metaclass, meta) + class BuildConstantsTest(FrameTestBase): diff --git a/pytype/rewrite/output.py b/pytype/rewrite/output.py index 5a5a7429d..a9fa14639 100644 --- a/pytype/rewrite/output.py +++ b/pytype/rewrite/output.py @@ -67,10 +67,12 @@ def _class_to_pytd_def(self, val: abstract.SimpleClass) -> pytd.Class: for member_name, member_val in instance.members.items(): member_type = self.to_pytd_type(member_val) constants.append(pytd.Constant(name=member_name, type=member_type)) + keywords = tuple((k, self.to_pytd_type_of_instance(v)) + for k, v in val.keywords.items()) bases = tuple(self.to_pytd_type_of_instance(base) for base in val.bases) return pytd.Class( name=val.name, - keywords=(), + keywords=keywords, bases=bases, methods=tuple(methods), constants=tuple(constants), diff --git a/pytype/rewrite/output_test.py b/pytype/rewrite/output_test.py index 8d4fe71a4..274e2a88c 100644 --- a/pytype/rewrite/output_test.py +++ b/pytype/rewrite/output_test.py @@ -72,6 +72,18 @@ class C: def __init__(self) -> None: ... """) + def test_metaclass(self): + cls = self.make_value(""" + class Meta(type): + pass + class C(metaclass=Meta): + pass + """) + pytd_cls = self.ctx.pytd_converter.to_pytd_def(cls) + self.assertPytdEqual(pytd_cls, """ + class C(metaclass=Meta): ... + """) + class FunctionToPytdDefTest(OutputTestBase): From c29cb159e56be06333f7ac745d1aea35c31d4c67 Mon Sep 17 00:00:00 2001 From: rechen Date: Tue, 16 Apr 2024 16:44:19 -0700 Subject: [PATCH 05/22] Add EMPTY_MAP to datatypes.py. PiperOrigin-RevId: 625493532 --- pytype/datatypes.py | 6 ++++++ pytype/rewrite/CMakeLists.txt | 1 + pytype/rewrite/abstract/CMakeLists.txt | 2 ++ pytype/rewrite/abstract/classes.py | 12 +++++------- pytype/rewrite/abstract/functions.py | 11 +++++------ pytype/rewrite/frame.py | 20 +++++++++----------- 6 files changed, 28 insertions(+), 24 deletions(-) diff --git a/pytype/datatypes.py b/pytype/datatypes.py index 3750aa795..7178ddb97 100644 --- a/pytype/datatypes.py +++ b/pytype/datatypes.py @@ -5,9 +5,15 @@ import itertools from typing import Dict, Optional, TypeVar +import immutabledict + _K = TypeVar("_K") _V = TypeVar("_V") +# Public alias for immutabledict to save users the extra import. +immutabledict = immutabledict.immutabledict +EMPTY_MAP = immutabledict() + class UnionFind: r"""A disjoint-set data structure for `AliasingDict`. diff --git a/pytype/rewrite/CMakeLists.txt b/pytype/rewrite/CMakeLists.txt index b72ccbf1f..97cf52d64 100644 --- a/pytype/rewrite/CMakeLists.txt +++ b/pytype/rewrite/CMakeLists.txt @@ -79,6 +79,7 @@ py_library( DEPS .context .stack + pytype.utils pytype.blocks.blocks pytype.rewrite.abstract.abstract pytype.rewrite.flow.flow diff --git a/pytype/rewrite/abstract/CMakeLists.txt b/pytype/rewrite/abstract/CMakeLists.txt index 6da80abe3..e2ea6e95e 100644 --- a/pytype/rewrite/abstract/CMakeLists.txt +++ b/pytype/rewrite/abstract/CMakeLists.txt @@ -48,6 +48,7 @@ py_library( DEPS .base .functions + pytype.utils pytype.types.types ) @@ -93,6 +94,7 @@ py_library( functions.py DEPS .base + pytype.utils pytype.blocks.blocks pytype.pytd.pytd ) diff --git a/pytype/rewrite/abstract/classes.py b/pytype/rewrite/abstract/classes.py index 2cbe60fe4..c756205ea 100644 --- a/pytype/rewrite/abstract/classes.py +++ b/pytype/rewrite/abstract/classes.py @@ -6,15 +6,13 @@ from typing import Dict, List, Mapping, Optional, Protocol, Sequence -import immutabledict +from pytype import datatypes from pytype.rewrite.abstract import base from pytype.rewrite.abstract import functions as functions_lib from pytype.types import types log = logging.getLogger(__name__) -_EMPTY_MAP = immutabledict.immutabledict() - class _HasMembers(Protocol): @@ -39,7 +37,7 @@ def __init__( name: str, members: Dict[str, base.BaseValue], bases: Sequence['SimpleClass'] = (), - keywords: Mapping[str, base.BaseValue] = _EMPTY_MAP, + keywords: Mapping[str, base.BaseValue] = datatypes.EMPTY_MAP, module: Optional[str] = None, ): super().__init__(ctx) @@ -150,7 +148,7 @@ def __repr__(self): @property def _attrs(self): - return (self.name, immutabledict.immutabledict(self.members)) + return (self.name, datatypes.immutabledict(self.members)) class BaseInstance(base.BaseValue): @@ -188,7 +186,7 @@ def __repr__(self): @property def _attrs(self): - return (self.cls, immutabledict.immutabledict(self.members)) + return (self.cls, datatypes.immutabledict(self.members)) def set_attribute(self, name: str, value: base.BaseValue) -> None: if name in self.members: @@ -209,7 +207,7 @@ class FrozenInstance(BaseInstance): def __init__(self, ctx: base.ContextType, instance: MutableInstance): super().__init__( - ctx, instance.cls, immutabledict.immutabledict(instance.members)) + ctx, instance.cls, datatypes.immutabledict(instance.members)) def __repr__(self): return f'FrozenInstance({self.cls.name})' diff --git a/pytype/rewrite/abstract/functions.py b/pytype/rewrite/abstract/functions.py index cd07e03a0..4069d1a59 100644 --- a/pytype/rewrite/abstract/functions.py +++ b/pytype/rewrite/abstract/functions.py @@ -20,14 +20,13 @@ import logging from typing import Dict, Generic, Mapping, Optional, Protocol, Sequence, Tuple, TypeVar -import immutabledict +from pytype import datatypes from pytype.blocks import blocks from pytype.pytd import pytd from pytype.rewrite.abstract import base log = logging.getLogger(__name__) -_EMPTY_MAP = immutabledict.immutabledict() _ArgDict = Dict[str, base.AbstractVariableType] @@ -56,7 +55,7 @@ def get_return_value(self) -> base.BaseValue: ... class Args(Generic[_FrameT]): """Arguments to one function call.""" posargs: Tuple[base.AbstractVariableType, ...] = () - kwargs: Mapping[str, base.AbstractVariableType] = _EMPTY_MAP + kwargs: Mapping[str, base.AbstractVariableType] = datatypes.EMPTY_MAP starargs: Optional[base.AbstractVariableType] = None starstarargs: Optional[base.AbstractVariableType] = None frame: Optional[_FrameT] = None @@ -116,8 +115,8 @@ def __init__( varargs_name: Optional[str] = None, kwonly_params: Tuple[str, ...] = (), kwargs_name: Optional[str] = None, - defaults: Mapping[str, base.BaseValue] = _EMPTY_MAP, - annotations: Mapping[str, base.BaseValue] = _EMPTY_MAP, + defaults: Mapping[str, base.BaseValue] = datatypes.EMPTY_MAP, + annotations: Mapping[str, base.BaseValue] = datatypes.EMPTY_MAP, ): self._ctx = ctx self.name = name @@ -398,7 +397,7 @@ def call_with_mapped_args(self, mapped_args: MappedArgs[_FrameT]) -> _FrameT: else: # If the parent frame has finished running, then the context of this call # will not change, so we can cache the return value. - k = (parent_frame.name, immutabledict.immutabledict(mapped_args.argdict)) + k = (parent_frame.name, datatypes.immutabledict(mapped_args.argdict)) if k in self._call_cache: log.info('Reusing cached return value of function %s', self.name) return self._call_cache[k] diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py index d5d7cddf7..7b49103a4 100644 --- a/pytype/rewrite/frame.py +++ b/pytype/rewrite/frame.py @@ -3,8 +3,8 @@ import logging from typing import Any, FrozenSet, List, Mapping, Optional, Sequence, Set, Type -import immutabledict from pycnite import marshal as pyc_marshal +from pytype import datatypes from pytype.blocks import blocks from pytype.rewrite import context from pytype.rewrite import stack @@ -15,8 +15,6 @@ log = logging.getLogger(__name__) -_EMPTY_MAP = immutabledict.immutabledict() - # Type aliases _AbstractVariable = variables.Variable[abstract.BaseValue] _VarMap = Mapping[str, _AbstractVariable] @@ -67,9 +65,9 @@ def __init__( name: str, code: blocks.OrderedCode, *, - initial_locals: _VarMap = _EMPTY_MAP, - initial_enclosing: _VarMap = _EMPTY_MAP, - initial_globals: _VarMap = _EMPTY_MAP, + initial_locals: _VarMap = datatypes.EMPTY_MAP, + initial_enclosing: _VarMap = datatypes.EMPTY_MAP, + initial_globals: _VarMap = datatypes.EMPTY_MAP, f_back: Optional['Frame'] = None, ): super().__init__(code, initial_locals) @@ -154,7 +152,7 @@ def run(self) -> None: # Set the current state to None so that the load_* and store_* methods # cannot be used to modify finalized locals. self._current_state = None - self.final_locals = immutabledict.immutabledict({ + self.final_locals = datatypes.immutabledict({ name: abstract.join_values(self._ctx, var.values) for name, var in self._final_locals.items()}) @@ -604,10 +602,10 @@ def _make_function_args(self, args): n_kw = len(self._kw_names) posargs = tuple(args[:-n_kw]) kw_vals = args[-n_kw:] - kwargs = immutabledict.immutabledict(zip(self._kw_names, kw_vals)) + kwargs = datatypes.immutabledict(zip(self._kw_names, kw_vals)) else: posargs = tuple(args) - kwargs = _EMPTY_MAP + kwargs = datatypes.EMPTY_MAP self._kw_names = () return abstract.Args(posargs=posargs, kwargs=kwargs, frame=self) @@ -677,9 +675,9 @@ def byte_CALL_FUNCTION_EX(self, opcode): starstarargs = None else: # We have an indefinite dict, leave it in starstarargs - kwargs = _EMPTY_MAP + kwargs = datatypes.EMPTY_MAP else: - kwargs = _EMPTY_MAP + kwargs = datatypes.EMPTY_MAP starstarargs = None # Convert *args starargs = self._stack.pop() From fb55b1e3a7c70d20144695f74266fa854ba3d321 Mon Sep 17 00:00:00 2001 From: mdemello Date: Wed, 17 Apr 2024 17:43:29 -0700 Subject: [PATCH 06/22] rewrite: Unpack *args and **args when mapping against a signature. Largely ported from abstract.function.Args.simplify, with a few changes because we don't handle indefinite iterables well yet, and refactored a bit to be easier to read and type check. Not tested for now because it also needs changes to LIST_EXTEND and a few other places to handle unpacking of `Any` better, and this change was getting long enough already. Added a skipped test that works in the execution phase, but not the analyse-with-fake-args phase. PiperOrigin-RevId: 625859869 --- pytype/rewrite/abstract/CMakeLists.txt | 3 + pytype/rewrite/abstract/functions.py | 242 +++++++++++++++++++++++-- pytype/rewrite/abstract/internal.py | 8 + pytype/rewrite/flow/variables.py | 9 +- pytype/rewrite/tests/test_args.py | 16 ++ 5 files changed, 262 insertions(+), 16 deletions(-) diff --git a/pytype/rewrite/abstract/CMakeLists.txt b/pytype/rewrite/abstract/CMakeLists.txt index e2ea6e95e..9faef5fa6 100644 --- a/pytype/rewrite/abstract/CMakeLists.txt +++ b/pytype/rewrite/abstract/CMakeLists.txt @@ -94,6 +94,9 @@ py_library( functions.py DEPS .base + .containers + .internal + .utils pytype.utils pytype.blocks.blocks pytype.pytd.pytd diff --git a/pytype/rewrite/abstract/functions.py b/pytype/rewrite/abstract/functions.py index 4069d1a59..c4d1f5ed5 100644 --- a/pytype/rewrite/abstract/functions.py +++ b/pytype/rewrite/abstract/functions.py @@ -16,18 +16,22 @@ """ import abc +import collections import dataclasses import logging -from typing import Dict, Generic, Mapping, Optional, Protocol, Sequence, Tuple, TypeVar +from typing import Any, Dict, Generic, List, Mapping, Optional, Protocol, Sequence, Tuple, TypeVar from pytype import datatypes from pytype.blocks import blocks from pytype.pytd import pytd from pytype.rewrite.abstract import base +from pytype.rewrite.abstract import containers +from pytype.rewrite.abstract import internal log = logging.getLogger(__name__) -_ArgDict = Dict[str, base.AbstractVariableType] +_Var = base.AbstractVariableType +_ArgDict = Dict[str, _Var] class FrameType(Protocol): @@ -40,7 +44,7 @@ class FrameType(Protocol): def make_child_frame( self, func: 'InterpreterFunction', - initial_locals: Mapping[str, base.AbstractVariableType], + initial_locals: Mapping[str, _Var], ) -> 'FrameType': ... def run(self) -> None: ... @@ -51,15 +55,231 @@ def get_return_value(self) -> base.BaseValue: ... _FrameT = TypeVar('_FrameT', bound=FrameType) +def _unpack_splats(elts): + """Unpack any concrete splats and splice them into the sequence.""" + ret = [] + for e in elts: + try: + splat = e.get_atomic_value(internal.Splat) + ret.extend(splat.get_concrete_iterable()) + except ValueError: + # Leave an indefinite splat intact + ret.append(e) + return tuple(ret) + + @dataclasses.dataclass class Args(Generic[_FrameT]): """Arguments to one function call.""" - posargs: Tuple[base.AbstractVariableType, ...] = () - kwargs: Mapping[str, base.AbstractVariableType] = datatypes.EMPTY_MAP - starargs: Optional[base.AbstractVariableType] = None - starstarargs: Optional[base.AbstractVariableType] = None + posargs: Tuple[_Var, ...] = () + kwargs: Mapping[str, _Var] = datatypes.EMPTY_MAP + starargs: Optional[_Var] = None + starstarargs: Optional[_Var] = None frame: Optional[_FrameT] = None + def get_concrete_starargs(self) -> Tuple[Any, ...]: + """Returns a concrete tuple from starargs or raises ValueError.""" + if self.starargs is None: + raise ValueError('No starargs to convert') + starargs = self.starargs.get_atomic_value(internal.FunctionArgTuple) # pytype: disable=attribute-error + return _unpack_splats(starargs.constant) + + def get_concrete_starstarargs(self) -> Mapping[str, Any]: + """Returns a concrete dict from starstarargs or raises ValueError.""" + if self.starstarargs is None: + raise ValueError('No starstarargs to convert') + starstarargs = self.starstarargs.get_atomic_value(internal.ConstKeyDict) # pytype: disable=attribute-error + return starstarargs.constant + + +class _ArgMapper: + """Map args into a signature.""" + + def __init__(self, ctx: base.ContextType, args: Args, sig: 'Signature'): + self._ctx = ctx + self.args = args + self.sig = sig + self.argdict: _ArgDict = {} + + def _expand_positional_args(self): + """Unpack concrete splats in posargs.""" + new_posargs = _unpack_splats(self.args.posargs) + self.args = dataclasses.replace(self.args, posargs=new_posargs) + + def _expand_typed_star(self, star, n) -> List[_Var]: + """Convert *xs: Sequence[T] -> [T, T, ...].""" + del star # not implemented yet + return [self._ctx.consts.Any.to_variable() for _ in range(n)] + + def _splats_to_any(self, seq) -> Tuple[_Var, ...]: + any_ = self._ctx.consts.Any + return tuple( + any_.to_variable() if v.is_atomic(internal.Splat) else v + for v in seq) + + def _partition_starargs_tuple( + self, starargs_tuple + ) -> Tuple[List[_Var], List[_Var], List[_Var]]: + """Partition a sequence like a, b, c, *middle, x, y, z.""" + pre = [] + post = [] + stars = collections.deque(starargs_tuple) + while stars and not stars[0].is_atomic(internal.Splat): + pre.append(stars.popleft()) + while stars and not stars[-1].is_atomic(internal.Splat): + post.append(stars.pop()) + post.reverse() + return pre, list(stars), post + + def _get_required_posarg_count(self) -> int: + """Find out how many params in sig need to be filled by arg.posargs.""" + # Iterate through param_names until we hit the first kwarg or default, + # since python does not let non-required posargs follow those. + required_posargs = 0 + for p in self.sig.param_names: + if p in self.args.kwargs or p in self.sig.defaults: + break + required_posargs += 1 + return required_posargs + + def _unpack_starargs(self) -> Tuple[Tuple[_Var, ...], Optional[_Var]]: + """Adjust *args and posargs based on function signature.""" + posargs = self.args.posargs + indef_starargs = False + if self.args.starargs is None: + # There is nothing to unpack, but we might want to move unused posargs + # into sig.varargs_name + starargs_tuple = () + else: + try: + starargs_tuple = _unpack_splats(self.args.get_concrete_starargs()) + except ValueError: + # We don't have a concrete starargs. We still need to use this to fill + # in missing posargs or absorb extra ones. + starargs_tuple = () + indef_starargs = True + + # Attempt to adjust the starargs into the missing posargs. + pre, stars, post = self._partition_starargs_tuple(starargs_tuple) + n_matched = len(posargs) + len(pre) + len(post) + n_required_posargs = self._get_required_posarg_count() + posarg_delta = n_required_posargs - n_matched + + if stars and not post: + star = stars[-1] + if self.sig.varargs_name: + # If the invocation ends with `*args`, return it to match against *args + # in the function signature. For f(, *xs, ..., *ys), transform + # to f(, *ys) since ys is an indefinite tuple anyway and will + # match against all remaining posargs. + star = star.get_atomic_value(internal.Splat) + return posargs + tuple(pre), star.iterable.to_variable() + else: + # If we do not have a `*args` in self.sig, just expand the + # terminal splat to as many args as needed and then drop it. + mid = self._expand_typed_star(star, posarg_delta) + return posargs + tuple(pre + mid), None + elif posarg_delta <= len(stars): + # We have too many args; don't do *xs expansion. Go back to matching from + # the start and treat every entry in starargs_tuple as length 1. + n_params = len(self.sig.param_names) + all_args = posargs + starargs_tuple + if not self.sig.varargs_name: + # If the function sig has no *args, return everything in posargs + pos = self._splats_to_any(all_args) + return pos, None + # Don't unwrap splats here because f(*xs, y) is not the same as f(xs, y). + # TODO(mdemello): Ideally, since we are matching call f(*xs, y) against + # sig f(x, y) we should raise an error here. + pos = self._splats_to_any(all_args[:n_params]) + star = [] + for var in all_args[n_params:]: + if var.is_atomic(internal.Splat): + # TODO(rewrite): Fix this! + star.append(self._ctx.consts.Any.to_variable()) + else: + star.append(var) + if star: + return pos, containers.Tuple(self._ctx, tuple(star)).to_variable() + else: + return pos, None + elif stars: + if len(stars) == 1: + # Special case (
, *xs) and (*xs, ) to fill in the type of xs
+        # in every remaining arg.
+        mid = self._expand_typed_star(stars[0], posarg_delta)
+      else:
+        # If we have (*xs, , *ys) remaining, and more than k+2 params to
+        # match, don't try to match the intermediate params to any range, just
+        # match all k+2 to Any
+        mid = [self._ctx.consts.Any.to_variable() for _ in range(posarg_delta)]
+      return posargs + tuple(pre + mid + post), None
+    elif posarg_delta and indef_starargs:
+      # Fill in *required* posargs if needed; don't override the default posargs
+      # with indef starargs yet because we aren't capturing the type of *args
+      if posarg_delta > 0:
+        extra = self._expand_typed_star(self.args.starargs, posarg_delta)
+        return posargs + tuple(extra), None
+      elif self.sig.varargs_name:
+        posargs = posargs[:n_required_posargs]
+        return posargs, self.args.starargs
+      else:
+        # We have too many posargs *and* no *args in the sig to absorb them, so
+        # just do nothing and handle the error downstream.
+        return posargs, self.args.starargs
+
+    else:
+      # We have **kwargs but no *args in the invocation
+      return posargs + tuple(pre), None
+
+  def _map_posargs(self):
+    posargs, starargs = self._unpack_starargs()
+    argdict = dict(zip(self.sig.param_names, posargs))
+    self.argdict.update(argdict)
+    if self.sig.varargs_name:
+      # Make sure kwargs_name is bound to something
+      if starargs is None:
+        starargs = self._ctx.consts.Any.to_variable()
+      self.argdict[self.sig.varargs_name] = starargs
+
+  def _unpack_starstarargs(self):
+    """Adjust **args and kwargs based on function signature."""
+    if self.args.starstarargs is None:
+      # Nothing to unpack
+      return self.args.kwargs, None
+    try:
+      starstarargs_dict = self.args.get_concrete_starstarargs()
+    except ValueError:
+      # We have a non-concrete starstarargs
+      return self.args.kwargs, self.args.starstarargs
+    # Unpack **args into kwargs, overwriting named args for now
+    # TODO(mdemello): raise an error if we have a conflict
+    kwargs = {**self.args.kwargs}
+    starstarargs_dict = {**starstarargs_dict}
+    for k in self.sig.param_names:
+      if k in starstarargs_dict:
+        kwargs[k] = starstarargs_dict[k]
+        del starstarargs_dict[k]
+    # Pack the unused entries in starstarargs back into an abstract value
+    starstarargs = internal.ConstKeyDict(self._ctx, starstarargs_dict)
+    return kwargs, starstarargs.to_variable()
+
+  def _map_kwargs(self):
+    kwargs, starstarargs = self._unpack_starstarargs()
+    # Copy kwargs into argdict
+    self.argdict.update(kwargs)
+    # Make sure kwargs_name is bound to something
+    if self.sig.kwargs_name:
+      if starstarargs is None:
+        starstarargs = internal.ConstKeyDict(self._ctx, {}).to_variable()
+      self.argdict[self.sig.kwargs_name] = starstarargs
+
+  def map_args(self):
+    self._expand_positional_args()
+    self._map_kwargs()
+    self._map_posargs()
+    return self.argdict
+
 
 @dataclasses.dataclass
 class MappedArgs(Generic[_FrameT]):
@@ -233,13 +453,7 @@ def fmt(param_name):
 
   def map_args(self, args: Args[_FrameT]) -> MappedArgs[_FrameT]:
     # TODO(b/241479600): Implement this properly, with error detection.
-    argdict = dict(zip(self.param_names, args.posargs))
-    argdict.update(args.kwargs)
-    def add_arg(k, v):
-      if k:
-        argdict[k] = v or self._ctx.consts.Any.to_variable()
-    add_arg(self.varargs_name, args.starargs)
-    add_arg(self.kwargs_name, args.starstarargs)
+    argdict = _ArgMapper(self._ctx, args, self).map_args()
     return MappedArgs(signature=self, argdict=argdict, frame=args.frame)
 
   def make_fake_args(self) -> MappedArgs[FrameType]:
diff --git a/pytype/rewrite/abstract/internal.py b/pytype/rewrite/abstract/internal.py
index 886b028f6..d170c8231 100644
--- a/pytype/rewrite/abstract/internal.py
+++ b/pytype/rewrite/abstract/internal.py
@@ -1,5 +1,6 @@
 """Abstract types used internally by pytype."""
 
+import collections
 from typing import Dict, Tuple
 
 import immutabledict
@@ -58,6 +59,13 @@ def __init__(self, ctx: base.ContextType, iterable: base.BaseValue):
     super().__init__(ctx)
     self.iterable = iterable
 
+  def get_concrete_iterable(self):
+    if (isinstance(self.iterable, base.PythonConstant) and
+        isinstance(self.iterable.constant, collections.abc.Iterable)):
+      return self.iterable.constant
+    else:
+      raise ValueError("Not a concrete iterable")
+
   def __repr__(self):
     return f"splat({self.iterable!r})"
 
diff --git a/pytype/rewrite/flow/variables.py b/pytype/rewrite/flow/variables.py
index 81b8777eb..21ffa9726 100644
--- a/pytype/rewrite/flow/variables.py
+++ b/pytype/rewrite/flow/variables.py
@@ -60,7 +60,7 @@ def get_atomic_value(self, typ: None = ...) -> _T: ...
 
   def get_atomic_value(self, typ=None):
     """Gets this variable's value if there's exactly one, errors otherwise."""
-    if len(self.bindings) != 1:
+    if not self.is_atomic():
       desc = 'many' if len(self.bindings) > 1 else 'few'
       raise ValueError(
           f'Too {desc} bindings for {self.display_name()}: {self.bindings}')
@@ -71,8 +71,13 @@ def get_atomic_value(self, typ=None):
           f'{runtime_type.__name__}, got {value.__class__.__name__}')
     return value
 
+  def is_atomic(self, typ: Optional[Type[_T]] = None) -> bool:
+    if len(self.bindings) != 1:
+      return False
+    return True if typ is None else isinstance(self.values[0], typ)
+
   def has_atomic_value(self, value: Any) -> bool:
-    return len(self.values) == 1 and self.values[0] == value
+    return self.is_atomic() and self.values[0] == value
 
   def with_condition(self, condition: conditions.Condition) -> 'Variable[_T]':
     """Adds a condition, 'and'-ing it with any existing."""
diff --git a/pytype/rewrite/tests/test_args.py b/pytype/rewrite/tests/test_args.py
index 91d9d6688..d56ff7e60 100644
--- a/pytype/rewrite/tests/test_args.py
+++ b/pytype/rewrite/tests/test_args.py
@@ -68,6 +68,22 @@ def g(a, b, x, y):
       f(*a, **b)
     """)
 
+  @test_base.skip('Does not yet work with fake args.')
+  def test_unpack_posargs(self):
+    self.Check("""
+      def f(x, y, z):
+        g(*x, *y, *z)
+
+      def g(*args):
+        h(*args)
+
+      def h(p, q, r, s, t, u):
+        return u
+
+      ret = f((1, 2), (3, 4), (5, 6))
+      assert_type(ret, int)
+    """)
+
 
 if __name__ == '__main__':
   test_base.main()

From 8fb00065e05590464568019efe3fbc2e338b3f62 Mon Sep 17 00:00:00 2001
From: rechen 
Date: Wed, 17 Apr 2024 19:01:01 -0700
Subject: [PATCH 07/22] rewrite: add MRO computation and improve class
 attribute lookup.

PiperOrigin-RevId: 625877986
---
 pytype/rewrite/abstract/CMakeLists.txt  |  1 +
 pytype/rewrite/abstract/classes.py      | 23 ++++++++++++++++++++++-
 pytype/rewrite/abstract/classes_test.py | 13 +++++++++++++
 3 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/pytype/rewrite/abstract/CMakeLists.txt b/pytype/rewrite/abstract/CMakeLists.txt
index 9faef5fa6..4f0d09de6 100644
--- a/pytype/rewrite/abstract/CMakeLists.txt
+++ b/pytype/rewrite/abstract/CMakeLists.txt
@@ -49,6 +49,7 @@ py_library(
     .base
     .functions
     pytype.utils
+    pytype.pytd.pytd
     pytype.types.types
 )
 
diff --git a/pytype/rewrite/abstract/classes.py b/pytype/rewrite/abstract/classes.py
index c756205ea..4e951bfc5 100644
--- a/pytype/rewrite/abstract/classes.py
+++ b/pytype/rewrite/abstract/classes.py
@@ -7,6 +7,7 @@
 from typing import Dict, List, Mapping, Optional, Protocol, Sequence
 
 from pytype import datatypes
+from pytype.pytd import mro as mro_lib
 from pytype.rewrite.abstract import base
 from pytype.rewrite.abstract import functions as functions_lib
 from pytype.types import types
@@ -47,6 +48,7 @@ def __init__(
     self.keywords = keywords
     self.module = module
     self._canonical_instance: Optional['FrozenInstance'] = None
+    self._mro: Optional[Sequence['SimpleClass']] = None
 
     if isinstance((init := members.get('__init__')),
                   functions_lib.SimpleFunction):
@@ -85,7 +87,12 @@ def metaclass(self) -> Optional[base.BaseValue]:
     return self.keywords.get('metaclass')
 
   def get_attribute(self, name: str) -> Optional[base.BaseValue]:
-    return self.members.get(name)
+    if name in self.members:
+      return self.members[name]
+    mro = self.mro()
+    if len(mro) > 1:
+      return mro[1].get_attribute(name)
+    return None
 
   def set_attribute(self, name: str, value: base.BaseValue) -> None:
     # SimpleClass is used to model imported classes, which we treat as frozen.
@@ -124,6 +131,20 @@ def call(self, args: functions_lib.Args) -> ClassCallReturn:
         _ = initializer.bind_to(instance).call(args)
     return ClassCallReturn(instance)
 
+  def mro(self) -> Sequence['SimpleClass']:
+    if self._mro:
+      return self._mro
+    if self.full_name == 'builtins.object':
+      self._mro = mro = [self]
+      return mro
+    bases = list(self.bases)
+    obj_type = self._ctx.abstract_loader.load_raw_type(object)
+    if not bases or bases[-1] != obj_type:
+      bases.append(obj_type)
+    mro_bases = [[self]] + [list(base.mro()) for base in bases] + [bases]
+    self._mro = mro = mro_lib.MROMerge(mro_bases)
+    return mro
+
 
 class InterpreterClass(SimpleClass):
   """Class defined in the current module."""
diff --git a/pytype/rewrite/abstract/classes_test.py b/pytype/rewrite/abstract/classes_test.py
index 1f7e15997..bcbec0b11 100644
--- a/pytype/rewrite/abstract/classes_test.py
+++ b/pytype/rewrite/abstract/classes_test.py
@@ -16,6 +16,12 @@ def test_get_nonexistent_attribute(self):
     cls = classes.SimpleClass(self.ctx, 'X', {})
     self.assertIsNone(cls.get_attribute('x'))
 
+  def test_get_parent_attribute(self):
+    x = self.ctx.consts[5]
+    parent = classes.SimpleClass(self.ctx, 'Parent', {'x': x})
+    child = classes.SimpleClass(self.ctx, 'Child', {}, bases=[parent])
+    self.assertEqual(child.get_attribute('x'), x)
+
   def test_instantiate(self):
     cls = classes.SimpleClass(self.ctx, 'X', {})
     instance = cls.instantiate()
@@ -26,6 +32,13 @@ def test_call(self):
     instance = cls.call(functions.Args()).get_return_value()
     self.assertEqual(instance.cls, cls)
 
+  def test_mro(self):
+    parent = classes.SimpleClass(self.ctx, 'Parent', {})
+    child = classes.SimpleClass(self.ctx, 'Child', {}, bases=[parent])
+    self.assertEqual(
+        child.mro(),
+        [child, parent, self.ctx.abstract_loader.load_raw_type(object)])
+
 
 class MutableInstanceTest(test_utils.ContextfulTestBase):
 

From 9a164b27ab1db91308fa393c131952a4812c8276 Mon Sep 17 00:00:00 2001
From: mdemello 
Date: Thu, 18 Apr 2024 12:00:40 -0700
Subject: [PATCH 08/22] rewrite: Improve handling of indefinite iterables in
 posargs

Merges posargs and starargs before handling splats, and does not replace splats
with Any when matching args since we might want to forward them through several
function calls.

PiperOrigin-RevId: 626107538
---
 pytype/rewrite/abstract/functions.py | 38 +++++++++-------------------
 pytype/rewrite/tests/test_args.py    |  5 ++--
 2 files changed, 14 insertions(+), 29 deletions(-)

diff --git a/pytype/rewrite/abstract/functions.py b/pytype/rewrite/abstract/functions.py
index c4d1f5ed5..92c0cccd7 100644
--- a/pytype/rewrite/abstract/functions.py
+++ b/pytype/rewrite/abstract/functions.py
@@ -111,13 +111,7 @@ def _expand_typed_star(self, star, n) -> List[_Var]:
     del star  # not implemented yet
     return [self._ctx.consts.Any.to_variable() for _ in range(n)]
 
-  def _splats_to_any(self, seq) -> Tuple[_Var, ...]:
-    any_ = self._ctx.consts.Any
-    return tuple(
-        any_.to_variable() if v.is_atomic(internal.Splat) else v
-        for v in seq)
-
-  def _partition_starargs_tuple(
+  def _partition_args_tuple(
       self, starargs_tuple
   ) -> Tuple[List[_Var], List[_Var], List[_Var]]:
     """Partition a sequence like a, b, c, *middle, x, y, z."""
@@ -160,8 +154,9 @@ def _unpack_starargs(self) -> Tuple[Tuple[_Var, ...], Optional[_Var]]:
         indef_starargs = True
 
     # Attempt to adjust the starargs into the missing posargs.
-    pre, stars, post = self._partition_starargs_tuple(starargs_tuple)
-    n_matched = len(posargs) + len(pre) + len(post)
+    all_posargs = posargs + starargs_tuple
+    pre, stars, post = self._partition_args_tuple(all_posargs)
+    n_matched = len(pre) + len(post)
     n_required_posargs = self._get_required_posarg_count()
     posarg_delta = n_required_posargs - n_matched
 
@@ -173,32 +168,24 @@ def _unpack_starargs(self) -> Tuple[Tuple[_Var, ...], Optional[_Var]]:
         # to f(, *ys) since ys is an indefinite tuple anyway and will
         # match against all remaining posargs.
         star = star.get_atomic_value(internal.Splat)
-        return posargs + tuple(pre), star.iterable.to_variable()
+        return tuple(pre), star.iterable.to_variable()
       else:
         # If we do not have a `*args` in self.sig, just expand the
         # terminal splat to as many args as needed and then drop it.
         mid = self._expand_typed_star(star, posarg_delta)
-        return posargs + tuple(pre + mid), None
+        return tuple(pre + mid), None
     elif posarg_delta <= len(stars):
       # We have too many args; don't do *xs expansion. Go back to matching from
       # the start and treat every entry in starargs_tuple as length 1.
       n_params = len(self.sig.param_names)
-      all_args = posargs + starargs_tuple
       if not self.sig.varargs_name:
         # If the function sig has no *args, return everything in posargs
-        pos = self._splats_to_any(all_args)
-        return pos, None
+        return all_posargs, None
       # Don't unwrap splats here because f(*xs, y) is not the same as f(xs, y).
       # TODO(mdemello): Ideally, since we are matching call f(*xs, y) against
       # sig f(x, y) we should raise an error here.
-      pos = self._splats_to_any(all_args[:n_params])
-      star = []
-      for var in all_args[n_params:]:
-        if var.is_atomic(internal.Splat):
-          # TODO(rewrite): Fix this!
-          star.append(self._ctx.consts.Any.to_variable())
-        else:
-          star.append(var)
+      pos = all_posargs[:n_params]
+      star = all_posargs[n_params:]
       if star:
         return pos, containers.Tuple(self._ctx, tuple(star)).to_variable()
       else:
@@ -213,7 +200,7 @@ def _unpack_starargs(self) -> Tuple[Tuple[_Var, ...], Optional[_Var]]:
         # match, don't try to match the intermediate params to any range, just
         # match all k+2 to Any
         mid = [self._ctx.consts.Any.to_variable() for _ in range(posarg_delta)]
-      return posargs + tuple(pre + mid + post), None
+      return tuple(pre + mid + post), None
     elif posarg_delta and indef_starargs:
       # Fill in *required* posargs if needed; don't override the default posargs
       # with indef starargs yet because we aren't capturing the type of *args
@@ -221,8 +208,7 @@ def _unpack_starargs(self) -> Tuple[Tuple[_Var, ...], Optional[_Var]]:
         extra = self._expand_typed_star(self.args.starargs, posarg_delta)
         return posargs + tuple(extra), None
       elif self.sig.varargs_name:
-        posargs = posargs[:n_required_posargs]
-        return posargs, self.args.starargs
+        return posargs[:n_required_posargs], self.args.starargs
       else:
         # We have too many posargs *and* no *args in the sig to absorb them, so
         # just do nothing and handle the error downstream.
@@ -230,7 +216,7 @@ def _unpack_starargs(self) -> Tuple[Tuple[_Var, ...], Optional[_Var]]:
 
     else:
       # We have **kwargs but no *args in the invocation
-      return posargs + tuple(pre), None
+      return tuple(pre), None
 
   def _map_posargs(self):
     posargs, starargs = self._unpack_starargs()
diff --git a/pytype/rewrite/tests/test_args.py b/pytype/rewrite/tests/test_args.py
index d56ff7e60..3e7bf1dbf 100644
--- a/pytype/rewrite/tests/test_args.py
+++ b/pytype/rewrite/tests/test_args.py
@@ -68,14 +68,13 @@ def g(a, b, x, y):
       f(*a, **b)
     """)
 
-  @test_base.skip('Does not yet work with fake args.')
   def test_unpack_posargs(self):
     self.Check("""
       def f(x, y, z):
-        g(*x, *y, *z)
+        return g(*x, *y, *z)
 
       def g(*args):
-        h(*args)
+        return h(*args)
 
       def h(p, q, r, s, t, u):
         return u

From e8da23e4a89d8cd996dd0a9b5950056c2d9dcdf4 Mon Sep 17 00:00:00 2001
From: mdemello 
Date: Fri, 19 Apr 2024 14:42:05 -0700
Subject: [PATCH 09/22] rewrite: Support indefinite **args and **args
 forwarding

- Replaces `internal.ConstKeyDict` with `internal.FunctionArgsDict`
- Adds better support for indefinite `containers.Dict` and conversion between
  `Dict` and `FunctionArgsDict`
- Fixes balancing of args between `kwargs` and `**args` in the arg mapper

PiperOrigin-RevId: 626473976
---
 pytype/rewrite/abstract/abstract.py        |  2 +-
 pytype/rewrite/abstract/base.py            |  1 +
 pytype/rewrite/abstract/containers.py      | 46 +++++++++++++++++----
 pytype/rewrite/abstract/containers_test.py | 16 ++++++--
 pytype/rewrite/abstract/functions.py       | 41 +++++++++++--------
 pytype/rewrite/abstract/internal.py        | 47 ++++++++++++++--------
 pytype/rewrite/abstract/internal_test.py   | 11 +++--
 pytype/rewrite/frame.py                    | 39 ++++++++----------
 pytype/rewrite/frame_test.py               |  8 ++--
 pytype/rewrite/output.py                   |  2 +
 pytype/rewrite/tests/test_args.py          | 38 +++++++++++++++++
 11 files changed, 177 insertions(+), 74 deletions(-)

diff --git a/pytype/rewrite/abstract/abstract.py b/pytype/rewrite/abstract/abstract.py
index c60c5871d..d5b163372 100644
--- a/pytype/rewrite/abstract/abstract.py
+++ b/pytype/rewrite/abstract/abstract.py
@@ -36,7 +36,7 @@
 Set = _containers.Set
 Tuple = _containers.Tuple
 
-ConstKeyDict = _internal.ConstKeyDict
+FunctionArgDict = _internal.FunctionArgDict
 FunctionArgTuple = _internal.FunctionArgTuple
 Splat = _internal.Splat
 
diff --git a/pytype/rewrite/abstract/base.py b/pytype/rewrite/abstract/base.py
index 7d626b3f1..5aa85a462 100644
--- a/pytype/rewrite/abstract/base.py
+++ b/pytype/rewrite/abstract/base.py
@@ -161,4 +161,5 @@ def _attrs(self):
   def instantiate(self):
     return Union(self._ctx, tuple(o.instantiate() for o in self.options))
 
+
 AbstractVariableType = variables.Variable[BaseValue]
diff --git a/pytype/rewrite/abstract/containers.py b/pytype/rewrite/abstract/containers.py
index 76ebb6c42..8ebbef103 100644
--- a/pytype/rewrite/abstract/containers.py
+++ b/pytype/rewrite/abstract/containers.py
@@ -45,25 +45,57 @@ class Dict(base.PythonConstant[_Dict[_Variable, _Variable]]):
   """Representation of a Python dict."""
 
   def __init__(
-      self, ctx: base.ContextType, constant: _Dict[_Variable, _Variable]
+      self, ctx: base.ContextType, constant: _Dict[_Variable, _Variable],
+      indefinite: bool = False
   ):
     assert isinstance(constant, dict), constant
     super().__init__(ctx, constant)
-    self.indefinite = False
+    self.indefinite = indefinite
 
   def __repr__(self):
-    return f'Dict({self.constant!r})'
+    indef = '+' if self.indefinite else ''
+    return f'Dict({indef}{self.constant!r})'
+
+  @classmethod
+  def any_dict(cls, ctx):
+    return cls(ctx, {}, indefinite=True)
 
   def setitem(self, key: _Variable, val: _Variable) -> 'Dict':
     return Dict(self._ctx, {**self.constant, key: val})
 
   def update(self, var: _Variable) -> base.BaseValue:
     try:
-      val = utils.get_atomic_constant(var, dict)
+      val = var.get_atomic_value()
     except ValueError:
-      # This dict has multiple possible values, so it is no longer a constant.
-      return self._ctx.abstract_loader.load_raw_type(dict).instantiate()
-    return Dict(self._ctx, {**self.constant, **val})
+      # The update var has multiple possible values, so we cannot merge it into
+      # the constant dict. We also don't know if items have been overwritten, so
+      # we need to discard self.constant
+      return Dict.any_dict(self._ctx)
+
+    if not hasattr(val, 'constant'):
+      # This is an object with no concrete python value
+      return Dict.any_dict(self._ctx)
+    elif isinstance(val, Dict):
+      new_items = val.constant
+    elif isinstance(val, internal.FunctionArgDict):
+      new_items = {
+          self._ctx.consts[k].to_variable(): v
+          for k, v in val.constant.items()
+      }
+    else:
+      raise ValueError('Unexpected dict update:', val)
+
+    return Dict(
+        self._ctx, {**self.constant, **new_items},
+        self.indefinite or val.indefinite
+    )
+
+  def to_function_arg_dict(self) -> internal.FunctionArgDict:
+    new_const = {
+        utils.get_atomic_constant(k, str): v
+        for k, v in self.constant.items()
+    }
+    return internal.FunctionArgDict(self._ctx, new_const, self.indefinite)
 
 
 class Set(base.PythonConstant[_Set[_Variable]]):
diff --git a/pytype/rewrite/abstract/containers_test.py b/pytype/rewrite/abstract/containers_test.py
index 2fcdc2b8b..fb63ce877 100644
--- a/pytype/rewrite/abstract/containers_test.py
+++ b/pytype/rewrite/abstract/containers_test.py
@@ -82,7 +82,9 @@ def test_update_indefinite(self):
     d1 = containers.Dict(self.ctx, {})
     indef = self.ctx.abstract_loader.load_raw_type(dict).instantiate()
     d2 = d1.update(indef.to_variable())
-    self.assertEqual(d2, indef)
+    self.assertIsInstance(d2, containers.Dict)
+    self.assertEqual(d2.constant, {})
+    self.assertTrue(d2.indefinite)
 
   def test_update_multiple_bindings(self):
     d1 = containers.Dict(self.ctx, {})
@@ -90,8 +92,16 @@ def test_update_multiple_bindings(self):
     d3 = containers.Dict(self.ctx, {self.const_var("c"): self.const_var("d")})
     var = variables.Variable((variables.Binding(d2), variables.Binding(d3)))
     d4 = d1.update(var)
-    self.assertEqual(
-        d4, self.ctx.abstract_loader.load_raw_type(dict).instantiate())
+    self.assertIsInstance(d4, containers.Dict)
+    self.assertEqual(d4.constant, {})
+    self.assertTrue(d4.indefinite)
+
+  def test_update_from_arg_dict(self):
+    d1 = containers.Dict(self.ctx, {})
+    d2 = internal.FunctionArgDict(self.ctx, {"a": self.const_var("b")})
+    d3 = d1.update(d2.to_variable())
+    self.assertIsInstance(d3, containers.Dict)
+    self.assertEqual(d3.constant, {self.const_var("a"): self.const_var("b")})
 
 
 class SetTest(BaseTest):
diff --git a/pytype/rewrite/abstract/functions.py b/pytype/rewrite/abstract/functions.py
index 92c0cccd7..8c697f150 100644
--- a/pytype/rewrite/abstract/functions.py
+++ b/pytype/rewrite/abstract/functions.py
@@ -88,7 +88,7 @@ def get_concrete_starstarargs(self) -> Mapping[str, Any]:
     """Returns a concrete dict from starstarargs or raises ValueError."""
     if self.starstarargs is None:
       raise ValueError('No starstarargs to convert')
-    starstarargs = self.starstarargs.get_atomic_value(internal.ConstKeyDict)  # pytype: disable=attribute-error
+    starstarargs = self.starstarargs.get_atomic_value(internal.FunctionArgDict)  # pytype: disable=attribute-error
     return starstarargs.constant
 
 
@@ -230,34 +230,41 @@ def _map_posargs(self):
 
   def _unpack_starstarargs(self):
     """Adjust **args and kwargs based on function signature."""
-    if self.args.starstarargs is None:
-      # Nothing to unpack
-      return self.args.kwargs, None
-    try:
-      starstarargs_dict = self.args.get_concrete_starstarargs()
-    except ValueError:
-      # We have a non-concrete starstarargs
-      return self.args.kwargs, self.args.starstarargs
+    starstarargs_var = self.args.starstarargs
+    if starstarargs_var is None:
+      # There is nothing to unpack, but we might want to move unused kwargs into
+      # sig.kwargs_name
+      starstarargs = internal.FunctionArgDict(self._ctx, {})
+    else:
+      # Do not catch the error; this should always succeed
+      starstarargs = starstarargs_var.get_atomic_value(internal.FunctionArgDict)
     # Unpack **args into kwargs, overwriting named args for now
     # TODO(mdemello): raise an error if we have a conflict
-    kwargs = {**self.args.kwargs}
-    starstarargs_dict = {**starstarargs_dict}
+    kwargs_dict = {**self.args.kwargs}
+    starstarargs_dict = {**starstarargs.constant}
     for k in self.sig.param_names:
       if k in starstarargs_dict:
-        kwargs[k] = starstarargs_dict[k]
+        kwargs_dict[k] = starstarargs_dict[k]
         del starstarargs_dict[k]
+      elif starstarargs.indefinite:
+        kwargs_dict[k] = self._ctx.consts.Any.to_variable()
+    # Absorb extra kwargs into the sig's **args if present
+    if self.sig.kwargs_name:
+      extra = set(kwargs_dict) - set(self.sig.param_names)
+      for k in extra:
+        starstarargs_dict[k] = kwargs_dict[k]
+        del kwargs_dict[k]
     # Pack the unused entries in starstarargs back into an abstract value
-    starstarargs = internal.ConstKeyDict(self._ctx, starstarargs_dict)
-    return kwargs, starstarargs.to_variable()
+    new_starstarargs = internal.FunctionArgDict(
+        self._ctx, starstarargs_dict, starstarargs.indefinite)
+    return kwargs_dict, new_starstarargs.to_variable()
 
   def _map_kwargs(self):
     kwargs, starstarargs = self._unpack_starstarargs()
     # Copy kwargs into argdict
     self.argdict.update(kwargs)
-    # Make sure kwargs_name is bound to something
+    # Bind kwargs_name to remaining **args
     if self.sig.kwargs_name:
-      if starstarargs is None:
-        starstarargs = internal.ConstKeyDict(self._ctx, {}).to_variable()
       self.argdict[self.sig.kwargs_name] = starstarargs
 
   def map_args(self):
diff --git a/pytype/rewrite/abstract/internal.py b/pytype/rewrite/abstract/internal.py
index d170c8231..627f13f58 100644
--- a/pytype/rewrite/abstract/internal.py
+++ b/pytype/rewrite/abstract/internal.py
@@ -12,39 +12,54 @@
 _Variable = base.AbstractVariableType
 
 
-class ConstKeyDict(base.BaseValue):
-  """Dictionary with constant literal keys.
-
-  Used by the python interpreter to construct function args.
-  """
+class FunctionArgTuple(base.BaseValue):
+  """Representation of a function arg tuple."""
 
-  def __init__(self, ctx: base.ContextType, constant: Dict[str, _Variable]):
+  def __init__(self, ctx: base.ContextType, constant: Tuple[_Variable, ...]):
     super().__init__(ctx)
-    assert isinstance(constant, dict), constant
+    assert isinstance(constant, tuple), constant
     self.constant = constant
 
   def __repr__(self):
-    return f"ConstKeyDict({self.constant!r})"
+    return f"FunctionArgTuple({self.constant!r})"
 
   @property
   def _attrs(self):
-    return (immutabledict.immutabledict(self.constant),)
+    return (self.constant,)
 
 
-class FunctionArgTuple(base.BaseValue):
-  """Representation of a function arg tuple."""
+class FunctionArgDict(base.BaseValue):
+  """Representation of a function kwarg dict."""
 
-  def __init__(self, ctx: base.ContextType, constant: Tuple[_Variable, ...]):
-    super().__init__(ctx)
-    assert isinstance(constant, tuple), constant
+  def __init__(
+      self,
+      ctx: base.ContextType,
+      constant: Dict[str, _Variable],
+      indefinite: bool = False
+  ):
+    self._ctx = ctx
+    self._check_keys(constant)
     self.constant = constant
+    self.indefinite = indefinite
+
+  @classmethod
+  def any_kwargs(cls, ctx):
+    """Return a new kwargs dict with only indefinite values."""
+    return cls(ctx, {}, indefinite=True)
+
+  def _check_keys(self, constant: Dict[str, _Variable]):
+    """Runtime check to ensure the invariant."""
+    assert isinstance(constant, dict), constant
+    if not all(isinstance(k, str) for k in constant):
+      raise ValueError("Passing a non-string key to a function arg dict")
 
   def __repr__(self):
-    return f"FunctionArgTuple({self.constant!r})"
+    indef = "+" if self.indefinite else ""
+    return f"FunctionArgDict({indef}{self.constant!r})"
 
   @property
   def _attrs(self):
-    return (self.constant,)
+    return (immutabledict.immutabledict(self.constant), self.indefinite)
 
 
 class Splat(base.BaseValue):
diff --git a/pytype/rewrite/abstract/internal_test.py b/pytype/rewrite/abstract/internal_test.py
index 5de4551cc..4bd2a9ad1 100644
--- a/pytype/rewrite/abstract/internal_test.py
+++ b/pytype/rewrite/abstract/internal_test.py
@@ -6,15 +6,20 @@
 import unittest
 
 
-class ConstKeyDictTest(test_utils.ContextfulTestBase):
+class FunctionArgDictTest(test_utils.ContextfulTestBase):
 
   def test_asserts_dict(self):
-    _ = internal.ConstKeyDict(self.ctx, {
+    _ = internal.FunctionArgDict(self.ctx, {
         'a': self.ctx.consts.Any.to_variable()
     })
     with self.assertRaises(AssertionError):
       x: Any = ['a', 'b']
-      _ = internal.ConstKeyDict(self.ctx, x)
+      _ = internal.FunctionArgDict(self.ctx, x)
+
+  def test_asserts_string_keys(self):
+    with self.assertRaises(ValueError):
+      x: Any = {1: 2}
+      _ = internal.FunctionArgDict(self.ctx, x)
 
 
 class SplatTest(test_utils.ContextfulTestBase):
diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py
index 7b49103a4..252b3296a 100644
--- a/pytype/rewrite/frame.py
+++ b/pytype/rewrite/frame.py
@@ -497,7 +497,7 @@ def byte_MAKE_FUNCTION(self, opcode):
       pos_defaults = pop_const(tuple)
     if arg & _Flags.MAKE_FUNCTION_HAS_KW_DEFAULTS:
       packed_kw_def = self._stack.pop()
-      kw_defaults = packed_kw_def.get_atomic_value(abstract.ConstKeyDict)
+      kw_defaults = packed_kw_def.get_atomic_value(abstract.Dict)
     # Make function
     del annot, pos_defaults, kw_defaults  # TODO(b/241479600): Use these
     func = abstract.InterpreterFunction(
@@ -643,22 +643,15 @@ def _unpack_starargs(self, starargs) -> abstract.BaseValue:
       assert False, f'unexpected posargs type: {posargs}: {type(posargs)}'
     return posargs
 
-  def _unpack_starstarargs(self, starstarargs) -> abstract.BaseValue:
+  def _unpack_starstarargs(self, starstarargs) -> abstract.FunctionArgDict:
     kwargs = starstarargs.get_atomic_value()
-    if isinstance(kwargs, abstract.ConstKeyDict):
+    if isinstance(kwargs, abstract.FunctionArgDict):
       # This has already been converted
       pass
-    elif isinstance(kwargs, abstract.FrozenInstance):
-      # This is indefinite; leave it as-is
-      pass
-    elif isinstance(kwargs, abstract.PythonConstant):
-      assert isinstance(kwargs.constant, dict)
-      kwargs = abstract.ConstKeyDict(self._ctx, {
-          abstract.get_atomic_constant(k, str): v
-          for k, v in kwargs.constant.items()
-      })
+    elif isinstance(kwargs, abstract.Dict):
+      kwargs = kwargs.to_function_arg_dict()
     elif abstract.is_any(kwargs):
-      kwargs = self._ctx.abstract_loader.load_raw_type(dict).instantiate()
+      kwargs = abstract.FunctionArgDict.any_kwargs(self._ctx)
     else:
       assert False, f'unexpected kwargs type: {kwargs}: {type(kwargs)}'
     return kwargs
@@ -668,14 +661,16 @@ def byte_CALL_FUNCTION_EX(self, opcode):
     if opcode.arg & _Flags.CALL_FUNCTION_EX_HAS_KWARGS:
       starstarargs = self._stack.pop()
       unpacked_starstarargs = self._unpack_starstarargs(starstarargs)
-      if isinstance(
-          unpacked_starstarargs, (abstract.Dict, abstract.ConstKeyDict)):
-        # We have a concrete dict we are unpacking; move it into kwargs
-        kwargs = unpacked_starstarargs.constant
-        starstarargs = None
+      # If we have a concrete dict we are unpacking; move it into kwargs (if
+      # not, .constant will be {} anyway, so we don't need to check here.)
+      kwargs = unpacked_starstarargs.constant
+      if unpacked_starstarargs.indefinite:
+        # We also have **args, apart from the concrete kv pairs we moved into
+        # kwargs, that need to be preserved.
+        starstarargs = (
+            abstract.FunctionArgDict.any_kwargs(self._ctx).to_variable())
       else:
-        # We have an indefinite dict, leave it in starstarargs
-        kwargs = datatypes.EMPTY_MAP
+        starstarargs = None
     else:
       kwargs = datatypes.EMPTY_MAP
       starstarargs = None
@@ -757,12 +752,10 @@ def byte_BUILD_CONST_KEY_MAP(self, opcode):
     # to abstract objects because they are used internally to construct function
     # call args.
     keys = abstract.get_atomic_constant(keys, tuple)
-    # Unpack the keys into raw strings.
-    keys = [abstract.get_atomic_constant(k, str) for k in keys]
     assert len(keys) == n_elts
     vals = self._stack.popn(n_elts)
     ret = dict(zip(keys, vals))
-    ret = abstract.ConstKeyDict(self._ctx, ret)
+    ret = abstract.Dict(self._ctx, ret)
     self._stack.push(ret.to_variable())
 
   def byte_LIST_APPEND(self, opcode):
diff --git a/pytype/rewrite/frame_test.py b/pytype/rewrite/frame_test.py
index 77b627bdd..0acdf86fa 100644
--- a/pytype/rewrite/frame_test.py
+++ b/pytype/rewrite/frame_test.py
@@ -566,11 +566,11 @@ def test_const_key_map(self):
       b = 2
       c = 3
       constant = {'a': a, 'b': b, 'c': c}
-    """, typ=abstract.ConstKeyDict)
+    """, typ=abstract.Dict)
     self.assertEqual(constant.constant, {
-        'a': self._const_var(1, 'a'),
-        'b': self._const_var(2, 'b'),
-        'c': self._const_var(3, 'c'),
+        self._const_var('a'): self._const_var(1, 'a'),
+        self._const_var('b'): self._const_var(2, 'b'),
+        self._const_var('c'): self._const_var(3, 'c'),
     })
 
 
diff --git a/pytype/rewrite/output.py b/pytype/rewrite/output.py
index a9fa14639..3de2dae1a 100644
--- a/pytype/rewrite/output.py
+++ b/pytype/rewrite/output.py
@@ -176,6 +176,8 @@ def to_pytd_type(self, val: abstract.BaseValue) -> pytd.Type:
       return pytd_utils.JoinTypes(self.to_pytd_type(v) for v in val.options)
     elif isinstance(val, abstract.PythonConstant):
       return pytd.NamedType(f'builtins.{val.constant.__class__.__name__}')
+    elif isinstance(val, abstract.FunctionArgDict):
+      return pytd.NamedType('builtins.dict')
     elif isinstance(val, abstract.SimpleClass):
       return pytd.GenericType(
           base_type=pytd.NamedType('builtins.type'),
diff --git a/pytype/rewrite/tests/test_args.py b/pytype/rewrite/tests/test_args.py
index 3e7bf1dbf..8d78d946c 100644
--- a/pytype/rewrite/tests/test_args.py
+++ b/pytype/rewrite/tests/test_args.py
@@ -83,6 +83,44 @@ def h(p, q, r, s, t, u):
       assert_type(ret, int)
     """)
 
+  def test_indef_starstarargs(self):
+    self.Check("""
+      def f(**args):
+        return g(**args)
+
+      def g(x, y, z):
+        return z
+    """)
+
+  def test_forward_starstarargs(self):
+    self.Check("""
+      def f(**args):
+        return g(**args)
+
+      def g(**args):
+        return h(**args)
+
+      def h(p, q, r):
+        return r
+
+      args = {'p': 1, 'q': 2, 'r': 3, 's': 4}
+      ret = f(**args)
+      assert_type(ret, int)
+    """)
+
+  def test_capture_starstarargs(self):
+    self.Check("""
+      def f(**args):
+        return g(args)
+
+      def g(args):
+        return args
+
+      args = {'p': 1, 'q': 2, 'r': 3, 's': 4}
+      ret = f(**args)
+      assert_type(ret, dict)
+    """)
+
 
 if __name__ == '__main__':
   test_base.main()

From f26776350720d1526ce8f10652f6263a9b958eca Mon Sep 17 00:00:00 2001
From: rechen 
Date: Fri, 19 Apr 2024 16:02:40 -0700
Subject: [PATCH 10/22] Add a less annoying way to invoke
 ctx.abstract_loader.load_raw_type.

PiperOrigin-RevId: 626492628
---
 pytype/rewrite/abstract/base.py            |  1 +
 pytype/rewrite/abstract/classes.py         |  2 +-
 pytype/rewrite/abstract/classes_test.py    |  4 +---
 pytype/rewrite/abstract/containers.py      |  2 +-
 pytype/rewrite/abstract/containers_test.py |  7 +++----
 pytype/rewrite/abstract/functions_test.py  |  2 +-
 pytype/rewrite/abstract/internal_test.py   |  3 +--
 pytype/rewrite/context.py                  |  4 +++-
 pytype/rewrite/convert.py                  |  2 +-
 pytype/rewrite/frame.py                    |  2 +-
 pytype/rewrite/load_abstract.py            | 17 +++++++++++++++++
 11 files changed, 31 insertions(+), 15 deletions(-)

diff --git a/pytype/rewrite/abstract/base.py b/pytype/rewrite/abstract/base.py
index 5aa85a462..dce928ef8 100644
--- a/pytype/rewrite/abstract/base.py
+++ b/pytype/rewrite/abstract/base.py
@@ -24,6 +24,7 @@ class ContextType(Protocol):
   abstract_loader: Any
   pytd_converter: Any
   consts: Any
+  types: Any
 
 
 class BaseValue(types.BaseValue, abc.ABC):
diff --git a/pytype/rewrite/abstract/classes.py b/pytype/rewrite/abstract/classes.py
index 4e951bfc5..450cc43b6 100644
--- a/pytype/rewrite/abstract/classes.py
+++ b/pytype/rewrite/abstract/classes.py
@@ -138,7 +138,7 @@ def mro(self) -> Sequence['SimpleClass']:
       self._mro = mro = [self]
       return mro
     bases = list(self.bases)
-    obj_type = self._ctx.abstract_loader.load_raw_type(object)
+    obj_type = self._ctx.types[object]
     if not bases or bases[-1] != obj_type:
       bases.append(obj_type)
     mro_bases = [[self]] + [list(base.mro()) for base in bases] + [bases]
diff --git a/pytype/rewrite/abstract/classes_test.py b/pytype/rewrite/abstract/classes_test.py
index bcbec0b11..02a7c2acb 100644
--- a/pytype/rewrite/abstract/classes_test.py
+++ b/pytype/rewrite/abstract/classes_test.py
@@ -35,9 +35,7 @@ def test_call(self):
   def test_mro(self):
     parent = classes.SimpleClass(self.ctx, 'Parent', {})
     child = classes.SimpleClass(self.ctx, 'Child', {}, bases=[parent])
-    self.assertEqual(
-        child.mro(),
-        [child, parent, self.ctx.abstract_loader.load_raw_type(object)])
+    self.assertEqual(child.mro(), [child, parent, self.ctx.types[object]])
 
 
 class MutableInstanceTest(test_utils.ContextfulTestBase):
diff --git a/pytype/rewrite/abstract/containers.py b/pytype/rewrite/abstract/containers.py
index 8ebbef103..ab435ddec 100644
--- a/pytype/rewrite/abstract/containers.py
+++ b/pytype/rewrite/abstract/containers.py
@@ -32,7 +32,7 @@ def extend(self, var: _Variable) -> base.BaseValue:
       val = var.get_atomic_value()
     except ValueError:
       # This list has multiple possible values, so it is no longer a constant.
-      return self._ctx.abstract_loader.load_raw_type(list).instantiate()
+      return self._ctx.types[list].instantiate()
     if isinstance(val, List):
       new_constant = self.constant + val.constant
     else:
diff --git a/pytype/rewrite/abstract/containers_test.py b/pytype/rewrite/abstract/containers_test.py
index fb63ce877..bb2fa707b 100644
--- a/pytype/rewrite/abstract/containers_test.py
+++ b/pytype/rewrite/abstract/containers_test.py
@@ -41,7 +41,7 @@ def test_extend(self):
 
   def test_extend_splat(self):
     l1 = containers.List(self.ctx, [self.const_var("a")])
-    l2 = self.ctx.abstract_loader.load_raw_type(list).instantiate()
+    l2 = self.ctx.types[list].instantiate()
     l3 = l1.extend(l2.to_variable())
     self.assertIsInstance(l3, containers.List)
     self.assertEqual(
@@ -54,8 +54,7 @@ def test_extend_multiple_bindings(self):
     l3 = containers.List(self.ctx, [self.const_var("c")])
     var = variables.Variable((variables.Binding(l2), variables.Binding(l3)))
     l4 = l1.extend(var)
-    self.assertEqual(
-        l4, self.ctx.abstract_loader.load_raw_type(list).instantiate())
+    self.assertEqual(l4, self.ctx.types[list].instantiate())
 
 
 class DictTest(BaseTest):
@@ -80,7 +79,7 @@ def test_update(self):
 
   def test_update_indefinite(self):
     d1 = containers.Dict(self.ctx, {})
-    indef = self.ctx.abstract_loader.load_raw_type(dict).instantiate()
+    indef = self.ctx.types[dict].instantiate()
     d2 = d1.update(indef.to_variable())
     self.assertIsInstance(d2, containers.Dict)
     self.assertEqual(d2.constant, {})
diff --git a/pytype/rewrite/abstract/functions_test.py b/pytype/rewrite/abstract/functions_test.py
index 87f745546..11ca5d40f 100644
--- a/pytype/rewrite/abstract/functions_test.py
+++ b/pytype/rewrite/abstract/functions_test.py
@@ -56,7 +56,7 @@ def test_map_args(self):
     self.assertEqual(args.argdict, {'x': x, 'y': y})
 
   def test_fake_args(self):
-    annotations = {'x': self.ctx.abstract_loader.load_raw_type(int)}
+    annotations = {'x': self.ctx.types[int]}
     signature = functions.Signature(self.ctx, 'f', ('x', 'y'),
                                     annotations=annotations)
     args = signature.make_fake_args()
diff --git a/pytype/rewrite/abstract/internal_test.py b/pytype/rewrite/abstract/internal_test.py
index 4bd2a9ad1..19890965a 100644
--- a/pytype/rewrite/abstract/internal_test.py
+++ b/pytype/rewrite/abstract/internal_test.py
@@ -26,8 +26,7 @@ class SplatTest(test_utils.ContextfulTestBase):
 
   def test_basic(self):
     # Basic smoke test, remove when we have some real functionality to test.
-    cls = self.ctx.abstract_loader.load_raw_type(tuple)
-    seq = cls.instantiate()
+    seq = self.ctx.types[tuple].instantiate()
     x = internal.Splat(self.ctx, seq)
     self.assertEqual(x.iterable, seq)
 
diff --git a/pytype/rewrite/context.py b/pytype/rewrite/context.py
index 45545650c..5e6c75f15 100644
--- a/pytype/rewrite/context.py
+++ b/pytype/rewrite/context.py
@@ -34,6 +34,7 @@ class Context:
   abstract_loader: load_abstract.AbstractLoader
   pytd_converter: output.PytdConverter
   consts: load_abstract.Constants
+  types: load_abstract.Types
 
   def __init__(
       self,
@@ -48,5 +49,6 @@ def __init__(
     self.abstract_loader = load_abstract.AbstractLoader(self, self.pytd_loader)
     self.pytd_converter = output.PytdConverter(self)
 
-    # We access these all the time, so create a convenient alias.
+    # We access these all the time, so create convenient aliases.
     self.consts = self.abstract_loader.consts
+    self.types = self.abstract_loader.types
diff --git a/pytype/rewrite/convert.py b/pytype/rewrite/convert.py
index 989512881..5063877f3 100644
--- a/pytype/rewrite/convert.py
+++ b/pytype/rewrite/convert.py
@@ -111,7 +111,7 @@ def _pytd_type_to_value(self, typ: pytd.Type) -> abstract.BaseValue:
     elif isinstance(typ, pytd.TypeParameter):
       return self._ctx.consts.Any
     elif isinstance(typ, pytd.Literal):
-      return self._ctx.abstract_loader.load_raw_type(type(typ.value))
+      return self._ctx.types[type(typ.value)]
     elif isinstance(typ, pytd.Annotated):
       # We discard the Annotated wrapper for now, but we will need to keep track
       # of it because Annotated is a special form that can be used in generic
diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py
index 252b3296a..6ddfdae1e 100644
--- a/pytype/rewrite/frame.py
+++ b/pytype/rewrite/frame.py
@@ -638,7 +638,7 @@ def _unpack_starargs(self, starargs) -> abstract.BaseValue:
     elif isinstance(posargs, tuple):
       posargs = abstract.FunctionArgTuple(self._ctx, posargs)
     elif abstract.is_any(posargs):
-      return self._ctx.abstract_loader.load_raw_type(tuple).instantiate()
+      return self._ctx.types[tuple].instantiate()
     else:
       assert False, f'unexpected posargs type: {posargs}: {type(posargs)}'
     return posargs
diff --git a/pytype/rewrite/load_abstract.py b/pytype/rewrite/load_abstract.py
index eca86093c..f1ad993d2 100644
--- a/pytype/rewrite/load_abstract.py
+++ b/pytype/rewrite/load_abstract.py
@@ -39,6 +39,20 @@ def __getitem__(self, const: _Any):
     return self._consts[const]
 
 
+class Types:
+  """Wrapper for AbstractLoader.load_raw_types.
+
+  We use this method all the time, so we provide a convenient wrapper for it.
+  For consistency, this wrapper has the same interface as Constants above.
+  """
+
+  def __init__(self, ctx: abstract.ContextType):
+    self._ctx = ctx
+
+  def __getitem__(self, raw_type: Type[_Any]) -> abstract.BaseValue:
+    return self._ctx.abstract_loader.load_raw_type(raw_type)
+
+
 class AbstractLoader:
   """Abstract loader."""
 
@@ -47,6 +61,7 @@ def __init__(self, ctx: abstract.ContextType, pytd_loader: load_pytd.Loader):
     self._pytd_loader = pytd_loader
 
     self.consts = Constants(ctx)
+    self.types = Types(ctx)
     self._special_builtins = {
         'assert_type': special_builtins.AssertType(self._ctx),
         'reveal_type': special_builtins.RevealType(self._ctx),
@@ -96,6 +111,8 @@ def get_module_globals(self) -> Dict[str, abstract.BaseValue]:
   def load_raw_type(self, typ: Type[_Any]) -> abstract.BaseValue:
     """Converts a raw type to an abstract value.
 
+    For convenience, this method can also be called via ctx.types[typ].
+
     Args:
       typ: The type.
 

From d9c1435e84a3896f12554b4e9f96a4950933f8c4 Mon Sep 17 00:00:00 2001
From: mdemello 
Date: Fri, 19 Apr 2024 18:24:48 -0700
Subject: [PATCH 11/22] rewrite: Move unpacking of DICT_UPDATE and LIST_EXTEND
 args into the vm

The `abstract.Dict` and `abstract.List` containers now only get their update
methods called with a well-formed Dict/List

PiperOrigin-RevId: 626520407
---
 pytype/rewrite/abstract/CMakeLists.txt     |  2 -
 pytype/rewrite/abstract/containers.py      | 48 +++++++--------------
 pytype/rewrite/abstract/containers_test.py | 48 +--------------------
 pytype/rewrite/abstract/internal.py        |  6 ++-
 pytype/rewrite/frame.py                    | 49 ++++++++++++++++++++--
 5 files changed, 66 insertions(+), 87 deletions(-)

diff --git a/pytype/rewrite/abstract/CMakeLists.txt b/pytype/rewrite/abstract/CMakeLists.txt
index 4f0d09de6..e1cc7b47e 100644
--- a/pytype/rewrite/abstract/CMakeLists.txt
+++ b/pytype/rewrite/abstract/CMakeLists.txt
@@ -83,8 +83,6 @@ py_test(
   DEPS
     .base
     .containers
-    .internal
-    pytype.rewrite.flow.flow
     pytype.rewrite.tests.test_utils
 )
 
diff --git a/pytype/rewrite/abstract/containers.py b/pytype/rewrite/abstract/containers.py
index ab435ddec..cd3d868c0 100644
--- a/pytype/rewrite/abstract/containers.py
+++ b/pytype/rewrite/abstract/containers.py
@@ -27,17 +27,8 @@ def __repr__(self):
   def append(self, var: _Variable) -> 'List':
     return List(self._ctx, self.constant + [var])
 
-  def extend(self, var: _Variable) -> base.BaseValue:
-    try:
-      val = var.get_atomic_value()
-    except ValueError:
-      # This list has multiple possible values, so it is no longer a constant.
-      return self._ctx.types[list].instantiate()
-    if isinstance(val, List):
-      new_constant = self.constant + val.constant
-    else:
-      splat = internal.Splat(self._ctx, val)
-      new_constant = self.constant + [splat.to_variable()]
+  def extend(self, val: 'List') -> 'List':
+    new_constant = self.constant + val.constant
     return List(self._ctx, new_constant)
 
 
@@ -60,33 +51,22 @@ def __repr__(self):
   def any_dict(cls, ctx):
     return cls(ctx, {}, indefinite=True)
 
+  @classmethod
+  def from_function_arg_dict(
+      cls, ctx: base.ContextType, val: internal.FunctionArgDict
+  ) -> 'Dict':
+    new_constant = {
+        ctx.consts[k].to_variable(): v
+        for k, v in val.constant.items()
+    }
+    return cls(ctx, new_constant, val.indefinite)
+
   def setitem(self, key: _Variable, val: _Variable) -> 'Dict':
     return Dict(self._ctx, {**self.constant, key: val})
 
-  def update(self, var: _Variable) -> base.BaseValue:
-    try:
-      val = var.get_atomic_value()
-    except ValueError:
-      # The update var has multiple possible values, so we cannot merge it into
-      # the constant dict. We also don't know if items have been overwritten, so
-      # we need to discard self.constant
-      return Dict.any_dict(self._ctx)
-
-    if not hasattr(val, 'constant'):
-      # This is an object with no concrete python value
-      return Dict.any_dict(self._ctx)
-    elif isinstance(val, Dict):
-      new_items = val.constant
-    elif isinstance(val, internal.FunctionArgDict):
-      new_items = {
-          self._ctx.consts[k].to_variable(): v
-          for k, v in val.constant.items()
-      }
-    else:
-      raise ValueError('Unexpected dict update:', val)
-
+  def update(self, val: 'Dict') -> base.BaseValue:
     return Dict(
-        self._ctx, {**self.constant, **new_items},
+        self._ctx, {**self.constant, **val.constant},
         self.indefinite or val.indefinite
     )
 
diff --git a/pytype/rewrite/abstract/containers_test.py b/pytype/rewrite/abstract/containers_test.py
index bb2fa707b..174126605 100644
--- a/pytype/rewrite/abstract/containers_test.py
+++ b/pytype/rewrite/abstract/containers_test.py
@@ -2,8 +2,6 @@
 
 from pytype.rewrite.abstract import base
 from pytype.rewrite.abstract import containers
-from pytype.rewrite.abstract import internal
-from pytype.rewrite.flow import variables
 from pytype.rewrite.tests import test_utils
 from typing_extensions import assert_type
 
@@ -35,27 +33,10 @@ def test_append(self):
   def test_extend(self):
     l1 = containers.List(self.ctx, [self.const_var("a")])
     l2 = containers.List(self.ctx, [self.const_var("b")])
-    l3 = l1.extend(l2.to_variable())
+    l3 = l1.extend(l2)
     self.assertIsInstance(l3, containers.List)
     self.assertEqual(l3.constant, [self.const_var("a"), self.const_var("b")])
 
-  def test_extend_splat(self):
-    l1 = containers.List(self.ctx, [self.const_var("a")])
-    l2 = self.ctx.types[list].instantiate()
-    l3 = l1.extend(l2.to_variable())
-    self.assertIsInstance(l3, containers.List)
-    self.assertEqual(
-        l3.constant,
-        [self.const_var("a"), internal.Splat(self.ctx, l2).to_variable()])
-
-  def test_extend_multiple_bindings(self):
-    l1 = containers.List(self.ctx, [self.const_var("a")])
-    l2 = containers.List(self.ctx, [self.const_var("b")])
-    l3 = containers.List(self.ctx, [self.const_var("c")])
-    var = variables.Variable((variables.Binding(l2), variables.Binding(l3)))
-    l4 = l1.extend(var)
-    self.assertEqual(l4, self.ctx.types[list].instantiate())
-
 
 class DictTest(BaseTest):
 
@@ -73,32 +54,7 @@ def test_setitem(self):
   def test_update(self):
     d1 = containers.Dict(self.ctx, {})
     d2 = containers.Dict(self.ctx, {self.const_var("a"): self.const_var("b")})
-    d3 = d1.update(d2.to_variable())
-    self.assertIsInstance(d3, containers.Dict)
-    self.assertEqual(d3.constant, {self.const_var("a"): self.const_var("b")})
-
-  def test_update_indefinite(self):
-    d1 = containers.Dict(self.ctx, {})
-    indef = self.ctx.types[dict].instantiate()
-    d2 = d1.update(indef.to_variable())
-    self.assertIsInstance(d2, containers.Dict)
-    self.assertEqual(d2.constant, {})
-    self.assertTrue(d2.indefinite)
-
-  def test_update_multiple_bindings(self):
-    d1 = containers.Dict(self.ctx, {})
-    d2 = containers.Dict(self.ctx, {self.const_var("a"): self.const_var("b")})
-    d3 = containers.Dict(self.ctx, {self.const_var("c"): self.const_var("d")})
-    var = variables.Variable((variables.Binding(d2), variables.Binding(d3)))
-    d4 = d1.update(var)
-    self.assertIsInstance(d4, containers.Dict)
-    self.assertEqual(d4.constant, {})
-    self.assertTrue(d4.indefinite)
-
-  def test_update_from_arg_dict(self):
-    d1 = containers.Dict(self.ctx, {})
-    d2 = internal.FunctionArgDict(self.ctx, {"a": self.const_var("b")})
-    d3 = d1.update(d2.to_variable())
+    d3 = d1.update(d2)
     self.assertIsInstance(d3, containers.Dict)
     self.assertEqual(d3.constant, {self.const_var("a"): self.const_var("b")})
 
diff --git a/pytype/rewrite/abstract/internal.py b/pytype/rewrite/abstract/internal.py
index 627f13f58..78f74f7d0 100644
--- a/pytype/rewrite/abstract/internal.py
+++ b/pytype/rewrite/abstract/internal.py
@@ -43,7 +43,7 @@ def __init__(
     self.indefinite = indefinite
 
   @classmethod
-  def any_kwargs(cls, ctx):
+  def any_kwargs(cls, ctx: base.ContextType):
     """Return a new kwargs dict with only indefinite values."""
     return cls(ctx, {}, indefinite=True)
 
@@ -74,6 +74,10 @@ def __init__(self, ctx: base.ContextType, iterable: base.BaseValue):
     super().__init__(ctx)
     self.iterable = iterable
 
+  @classmethod
+  def any(cls, ctx: base.ContextType):
+    return cls(ctx, ctx.consts.Any)
+
   def get_concrete_iterable(self):
     if (isinstance(self.iterable, base.PythonConstant) and
         isinstance(self.iterable.constant, collections.abc.Iterable)):
diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py
index 6ddfdae1e..193daea91 100644
--- a/pytype/rewrite/frame.py
+++ b/pytype/rewrite/frame.py
@@ -786,12 +786,45 @@ def byte_MAP_ADD(self, opcode):
     target = target_var.get_atomic_value()
     self._replace_atomic_stack_value(count, target.setitem(key, val))
 
+  def _unpack_list_extension(self, var: _AbstractVariable) -> abstract.List:
+    try:
+      val = var.get_atomic_value()
+    except ValueError:
+      # This list has multiple possible values, so it is no longer a constant.
+      return abstract.List(
+          self._ctx, [abstract.Splat.any(self._ctx).to_variable()])
+    if isinstance(val, abstract.List):
+      return val
+    else:
+      return abstract.List(
+          self._ctx, [abstract.Splat(self._ctx, val).to_variable()])
+
   def byte_LIST_EXTEND(self, opcode):
     count = opcode.arg
-    val = self._stack.pop()
+    update_var = self._stack.pop()
+    update = self._unpack_list_extension(update_var)
     target_var = self._stack.peek(count)
     target = target_var.get_atomic_value()
-    self._replace_atomic_stack_value(count, target.extend(val))
+    self._replace_atomic_stack_value(count, target.extend(update))
+
+  def _unpack_dict_update(
+      self, var: _AbstractVariable
+  ) -> Optional[abstract.Dict]:
+    try:
+      val = var.get_atomic_value()
+    except ValueError:
+      return None
+    if isinstance(val, abstract.Dict):
+      return val
+    elif isinstance(val, abstract.FunctionArgDict):
+      return abstract.Dict.from_function_arg_dict(self._ctx, val)
+    elif abstract.is_any(val):
+      return None
+    elif isinstance(val, abstract.BaseInstance):
+      # This is an object with no concrete python value
+      return None
+    else:
+      raise ValueError('Unexpected dict update:', val)
 
   def byte_DICT_MERGE(self, opcode):
     # DICT_MERGE is like DICT_UPDATE but raises an exception for duplicate keys.
@@ -799,10 +832,18 @@ def byte_DICT_MERGE(self, opcode):
 
   def byte_DICT_UPDATE(self, opcode):
     count = opcode.arg
-    val = self._stack.pop()
+    update_var = self._stack.pop()
+    update = self._unpack_dict_update(update_var)
     target_var = self._stack.peek(count)
     target = target_var.get_atomic_value()
-    self._replace_atomic_stack_value(count, target.update(val))
+    if update is None:
+      # The update var has multiple possible values, or no constant, so we
+      # cannot merge it into the constant dict. We also don't know if existing
+      # items have been overwritten, so we need to return a new 'any' dict.
+      ret = abstract.Dict.any_dict(self._ctx)
+    else:
+      ret = target.update(update)
+    self._replace_atomic_stack_value(count, ret)
 
   def byte_LIST_TO_TUPLE(self, opcode):
     target_var = self._stack.pop()

From bf876f80cdc1bbba2814aa344642f1ce519cbc33 Mon Sep 17 00:00:00 2001
From: rechen 
Date: Fri, 19 Apr 2024 20:55:28 -0700
Subject: [PATCH 12/22] Add the metaclass-calling framework needed for
 EnumMeta.__new__.

Adds the code needed to detect and call a custom implementation of
type.__new__. An implementation of EnumMeta.__new__ will come in a later CL.

Also fixes a minor issue where convert.py was failing to add modules and class
prefixes to method names.

PiperOrigin-RevId: 626545872
---
 pytype/rewrite/abstract/base.py      |  3 ++
 pytype/rewrite/abstract/classes.py   |  2 +-
 pytype/rewrite/abstract/functions.py |  6 ++--
 pytype/rewrite/convert.py            | 16 ++++++---
 pytype/rewrite/convert_test.py       |  3 +-
 pytype/rewrite/frame.py              | 49 +++++++++++++++++++++-------
 pytype/rewrite/tests/test_basic.py   | 24 ++++++++++++++
 pytype/stubs/stdlib/enum.pytd        |  5 ++-
 8 files changed, 88 insertions(+), 20 deletions(-)

diff --git a/pytype/rewrite/abstract/base.py b/pytype/rewrite/abstract/base.py
index dce928ef8..20738a2d2 100644
--- a/pytype/rewrite/abstract/base.py
+++ b/pytype/rewrite/abstract/base.py
@@ -137,6 +137,9 @@ def _attrs(self):
   def instantiate(self) -> 'Singleton':
     return self
 
+  def get_attribute(self, name: str) -> 'Singleton':
+    return self
+
 
 class Union(BaseValue):
   """Union of values."""
diff --git a/pytype/rewrite/abstract/classes.py b/pytype/rewrite/abstract/classes.py
index 450cc43b6..5c2ea348d 100644
--- a/pytype/rewrite/abstract/classes.py
+++ b/pytype/rewrite/abstract/classes.py
@@ -110,7 +110,7 @@ def instantiate(self) -> 'FrozenInstance':
       if isinstance(setup_method, functions_lib.InterpreterFunction):
         _ = setup_method.bind_to(self).analyze()
     constructor = self.get_attribute(self.constructor)
-    if constructor:
+    if constructor and constructor.full_name != 'builtins.object.__new__':
       log.error('Custom __new__ not yet implemented')
     instance = MutableInstance(self._ctx, self)
     for initializer_name in self.initializers:
diff --git a/pytype/rewrite/abstract/functions.py b/pytype/rewrite/abstract/functions.py
index 8c697f150..b333bd84b 100644
--- a/pytype/rewrite/abstract/functions.py
+++ b/pytype/rewrite/abstract/functions.py
@@ -596,8 +596,8 @@ def _attrs(self):
     return (self.name, self.code)
 
   def call_with_mapped_args(self, mapped_args: MappedArgs[_FrameT]) -> _FrameT:
-    log.info('Calling function:\n  Sig:  %s\n  Args: %s',
-             mapped_args.signature, mapped_args.argdict)
+    log.info('Calling function %s:\n  Sig:  %s\n  Args: %s',
+             self.full_name, mapped_args.signature, mapped_args.argdict)
     parent_frame = mapped_args.frame or self._parent_frame
     if parent_frame.final_locals is None:
       k = None
@@ -622,6 +622,8 @@ class PytdFunction(SimpleFunction[SimpleReturn]):
 
   def call_with_mapped_args(
       self, mapped_args: MappedArgs[FrameType]) -> SimpleReturn:
+    log.info('Calling function %s:\n  Sig:  %s\n  Args: %s',
+             self.full_name, mapped_args.signature, mapped_args.argdict)
     ret = mapped_args.signature.annotations['return'].instantiate()
     return SimpleReturn(ret)
 
diff --git a/pytype/rewrite/convert.py b/pytype/rewrite/convert.py
index 5063877f3..086fd0a2d 100644
--- a/pytype/rewrite/convert.py
+++ b/pytype/rewrite/convert.py
@@ -1,5 +1,7 @@
 """Conversion from pytd to abstract representations of Python values."""
 
+from typing import Optional, Tuple
+
 from pytype.pytd import pytd
 from pytype.rewrite.abstract import abstract
 
@@ -38,8 +40,10 @@ def pytd_class_to_value(self, cls: pytd.Class) -> abstract.SimpleClass:
     # don't cause infinite recursion.
     self._cache.classes[cls] = abstract_class
     for method in cls.methods:
-      abstract_class.members[method.name] = (
-          self.pytd_function_to_value(method))
+      # For consistency with InterpreterFunction, prepend the class name.
+      full_name = f'{name}.{method.name}'
+      method_value = self.pytd_function_to_value(method, (module, full_name))
+      abstract_class.members[method.name] = method_value
     for constant in cls.constants:
       constant_type = self.pytd_type_to_value(constant.type)
       abstract_class.members[constant.name] = constant_type.instantiate()
@@ -61,11 +65,15 @@ def pytd_class_to_value(self, cls: pytd.Class) -> abstract.SimpleClass:
     return abstract_class
 
   def pytd_function_to_value(
-      self, func: pytd.Function) -> abstract.PytdFunction:
+      self, func: pytd.Function, func_name: Optional[Tuple[str, str]] = None,
+  ) -> abstract.PytdFunction:
     """Converts a pytd function to an abstract function."""
     if func in self._cache.funcs:
       return self._cache.funcs[func]
-    module, _, name = func.name.rpartition('.')
+    if func_name:
+      module, name = func_name
+    else:
+      module, _, name = func.name.rpartition('.')
     signatures = tuple(
         abstract.Signature.from_pytd(self._ctx, name, pytd_sig)
         for pytd_sig in func.signatures)
diff --git a/pytype/rewrite/convert_test.py b/pytype/rewrite/convert_test.py
index 7a3aea0e4..5e96d62d5 100644
--- a/pytype/rewrite/convert_test.py
+++ b/pytype/rewrite/convert_test.py
@@ -65,7 +65,8 @@ def f(self, x) -> None: ...
     self.assertEqual(set(cls.members), {'f'})
     f = cls.members['f']
     self.assertIsInstance(f, abstract.PytdFunction)
-    self.assertEqual(repr(f.signatures[0]), 'def f(self: C, x: Any) -> None')
+    self.assertEqual(f.module, '')
+    self.assertEqual(repr(f.signatures[0]), 'def C.f(self: C, x: Any) -> None')
 
   def test_constant(self):
     pytd_cls = self.build_pytd("""
diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py
index 193daea91..f8652918b 100644
--- a/pytype/rewrite/frame.py
+++ b/pytype/rewrite/frame.py
@@ -1,5 +1,6 @@
 """A frame of an abstract VM for type analysis of python bytecode."""
 
+import itertools
 import logging
 from typing import Any, FrozenSet, List, Mapping, Optional, Sequence, Set, Type
 
@@ -310,7 +311,8 @@ def _merge_nonlocals_into(self, frame: Optional['Frame']) -> None:
 
   def _build_class(self, args: abstract.Args) -> abstract.InterpreterClass:
     builder = args.posargs[0].get_atomic_value(_FrameFunction)
-    name = abstract.get_atomic_constant(args.posargs[1], str)
+    name_var = args.posargs[1]
+    name = abstract.get_atomic_constant(name_var, str)
 
     base_vars = args.posargs[2:]
     bases = []
@@ -330,16 +332,41 @@ def _build_class(self, args: abstract.Args) -> abstract.InterpreterClass:
       keywords[kw] = val
 
     frame = builder.call(abstract.Args(frame=self))
-    cls = abstract.InterpreterClass(
-        ctx=self._ctx,
-        name=name,
-        members=dict(frame.final_locals),
-        bases=bases,
-        keywords=keywords,
-        functions=frame.functions,
-        classes=frame.classes,
-    )
-    log.info('Created class: %s', cls.name)
+    members = dict(frame.final_locals)
+    metaclass_instance = None
+    for metaclass in itertools.chain([keywords.get('metaclass')],
+                                     (base.metaclass for base in bases)):
+      if not metaclass:
+        continue
+      metaclass_new = metaclass.get_attribute('__new__')
+      if metaclass_new.full_name == 'builtins.type.__new__':
+        continue
+      # The metaclass has overridden type.__new__. Invoke the custom __new__
+      # method to construct the class.
+      metaclass_var = metaclass.to_variable()
+      bases_var = abstract.Tuple(self._ctx, tuple(base_vars)).to_variable()
+      members_var = abstract.Dict(
+          self._ctx, {self._ctx.consts[k].to_variable(): v.to_variable()
+                      for k, v in members.items()}
+      ).to_variable()
+      args = abstract.Args(
+          posargs=(metaclass_var, name_var, bases_var, members_var),
+          frame=self)
+      metaclass_instance = metaclass_new.call(args).get_return_value()
+      break
+    if metaclass_instance and metaclass_instance.full_name == name:
+      cls = metaclass_instance
+    else:
+      cls = abstract.InterpreterClass(
+          ctx=self._ctx,
+          name=name,
+          members=members,
+          bases=bases,
+          keywords=keywords,
+          functions=frame.functions,
+          classes=frame.classes,
+      )
+    log.info('Created class: %r', cls)
     return cls
 
   def _call_function(
diff --git a/pytype/rewrite/tests/test_basic.py b/pytype/rewrite/tests/test_basic.py
index a1517828b..4558d6d8c 100644
--- a/pytype/rewrite/tests/test_basic.py
+++ b/pytype/rewrite/tests/test_basic.py
@@ -142,5 +142,29 @@ def test_aliases(self):
     """)
 
 
+@test_base.skip('Under construction')
+class EnumTest(RewriteTest):
+  """Enum tests."""
+
+  def test_member(self):
+    self.Check("""
+      import enum
+      class E(enum.Enum):
+        X = 42
+      assert_type(E.X, E)
+    """)
+
+  def test_member_pyi(self):
+    with self.DepTree([('foo.pyi', """
+      import enum
+      class E(enum.Enum):
+        X = 42
+    """)]):
+      self.Check("""
+        import foo
+        assert_type(foo.E.X, foo.E)
+      """)
+
+
 if __name__ == '__main__':
   test_base.main()
diff --git a/pytype/stubs/stdlib/enum.pytd b/pytype/stubs/stdlib/enum.pytd
index bcff8d0e4..358ce8c94 100644
--- a/pytype/stubs/stdlib/enum.pytd
+++ b/pytype/stubs/stdlib/enum.pytd
@@ -1,4 +1,4 @@
-from typing import Any, Dict, Iterable, Iterator, Tuple, Type, TypeVar, Union
+from typing import Any, Dict, Iterable, Iterator, Self, Tuple, Type, TypeVar, Union
 
 _T = TypeVar('_T')
 _EnumType = TypeVar('_EnumType', bound=Type[Enum])
@@ -8,6 +8,9 @@ class EnumMeta(type, Iterable):
   def __getitem__(cls: EnumMeta, name: str) -> Any: ...
   def __contains__(self, member: Enum) -> bool: ...
   def __len__(self) -> int: ...
+  def __new__(
+      metacls: type[Self], cls: str, bases: tuple[type, ...], classdict: dict[str, Any], **kwds: Any
+  ) -> Self: ...
 
 class Enum(metaclass=EnumMeta):
   __members__: collections.OrderedDict[str, Enum]

From 518415cbc0d0d9243a6ce528699777f7a791a032 Mon Sep 17 00:00:00 2001
From: rechen 
Date: Fri, 19 Apr 2024 22:40:40 -0700
Subject: [PATCH 13/22] rewrite: improve slightly suboptimal isinstance checks.

Both SimpleFunction and BoundFunction are subclasses of BaseFunction.

PiperOrigin-RevId: 626560114
---
 pytype/rewrite/frame.py  | 4 +---
 pytype/rewrite/output.py | 2 +-
 2 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py
index f8652918b..31af3b979 100644
--- a/pytype/rewrite/frame.py
+++ b/pytype/rewrite/frame.py
@@ -376,9 +376,7 @@ def _call_function(
   ) -> None:
     ret_values = []
     for func in func_var.values:
-      if isinstance(func, (abstract.SimpleFunction,
-                           abstract.InterpreterClass,
-                           abstract.BoundFunction)):
+      if isinstance(func, (abstract.BaseFunction, abstract.InterpreterClass)):
         ret = func.call(args)
         ret_values.append(ret.get_return_value())
       elif func is self._ctx.consts.singles['__build_class__']:
diff --git a/pytype/rewrite/output.py b/pytype/rewrite/output.py
index 3de2dae1a..1a6351a54 100644
--- a/pytype/rewrite/output.py
+++ b/pytype/rewrite/output.py
@@ -32,7 +32,7 @@ def to_pytd_def(self, val: abstract.BaseValue) -> pytd.Node:
     """
     if isinstance(val, abstract.SimpleClass):
       return self._class_to_pytd_def(val)
-    elif isinstance(val, (abstract.SimpleFunction, abstract.BoundFunction)):
+    elif isinstance(val, abstract.BaseFunction):
       return self._function_to_pytd_def(val)
     else:
       raise NotImplementedError(

From 0322324bb7a554c4021ab37890d3ddc667a3972e Mon Sep 17 00:00:00 2001
From: rechen 
Date: Sat, 20 Apr 2024 12:58:59 -0700
Subject: [PATCH 14/22] rewrite: move most of frame.py's function calling code
 into a separate file.

frame.py is already sort of long and will only get longer as we add more
opcodes. This pulls most of the function calling code, which is fairly
self-contained, into a separate file.

PiperOrigin-RevId: 626664567
---
 pytype/rewrite/CMakeLists.txt               |  24 +++
 pytype/rewrite/frame.py                     | 148 +----------------
 pytype/rewrite/frame_test.py                |   5 +-
 pytype/rewrite/function_call_helper.py      | 174 ++++++++++++++++++++
 pytype/rewrite/function_call_helper_test.py |  73 ++++++++
 5 files changed, 282 insertions(+), 142 deletions(-)
 create mode 100644 pytype/rewrite/function_call_helper.py
 create mode 100644 pytype/rewrite/function_call_helper_test.py

diff --git a/pytype/rewrite/CMakeLists.txt b/pytype/rewrite/CMakeLists.txt
index 97cf52d64..e2afcbd3a 100644
--- a/pytype/rewrite/CMakeLists.txt
+++ b/pytype/rewrite/CMakeLists.txt
@@ -71,6 +71,29 @@ py_test(
     pytype.rewrite.tests.test_utils
 )
 
+py_library(
+  NAME
+    function_call_helper
+  SRCS
+    function_call_helper.py
+  DEPS
+    .context
+    pytype.utils
+    pytype.rewrite.abstract.abstract
+    pytype.rewrite.flow.flow
+)
+
+py_test(
+  NAME
+    function_call_helper_test
+  SRCS
+    function_call_helper_test.py
+  DEPS
+    .frame
+    pytype.rewrite.abstract.abstract
+    pytype.rewrite.tests.test_utils
+)
+
 py_library(
   NAME
     frame
@@ -78,6 +101,7 @@ py_library(
     frame.py
   DEPS
     .context
+    .function_call_helper
     .stack
     pytype.utils
     pytype.blocks.blocks
diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py
index 31af3b979..92942fe38 100644
--- a/pytype/rewrite/frame.py
+++ b/pytype/rewrite/frame.py
@@ -1,6 +1,5 @@
 """A frame of an abstract VM for type analysis of python bytecode."""
 
-import itertools
 import logging
 from typing import Any, FrozenSet, List, Mapping, Optional, Sequence, Set, Type
 
@@ -8,6 +7,7 @@
 from pytype import datatypes
 from pytype.blocks import blocks
 from pytype.rewrite import context
+from pytype.rewrite import function_call_helper
 from pytype.rewrite import stack
 from pytype.rewrite.abstract import abstract
 from pytype.rewrite.flow import conditions
@@ -94,8 +94,8 @@ def __init__(
     self._classes: List[abstract.InterpreterClass] = []
     # All variables returned via RETURN_VALUE
     self._returns: List[_AbstractVariable] = []
-    # Function kwnames are stored in the vm by KW_NAMES and retrieved by CALL
-    self._kw_names = ()
+    # Handler for function calls.
+    self._call_helper = function_call_helper.FunctionCallHelper(ctx, self)
 
   def __repr__(self):
     return f'Frame({self.name})'
@@ -309,70 +309,10 @@ def _merge_nonlocals_into(self, frame: Optional['Frame']) -> None:
       var = self._final_locals[name]
       frame.store_global(name, var)
 
-  def _build_class(self, args: abstract.Args) -> abstract.InterpreterClass:
-    builder = args.posargs[0].get_atomic_value(_FrameFunction)
-    name_var = args.posargs[1]
-    name = abstract.get_atomic_constant(name_var, str)
-
-    base_vars = args.posargs[2:]
-    bases = []
-    for base_var in base_vars:
-      try:
-        base = base_var.get_atomic_value(abstract.SimpleClass)
-      except ValueError as e:
-        raise NotImplementedError('Unexpected base class') from e
-      bases.append(base)
-
-    keywords = {}
-    for kw, var in args.kwargs.items():
-      try:
-        val = var.get_atomic_value()
-      except ValueError as e:
-        raise NotImplementedError('Unexpected keyword value') from e
-      keywords[kw] = val
-
-    frame = builder.call(abstract.Args(frame=self))
-    members = dict(frame.final_locals)
-    metaclass_instance = None
-    for metaclass in itertools.chain([keywords.get('metaclass')],
-                                     (base.metaclass for base in bases)):
-      if not metaclass:
-        continue
-      metaclass_new = metaclass.get_attribute('__new__')
-      if metaclass_new.full_name == 'builtins.type.__new__':
-        continue
-      # The metaclass has overridden type.__new__. Invoke the custom __new__
-      # method to construct the class.
-      metaclass_var = metaclass.to_variable()
-      bases_var = abstract.Tuple(self._ctx, tuple(base_vars)).to_variable()
-      members_var = abstract.Dict(
-          self._ctx, {self._ctx.consts[k].to_variable(): v.to_variable()
-                      for k, v in members.items()}
-      ).to_variable()
-      args = abstract.Args(
-          posargs=(metaclass_var, name_var, bases_var, members_var),
-          frame=self)
-      metaclass_instance = metaclass_new.call(args).get_return_value()
-      break
-    if metaclass_instance and metaclass_instance.full_name == name:
-      cls = metaclass_instance
-    else:
-      cls = abstract.InterpreterClass(
-          ctx=self._ctx,
-          name=name,
-          members=members,
-          bases=bases,
-          keywords=keywords,
-          functions=frame.functions,
-          classes=frame.classes,
-      )
-    log.info('Created class: %r', cls)
-    return cls
-
   def _call_function(
       self,
       func_var: _AbstractVariable,
-      args: abstract.Args,
+      args: abstract.Args['Frame'],
   ) -> None:
     ret_values = []
     for func in func_var.values:
@@ -380,7 +320,8 @@ def _call_function(
         ret = func.call(args)
         ret_values.append(ret.get_return_value())
       elif func is self._ctx.consts.singles['__build_class__']:
-        cls = self._build_class(args)
+        cls = self._call_helper.build_class(args)
+        log.info('Created class: %r', cls)
         self._classes.append(cls)
         ret_values.append(cls)
       else:
@@ -619,27 +560,14 @@ def byte_IMPORT_FROM(self, opcode):
 
   def byte_KW_NAMES(self, opcode):
     # Stores a list of kw names to be retrieved by CALL
-    self._kw_names = opcode.argval
-
-  def _make_function_args(self, args):
-    """Unpack args into posargs and kwargs (3.11+)."""
-    if self._kw_names:
-      n_kw = len(self._kw_names)
-      posargs = tuple(args[:-n_kw])
-      kw_vals = args[-n_kw:]
-      kwargs = datatypes.immutabledict(zip(self._kw_names, kw_vals))
-    else:
-      posargs = tuple(args)
-      kwargs = datatypes.EMPTY_MAP
-    self._kw_names = ()
-    return abstract.Args(posargs=posargs, kwargs=kwargs, frame=self)
+    self._call_helper.set_kw_names(opcode.argval)
 
   def byte_CALL(self, opcode):
     sentinel, *rest = self._stack.popn(opcode.arg + 2)
     if not sentinel.has_atomic_value(self._ctx.consts.singles['NULL']):
       raise NotImplementedError('CALL not fully implemented')
     func, *args = rest
-    callargs = self._make_function_args(args)
+    callargs = self._call_helper.make_function_args(args)
     self._call_function(func, callargs)
 
   def byte_CALL_FUNCTION(self, opcode):
@@ -648,76 +576,18 @@ def byte_CALL_FUNCTION(self, opcode):
     callargs = abstract.Args(posargs=tuple(args), frame=self)
     self._call_function(func, callargs)
 
-  def _unpack_starargs(self, starargs) -> abstract.BaseValue:
-    # TODO(b/331853896): This follows vm_utils.ensure_unpacked_starargs, but
-    # does not yet handle indefinite iterables.
-    posargs = starargs.get_atomic_value()
-    if isinstance(posargs, abstract.FunctionArgTuple):
-      # This has already been converted
-      pass
-    elif isinstance(posargs, abstract.FrozenInstance):
-      # This is indefinite; leave it as-is
-      pass
-    elif isinstance(posargs, abstract.Tuple):
-      posargs = abstract.FunctionArgTuple(self._ctx, posargs.constant)
-    elif isinstance(posargs, tuple):
-      posargs = abstract.FunctionArgTuple(self._ctx, posargs)
-    elif abstract.is_any(posargs):
-      return self._ctx.types[tuple].instantiate()
-    else:
-      assert False, f'unexpected posargs type: {posargs}: {type(posargs)}'
-    return posargs
-
-  def _unpack_starstarargs(self, starstarargs) -> abstract.FunctionArgDict:
-    kwargs = starstarargs.get_atomic_value()
-    if isinstance(kwargs, abstract.FunctionArgDict):
-      # This has already been converted
-      pass
-    elif isinstance(kwargs, abstract.Dict):
-      kwargs = kwargs.to_function_arg_dict()
-    elif abstract.is_any(kwargs):
-      kwargs = abstract.FunctionArgDict.any_kwargs(self._ctx)
-    else:
-      assert False, f'unexpected kwargs type: {kwargs}: {type(kwargs)}'
-    return kwargs
-
   def byte_CALL_FUNCTION_EX(self, opcode):
-    # Convert **kwargs
     if opcode.arg & _Flags.CALL_FUNCTION_EX_HAS_KWARGS:
       starstarargs = self._stack.pop()
-      unpacked_starstarargs = self._unpack_starstarargs(starstarargs)
-      # If we have a concrete dict we are unpacking; move it into kwargs (if
-      # not, .constant will be {} anyway, so we don't need to check here.)
-      kwargs = unpacked_starstarargs.constant
-      if unpacked_starstarargs.indefinite:
-        # We also have **args, apart from the concrete kv pairs we moved into
-        # kwargs, that need to be preserved.
-        starstarargs = (
-            abstract.FunctionArgDict.any_kwargs(self._ctx).to_variable())
-      else:
-        starstarargs = None
     else:
-      kwargs = datatypes.EMPTY_MAP
       starstarargs = None
-    # Convert *args
     starargs = self._stack.pop()
-    unpacked_starargs = self._unpack_starargs(starargs)
-    if isinstance(
-        unpacked_starargs, (abstract.Tuple, abstract.FunctionArgTuple)):
-      # We have a concrete tuple we are unpacking; move it into posargs
-      posargs = unpacked_starargs.constant
-      starargs = None
-    else:
-      # We have an indefinite tuple; leave it in starargs
-      posargs = ()
+    callargs = self._call_helper.make_function_args_ex(starargs, starstarargs)
     # Retrieve and call the function
     func = self._stack.pop()
     if self._code.python_version >= (3, 11):
       # the compiler puts a NULL on the stack before function calls
       self._stack.pop_and_discard()
-    callargs = abstract.Args(
-        posargs=posargs, kwargs=kwargs, starargs=starargs,
-        starstarargs=starstarargs, frame=self)
     self._call_function(func, callargs)
 
   def byte_CALL_METHOD(self, opcode):
diff --git a/pytype/rewrite/frame_test.py b/pytype/rewrite/frame_test.py
index 0acdf86fa..e206aab26 100644
--- a/pytype/rewrite/frame_test.py
+++ b/pytype/rewrite/frame_test.py
@@ -30,7 +30,6 @@ def _make_frame(self, src: str, name: str = '__main__') -> frame_lib.Frame:
           name: value.to_variable() for name, value in module_globals.items()}
     else:
       initial_locals = initial_globals = {}
-    self._kw_names = ()
     return frame_lib.Frame(self.ctx, name, code, initial_locals=initial_locals,
                            initial_globals=initial_globals)
 
@@ -655,10 +654,10 @@ def f(x, *, y):
         pass
       f(1, y=2)
     """)
-    self.assertEqual(frame._kw_names, ('y',))
+    self.assertEqual(frame._call_helper._kw_names, ('y',))
     oparg = frame.current_opcode.arg  # pytype: disable=attribute-error
     _, _, *args = frame._stack.popn(oparg + 2)
-    callargs = frame._make_function_args(args)
+    callargs = frame._call_helper.make_function_args(args)
     self.assertConstantVar(callargs.posargs[0], 1)
     self.assertConstantVar(callargs.kwargs['y'], 2)
 
diff --git a/pytype/rewrite/function_call_helper.py b/pytype/rewrite/function_call_helper.py
new file mode 100644
index 000000000..dd34ec3ea
--- /dev/null
+++ b/pytype/rewrite/function_call_helper.py
@@ -0,0 +1,174 @@
+"""Function call helper used by VM frames."""
+
+import itertools
+from typing import Generic, Optional, Sequence, TypeVar
+
+from pytype import datatypes
+from pytype.rewrite import context
+from pytype.rewrite.abstract import abstract
+from pytype.rewrite.flow import variables
+
+_AbstractVariable = variables.Variable[abstract.BaseValue]
+_FrameT = TypeVar('_FrameT')
+
+
+class FunctionCallHelper(Generic[_FrameT]):
+  """Helper for executing function calls."""
+
+  def __init__(self, ctx: context.Context, frame: _FrameT):
+    self._ctx = ctx
+    self._frame = frame
+    # Function kwnames are stored in the vm by KW_NAMES and retrieved by CALL
+    self._kw_names: Sequence[str] = ()
+
+  def set_kw_names(self, kw_names: Sequence[str]) -> None:
+    self._kw_names = kw_names
+
+  def make_function_args(
+      self, args: Sequence[_AbstractVariable],
+  ) -> abstract.Args[_FrameT]:
+    """Unpack args into posargs and kwargs (3.11+)."""
+    if self._kw_names:
+      n_kw = len(self._kw_names)
+      posargs = tuple(args[:-n_kw])
+      kw_vals = args[-n_kw:]
+      kwargs = datatypes.immutabledict(zip(self._kw_names, kw_vals))
+    else:
+      posargs = tuple(args)
+      kwargs = datatypes.EMPTY_MAP
+    self._kw_names = ()
+    return abstract.Args(posargs=posargs, kwargs=kwargs, frame=self._frame)
+
+  def _unpack_starargs(self, starargs: _AbstractVariable) -> abstract.BaseValue:
+    """Unpacks variable positional arguments."""
+    # TODO(b/331853896): This follows vm_utils.ensure_unpacked_starargs, but
+    # does not yet handle indefinite iterables.
+    posargs = starargs.get_atomic_value()
+    if isinstance(posargs, abstract.FunctionArgTuple):
+      # This has already been converted
+      pass
+    elif isinstance(posargs, abstract.FrozenInstance):
+      # This is indefinite; leave it as-is
+      pass
+    elif isinstance(posargs, abstract.Tuple):
+      posargs = abstract.FunctionArgTuple(self._ctx, posargs.constant)
+    elif isinstance(posargs, tuple):
+      posargs = abstract.FunctionArgTuple(self._ctx, posargs)
+    elif abstract.is_any(posargs):
+      return self._ctx.types[tuple].instantiate()
+    else:
+      assert False, f'unexpected posargs type: {posargs}: {type(posargs)}'
+    return posargs
+
+  def _unpack_starstarargs(
+      self, starstarargs: _AbstractVariable) -> abstract.FunctionArgDict:
+    """Unpacks variable keyword arguments."""
+    kwargs = starstarargs.get_atomic_value()
+    if isinstance(kwargs, abstract.FunctionArgDict):
+      # This has already been converted
+      pass
+    elif isinstance(kwargs, abstract.Dict):
+      kwargs = kwargs.to_function_arg_dict()
+    elif abstract.is_any(kwargs):
+      kwargs = abstract.FunctionArgDict.any_kwargs(self._ctx)
+    else:
+      assert False, f'unexpected kwargs type: {kwargs}: {type(kwargs)}'
+    return kwargs
+
+  def make_function_args_ex(
+      self,
+      starargs: _AbstractVariable,
+      starstarargs: Optional[_AbstractVariable],
+  ) -> abstract.Args[_FrameT]:
+    """Makes function args from variable positional and keyword arguments."""
+    # Convert *args
+    unpacked_starargs = self._unpack_starargs(starargs)
+    if isinstance(
+        unpacked_starargs, (abstract.Tuple, abstract.FunctionArgTuple)):
+      # We have a concrete tuple we are unpacking; move it into posargs
+      posargs = unpacked_starargs.constant
+      starargs = None
+    else:
+      # We have an indefinite tuple; leave it in starargs
+      posargs = ()
+    # Convert **kwargs
+    if starstarargs:
+      unpacked_starstarargs = self._unpack_starstarargs(starstarargs)
+      # If we have a concrete dict we are unpacking; move it into kwargs (if
+      # not, .constant will be {} anyway, so we don't need to check here.)
+      kwargs = unpacked_starstarargs.constant
+      if unpacked_starstarargs.indefinite:
+        # We also have **kwargs, apart from the concrete kv pairs we moved into
+        # kwargs, that need to be preserved.
+        starstarargs = (
+            abstract.FunctionArgDict.any_kwargs(self._ctx).to_variable())
+      else:
+        starstarargs = None
+    else:
+      kwargs = datatypes.EMPTY_MAP
+    return abstract.Args(
+        posargs=posargs, kwargs=kwargs, starargs=starargs,
+        starstarargs=starstarargs, frame=self._frame)
+
+  def build_class(
+      self, args: abstract.Args[_FrameT]) -> abstract.InterpreterClass:
+    """Builds a class."""
+    builder = args.posargs[0].get_atomic_value(
+        abstract.InterpreterFunction[_FrameT])
+    name_var = args.posargs[1]
+    name = abstract.get_atomic_constant(name_var, str)
+
+    base_vars = args.posargs[2:]
+    bases = []
+    for base_var in base_vars:
+      try:
+        base = base_var.get_atomic_value(abstract.SimpleClass)
+      except ValueError as e:
+        raise NotImplementedError('Unexpected base class') from e
+      bases.append(base)
+
+    keywords = {}
+    for kw, var in args.kwargs.items():
+      try:
+        val = var.get_atomic_value()
+      except ValueError as e:
+        raise NotImplementedError('Unexpected keyword value') from e
+      keywords[kw] = val
+
+    frame = builder.call(abstract.Args(frame=self._frame))
+    members = dict(frame.final_locals)
+    metaclass_instance = None
+    for metaclass in itertools.chain([keywords.get('metaclass')],
+                                     (base.metaclass for base in bases)):
+      if not metaclass:
+        continue
+      metaclass_new = metaclass.get_attribute('__new__')
+      if (not isinstance(metaclass_new, abstract.BaseFunction) or
+          metaclass_new.full_name == 'builtins.type.__new__'):
+        continue
+      # The metaclass has overridden type.__new__. Invoke the custom __new__
+      # method to construct the class.
+      metaclass_var = metaclass.to_variable()
+      bases_var = abstract.Tuple(self._ctx, tuple(base_vars)).to_variable()
+      members_var = abstract.Dict(
+          self._ctx, {self._ctx.consts[k].to_variable(): v.to_variable()
+                      for k, v in members.items()}
+      ).to_variable()
+      args = abstract.Args(
+          posargs=(metaclass_var, name_var, bases_var, members_var),
+          frame=self._frame)
+      metaclass_instance = metaclass_new.call(args).get_return_value()
+      break
+    if metaclass_instance and metaclass_instance.full_name == name:
+      cls = metaclass_instance
+    else:
+      cls = abstract.InterpreterClass(
+          ctx=self._ctx,
+          name=name,
+          members=members,
+          bases=bases,
+          keywords=keywords,
+          functions=frame.functions,
+          classes=frame.classes,
+      )
+    return cls
diff --git a/pytype/rewrite/function_call_helper_test.py b/pytype/rewrite/function_call_helper_test.py
new file mode 100644
index 000000000..1b970e26f
--- /dev/null
+++ b/pytype/rewrite/function_call_helper_test.py
@@ -0,0 +1,73 @@
+from pytype.rewrite import frame as frame_lib
+from pytype.rewrite.abstract import abstract
+from pytype.rewrite.tests import test_utils
+
+import unittest
+
+
+class TestBase(test_utils.ContextfulTestBase):
+  def setUp(self):
+    super().setUp()
+    frame = frame_lib.Frame(self.ctx, '__main__', test_utils.parse(''),
+                            initial_locals={}, initial_globals={})
+    self.helper = frame._call_helper
+
+
+class MakeFunctionArgsTest(TestBase):
+
+  def test_make_args_positional(self):
+    raw_args = [self.ctx.consts[0].to_variable(),
+                self.ctx.consts[1].to_variable()]
+    args = self.helper.make_function_args(raw_args)
+    self.assertEqual(
+        args, abstract.Args(posargs=tuple(raw_args), frame=self.helper._frame))
+
+  def test_make_args_positional_and_keyword(self):
+    raw_args = [self.ctx.consts[0].to_variable(),
+                self.ctx.consts[1].to_variable()]
+    self.helper.set_kw_names(('x',))
+    args = self.helper.make_function_args(raw_args)
+    expected_args = abstract.Args(posargs=(raw_args[0],),
+                                  kwargs={'x': raw_args[1]},
+                                  frame=self.helper._frame)
+    self.assertEqual(args, expected_args)
+
+  def test_make_args_varargs(self):
+    varargs = abstract.Tuple(self.ctx, (self.ctx.consts[0].to_variable(),))
+    args = self.helper.make_function_args_ex(varargs.to_variable(), None)
+    expected_args = abstract.Args(posargs=(self.ctx.consts[0].to_variable(),),
+                                  starstarargs=None,
+                                  frame=self.helper._frame)
+    self.assertEqual(args, expected_args)
+
+  def test_make_args_kwargs(self):
+    varargs = abstract.Tuple(self.ctx, ())
+    kwargs = abstract.Dict(self.ctx, {self.ctx.consts['k'].to_variable():
+                                      self.ctx.consts['v'].to_variable()})
+    args = self.helper.make_function_args_ex(varargs.to_variable(),
+                                             kwargs.to_variable())
+    expected_args = abstract.Args(
+        posargs=(),
+        kwargs={'k': self.ctx.consts['v'].to_variable()},
+        starargs=None,
+        frame=self.helper._frame)
+    self.assertEqual(args, expected_args)
+
+
+class BuildClassTest(TestBase):
+
+  def test_build(self):
+    code = test_utils.parse('def C(): pass').consts[0]
+    builder = abstract.InterpreterFunction(
+        ctx=self.ctx, name='C', code=code, enclosing_scope=(),
+        parent_frame=self.helper._frame)
+    args = abstract.Args(
+        posargs=(builder.to_variable(), self.ctx.consts['C'].to_variable()),
+        frame=self.helper._frame)
+    self.helper._frame.step()  # initialize frame state
+    cls = self.helper.build_class(args)
+    self.assertEqual(cls.name, 'C')
+
+
+if __name__ == '__main__':
+  unittest.main()

From 94bd0430cbe0bbadc0e33e588468f2dc8482652a Mon Sep 17 00:00:00 2001
From: mdemello 
Date: Mon, 22 Apr 2024 13:12:06 -0700
Subject: [PATCH 15/22] rewrite: Minimal implementation of type subscripting.

Replaces `A[X]` with `A`, mostly so code with subscripted type annotations can
work without crashing.

PiperOrigin-RevId: 627133846
---
 pytype/rewrite/abstract/classes.py |  6 ++++++
 pytype/rewrite/frame.py            | 10 ++++++++++
 pytype/rewrite/tests/test_basic.py |  9 +++++++++
 3 files changed, 25 insertions(+)

diff --git a/pytype/rewrite/abstract/classes.py b/pytype/rewrite/abstract/classes.py
index 5c2ea348d..a55b0c13f 100644
--- a/pytype/rewrite/abstract/classes.py
+++ b/pytype/rewrite/abstract/classes.py
@@ -145,6 +145,12 @@ def mro(self) -> Sequence['SimpleClass']:
     self._mro = mro = mro_lib.MROMerge(mro_bases)
     return mro
 
+  def set_type_parameters(self, params):
+    # A dummy implementation to let type annotations with parameters not crash.
+    del params  # not implemented yet
+    # We eventually want to return a new class with the type parameters set
+    return self
+
 
 class InterpreterClass(SimpleClass):
   """Class defined in the current module."""
diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py
index 92942fe38..95f9f32de 100644
--- a/pytype/rewrite/frame.py
+++ b/pytype/rewrite/frame.py
@@ -610,6 +610,16 @@ def byte_COPY_FREE_VARS(self, opcode):
   def byte_LOAD_BUILD_CLASS(self, opcode):
     self._stack.push(self._ctx.consts.singles['__build_class__'].to_variable())
 
+  def byte_BINARY_SUBSCR(self, opcode):
+    obj_var, subscr_var = self._stack.popn(2)
+    try:
+      obj = obj_var.get_atomic_value(abstract.SimpleClass)
+    except ValueError as e:
+      msg = 'BINARY_SUBSCR only implemented for type annotations.'
+      raise NotImplementedError(msg) from e
+    ret = obj.set_type_parameters(subscr_var)
+    self._stack.push(ret.to_variable())
+
   # ---------------------------------------------------------------
   # Build and extend collections
 
diff --git a/pytype/rewrite/tests/test_basic.py b/pytype/rewrite/tests/test_basic.py
index 4558d6d8c..efe513c02 100644
--- a/pytype/rewrite/tests/test_basic.py
+++ b/pytype/rewrite/tests/test_basic.py
@@ -141,6 +141,15 @@ def test_aliases(self):
       assert_type(path2, "module")
     """)
 
+  def test_type_subscript(self):
+    self.Check("""
+      IntList = list[int]
+      def f(xs: IntList) -> list[str]:
+        return ["hello world"]
+      a = f([1, 2, 3])
+      assert_type(a, list)
+    """)
+
 
 @test_base.skip('Under construction')
 class EnumTest(RewriteTest):

From 3328c34ab980bfe04a9c9b041612306d2b894f12 Mon Sep 17 00:00:00 2001
From: mdemello 
Date: Mon, 22 Apr 2024 16:49:33 -0700
Subject: [PATCH 16/22] rewrite: handle fstrings

PiperOrigin-RevId: 627195965
---
 pytype/rewrite/frame.py            | 15 +++++++++++++++
 pytype/rewrite/tests/test_basic.py | 12 ++++++++++++
 2 files changed, 27 insertions(+)

diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py
index 95f9f32de..8d6fe208c 100644
--- a/pytype/rewrite/frame.py
+++ b/pytype/rewrite/frame.py
@@ -756,6 +756,21 @@ def byte_LIST_TO_TUPLE(self, opcode):
     ret = abstract.Tuple(self._ctx, tuple(target)).to_variable()
     self._stack.push(ret)
 
+  def byte_FORMAT_VALUE(self, opcode):
+    if opcode.arg & pyc_marshal.Flags.FVS_MASK:
+      self._stack.pop_and_discard()
+    # FORMAT_VALUE pops, formats and pushes back a string, so we just need to
+    # push a new string onto the stack.
+    self._stack.pop_and_discard()
+    ret = self._ctx.types[str].instantiate().to_variable()
+    self._stack.push(ret)
+
+  def byte_BUILD_STRING(self, opcode):
+    # Pop n arguments off the stack and build a string out of them
+    self._stack.popn(opcode.arg)
+    ret = self._ctx.types[str].instantiate().to_variable()
+    self._stack.push(ret)
+
   # ---------------------------------------------------------------
   # Branches and jumps
 
diff --git a/pytype/rewrite/tests/test_basic.py b/pytype/rewrite/tests/test_basic.py
index efe513c02..cf8d88435 100644
--- a/pytype/rewrite/tests/test_basic.py
+++ b/pytype/rewrite/tests/test_basic.py
@@ -150,6 +150,18 @@ def f(xs: IntList) -> list[str]:
       assert_type(a, list)
     """)
 
+  def test_fstrings(self):
+    self.Check("""
+      x = 1
+      y = 2
+      z = (
+        f'x = {x}'
+        ' and '
+        f'y = {y}'
+      )
+      assert_type(z, str)
+    """)
+
 
 @test_base.skip('Under construction')
 class EnumTest(RewriteTest):

From 98fa177b26ec70237db29e1732e4fbe2fbd32666 Mon Sep 17 00:00:00 2001
From: rechen 
Date: Tue, 23 Apr 2024 16:36:11 -0700
Subject: [PATCH 17/22] Add ability to register overlay functions with the
 abstract converter.

I started out by trying to implement overlays following our current design,
with an overlay module containing overlay members, but I realized that this
quickly becomes a pain if you want to replace a method of a class (e.g., to get
at EnumMeta.__new__, I need to create overlay objects for the enum module, the
EnumMeta class, and finally the EnumMeta.__new__ method).

Here's a potential alternative design that involves using decorators to
"register" individual objects as overlays and having convert.py do overlay
registry lookups. Let me know what you think.

PiperOrigin-RevId: 627541189
---
 pytype/rewrite/CMakeLists.txt                 |  2 +-
 pytype/rewrite/convert.py                     |  5 ++-
 pytype/rewrite/load_abstract.py               | 27 +++++++-------
 pytype/rewrite/overlays/CMakeLists.txt        | 25 +++++++++++--
 pytype/rewrite/overlays/enum_overlay.py       | 10 ++++++
 pytype/rewrite/overlays/overlays.py           | 33 +++++++++++++++++
 pytype/rewrite/overlays/special_builtins.py   | 25 +++++--------
 .../rewrite/overlays/special_builtins_test.py | 35 +++++++++++--------
 pytype/stubs/builtins/builtins.pytd           |  4 +--
 9 files changed, 114 insertions(+), 52 deletions(-)
 create mode 100644 pytype/rewrite/overlays/enum_overlay.py
 create mode 100644 pytype/rewrite/overlays/overlays.py

diff --git a/pytype/rewrite/CMakeLists.txt b/pytype/rewrite/CMakeLists.txt
index e2afcbd3a..1201fd4f0 100644
--- a/pytype/rewrite/CMakeLists.txt
+++ b/pytype/rewrite/CMakeLists.txt
@@ -58,6 +58,7 @@ py_library(
   DEPS
     pytype.pytd.pytd
     pytype.rewrite.abstract.abstract
+    pytype.rewrite.overlays.overlays
 )
 
 py_test(
@@ -131,7 +132,6 @@ py_library(
     pytype.load_pytd
     pytype.pytd.pytd
     pytype.rewrite.abstract.abstract
-    pytype.rewrite.overlays.overlays
 )
 
 py_test(
diff --git a/pytype/rewrite/convert.py b/pytype/rewrite/convert.py
index 086fd0a2d..7b33d9085 100644
--- a/pytype/rewrite/convert.py
+++ b/pytype/rewrite/convert.py
@@ -4,6 +4,7 @@
 
 from pytype.pytd import pytd
 from pytype.rewrite.abstract import abstract
+from pytype.rewrite.overlays import overlays
 
 
 class _Cache:
@@ -20,6 +21,7 @@ class AbstractConverter:
   def __init__(self, ctx: abstract.ContextType):
     self._ctx = ctx
     self._cache = _Cache()
+    overlays.initialize()
 
   def pytd_class_to_value(self, cls: pytd.Class) -> abstract.SimpleClass:
     """Converts a pytd class to an abstract class."""
@@ -77,7 +79,8 @@ def pytd_function_to_value(
     signatures = tuple(
         abstract.Signature.from_pytd(self._ctx, name, pytd_sig)
         for pytd_sig in func.signatures)
-    abstract_func = abstract.PytdFunction(
+    builder = overlays.FUNCTIONS.get((module, name), abstract.PytdFunction)
+    abstract_func = builder(
         ctx=self._ctx,
         name=name,
         signatures=signatures,
diff --git a/pytype/rewrite/load_abstract.py b/pytype/rewrite/load_abstract.py
index f1ad993d2..5371c5011 100644
--- a/pytype/rewrite/load_abstract.py
+++ b/pytype/rewrite/load_abstract.py
@@ -1,11 +1,10 @@
 """Loads abstract representations of imported objects."""
 
-from typing import Any as _Any, Dict, Tuple, Type
+from typing import Any, Dict, Tuple, Type
 
 from pytype import load_pytd
 from pytype.pytd import pytd
 from pytype.rewrite.abstract import abstract
-from pytype.rewrite.overlays import special_builtins
 
 
 class Constants:
@@ -23,7 +22,7 @@ class Constants:
 
   def __init__(self, ctx: abstract.ContextType):
     self._ctx = ctx
-    self._consts: Dict[_Any, abstract.PythonConstant] = {}
+    self._consts: Dict[Any, abstract.PythonConstant] = {}
     self.singles: Dict[str, abstract.Singleton] = {}
 
     for single in self._SINGLETONS:
@@ -32,13 +31,18 @@ def __init__(self, ctx: abstract.ContextType):
     # We use Any all the time, so alias it for convenience.
     self.Any = self.singles['Any']  # pylint: disable=invalid-name
 
-  def __getitem__(self, const: _Any):
+  def __getitem__(self, const: Any):
     if const not in self._consts:
       self._consts[const] = abstract.PythonConstant(
           self._ctx, const, allow_direct_instantiation=True)
     return self._consts[const]
 
 
+# This is a workaround for a weird pytype crash caused by the use of 'Any' as an
+# attribute name.
+Constants: Any
+
+
 class Types:
   """Wrapper for AbstractLoader.load_raw_types.
 
@@ -49,7 +53,7 @@ class Types:
   def __init__(self, ctx: abstract.ContextType):
     self._ctx = ctx
 
-  def __getitem__(self, raw_type: Type[_Any]) -> abstract.BaseValue:
+  def __getitem__(self, raw_type: Type[Any]) -> abstract.BaseValue:
     return self._ctx.abstract_loader.load_raw_type(raw_type)
 
 
@@ -62,11 +66,6 @@ def __init__(self, ctx: abstract.ContextType, pytd_loader: load_pytd.Loader):
 
     self.consts = Constants(ctx)
     self.types = Types(ctx)
-    self._special_builtins = {
-        'assert_type': special_builtins.AssertType(self._ctx),
-        'reveal_type': special_builtins.RevealType(self._ctx),
-    }
-    self._special_builtins['NoneType'] = self.consts[None]
 
   def _load_pytd_node(self, pytd_node: pytd.Node) -> abstract.BaseValue:
     if isinstance(pytd_node, pytd.Class):
@@ -82,8 +81,8 @@ def _load_pytd_node(self, pytd_node: pytd.Node) -> abstract.BaseValue:
       raise NotImplementedError(f'I do not know how to load {pytd_node}')
 
   def load_builtin(self, name: str) -> abstract.BaseValue:
-    if name in self._special_builtins:
-      return self._special_builtins[name]
+    if name == 'NoneType':
+      return self.consts[None]
     pytd_node = self._pytd_loader.lookup_pytd('builtins', name)
     if isinstance(pytd_node, pytd.Constant):
       # This usage of eval is safe, as we've already checked that this is the
@@ -108,7 +107,7 @@ def get_module_globals(self) -> Dict[str, abstract.BaseValue]:
         '__package__': self.consts[None],
     }
 
-  def load_raw_type(self, typ: Type[_Any]) -> abstract.BaseValue:
+  def load_raw_type(self, typ: Type[Any]) -> abstract.BaseValue:
     """Converts a raw type to an abstract value.
 
     For convenience, this method can also be called via ctx.types[typ].
@@ -125,7 +124,7 @@ def load_raw_type(self, typ: Type[_Any]) -> abstract.BaseValue:
     pytd_node = self._pytd_loader.lookup_pytd(typ.__module__, typ.__name__)
     return self._load_pytd_node(pytd_node)
 
-  def build_tuple(self, const: Tuple[_Any, ...]) -> abstract.Tuple:
+  def build_tuple(self, const: Tuple[Any, ...]) -> abstract.Tuple:
     """Convert a raw constant tuple to an abstract value."""
     ret = []
     for e in const:
diff --git a/pytype/rewrite/overlays/CMakeLists.txt b/pytype/rewrite/overlays/CMakeLists.txt
index 40b4b746d..baa860e3e 100644
--- a/pytype/rewrite/overlays/CMakeLists.txt
+++ b/pytype/rewrite/overlays/CMakeLists.txt
@@ -4,15 +4,37 @@ py_library(
   NAME
     overlays
   DEPS
+    ._overlays
+    .enum_overlay
     .special_builtins
 )
 
+py_library(
+  NAME
+    _overlays
+  SRCS
+    overlays.py
+  DEPS
+    pytype.rewrite.abstract.abstract
+)
+
+py_library(
+  NAME
+    enum_overlay
+  SRCS
+    enum_overlay.py
+  DEPS
+    ._overlays
+    pytype.rewrite.abstract.abstract
+)
+
 py_library(
   NAME
     special_builtins
   SRCS
     special_builtins.py
   DEPS
+    ._overlays
     pytype.rewrite.abstract.abstract
 )
 
@@ -22,7 +44,6 @@ py_test(
   SRCS
     special_builtins_test.py
   DEPS
-    .special_builtins
-    pytype.rewrite.context
     pytype.rewrite.abstract.abstract
+    pytype.rewrite.tests.test_utils
 )
diff --git a/pytype/rewrite/overlays/enum_overlay.py b/pytype/rewrite/overlays/enum_overlay.py
new file mode 100644
index 000000000..e1cb2f9cd
--- /dev/null
+++ b/pytype/rewrite/overlays/enum_overlay.py
@@ -0,0 +1,10 @@
+"""Enum overlay."""
+from pytype.rewrite.abstract import abstract
+from pytype.rewrite.overlays import overlays
+
+
+@overlays.register_function('enum', 'EnumMeta.__new__')
+class EnumMetaNew(abstract.PytdFunction):
+
+  def call_with_mapped_args(self, *args, **kwargs):
+    raise NotImplementedError()
diff --git a/pytype/rewrite/overlays/overlays.py b/pytype/rewrite/overlays/overlays.py
new file mode 100644
index 000000000..c15ba7baf
--- /dev/null
+++ b/pytype/rewrite/overlays/overlays.py
@@ -0,0 +1,33 @@
+"""Overlays on top of abstract values that provide extra typing information.
+
+An overlay generates extra typing information that cannot be expressed in a pyi
+file. For example, collections.namedtuple is a factory method that generates
+class definitions at runtime. An overlay is used to generate these classes.
+"""
+from typing import Callable, Dict, Tuple, Type, TypeVar
+
+from pytype.rewrite.abstract import abstract
+
+_FuncTypeType = Type[abstract.PytdFunction]
+_FuncTypeTypeT = TypeVar('_FuncTypeTypeT', bound=_FuncTypeType)
+
+FUNCTIONS: Dict[Tuple[str, str], _FuncTypeType] = {}
+
+
+def register_function(
+    module: str, name: str) -> Callable[[_FuncTypeTypeT], _FuncTypeTypeT]:
+  def register(func_builder: _FuncTypeTypeT) -> _FuncTypeTypeT:
+    FUNCTIONS[(module, name)] = func_builder
+    return func_builder
+  return register
+
+
+def initialize():
+  # Imports overlay implementations so that ther @register_* decorators execute
+  # and populate the overlay registry.
+  # pylint: disable=g-import-not-at-top,unused-import
+  # pytype: disable=import-error
+  from pytype.rewrite.overlays import enum_overlay
+  from pytype.rewrite.overlays import special_builtins
+  # pytype: enable=import-error
+  # pylint: enable=g-import-not-at-top,unused-import
diff --git a/pytype/rewrite/overlays/special_builtins.py b/pytype/rewrite/overlays/special_builtins.py
index b8b942bcc..0dae10faa 100644
--- a/pytype/rewrite/overlays/special_builtins.py
+++ b/pytype/rewrite/overlays/special_builtins.py
@@ -3,6 +3,7 @@
 from typing import Optional, Sequence
 
 from pytype.rewrite.abstract import abstract
+from pytype.rewrite.overlays import overlays
 
 
 def _stack(
@@ -11,20 +12,15 @@ def _stack(
   return frame.stack if frame else None
 
 
-class AssertType(abstract.SimpleFunction[abstract.SimpleReturn]):
+@overlays.register_function('builtins', 'assert_type')
+class AssertType(abstract.PytdFunction):
   """assert_type implementation."""
 
-  def __init__(self, ctx: abstract.ContextType):
-    signature = abstract.Signature(
-        ctx=ctx, name='assert_type', param_names=('variable', 'type'))
-    super().__init__(
-        ctx=ctx, name='assert_type', signatures=(signature,), module='builtins')
-
   def call_with_mapped_args(
       self, mapped_args: abstract.MappedArgs[abstract.FrameType],
   ) -> abstract.SimpleReturn:
-    var = mapped_args.argdict['variable']
-    typ = mapped_args.argdict['type']
+    var = mapped_args.argdict['val']
+    typ = mapped_args.argdict['typ']
     pp = self._ctx.errorlog.pretty_printer
     actual = pp.print_var_type(var, node=None)
     try:
@@ -37,19 +33,14 @@ def call_with_mapped_args(
     return abstract.SimpleReturn(self._ctx.consts[None])
 
 
-class RevealType(abstract.SimpleFunction[abstract.SimpleReturn]):
+@overlays.register_function('builtins', 'reveal_type')
+class RevealType(abstract.PytdFunction):
   """reveal_type implementation."""
 
-  def __init__(self, ctx: abstract.ContextType):
-    signature = abstract.Signature(
-        ctx=ctx, name='reveal_type', param_names=('object',))
-    super().__init__(
-        ctx=ctx, name='reveal_type', signatures=(signature,), module='builtins')
-
   def call_with_mapped_args(
       self, mapped_args: abstract.MappedArgs[abstract.FrameType],
   ) -> abstract.SimpleReturn:
-    obj = mapped_args.argdict['object']
+    obj = mapped_args.argdict['obj']
     stack = _stack(mapped_args.frame)
     self._ctx.errorlog.reveal_type(stack, node=None, var=obj)
     return abstract.SimpleReturn(self._ctx.consts[None])
diff --git a/pytype/rewrite/overlays/special_builtins_test.py b/pytype/rewrite/overlays/special_builtins_test.py
index a0d2571e0..413ce4336 100644
--- a/pytype/rewrite/overlays/special_builtins_test.py
+++ b/pytype/rewrite/overlays/special_builtins_test.py
@@ -1,31 +1,36 @@
-from pytype.rewrite import context
 from pytype.rewrite.abstract import abstract
-from pytype.rewrite.overlays import special_builtins
+from pytype.rewrite.tests import test_utils
 
 import unittest
 
 
-class AssertTypeTest(unittest.TestCase):
+class SpecialBuiltinsTest(test_utils.ContextfulTestBase):
+
+  def load_builtin_function(self, name: str) -> abstract.PytdFunction:
+    func = self.ctx.abstract_loader.load_builtin(name)
+    assert isinstance(func, abstract.PytdFunction)
+    return func
+
+
+class AssertTypeTest(SpecialBuiltinsTest):
 
   def test_types_match(self):
-    ctx = context.Context()
-    assert_type_func = special_builtins.AssertType(ctx)
-    var = ctx.consts[0].to_variable()
-    typ = abstract.SimpleClass(ctx, 'int', {}).to_variable()
+    assert_type_func = self.load_builtin_function('assert_type')
+    var = self.ctx.consts[0].to_variable()
+    typ = abstract.SimpleClass(self.ctx, 'int', {}).to_variable()
     ret = assert_type_func.call(abstract.Args(posargs=(var, typ)))
-    self.assertEqual(ret.get_return_value(), ctx.consts[None])
-    self.assertEqual(len(ctx.errorlog), 0)  # pylint: disable=g-generic-assert
+    self.assertEqual(ret.get_return_value(), self.ctx.consts[None])
+    self.assertEqual(len(self.ctx.errorlog), 0)  # pylint: disable=g-generic-assert
 
 
-class RevealTypeTest(unittest.TestCase):
+class RevealTypeTest(SpecialBuiltinsTest):
 
   def test_basic(self):
-    ctx = context.Context()
-    reveal_type_func = special_builtins.RevealType(ctx)
-    var = ctx.consts[0].to_variable()
+    reveal_type_func = self.load_builtin_function('reveal_type')
+    var = self.ctx.consts[0].to_variable()
     ret = reveal_type_func.call(abstract.Args(posargs=(var,)))
-    self.assertEqual(ret.get_return_value(), ctx.consts[None])
-    self.assertEqual(len(ctx.errorlog), 1)
+    self.assertEqual(ret.get_return_value(), self.ctx.consts[None])
+    self.assertEqual(len(self.ctx.errorlog), 1)
 
 
 if __name__ == '__main__':
diff --git a/pytype/stubs/builtins/builtins.pytd b/pytype/stubs/builtins/builtins.pytd
index 42ff7dac3..283ee2dda 100644
--- a/pytype/stubs/builtins/builtins.pytd
+++ b/pytype/stubs/builtins/builtins.pytd
@@ -45,7 +45,7 @@ def all(iterable) -> bool: ...
 def any(iterable) -> bool: ...
 def ascii(__obj: object) -> str: ...
 def apply(object: Callable, *args, **kwargs) -> NoneType: ...
-def assert_type(*args): ...
+def assert_type(val, typ, /): ...
 def bin(number: Union[int, float]) -> str: ...
 def breakpoint(*args, **kwargs) -> NoneType: ...
 def callable(obj) -> bool: ...
@@ -170,7 +170,7 @@ def reduce(function: Callable[..., _T], iterable: Iterable, initial) -> _T: ...
 # No reload() in Python3
 def reload(mod: module) -> module: ...
 def repr(x) -> str: ...
-def reveal_type(__obj: _T) -> _T: ...
+def reveal_type(obj: _T, /) -> _T: ...
 def round(number: Union[int, float, typing.SupportsRound]) -> int: ...
 def round(number: Union[int, float, typing.SupportsRound], *args, **kwargs) -> float: ...
 def setattr(object, name: str, value) -> NoneType: ...

From 592d4fa754ebc968e7ce92291619a3ebb526808f Mon Sep 17 00:00:00 2001
From: rechen 
Date: Tue, 23 Apr 2024 17:48:30 -0700
Subject: [PATCH 18/22] rewrite: simplify types related to function calls.

Removes the 'indefinite' attribute from abstract.Dict and adds it to
abstract.FunctionArgTuple. This way, FunctionArgTuple and FunctionArgDict have
a fully consistent typing story, but the general Tuple and Dict representations
are kept relatively simple.

I reran ./rewrite/tests/run.sh, and this change has no effect on our
pass/fail/error numbers.

PiperOrigin-RevId: 627558177
---
 pytype/rewrite/abstract/containers.py  | 19 ++++----------
 pytype/rewrite/abstract/functions.py   | 24 ++++++++----------
 pytype/rewrite/abstract/internal.py    | 25 ++++++++++---------
 pytype/rewrite/frame.py                |  4 ++-
 pytype/rewrite/function_call_helper.py | 34 +++++++++++++-------------
 5 files changed, 49 insertions(+), 57 deletions(-)

diff --git a/pytype/rewrite/abstract/containers.py b/pytype/rewrite/abstract/containers.py
index cd3d868c0..db4c25d53 100644
--- a/pytype/rewrite/abstract/containers.py
+++ b/pytype/rewrite/abstract/containers.py
@@ -37,45 +37,36 @@ class Dict(base.PythonConstant[_Dict[_Variable, _Variable]]):
 
   def __init__(
       self, ctx: base.ContextType, constant: _Dict[_Variable, _Variable],
-      indefinite: bool = False
   ):
     assert isinstance(constant, dict), constant
     super().__init__(ctx, constant)
-    self.indefinite = indefinite
 
   def __repr__(self):
-    indef = '+' if self.indefinite else ''
-    return f'Dict({indef}{self.constant!r})'
-
-  @classmethod
-  def any_dict(cls, ctx):
-    return cls(ctx, {}, indefinite=True)
+    return f'Dict({self.constant!r})'
 
   @classmethod
   def from_function_arg_dict(
       cls, ctx: base.ContextType, val: internal.FunctionArgDict
   ) -> 'Dict':
+    assert not val.indefinite
     new_constant = {
         ctx.consts[k].to_variable(): v
         for k, v in val.constant.items()
     }
-    return cls(ctx, new_constant, val.indefinite)
+    return cls(ctx, new_constant)
 
   def setitem(self, key: _Variable, val: _Variable) -> 'Dict':
     return Dict(self._ctx, {**self.constant, key: val})
 
   def update(self, val: 'Dict') -> base.BaseValue:
-    return Dict(
-        self._ctx, {**self.constant, **val.constant},
-        self.indefinite or val.indefinite
-    )
+    return Dict(self._ctx, {**self.constant, **val.constant})
 
   def to_function_arg_dict(self) -> internal.FunctionArgDict:
     new_const = {
         utils.get_atomic_constant(k, str): v
         for k, v in self.constant.items()
     }
-    return internal.FunctionArgDict(self._ctx, new_const, self.indefinite)
+    return internal.FunctionArgDict(self._ctx, new_const)
 
 
 class Set(base.PythonConstant[_Set[_Variable]]):
diff --git a/pytype/rewrite/abstract/functions.py b/pytype/rewrite/abstract/functions.py
index b333bd84b..d2fa26491 100644
--- a/pytype/rewrite/abstract/functions.py
+++ b/pytype/rewrite/abstract/functions.py
@@ -138,20 +138,16 @@ def _get_required_posarg_count(self) -> int:
 
   def _unpack_starargs(self) -> Tuple[Tuple[_Var, ...], Optional[_Var]]:
     """Adjust *args and posargs based on function signature."""
+    starargs_var = self.args.starargs
     posargs = self.args.posargs
-    indef_starargs = False
-    if self.args.starargs is None:
+    if starargs_var is None:
       # There is nothing to unpack, but we might want to move unused posargs
       # into sig.varargs_name
-      starargs_tuple = ()
+      starargs = internal.FunctionArgTuple(self._ctx, ())
     else:
-      try:
-        starargs_tuple = _unpack_splats(self.args.get_concrete_starargs())
-      except ValueError:
-        # We don't have a concrete starargs. We still need to use this to fill
-        # in missing posargs or absorb extra ones.
-        starargs_tuple = ()
-        indef_starargs = True
+      # Do not catch the error; this should always succeed
+      starargs = starargs_var.get_atomic_value(internal.FunctionArgTuple)
+    starargs_tuple = starargs.constant
 
     # Attempt to adjust the starargs into the missing posargs.
     all_posargs = posargs + starargs_tuple
@@ -201,18 +197,18 @@ def _unpack_starargs(self) -> Tuple[Tuple[_Var, ...], Optional[_Var]]:
         # match all k+2 to Any
         mid = [self._ctx.consts.Any.to_variable() for _ in range(posarg_delta)]
       return tuple(pre + mid + post), None
-    elif posarg_delta and indef_starargs:
+    elif posarg_delta and starargs.indefinite:
       # Fill in *required* posargs if needed; don't override the default posargs
       # with indef starargs yet because we aren't capturing the type of *args
       if posarg_delta > 0:
-        extra = self._expand_typed_star(self.args.starargs, posarg_delta)
+        extra = self._expand_typed_star(starargs_var, posarg_delta)
         return posargs + tuple(extra), None
       elif self.sig.varargs_name:
-        return posargs[:n_required_posargs], self.args.starargs
+        return posargs[:n_required_posargs], starargs_var
       else:
         # We have too many posargs *and* no *args in the sig to absorb them, so
         # just do nothing and handle the error downstream.
-        return posargs, self.args.starargs
+        return posargs, starargs_var
 
     else:
       # We have **kwargs but no *args in the invocation
diff --git a/pytype/rewrite/abstract/internal.py b/pytype/rewrite/abstract/internal.py
index 78f74f7d0..73bbd830a 100644
--- a/pytype/rewrite/abstract/internal.py
+++ b/pytype/rewrite/abstract/internal.py
@@ -1,7 +1,7 @@
 """Abstract types used internally by pytype."""
 
 import collections
-from typing import Dict, Tuple
+from typing import Dict, Optional, Tuple
 
 import immutabledict
 
@@ -15,17 +15,24 @@
 class FunctionArgTuple(base.BaseValue):
   """Representation of a function arg tuple."""
 
-  def __init__(self, ctx: base.ContextType, constant: Tuple[_Variable, ...]):
+  def __init__(
+      self,
+      ctx: base.ContextType,
+      constant: Tuple[_Variable, ...] = (),
+      indefinite: bool = False,
+  ):
     super().__init__(ctx)
     assert isinstance(constant, tuple), constant
     self.constant = constant
+    self.indefinite = indefinite
 
   def __repr__(self):
-    return f"FunctionArgTuple({self.constant!r})"
+    indef = "+" if self.indefinite else ""
+    return f"FunctionArgTuple({indef}{self.constant!r})"
 
   @property
   def _attrs(self):
-    return (self.constant,)
+    return (self.constant, self.indefinite)
 
 
 class FunctionArgDict(base.BaseValue):
@@ -34,19 +41,15 @@ class FunctionArgDict(base.BaseValue):
   def __init__(
       self,
       ctx: base.ContextType,
-      constant: Dict[str, _Variable],
-      indefinite: bool = False
+      constant: Optional[Dict[str, _Variable]] = None,
+      indefinite: bool = False,
   ):
     self._ctx = ctx
+    constant = constant or {}
     self._check_keys(constant)
     self.constant = constant
     self.indefinite = indefinite
 
-  @classmethod
-  def any_kwargs(cls, ctx: base.ContextType):
-    """Return a new kwargs dict with only indefinite values."""
-    return cls(ctx, {}, indefinite=True)
-
   def _check_keys(self, constant: Dict[str, _Variable]):
     """Runtime check to ensure the invariant."""
     assert isinstance(constant, dict), constant
diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py
index 8d6fe208c..d6f857056 100644
--- a/pytype/rewrite/frame.py
+++ b/pytype/rewrite/frame.py
@@ -722,6 +722,8 @@ def _unpack_dict_update(
     if isinstance(val, abstract.Dict):
       return val
     elif isinstance(val, abstract.FunctionArgDict):
+      if val.indefinite:
+        return None
       return abstract.Dict.from_function_arg_dict(self._ctx, val)
     elif abstract.is_any(val):
       return None
@@ -745,7 +747,7 @@ def byte_DICT_UPDATE(self, opcode):
       # The update var has multiple possible values, or no constant, so we
       # cannot merge it into the constant dict. We also don't know if existing
       # items have been overwritten, so we need to return a new 'any' dict.
-      ret = abstract.Dict.any_dict(self._ctx)
+      ret = self._ctx.types[dict].instantiate()
     else:
       ret = target.update(update)
     self._replace_atomic_stack_value(count, ret)
diff --git a/pytype/rewrite/function_call_helper.py b/pytype/rewrite/function_call_helper.py
index dd34ec3ea..14b525c0b 100644
--- a/pytype/rewrite/function_call_helper.py
+++ b/pytype/rewrite/function_call_helper.py
@@ -39,7 +39,8 @@ def make_function_args(
     self._kw_names = ()
     return abstract.Args(posargs=posargs, kwargs=kwargs, frame=self._frame)
 
-  def _unpack_starargs(self, starargs: _AbstractVariable) -> abstract.BaseValue:
+  def _unpack_starargs(
+      self, starargs: _AbstractVariable) -> abstract.FunctionArgTuple:
     """Unpacks variable positional arguments."""
     # TODO(b/331853896): This follows vm_utils.ensure_unpacked_starargs, but
     # does not yet handle indefinite iterables.
@@ -48,14 +49,14 @@ def _unpack_starargs(self, starargs: _AbstractVariable) -> abstract.BaseValue:
       # This has already been converted
       pass
     elif isinstance(posargs, abstract.FrozenInstance):
-      # This is indefinite; leave it as-is
-      pass
+      # This is indefinite.
+      posargs = abstract.FunctionArgTuple(self._ctx, indefinite=True)
     elif isinstance(posargs, abstract.Tuple):
       posargs = abstract.FunctionArgTuple(self._ctx, posargs.constant)
     elif isinstance(posargs, tuple):
       posargs = abstract.FunctionArgTuple(self._ctx, posargs)
     elif abstract.is_any(posargs):
-      return self._ctx.types[tuple].instantiate()
+      posargs = abstract.FunctionArgTuple(self._ctx, indefinite=True)
     else:
       assert False, f'unexpected posargs type: {posargs}: {type(posargs)}'
     return posargs
@@ -67,10 +68,13 @@ def _unpack_starstarargs(
     if isinstance(kwargs, abstract.FunctionArgDict):
       # This has already been converted
       pass
+    elif isinstance(kwargs, abstract.FrozenInstance):
+      # This is indefinite.
+      kwargs = abstract.FunctionArgDict(self._ctx, indefinite=True)
     elif isinstance(kwargs, abstract.Dict):
       kwargs = kwargs.to_function_arg_dict()
     elif abstract.is_any(kwargs):
-      kwargs = abstract.FunctionArgDict.any_kwargs(self._ctx)
+      kwargs = abstract.FunctionArgDict(self._ctx, indefinite=True)
     else:
       assert False, f'unexpected kwargs type: {kwargs}: {type(kwargs)}'
     return kwargs
@@ -83,26 +87,22 @@ def make_function_args_ex(
     """Makes function args from variable positional and keyword arguments."""
     # Convert *args
     unpacked_starargs = self._unpack_starargs(starargs)
-    if isinstance(
-        unpacked_starargs, (abstract.Tuple, abstract.FunctionArgTuple)):
+    if unpacked_starargs.indefinite:
+      # We have an indefinite tuple; leave it in starargs
+      posargs = ()
+      starargs = unpacked_starargs.to_variable()
+    else:
       # We have a concrete tuple we are unpacking; move it into posargs
       posargs = unpacked_starargs.constant
       starargs = None
-    else:
-      # We have an indefinite tuple; leave it in starargs
-      posargs = ()
     # Convert **kwargs
     if starstarargs:
       unpacked_starstarargs = self._unpack_starstarargs(starstarargs)
-      # If we have a concrete dict we are unpacking; move it into kwargs (if
-      # not, .constant will be {} anyway, so we don't need to check here.)
-      kwargs = unpacked_starstarargs.constant
       if unpacked_starstarargs.indefinite:
-        # We also have **kwargs, apart from the concrete kv pairs we moved into
-        # kwargs, that need to be preserved.
-        starstarargs = (
-            abstract.FunctionArgDict.any_kwargs(self._ctx).to_variable())
+        kwargs = datatypes.EMPTY_MAP
+        starstarargs = unpacked_starstarargs.to_variable()
       else:
+        kwargs = unpacked_starstarargs.constant
         starstarargs = None
     else:
       kwargs = datatypes.EMPTY_MAP

From 5a3306bd29e02801255b94420c5f9121d4752cab Mon Sep 17 00:00:00 2001
From: Rebecca Chen 
Date: Wed, 24 Apr 2024 10:03:36 -0700
Subject: [PATCH 19/22] Call super().__init__() in FunctionArgDict.

---
 pytype/rewrite/abstract/internal.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pytype/rewrite/abstract/internal.py b/pytype/rewrite/abstract/internal.py
index 73bbd830a..ff0535fbb 100644
--- a/pytype/rewrite/abstract/internal.py
+++ b/pytype/rewrite/abstract/internal.py
@@ -44,7 +44,7 @@ def __init__(
       constant: Optional[Dict[str, _Variable]] = None,
       indefinite: bool = False,
   ):
-    self._ctx = ctx
+    super().__init__(ctx)
     constant = constant or {}
     self._check_keys(constant)
     self.constant = constant

From bd3b303097814ad3751a303555dfc3e6a331b780 Mon Sep 17 00:00:00 2001
From: Rebecca Chen 
Date: Wed, 24 Apr 2024 10:54:21 -0700
Subject: [PATCH 20/22] Implement byte_CALL_FUNCTION_KW.

---
 pytype/rewrite/frame.py | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/pytype/rewrite/frame.py b/pytype/rewrite/frame.py
index d6f857056..4fec990a1 100644
--- a/pytype/rewrite/frame.py
+++ b/pytype/rewrite/frame.py
@@ -573,7 +573,17 @@ def byte_CALL(self, opcode):
   def byte_CALL_FUNCTION(self, opcode):
     args = self._stack.popn(opcode.arg)
     func = self._stack.pop()
-    callargs = abstract.Args(posargs=tuple(args), frame=self)
+    callargs = self._call_helper.make_function_args(args)
+    self._call_function(func, callargs)
+
+  def byte_CALL_FUNCTION_KW(self, opcode):
+    kwnames_var = self._stack.pop()
+    args = self._stack.popn(opcode.arg)
+    func = self._stack.pop()
+    kwnames = [abstract.get_atomic_constant(key, str)
+               for key in abstract.get_atomic_constant(kwnames_var, tuple)]
+    self._call_helper.set_kw_names(kwnames)
+    callargs = self._call_helper.make_function_args(args)
     self._call_function(func, callargs)
 
   def byte_CALL_FUNCTION_EX(self, opcode):

From b5e788e6780edf33782f5d8d1ab9bd9bdff8956a Mon Sep 17 00:00:00 2001
From: Rebecca Chen 
Date: Wed, 24 Apr 2024 12:08:06 -0700
Subject: [PATCH 21/22] Skip test in 3.8 that uses BUILD_TUPLE_UNPACK_WITH_CALL
 opcode.

---
 pytype/rewrite/tests/test_args.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/pytype/rewrite/tests/test_args.py b/pytype/rewrite/tests/test_args.py
index 8d78d946c..c0726a401 100644
--- a/pytype/rewrite/tests/test_args.py
+++ b/pytype/rewrite/tests/test_args.py
@@ -68,6 +68,7 @@ def g(a, b, x, y):
       f(*a, **b)
     """)
 
+  @test_utils.skipBeforePy((3, 9), 'Relies on 3.9+ bytecode')
   def test_unpack_posargs(self):
     self.Check("""
       def f(x, y, z):

From a2dc5add7efc82e28783dc1f76acab62eb857803 Mon Sep 17 00:00:00 2001
From: Rebecca Chen 
Date: Wed, 24 Apr 2024 12:50:07 -0700
Subject: [PATCH 22/22] Use macos-13 to work around bug in macos-latest.

For https://github.com/google/pytype/issues/1621.
---
 .github/workflows/build.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 1c2976d07..e9b45f824 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -55,7 +55,7 @@ jobs:
     # used to build Python there.  It is the latter that determines
     # the wheel's platform tag.
     # https://github.com/actions/virtual-environments/issues/696
-    runs-on: macos-latest
+    runs-on: macos-13
     strategy:
       matrix:
         python_version: ['3.8', '3.9', '3.10', '3.11']