Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace __dir__ on TFF modules to only return those symbols explicitly imported. #4574

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
123 changes: 121 additions & 2 deletions tensorflow_federated/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
# limitations under the License.
"""The TensorFlow Federated library."""

import sys
import ast as _ast
import inspect as _inspect
import sys as _sys

from absl import logging as _logging

# pylint: disable=g-importing-member
from tensorflow_federated.python import aggregators
Expand Down Expand Up @@ -69,7 +73,7 @@
from tensorflow_federated.version import __version__
# pylint: enable=g-importing-member

if sys.version_info < (3, 9):
if _sys.version_info < (3, 9):
raise RuntimeError('TFF only supports Python versions 3.9 or later.')

# TODO: b/305743962 - Remove deprecated API.
Expand All @@ -90,3 +94,118 @@
# these to locals().
del python # pylint: disable=undefined-variable
del proto # pylint: disable=undefined-variable

# Update the __dir__ attribute on all TFF modules so that autocompletion tools
# that rely on __dir__ (such as JEDI, IPython, and Colab) only display the
# public APIs symbols shown on tensorflow.org/federated.
_self = _sys.modules[__name__]
_ModuleType = __import__('types').ModuleType


def _get_imported_symbols(module: _ModuleType) -> tuple[str, ...]:
"""Gets a list of only the symbols from explicit import statements."""

class ImportNodeVisitor(_ast.NodeVisitor):
"""An `ast.Visitor` that collects the names of imported symbols."""

def __init__(self):
self.imported_symbols = []

def _add_imported_symbol(self, node):
for alias in node.names:
name = alias.asname or alias.name
if name == '*':
continue
if '.' in name:
continue
if name.startswith('_'):
continue
self.imported_symbols.append(name)

def visit_Import(self, node): # pylint: disable=invalid-name
self._add_imported_symbol(node)

def visit_ImportFrom(self, node): # pylint: disable=invalid-name
self._add_imported_symbol(node)

try:
tree = _ast.parse(_inspect.getsource(module))
except OSError:
_logging.debug('Failed to get source code for: %s, skipping...', module)
tree = None
if tree is None:
return ()

visitor = ImportNodeVisitor()
visitor.visit(tree)

return tuple(sorted(visitor.imported_symbols))


def _update_dir_method(
module: _ModuleType, seen_modules: set[_ModuleType]
) -> None:
"""Overwrites `__dir__` to only return the explicit public API.

The "public API" is defined as:
- modules, functions, and classes imported be packages (in __init__.py
files)
- functions and classes imported by modules (but not modules imported by
modules)

This definition matches the documentation generated at
http://www.tensorflow.org/federated/api_docs/python/tff.

To improve JEDI, IPython, and Colab autocomplete consistency with
tensorflow.org/federated public API documentation, this method traverses
the modules on import and replaces `__dir__` (the source of autocompletions)
with only those symbols that were explicitly imported.

Otherwise, Python imports will implicitly import any submodule in the package,
exposing it via `__dir__`, which is undesirable.

Args:
module: A module to bind a new `__dir__` method to.
seen_modules: A set of modules that have already been operated on, to reduce
tree traversal .
"""
public_attributes = tuple(
getattr(module, attr, None)
for attr in dir(module)
if not attr.startswith('_')
)

def _is_tff_submodule(attribute):
return (
attribute is not None
and _inspect.ismodule(attribute)
and 'tensorflow_federated' in getattr(attribute, '__file__', '')
)

tff_submodules = tuple(
a
for a in public_attributes
if _is_tff_submodule(a) and a not in seen_modules
)
for submodule in tff_submodules:
_update_dir_method(submodule, seen_modules)
seen_modules.add(submodule)
imported_symbols = _get_imported_symbols(module)
# Filter out imported modules from modules that are not themselves packages.
is_package = hasattr(module, '__path__')

def is_module_imported_by_module(symbol_name: str) -> bool:
return not is_package and _inspect.ismodule(
getattr(module, symbol_name, None)
)

imported_symbols = [
symbol_name
for symbol_name in imported_symbols
if not is_module_imported_by_module(symbol_name)
]
_logging.debug('Module %s had imported symbols %s', module, imported_symbols)
module.__dir__ = lambda: imported_symbols


_update_dir_method(_self, seen_modules=set([]))