65 lines
2.0 KiB
Python
65 lines
2.0 KiB
Python
|
|
"""Request tracing middleware for structured logging with correlation IDs."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import contextvars
|
||
|
|
import uuid
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
trace_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("trace_id", default="")
|
||
|
|
|
||
|
|
|
||
|
|
def get_trace_id() -> str:
|
||
|
|
"""Get current trace ID from context."""
|
||
|
|
return trace_id_var.get() or ""
|
||
|
|
|
||
|
|
|
||
|
|
def set_trace_id(trace_id: str) -> None:
|
||
|
|
"""Set trace ID in current context."""
|
||
|
|
trace_id_var.set(trace_id)
|
||
|
|
|
||
|
|
|
||
|
|
def generate_trace_id() -> str:
|
||
|
|
"""Generate a new trace ID."""
|
||
|
|
return str(uuid.uuid4())[:8]
|
||
|
|
|
||
|
|
|
||
|
|
class TracingMiddleware:
|
||
|
|
"""ASGI middleware that sets/propagates request trace IDs.
|
||
|
|
|
||
|
|
Extracts trace ID from X-Request-ID header or generates a new one.
|
||
|
|
Injects trace ID into response headers and logging context.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, app: Any, header_name: str = "X-Request-ID") -> None:
|
||
|
|
self.app = app
|
||
|
|
self.header_name = header_name.lower()
|
||
|
|
|
||
|
|
async def __call__(self, scope: dict[str, Any], receive: Any, send: Any) -> None:
|
||
|
|
"""ASGI entrypoint."""
|
||
|
|
if scope["type"] not in ("http", "websocket"):
|
||
|
|
await self.app(scope, receive, send)
|
||
|
|
return
|
||
|
|
|
||
|
|
headers = dict(scope.get("headers", []))
|
||
|
|
trace_id = ""
|
||
|
|
for k, v in headers.items():
|
||
|
|
if isinstance(k, bytes) and k.decode("latin-1").lower() == self.header_name:
|
||
|
|
trace_id = v.decode("latin-1") if isinstance(v, bytes) else str(v)
|
||
|
|
break
|
||
|
|
|
||
|
|
if not trace_id:
|
||
|
|
trace_id = generate_trace_id()
|
||
|
|
|
||
|
|
set_trace_id(trace_id)
|
||
|
|
|
||
|
|
async def send_with_trace(message: dict[str, Any]) -> None:
|
||
|
|
if message["type"] == "http.response.start":
|
||
|
|
headers_list = list(message.get("headers", []))
|
||
|
|
headers_list.append((b"x-request-id", trace_id.encode()))
|
||
|
|
headers_list.append((b"x-trace-id", trace_id.encode()))
|
||
|
|
message["headers"] = headers_list
|
||
|
|
await send(message)
|
||
|
|
|
||
|
|
await self.app(scope, receive, send_with_trace)
|