Initial commit: add .gitignore and README
This commit is contained in:
332
tests/test_tools_runner.py
Normal file
332
tests/test_tools_runner.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""Tests for tools runner and builtins."""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from fusionagi.tools.registry import ToolDef, ToolRegistry
|
||||
from fusionagi.tools.runner import run_tool, validate_args, ToolValidationError
|
||||
from fusionagi.tools.builtins import (
|
||||
make_file_read_tool,
|
||||
make_file_write_tool,
|
||||
make_http_get_tool,
|
||||
_validate_url,
|
||||
SSRFProtectionError,
|
||||
)
|
||||
|
||||
|
||||
class TestToolRunner:
|
||||
"""Test tool runner functionality."""
|
||||
|
||||
def test_run_tool_success(self):
|
||||
"""Test successful tool execution."""
|
||||
def add(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
tool = ToolDef(
|
||||
name="add",
|
||||
description="Add two numbers",
|
||||
fn=add,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "integer"},
|
||||
"b": {"type": "integer"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
)
|
||||
|
||||
result, log = run_tool(tool, {"a": 2, "b": 3})
|
||||
|
||||
assert result == 5
|
||||
assert log["result"] == 5
|
||||
assert log["error"] is None
|
||||
|
||||
def test_run_tool_timeout(self):
|
||||
"""Test tool timeout handling."""
|
||||
import time
|
||||
|
||||
def slow_fn() -> str:
|
||||
time.sleep(2)
|
||||
return "done"
|
||||
|
||||
tool = ToolDef(
|
||||
name="slow",
|
||||
description="Slow function",
|
||||
fn=slow_fn,
|
||||
timeout_seconds=0.1,
|
||||
)
|
||||
|
||||
result, log = run_tool(tool, {})
|
||||
|
||||
assert result is None
|
||||
assert "timed out" in log["error"]
|
||||
|
||||
def test_run_tool_exception(self):
|
||||
"""Test tool exception handling."""
|
||||
def failing_fn() -> None:
|
||||
raise ValueError("Something went wrong")
|
||||
|
||||
tool = ToolDef(
|
||||
name="fail",
|
||||
description="Failing function",
|
||||
fn=failing_fn,
|
||||
)
|
||||
|
||||
result, log = run_tool(tool, {})
|
||||
|
||||
assert result is None
|
||||
assert "Something went wrong" in log["error"]
|
||||
|
||||
|
||||
class TestArgValidation:
|
||||
"""Test argument validation."""
|
||||
|
||||
def test_validate_required_fields(self):
|
||||
"""Test validation of required fields."""
|
||||
tool = ToolDef(
|
||||
name="test",
|
||||
description="Test",
|
||||
fn=lambda: None,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"required_field": {"type": "string"},
|
||||
},
|
||||
"required": ["required_field"],
|
||||
},
|
||||
)
|
||||
|
||||
# Missing required field
|
||||
is_valid, error = validate_args(tool, {})
|
||||
assert not is_valid
|
||||
assert "required_field" in error
|
||||
|
||||
# With required field
|
||||
is_valid, error = validate_args(tool, {"required_field": "value"})
|
||||
assert is_valid
|
||||
|
||||
def test_validate_string_type(self):
|
||||
"""Test string type validation."""
|
||||
tool = ToolDef(
|
||||
name="test",
|
||||
description="Test",
|
||||
fn=lambda: None,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
is_valid, _ = validate_args(tool, {"name": "hello"})
|
||||
assert is_valid
|
||||
|
||||
is_valid, error = validate_args(tool, {"name": 123})
|
||||
assert not is_valid
|
||||
assert "string" in error
|
||||
|
||||
def test_validate_number_constraints(self):
|
||||
"""Test number constraint validation."""
|
||||
tool = ToolDef(
|
||||
name="test",
|
||||
description="Test",
|
||||
fn=lambda: None,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"score": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
is_valid, _ = validate_args(tool, {"score": 50})
|
||||
assert is_valid
|
||||
|
||||
is_valid, error = validate_args(tool, {"score": -1})
|
||||
assert not is_valid
|
||||
assert ">=" in error
|
||||
|
||||
is_valid, error = validate_args(tool, {"score": 101})
|
||||
assert not is_valid
|
||||
assert "<=" in error
|
||||
|
||||
def test_validate_enum(self):
|
||||
"""Test enum constraint validation."""
|
||||
tool = ToolDef(
|
||||
name="test",
|
||||
description="Test",
|
||||
fn=lambda: None,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "active", "done"],
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
is_valid, _ = validate_args(tool, {"status": "active"})
|
||||
assert is_valid
|
||||
|
||||
is_valid, error = validate_args(tool, {"status": "invalid"})
|
||||
assert not is_valid
|
||||
assert "one of" in error
|
||||
|
||||
def test_validate_with_tool_runner(self):
|
||||
"""Test validation integration with run_tool."""
|
||||
tool = ToolDef(
|
||||
name="test",
|
||||
description="Test",
|
||||
fn=lambda x: x,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "integer"},
|
||||
},
|
||||
"required": ["x"],
|
||||
},
|
||||
)
|
||||
|
||||
# Invalid args should fail validation
|
||||
result, log = run_tool(tool, {"x": "not an int"}, validate=True)
|
||||
assert result is None
|
||||
assert "Validation error" in log["error"]
|
||||
|
||||
# Skip validation
|
||||
result, log = run_tool(tool, {"x": "not an int"}, validate=False)
|
||||
# Execution may fail, but not due to validation
|
||||
assert "Validation error" not in (log.get("error") or "")
|
||||
|
||||
|
||||
class TestToolRegistry:
|
||||
"""Test tool registry functionality."""
|
||||
|
||||
def test_register_and_get(self):
|
||||
"""Test registering and retrieving tools."""
|
||||
registry = ToolRegistry()
|
||||
|
||||
tool = ToolDef(name="test", description="Test", fn=lambda: None)
|
||||
registry.register(tool)
|
||||
|
||||
retrieved = registry.get("test")
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == "test"
|
||||
|
||||
def test_list_tools(self):
|
||||
"""Test listing all tools."""
|
||||
registry = ToolRegistry()
|
||||
|
||||
registry.register(ToolDef(name="t1", description="Tool 1", fn=lambda: None))
|
||||
registry.register(ToolDef(name="t2", description="Tool 2", fn=lambda: None))
|
||||
|
||||
tools = registry.list_tools()
|
||||
assert len(tools) == 2
|
||||
names = {t["name"] for t in tools}
|
||||
assert names == {"t1", "t2"}
|
||||
|
||||
def test_permission_check(self):
|
||||
"""Test permission checking."""
|
||||
registry = ToolRegistry()
|
||||
|
||||
tool = ToolDef(
|
||||
name="restricted",
|
||||
description="Restricted tool",
|
||||
fn=lambda: None,
|
||||
permission_scope=["admin", "write"],
|
||||
)
|
||||
registry.register(tool)
|
||||
|
||||
# Has matching permission
|
||||
assert registry.allowed_for("restricted", ["admin"])
|
||||
assert registry.allowed_for("restricted", ["write"])
|
||||
|
||||
# No matching permission
|
||||
assert not registry.allowed_for("restricted", ["read"])
|
||||
|
||||
# Wildcard permissions
|
||||
assert registry.allowed_for("restricted", ["*"])
|
||||
|
||||
|
||||
class TestSSRFProtection:
|
||||
"""Test SSRF protection in URL validation."""
|
||||
|
||||
def test_localhost_blocked(self):
|
||||
"""Test that localhost URLs are blocked."""
|
||||
with pytest.raises(SSRFProtectionError, match="Localhost"):
|
||||
_validate_url("http://localhost/path")
|
||||
|
||||
with pytest.raises(SSRFProtectionError, match="Localhost"):
|
||||
_validate_url("http://127.0.0.1/path")
|
||||
|
||||
def test_private_ip_blocked(self):
|
||||
"""Test that private IPs are blocked after DNS resolution."""
|
||||
# Note: This test may pass or fail depending on DNS resolution
|
||||
# Testing the concept with a known internal hostname pattern
|
||||
with pytest.raises(SSRFProtectionError):
|
||||
_validate_url("http://test.local/path")
|
||||
|
||||
def test_non_http_scheme_blocked(self):
|
||||
"""Test that non-HTTP schemes are blocked."""
|
||||
with pytest.raises(SSRFProtectionError, match="scheme"):
|
||||
_validate_url("file:///etc/passwd")
|
||||
|
||||
with pytest.raises(SSRFProtectionError, match="scheme"):
|
||||
_validate_url("ftp://example.com/file")
|
||||
|
||||
def test_valid_url_passes(self):
|
||||
"""Test that valid public URLs pass."""
|
||||
# This should not raise
|
||||
url = _validate_url("https://example.com/path")
|
||||
assert url == "https://example.com/path"
|
||||
|
||||
|
||||
class TestFileTools:
|
||||
"""Test file read/write tools."""
|
||||
|
||||
def test_file_read_in_scope(self):
|
||||
"""Test reading a file within scope."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create a test file
|
||||
test_file = os.path.join(tmpdir, "test.txt")
|
||||
with open(test_file, "w") as f:
|
||||
f.write("Hello, World!")
|
||||
|
||||
tool = make_file_read_tool(scope=tmpdir)
|
||||
result, log = run_tool(tool, {"path": test_file})
|
||||
|
||||
assert result == "Hello, World!"
|
||||
assert log["error"] is None
|
||||
|
||||
def test_file_read_outside_scope(self):
|
||||
"""Test reading a file outside scope is blocked."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tool = make_file_read_tool(scope=tmpdir)
|
||||
|
||||
# Try to read file outside scope
|
||||
result, log = run_tool(tool, {"path": "/etc/passwd"})
|
||||
|
||||
assert result is None
|
||||
assert "not allowed" in log["error"].lower() or "permission" in log["error"].lower()
|
||||
|
||||
def test_file_write_in_scope(self):
|
||||
"""Test writing a file within scope."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tool = make_file_write_tool(scope=tmpdir)
|
||||
|
||||
test_file = os.path.join(tmpdir, "output.txt")
|
||||
result, log = run_tool(tool, {"path": test_file, "content": "Test content"})
|
||||
|
||||
assert log["error"] is None
|
||||
assert os.path.exists(test_file)
|
||||
|
||||
with open(test_file) as f:
|
||||
assert f.read() == "Test content"
|
||||
Reference in New Issue
Block a user