Skip to content

Commit

Permalink
Merge pull request #1612 from google/google_sync
Browse files Browse the repository at this point in the history
Google sync
  • Loading branch information
rchen152 committed Apr 9, 2024
2 parents 256fbf6 + e83e546 commit d1c01f3
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 36 deletions.
14 changes: 0 additions & 14 deletions pytype/abstract/_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,20 +162,6 @@ def __init__(self, ctx, ret_var, node):
self.merge_instance_type_parameter(
node, abstract_utils.V, ret_var.AssignToNewVariable(node))

@classmethod
def make(cls, ctx, func, node):
"""Get return type of coroutine function."""
assert func.signature.has_return_annotation
ret_val = func.signature.annotations["return"]
if func.code.has_coroutine():
ret_var = ret_val.instantiate(node)
elif func.code.has_iterable_coroutine():
ret_var = ret_val.get_formal_type_parameter(
abstract_utils.V).instantiate(node)
else:
assert False, f"Function {func.name} is not a coroutine"
return cls(ctx, ret_var, node)


class Iterator(_instance_base.Instance, mixin.HasSlots):
"""A representation of instances of iterators."""
Expand Down
20 changes: 18 additions & 2 deletions pytype/abstract/_interpreter_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,19 @@ def __init__(self, name, def_opcode, code, f_locals, f_globals, defaults,
self.nonstararg_count = self.code.argcount + self.code.kwonlyargcount
signature = self._build_signature(name, annotations)
super().__init__(signature, ctx)
if not self.code.has_coroutine():
# Sanity check: has_iterable_coroutine() is set by the types.coroutine
# decorator, so it should always be False at function creation time.
assert not self.code.has_iterable_coroutine()
elif signature.has_return_annotation:
params = {
abstract_utils.T: ctx.convert.unsolvable,
abstract_utils.T2: ctx.convert.unsolvable,
abstract_utils.V: signature.annotations["return"],
}
coroutine_type = _classes.ParameterizedClass(
ctx.convert.coroutine_type, params, ctx)
signature.annotations["return"] = coroutine_type
self._check_signature()
self._update_signature_scope_from_closure()
self.last_frame = None # for BuildClass
Expand Down Expand Up @@ -496,7 +509,7 @@ def call(self, node, func, args, alias_map=None, new_locals=False,
if "return" not in annotations:
return node, self.ctx.new_unsolvable(node)
ret = self.ctx.vm.init_class(node, annotations["return"])
if self.is_coroutine():
if self.is_unannotated_coroutine():
ret = _instances.Coroutine(self.ctx, ret, node).to_variable(node)
return node, ret

Expand Down Expand Up @@ -630,7 +643,7 @@ def call(self, node, func, args, alias_map=None, new_locals=False,
log.info("%s Start running frame for %r", indent, self.name)
node2, ret = self.ctx.vm.run_frame(frame, node, annotated_locals)
log.info("%s Finished running frame for %r", indent, self.name)
if self.is_coroutine():
if self.is_unannotated_coroutine():
ret = _instances.Coroutine(self.ctx, ret, node2).to_variable(node2)
node_after_call = node2
self._inner_cls_check(frame)
Expand Down Expand Up @@ -732,6 +745,9 @@ def property_get(self, callself, is_class=False):
def is_coroutine(self):
return self.code.has_coroutine() or self.code.has_iterable_coroutine()

def is_unannotated_coroutine(self):
return self.is_coroutine() and not self.signature.has_return_annotation

def has_empty_body(self):
# TODO(mdemello): Optimise this.
ops = list(self.code.code_iter)
Expand Down
5 changes: 1 addition & 4 deletions pytype/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,7 @@ def annotations_to_instance_types(self, node, annots):
def _function_call_to_return_type(self, node, v, seen_return, num_returns):
"""Get a function call's pytd return type."""
if v.signature.has_return_annotation:
if v.is_coroutine():
ret = abstract.Coroutine.make(self.ctx, v, node).to_pytd_type(node)
else:
ret = v.signature.annotations["return"].to_pytd_type_of_instance(node)
ret = v.signature.annotations["return"].to_pytd_type_of_instance(node)
else:
ret = seen_return.data.to_pytd_type(node)
if isinstance(ret, pytd.NothingType) and num_returns == 1:
Expand Down
10 changes: 10 additions & 0 deletions pytype/overlays/asyncio_types_overlay.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Implementation of special members of types and asyncio module."""

from pytype.abstract import abstract
from pytype.abstract import abstract_utils
from pytype.overlays import overlay


Expand Down Expand Up @@ -44,4 +45,13 @@ def call(self, node, func, args, alias_map=None):
(self.module == "asyncio" or
self.module == "types" and code.has_generator())):
code.set_iterable_coroutine()
if funcv.signature.has_return_annotation:
ret = funcv.signature.annotations["return"]
params = {
param: ret.get_formal_type_parameter(param)
for param in (abstract_utils.T, abstract_utils.T2, abstract_utils.V)
}
coroutine_type = abstract.ParameterizedClass(
self.ctx.convert.coroutine_type, params, self.ctx)
funcv.signature.annotations["return"] = coroutine_type
return node, func_var
10 changes: 2 additions & 8 deletions pytype/overriding_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,15 +504,9 @@ def check_overriding_members(cls, bases, members, matcher, ctx):

class_signature_map = {}
for method_name, method in class_method_map.items():
if method.is_coroutine():
if method.is_unannotated_coroutine():
annotations = dict(method.signature.annotations)
coroutine_params = {
abstract_utils.T: ctx.convert.unsolvable,
abstract_utils.T2: ctx.convert.unsolvable,
abstract_utils.V: annotations.get("return", ctx.convert.unsolvable),
}
annotations["return"] = abstract.ParameterizedClass(
ctx.convert.coroutine_type, coroutine_params, ctx)
annotations["return"] = ctx.convert.coroutine_type
signature = method.signature._replace(annotations=annotations)
else:
signature = method.signature
Expand Down
4 changes: 2 additions & 2 deletions pytype/rewrite/abstract/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def set_attribute(self, name: str, value: base.BaseValue) -> None:

def instantiate(self) -> 'FrozenInstance':
"""Creates an instance of this class."""
log.info('Instantiating class %s', self.full_name)
if self._canonical_instance:
log.info('Reusing cached instance of class %s', self.full_name)
log.info('Getting cached instance of class %s', self.full_name)
return self._canonical_instance
log.info('Instantiating class %s', self.full_name)
for setup_method_name in self.setup_methods:
setup_method = self.get_attribute(setup_method_name)
if isinstance(setup_method, functions_lib.InterpreterFunction):
Expand Down
4 changes: 4 additions & 0 deletions pytype/rewrite/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def run(self) -> None:
while True:
try:
self.step()
self._log_stack()
except frame_base.FrameConsumedError:
break
assert not self._stack
Expand All @@ -157,6 +158,9 @@ def run(self) -> None:
name: abstract.join_values(self._ctx, var.values)
for name, var in self._final_locals.items()})

def _log_stack(self):
log.debug('stack: %r', self._stack)

def store_local(self, name: str, var: _AbstractVariable) -> None:
self._current_state.store_local(name, var)

Expand Down
2 changes: 1 addition & 1 deletion pytype/rewrite/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ py_test(
SRCS
test_basic.py
DEPS
pytype.rewrite.tests.test_utils
.test_utils
pytype.tests.test_base
)

Expand Down
26 changes: 26 additions & 0 deletions pytype/tests/test_async_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,5 +210,31 @@ def f4() -> Coroutine[Any, Any, int]: ...
def func() -> Coroutine[Any, Any, str]: ...
""")

def test_callable(self):
self.Check("""
from typing import Awaitable, Callable
async def f1(a: str) -> str:
return a
async def f2(fun: Callable[[str], Awaitable[str]]) -> str:
return await fun('a')
async def f3() -> None:
await f2(f1)
""")

def test_callable_with_imported_func(self):
with self.DepTree([("foo.py", """
async def f1(a: str) -> str:
return a
""")]):
self.Check("""
import foo
from typing import Awaitable, Callable
async def f2(fun: Callable[[str], Awaitable[str]]) -> str:
return await fun('a')
async def f3() -> None:
await f2(foo.f1)
""")


if __name__ == "__main__":
test_base.main()
13 changes: 12 additions & 1 deletion pytype/tests/test_coroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,17 @@ async def caller():
self.assertErrorRegexes(
errors, {"e1": r"Awaitable.*int", "e2": r"y: str.*y: int"})

def test_generator_based_coroutine_bad_annotation(self):
self.CheckWithErrors("""
import types
from typing import Generator
@types.coroutine
def f() -> Generator[str, None, str]:
yield 1 # bad-return-type
return 123 # bad-return-type
""")

def test_awaitable_pyi(self):
ty = self.Infer("""
from typing import Awaitable, Generator
Expand Down Expand Up @@ -226,7 +237,7 @@ class SubAwaitable(BaseAwaitable):
def c1() -> Coroutine[Any, Any, int]: ...
def c2() -> Coroutine[Any, Any, int]: ...
def c2() -> Coroutine[int, None, int]: ...
def f1() -> Coroutine[Any, Any, None]: ...
def f2(x: Awaitable[int]) -> Coroutine[Any, Any, int]: ...
def f3() -> Coroutine[Any, Any, None]: ...
Expand Down
13 changes: 9 additions & 4 deletions pytype/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2733,13 +2733,18 @@ def byte_RETURN_VALUE(self, state, op):
"""Get and check the return value."""
state, var = state.pop()
if self.frame.check_return:
if self.frame.f_code.has_generator():
if (self.frame.f_code.has_generator() or
self.frame.f_code.has_coroutine() or
self.frame.f_code.has_iterable_coroutine()):
ret_type = self.frame.allowed_returns
assert ret_type is not None
self._check_return(state.node, var,
ret_type.get_formal_type_parameter(abstract_utils.V))
allowed_return = ret_type.get_formal_type_parameter(abstract_utils.V)
elif not self.frame.f_code.has_async_generator():
self._check_return(state.node, var, self.frame.allowed_returns)
allowed_return = self.frame.allowed_returns
else:
allowed_return = None
if allowed_return:
self._check_return(state.node, var, allowed_return)
if (self.ctx.options.no_return_any and
any(d == self.ctx.convert.unsolvable for d in var.data)):
self.ctx.errorlog.any_return_type(self.frames)
Expand Down

0 comments on commit d1c01f3

Please sign in to comment.