Files
FusionAGI/tests/test_core_enhanced.py
defiQUG c052b07662
Some checks failed
Tests / test (3.10) (push) Has been cancelled
Tests / test (3.11) (push) Has been cancelled
Tests / test (3.12) (push) Has been cancelled
Tests / lint (push) Has been cancelled
Tests / docker (push) Has been cancelled
Initial commit: add .gitignore and README
2026-02-09 21:51:42 -08:00

247 lines
8.4 KiB
Python

"""Tests for enhanced core module functionality."""
import pytest
from fusionagi.core import (
EventBus,
StateManager,
Orchestrator,
InvalidStateTransitionError,
VALID_STATE_TRANSITIONS,
JsonFileBackend,
)
from fusionagi.schemas.task import Task, TaskState, TaskPriority
class TestStateManagerWithBackend:
"""Test StateManager with persistence backend integration."""
def test_state_manager_basic_operations(self):
"""Test basic get/set operations."""
sm = StateManager()
task = Task(task_id="test-1", goal="Test goal")
sm.set_task(task)
retrieved = sm.get_task("test-1")
assert retrieved is not None
assert retrieved.task_id == "test-1"
assert retrieved.goal == "Test goal"
def test_state_manager_task_state(self):
"""Test task state operations."""
sm = StateManager()
task = Task(task_id="test-2", goal="Test")
sm.set_task(task)
assert sm.get_task_state("test-2") == TaskState.PENDING
sm.set_task_state("test-2", TaskState.ACTIVE)
assert sm.get_task_state("test-2") == TaskState.ACTIVE
def test_state_manager_trace(self):
"""Test trace append and retrieval."""
sm = StateManager()
task = Task(task_id="test-3", goal="Test")
sm.set_task(task)
sm.append_trace("test-3", {"step": "step1", "result": "ok"})
sm.append_trace("test-3", {"step": "step2", "result": "ok"})
trace = sm.get_trace("test-3")
assert len(trace) == 2
assert trace[0]["step"] == "step1"
assert trace[1]["step"] == "step2"
def test_state_manager_list_tasks(self):
"""Test listing tasks with filter."""
sm = StateManager()
sm.set_task(Task(task_id="t1", goal="Goal 1", state=TaskState.PENDING))
sm.set_task(Task(task_id="t2", goal="Goal 2", state=TaskState.ACTIVE))
sm.set_task(Task(task_id="t3", goal="Goal 3", state=TaskState.ACTIVE))
all_tasks = sm.list_tasks()
assert len(all_tasks) == 3
active_tasks = sm.list_tasks(state=TaskState.ACTIVE)
assert len(active_tasks) == 2
def test_state_manager_task_count(self):
"""Test task counting."""
sm = StateManager()
assert sm.task_count() == 0
sm.set_task(Task(task_id="t1", goal="Goal 1"))
sm.set_task(Task(task_id="t2", goal="Goal 2"))
assert sm.task_count() == 2
class TestJsonFileBackend:
"""Test JsonFileBackend persistence."""
def test_json_file_backend_roundtrip(self, tmp_path):
"""Test task and trace persist to JSON file."""
path = tmp_path / "state.json"
backend = JsonFileBackend(path)
task = Task(task_id="tb1", goal="Backend goal", state=TaskState.ACTIVE)
backend.set_task(task)
backend.append_trace("tb1", {"step": "s1", "result": "ok"})
assert path.exists()
backend2 = JsonFileBackend(path)
loaded = backend2.get_task("tb1")
assert loaded is not None
assert loaded.goal == "Backend goal"
assert loaded.state == TaskState.ACTIVE
trace = backend2.get_trace("tb1")
assert len(trace) == 1
assert trace[0]["step"] == "s1"
def test_json_file_backend_set_task_state(self, tmp_path):
"""Test set_task_state updates persisted task."""
path = tmp_path / "state.json"
backend = JsonFileBackend(path)
task = Task(task_id="tb2", goal="Goal", state=TaskState.PENDING)
backend.set_task(task)
backend.set_task_state("tb2", TaskState.COMPLETED)
backend2 = JsonFileBackend(path)
assert backend2.get_task_state("tb2") == TaskState.COMPLETED
class TestOrchestratorStateTransitions:
"""Test Orchestrator state transition validation."""
def test_valid_transitions(self):
"""Test valid state transitions."""
bus = EventBus()
state = StateManager()
orch = Orchestrator(event_bus=bus, state_manager=state)
task_id = orch.submit_task(goal="Test task")
# PENDING -> ACTIVE is valid
orch.set_task_state(task_id, TaskState.ACTIVE)
assert orch.get_task_state(task_id) == TaskState.ACTIVE
# ACTIVE -> COMPLETED is valid
orch.set_task_state(task_id, TaskState.COMPLETED)
assert orch.get_task_state(task_id) == TaskState.COMPLETED
def test_invalid_transition_raises(self):
"""Test that invalid transitions raise an error."""
bus = EventBus()
state = StateManager()
orch = Orchestrator(event_bus=bus, state_manager=state)
task_id = orch.submit_task(goal="Test task")
orch.set_task_state(task_id, TaskState.ACTIVE)
orch.set_task_state(task_id, TaskState.COMPLETED)
# COMPLETED -> ACTIVE is invalid (terminal state)
with pytest.raises(InvalidStateTransitionError) as exc_info:
orch.set_task_state(task_id, TaskState.ACTIVE)
assert exc_info.value.task_id == task_id
assert exc_info.value.from_state == TaskState.COMPLETED
assert exc_info.value.to_state == TaskState.ACTIVE
def test_can_transition(self):
"""Test can_transition helper method."""
bus = EventBus()
state = StateManager()
orch = Orchestrator(event_bus=bus, state_manager=state)
task_id = orch.submit_task(goal="Test task")
assert orch.can_transition(task_id, TaskState.ACTIVE) is True
assert orch.can_transition(task_id, TaskState.CANCELLED) is True
assert orch.can_transition(task_id, TaskState.COMPLETED) is False # Can't skip ACTIVE
def test_force_transition(self):
"""Test force=True bypasses validation."""
bus = EventBus()
state = StateManager()
orch = Orchestrator(event_bus=bus, state_manager=state)
task_id = orch.submit_task(goal="Test task")
orch.set_task_state(task_id, TaskState.ACTIVE)
orch.set_task_state(task_id, TaskState.COMPLETED)
# Force allows invalid transition
orch.set_task_state(task_id, TaskState.PENDING, force=True)
assert orch.get_task_state(task_id) == TaskState.PENDING
def test_failed_to_pending_retry(self):
"""Test that FAILED can transition to PENDING for retry."""
bus = EventBus()
state = StateManager()
orch = Orchestrator(event_bus=bus, state_manager=state)
task_id = orch.submit_task(goal="Test task")
orch.set_task_state(task_id, TaskState.ACTIVE)
orch.set_task_state(task_id, TaskState.FAILED)
# FAILED -> PENDING is valid (retry)
orch.set_task_state(task_id, TaskState.PENDING)
assert orch.get_task_state(task_id) == TaskState.PENDING
class TestEventBus:
"""Test EventBus functionality."""
def test_publish_subscribe(self):
"""Test basic pub/sub."""
bus = EventBus()
received = []
def handler(event_type, payload):
received.append({"type": event_type, "payload": payload})
bus.subscribe("test_event", handler)
bus.publish("test_event", {"data": "value"})
assert len(received) == 1
assert received[0]["payload"]["data"] == "value"
def test_multiple_subscribers(self):
"""Test multiple subscribers receive events."""
bus = EventBus()
received1 = []
received2 = []
bus.subscribe("test", lambda t, p: received1.append(p))
bus.subscribe("test", lambda t, p: received2.append(p))
bus.publish("test", {"n": 1})
assert len(received1) == 1
assert len(received2) == 1
def test_unsubscribe(self):
"""Test unsubscribe stops delivery."""
bus = EventBus()
received = []
def handler(t, p):
received.append(p)
bus.subscribe("test", handler)
bus.publish("test", {})
assert len(received) == 1
bus.unsubscribe("test", handler)
bus.publish("test", {})
assert len(received) == 1 # No new messages
def test_clear(self):
"""Test clear removes all subscribers."""
bus = EventBus()
received = []
bus.subscribe("test", lambda t, p: received.append(p))
bus.clear()
bus.publish("test", {})
assert len(received) == 0