"""Policy engine: hard constraints independent of LLM for AGI.""" from typing import Any from fusionagi.schemas.policy import PolicyEffect, PolicyRule from fusionagi._logger import logger class PolicyEngine: """Evaluates policy rules; higher priority first; first match wins (allow/deny).""" def __init__(self) -> None: self._rules: list[PolicyRule] = [] def add_rule(self, rule: PolicyRule) -> None: self._rules.append(rule) self._rules.sort(key=lambda r: -r.priority) logger.debug("PolicyEngine: rule added", extra={"rule_id": rule.rule_id}) def get_rules(self) -> list[PolicyRule]: """Return all rules (copy).""" return list(self._rules) def get_rule(self, rule_id: str) -> PolicyRule | None: """Return rule by id or None.""" for r in self._rules: if r.rule_id == rule_id: return r return None def update_rule(self, rule_id: str, updates: dict[str, Any]) -> bool: """ Update an existing rule by id. Updates can include condition, effect, reason, priority. Returns True if updated, False if rule_id not found. """ for i, r in enumerate(self._rules): if r.rule_id == rule_id: allowed = {"condition", "effect", "reason", "priority"} data = r.model_dump() for k, v in updates.items(): if k in allowed: data[k] = v self._rules[i] = PolicyRule.model_validate(data) self._rules.sort(key=lambda x: -x.priority) logger.debug("PolicyEngine: rule updated", extra={"rule_id": rule_id}) return True return False def remove_rule(self, rule_id: str) -> bool: """Remove a rule by id. Returns True if removed.""" for i, r in enumerate(self._rules): if r.rule_id == rule_id: self._rules.pop(i) logger.debug("PolicyEngine: rule removed", extra={"rule_id": rule_id}) return True return False def check(self, action: str, context: dict[str, Any]) -> tuple[bool, str]: """ Returns (allowed, reason). Context has e.g. tool_name, domain, data_class, agent_id. """ for rule in self._rules: if self._match(rule.condition, context): if rule.effect == PolicyEffect.DENY: return False, rule.reason or "Policy denied" return True, rule.reason or "Policy allowed" return True, "" def _match(self, condition: dict[str, Any], context: dict[str, Any]) -> bool: for k, v in condition.items(): if context.get(k) != v: return False return True