Security and Code Quality Remediation (M6-Fixes) #343
299
apps/coordinator/src/circuit_breaker.py
Normal file
299
apps/coordinator/src/circuit_breaker.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""Circuit breaker pattern for preventing infinite retry loops.
|
||||
|
||||
This module provides a CircuitBreaker class that implements the circuit breaker
|
||||
pattern to protect against cascading failures in coordinator loops.
|
||||
|
||||
Circuit breaker states:
|
||||
- CLOSED: Normal operation, requests pass through
|
||||
- OPEN: After N consecutive failures, all requests are blocked
|
||||
- HALF_OPEN: After cooldown, allow one request to test recovery
|
||||
|
||||
Reference: SEC-ORCH-7 from security review
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CircuitState(str, Enum):
|
||||
"""States for the circuit breaker."""
|
||||
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Blocking requests after failures
|
||||
HALF_OPEN = "half_open" # Testing if service recovered
|
||||
|
||||
|
||||
class CircuitBreakerError(Exception):
|
||||
"""Exception raised when circuit is open and blocking requests."""
|
||||
|
||||
def __init__(self, state: CircuitState, time_until_retry: float) -> None:
|
||||
"""Initialize CircuitBreakerError.
|
||||
|
||||
Args:
|
||||
state: Current circuit state
|
||||
time_until_retry: Seconds until circuit may close
|
||||
"""
|
||||
self.state = state
|
||||
self.time_until_retry = time_until_retry
|
||||
super().__init__(
|
||||
f"Circuit breaker is {state.value}. "
|
||||
f"Retry in {time_until_retry:.1f} seconds."
|
||||
)
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""Circuit breaker for protecting against cascading failures.
|
||||
|
||||
The circuit breaker tracks consecutive failures and opens the circuit
|
||||
after a threshold is reached, preventing further requests until a
|
||||
cooldown period has elapsed.
|
||||
|
||||
Attributes:
|
||||
name: Identifier for this circuit breaker (for logging)
|
||||
failure_threshold: Number of consecutive failures before opening
|
||||
cooldown_seconds: Seconds to wait before allowing retry
|
||||
state: Current circuit state
|
||||
failure_count: Current consecutive failure count
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
failure_threshold: int = 5,
|
||||
cooldown_seconds: float = 30.0,
|
||||
) -> None:
|
||||
"""Initialize CircuitBreaker.
|
||||
|
||||
Args:
|
||||
name: Identifier for this circuit breaker
|
||||
failure_threshold: Consecutive failures before opening (default: 5)
|
||||
cooldown_seconds: Seconds to wait before half-open (default: 30)
|
||||
"""
|
||||
self.name = name
|
||||
self.failure_threshold = failure_threshold
|
||||
self.cooldown_seconds = cooldown_seconds
|
||||
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._last_failure_time: float | None = None
|
||||
self._total_failures = 0
|
||||
self._total_successes = 0
|
||||
self._state_transitions = 0
|
||||
|
||||
@property
|
||||
def state(self) -> CircuitState:
|
||||
"""Get the current circuit state.
|
||||
|
||||
This also handles automatic state transitions based on cooldown.
|
||||
|
||||
Returns:
|
||||
Current CircuitState
|
||||
"""
|
||||
if self._state == CircuitState.OPEN:
|
||||
# Check if cooldown has elapsed
|
||||
if self._last_failure_time is not None:
|
||||
elapsed = time.time() - self._last_failure_time
|
||||
if elapsed >= self.cooldown_seconds:
|
||||
self._transition_to(CircuitState.HALF_OPEN)
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
"""Get current consecutive failure count.
|
||||
|
||||
Returns:
|
||||
Number of consecutive failures
|
||||
"""
|
||||
return self._failure_count
|
||||
|
||||
@property
|
||||
def total_failures(self) -> int:
|
||||
"""Get total failure count (all-time).
|
||||
|
||||
Returns:
|
||||
Total number of failures
|
||||
"""
|
||||
return self._total_failures
|
||||
|
||||
@property
|
||||
def total_successes(self) -> int:
|
||||
"""Get total success count (all-time).
|
||||
|
||||
Returns:
|
||||
Total number of successes
|
||||
"""
|
||||
return self._total_successes
|
||||
|
||||
@property
|
||||
def state_transitions(self) -> int:
|
||||
"""Get total state transition count.
|
||||
|
||||
Returns:
|
||||
Number of state transitions
|
||||
"""
|
||||
return self._state_transitions
|
||||
|
||||
@property
|
||||
def time_until_retry(self) -> float:
|
||||
"""Get time remaining until retry is allowed.
|
||||
|
||||
Returns:
|
||||
Seconds until circuit may transition to half-open, or 0 if not open
|
||||
"""
|
||||
if self._state != CircuitState.OPEN or self._last_failure_time is None:
|
||||
return 0.0
|
||||
|
||||
elapsed = time.time() - self._last_failure_time
|
||||
remaining = self.cooldown_seconds - elapsed
|
||||
return max(0.0, remaining)
|
||||
|
||||
def can_execute(self) -> bool:
|
||||
"""Check if a request can be executed.
|
||||
|
||||
This method checks the current state and determines if a request
|
||||
should be allowed through.
|
||||
|
||||
Returns:
|
||||
True if request can proceed, False otherwise
|
||||
"""
|
||||
current_state = self.state # This handles cooldown transitions
|
||||
|
||||
if current_state == CircuitState.CLOSED:
|
||||
return True
|
||||
elif current_state == CircuitState.HALF_OPEN:
|
||||
# Allow one test request
|
||||
return True
|
||||
else: # OPEN
|
||||
return False
|
||||
|
||||
def record_success(self) -> None:
|
||||
"""Record a successful operation.
|
||||
|
||||
This resets the failure count and closes the circuit if it was
|
||||
in half-open state.
|
||||
"""
|
||||
self._total_successes += 1
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
logger.info(
|
||||
f"Circuit breaker '{self.name}': Recovery confirmed, closing circuit"
|
||||
)
|
||||
self._transition_to(CircuitState.CLOSED)
|
||||
|
||||
# Reset failure count on any success
|
||||
self._failure_count = 0
|
||||
logger.debug(f"Circuit breaker '{self.name}': Success recorded, failure count reset")
|
||||
|
||||
def record_failure(self) -> None:
|
||||
"""Record a failed operation.
|
||||
|
||||
This increments the failure count and may open the circuit if
|
||||
the threshold is reached.
|
||||
"""
|
||||
self._failure_count += 1
|
||||
self._total_failures += 1
|
||||
self._last_failure_time = time.time()
|
||||
|
||||
logger.warning(
|
||||
f"Circuit breaker '{self.name}': Failure recorded "
|
||||
f"({self._failure_count}/{self.failure_threshold})"
|
||||
)
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
# Failed during test request, go back to open
|
||||
logger.warning(
|
||||
f"Circuit breaker '{self.name}': Test request failed, reopening circuit"
|
||||
)
|
||||
self._transition_to(CircuitState.OPEN)
|
||||
elif self._failure_count >= self.failure_threshold:
|
||||
logger.error(
|
||||
f"Circuit breaker '{self.name}': Failure threshold reached, opening circuit"
|
||||
)
|
||||
self._transition_to(CircuitState.OPEN)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the circuit breaker to initial state.
|
||||
|
||||
This should be used carefully, typically only for testing or
|
||||
manual intervention.
|
||||
"""
|
||||
old_state = self._state
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._last_failure_time = None
|
||||
|
||||
logger.info(
|
||||
f"Circuit breaker '{self.name}': Manual reset "
|
||||
f"(was {old_state.value}, now closed)"
|
||||
)
|
||||
|
||||
def _transition_to(self, new_state: CircuitState) -> None:
|
||||
"""Transition to a new state.
|
||||
|
||||
Args:
|
||||
new_state: The state to transition to
|
||||
"""
|
||||
old_state = self._state
|
||||
self._state = new_state
|
||||
self._state_transitions += 1
|
||||
|
||||
logger.info(
|
||||
f"Circuit breaker '{self.name}': State transition "
|
||||
f"{old_state.value} -> {new_state.value}"
|
||||
)
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get circuit breaker statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with current stats
|
||||
"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"state": self.state.value,
|
||||
"failure_count": self._failure_count,
|
||||
"failure_threshold": self.failure_threshold,
|
||||
"cooldown_seconds": self.cooldown_seconds,
|
||||
"time_until_retry": self.time_until_retry,
|
||||
"total_failures": self._total_failures,
|
||||
"total_successes": self._total_successes,
|
||||
"state_transitions": self._state_transitions,
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
func: Callable[..., Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Execute a function with circuit breaker protection.
|
||||
|
||||
This is a convenience method that wraps async function execution
|
||||
with automatic success/failure recording.
|
||||
|
||||
Args:
|
||||
func: Async function to execute
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
Result of the function execution
|
||||
|
||||
Raises:
|
||||
CircuitBreakerError: If circuit is open
|
||||
Exception: If function raises and circuit is closed/half-open
|
||||
"""
|
||||
if not self.can_execute():
|
||||
raise CircuitBreakerError(self.state, self.time_until_retry)
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
self.record_success()
|
||||
return result
|
||||
except Exception:
|
||||
self.record_failure()
|
||||
raise
|
||||
@@ -6,6 +6,7 @@ from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from src.circuit_breaker import CircuitBreaker
|
||||
from src.context_compaction import CompactionResult, ContextCompactor, SessionRotation
|
||||
from src.models import ContextAction, ContextUsage
|
||||
|
||||
@@ -19,17 +20,29 @@ class ContextMonitor:
|
||||
Triggers appropriate actions based on defined thresholds:
|
||||
- 80% (COMPACT_THRESHOLD): Trigger context compaction
|
||||
- 95% (ROTATE_THRESHOLD): Trigger session rotation
|
||||
|
||||
Circuit Breaker (SEC-ORCH-7):
|
||||
- Per-agent circuit breakers prevent infinite retry loops on API failures
|
||||
- After failure_threshold consecutive failures, backs off for cooldown_seconds
|
||||
"""
|
||||
|
||||
COMPACT_THRESHOLD = 0.80 # 80% triggers compaction
|
||||
ROTATE_THRESHOLD = 0.95 # 95% triggers rotation
|
||||
|
||||
def __init__(self, api_client: Any, poll_interval: float = 10.0) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
api_client: Any,
|
||||
poll_interval: float = 10.0,
|
||||
circuit_breaker_threshold: int = 3,
|
||||
circuit_breaker_cooldown: float = 60.0,
|
||||
) -> None:
|
||||
"""Initialize context monitor.
|
||||
|
||||
Args:
|
||||
api_client: Claude API client for fetching context usage
|
||||
poll_interval: Seconds between polls (default: 10s)
|
||||
circuit_breaker_threshold: Consecutive failures before opening circuit (default: 3)
|
||||
circuit_breaker_cooldown: Seconds to wait before retry after circuit opens (default: 60)
|
||||
"""
|
||||
self.api_client = api_client
|
||||
self.poll_interval = poll_interval
|
||||
@@ -37,6 +50,11 @@ class ContextMonitor:
|
||||
self._monitoring_tasks: dict[str, bool] = {}
|
||||
self._compactor = ContextCompactor(api_client=api_client)
|
||||
|
||||
# Circuit breaker settings for per-agent monitoring loops (SEC-ORCH-7)
|
||||
self._circuit_breaker_threshold = circuit_breaker_threshold
|
||||
self._circuit_breaker_cooldown = circuit_breaker_cooldown
|
||||
self._circuit_breakers: dict[str, CircuitBreaker] = {}
|
||||
|
||||
async def get_context_usage(self, agent_id: str) -> ContextUsage:
|
||||
"""Get current context usage for an agent.
|
||||
|
||||
@@ -98,6 +116,36 @@ class ContextMonitor:
|
||||
"""
|
||||
return self._usage_history[agent_id]
|
||||
|
||||
def _get_circuit_breaker(self, agent_id: str) -> CircuitBreaker:
|
||||
"""Get or create circuit breaker for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for the agent
|
||||
|
||||
Returns:
|
||||
CircuitBreaker instance for this agent
|
||||
"""
|
||||
if agent_id not in self._circuit_breakers:
|
||||
self._circuit_breakers[agent_id] = CircuitBreaker(
|
||||
name=f"context_monitor_{agent_id}",
|
||||
failure_threshold=self._circuit_breaker_threshold,
|
||||
cooldown_seconds=self._circuit_breaker_cooldown,
|
||||
)
|
||||
return self._circuit_breakers[agent_id]
|
||||
|
||||
def get_circuit_breaker_stats(self, agent_id: str) -> dict[str, Any]:
|
||||
"""Get circuit breaker statistics for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for the agent
|
||||
|
||||
Returns:
|
||||
Dictionary with circuit breaker stats, or empty dict if no breaker exists
|
||||
"""
|
||||
if agent_id in self._circuit_breakers:
|
||||
return self._circuit_breakers[agent_id].get_stats()
|
||||
return {}
|
||||
|
||||
async def start_monitoring(
|
||||
self, agent_id: str, callback: Callable[[str, ContextAction], None]
|
||||
) -> None:
|
||||
@@ -106,22 +154,46 @@ class ContextMonitor:
|
||||
Polls context usage at regular intervals and calls callback with
|
||||
appropriate actions when thresholds are crossed.
|
||||
|
||||
Uses circuit breaker to prevent infinite retry loops on repeated failures.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for the agent
|
||||
callback: Function to call with (agent_id, action) on each poll
|
||||
"""
|
||||
self._monitoring_tasks[agent_id] = True
|
||||
circuit_breaker = self._get_circuit_breaker(agent_id)
|
||||
|
||||
logger.info(
|
||||
f"Started monitoring agent {agent_id} (poll interval: {self.poll_interval}s)"
|
||||
)
|
||||
|
||||
while self._monitoring_tasks.get(agent_id, False):
|
||||
# Check circuit breaker state before polling
|
||||
if not circuit_breaker.can_execute():
|
||||
wait_time = circuit_breaker.time_until_retry
|
||||
logger.warning(
|
||||
f"Circuit breaker OPEN for agent {agent_id} - "
|
||||
f"backing off for {wait_time:.1f}s"
|
||||
)
|
||||
try:
|
||||
await asyncio.sleep(wait_time)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
continue
|
||||
|
||||
try:
|
||||
action = await self.determine_action(agent_id)
|
||||
callback(agent_id, action)
|
||||
# Successful poll - record success
|
||||
circuit_breaker.record_success()
|
||||
except Exception as e:
|
||||
logger.error(f"Error monitoring agent {agent_id}: {e}")
|
||||
# Continue monitoring despite errors
|
||||
# Record failure in circuit breaker
|
||||
circuit_breaker.record_failure()
|
||||
logger.error(
|
||||
f"Error monitoring agent {agent_id}: {e} "
|
||||
f"(circuit breaker: {circuit_breaker.state.value}, "
|
||||
f"failures: {circuit_breaker.failure_count}/{circuit_breaker.failure_threshold})"
|
||||
)
|
||||
|
||||
# Wait for next poll (or until stopped)
|
||||
try:
|
||||
@@ -129,7 +201,15 @@ class ContextMonitor:
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
logger.info(f"Stopped monitoring agent {agent_id}")
|
||||
# Clean up circuit breaker when monitoring stops
|
||||
if agent_id in self._circuit_breakers:
|
||||
stats = self._circuit_breakers[agent_id].get_stats()
|
||||
del self._circuit_breakers[agent_id]
|
||||
logger.info(
|
||||
f"Stopped monitoring agent {agent_id} (circuit breaker stats: {stats})"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Stopped monitoring agent {agent_id}")
|
||||
|
||||
def stop_monitoring(self, agent_id: str) -> None:
|
||||
"""Stop background monitoring for an agent.
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.circuit_breaker import CircuitBreaker, CircuitBreakerError, CircuitState
|
||||
from src.context_monitor import ContextMonitor
|
||||
from src.forced_continuation import ForcedContinuationService
|
||||
from src.models import ContextAction
|
||||
@@ -24,20 +25,30 @@ class Coordinator:
|
||||
- Monitoring the queue for ready items
|
||||
- Spawning agents to process issues (stub implementation for Phase 0)
|
||||
- Marking items as complete when processing finishes
|
||||
- Handling errors gracefully
|
||||
- Handling errors gracefully with circuit breaker protection
|
||||
- Supporting graceful shutdown
|
||||
|
||||
Circuit Breaker (SEC-ORCH-7):
|
||||
- Tracks consecutive failures in the main loop
|
||||
- After failure_threshold consecutive failures, enters OPEN state
|
||||
- In OPEN state, backs off for cooldown_seconds before retrying
|
||||
- Prevents infinite retry loops on repeated failures
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
queue_manager: QueueManager,
|
||||
poll_interval: float = 5.0,
|
||||
circuit_breaker_threshold: int = 5,
|
||||
circuit_breaker_cooldown: float = 30.0,
|
||||
) -> None:
|
||||
"""Initialize the Coordinator.
|
||||
|
||||
Args:
|
||||
queue_manager: QueueManager instance for queue operations
|
||||
poll_interval: Seconds between queue polls (default: 5.0)
|
||||
circuit_breaker_threshold: Consecutive failures before opening circuit (default: 5)
|
||||
circuit_breaker_cooldown: Seconds to wait before retry after circuit opens (default: 30)
|
||||
"""
|
||||
self.queue_manager = queue_manager
|
||||
self.poll_interval = poll_interval
|
||||
@@ -45,6 +56,13 @@ class Coordinator:
|
||||
self._stop_event: asyncio.Event | None = None
|
||||
self._active_agents: dict[int, dict[str, Any]] = {}
|
||||
|
||||
# Circuit breaker for preventing infinite retry loops (SEC-ORCH-7)
|
||||
self._circuit_breaker = CircuitBreaker(
|
||||
name="coordinator_loop",
|
||||
failure_threshold=circuit_breaker_threshold,
|
||||
cooldown_seconds=circuit_breaker_cooldown,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the coordinator is currently running.
|
||||
@@ -71,10 +89,28 @@ class Coordinator:
|
||||
"""
|
||||
return len(self._active_agents)
|
||||
|
||||
@property
|
||||
def circuit_breaker(self) -> CircuitBreaker:
|
||||
"""Get the circuit breaker instance.
|
||||
|
||||
Returns:
|
||||
CircuitBreaker instance for this coordinator
|
||||
"""
|
||||
return self._circuit_breaker
|
||||
|
||||
def get_circuit_breaker_stats(self) -> dict[str, Any]:
|
||||
"""Get circuit breaker statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with circuit breaker stats
|
||||
"""
|
||||
return self._circuit_breaker.get_stats()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the orchestration loop.
|
||||
|
||||
Continuously processes the queue until stop() is called.
|
||||
Uses circuit breaker to prevent infinite retry loops on repeated failures.
|
||||
"""
|
||||
self._running = True
|
||||
self._stop_event = asyncio.Event()
|
||||
@@ -82,11 +118,32 @@ class Coordinator:
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
# Check circuit breaker state before processing
|
||||
if not self._circuit_breaker.can_execute():
|
||||
# Circuit is open - wait for cooldown
|
||||
wait_time = self._circuit_breaker.time_until_retry
|
||||
logger.warning(
|
||||
f"Circuit breaker OPEN - backing off for {wait_time:.1f}s "
|
||||
f"(failures: {self._circuit_breaker.failure_count})"
|
||||
)
|
||||
await self._wait_for_cooldown_or_stop(wait_time)
|
||||
continue
|
||||
|
||||
try:
|
||||
await self.process_queue()
|
||||
# Successful processing - record success
|
||||
self._circuit_breaker.record_success()
|
||||
except CircuitBreakerError as e:
|
||||
# Circuit breaker blocked the request
|
||||
logger.warning(f"Circuit breaker blocked request: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_queue: {e}")
|
||||
# Continue running despite errors
|
||||
# Record failure in circuit breaker
|
||||
self._circuit_breaker.record_failure()
|
||||
logger.error(
|
||||
f"Error in process_queue: {e} "
|
||||
f"(circuit breaker: {self._circuit_breaker.state.value}, "
|
||||
f"failures: {self._circuit_breaker.failure_count}/{self._circuit_breaker.failure_threshold})"
|
||||
)
|
||||
|
||||
# Wait for poll interval or stop signal
|
||||
try:
|
||||
@@ -102,7 +159,26 @@ class Coordinator:
|
||||
|
||||
finally:
|
||||
self._running = False
|
||||
logger.info("Coordinator stopped")
|
||||
logger.info(
|
||||
f"Coordinator stopped "
|
||||
f"(circuit breaker stats: {self._circuit_breaker.get_stats()})"
|
||||
)
|
||||
|
||||
async def _wait_for_cooldown_or_stop(self, cooldown: float) -> None:
|
||||
"""Wait for cooldown period or stop signal, whichever comes first.
|
||||
|
||||
Args:
|
||||
cooldown: Seconds to wait for cooldown
|
||||
"""
|
||||
if self._stop_event is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._stop_event.wait(), timeout=cooldown)
|
||||
# Stop was requested during cooldown
|
||||
except TimeoutError:
|
||||
# Cooldown completed, continue
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the orchestration loop gracefully.
|
||||
@@ -200,6 +276,12 @@ class OrchestrationLoop:
|
||||
- Quality gate verification on completion claims
|
||||
- Rejection handling with forced continuation prompts
|
||||
- Context monitoring during agent execution
|
||||
|
||||
Circuit Breaker (SEC-ORCH-7):
|
||||
- Tracks consecutive failures in the main loop
|
||||
- After failure_threshold consecutive failures, enters OPEN state
|
||||
- In OPEN state, backs off for cooldown_seconds before retrying
|
||||
- Prevents infinite retry loops on repeated failures
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -209,6 +291,8 @@ class OrchestrationLoop:
|
||||
continuation_service: ForcedContinuationService,
|
||||
context_monitor: ContextMonitor,
|
||||
poll_interval: float = 5.0,
|
||||
circuit_breaker_threshold: int = 5,
|
||||
circuit_breaker_cooldown: float = 30.0,
|
||||
) -> None:
|
||||
"""Initialize the OrchestrationLoop.
|
||||
|
||||
@@ -218,6 +302,8 @@ class OrchestrationLoop:
|
||||
continuation_service: ForcedContinuationService for rejection prompts
|
||||
context_monitor: ContextMonitor for tracking agent context usage
|
||||
poll_interval: Seconds between queue polls (default: 5.0)
|
||||
circuit_breaker_threshold: Consecutive failures before opening circuit (default: 5)
|
||||
circuit_breaker_cooldown: Seconds to wait before retry after circuit opens (default: 30)
|
||||
"""
|
||||
self.queue_manager = queue_manager
|
||||
self.quality_orchestrator = quality_orchestrator
|
||||
@@ -233,6 +319,13 @@ class OrchestrationLoop:
|
||||
self._success_count = 0
|
||||
self._rejection_count = 0
|
||||
|
||||
# Circuit breaker for preventing infinite retry loops (SEC-ORCH-7)
|
||||
self._circuit_breaker = CircuitBreaker(
|
||||
name="orchestration_loop",
|
||||
failure_threshold=circuit_breaker_threshold,
|
||||
cooldown_seconds=circuit_breaker_cooldown,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the orchestration loop is currently running.
|
||||
@@ -286,10 +379,28 @@ class OrchestrationLoop:
|
||||
"""
|
||||
return len(self._active_agents)
|
||||
|
||||
@property
|
||||
def circuit_breaker(self) -> CircuitBreaker:
|
||||
"""Get the circuit breaker instance.
|
||||
|
||||
Returns:
|
||||
CircuitBreaker instance for this orchestration loop
|
||||
"""
|
||||
return self._circuit_breaker
|
||||
|
||||
def get_circuit_breaker_stats(self) -> dict[str, Any]:
|
||||
"""Get circuit breaker statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with circuit breaker stats
|
||||
"""
|
||||
return self._circuit_breaker.get_stats()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the orchestration loop.
|
||||
|
||||
Continuously processes the queue until stop() is called.
|
||||
Uses circuit breaker to prevent infinite retry loops on repeated failures.
|
||||
"""
|
||||
self._running = True
|
||||
self._stop_event = asyncio.Event()
|
||||
@@ -297,11 +408,32 @@ class OrchestrationLoop:
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
# Check circuit breaker state before processing
|
||||
if not self._circuit_breaker.can_execute():
|
||||
# Circuit is open - wait for cooldown
|
||||
wait_time = self._circuit_breaker.time_until_retry
|
||||
logger.warning(
|
||||
f"Circuit breaker OPEN - backing off for {wait_time:.1f}s "
|
||||
f"(failures: {self._circuit_breaker.failure_count})"
|
||||
)
|
||||
await self._wait_for_cooldown_or_stop(wait_time)
|
||||
continue
|
||||
|
||||
try:
|
||||
await self.process_next_issue()
|
||||
# Successful processing - record success
|
||||
self._circuit_breaker.record_success()
|
||||
except CircuitBreakerError as e:
|
||||
# Circuit breaker blocked the request
|
||||
logger.warning(f"Circuit breaker blocked request: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process_next_issue: {e}")
|
||||
# Continue running despite errors
|
||||
# Record failure in circuit breaker
|
||||
self._circuit_breaker.record_failure()
|
||||
logger.error(
|
||||
f"Error in process_next_issue: {e} "
|
||||
f"(circuit breaker: {self._circuit_breaker.state.value}, "
|
||||
f"failures: {self._circuit_breaker.failure_count}/{self._circuit_breaker.failure_threshold})"
|
||||
)
|
||||
|
||||
# Wait for poll interval or stop signal
|
||||
try:
|
||||
@@ -317,7 +449,26 @@ class OrchestrationLoop:
|
||||
|
||||
finally:
|
||||
self._running = False
|
||||
logger.info("OrchestrationLoop stopped")
|
||||
logger.info(
|
||||
f"OrchestrationLoop stopped "
|
||||
f"(circuit breaker stats: {self._circuit_breaker.get_stats()})"
|
||||
)
|
||||
|
||||
async def _wait_for_cooldown_or_stop(self, cooldown: float) -> None:
|
||||
"""Wait for cooldown period or stop signal, whichever comes first.
|
||||
|
||||
Args:
|
||||
cooldown: Seconds to wait for cooldown
|
||||
"""
|
||||
if self._stop_event is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self._stop_event.wait(), timeout=cooldown)
|
||||
# Stop was requested during cooldown
|
||||
except TimeoutError:
|
||||
# Cooldown completed, continue
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the orchestration loop gracefully.
|
||||
|
||||
495
apps/coordinator/tests/test_circuit_breaker.py
Normal file
495
apps/coordinator/tests/test_circuit_breaker.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""Tests for circuit breaker pattern implementation.
|
||||
|
||||
These tests verify the circuit breaker behavior:
|
||||
- State transitions (closed -> open -> half_open -> closed)
|
||||
- Failure counting and threshold detection
|
||||
- Cooldown timing
|
||||
- Success/failure recording
|
||||
- Execute wrapper method
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.circuit_breaker import CircuitBreaker, CircuitBreakerError, CircuitState
|
||||
|
||||
|
||||
class TestCircuitBreakerInitialization:
|
||||
"""Tests for CircuitBreaker initialization."""
|
||||
|
||||
def test_default_initialization(self) -> None:
|
||||
"""Test circuit breaker initializes with default values."""
|
||||
cb = CircuitBreaker("test")
|
||||
|
||||
assert cb.name == "test"
|
||||
assert cb.failure_threshold == 5
|
||||
assert cb.cooldown_seconds == 30.0
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
assert cb.failure_count == 0
|
||||
|
||||
def test_custom_initialization(self) -> None:
|
||||
"""Test circuit breaker with custom values."""
|
||||
cb = CircuitBreaker(
|
||||
name="custom",
|
||||
failure_threshold=3,
|
||||
cooldown_seconds=10.0,
|
||||
)
|
||||
|
||||
assert cb.name == "custom"
|
||||
assert cb.failure_threshold == 3
|
||||
assert cb.cooldown_seconds == 10.0
|
||||
|
||||
def test_initial_state_is_closed(self) -> None:
|
||||
"""Test circuit starts in closed state."""
|
||||
cb = CircuitBreaker("test")
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
def test_initial_can_execute_is_true(self) -> None:
|
||||
"""Test can_execute returns True initially."""
|
||||
cb = CircuitBreaker("test")
|
||||
assert cb.can_execute() is True
|
||||
|
||||
|
||||
class TestCircuitBreakerFailureTracking:
|
||||
"""Tests for failure tracking behavior."""
|
||||
|
||||
def test_failure_increments_count(self) -> None:
|
||||
"""Test that recording failure increments failure count."""
|
||||
cb = CircuitBreaker("test", failure_threshold=5)
|
||||
|
||||
cb.record_failure()
|
||||
assert cb.failure_count == 1
|
||||
|
||||
cb.record_failure()
|
||||
assert cb.failure_count == 2
|
||||
|
||||
def test_success_resets_failure_count(self) -> None:
|
||||
"""Test that recording success resets failure count."""
|
||||
cb = CircuitBreaker("test", failure_threshold=5)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
assert cb.failure_count == 2
|
||||
|
||||
cb.record_success()
|
||||
assert cb.failure_count == 0
|
||||
|
||||
def test_total_failures_tracked(self) -> None:
|
||||
"""Test that total failures are tracked separately."""
|
||||
cb = CircuitBreaker("test", failure_threshold=5)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
cb.record_success() # Resets consecutive count
|
||||
cb.record_failure()
|
||||
|
||||
assert cb.failure_count == 1 # Consecutive
|
||||
assert cb.total_failures == 3 # Total
|
||||
|
||||
def test_total_successes_tracked(self) -> None:
|
||||
"""Test that total successes are tracked."""
|
||||
cb = CircuitBreaker("test")
|
||||
|
||||
cb.record_success()
|
||||
cb.record_success()
|
||||
cb.record_failure()
|
||||
cb.record_success()
|
||||
|
||||
assert cb.total_successes == 3
|
||||
|
||||
|
||||
class TestCircuitBreakerStateTransitions:
|
||||
"""Tests for state transition behavior."""
|
||||
|
||||
def test_reaches_threshold_opens_circuit(self) -> None:
|
||||
"""Test circuit opens when failure threshold is reached."""
|
||||
cb = CircuitBreaker("test", failure_threshold=3)
|
||||
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
def test_open_circuit_blocks_execution(self) -> None:
|
||||
"""Test that open circuit blocks can_execute."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
|
||||
assert cb.state == CircuitState.OPEN
|
||||
assert cb.can_execute() is False
|
||||
|
||||
def test_cooldown_transitions_to_half_open(self) -> None:
|
||||
"""Test that cooldown period transitions circuit to half-open."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
# Wait for cooldown
|
||||
time.sleep(0.15)
|
||||
|
||||
# Accessing state triggers transition
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
|
||||
def test_half_open_allows_one_request(self) -> None:
|
||||
"""Test that half-open state allows test request."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
|
||||
time.sleep(0.15)
|
||||
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
assert cb.can_execute() is True
|
||||
|
||||
def test_half_open_success_closes_circuit(self) -> None:
|
||||
"""Test that success in half-open state closes circuit."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
|
||||
time.sleep(0.15)
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
|
||||
cb.record_success()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
def test_half_open_failure_reopens_circuit(self) -> None:
|
||||
"""Test that failure in half-open state reopens circuit."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
|
||||
time.sleep(0.15)
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
def test_state_transitions_counted(self) -> None:
|
||||
"""Test that state transitions are counted."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1)
|
||||
|
||||
assert cb.state_transitions == 0
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure() # -> OPEN
|
||||
assert cb.state_transitions == 1
|
||||
|
||||
time.sleep(0.15)
|
||||
_ = cb.state # -> HALF_OPEN
|
||||
assert cb.state_transitions == 2
|
||||
|
||||
cb.record_success() # -> CLOSED
|
||||
assert cb.state_transitions == 3
|
||||
|
||||
|
||||
class TestCircuitBreakerCooldown:
|
||||
"""Tests for cooldown timing behavior."""
|
||||
|
||||
def test_time_until_retry_when_open(self) -> None:
|
||||
"""Test time_until_retry reports correct value when open."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=1.0)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
|
||||
# Should be approximately 1 second
|
||||
assert 0.9 <= cb.time_until_retry <= 1.0
|
||||
|
||||
def test_time_until_retry_decreases(self) -> None:
|
||||
"""Test time_until_retry decreases over time."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=1.0)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
|
||||
initial = cb.time_until_retry
|
||||
time.sleep(0.2)
|
||||
after = cb.time_until_retry
|
||||
|
||||
assert after < initial
|
||||
|
||||
def test_time_until_retry_zero_when_closed(self) -> None:
|
||||
"""Test time_until_retry is 0 when circuit is closed."""
|
||||
cb = CircuitBreaker("test")
|
||||
assert cb.time_until_retry == 0.0
|
||||
|
||||
def test_time_until_retry_zero_when_half_open(self) -> None:
|
||||
"""Test time_until_retry is 0 when circuit is half-open."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
time.sleep(0.15)
|
||||
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
assert cb.time_until_retry == 0.0
|
||||
|
||||
|
||||
class TestCircuitBreakerReset:
|
||||
"""Tests for manual reset behavior."""
|
||||
|
||||
def test_reset_closes_circuit(self) -> None:
|
||||
"""Test that reset closes an open circuit."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
cb.reset()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
def test_reset_clears_failure_count(self) -> None:
|
||||
"""Test that reset clears failure count."""
|
||||
cb = CircuitBreaker("test", failure_threshold=5)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
assert cb.failure_count == 2
|
||||
|
||||
cb.reset()
|
||||
assert cb.failure_count == 0
|
||||
|
||||
def test_reset_from_half_open(self) -> None:
|
||||
"""Test reset from half-open state."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
time.sleep(0.15)
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
|
||||
cb.reset()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
|
||||
class TestCircuitBreakerStats:
|
||||
"""Tests for statistics reporting."""
|
||||
|
||||
def test_get_stats_returns_all_fields(self) -> None:
|
||||
"""Test get_stats returns complete statistics."""
|
||||
cb = CircuitBreaker("test", failure_threshold=3, cooldown_seconds=15.0)
|
||||
|
||||
stats = cb.get_stats()
|
||||
|
||||
assert stats["name"] == "test"
|
||||
assert stats["state"] == "closed"
|
||||
assert stats["failure_count"] == 0
|
||||
assert stats["failure_threshold"] == 3
|
||||
assert stats["cooldown_seconds"] == 15.0
|
||||
assert stats["time_until_retry"] == 0.0
|
||||
assert stats["total_failures"] == 0
|
||||
assert stats["total_successes"] == 0
|
||||
assert stats["state_transitions"] == 0
|
||||
|
||||
def test_stats_update_after_operations(self) -> None:
|
||||
"""Test stats update correctly after operations."""
|
||||
cb = CircuitBreaker("test", failure_threshold=3)
|
||||
|
||||
cb.record_failure()
|
||||
cb.record_success()
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
cb.record_failure() # Opens circuit
|
||||
|
||||
stats = cb.get_stats()
|
||||
|
||||
assert stats["state"] == "open"
|
||||
assert stats["failure_count"] == 3
|
||||
assert stats["total_failures"] == 4
|
||||
assert stats["total_successes"] == 1
|
||||
assert stats["state_transitions"] == 1
|
||||
|
||||
|
||||
class TestCircuitBreakerError:
|
||||
"""Tests for CircuitBreakerError exception."""
|
||||
|
||||
def test_error_contains_state(self) -> None:
|
||||
"""Test error contains circuit state."""
|
||||
error = CircuitBreakerError(CircuitState.OPEN, 10.0)
|
||||
assert error.state == CircuitState.OPEN
|
||||
|
||||
def test_error_contains_retry_time(self) -> None:
|
||||
"""Test error contains time until retry."""
|
||||
error = CircuitBreakerError(CircuitState.OPEN, 10.5)
|
||||
assert error.time_until_retry == 10.5
|
||||
|
||||
def test_error_message_formatting(self) -> None:
|
||||
"""Test error message is properly formatted."""
|
||||
error = CircuitBreakerError(CircuitState.OPEN, 15.3)
|
||||
assert "open" in str(error)
|
||||
assert "15.3" in str(error)
|
||||
|
||||
|
||||
class TestCircuitBreakerExecute:
|
||||
"""Tests for the execute wrapper method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_calls_function(self) -> None:
|
||||
"""Test execute calls the provided function."""
|
||||
cb = CircuitBreaker("test")
|
||||
mock_func = AsyncMock(return_value="success")
|
||||
|
||||
result = await cb.execute(mock_func, "arg1", kwarg="value")
|
||||
|
||||
mock_func.assert_called_once_with("arg1", kwarg="value")
|
||||
assert result == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_records_success(self) -> None:
|
||||
"""Test execute records success on successful call."""
|
||||
cb = CircuitBreaker("test")
|
||||
mock_func = AsyncMock(return_value="ok")
|
||||
|
||||
await cb.execute(mock_func)
|
||||
|
||||
assert cb.total_successes == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_records_failure(self) -> None:
|
||||
"""Test execute records failure when function raises."""
|
||||
cb = CircuitBreaker("test")
|
||||
mock_func = AsyncMock(side_effect=RuntimeError("test error"))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await cb.execute(mock_func)
|
||||
|
||||
assert cb.failure_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_raises_when_open(self) -> None:
|
||||
"""Test execute raises CircuitBreakerError when circuit is open."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2)
|
||||
|
||||
mock_func = AsyncMock(side_effect=RuntimeError("fail"))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await cb.execute(mock_func)
|
||||
with pytest.raises(RuntimeError):
|
||||
await cb.execute(mock_func)
|
||||
|
||||
# Circuit should now be open
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
# Next call should raise CircuitBreakerError
|
||||
with pytest.raises(CircuitBreakerError) as exc_info:
|
||||
await cb.execute(mock_func)
|
||||
|
||||
assert exc_info.value.state == CircuitState.OPEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_allows_half_open_test(self) -> None:
|
||||
"""Test execute allows test request in half-open state."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1)
|
||||
|
||||
mock_func = AsyncMock(side_effect=RuntimeError("fail"))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await cb.execute(mock_func)
|
||||
with pytest.raises(RuntimeError):
|
||||
await cb.execute(mock_func)
|
||||
|
||||
# Wait for cooldown
|
||||
await asyncio.sleep(0.15)
|
||||
assert cb.state == CircuitState.HALF_OPEN
|
||||
|
||||
# Should allow test request
|
||||
mock_func.side_effect = None
|
||||
mock_func.return_value = "recovered"
|
||||
|
||||
result = await cb.execute(mock_func)
|
||||
assert result == "recovered"
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
|
||||
class TestCircuitBreakerConcurrency:
|
||||
"""Tests for thread safety and concurrent access."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_failures(self) -> None:
|
||||
"""Test concurrent failures are handled correctly."""
|
||||
cb = CircuitBreaker("test", failure_threshold=10)
|
||||
|
||||
async def record_failure() -> None:
|
||||
cb.record_failure()
|
||||
|
||||
# Record 10 concurrent failures
|
||||
await asyncio.gather(*[record_failure() for _ in range(10)])
|
||||
|
||||
assert cb.failure_count >= 10
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_mixed_operations(self) -> None:
|
||||
"""Test concurrent mixed success/failure operations."""
|
||||
cb = CircuitBreaker("test", failure_threshold=100)
|
||||
|
||||
async def record_success() -> None:
|
||||
cb.record_success()
|
||||
|
||||
async def record_failure() -> None:
|
||||
cb.record_failure()
|
||||
|
||||
# Mix of operations
|
||||
tasks = [record_failure() for _ in range(5)]
|
||||
tasks.extend([record_success() for _ in range(3)])
|
||||
tasks.extend([record_failure() for _ in range(5)])
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# At least some of each should have been recorded
|
||||
assert cb.total_failures >= 5
|
||||
assert cb.total_successes >= 1
|
||||
|
||||
|
||||
class TestCircuitBreakerLogging:
|
||||
"""Tests for logging behavior."""
|
||||
|
||||
def test_logs_state_transitions(self) -> None:
|
||||
"""Test that state transitions are logged."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2)
|
||||
|
||||
with patch("src.circuit_breaker.logger") as mock_logger:
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
|
||||
# Should have logged the transition to OPEN
|
||||
mock_logger.info.assert_called()
|
||||
calls = [str(c) for c in mock_logger.info.call_args_list]
|
||||
assert any("closed -> open" in c for c in calls)
|
||||
|
||||
def test_logs_failure_warnings(self) -> None:
|
||||
"""Test that failures are logged as warnings."""
|
||||
cb = CircuitBreaker("test", failure_threshold=5)
|
||||
|
||||
with patch("src.circuit_breaker.logger") as mock_logger:
|
||||
cb.record_failure()
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
def test_logs_threshold_reached_as_error(self) -> None:
|
||||
"""Test that reaching threshold is logged as error."""
|
||||
cb = CircuitBreaker("test", failure_threshold=2)
|
||||
|
||||
with patch("src.circuit_breaker.logger") as mock_logger:
|
||||
cb.record_failure()
|
||||
cb.record_failure()
|
||||
|
||||
mock_logger.error.assert_called()
|
||||
calls = [str(c) for c in mock_logger.error.call_args_list]
|
||||
assert any("threshold reached" in c for c in calls)
|
||||
@@ -744,3 +744,186 @@ class TestCoordinatorActiveAgents:
|
||||
await coordinator.process_queue()
|
||||
|
||||
assert coordinator.get_active_agent_count() == 3
|
||||
|
||||
|
||||
class TestCoordinatorCircuitBreaker:
|
||||
"""Tests for Coordinator circuit breaker integration (SEC-ORCH-7)."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_queue_file(self) -> Generator[Path, None, None]:
|
||||
"""Create a temporary file for queue persistence."""
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f:
|
||||
temp_path = Path(f.name)
|
||||
yield temp_path
|
||||
if temp_path.exists():
|
||||
temp_path.unlink()
|
||||
|
||||
@pytest.fixture
|
||||
def queue_manager(self, temp_queue_file: Path) -> QueueManager:
|
||||
"""Create a queue manager with temporary storage."""
|
||||
return QueueManager(queue_file=temp_queue_file)
|
||||
|
||||
def test_circuit_breaker_initialized(self, queue_manager: QueueManager) -> None:
|
||||
"""Test that circuit breaker is initialized with Coordinator."""
|
||||
from src.coordinator import Coordinator
|
||||
|
||||
coordinator = Coordinator(queue_manager=queue_manager)
|
||||
|
||||
assert coordinator.circuit_breaker is not None
|
||||
assert coordinator.circuit_breaker.name == "coordinator_loop"
|
||||
|
||||
def test_circuit_breaker_custom_settings(self, queue_manager: QueueManager) -> None:
|
||||
"""Test circuit breaker with custom threshold and cooldown."""
|
||||
from src.coordinator import Coordinator
|
||||
|
||||
coordinator = Coordinator(
|
||||
queue_manager=queue_manager,
|
||||
circuit_breaker_threshold=3,
|
||||
circuit_breaker_cooldown=15.0,
|
||||
)
|
||||
|
||||
assert coordinator.circuit_breaker.failure_threshold == 3
|
||||
assert coordinator.circuit_breaker.cooldown_seconds == 15.0
|
||||
|
||||
def test_get_circuit_breaker_stats(self, queue_manager: QueueManager) -> None:
|
||||
"""Test getting circuit breaker statistics."""
|
||||
from src.coordinator import Coordinator
|
||||
|
||||
coordinator = Coordinator(queue_manager=queue_manager)
|
||||
|
||||
stats = coordinator.get_circuit_breaker_stats()
|
||||
|
||||
assert "name" in stats
|
||||
assert "state" in stats
|
||||
assert "failure_count" in stats
|
||||
assert "total_failures" in stats
|
||||
assert stats["name"] == "coordinator_loop"
|
||||
assert stats["state"] == "closed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_opens_on_repeated_failures(
|
||||
self, queue_manager: QueueManager
|
||||
) -> None:
|
||||
"""Test that circuit breaker opens after repeated failures."""
|
||||
from src.circuit_breaker import CircuitState
|
||||
from src.coordinator import Coordinator
|
||||
|
||||
coordinator = Coordinator(
|
||||
queue_manager=queue_manager,
|
||||
poll_interval=0.02,
|
||||
circuit_breaker_threshold=3,
|
||||
circuit_breaker_cooldown=0.2,
|
||||
)
|
||||
|
||||
failure_count = 0
|
||||
|
||||
async def failing_process_queue() -> None:
|
||||
nonlocal failure_count
|
||||
failure_count += 1
|
||||
raise RuntimeError("Simulated failure")
|
||||
|
||||
coordinator.process_queue = failing_process_queue # type: ignore[method-assign]
|
||||
|
||||
task = asyncio.create_task(coordinator.start())
|
||||
await asyncio.sleep(0.15) # Allow time for failures
|
||||
|
||||
# Circuit should be open after 3 failures
|
||||
assert coordinator.circuit_breaker.state == CircuitState.OPEN
|
||||
assert coordinator.circuit_breaker.failure_count >= 3
|
||||
|
||||
await coordinator.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_backs_off_when_open(
|
||||
self, queue_manager: QueueManager
|
||||
) -> None:
|
||||
"""Test that coordinator backs off when circuit breaker is open."""
|
||||
from src.coordinator import Coordinator
|
||||
|
||||
coordinator = Coordinator(
|
||||
queue_manager=queue_manager,
|
||||
poll_interval=0.02,
|
||||
circuit_breaker_threshold=2,
|
||||
circuit_breaker_cooldown=0.3,
|
||||
)
|
||||
|
||||
call_timestamps: list[float] = []
|
||||
|
||||
async def failing_process_queue() -> None:
|
||||
call_timestamps.append(asyncio.get_event_loop().time())
|
||||
raise RuntimeError("Simulated failure")
|
||||
|
||||
coordinator.process_queue = failing_process_queue # type: ignore[method-assign]
|
||||
|
||||
task = asyncio.create_task(coordinator.start())
|
||||
await asyncio.sleep(0.5) # Allow time for failures and backoff
|
||||
await coordinator.stop()
|
||||
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Should have at least 2 calls (to trigger open), then back off
|
||||
assert len(call_timestamps) >= 2
|
||||
|
||||
# After circuit opens, there should be a gap (cooldown)
|
||||
if len(call_timestamps) >= 3:
|
||||
# Check there's a larger gap after the first 2 calls
|
||||
first_gap = call_timestamps[1] - call_timestamps[0]
|
||||
later_gap = call_timestamps[2] - call_timestamps[1]
|
||||
# Later gap should be larger due to cooldown
|
||||
assert later_gap > first_gap * 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_resets_on_success(
|
||||
self, queue_manager: QueueManager
|
||||
) -> None:
|
||||
"""Test that circuit breaker resets after successful operation."""
|
||||
from src.circuit_breaker import CircuitState
|
||||
from src.coordinator import Coordinator
|
||||
|
||||
coordinator = Coordinator(
|
||||
queue_manager=queue_manager,
|
||||
poll_interval=0.02,
|
||||
circuit_breaker_threshold=3,
|
||||
)
|
||||
|
||||
# Record failures then success
|
||||
coordinator.circuit_breaker.record_failure()
|
||||
coordinator.circuit_breaker.record_failure()
|
||||
assert coordinator.circuit_breaker.failure_count == 2
|
||||
|
||||
coordinator.circuit_breaker.record_success()
|
||||
assert coordinator.circuit_breaker.failure_count == 0
|
||||
assert coordinator.circuit_breaker.state == CircuitState.CLOSED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circuit_breaker_stats_logged_on_stop(
|
||||
self, queue_manager: QueueManager
|
||||
) -> None:
|
||||
"""Test that circuit breaker stats are logged when coordinator stops."""
|
||||
from src.coordinator import Coordinator
|
||||
|
||||
coordinator = Coordinator(queue_manager=queue_manager, poll_interval=0.05)
|
||||
|
||||
with patch("src.coordinator.logger") as mock_logger:
|
||||
task = asyncio.create_task(coordinator.start())
|
||||
await asyncio.sleep(0.1)
|
||||
await coordinator.stop()
|
||||
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Should log circuit breaker stats on stop
|
||||
info_calls = [str(call) for call in mock_logger.info.call_args_list]
|
||||
assert any("circuit breaker" in call.lower() for call in info_calls)
|
||||
|
||||
Reference in New Issue
Block a user