Initial commit: add .gitignore and README
This commit is contained in:
16
fusionagi/tools/__init__.py
Normal file
16
fusionagi/tools/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Tool registry, safe execution, connectors (docs, DB, code runner)."""
|
||||
|
||||
from fusionagi.tools.registry import ToolRegistry, ToolDef
|
||||
from fusionagi.tools.runner import run_tool, run_tool_with_audit
|
||||
from fusionagi.tools.connectors import BaseConnector, DocsConnector, DBConnector, CodeRunnerConnector
|
||||
|
||||
__all__ = [
|
||||
"ToolRegistry",
|
||||
"ToolDef",
|
||||
"run_tool",
|
||||
"run_tool_with_audit",
|
||||
"BaseConnector",
|
||||
"DocsConnector",
|
||||
"DBConnector",
|
||||
"CodeRunnerConnector",
|
||||
]
|
||||
291
fusionagi/tools/builtins.py
Normal file
291
fusionagi/tools/builtins.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Built-in tools: file read (scoped), HTTP GET (with SSRF protection), query state."""
|
||||
|
||||
import ipaddress
|
||||
import os
|
||||
import socket
|
||||
from typing import Any, Callable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fusionagi.tools.registry import ToolDef
|
||||
from fusionagi._logger import logger
|
||||
|
||||
# Default allowed path prefix for file tools. Deployers should pass an explicit scope (e.g. from config/env)
|
||||
# and not rely on cwd in production.
|
||||
DEFAULT_FILE_SCOPE = os.path.abspath(os.getcwd())
|
||||
|
||||
# Maximum file size for read/write operations (10MB)
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
|
||||
class SSRFProtectionError(Exception):
|
||||
"""Raised when a URL is blocked for SSRF protection."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FileSizeError(Exception):
|
||||
"""Raised when file size exceeds limit."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _normalize_path(path: str, scope: str) -> str:
|
||||
"""
|
||||
Normalize and validate a file path against scope.
|
||||
|
||||
Resolves symlinks and prevents path traversal attacks.
|
||||
"""
|
||||
# Resolve to absolute path
|
||||
abs_path = os.path.abspath(path)
|
||||
|
||||
# Resolve symlinks to get the real path
|
||||
try:
|
||||
real_path = os.path.realpath(abs_path)
|
||||
except OSError:
|
||||
real_path = abs_path
|
||||
|
||||
# Normalize scope too
|
||||
real_scope = os.path.realpath(os.path.abspath(scope))
|
||||
|
||||
# Check if path is under scope
|
||||
if not real_path.startswith(real_scope + os.sep) and real_path != real_scope:
|
||||
raise PermissionError(f"Path not allowed: {path} resolves outside {scope}")
|
||||
|
||||
return real_path
|
||||
|
||||
|
||||
def _file_read(path: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int = MAX_FILE_SIZE) -> str:
|
||||
"""
|
||||
Read file content; path must be under scope.
|
||||
|
||||
Args:
|
||||
path: File path to read.
|
||||
scope: Allowed directory scope.
|
||||
max_size: Maximum file size in bytes.
|
||||
|
||||
Returns:
|
||||
File contents as string.
|
||||
|
||||
Raises:
|
||||
PermissionError: If path is outside scope.
|
||||
FileSizeError: If file exceeds max_size.
|
||||
"""
|
||||
real_path = _normalize_path(path, scope)
|
||||
|
||||
# Check file size before reading
|
||||
try:
|
||||
file_size = os.path.getsize(real_path)
|
||||
if file_size > max_size:
|
||||
raise FileSizeError(f"File too large: {file_size} bytes (max {max_size})")
|
||||
except OSError as e:
|
||||
raise PermissionError(f"Cannot access file: {e}")
|
||||
|
||||
with open(real_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _file_write(path: str, content: str, scope: str = DEFAULT_FILE_SCOPE, max_size: int = MAX_FILE_SIZE) -> str:
|
||||
"""
|
||||
Write content to file; path must be under scope.
|
||||
|
||||
Args:
|
||||
path: File path to write.
|
||||
content: Content to write.
|
||||
scope: Allowed directory scope.
|
||||
max_size: Maximum content size in bytes.
|
||||
|
||||
Returns:
|
||||
Success message with byte count.
|
||||
|
||||
Raises:
|
||||
PermissionError: If path is outside scope.
|
||||
FileSizeError: If content exceeds max_size.
|
||||
"""
|
||||
# Check content size before writing
|
||||
content_bytes = len(content.encode("utf-8"))
|
||||
if content_bytes > max_size:
|
||||
raise FileSizeError(f"Content too large: {content_bytes} bytes (max {max_size})")
|
||||
|
||||
real_path = _normalize_path(path, scope)
|
||||
|
||||
# Ensure parent directory exists
|
||||
parent_dir = os.path.dirname(real_path)
|
||||
if parent_dir and not os.path.exists(parent_dir):
|
||||
# Check if parent would be under scope
|
||||
_normalize_path(parent_dir, scope)
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
with open(real_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return f"Wrote {content_bytes} bytes to {real_path}"
|
||||
|
||||
|
||||
def _is_private_ip(ip: str) -> bool:
|
||||
"""Check if an IP address is private, loopback, or otherwise unsafe."""
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip)
|
||||
return (
|
||||
addr.is_private
|
||||
or addr.is_loopback
|
||||
or addr.is_link_local
|
||||
or addr.is_multicast
|
||||
or addr.is_reserved
|
||||
or addr.is_unspecified
|
||||
# Block IPv6 mapped IPv4 addresses
|
||||
or (isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped is not None)
|
||||
)
|
||||
except ValueError:
|
||||
return True # Invalid IP is treated as unsafe
|
||||
|
||||
|
||||
def _validate_url(url: str, allow_private: bool = False) -> str:
|
||||
"""
|
||||
Validate a URL for SSRF protection.
|
||||
|
||||
Args:
|
||||
url: URL to validate.
|
||||
allow_private: If True, allow private/internal IPs (default False).
|
||||
|
||||
Returns:
|
||||
The validated URL.
|
||||
|
||||
Raises:
|
||||
SSRFProtectionError: If URL is blocked for security reasons.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except Exception as e:
|
||||
raise SSRFProtectionError(f"Invalid URL: {e}")
|
||||
|
||||
# Only allow HTTP and HTTPS
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise SSRFProtectionError(f"URL scheme not allowed: {parsed.scheme}")
|
||||
|
||||
# Must have a hostname
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise SSRFProtectionError("URL must have a hostname")
|
||||
|
||||
# Block localhost variants
|
||||
localhost_patterns = ["localhost", "127.0.0.1", "::1", "0.0.0.0"]
|
||||
if hostname.lower() in localhost_patterns:
|
||||
raise SSRFProtectionError(f"Localhost URLs not allowed: {hostname}")
|
||||
|
||||
# Block common internal hostnames
|
||||
internal_patterns = [".local", ".internal", ".corp", ".lan", ".home"]
|
||||
for pattern in internal_patterns:
|
||||
if hostname.lower().endswith(pattern):
|
||||
raise SSRFProtectionError(f"Internal hostname not allowed: {hostname}")
|
||||
|
||||
if not allow_private:
|
||||
# Resolve hostname and check if IP is private
|
||||
try:
|
||||
# Get all IP addresses for the hostname
|
||||
ips = socket.getaddrinfo(hostname, parsed.port or (443 if parsed.scheme == "https" else 80))
|
||||
for family, socktype, proto, canonname, sockaddr in ips:
|
||||
ip = sockaddr[0]
|
||||
if _is_private_ip(ip):
|
||||
raise SSRFProtectionError(f"URL resolves to private IP: {ip}")
|
||||
except socket.gaierror as e:
|
||||
# DNS resolution failed - could be a security issue
|
||||
logger.warning(f"DNS resolution failed for {hostname}: {e}")
|
||||
raise SSRFProtectionError(f"Cannot resolve hostname: {hostname}")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def _http_get(url: str, allow_private: bool = False) -> str:
|
||||
"""
|
||||
Simple HTTP GET with SSRF protection.
|
||||
|
||||
Args:
|
||||
url: URL to fetch.
|
||||
allow_private: If True, allow private/internal IPs (default False).
|
||||
|
||||
Returns:
|
||||
Response text. On failure returns a string starting with 'Error: '.
|
||||
"""
|
||||
try:
|
||||
validated_url = _validate_url(url, allow_private=allow_private)
|
||||
except SSRFProtectionError as e:
|
||||
return f"Error: SSRF protection: {e}"
|
||||
|
||||
try:
|
||||
import urllib.request
|
||||
with urllib.request.urlopen(validated_url, timeout=10) as resp:
|
||||
return resp.read().decode("utf-8", errors="replace")
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
def make_file_read_tool(scope: str = DEFAULT_FILE_SCOPE) -> ToolDef:
|
||||
"""File read tool with path scope."""
|
||||
def fn(path: str) -> str:
|
||||
return _file_read(path, scope=scope)
|
||||
return ToolDef(
|
||||
name="file_read",
|
||||
description="Read file content; path must be under allowed scope",
|
||||
fn=fn,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {"path": {"type": "string", "description": "File path"}},
|
||||
"required": ["path"],
|
||||
},
|
||||
permission_scope=["file"],
|
||||
timeout_seconds=5.0,
|
||||
)
|
||||
|
||||
|
||||
def make_file_write_tool(scope: str = DEFAULT_FILE_SCOPE) -> ToolDef:
|
||||
"""File write tool with path scope."""
|
||||
def fn(path: str, content: str) -> str:
|
||||
return _file_write(path, content, scope=scope)
|
||||
return ToolDef(
|
||||
name="file_write",
|
||||
description="Write content to file; path must be under allowed scope",
|
||||
fn=fn,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "File path"},
|
||||
"content": {"type": "string", "description": "Content to write"},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
},
|
||||
permission_scope=["file"],
|
||||
timeout_seconds=5.0,
|
||||
)
|
||||
|
||||
|
||||
def make_http_get_tool() -> ToolDef:
|
||||
"""HTTP GET tool."""
|
||||
return ToolDef(
|
||||
name="http_get",
|
||||
description="Perform HTTP GET request and return response body",
|
||||
fn=_http_get,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {"url": {"type": "string", "description": "URL to fetch"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
permission_scope=["network"],
|
||||
timeout_seconds=15.0,
|
||||
)
|
||||
|
||||
|
||||
def make_query_state_tool(get_state_fn: Callable[[str], Any]) -> ToolDef:
|
||||
"""Internal tool: query task state (injected get_state_fn(task_id) -> state/trace)."""
|
||||
def fn(task_id: str) -> Any:
|
||||
return get_state_fn(task_id)
|
||||
return ToolDef(
|
||||
name="query_state",
|
||||
description="Query task state and trace (internal)",
|
||||
fn=fn,
|
||||
parameters_schema={
|
||||
"type": "object",
|
||||
"properties": {"task_id": {"type": "string"}},
|
||||
"required": ["task_id"],
|
||||
},
|
||||
permission_scope=["internal"],
|
||||
timeout_seconds=2.0,
|
||||
)
|
||||
5
fusionagi/tools/connectors/__init__.py
Normal file
5
fusionagi/tools/connectors/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from fusionagi.tools.connectors.base import BaseConnector
|
||||
from fusionagi.tools.connectors.docs import DocsConnector
|
||||
from fusionagi.tools.connectors.db import DBConnector
|
||||
from fusionagi.tools.connectors.code_runner import CodeRunnerConnector
|
||||
__all__ = ["BaseConnector", "DocsConnector", "DBConnector", "CodeRunnerConnector"]
|
||||
9
fusionagi/tools/connectors/base.py
Normal file
9
fusionagi/tools/connectors/base.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
class BaseConnector(ABC):
|
||||
name = "base"
|
||||
@abstractmethod
|
||||
def invoke(self, action: str, params: dict) -> Any: ...
|
||||
def schema(self) -> dict:
|
||||
return {"name": self.name, "actions": [], "parameters": {}}
|
||||
20
fusionagi/tools/connectors/code_runner.py
Normal file
20
fusionagi/tools/connectors/code_runner.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Code runner connector: run code in sandbox (stub; extend with safe executor)."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.tools.connectors.base import BaseConnector
|
||||
|
||||
|
||||
class CodeRunnerConnector(BaseConnector):
|
||||
name = "code_runner"
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def invoke(self, action: str, params: dict[str, Any]) -> Any:
|
||||
if action == "run":
|
||||
return {"stdout": "", "stderr": "", "error": "CodeRunnerConnector stub: implement run"}
|
||||
return {"error": f"Unknown action: {action}"}
|
||||
|
||||
def schema(self) -> dict[str, Any]:
|
||||
return {"name": self.name, "actions": ["run"], "parameters": {"code": "string", "language": "string"}}
|
||||
20
fusionagi/tools/connectors/db.py
Normal file
20
fusionagi/tools/connectors/db.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""DB connector: query database (stub; extend with SQL driver)."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.tools.connectors.base import BaseConnector
|
||||
|
||||
|
||||
class DBConnector(BaseConnector):
|
||||
name = "db"
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def invoke(self, action: str, params: dict[str, Any]) -> Any:
|
||||
if action == "query":
|
||||
return {"rows": [], "error": "DBConnector stub: implement query"}
|
||||
return {"error": f"Unknown action: {action}"}
|
||||
|
||||
def schema(self) -> dict[str, Any]:
|
||||
return {"name": self.name, "actions": ["query"], "parameters": {"query": "string"}}
|
||||
21
fusionagi/tools/connectors/docs.py
Normal file
21
fusionagi/tools/connectors/docs.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Docs connector: read documents (stub; extend with PDF/Office)."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.tools.connectors.base import BaseConnector
|
||||
|
||||
|
||||
class DocsConnector(BaseConnector):
|
||||
name = "docs"
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def invoke(self, action: str, params: dict[str, Any]) -> Any:
|
||||
if action == "read":
|
||||
path = params.get("path", "")
|
||||
return {"content": "", "path": path, "error": "DocsConnector stub: implement read"}
|
||||
return {"error": f"Unknown action: {action}"}
|
||||
|
||||
def schema(self) -> dict[str, Any]:
|
||||
return {"name": self.name, "actions": ["read"], "parameters": {"path": "string"}}
|
||||
69
fusionagi/tools/registry.py
Normal file
69
fusionagi/tools/registry.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Tool registry: register tools by name; resolve by name and check permissions."""
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
# Tool schema: name, description, parameters (JSON Schema), permission_scope
|
||||
# Invocation: (args: dict) -> result
|
||||
ToolFn = Callable[..., Any]
|
||||
|
||||
|
||||
class ToolDef:
|
||||
"""Tool definition: name, description, parameters schema, permission scope, timeout."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
fn: ToolFn,
|
||||
parameters_schema: dict[str, Any] | None = None,
|
||||
permission_scope: str | list[str] = "*",
|
||||
timeout_seconds: float = 30.0,
|
||||
manufacturing: bool = False,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.fn = fn
|
||||
self.parameters_schema = parameters_schema or {"type": "object", "properties": {}}
|
||||
self.permission_scope = permission_scope if isinstance(permission_scope, list) else [permission_scope]
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.manufacturing = manufacturing
|
||||
|
||||
def to_schema(self) -> dict[str, Any]:
|
||||
"""JSON Schema for this tool."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters_schema,
|
||||
"permission_scope": self.permission_scope,
|
||||
"timeout_seconds": self.timeout_seconds,
|
||||
"manufacturing": self.manufacturing,
|
||||
}
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""Register and resolve tools by name; check agent permissions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tools: dict[str, ToolDef] = {}
|
||||
|
||||
def register(self, tool: ToolDef) -> None:
|
||||
"""Register a tool by name."""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def get(self, name: str) -> ToolDef | None:
|
||||
"""Return tool definition by name or None."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_tools(self) -> list[dict[str, Any]]:
|
||||
"""Return list of tool schemas."""
|
||||
return [t.to_schema() for t in self._tools.values()]
|
||||
|
||||
def allowed_for(self, tool_name: str, agent_permissions: list[str] | str) -> bool:
|
||||
"""Return True if agent is allowed to use this tool (* or matching scope)."""
|
||||
tool = self._tools.get(tool_name)
|
||||
if not tool:
|
||||
return False
|
||||
perms = agent_permissions if isinstance(agent_permissions, list) else [agent_permissions]
|
||||
if "*" in tool.permission_scope or "*" in perms:
|
||||
return True
|
||||
return bool(set(tool.permission_scope) & set(perms))
|
||||
221
fusionagi/tools/runner.py
Normal file
221
fusionagi/tools/runner.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Safe runner: invoke tool with timeout, input validation, and failure handling; log for replay."""
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fusionagi.governance.audit_log import AuditLog
|
||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
||||
from typing import Any
|
||||
|
||||
from fusionagi.tools.registry import ToolDef
|
||||
from fusionagi._logger import logger
|
||||
|
||||
|
||||
class ToolValidationError(Exception):
|
||||
"""Raised when tool arguments fail validation."""
|
||||
|
||||
def __init__(self, tool_name: str, message: str, details: dict[str, Any] | None = None):
|
||||
self.tool_name = tool_name
|
||||
self.details = details or {}
|
||||
super().__init__(f"Tool {tool_name}: {message}")
|
||||
|
||||
|
||||
def validate_args(tool: ToolDef, args: dict[str, Any]) -> tuple[bool, str]:
|
||||
"""
|
||||
Validate arguments against tool's JSON schema.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message). error_message is empty if valid.
|
||||
"""
|
||||
schema = tool.parameters_schema
|
||||
if not schema:
|
||||
return True, ""
|
||||
|
||||
# Basic JSON schema validation (without external dependency)
|
||||
schema_type = schema.get("type", "object")
|
||||
if schema_type != "object":
|
||||
return True, "" # Only validate object schemas
|
||||
|
||||
properties = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
# Check required fields
|
||||
for field in required:
|
||||
if field not in args:
|
||||
return False, f"Missing required argument: {field}"
|
||||
|
||||
# Check types of provided fields
|
||||
for field, value in args.items():
|
||||
if field not in properties:
|
||||
# Allow extra fields by default (additionalProperties: true is common)
|
||||
continue
|
||||
|
||||
prop_schema = properties[field]
|
||||
prop_type = prop_schema.get("type")
|
||||
|
||||
if prop_type is None:
|
||||
continue
|
||||
|
||||
# Type checking
|
||||
type_valid = True
|
||||
if prop_type == "string":
|
||||
type_valid = isinstance(value, str)
|
||||
elif prop_type == "integer":
|
||||
type_valid = isinstance(value, int) and not isinstance(value, bool)
|
||||
elif prop_type == "number":
|
||||
type_valid = isinstance(value, (int, float)) and not isinstance(value, bool)
|
||||
elif prop_type == "boolean":
|
||||
type_valid = isinstance(value, bool)
|
||||
elif prop_type == "array":
|
||||
type_valid = isinstance(value, list)
|
||||
elif prop_type == "object":
|
||||
type_valid = isinstance(value, dict)
|
||||
elif prop_type == "null":
|
||||
type_valid = value is None
|
||||
|
||||
if not type_valid:
|
||||
return False, f"Argument '{field}' must be of type {prop_type}, got {type(value).__name__}"
|
||||
|
||||
# String constraints
|
||||
if prop_type == "string" and isinstance(value, str):
|
||||
min_len = prop_schema.get("minLength")
|
||||
max_len = prop_schema.get("maxLength")
|
||||
pattern = prop_schema.get("pattern")
|
||||
|
||||
if min_len is not None and len(value) < min_len:
|
||||
return False, f"Argument '{field}' must be at least {min_len} characters"
|
||||
if max_len is not None and len(value) > max_len:
|
||||
return False, f"Argument '{field}' must be at most {max_len} characters"
|
||||
if pattern:
|
||||
import re
|
||||
if not re.match(pattern, value):
|
||||
return False, f"Argument '{field}' does not match pattern: {pattern}"
|
||||
|
||||
# Number constraints
|
||||
if prop_type in ("integer", "number") and isinstance(value, (int, float)):
|
||||
minimum = prop_schema.get("minimum")
|
||||
maximum = prop_schema.get("maximum")
|
||||
exclusive_min = prop_schema.get("exclusiveMinimum")
|
||||
exclusive_max = prop_schema.get("exclusiveMaximum")
|
||||
|
||||
if minimum is not None and value < minimum:
|
||||
return False, f"Argument '{field}' must be >= {minimum}"
|
||||
if maximum is not None and value > maximum:
|
||||
return False, f"Argument '{field}' must be <= {maximum}"
|
||||
if exclusive_min is not None and value <= exclusive_min:
|
||||
return False, f"Argument '{field}' must be > {exclusive_min}"
|
||||
if exclusive_max is not None and value >= exclusive_max:
|
||||
return False, f"Argument '{field}' must be < {exclusive_max}"
|
||||
|
||||
# Enum constraint
|
||||
enum = prop_schema.get("enum")
|
||||
if enum is not None and value not in enum:
|
||||
return False, f"Argument '{field}' must be one of: {enum}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def run_tool(
|
||||
tool: ToolDef,
|
||||
args: dict[str, Any],
|
||||
timeout_seconds: float | None = None,
|
||||
validate: bool = True,
|
||||
) -> tuple[Any, dict[str, Any]]:
|
||||
"""
|
||||
Invoke tool.fn(args) with optional validation and timeout.
|
||||
|
||||
Args:
|
||||
tool: The tool definition to execute.
|
||||
args: Arguments to pass to the tool function.
|
||||
timeout_seconds: Override timeout (uses tool.timeout_seconds if None).
|
||||
validate: Whether to validate args against tool's schema (default True).
|
||||
|
||||
Returns:
|
||||
Tuple of (result, log_entry). On error, result is None and log_entry contains error.
|
||||
"""
|
||||
timeout = timeout_seconds if timeout_seconds is not None else tool.timeout_seconds
|
||||
start = time.monotonic()
|
||||
log_entry: dict[str, Any] = {
|
||||
"tool": tool.name,
|
||||
"args": args,
|
||||
"result": None,
|
||||
"error": None,
|
||||
"duration_seconds": None,
|
||||
"validated": validate,
|
||||
}
|
||||
|
||||
# Validate arguments before execution
|
||||
if validate:
|
||||
is_valid, error_msg = validate_args(tool, args)
|
||||
if not is_valid:
|
||||
log_entry["error"] = f"Validation error: {error_msg}"
|
||||
log_entry["duration_seconds"] = round(time.monotonic() - start, 3)
|
||||
logger.warning(
|
||||
"Tool validation failed",
|
||||
extra={"tool": tool.name, "error": error_msg},
|
||||
)
|
||||
return None, log_entry
|
||||
|
||||
def _invoke() -> Any:
|
||||
return tool.fn(**args)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as ex:
|
||||
fut = ex.submit(_invoke)
|
||||
result = fut.result(timeout=timeout if timeout and timeout > 0 else None)
|
||||
log_entry["result"] = result
|
||||
logger.debug(
|
||||
"Tool executed successfully",
|
||||
extra={"tool": tool.name, "duration": log_entry.get("duration_seconds")},
|
||||
)
|
||||
return result, log_entry
|
||||
except FuturesTimeoutError:
|
||||
log_entry["error"] = f"Tool {tool.name} timed out after {timeout}s"
|
||||
logger.warning(
|
||||
"Tool timed out",
|
||||
extra={"tool": tool.name, "timeout": timeout},
|
||||
)
|
||||
return None, log_entry
|
||||
except Exception as e:
|
||||
log_entry["error"] = str(e)
|
||||
logger.error(
|
||||
"Tool execution failed",
|
||||
extra={"tool": tool.name, "error": str(e), "error_type": type(e).__name__},
|
||||
)
|
||||
return None, log_entry
|
||||
finally:
|
||||
log_entry["duration_seconds"] = round(time.monotonic() - start, 3)
|
||||
|
||||
|
||||
def run_tool_with_audit(
|
||||
tool: ToolDef,
|
||||
args: dict[str, Any],
|
||||
audit_log: "AuditLog",
|
||||
actor: str = "system",
|
||||
task_id: str | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
validate: bool = True,
|
||||
) -> tuple[Any, dict[str, Any]]:
|
||||
"""
|
||||
Invoke tool and log to AuditLog.
|
||||
Sanitizes args in log (e.g. truncate long values).
|
||||
"""
|
||||
from fusionagi.schemas.audit import AuditEventType
|
||||
|
||||
sanitized = {}
|
||||
for k, v in args.items():
|
||||
if isinstance(v, str) and len(v) > 200:
|
||||
sanitized[k] = v[:200] + "..."
|
||||
else:
|
||||
sanitized[k] = v
|
||||
result, log_entry = run_tool(tool, args, timeout_seconds, validate)
|
||||
audit_log.append(
|
||||
AuditEventType.TOOL_CALL,
|
||||
actor,
|
||||
action=f"tool:{tool.name}",
|
||||
task_id=task_id,
|
||||
payload={"tool": tool.name, "args": sanitized, "error": log_entry.get("error")},
|
||||
outcome="success" if result is not None else "failure",
|
||||
)
|
||||
return result, log_entry
|
||||
Reference in New Issue
Block a user