Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Update Incorrect type-hints #2303

Open
danphenderson opened this issue Feb 15, 2024 · 2 comments
Open

[Feature]: Update Incorrect type-hints #2303

danphenderson opened this issue Feb 15, 2024 · 2 comments

Comments

@danphenderson
Copy link
Contributor

danphenderson commented Feb 15, 2024

馃殌 Feature Request

There is both inconsistent and incorrect use of type-hints throughout the package. This could be classified as a bug report.

Inconsistent

The 'typing' module is accessed in the following ways:

  1. import typing
  2. from typing import ..
    Ideally, updates should be made to stick with the second access pattern. This will greatly improve readability.

Then when we declare Union types as optional, we should consistently useOptional[Union[<Type>,...]] or Union[<Type>,..., None]. There is a mix of both approaches in the package.

Incorrect

The incorrect use of optional type-hints was first documented here. Essentially, params/attributes of the form some_param: SomeType = None need to be updated to some_param: Optional[SomeType] = None.

Motivation

Improve readability and enable better static analysis.

@danphenderson
Copy link
Contributor Author

danphenderson commented Feb 22, 2024

Below is a script (with test coverage) that can solve the problem at hand. It uses libCST to parse src code into a Concrete Syntax Tree and make the modifications. Note, that manual intervention would need to occur to update the typing imports after the script is run (this issue is caught by the existing pre-commit hooks).

I would be more than happy to execute the script and submit a PR, assuming that the community finds this to be a valuable contribution. I would appreciate any/all feedback.

from difflib import unified_diff
from typing import Dict, List, Optional, Sequence, Union
from pathlib import Path
from libcst import (
    AnnAssign,
    Annotation,
    Attribute,
    CSTNode,
    Index,
    Name,
    Param,
    Subscript,
    SubscriptElement,
    ensure_type,
    parse_expression,
    parse_module,
)
from libcst.codemod import CodemodContext as _CodemodContext
from libcst.codemod import ContextAwareTransformer as _Codemod
from libcst.matchers import Name as mName
from libcst.matchers import matches


class CodemodContext(_CodemodContext):
    def __init__(self):
        super().__init__()
        self.code_modifications: List = []
        self.made_changes: bool = False


class Codemod(_Codemod):

    def __init__(self, context: Optional[CodemodContext] = None) -> None:
        context = context or CodemodContext()
        super().__init__(context)

    @property
    def code_modifications(self) -> List:
        return getattr(self.context, "code_modifications", [])

    @property
    def made_changes(self) -> bool:
        return getattr(self.context, "made_changes", False)

    def report_changes(self, original_node: CSTNode, updated_node: CSTNode, *, print_changes: bool = False) -> None:
        if original_node.deep_equals(updated_node):
            return
        origonal_code = getattr(original_node, "code", "")
        updated_code = getattr(updated_node, "code", "")
        code_diff = unified_diff(origonal_code.splitlines(), updated_code.splitlines(), lineterm="")
        self.code_modifications.append(code_diff)


class EnforceOptionallNoneTypes(Codemod):
    """
    Enforce the use of 'Optional' in all 'NoneType' annotated assignments.


    The transformation will remove None from the Union and wrap the remaining type(s)
    with Optional. If there's only one other type besides None, it will be Optional[Type].
    If there are multiple types besides None, it will be Optional[Union[Type1, Type2, ...]].
    """

    def leave_Subscript(self, original_node: Subscript, updated_node: Subscript) -> Subscript:
        # Check if it's a Union type
        if matches(updated_node.value, mName("Union")):
            union_element = updated_node.slice

            # Extract the types in the Union
            union_types = self._extract_union_types(union_element)

            # Check if None is one of the types in the Union
            if "None" in union_types:
                # Remove 'None' and handle single or multiple remaining types
                remaining_types = [t for t in union_types if t != "None"]
                if len(remaining_types) == 1:
                    # Single type + None becomes Optional[SingleType]
                    new_node = parse_expression(f"Optional[{remaining_types[0]}]")
                else:
                    # Multiple types + None becomes Optional[Union[{', '.join(remaining_types)}]]"
                    new_node = parse_expression(f"Optional[Union[{', '.join(remaining_types)}]]")
                setattr(self.context, "made_changes", True)
                return new_node  # type: ignore

        return updated_node

    def _extract_union_types(self, subscript_slice: Sequence[SubscriptElement]):
        types = []
        for element in subscript_slice:
            element_index = ensure_type(element.slice, Index)
            types.append(self._node_to_string(element_index.value))
        return types

    def _node_to_string(self, node: CSTNode) -> str:
        """
        Convert a CSTNode to its string representation, handling different node types.

        Performs recursive depth-first search to handle nested nodes.
        """
        if isinstance(node, Name):
            return node.value
        elif isinstance(node, Subscript):
            value = self._node_to_string(node.value)
            # Handle subscript slices (e.g., List[int])
            slice_parts = [self._node_to_string(s.slice.value) for s in node.slice]  # type: ignore
            return f"{value}[{', '.join(slice_parts)}]"
        elif isinstance(node, Attribute):
            value = self._node_to_string(node.value)
            attr = self._node_to_string(node.attr)
            return f"{value}.{attr}"
        else:
            # This might need to be extended to handle other node types as necessary
            raise ValueError(f"Unsupported node type: {type(node)}")


class InferOptionalNoneTypes(Codemod):
    """
    Infer that a type is 'Optional' in annotated assignments to None.

    This transformer will wrap the type annotation with Optional if the variable is assigned to None
    or if a function parameter has a default value of None. Note, if the annotation is already Optional, or Any,
    it will remain unchanged.
    """

    def leave_AnnAssign(self, original_node: AnnAssign, updated_node: AnnAssign) -> AnnAssign:
        if matches(getattr(updated_node, "value"), mName("None")):
            if not self._is_optional_annotation(updated_node.annotation):
                new_annotation = self._wrap_with_optional(updated_node.annotation)
                new_node = updated_node.with_changes(annotation=new_annotation)
                setattr(self.context, "made_changes", True)
                return new_node
        return updated_node

    def leave_Param(self, original_node: Param, updated_node: Param) -> Param:
        if updated_node.default is not None and matches(updated_node.default, mName("None")):
            if updated_node.annotation is not None and not self._is_optional_annotation(updated_node.annotation):
                new_annotation = self._wrap_with_optional(updated_node.annotation)
                new_node = updated_node.with_changes(annotation=new_annotation)
                setattr(self.context, "made_changes", True)
                return new_node
        return updated_node

    def _is_optional_annotation(self, annotation: Annotation) -> bool:
        if isinstance(annotation.annotation, Subscript) and matches(annotation.annotation.value, mName("Optional")):
            return True
        elif isinstance(annotation.annotation, Name) and matches(annotation.annotation, mName("Any")):
            return True
        return False

    def _wrap_with_optional(self, annotation: Annotation) -> Annotation:
        optional_annotation = Annotation(
            annotation=Subscript(value=Name(value="Optional"), slice=[SubscriptElement(slice=Index(value=annotation.annotation))])
        )
        return optional_annotation


def _parse_context(context: Optional[Union[CodemodContext, Dict[str, Union[bool, List]]]]) -> CodemodContext:
    context = context or CodemodContext()
    if isinstance(context, Dict):
        context = CodemodContext(**context)
    return context


def apply(code: str, codemod: Union[EnforceOptionallNoneTypes, InferOptionalNoneTypes]) -> str:
    module = parse_module(code)  # Parse the entire code as a module
    modified_tree = module.visit(codemod)
    return modified_tree.code


def apply_all(code: str, context: Optional[Union[CodemodContext, Dict[str, Union[bool, List]]]] = None) -> str:
    context = _parse_context(context)
    code = apply(code, EnforceOptionallNoneTypes(context))
    code = apply(code, InferOptionalNoneTypes(context))
    return code


def process_files_in_directory(directory_path: Path):
    for file in directory_path.glob('**/*.py'):
        with open(file, 'r', encoding='utf-8') as f:
            original_content = f.read()

        transformed_content = apply_all(original_content)

        if original_content != transformed_content:
            with open(file, 'w', encoding='utf-8') as f:
                f.write(transformed_content)
            print(f"Transformed {file}")



def main() -> None:
    # Assumeing the script is run from the root of the repository
    directory = Path.cwd() / Path("playwright")
    if not directory.exists():
        raise FileNotFoundError("Directory 'playwright' not found")

    process_files_in_directory(
        Path.cwd() / Path("playwright"),
    )

if __name__ == "__main__":
    main()

Test coverage:

import pytest

@pytest.mark.parametrize(
   "source_code, expected_code",
   [
       ("a: Union[int, None]", "a: Optional[int]"),
       ("a: Union[Dict[str, int], None]", "a: Optional[Dict[str, int]]"),
       ("a: Union[str, int, None]", "a: Optional[Union[str, int]]"),
       ("a: Union[Dict[str, int], List[int], None]", "a: Optional[Union[Dict[str, int], List[int]]]"),
       ("a: Union[str, int]", "a: Union[str, int]"),
       ("a: Union[str, int, float]", "a: Union[str, int, float]"),
       ("a: Union[str, Union[int, None]]", "a: Union[str, Optional[int]]"),
       ("a: Union[str, Union[int, float]]", "a: Union[str, Union[int, float]]"),
       ("a: Union[str, Union[int, Union[Dict[str, int], None]]]", "a: Union[str, Union[int, Optional[Dict[str, int]]]]"),
       ("a: Union[str, Union[int, Union[float, bool]]]", "a: Union[str, Union[int, Union[float, bool]]]"),
       ("def func(a: Union[str, None]) -> Union[int, None]: pass", "def func(a: Optional[str]) -> Optional[int]: pass"),
       ("async def func(a: Union[str, None]) -> Union[int, None]: pass", "async def func(a: Optional[str]) -> Optional[int]: pass"),
       ("var: Union[str, None] = 'hello'", "var: Optional[str] = 'hello'"),
       ("class A: a: Union[str, None] = 'hello'", "class A: a: Optional[str] = 'hello'"),
       ("def func(): var: Union[str, None] = 'hello'", "def func(): var: Optional[str] = 'hello'"),
   ],
)
def test_union_to_optional_transform(source_code, expected_code):
   assert apply(source_code, EnforceOptionallNoneTypes()) == expected_code


@pytest.mark.parametrize(
   "source_code, expected_code",
   [
       ("var: int = None", "var: Optional[int] = None"),
       ("var: Optional[int] = None", "var: Optional[int] = None"),
       ("var: Any = None", "var: Any = None"),
       ("var: Optional[Dict[str, int]] = None", "var: Optional[Dict[str, int]] = None"),
       ("var = None", "var = None"),
       ("var: Dict[str, List[int]] = None", "var: Optional[Dict[str, List[int]]] = None"),
       ("def func(var: int = None): pass", "def func(var: Optional[int] = None): pass"),
       ("class A: var: int = None", "class A: var: Optional[int] = None"),
       ("async def func(var: int = None): pass", "async def func(var: Optional[int] = None): pass"),
   ],
)
def test_enforce_optional_transform(source_code, expected_code):
   assert apply(source_code, InferOptionalNoneTypes()) == expected_code

Thanks,
Daniel H

@dangotbanned
Copy link

dangotbanned commented Apr 9, 2024

Not sure if this warrants a separate issue, but it is related in annotating optional parameters.

The return type of cookies (python)/ cookies (protocol)/ NetworkCookie (protocol) isn't represented correctly in the TypedDict Cookie.

Leading to the following on access:

Could not access item in TypedDict "name" is not a required key in "Cookie", so access may result in runtime exception
Could not access item in TypedDict "value" is not a required key in "Cookie", so access may result in runtime exception
Could not access item in TypedDict "expires" is not a required key in "Cookie", so access may result in runtime exception

From my understanding, none of the items should be optional - so Cookie should be updated to:

class Cookie(TypedDict):
    name: str
    value: str
    domain: str
    path: str
    expires: float
    httpOnly: bool
    secure: bool
    sameSite: Literal["Lax", "None", "Strict"]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants