72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
"""Test cases for annotating source code to highlight inefficiencies."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os.path
|
|
|
|
from mypy.errors import CompileError
|
|
from mypy.test.config import test_temp_dir
|
|
from mypy.test.data import DataDrivenTestCase
|
|
from mypyc.annotate import generate_annotations, get_max_prio
|
|
from mypyc.ir.pprint import format_func
|
|
from mypyc.test.testutil import (
|
|
ICODE_GEN_BUILTINS,
|
|
MypycDataSuite,
|
|
assert_test_output,
|
|
build_ir_for_single_file2,
|
|
infer_ir_build_options_from_test_name,
|
|
remove_comment_lines,
|
|
use_custom_builtins,
|
|
)
|
|
|
|
files = ["annotate-basic.test"]
|
|
|
|
|
|
class TestReport(MypycDataSuite):
|
|
files = files
|
|
base_path = test_temp_dir
|
|
optional_out = True
|
|
|
|
def run_case(self, testcase: DataDrivenTestCase) -> None:
|
|
"""Perform a runtime checking transformation test case."""
|
|
options = infer_ir_build_options_from_test_name(testcase.name)
|
|
if options is None:
|
|
# Skipped test case
|
|
return
|
|
with use_custom_builtins(os.path.join(self.data_prefix, ICODE_GEN_BUILTINS), testcase):
|
|
expected_output = remove_comment_lines(testcase.output)
|
|
|
|
# Parse "# A: <message>" comments.
|
|
for i, line in enumerate(testcase.input):
|
|
if "# A:" in line:
|
|
msg = line.rpartition("# A:")[2].strip()
|
|
expected_output.append(f"main:{i + 1}: {msg}")
|
|
|
|
ir = None
|
|
try:
|
|
ir, tree, type_map, mapper = build_ir_for_single_file2(testcase.input, options)
|
|
except CompileError as e:
|
|
actual = e.messages
|
|
else:
|
|
annotations = generate_annotations("native.py", tree, ir, type_map, mapper)
|
|
actual = []
|
|
for line_num, line_anns in sorted(
|
|
annotations.annotations.items(), key=lambda it: it[0]
|
|
):
|
|
anns = get_max_prio(line_anns)
|
|
str_anns = [a.message for a in anns]
|
|
s = " ".join(str_anns)
|
|
actual.append(f"main:{line_num}: {s}")
|
|
|
|
try:
|
|
assert_test_output(testcase, actual, "Invalid source code output", expected_output)
|
|
except BaseException:
|
|
if ir:
|
|
print("Generated IR:\n")
|
|
for fn in ir.functions:
|
|
if fn.name == "__top_level__":
|
|
continue
|
|
for s in format_func(fn):
|
|
print(s)
|
|
raise
|