107 lines
3.5 KiB
Python
107 lines
3.5 KiB
Python
|
|
"""Async background task queue for long-running operations."""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import time
|
||
|
|
import uuid
|
||
|
|
from enum import Enum
|
||
|
|
from typing import Any, Callable, Coroutine
|
||
|
|
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
|
||
|
|
|
||
|
|
class TaskStatus(str, Enum):
|
||
|
|
"""Background task status."""
|
||
|
|
PENDING = "pending"
|
||
|
|
RUNNING = "running"
|
||
|
|
COMPLETED = "completed"
|
||
|
|
FAILED = "failed"
|
||
|
|
CANCELLED = "cancelled"
|
||
|
|
|
||
|
|
|
||
|
|
class TaskResult(BaseModel):
|
||
|
|
"""Result of a background task."""
|
||
|
|
task_id: str
|
||
|
|
status: TaskStatus
|
||
|
|
result: Any = None
|
||
|
|
error: str | None = None
|
||
|
|
created_at: float = Field(default_factory=time.time)
|
||
|
|
completed_at: float | None = None
|
||
|
|
duration_ms: float | None = None
|
||
|
|
|
||
|
|
|
||
|
|
class BackgroundTaskQueue:
|
||
|
|
"""Async task queue for offloading long-running work.
|
||
|
|
|
||
|
|
Tasks are submitted and run concurrently via asyncio. Results are
|
||
|
|
stored in-memory and queryable by task_id.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, max_concurrent: int = 5, result_ttl: float = 3600.0) -> None:
|
||
|
|
self._semaphore = asyncio.Semaphore(max_concurrent)
|
||
|
|
self._results: dict[str, TaskResult] = {}
|
||
|
|
self._tasks: dict[str, asyncio.Task[None]] = {}
|
||
|
|
self._result_ttl = result_ttl
|
||
|
|
|
||
|
|
def submit(
|
||
|
|
self,
|
||
|
|
fn: Callable[..., Coroutine[Any, Any, Any]],
|
||
|
|
*args: Any,
|
||
|
|
task_id: str | None = None,
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> str:
|
||
|
|
"""Submit a coroutine to run in the background. Returns task_id."""
|
||
|
|
tid = task_id or str(uuid.uuid4())
|
||
|
|
self._results[tid] = TaskResult(task_id=tid, status=TaskStatus.PENDING)
|
||
|
|
|
||
|
|
async def _runner() -> None:
|
||
|
|
async with self._semaphore:
|
||
|
|
self._results[tid].status = TaskStatus.RUNNING
|
||
|
|
start = time.time()
|
||
|
|
try:
|
||
|
|
result = await fn(*args, **kwargs)
|
||
|
|
self._results[tid].result = result
|
||
|
|
self._results[tid].status = TaskStatus.COMPLETED
|
||
|
|
except Exception as e:
|
||
|
|
self._results[tid].error = str(e)
|
||
|
|
self._results[tid].status = TaskStatus.FAILED
|
||
|
|
finally:
|
||
|
|
self._results[tid].completed_at = time.time()
|
||
|
|
self._results[tid].duration_ms = (time.time() - start) * 1000
|
||
|
|
|
||
|
|
loop = asyncio.get_event_loop()
|
||
|
|
task = loop.create_task(_runner())
|
||
|
|
self._tasks[tid] = task
|
||
|
|
return tid
|
||
|
|
|
||
|
|
def get_status(self, task_id: str) -> TaskResult | None:
|
||
|
|
"""Get the status and result of a task."""
|
||
|
|
return self._results.get(task_id)
|
||
|
|
|
||
|
|
def cancel(self, task_id: str) -> bool:
|
||
|
|
"""Cancel a pending or running task."""
|
||
|
|
task = self._tasks.get(task_id)
|
||
|
|
if task and not task.done():
|
||
|
|
task.cancel()
|
||
|
|
self._results[task_id].status = TaskStatus.CANCELLED
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
def list_tasks(self, status: TaskStatus | None = None) -> list[TaskResult]:
|
||
|
|
"""List all tasks, optionally filtered by status."""
|
||
|
|
results = list(self._results.values())
|
||
|
|
if status:
|
||
|
|
results = [r for r in results if r.status == status]
|
||
|
|
return results
|
||
|
|
|
||
|
|
def cleanup_expired(self) -> int:
|
||
|
|
"""Remove completed tasks older than result_ttl."""
|
||
|
|
now = time.time()
|
||
|
|
expired = [
|
||
|
|
tid for tid, r in self._results.items()
|
||
|
|
if r.completed_at and (now - r.completed_at) > self._result_ttl
|
||
|
|
]
|
||
|
|
for tid in expired:
|
||
|
|
del self._results[tid]
|
||
|
|
self._tasks.pop(tid, None)
|
||
|
|
return len(expired)
|