396 lines
15 KiB
Python
396 lines
15 KiB
Python
"""Plugin for supporting the functools standard library module."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Final, NamedTuple
|
|
|
|
import mypy.checker
|
|
import mypy.plugin
|
|
import mypy.semanal
|
|
from mypy.argmap import map_actuals_to_formals
|
|
from mypy.erasetype import erase_typevars
|
|
from mypy.nodes import (
|
|
ARG_POS,
|
|
ARG_STAR2,
|
|
SYMBOL_FUNCBASE_TYPES,
|
|
ArgKind,
|
|
Argument,
|
|
CallExpr,
|
|
NameExpr,
|
|
Var,
|
|
)
|
|
from mypy.plugins.common import add_method_to_class
|
|
from mypy.typeops import get_all_type_vars
|
|
from mypy.types import (
|
|
AnyType,
|
|
CallableType,
|
|
Instance,
|
|
Overloaded,
|
|
ParamSpecFlavor,
|
|
ParamSpecType,
|
|
Type,
|
|
TypeOfAny,
|
|
TypeVarType,
|
|
UnboundType,
|
|
UnionType,
|
|
get_proper_type,
|
|
)
|
|
|
|
functools_total_ordering_makers: Final = {"functools.total_ordering"}
|
|
|
|
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}
|
|
|
|
PARTIAL: Final = "functools.partial"
|
|
|
|
|
|
class _MethodInfo(NamedTuple):
|
|
is_static: bool
|
|
type: CallableType
|
|
|
|
|
|
def functools_total_ordering_maker_callback(
|
|
ctx: mypy.plugin.ClassDefContext, auto_attribs_default: bool = False
|
|
) -> bool:
|
|
"""Add dunder methods to classes decorated with functools.total_ordering."""
|
|
comparison_methods = _analyze_class(ctx)
|
|
if not comparison_methods:
|
|
ctx.api.fail(
|
|
'No ordering operation defined when using "functools.total_ordering": < > <= >=',
|
|
ctx.reason,
|
|
)
|
|
return True
|
|
|
|
# prefer __lt__ to __le__ to __gt__ to __ge__
|
|
root = max(comparison_methods, key=lambda k: (comparison_methods[k] is None, k))
|
|
root_method = comparison_methods[root]
|
|
if not root_method:
|
|
# None of the defined comparison methods can be analysed
|
|
return True
|
|
|
|
other_type = _find_other_type(root_method)
|
|
bool_type = ctx.api.named_type("builtins.bool")
|
|
ret_type: Type = bool_type
|
|
if root_method.type.ret_type != ctx.api.named_type("builtins.bool"):
|
|
proper_ret_type = get_proper_type(root_method.type.ret_type)
|
|
if not (
|
|
isinstance(proper_ret_type, UnboundType)
|
|
and proper_ret_type.name.split(".")[-1] == "bool"
|
|
):
|
|
ret_type = AnyType(TypeOfAny.implementation_artifact)
|
|
for additional_op in _ORDERING_METHODS:
|
|
# Either the method is not implemented
|
|
# or has an unknown signature that we can now extrapolate.
|
|
if not comparison_methods.get(additional_op):
|
|
args = [Argument(Var("other", other_type), other_type, None, ARG_POS)]
|
|
add_method_to_class(ctx.api, ctx.cls, additional_op, args, ret_type)
|
|
|
|
return True
|
|
|
|
|
|
def _find_other_type(method: _MethodInfo) -> Type:
|
|
"""Find the type of the ``other`` argument in a comparison method."""
|
|
first_arg_pos = 0 if method.is_static else 1
|
|
cur_pos_arg = 0
|
|
other_arg = None
|
|
for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types):
|
|
if arg_kind.is_positional():
|
|
if cur_pos_arg == first_arg_pos:
|
|
other_arg = arg_type
|
|
break
|
|
|
|
cur_pos_arg += 1
|
|
elif arg_kind != ARG_STAR2:
|
|
other_arg = arg_type
|
|
break
|
|
|
|
if other_arg is None:
|
|
return AnyType(TypeOfAny.implementation_artifact)
|
|
|
|
return other_arg
|
|
|
|
|
|
def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo | None]:
|
|
"""Analyze the class body, its parents, and return the comparison methods found."""
|
|
# Traverse the MRO and collect ordering methods.
|
|
comparison_methods: dict[str, _MethodInfo | None] = {}
|
|
# Skip object because total_ordering does not use methods from object
|
|
for cls in ctx.cls.info.mro[:-1]:
|
|
for name in _ORDERING_METHODS:
|
|
if name in cls.names and name not in comparison_methods:
|
|
node = cls.names[name].node
|
|
if isinstance(node, SYMBOL_FUNCBASE_TYPES) and isinstance(node.type, CallableType):
|
|
comparison_methods[name] = _MethodInfo(node.is_static, node.type)
|
|
continue
|
|
|
|
if isinstance(node, Var):
|
|
proper_type = get_proper_type(node.type)
|
|
if isinstance(proper_type, CallableType):
|
|
comparison_methods[name] = _MethodInfo(node.is_staticmethod, proper_type)
|
|
continue
|
|
|
|
comparison_methods[name] = None
|
|
|
|
return comparison_methods
|
|
|
|
|
|
def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
|
|
"""Infer a more precise return type for functools.partial"""
|
|
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
|
|
return ctx.default_return_type
|
|
if len(ctx.arg_types) != 3: # fn, *args, **kwargs
|
|
return ctx.default_return_type
|
|
if len(ctx.arg_types[0]) != 1:
|
|
return ctx.default_return_type
|
|
|
|
if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded):
|
|
# TODO: handle overloads, just fall back to whatever the non-plugin code does
|
|
return ctx.default_return_type
|
|
return handle_partial_with_callee(ctx, callee=ctx.arg_types[0][0])
|
|
|
|
|
|
def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -> Type:
|
|
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
|
|
return ctx.default_return_type
|
|
|
|
if isinstance(callee_proper := get_proper_type(callee), UnionType):
|
|
return UnionType.make_union(
|
|
[handle_partial_with_callee(ctx, item) for item in callee_proper.items]
|
|
)
|
|
|
|
fn_type = ctx.api.extract_callable_type(callee, ctx=ctx.default_return_type)
|
|
if fn_type is None:
|
|
return ctx.default_return_type
|
|
|
|
# We must normalize from the start to have coherent view together with TypeChecker.
|
|
fn_type = fn_type.with_unpacked_kwargs().with_normalized_var_args()
|
|
|
|
last_context = ctx.api.type_context[-1]
|
|
if not fn_type.is_type_obj():
|
|
# We wrap the return type to get use of a possible type context provided by caller.
|
|
# We cannot do this in case of class objects, since otherwise the plugin may get
|
|
# falsely triggered when evaluating the constructed call itself.
|
|
ret_type: Type = ctx.api.named_generic_type(PARTIAL, [fn_type.ret_type])
|
|
wrapped_return = True
|
|
else:
|
|
ret_type = fn_type.ret_type
|
|
# Instead, for class objects we ignore any type context to avoid spurious errors,
|
|
# since the type context will be partial[X] etc., not X.
|
|
ctx.api.type_context[-1] = None
|
|
wrapped_return = False
|
|
|
|
# Flatten actual to formal mapping, since this is what check_call() expects.
|
|
actual_args = []
|
|
actual_arg_kinds = []
|
|
actual_arg_names = []
|
|
actual_types = []
|
|
seen_args = set()
|
|
for i, param in enumerate(ctx.args[1:], start=1):
|
|
for j, a in enumerate(param):
|
|
if a in seen_args:
|
|
# Same actual arg can map to multiple formals, but we need to include
|
|
# each one only once.
|
|
continue
|
|
# Here we rely on the fact that expressions are essentially immutable, so
|
|
# they can be compared by identity.
|
|
seen_args.add(a)
|
|
actual_args.append(a)
|
|
actual_arg_kinds.append(ctx.arg_kinds[i][j])
|
|
actual_arg_names.append(ctx.arg_names[i][j])
|
|
actual_types.append(ctx.arg_types[i][j])
|
|
|
|
formal_to_actual = map_actuals_to_formals(
|
|
actual_kinds=actual_arg_kinds,
|
|
actual_names=actual_arg_names,
|
|
formal_kinds=fn_type.arg_kinds,
|
|
formal_names=fn_type.arg_names,
|
|
actual_arg_type=lambda i: actual_types[i],
|
|
)
|
|
|
|
# We need to remove any type variables that appear only in formals that have
|
|
# no actuals, to avoid eagerly binding them in check_call() below.
|
|
can_infer_ids = set()
|
|
for i, arg_type in enumerate(fn_type.arg_types):
|
|
if not formal_to_actual[i]:
|
|
continue
|
|
can_infer_ids.update({tv.id for tv in get_all_type_vars(arg_type)})
|
|
|
|
# special_sig="partial" allows omission of args/kwargs typed with ParamSpec
|
|
defaulted = fn_type.copy_modified(
|
|
arg_kinds=[
|
|
(
|
|
ArgKind.ARG_OPT
|
|
if k == ArgKind.ARG_POS
|
|
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
|
|
)
|
|
for k in fn_type.arg_kinds
|
|
],
|
|
ret_type=ret_type,
|
|
variables=[
|
|
tv
|
|
for tv in fn_type.variables
|
|
# Keep TypeVarTuple/ParamSpec to avoid spurious errors on empty args.
|
|
if tv.id in can_infer_ids or not isinstance(tv, TypeVarType)
|
|
],
|
|
special_sig="partial",
|
|
)
|
|
if defaulted.line < 0:
|
|
# Make up a line number if we don't have one
|
|
defaulted.set_line(ctx.default_return_type)
|
|
|
|
# Create a valid context for various ad-hoc inspections in check_call().
|
|
call_expr = CallExpr(
|
|
callee=ctx.args[0][0],
|
|
args=actual_args,
|
|
arg_kinds=actual_arg_kinds,
|
|
arg_names=actual_arg_names,
|
|
analyzed=ctx.context.analyzed if isinstance(ctx.context, CallExpr) else None,
|
|
)
|
|
call_expr.set_line(ctx.context)
|
|
|
|
_, bound = ctx.api.expr_checker.check_call(
|
|
callee=defaulted,
|
|
args=actual_args,
|
|
arg_kinds=actual_arg_kinds,
|
|
arg_names=actual_arg_names,
|
|
context=call_expr,
|
|
)
|
|
if not wrapped_return:
|
|
# Restore previously ignored context.
|
|
ctx.api.type_context[-1] = last_context
|
|
|
|
bound = get_proper_type(bound)
|
|
if not isinstance(bound, CallableType):
|
|
return ctx.default_return_type
|
|
|
|
if wrapped_return:
|
|
# Reverse the wrapping we did above.
|
|
ret_type = get_proper_type(bound.ret_type)
|
|
if not isinstance(ret_type, Instance) or ret_type.type.fullname != PARTIAL:
|
|
return ctx.default_return_type
|
|
bound = bound.copy_modified(ret_type=ret_type.args[0])
|
|
|
|
partial_kinds = []
|
|
partial_types = []
|
|
partial_names = []
|
|
# We need to fully apply any positional arguments (they cannot be respecified)
|
|
# However, keyword arguments can be respecified, so just give them a default
|
|
for i, actuals in enumerate(formal_to_actual):
|
|
if len(bound.arg_types) == len(fn_type.arg_types):
|
|
arg_type = bound.arg_types[i]
|
|
if not mypy.checker.is_valid_inferred_type(arg_type, ctx.api.options):
|
|
arg_type = fn_type.arg_types[i] # bit of a hack
|
|
else:
|
|
# TODO: I assume that bound and fn_type have the same arguments. It appears this isn't
|
|
# true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple
|
|
arg_type = fn_type.arg_types[i]
|
|
|
|
if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
|
|
partial_kinds.append(fn_type.arg_kinds[i])
|
|
partial_types.append(arg_type)
|
|
partial_names.append(fn_type.arg_names[i])
|
|
else:
|
|
assert actuals
|
|
if any(actual_arg_kinds[j] in (ArgKind.ARG_POS, ArgKind.ARG_STAR) for j in actuals):
|
|
# Don't add params for arguments passed positionally
|
|
continue
|
|
# Add defaulted params for arguments passed via keyword
|
|
kind = actual_arg_kinds[actuals[0]]
|
|
if kind == ArgKind.ARG_NAMED or kind == ArgKind.ARG_STAR2:
|
|
kind = ArgKind.ARG_NAMED_OPT
|
|
partial_kinds.append(kind)
|
|
partial_types.append(arg_type)
|
|
partial_names.append(fn_type.arg_names[i])
|
|
|
|
ret_type = bound.ret_type
|
|
if not mypy.checker.is_valid_inferred_type(ret_type, ctx.api.options):
|
|
ret_type = fn_type.ret_type # same kind of hack as above
|
|
|
|
partially_applied = fn_type.copy_modified(
|
|
arg_types=partial_types,
|
|
arg_kinds=partial_kinds,
|
|
arg_names=partial_names,
|
|
ret_type=ret_type,
|
|
special_sig="partial",
|
|
)
|
|
|
|
# Do not leak typevars from generic functions - they cannot be usable.
|
|
# Keep them in the wrapped callable, but avoid `partial[SomeStrayTypeVar]`
|
|
erased_ret_type = erase_typevars(ret_type, [tv.id for tv in fn_type.variables])
|
|
|
|
ret = ctx.api.named_generic_type(PARTIAL, [erased_ret_type])
|
|
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
|
|
if partially_applied.param_spec():
|
|
assert ret.extra_attrs is not None # copy_with_extra_attr above ensures this
|
|
attrs = ret.extra_attrs.copy()
|
|
if ArgKind.ARG_STAR in actual_arg_kinds:
|
|
attrs.immutable.add("__mypy_partial_paramspec_args_bound")
|
|
if ArgKind.ARG_STAR2 in actual_arg_kinds:
|
|
attrs.immutable.add("__mypy_partial_paramspec_kwargs_bound")
|
|
ret.extra_attrs = attrs
|
|
return ret
|
|
|
|
|
|
def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
|
|
"""Infer a more precise return type for functools.partial.__call__."""
|
|
if (
|
|
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals
|
|
or not isinstance(ctx.type, Instance)
|
|
or ctx.type.type.fullname != PARTIAL
|
|
or not ctx.type.extra_attrs
|
|
or "__mypy_partial" not in ctx.type.extra_attrs.attrs
|
|
):
|
|
return ctx.default_return_type
|
|
|
|
extra_attrs = ctx.type.extra_attrs
|
|
partial_type = get_proper_type(extra_attrs.attrs["__mypy_partial"])
|
|
if len(ctx.arg_types) != 2: # *args, **kwargs
|
|
return ctx.default_return_type
|
|
|
|
# See comments for similar actual to formal code above
|
|
actual_args = []
|
|
actual_arg_kinds = []
|
|
actual_arg_names = []
|
|
seen_args = set()
|
|
for i, param in enumerate(ctx.args):
|
|
for j, a in enumerate(param):
|
|
if a in seen_args:
|
|
continue
|
|
seen_args.add(a)
|
|
actual_args.append(a)
|
|
actual_arg_kinds.append(ctx.arg_kinds[i][j])
|
|
actual_arg_names.append(ctx.arg_names[i][j])
|
|
|
|
result, _ = ctx.api.expr_checker.check_call(
|
|
callee=partial_type,
|
|
args=actual_args,
|
|
arg_kinds=actual_arg_kinds,
|
|
arg_names=actual_arg_names,
|
|
context=ctx.context,
|
|
)
|
|
if not isinstance(partial_type, CallableType) or partial_type.param_spec() is None:
|
|
return result
|
|
|
|
args_bound = "__mypy_partial_paramspec_args_bound" in extra_attrs.immutable
|
|
kwargs_bound = "__mypy_partial_paramspec_kwargs_bound" in extra_attrs.immutable
|
|
|
|
passed_paramspec_parts = [
|
|
arg.node.type
|
|
for arg in actual_args
|
|
if isinstance(arg, NameExpr)
|
|
and isinstance(arg.node, Var)
|
|
and isinstance(arg.node.type, ParamSpecType)
|
|
]
|
|
# ensure *args: P.args
|
|
args_passed = any(part.flavor == ParamSpecFlavor.ARGS for part in passed_paramspec_parts)
|
|
if not args_bound and not args_passed:
|
|
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
|
|
elif args_bound and args_passed:
|
|
ctx.api.expr_checker.msg.too_many_arguments(partial_type, ctx.context)
|
|
|
|
# ensure **kwargs: P.kwargs
|
|
kwargs_passed = any(part.flavor == ParamSpecFlavor.KWARGS for part in passed_paramspec_parts)
|
|
if not kwargs_bound and not kwargs_passed:
|
|
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
|
|
|
|
return result
|