896 lines
33 KiB
Python
896 lines
33 KiB
Python
"""Utilities for mypy.stubgen, mypy.stubgenc, and mypy.stubdoc modules."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os.path
|
|
import re
|
|
import sys
|
|
import traceback
|
|
from abc import abstractmethod
|
|
from collections import defaultdict
|
|
from collections.abc import Iterable, Iterator, Mapping
|
|
from contextlib import contextmanager
|
|
from typing import Final, overload
|
|
|
|
from mypy_extensions import mypyc_attr
|
|
|
|
import mypy.options
|
|
from mypy.modulefinder import ModuleNotFoundReason
|
|
from mypy.moduleinspect import InspectError, ModuleInspect
|
|
from mypy.nodes import PARAM_SPEC_KIND, TYPE_VAR_TUPLE_KIND, ClassDef, FuncDef, TypeAliasStmt
|
|
from mypy.stubdoc import ArgSig, FunctionSig
|
|
from mypy.types import (
|
|
AnyType,
|
|
NoneType,
|
|
Type,
|
|
TypeList,
|
|
TypeStrVisitor,
|
|
UnboundType,
|
|
UnionType,
|
|
UnpackType,
|
|
)
|
|
|
|
# Modules that may fail when imported, or that may have side effects (fully qualified).
|
|
NOT_IMPORTABLE_MODULES = ()
|
|
|
|
# Typing constructs to be replaced by their builtin equivalents.
|
|
TYPING_BUILTIN_REPLACEMENTS: Final = {
|
|
# From typing
|
|
"typing.Text": "builtins.str",
|
|
"typing.Tuple": "builtins.tuple",
|
|
"typing.List": "builtins.list",
|
|
"typing.Dict": "builtins.dict",
|
|
"typing.Set": "builtins.set",
|
|
"typing.FrozenSet": "builtins.frozenset",
|
|
"typing.Type": "builtins.type",
|
|
# From typing_extensions
|
|
"typing_extensions.Text": "builtins.str",
|
|
"typing_extensions.Tuple": "builtins.tuple",
|
|
"typing_extensions.List": "builtins.list",
|
|
"typing_extensions.Dict": "builtins.dict",
|
|
"typing_extensions.Set": "builtins.set",
|
|
"typing_extensions.FrozenSet": "builtins.frozenset",
|
|
"typing_extensions.Type": "builtins.type",
|
|
}
|
|
|
|
|
|
class CantImport(Exception):
|
|
def __init__(self, module: str, message: str) -> None:
|
|
self.module = module
|
|
self.message = message
|
|
|
|
|
|
def walk_packages(
|
|
inspect: ModuleInspect, packages: list[str], verbose: bool = False
|
|
) -> Iterator[str]:
|
|
"""Iterates through all packages and sub-packages in the given list.
|
|
|
|
This uses runtime imports (in another process) to find both Python and C modules.
|
|
For Python packages we simply pass the __path__ attribute to pkgutil.walk_packages() to
|
|
get the content of the package (all subpackages and modules). However, packages in C
|
|
extensions do not have this attribute, so we have to roll out our own logic: recursively
|
|
find all modules imported in the package that have matching names.
|
|
"""
|
|
for package_name in packages:
|
|
if package_name in NOT_IMPORTABLE_MODULES:
|
|
print(f"{package_name}: Skipped (blacklisted)")
|
|
continue
|
|
if verbose:
|
|
print(f"Trying to import {package_name!r} for runtime introspection")
|
|
try:
|
|
prop = inspect.get_package_properties(package_name)
|
|
except InspectError:
|
|
if verbose:
|
|
tb = traceback.format_exc()
|
|
sys.stderr.write(tb)
|
|
report_missing(package_name)
|
|
continue
|
|
yield prop.name
|
|
if prop.is_c_module:
|
|
# Recursively iterate through the subpackages
|
|
yield from walk_packages(inspect, prop.subpackages, verbose)
|
|
else:
|
|
yield from prop.subpackages
|
|
|
|
|
|
def find_module_path_using_sys_path(module: str, sys_path: list[str]) -> str | None:
|
|
relative_candidates = (
|
|
module.replace(".", "/") + ".py",
|
|
os.path.join(module.replace(".", "/"), "__init__.py"),
|
|
)
|
|
for base in sys_path:
|
|
for relative_path in relative_candidates:
|
|
path = os.path.join(base, relative_path)
|
|
if os.path.isfile(path):
|
|
return path
|
|
return None
|
|
|
|
|
|
def find_module_path_and_all_py3(
|
|
inspect: ModuleInspect, module: str, verbose: bool
|
|
) -> tuple[str | None, list[str] | None] | None:
|
|
"""Find module and determine __all__ for a Python 3 module.
|
|
|
|
Return None if the module is a C or pyc-only module.
|
|
Return (module_path, __all__) if it is a Python module.
|
|
Raise CantImport if import failed.
|
|
"""
|
|
if module in NOT_IMPORTABLE_MODULES:
|
|
raise CantImport(module, "")
|
|
|
|
# TODO: Support custom interpreters.
|
|
if verbose:
|
|
print(f"Trying to import {module!r} for runtime introspection")
|
|
try:
|
|
mod = inspect.get_package_properties(module)
|
|
except InspectError as e:
|
|
# Fall back to finding the module using sys.path.
|
|
path = find_module_path_using_sys_path(module, sys.path)
|
|
if path is None:
|
|
raise CantImport(module, str(e)) from e
|
|
return path, None
|
|
if mod.is_c_module:
|
|
return None
|
|
return mod.file, mod.all
|
|
|
|
|
|
@contextmanager
|
|
def generate_guarded(
|
|
mod: str, target: str, ignore_errors: bool = True, verbose: bool = False
|
|
) -> Iterator[None]:
|
|
"""Ignore or report errors during stub generation.
|
|
|
|
Optionally report success.
|
|
"""
|
|
if verbose:
|
|
print(f"Processing {mod}")
|
|
try:
|
|
yield
|
|
except Exception as e:
|
|
if not ignore_errors:
|
|
raise e
|
|
else:
|
|
# --ignore-errors was passed
|
|
print("Stub generation failed for", mod, file=sys.stderr)
|
|
else:
|
|
if verbose:
|
|
print(f"Created {target}")
|
|
|
|
|
|
def report_missing(mod: str, message: str | None = "", traceback: str = "") -> None:
|
|
if message:
|
|
message = " with error: " + message
|
|
print(f"{mod}: Failed to import, skipping{message}")
|
|
|
|
|
|
def fail_missing(mod: str, reason: ModuleNotFoundReason) -> None:
|
|
if reason is ModuleNotFoundReason.NOT_FOUND:
|
|
clarification = "(consider using --search-path)"
|
|
elif reason is ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS:
|
|
clarification = "(module likely exists, but is not PEP 561 compatible)"
|
|
else:
|
|
clarification = f"(unknown reason '{reason}')"
|
|
raise SystemExit(f"Can't find module '{mod}' {clarification}")
|
|
|
|
|
|
@overload
|
|
def remove_misplaced_type_comments(source: bytes) -> bytes: ...
|
|
|
|
|
|
@overload
|
|
def remove_misplaced_type_comments(source: str) -> str: ...
|
|
|
|
|
|
def remove_misplaced_type_comments(source: str | bytes) -> str | bytes:
|
|
"""Remove comments from source that could be understood as misplaced type comments.
|
|
|
|
Normal comments may look like misplaced type comments, and since they cause blocking
|
|
parse errors, we want to avoid them.
|
|
"""
|
|
if isinstance(source, bytes):
|
|
# This gives us a 1-1 character code mapping, so it's roundtrippable.
|
|
text = source.decode("latin1")
|
|
else:
|
|
text = source
|
|
|
|
# Remove something that looks like a variable type comment but that's by itself
|
|
# on a line, as it will often generate a parse error (unless it's # type: ignore).
|
|
text = re.sub(r'^[ \t]*# +type: +["\'a-zA-Z_].*$', "", text, flags=re.MULTILINE)
|
|
|
|
# Remove something that looks like a function type comment after docstring,
|
|
# which will result in a parse error.
|
|
text = re.sub(r'""" *\n[ \t\n]*# +type: +\(.*$', '"""\n', text, flags=re.MULTILINE)
|
|
text = re.sub(r"''' *\n[ \t\n]*# +type: +\(.*$", "'''\n", text, flags=re.MULTILINE)
|
|
|
|
# Remove something that looks like a badly formed function type comment.
|
|
text = re.sub(r"^[ \t]*# +type: +\([^()]+(\)[ \t]*)?$", "", text, flags=re.MULTILINE)
|
|
|
|
if isinstance(source, bytes):
|
|
return text.encode("latin1")
|
|
else:
|
|
return text
|
|
|
|
|
|
def common_dir_prefix(paths: list[str]) -> str:
|
|
if not paths:
|
|
return "."
|
|
cur = os.path.dirname(os.path.normpath(paths[0]))
|
|
for path in paths[1:]:
|
|
while True:
|
|
path = os.path.dirname(os.path.normpath(path))
|
|
if (cur + os.sep).startswith(path + os.sep):
|
|
cur = path
|
|
break
|
|
return cur or "."
|
|
|
|
|
|
class AnnotationPrinter(TypeStrVisitor):
|
|
"""Visitor used to print existing annotations in a file.
|
|
|
|
The main difference from TypeStrVisitor is a better treatment of
|
|
unbound types.
|
|
|
|
Notes:
|
|
* This visitor doesn't add imports necessary for annotations, this is done separately
|
|
by ImportTracker.
|
|
* It can print all kinds of types, but the generated strings may not be valid (notably
|
|
callable types) since it prints the same string that reveal_type() does.
|
|
* For Instance types it prints the fully qualified names.
|
|
"""
|
|
|
|
# TODO: Generate valid string representation for callable types.
|
|
# TODO: Use short names for Instances.
|
|
def __init__(
|
|
self,
|
|
stubgen: BaseStubGenerator,
|
|
known_modules: list[str] | None = None,
|
|
local_modules: list[str] | None = None,
|
|
) -> None:
|
|
super().__init__(options=mypy.options.Options())
|
|
self.stubgen = stubgen
|
|
self.known_modules = known_modules
|
|
self.local_modules = local_modules or ["builtins"]
|
|
|
|
def visit_any(self, t: AnyType) -> str:
|
|
s = super().visit_any(t)
|
|
self.stubgen.import_tracker.require_name(s)
|
|
return s
|
|
|
|
def visit_unbound_type(self, t: UnboundType) -> str:
|
|
s = t.name
|
|
fullname = self.stubgen.resolve_name(s)
|
|
if fullname == "typing.Union":
|
|
return " | ".join([item.accept(self) for item in t.args])
|
|
if fullname == "typing.Optional":
|
|
if len(t.args) == 1:
|
|
return f"{t.args[0].accept(self)} | None"
|
|
return self.stubgen.add_name("_typeshed.Incomplete")
|
|
if fullname in TYPING_BUILTIN_REPLACEMENTS:
|
|
s = self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=True)
|
|
if self.known_modules is not None and "." in s:
|
|
# see if this object is from any of the modules that we're currently processing.
|
|
# reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
|
|
for module_name in self.local_modules + sorted(self.known_modules, reverse=True):
|
|
if s.startswith(module_name + "."):
|
|
if module_name in self.local_modules:
|
|
s = s[len(module_name) + 1 :]
|
|
arg_module = module_name
|
|
break
|
|
else:
|
|
arg_module = s[: s.rindex(".")]
|
|
if arg_module not in self.local_modules:
|
|
self.stubgen.import_tracker.add_import(arg_module, require=True)
|
|
elif s == "NoneType":
|
|
# when called without analysis all types are unbound, so this won't hit
|
|
# visit_none_type().
|
|
s = "None"
|
|
else:
|
|
self.stubgen.import_tracker.require_name(s)
|
|
if t.args:
|
|
s += f"[{self.args_str(t.args)}]"
|
|
elif t.empty_tuple_index:
|
|
s += "[()]"
|
|
return s
|
|
|
|
def visit_none_type(self, t: NoneType) -> str:
|
|
return "None"
|
|
|
|
def visit_type_list(self, t: TypeList) -> str:
|
|
return f"[{self.list_str(t.items)}]"
|
|
|
|
def visit_union_type(self, t: UnionType) -> str:
|
|
return " | ".join([item.accept(self) for item in t.items])
|
|
|
|
def visit_unpack_type(self, t: UnpackType) -> str:
|
|
if self.options.python_version >= (3, 11):
|
|
return f"*{t.type.accept(self)}"
|
|
return super().visit_unpack_type(t)
|
|
|
|
def args_str(self, args: Iterable[Type]) -> str:
|
|
"""Convert an array of arguments to strings and join the results with commas.
|
|
|
|
The main difference from list_str is the preservation of quotes for string
|
|
arguments
|
|
"""
|
|
types = ["builtins.bytes", "builtins.str"]
|
|
res = []
|
|
for arg in args:
|
|
arg_str = arg.accept(self)
|
|
if isinstance(arg, UnboundType) and arg.original_str_fallback in types:
|
|
res.append(f"'{arg_str}'")
|
|
else:
|
|
res.append(arg_str)
|
|
return ", ".join(res)
|
|
|
|
|
|
class ClassInfo:
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
self_var: str,
|
|
docstring: str | None = None,
|
|
cls: type | None = None,
|
|
parent: ClassInfo | None = None,
|
|
) -> None:
|
|
self.name = name
|
|
self.self_var = self_var
|
|
self.docstring = docstring
|
|
self.cls = cls
|
|
self.parent = parent
|
|
|
|
|
|
class FunctionContext:
|
|
def __init__(
|
|
self,
|
|
module_name: str,
|
|
name: str,
|
|
docstring: str | None = None,
|
|
is_abstract: bool = False,
|
|
class_info: ClassInfo | None = None,
|
|
) -> None:
|
|
self.module_name = module_name
|
|
self.name = name
|
|
self.docstring = docstring
|
|
self.is_abstract = is_abstract
|
|
self.class_info = class_info
|
|
self._fullname: str | None = None
|
|
|
|
@property
|
|
def fullname(self) -> str:
|
|
if self._fullname is None:
|
|
if self.class_info:
|
|
parents = []
|
|
class_info: ClassInfo | None = self.class_info
|
|
while class_info is not None:
|
|
parents.append(class_info.name)
|
|
class_info = class_info.parent
|
|
namespace = ".".join(reversed(parents))
|
|
self._fullname = f"{self.module_name}.{namespace}.{self.name}"
|
|
else:
|
|
self._fullname = f"{self.module_name}.{self.name}"
|
|
return self._fullname
|
|
|
|
|
|
def infer_method_ret_type(name: str) -> str | None:
|
|
"""Infer return types for known special methods"""
|
|
if name.startswith("__") and name.endswith("__"):
|
|
name = name[2:-2]
|
|
if name in ("float", "bool", "bytes", "int", "complex", "str"):
|
|
return name
|
|
# Note: __eq__ and co may return arbitrary types, but bool is good enough for stubgen.
|
|
elif name in ("eq", "ne", "lt", "le", "gt", "ge", "contains"):
|
|
return "bool"
|
|
elif name in ("len", "length_hint", "index", "hash", "sizeof", "trunc", "floor", "ceil"):
|
|
return "int"
|
|
elif name in ("format", "repr"):
|
|
return "str"
|
|
elif name in ("init", "setitem", "del", "delitem"):
|
|
return "None"
|
|
return None
|
|
|
|
|
|
def infer_method_arg_types(
|
|
name: str, self_var: str = "self", arg_names: list[str] | None = None
|
|
) -> list[ArgSig] | None:
|
|
"""Infer argument types for known special methods"""
|
|
args: list[ArgSig] | None = None
|
|
if name.startswith("__") and name.endswith("__"):
|
|
if arg_names and len(arg_names) >= 1 and arg_names[0] == "self":
|
|
arg_names = arg_names[1:]
|
|
|
|
name = name[2:-2]
|
|
if name == "exit":
|
|
if arg_names is None:
|
|
arg_names = ["type", "value", "traceback"]
|
|
if len(arg_names) == 3:
|
|
arg_types = [
|
|
"type[BaseException] | None",
|
|
"BaseException | None",
|
|
"types.TracebackType | None",
|
|
]
|
|
args = [
|
|
ArgSig(name=arg_name, type=arg_type)
|
|
for arg_name, arg_type in zip(arg_names, arg_types)
|
|
]
|
|
if args is not None:
|
|
return [ArgSig(name=self_var)] + args
|
|
return None
|
|
|
|
|
|
@mypyc_attr(allow_interpreted_subclasses=True)
|
|
class SignatureGenerator:
|
|
"""Abstract base class for extracting a list of FunctionSigs for each function."""
|
|
|
|
def remove_self_type(
|
|
self, inferred: list[FunctionSig] | None, self_var: str
|
|
) -> list[FunctionSig] | None:
|
|
"""Remove type annotation from self/cls argument"""
|
|
if inferred:
|
|
for signature in inferred:
|
|
if signature.args:
|
|
if signature.args[0].name == self_var:
|
|
signature.args[0].type = None
|
|
return inferred
|
|
|
|
@abstractmethod
|
|
def get_function_sig(
|
|
self, default_sig: FunctionSig, ctx: FunctionContext
|
|
) -> list[FunctionSig] | None:
|
|
"""Return a list of signatures for the given function.
|
|
|
|
If no signature can be found, return None. If all of the registered SignatureGenerators
|
|
for the stub generator return None, then the default_sig will be used.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_property_type(self, default_type: str | None, ctx: FunctionContext) -> str | None:
|
|
"""Return the type of the given property"""
|
|
pass
|
|
|
|
|
|
class ImportTracker:
|
|
"""Record necessary imports during stub generation."""
|
|
|
|
def __init__(self) -> None:
|
|
# module_for['foo'] has the module name where 'foo' was imported from, or None if
|
|
# 'foo' is a module imported directly;
|
|
# direct_imports['foo'] is the module path used when the name 'foo' was added to the
|
|
# namespace.
|
|
# reverse_alias['foo'] is the name that 'foo' had originally when imported with an
|
|
# alias; examples
|
|
# 'from pkg import mod' ==> module_for['mod'] == 'pkg'
|
|
# 'from pkg import mod as m' ==> module_for['m'] == 'pkg'
|
|
# ==> reverse_alias['m'] == 'mod'
|
|
# 'import pkg.mod as m' ==> module_for['m'] == None
|
|
# ==> reverse_alias['m'] == 'pkg.mod'
|
|
# 'import pkg.mod' ==> module_for['pkg'] == None
|
|
# ==> module_for['pkg.mod'] == None
|
|
# ==> direct_imports['pkg'] == 'pkg.mod'
|
|
# ==> direct_imports['pkg.mod'] == 'pkg.mod'
|
|
self.module_for: dict[str, str | None] = {}
|
|
self.direct_imports: dict[str, str] = {}
|
|
self.reverse_alias: dict[str, str] = {}
|
|
|
|
# required_names is the set of names that are actually used in a type annotation
|
|
self.required_names: set[str] = set()
|
|
|
|
# Names that should be reexported if they come from another module
|
|
self.reexports: set[str] = set()
|
|
|
|
def add_import_from(
|
|
self, module: str, names: list[tuple[str, str | None]], require: bool = False
|
|
) -> None:
|
|
for name, alias in names:
|
|
if alias:
|
|
# 'from {module} import {name} as {alias}'
|
|
self.module_for[alias] = module
|
|
self.reverse_alias[alias] = name
|
|
else:
|
|
# 'from {module} import {name}'
|
|
self.module_for[name] = module
|
|
self.reverse_alias.pop(name, None)
|
|
if require:
|
|
self.require_name(alias or name)
|
|
self.direct_imports.pop(alias or name, None)
|
|
|
|
def add_import(self, module: str, alias: str | None = None, require: bool = False) -> None:
|
|
if alias:
|
|
# 'import {module} as {alias}'
|
|
assert "." not in alias # invalid syntax
|
|
self.module_for[alias] = None
|
|
self.reverse_alias[alias] = module
|
|
if require:
|
|
self.required_names.add(alias)
|
|
else:
|
|
# 'import {module}'
|
|
name = module
|
|
if require:
|
|
self.required_names.add(name)
|
|
# add module and its parent packages
|
|
while name:
|
|
self.module_for[name] = None
|
|
self.direct_imports[name] = module
|
|
self.reverse_alias.pop(name, None)
|
|
name = name.rpartition(".")[0]
|
|
|
|
def require_name(self, name: str) -> None:
|
|
while name not in self.direct_imports and "." in name:
|
|
name = name.rsplit(".", 1)[0]
|
|
self.required_names.add(name)
|
|
|
|
def reexport(self, name: str) -> None:
|
|
"""Mark a given non qualified name as needed in __all__.
|
|
|
|
This means that in case it comes from a module, it should be
|
|
imported with an alias even if the alias is the same as the name.
|
|
"""
|
|
self.require_name(name)
|
|
self.reexports.add(name)
|
|
|
|
def import_lines(self) -> list[str]:
|
|
"""The list of required import lines (as strings with python code).
|
|
|
|
In order for a module be included in this output, an identifier must be both
|
|
'required' via require_name() and 'imported' via add_import_from()
|
|
or add_import()
|
|
"""
|
|
result = []
|
|
|
|
# To summarize multiple names imported from a same module, we collect those
|
|
# in the `module_map` dictionary, mapping a module path to the list of names that should
|
|
# be imported from it. the names can also be alias in the form 'original as alias'
|
|
module_map: Mapping[str, list[str]] = defaultdict(list)
|
|
|
|
for name in sorted(
|
|
self.required_names,
|
|
key=lambda n: (self.reverse_alias[n], n) if n in self.reverse_alias else (n, ""),
|
|
):
|
|
# If we haven't seen this name in an import statement, ignore it
|
|
if name not in self.module_for:
|
|
continue
|
|
|
|
m = self.module_for[name]
|
|
if m is not None:
|
|
# This name was found in a from ... import ...
|
|
# Collect the name in the module_map
|
|
if name in self.reverse_alias:
|
|
name = f"{self.reverse_alias[name]} as {name}"
|
|
elif name in self.reexports:
|
|
name = f"{name} as {name}"
|
|
module_map[m].append(name)
|
|
else:
|
|
# This name was found in an import ...
|
|
# We can already generate the import line
|
|
if name in self.reverse_alias:
|
|
source = self.reverse_alias[name]
|
|
result.append(f"import {source} as {name}\n")
|
|
elif name in self.reexports:
|
|
assert "." not in name # Because reexports only has nonqualified names
|
|
result.append(f"import {name} as {name}\n")
|
|
else:
|
|
result.append(f"import {name}\n")
|
|
|
|
# Now generate all the from ... import ... lines collected in module_map
|
|
for module, names in sorted(module_map.items()):
|
|
result.append(f"from {module} import {', '.join(sorted(names))}\n")
|
|
return result
|
|
|
|
|
|
@mypyc_attr(allow_interpreted_subclasses=True)
|
|
class BaseStubGenerator:
|
|
# These names should be omitted from generated stubs.
|
|
IGNORED_DUNDERS: Final = {
|
|
"__all__",
|
|
"__author__",
|
|
"__about__",
|
|
"__copyright__",
|
|
"__email__",
|
|
"__license__",
|
|
"__summary__",
|
|
"__title__",
|
|
"__uri__",
|
|
"__str__",
|
|
"__repr__",
|
|
"__getstate__",
|
|
"__setstate__",
|
|
"__slots__",
|
|
"__builtins__",
|
|
"__cached__",
|
|
"__file__",
|
|
"__name__",
|
|
"__package__",
|
|
"__path__",
|
|
"__spec__",
|
|
"__loader__",
|
|
}
|
|
TYPING_MODULE_NAMES: Final = ("typing", "typing_extensions")
|
|
# Special-cased names that are implicitly exported from the stub (from m import y as y).
|
|
EXTRA_EXPORTED: Final = {
|
|
"pyasn1_modules.rfc2437.univ",
|
|
"pyasn1_modules.rfc2459.char",
|
|
"pyasn1_modules.rfc2459.univ",
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
_all_: list[str] | None = None,
|
|
include_private: bool = False,
|
|
export_less: bool = False,
|
|
include_docstrings: bool = False,
|
|
) -> None:
|
|
# Best known value of __all__.
|
|
self._all_ = _all_
|
|
self._include_private = include_private
|
|
self._include_docstrings = include_docstrings
|
|
# Disable implicit exports of package-internal imports?
|
|
self.export_less = export_less
|
|
self._import_lines: list[str] = []
|
|
self._output: list[str] = []
|
|
# Current indent level (indent is hardcoded to 4 spaces).
|
|
self._indent = ""
|
|
self._toplevel_names: list[str] = []
|
|
self.import_tracker = ImportTracker()
|
|
# Top-level members
|
|
self.defined_names: set[str] = set()
|
|
self.sig_generators = self.get_sig_generators()
|
|
# populated by visit_mypy_file
|
|
self.module_name: str = ""
|
|
# These are "soft" imports for objects which might appear in annotations but not have
|
|
# a corresponding import statement.
|
|
self.known_imports = {
|
|
"_typeshed": ["Incomplete"],
|
|
"typing": ["Any", "TypeVar", "NamedTuple", "TypedDict"],
|
|
"collections.abc": ["Generator"],
|
|
"typing_extensions": ["ParamSpec", "TypeVarTuple"],
|
|
}
|
|
|
|
def get_sig_generators(self) -> list[SignatureGenerator]:
|
|
return []
|
|
|
|
def resolve_name(self, name: str) -> str:
|
|
"""Return the full name resolving imports and import aliases."""
|
|
if "." not in name:
|
|
real_module = self.import_tracker.module_for.get(name)
|
|
real_short = self.import_tracker.reverse_alias.get(name, name)
|
|
if real_module is None and real_short not in self.defined_names:
|
|
real_module = "builtins" # not imported and not defined, must be a builtin
|
|
else:
|
|
name_module, real_short = name.split(".", 1)
|
|
real_module = self.import_tracker.reverse_alias.get(name_module, name_module)
|
|
resolved_name = real_short if real_module is None else f"{real_module}.{real_short}"
|
|
return resolved_name
|
|
|
|
def add_name(self, fullname: str, require: bool = True) -> str:
|
|
"""Add a name to be imported and return the name reference.
|
|
|
|
The import will be internal to the stub (i.e don't reexport).
|
|
"""
|
|
module, name = fullname.rsplit(".", 1)
|
|
alias = "_" + name if name in self.defined_names else None
|
|
while alias in self.defined_names:
|
|
alias = "_" + alias
|
|
if module != "builtins" or alias: # don't import from builtins unless needed
|
|
self.import_tracker.add_import_from(module, [(name, alias)], require=require)
|
|
return alias or name
|
|
|
|
def add_import_line(self, line: str) -> None:
|
|
"""Add a line of text to the import section, unless it's already there."""
|
|
if line not in self._import_lines:
|
|
self._import_lines.append(line)
|
|
|
|
def get_imports(self) -> str:
|
|
"""Return the import statements for the stub."""
|
|
imports = ""
|
|
if self._import_lines:
|
|
imports += "".join(self._import_lines)
|
|
imports += "".join(self.import_tracker.import_lines())
|
|
return imports
|
|
|
|
def output(self) -> str:
|
|
"""Return the text for the stub."""
|
|
pieces: list[str] = []
|
|
if imports := self.get_imports():
|
|
pieces.append(imports)
|
|
if dunder_all := self.get_dunder_all():
|
|
pieces.append(dunder_all)
|
|
if self._output:
|
|
pieces.append("".join(self._output))
|
|
return "\n".join(pieces)
|
|
|
|
def get_dunder_all(self) -> str:
|
|
"""Return the __all__ list for the stub."""
|
|
if self._all_:
|
|
# Note we emit all names in the runtime __all__ here, even if they
|
|
# don't actually exist. If that happens, the runtime has a bug, and
|
|
# it's not obvious what the correct behavior should be. We choose
|
|
# to reflect the runtime __all__ as closely as possible.
|
|
return f"__all__ = {self._all_!r}\n"
|
|
return ""
|
|
|
|
def add(self, string: str) -> None:
|
|
"""Add text to generated stub."""
|
|
self._output.append(string)
|
|
|
|
def is_top_level(self) -> bool:
|
|
"""Are we processing the top level of a file?"""
|
|
return self._indent == ""
|
|
|
|
def indent(self) -> None:
|
|
"""Add one level of indentation."""
|
|
self._indent += " "
|
|
|
|
def dedent(self) -> None:
|
|
"""Remove one level of indentation."""
|
|
self._indent = self._indent[:-4]
|
|
|
|
def record_name(self, name: str) -> None:
|
|
"""Mark a name as defined.
|
|
|
|
This only does anything if at the top level of a module.
|
|
"""
|
|
if self.is_top_level():
|
|
self._toplevel_names.append(name)
|
|
|
|
def is_recorded_name(self, name: str) -> bool:
|
|
"""Has this name been recorded previously?"""
|
|
return self.is_top_level() and name in self._toplevel_names
|
|
|
|
def set_defined_names(self, defined_names: set[str]) -> None:
|
|
self.defined_names = defined_names
|
|
# Names in __all__ are required
|
|
for name in self._all_ or ():
|
|
self.import_tracker.reexport(name)
|
|
|
|
for pkg, imports in self.known_imports.items():
|
|
for t in imports:
|
|
# require=False means that the import won't be added unless require_name() is called
|
|
# for the object during generation.
|
|
self.add_name(f"{pkg}.{t}", require=False)
|
|
|
|
def check_undefined_names(self) -> None:
|
|
undefined_names = [name for name in self._all_ or [] if name not in self._toplevel_names]
|
|
if undefined_names:
|
|
if self._output:
|
|
self.add("\n")
|
|
self.add("# Names in __all__ with no definition:\n")
|
|
for name in sorted(undefined_names):
|
|
self.add(f"# {name}\n")
|
|
|
|
def get_signatures(
|
|
self,
|
|
default_signature: FunctionSig,
|
|
sig_generators: list[SignatureGenerator],
|
|
func_ctx: FunctionContext,
|
|
) -> list[FunctionSig]:
|
|
for sig_gen in sig_generators:
|
|
inferred = sig_gen.get_function_sig(default_signature, func_ctx)
|
|
if inferred:
|
|
return inferred
|
|
|
|
return [default_signature]
|
|
|
|
def get_property_type(
|
|
self,
|
|
default_type: str | None,
|
|
sig_generators: list[SignatureGenerator],
|
|
func_ctx: FunctionContext,
|
|
) -> str | None:
|
|
for sig_gen in sig_generators:
|
|
inferred = sig_gen.get_property_type(default_type, func_ctx)
|
|
if inferred:
|
|
return inferred
|
|
|
|
return default_type
|
|
|
|
def format_func_def(
|
|
self,
|
|
sigs: list[FunctionSig],
|
|
is_coroutine: bool = False,
|
|
decorators: list[str] | None = None,
|
|
docstring: str | None = None,
|
|
) -> list[str]:
|
|
lines: list[str] = []
|
|
if decorators is None:
|
|
decorators = []
|
|
|
|
for signature in sigs:
|
|
# dump decorators, just before "def ..."
|
|
for deco in decorators:
|
|
lines.append(f"{self._indent}{deco}")
|
|
|
|
lines.append(
|
|
signature.format_sig(
|
|
indent=self._indent,
|
|
is_async=is_coroutine,
|
|
docstring=docstring,
|
|
include_docstrings=self._include_docstrings,
|
|
)
|
|
)
|
|
return lines
|
|
|
|
def format_type_args(self, o: TypeAliasStmt | FuncDef | ClassDef) -> str:
|
|
if not o.type_args:
|
|
return ""
|
|
p = AnnotationPrinter(self)
|
|
type_args_list: list[str] = []
|
|
for type_arg in o.type_args:
|
|
if type_arg.kind == PARAM_SPEC_KIND:
|
|
prefix = "**"
|
|
elif type_arg.kind == TYPE_VAR_TUPLE_KIND:
|
|
prefix = "*"
|
|
else:
|
|
prefix = ""
|
|
if type_arg.upper_bound:
|
|
bound_or_values = f": {type_arg.upper_bound.accept(p)}"
|
|
elif type_arg.values:
|
|
bound_or_values = f": ({', '.join(v.accept(p) for v in type_arg.values)})"
|
|
else:
|
|
bound_or_values = ""
|
|
if type_arg.default:
|
|
default = f" = {type_arg.default.accept(p)}"
|
|
else:
|
|
default = ""
|
|
type_args_list.append(f"{prefix}{type_arg.name}{bound_or_values}{default}")
|
|
return "[" + ", ".join(type_args_list) + "]"
|
|
|
|
def print_annotation(
|
|
self,
|
|
t: Type,
|
|
known_modules: list[str] | None = None,
|
|
local_modules: list[str] | None = None,
|
|
) -> str:
|
|
printer = AnnotationPrinter(self, known_modules, local_modules)
|
|
return t.accept(printer)
|
|
|
|
def is_not_in_all(self, name: str) -> bool:
|
|
if self.is_private_name(name):
|
|
return False
|
|
if self._all_:
|
|
return self.is_top_level() and name not in self._all_
|
|
return False
|
|
|
|
def is_private_name(self, name: str, fullname: str | None = None) -> bool:
|
|
if "__mypy-" in name:
|
|
return True # Never include mypy generated symbols
|
|
if self._include_private:
|
|
return False
|
|
if fullname in self.EXTRA_EXPORTED:
|
|
return False
|
|
if name == "_":
|
|
return False
|
|
if not name.startswith("_"):
|
|
return False
|
|
if self._all_ and name in self._all_:
|
|
return False
|
|
if name.startswith("__") and name.endswith("__"):
|
|
return name in self.IGNORED_DUNDERS
|
|
return True
|
|
|
|
def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> bool:
|
|
if (
|
|
not name_is_alias
|
|
and self.module_name
|
|
and (self.module_name + "." + name) in self.EXTRA_EXPORTED
|
|
):
|
|
# Special case certain names that should be exported, against our general rules.
|
|
return True
|
|
if name_is_alias:
|
|
return False
|
|
if self.export_less:
|
|
return False
|
|
if not self.module_name:
|
|
return False
|
|
is_private = self.is_private_name(name, full_module + "." + name)
|
|
if is_private:
|
|
return False
|
|
top_level = full_module.split(".")[0]
|
|
self_top_level = self.module_name.split(".", 1)[0]
|
|
if top_level not in (self_top_level, "_" + self_top_level):
|
|
# Export imports from the same package, since we can't reliably tell whether they
|
|
# are part of the public API.
|
|
return False
|
|
if self._all_:
|
|
return name in self._all_
|
|
return True
|