114 lines
4.7 KiB
Python
114 lines
4.7 KiB
Python
"""Convert tagged int primitive ops to lower-level ops."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import NamedTuple
|
|
|
|
from mypyc.ir.ops import Assign, BasicBlock, Branch, ComparisonOp, Register, Value
|
|
from mypyc.ir.rtypes import bool_rprimitive, is_short_int_rprimitive
|
|
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
|
|
from mypyc.lower.registry import lower_primitive_op
|
|
from mypyc.primitives.int_ops import int_equal_, int_less_than_
|
|
from mypyc.primitives.registry import CFunctionDescription
|
|
|
|
|
|
# Description for building int comparison ops
|
|
#
|
|
# Fields:
|
|
# binary_op_variant: identify which IntOp to use when operands are short integers
|
|
# c_func_description: the C function to call when operands are tagged integers
|
|
# c_func_negated: whether to negate the C function call's result
|
|
# c_func_swap_operands: whether to swap lhs and rhs when call the function
|
|
class IntComparisonOpDescription(NamedTuple):
|
|
binary_op_variant: int
|
|
c_func_description: CFunctionDescription
|
|
c_func_negated: bool
|
|
c_func_swap_operands: bool
|
|
|
|
|
|
# Provide mapping from textual op to short int's op variant and boxed int's description.
|
|
# Note that these are not complete implementations and require extra IR.
|
|
int_comparison_op_mapping: dict[str, IntComparisonOpDescription] = {
|
|
"==": IntComparisonOpDescription(ComparisonOp.EQ, int_equal_, False, False),
|
|
"!=": IntComparisonOpDescription(ComparisonOp.NEQ, int_equal_, True, False),
|
|
"<": IntComparisonOpDescription(ComparisonOp.SLT, int_less_than_, False, False),
|
|
"<=": IntComparisonOpDescription(ComparisonOp.SLE, int_less_than_, True, True),
|
|
">": IntComparisonOpDescription(ComparisonOp.SGT, int_less_than_, False, True),
|
|
">=": IntComparisonOpDescription(ComparisonOp.SGE, int_less_than_, True, False),
|
|
}
|
|
|
|
|
|
def compare_tagged(self: LowLevelIRBuilder, lhs: Value, rhs: Value, op: str, line: int) -> Value:
|
|
"""Compare two tagged integers using given operator (value context)."""
|
|
# generate fast binary logic ops on short ints
|
|
if (is_short_int_rprimitive(lhs.type) or is_short_int_rprimitive(rhs.type)) and op in (
|
|
"==",
|
|
"!=",
|
|
):
|
|
quick = True
|
|
else:
|
|
quick = is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type)
|
|
if quick:
|
|
return self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
|
|
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
|
|
result = Register(bool_rprimitive)
|
|
short_int_block, int_block, out = BasicBlock(), BasicBlock(), BasicBlock()
|
|
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
|
|
if op in ("==", "!="):
|
|
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
|
|
else:
|
|
# for non-equality logical ops (less/greater than, etc.), need to check both sides
|
|
short_lhs = BasicBlock()
|
|
self.add(Branch(check_lhs, int_block, short_lhs, Branch.BOOL))
|
|
self.activate_block(short_lhs)
|
|
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
|
|
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
|
|
self.activate_block(int_block)
|
|
if swap_op:
|
|
args = [rhs, lhs]
|
|
else:
|
|
args = [lhs, rhs]
|
|
call = self.call_c(c_func_desc, args, line)
|
|
if negate_result:
|
|
# TODO: introduce UnaryIntOp?
|
|
call_result = self.unary_op(call, "not", line)
|
|
else:
|
|
call_result = call
|
|
self.add(Assign(result, call_result, line))
|
|
self.goto(out)
|
|
self.activate_block(short_int_block)
|
|
eq = self.comparison_op(lhs, rhs, op_type, line)
|
|
self.add(Assign(result, eq, line))
|
|
self.goto_and_activate(out)
|
|
return result
|
|
|
|
|
|
@lower_primitive_op("int_eq")
|
|
def lower_int_eq(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
|
|
return compare_tagged(builder, args[0], args[1], "==", line)
|
|
|
|
|
|
@lower_primitive_op("int_ne")
|
|
def lower_int_ne(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
|
|
return compare_tagged(builder, args[0], args[1], "!=", line)
|
|
|
|
|
|
@lower_primitive_op("int_lt")
|
|
def lower_int_lt(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
|
|
return compare_tagged(builder, args[0], args[1], "<", line)
|
|
|
|
|
|
@lower_primitive_op("int_le")
|
|
def lower_int_le(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
|
|
return compare_tagged(builder, args[0], args[1], "<=", line)
|
|
|
|
|
|
@lower_primitive_op("int_gt")
|
|
def lower_int_gt(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
|
|
return compare_tagged(builder, args[0], args[1], ">", line)
|
|
|
|
|
|
@lower_primitive_op("int_ge")
|
|
def lower_int_ge(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
|
|
return compare_tagged(builder, args[0], args[1], ">=", line)
|