72 lines
3.3 KiB
Python
72 lines
3.3 KiB
Python
"""Guardrails: pre/post checks for tool calls (block paths, sanitize inputs)."""
|
|
|
|
import re
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
class PreCheckResult(BaseModel):
|
|
"""Result of a guardrails pre-check: allowed, optional sanitized args, optional error message."""
|
|
|
|
allowed: bool = Field(..., description="Whether the call is allowed")
|
|
sanitized_args: dict[str, Any] | None = Field(default=None, description="Args to use if allowed and sanitized")
|
|
error_message: str | None = Field(default=None, description="Reason for denial if not allowed")
|
|
|
|
|
|
class Guardrails:
|
|
"""Pre/post checks for tool invocations."""
|
|
|
|
def __init__(self) -> None:
|
|
self._blocked_paths: list[str] = []
|
|
self._blocked_patterns: list[re.Pattern[str]] = []
|
|
self._custom_checks: list[Any] = []
|
|
|
|
def block_path_prefix(self, prefix: str) -> None:
|
|
"""Block any file path starting with this prefix."""
|
|
self._blocked_paths.append(prefix.rstrip("/"))
|
|
|
|
def block_path_pattern(self, pattern: str) -> None:
|
|
"""Block paths matching this regex."""
|
|
self._blocked_patterns.append(re.compile(pattern))
|
|
|
|
def add_check(self, check: Any) -> None:
|
|
"""
|
|
Add a custom pre-check. Check receives (tool_name, args); must not mutate caller's args.
|
|
Returns (allowed, sanitized_args or error_message): (True, dict) or (True, None) or (False, str).
|
|
Returned sanitized_args are used for subsequent checks and invocation.
|
|
"""
|
|
self._custom_checks.append(check)
|
|
|
|
def pre_check(self, tool_name: str, args: dict[str, Any]) -> PreCheckResult:
|
|
"""Run all pre-checks. Returns PreCheckResult (allowed, sanitized_args, error_message)."""
|
|
args = dict(args) # Copy to avoid mutating caller's args
|
|
for key in ("path", "file_path"):
|
|
if key in args and isinstance(args[key], str):
|
|
path = args[key]
|
|
for prefix in self._blocked_paths:
|
|
if path.startswith(prefix) or path.startswith(prefix + "/"):
|
|
reason = "Blocked path prefix: " + prefix
|
|
logger.info("Guardrails pre_check blocked", extra={"tool_name": tool_name, "reason": reason})
|
|
return PreCheckResult(allowed=False, error_message=reason)
|
|
for pat in self._blocked_patterns:
|
|
if pat.search(path):
|
|
reason = "Blocked path pattern"
|
|
logger.info("Guardrails pre_check blocked", extra={"tool_name": tool_name, "reason": reason})
|
|
return PreCheckResult(allowed=False, error_message=reason)
|
|
for check in self._custom_checks:
|
|
allowed, result = check(tool_name, args)
|
|
if not allowed:
|
|
reason = result if isinstance(result, str) else "Check failed"
|
|
logger.info("Guardrails pre_check blocked", extra={"tool_name": tool_name, "reason": reason})
|
|
return PreCheckResult(allowed=False, error_message=reason)
|
|
if isinstance(result, dict):
|
|
args = result
|
|
return PreCheckResult(allowed=True, sanitized_args=args)
|
|
|
|
def post_check(self, tool_name: str, result: Any) -> tuple[bool, str]:
|
|
"""Optional post-check; return (True, "") or (False, error_message)."""
|
|
return True, ""
|