diff --git a/apps/coordinator/src/circuit_breaker.py b/apps/coordinator/src/circuit_breaker.py new file mode 100644 index 0000000..aa3c217 --- /dev/null +++ b/apps/coordinator/src/circuit_breaker.py @@ -0,0 +1,299 @@ +"""Circuit breaker pattern for preventing infinite retry loops. + +This module provides a CircuitBreaker class that implements the circuit breaker +pattern to protect against cascading failures in coordinator loops. + +Circuit breaker states: +- CLOSED: Normal operation, requests pass through +- OPEN: After N consecutive failures, all requests are blocked +- HALF_OPEN: After cooldown, allow one request to test recovery + +Reference: SEC-ORCH-7 from security review +""" + +import logging +import time +from enum import Enum +from typing import Any, Callable + +logger = logging.getLogger(__name__) + + +class CircuitState(str, Enum): + """States for the circuit breaker.""" + + CLOSED = "closed" # Normal operation + OPEN = "open" # Blocking requests after failures + HALF_OPEN = "half_open" # Testing if service recovered + + +class CircuitBreakerError(Exception): + """Exception raised when circuit is open and blocking requests.""" + + def __init__(self, state: CircuitState, time_until_retry: float) -> None: + """Initialize CircuitBreakerError. + + Args: + state: Current circuit state + time_until_retry: Seconds until circuit may close + """ + self.state = state + self.time_until_retry = time_until_retry + super().__init__( + f"Circuit breaker is {state.value}. " + f"Retry in {time_until_retry:.1f} seconds." + ) + + +class CircuitBreaker: + """Circuit breaker for protecting against cascading failures. + + The circuit breaker tracks consecutive failures and opens the circuit + after a threshold is reached, preventing further requests until a + cooldown period has elapsed. + + Attributes: + name: Identifier for this circuit breaker (for logging) + failure_threshold: Number of consecutive failures before opening + cooldown_seconds: Seconds to wait before allowing retry + state: Current circuit state + failure_count: Current consecutive failure count + """ + + def __init__( + self, + name: str, + failure_threshold: int = 5, + cooldown_seconds: float = 30.0, + ) -> None: + """Initialize CircuitBreaker. + + Args: + name: Identifier for this circuit breaker + failure_threshold: Consecutive failures before opening (default: 5) + cooldown_seconds: Seconds to wait before half-open (default: 30) + """ + self.name = name + self.failure_threshold = failure_threshold + self.cooldown_seconds = cooldown_seconds + + self._state = CircuitState.CLOSED + self._failure_count = 0 + self._last_failure_time: float | None = None + self._total_failures = 0 + self._total_successes = 0 + self._state_transitions = 0 + + @property + def state(self) -> CircuitState: + """Get the current circuit state. + + This also handles automatic state transitions based on cooldown. + + Returns: + Current CircuitState + """ + if self._state == CircuitState.OPEN: + # Check if cooldown has elapsed + if self._last_failure_time is not None: + elapsed = time.time() - self._last_failure_time + if elapsed >= self.cooldown_seconds: + self._transition_to(CircuitState.HALF_OPEN) + return self._state + + @property + def failure_count(self) -> int: + """Get current consecutive failure count. + + Returns: + Number of consecutive failures + """ + return self._failure_count + + @property + def total_failures(self) -> int: + """Get total failure count (all-time). + + Returns: + Total number of failures + """ + return self._total_failures + + @property + def total_successes(self) -> int: + """Get total success count (all-time). + + Returns: + Total number of successes + """ + return self._total_successes + + @property + def state_transitions(self) -> int: + """Get total state transition count. + + Returns: + Number of state transitions + """ + return self._state_transitions + + @property + def time_until_retry(self) -> float: + """Get time remaining until retry is allowed. + + Returns: + Seconds until circuit may transition to half-open, or 0 if not open + """ + if self._state != CircuitState.OPEN or self._last_failure_time is None: + return 0.0 + + elapsed = time.time() - self._last_failure_time + remaining = self.cooldown_seconds - elapsed + return max(0.0, remaining) + + def can_execute(self) -> bool: + """Check if a request can be executed. + + This method checks the current state and determines if a request + should be allowed through. + + Returns: + True if request can proceed, False otherwise + """ + current_state = self.state # This handles cooldown transitions + + if current_state == CircuitState.CLOSED: + return True + elif current_state == CircuitState.HALF_OPEN: + # Allow one test request + return True + else: # OPEN + return False + + def record_success(self) -> None: + """Record a successful operation. + + This resets the failure count and closes the circuit if it was + in half-open state. + """ + self._total_successes += 1 + + if self._state == CircuitState.HALF_OPEN: + logger.info( + f"Circuit breaker '{self.name}': Recovery confirmed, closing circuit" + ) + self._transition_to(CircuitState.CLOSED) + + # Reset failure count on any success + self._failure_count = 0 + logger.debug(f"Circuit breaker '{self.name}': Success recorded, failure count reset") + + def record_failure(self) -> None: + """Record a failed operation. + + This increments the failure count and may open the circuit if + the threshold is reached. + """ + self._failure_count += 1 + self._total_failures += 1 + self._last_failure_time = time.time() + + logger.warning( + f"Circuit breaker '{self.name}': Failure recorded " + f"({self._failure_count}/{self.failure_threshold})" + ) + + if self._state == CircuitState.HALF_OPEN: + # Failed during test request, go back to open + logger.warning( + f"Circuit breaker '{self.name}': Test request failed, reopening circuit" + ) + self._transition_to(CircuitState.OPEN) + elif self._failure_count >= self.failure_threshold: + logger.error( + f"Circuit breaker '{self.name}': Failure threshold reached, opening circuit" + ) + self._transition_to(CircuitState.OPEN) + + def reset(self) -> None: + """Reset the circuit breaker to initial state. + + This should be used carefully, typically only for testing or + manual intervention. + """ + old_state = self._state + self._state = CircuitState.CLOSED + self._failure_count = 0 + self._last_failure_time = None + + logger.info( + f"Circuit breaker '{self.name}': Manual reset " + f"(was {old_state.value}, now closed)" + ) + + def _transition_to(self, new_state: CircuitState) -> None: + """Transition to a new state. + + Args: + new_state: The state to transition to + """ + old_state = self._state + self._state = new_state + self._state_transitions += 1 + + logger.info( + f"Circuit breaker '{self.name}': State transition " + f"{old_state.value} -> {new_state.value}" + ) + + def get_stats(self) -> dict[str, Any]: + """Get circuit breaker statistics. + + Returns: + Dictionary with current stats + """ + return { + "name": self.name, + "state": self.state.value, + "failure_count": self._failure_count, + "failure_threshold": self.failure_threshold, + "cooldown_seconds": self.cooldown_seconds, + "time_until_retry": self.time_until_retry, + "total_failures": self._total_failures, + "total_successes": self._total_successes, + "state_transitions": self._state_transitions, + } + + async def execute( + self, + func: Callable[..., Any], + *args: Any, + **kwargs: Any, + ) -> Any: + """Execute a function with circuit breaker protection. + + This is a convenience method that wraps async function execution + with automatic success/failure recording. + + Args: + func: Async function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Result of the function execution + + Raises: + CircuitBreakerError: If circuit is open + Exception: If function raises and circuit is closed/half-open + """ + if not self.can_execute(): + raise CircuitBreakerError(self.state, self.time_until_retry) + + try: + result = await func(*args, **kwargs) + self.record_success() + return result + except Exception: + self.record_failure() + raise diff --git a/apps/coordinator/src/context_monitor.py b/apps/coordinator/src/context_monitor.py index 9c58c28..07d7d28 100644 --- a/apps/coordinator/src/context_monitor.py +++ b/apps/coordinator/src/context_monitor.py @@ -6,6 +6,7 @@ from collections import defaultdict from collections.abc import Callable from typing import Any +from src.circuit_breaker import CircuitBreaker from src.context_compaction import CompactionResult, ContextCompactor, SessionRotation from src.models import ContextAction, ContextUsage @@ -19,17 +20,29 @@ class ContextMonitor: Triggers appropriate actions based on defined thresholds: - 80% (COMPACT_THRESHOLD): Trigger context compaction - 95% (ROTATE_THRESHOLD): Trigger session rotation + + Circuit Breaker (SEC-ORCH-7): + - Per-agent circuit breakers prevent infinite retry loops on API failures + - After failure_threshold consecutive failures, backs off for cooldown_seconds """ COMPACT_THRESHOLD = 0.80 # 80% triggers compaction ROTATE_THRESHOLD = 0.95 # 95% triggers rotation - def __init__(self, api_client: Any, poll_interval: float = 10.0) -> None: + def __init__( + self, + api_client: Any, + poll_interval: float = 10.0, + circuit_breaker_threshold: int = 3, + circuit_breaker_cooldown: float = 60.0, + ) -> None: """Initialize context monitor. Args: api_client: Claude API client for fetching context usage poll_interval: Seconds between polls (default: 10s) + circuit_breaker_threshold: Consecutive failures before opening circuit (default: 3) + circuit_breaker_cooldown: Seconds to wait before retry after circuit opens (default: 60) """ self.api_client = api_client self.poll_interval = poll_interval @@ -37,6 +50,11 @@ class ContextMonitor: self._monitoring_tasks: dict[str, bool] = {} self._compactor = ContextCompactor(api_client=api_client) + # Circuit breaker settings for per-agent monitoring loops (SEC-ORCH-7) + self._circuit_breaker_threshold = circuit_breaker_threshold + self._circuit_breaker_cooldown = circuit_breaker_cooldown + self._circuit_breakers: dict[str, CircuitBreaker] = {} + async def get_context_usage(self, agent_id: str) -> ContextUsage: """Get current context usage for an agent. @@ -98,6 +116,36 @@ class ContextMonitor: """ return self._usage_history[agent_id] + def _get_circuit_breaker(self, agent_id: str) -> CircuitBreaker: + """Get or create circuit breaker for an agent. + + Args: + agent_id: Unique identifier for the agent + + Returns: + CircuitBreaker instance for this agent + """ + if agent_id not in self._circuit_breakers: + self._circuit_breakers[agent_id] = CircuitBreaker( + name=f"context_monitor_{agent_id}", + failure_threshold=self._circuit_breaker_threshold, + cooldown_seconds=self._circuit_breaker_cooldown, + ) + return self._circuit_breakers[agent_id] + + def get_circuit_breaker_stats(self, agent_id: str) -> dict[str, Any]: + """Get circuit breaker statistics for an agent. + + Args: + agent_id: Unique identifier for the agent + + Returns: + Dictionary with circuit breaker stats, or empty dict if no breaker exists + """ + if agent_id in self._circuit_breakers: + return self._circuit_breakers[agent_id].get_stats() + return {} + async def start_monitoring( self, agent_id: str, callback: Callable[[str, ContextAction], None] ) -> None: @@ -106,22 +154,46 @@ class ContextMonitor: Polls context usage at regular intervals and calls callback with appropriate actions when thresholds are crossed. + Uses circuit breaker to prevent infinite retry loops on repeated failures. + Args: agent_id: Unique identifier for the agent callback: Function to call with (agent_id, action) on each poll """ self._monitoring_tasks[agent_id] = True + circuit_breaker = self._get_circuit_breaker(agent_id) + logger.info( f"Started monitoring agent {agent_id} (poll interval: {self.poll_interval}s)" ) while self._monitoring_tasks.get(agent_id, False): + # Check circuit breaker state before polling + if not circuit_breaker.can_execute(): + wait_time = circuit_breaker.time_until_retry + logger.warning( + f"Circuit breaker OPEN for agent {agent_id} - " + f"backing off for {wait_time:.1f}s" + ) + try: + await asyncio.sleep(wait_time) + except asyncio.CancelledError: + break + continue + try: action = await self.determine_action(agent_id) callback(agent_id, action) + # Successful poll - record success + circuit_breaker.record_success() except Exception as e: - logger.error(f"Error monitoring agent {agent_id}: {e}") - # Continue monitoring despite errors + # Record failure in circuit breaker + circuit_breaker.record_failure() + logger.error( + f"Error monitoring agent {agent_id}: {e} " + f"(circuit breaker: {circuit_breaker.state.value}, " + f"failures: {circuit_breaker.failure_count}/{circuit_breaker.failure_threshold})" + ) # Wait for next poll (or until stopped) try: @@ -129,7 +201,15 @@ class ContextMonitor: except asyncio.CancelledError: break - logger.info(f"Stopped monitoring agent {agent_id}") + # Clean up circuit breaker when monitoring stops + if agent_id in self._circuit_breakers: + stats = self._circuit_breakers[agent_id].get_stats() + del self._circuit_breakers[agent_id] + logger.info( + f"Stopped monitoring agent {agent_id} (circuit breaker stats: {stats})" + ) + else: + logger.info(f"Stopped monitoring agent {agent_id}") def stop_monitoring(self, agent_id: str) -> None: """Stop background monitoring for an agent. diff --git a/apps/coordinator/src/coordinator.py b/apps/coordinator/src/coordinator.py index 790b2f3..85ff078 100644 --- a/apps/coordinator/src/coordinator.py +++ b/apps/coordinator/src/coordinator.py @@ -4,6 +4,7 @@ import asyncio import logging from typing import TYPE_CHECKING, Any +from src.circuit_breaker import CircuitBreaker, CircuitBreakerError, CircuitState from src.context_monitor import ContextMonitor from src.forced_continuation import ForcedContinuationService from src.models import ContextAction @@ -24,20 +25,30 @@ class Coordinator: - Monitoring the queue for ready items - Spawning agents to process issues (stub implementation for Phase 0) - Marking items as complete when processing finishes - - Handling errors gracefully + - Handling errors gracefully with circuit breaker protection - Supporting graceful shutdown + + Circuit Breaker (SEC-ORCH-7): + - Tracks consecutive failures in the main loop + - After failure_threshold consecutive failures, enters OPEN state + - In OPEN state, backs off for cooldown_seconds before retrying + - Prevents infinite retry loops on repeated failures """ def __init__( self, queue_manager: QueueManager, poll_interval: float = 5.0, + circuit_breaker_threshold: int = 5, + circuit_breaker_cooldown: float = 30.0, ) -> None: """Initialize the Coordinator. Args: queue_manager: QueueManager instance for queue operations poll_interval: Seconds between queue polls (default: 5.0) + circuit_breaker_threshold: Consecutive failures before opening circuit (default: 5) + circuit_breaker_cooldown: Seconds to wait before retry after circuit opens (default: 30) """ self.queue_manager = queue_manager self.poll_interval = poll_interval @@ -45,6 +56,13 @@ class Coordinator: self._stop_event: asyncio.Event | None = None self._active_agents: dict[int, dict[str, Any]] = {} + # Circuit breaker for preventing infinite retry loops (SEC-ORCH-7) + self._circuit_breaker = CircuitBreaker( + name="coordinator_loop", + failure_threshold=circuit_breaker_threshold, + cooldown_seconds=circuit_breaker_cooldown, + ) + @property def is_running(self) -> bool: """Check if the coordinator is currently running. @@ -71,10 +89,28 @@ class Coordinator: """ return len(self._active_agents) + @property + def circuit_breaker(self) -> CircuitBreaker: + """Get the circuit breaker instance. + + Returns: + CircuitBreaker instance for this coordinator + """ + return self._circuit_breaker + + def get_circuit_breaker_stats(self) -> dict[str, Any]: + """Get circuit breaker statistics. + + Returns: + Dictionary with circuit breaker stats + """ + return self._circuit_breaker.get_stats() + async def start(self) -> None: """Start the orchestration loop. Continuously processes the queue until stop() is called. + Uses circuit breaker to prevent infinite retry loops on repeated failures. """ self._running = True self._stop_event = asyncio.Event() @@ -82,11 +118,32 @@ class Coordinator: try: while self._running: + # Check circuit breaker state before processing + if not self._circuit_breaker.can_execute(): + # Circuit is open - wait for cooldown + wait_time = self._circuit_breaker.time_until_retry + logger.warning( + f"Circuit breaker OPEN - backing off for {wait_time:.1f}s " + f"(failures: {self._circuit_breaker.failure_count})" + ) + await self._wait_for_cooldown_or_stop(wait_time) + continue + try: await self.process_queue() + # Successful processing - record success + self._circuit_breaker.record_success() + except CircuitBreakerError as e: + # Circuit breaker blocked the request + logger.warning(f"Circuit breaker blocked request: {e}") except Exception as e: - logger.error(f"Error in process_queue: {e}") - # Continue running despite errors + # Record failure in circuit breaker + self._circuit_breaker.record_failure() + logger.error( + f"Error in process_queue: {e} " + f"(circuit breaker: {self._circuit_breaker.state.value}, " + f"failures: {self._circuit_breaker.failure_count}/{self._circuit_breaker.failure_threshold})" + ) # Wait for poll interval or stop signal try: @@ -102,7 +159,26 @@ class Coordinator: finally: self._running = False - logger.info("Coordinator stopped") + logger.info( + f"Coordinator stopped " + f"(circuit breaker stats: {self._circuit_breaker.get_stats()})" + ) + + async def _wait_for_cooldown_or_stop(self, cooldown: float) -> None: + """Wait for cooldown period or stop signal, whichever comes first. + + Args: + cooldown: Seconds to wait for cooldown + """ + if self._stop_event is None: + return + + try: + await asyncio.wait_for(self._stop_event.wait(), timeout=cooldown) + # Stop was requested during cooldown + except TimeoutError: + # Cooldown completed, continue + pass async def stop(self) -> None: """Stop the orchestration loop gracefully. @@ -200,6 +276,12 @@ class OrchestrationLoop: - Quality gate verification on completion claims - Rejection handling with forced continuation prompts - Context monitoring during agent execution + + Circuit Breaker (SEC-ORCH-7): + - Tracks consecutive failures in the main loop + - After failure_threshold consecutive failures, enters OPEN state + - In OPEN state, backs off for cooldown_seconds before retrying + - Prevents infinite retry loops on repeated failures """ def __init__( @@ -209,6 +291,8 @@ class OrchestrationLoop: continuation_service: ForcedContinuationService, context_monitor: ContextMonitor, poll_interval: float = 5.0, + circuit_breaker_threshold: int = 5, + circuit_breaker_cooldown: float = 30.0, ) -> None: """Initialize the OrchestrationLoop. @@ -218,6 +302,8 @@ class OrchestrationLoop: continuation_service: ForcedContinuationService for rejection prompts context_monitor: ContextMonitor for tracking agent context usage poll_interval: Seconds between queue polls (default: 5.0) + circuit_breaker_threshold: Consecutive failures before opening circuit (default: 5) + circuit_breaker_cooldown: Seconds to wait before retry after circuit opens (default: 30) """ self.queue_manager = queue_manager self.quality_orchestrator = quality_orchestrator @@ -233,6 +319,13 @@ class OrchestrationLoop: self._success_count = 0 self._rejection_count = 0 + # Circuit breaker for preventing infinite retry loops (SEC-ORCH-7) + self._circuit_breaker = CircuitBreaker( + name="orchestration_loop", + failure_threshold=circuit_breaker_threshold, + cooldown_seconds=circuit_breaker_cooldown, + ) + @property def is_running(self) -> bool: """Check if the orchestration loop is currently running. @@ -286,10 +379,28 @@ class OrchestrationLoop: """ return len(self._active_agents) + @property + def circuit_breaker(self) -> CircuitBreaker: + """Get the circuit breaker instance. + + Returns: + CircuitBreaker instance for this orchestration loop + """ + return self._circuit_breaker + + def get_circuit_breaker_stats(self) -> dict[str, Any]: + """Get circuit breaker statistics. + + Returns: + Dictionary with circuit breaker stats + """ + return self._circuit_breaker.get_stats() + async def start(self) -> None: """Start the orchestration loop. Continuously processes the queue until stop() is called. + Uses circuit breaker to prevent infinite retry loops on repeated failures. """ self._running = True self._stop_event = asyncio.Event() @@ -297,11 +408,32 @@ class OrchestrationLoop: try: while self._running: + # Check circuit breaker state before processing + if not self._circuit_breaker.can_execute(): + # Circuit is open - wait for cooldown + wait_time = self._circuit_breaker.time_until_retry + logger.warning( + f"Circuit breaker OPEN - backing off for {wait_time:.1f}s " + f"(failures: {self._circuit_breaker.failure_count})" + ) + await self._wait_for_cooldown_or_stop(wait_time) + continue + try: await self.process_next_issue() + # Successful processing - record success + self._circuit_breaker.record_success() + except CircuitBreakerError as e: + # Circuit breaker blocked the request + logger.warning(f"Circuit breaker blocked request: {e}") except Exception as e: - logger.error(f"Error in process_next_issue: {e}") - # Continue running despite errors + # Record failure in circuit breaker + self._circuit_breaker.record_failure() + logger.error( + f"Error in process_next_issue: {e} " + f"(circuit breaker: {self._circuit_breaker.state.value}, " + f"failures: {self._circuit_breaker.failure_count}/{self._circuit_breaker.failure_threshold})" + ) # Wait for poll interval or stop signal try: @@ -317,7 +449,26 @@ class OrchestrationLoop: finally: self._running = False - logger.info("OrchestrationLoop stopped") + logger.info( + f"OrchestrationLoop stopped " + f"(circuit breaker stats: {self._circuit_breaker.get_stats()})" + ) + + async def _wait_for_cooldown_or_stop(self, cooldown: float) -> None: + """Wait for cooldown period or stop signal, whichever comes first. + + Args: + cooldown: Seconds to wait for cooldown + """ + if self._stop_event is None: + return + + try: + await asyncio.wait_for(self._stop_event.wait(), timeout=cooldown) + # Stop was requested during cooldown + except TimeoutError: + # Cooldown completed, continue + pass async def stop(self) -> None: """Stop the orchestration loop gracefully. diff --git a/apps/coordinator/tests/test_circuit_breaker.py b/apps/coordinator/tests/test_circuit_breaker.py new file mode 100644 index 0000000..eda7b00 --- /dev/null +++ b/apps/coordinator/tests/test_circuit_breaker.py @@ -0,0 +1,495 @@ +"""Tests for circuit breaker pattern implementation. + +These tests verify the circuit breaker behavior: +- State transitions (closed -> open -> half_open -> closed) +- Failure counting and threshold detection +- Cooldown timing +- Success/failure recording +- Execute wrapper method +""" + +import asyncio +import time +from unittest.mock import AsyncMock, patch + +import pytest + +from src.circuit_breaker import CircuitBreaker, CircuitBreakerError, CircuitState + + +class TestCircuitBreakerInitialization: + """Tests for CircuitBreaker initialization.""" + + def test_default_initialization(self) -> None: + """Test circuit breaker initializes with default values.""" + cb = CircuitBreaker("test") + + assert cb.name == "test" + assert cb.failure_threshold == 5 + assert cb.cooldown_seconds == 30.0 + assert cb.state == CircuitState.CLOSED + assert cb.failure_count == 0 + + def test_custom_initialization(self) -> None: + """Test circuit breaker with custom values.""" + cb = CircuitBreaker( + name="custom", + failure_threshold=3, + cooldown_seconds=10.0, + ) + + assert cb.name == "custom" + assert cb.failure_threshold == 3 + assert cb.cooldown_seconds == 10.0 + + def test_initial_state_is_closed(self) -> None: + """Test circuit starts in closed state.""" + cb = CircuitBreaker("test") + assert cb.state == CircuitState.CLOSED + + def test_initial_can_execute_is_true(self) -> None: + """Test can_execute returns True initially.""" + cb = CircuitBreaker("test") + assert cb.can_execute() is True + + +class TestCircuitBreakerFailureTracking: + """Tests for failure tracking behavior.""" + + def test_failure_increments_count(self) -> None: + """Test that recording failure increments failure count.""" + cb = CircuitBreaker("test", failure_threshold=5) + + cb.record_failure() + assert cb.failure_count == 1 + + cb.record_failure() + assert cb.failure_count == 2 + + def test_success_resets_failure_count(self) -> None: + """Test that recording success resets failure count.""" + cb = CircuitBreaker("test", failure_threshold=5) + + cb.record_failure() + cb.record_failure() + assert cb.failure_count == 2 + + cb.record_success() + assert cb.failure_count == 0 + + def test_total_failures_tracked(self) -> None: + """Test that total failures are tracked separately.""" + cb = CircuitBreaker("test", failure_threshold=5) + + cb.record_failure() + cb.record_failure() + cb.record_success() # Resets consecutive count + cb.record_failure() + + assert cb.failure_count == 1 # Consecutive + assert cb.total_failures == 3 # Total + + def test_total_successes_tracked(self) -> None: + """Test that total successes are tracked.""" + cb = CircuitBreaker("test") + + cb.record_success() + cb.record_success() + cb.record_failure() + cb.record_success() + + assert cb.total_successes == 3 + + +class TestCircuitBreakerStateTransitions: + """Tests for state transition behavior.""" + + def test_reaches_threshold_opens_circuit(self) -> None: + """Test circuit opens when failure threshold is reached.""" + cb = CircuitBreaker("test", failure_threshold=3) + + cb.record_failure() + assert cb.state == CircuitState.CLOSED + + cb.record_failure() + assert cb.state == CircuitState.CLOSED + + cb.record_failure() + assert cb.state == CircuitState.OPEN + + def test_open_circuit_blocks_execution(self) -> None: + """Test that open circuit blocks can_execute.""" + cb = CircuitBreaker("test", failure_threshold=2) + + cb.record_failure() + cb.record_failure() + + assert cb.state == CircuitState.OPEN + assert cb.can_execute() is False + + def test_cooldown_transitions_to_half_open(self) -> None: + """Test that cooldown period transitions circuit to half-open.""" + cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1) + + cb.record_failure() + cb.record_failure() + assert cb.state == CircuitState.OPEN + + # Wait for cooldown + time.sleep(0.15) + + # Accessing state triggers transition + assert cb.state == CircuitState.HALF_OPEN + + def test_half_open_allows_one_request(self) -> None: + """Test that half-open state allows test request.""" + cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1) + + cb.record_failure() + cb.record_failure() + + time.sleep(0.15) + + assert cb.state == CircuitState.HALF_OPEN + assert cb.can_execute() is True + + def test_half_open_success_closes_circuit(self) -> None: + """Test that success in half-open state closes circuit.""" + cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1) + + cb.record_failure() + cb.record_failure() + + time.sleep(0.15) + assert cb.state == CircuitState.HALF_OPEN + + cb.record_success() + assert cb.state == CircuitState.CLOSED + + def test_half_open_failure_reopens_circuit(self) -> None: + """Test that failure in half-open state reopens circuit.""" + cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1) + + cb.record_failure() + cb.record_failure() + + time.sleep(0.15) + assert cb.state == CircuitState.HALF_OPEN + + cb.record_failure() + assert cb.state == CircuitState.OPEN + + def test_state_transitions_counted(self) -> None: + """Test that state transitions are counted.""" + cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1) + + assert cb.state_transitions == 0 + + cb.record_failure() + cb.record_failure() # -> OPEN + assert cb.state_transitions == 1 + + time.sleep(0.15) + _ = cb.state # -> HALF_OPEN + assert cb.state_transitions == 2 + + cb.record_success() # -> CLOSED + assert cb.state_transitions == 3 + + +class TestCircuitBreakerCooldown: + """Tests for cooldown timing behavior.""" + + def test_time_until_retry_when_open(self) -> None: + """Test time_until_retry reports correct value when open.""" + cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=1.0) + + cb.record_failure() + cb.record_failure() + + # Should be approximately 1 second + assert 0.9 <= cb.time_until_retry <= 1.0 + + def test_time_until_retry_decreases(self) -> None: + """Test time_until_retry decreases over time.""" + cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=1.0) + + cb.record_failure() + cb.record_failure() + + initial = cb.time_until_retry + time.sleep(0.2) + after = cb.time_until_retry + + assert after < initial + + def test_time_until_retry_zero_when_closed(self) -> None: + """Test time_until_retry is 0 when circuit is closed.""" + cb = CircuitBreaker("test") + assert cb.time_until_retry == 0.0 + + def test_time_until_retry_zero_when_half_open(self) -> None: + """Test time_until_retry is 0 when circuit is half-open.""" + cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1) + + cb.record_failure() + cb.record_failure() + time.sleep(0.15) + + assert cb.state == CircuitState.HALF_OPEN + assert cb.time_until_retry == 0.0 + + +class TestCircuitBreakerReset: + """Tests for manual reset behavior.""" + + def test_reset_closes_circuit(self) -> None: + """Test that reset closes an open circuit.""" + cb = CircuitBreaker("test", failure_threshold=2) + + cb.record_failure() + cb.record_failure() + assert cb.state == CircuitState.OPEN + + cb.reset() + assert cb.state == CircuitState.CLOSED + + def test_reset_clears_failure_count(self) -> None: + """Test that reset clears failure count.""" + cb = CircuitBreaker("test", failure_threshold=5) + + cb.record_failure() + cb.record_failure() + assert cb.failure_count == 2 + + cb.reset() + assert cb.failure_count == 0 + + def test_reset_from_half_open(self) -> None: + """Test reset from half-open state.""" + cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1) + + cb.record_failure() + cb.record_failure() + time.sleep(0.15) + assert cb.state == CircuitState.HALF_OPEN + + cb.reset() + assert cb.state == CircuitState.CLOSED + + +class TestCircuitBreakerStats: + """Tests for statistics reporting.""" + + def test_get_stats_returns_all_fields(self) -> None: + """Test get_stats returns complete statistics.""" + cb = CircuitBreaker("test", failure_threshold=3, cooldown_seconds=15.0) + + stats = cb.get_stats() + + assert stats["name"] == "test" + assert stats["state"] == "closed" + assert stats["failure_count"] == 0 + assert stats["failure_threshold"] == 3 + assert stats["cooldown_seconds"] == 15.0 + assert stats["time_until_retry"] == 0.0 + assert stats["total_failures"] == 0 + assert stats["total_successes"] == 0 + assert stats["state_transitions"] == 0 + + def test_stats_update_after_operations(self) -> None: + """Test stats update correctly after operations.""" + cb = CircuitBreaker("test", failure_threshold=3) + + cb.record_failure() + cb.record_success() + cb.record_failure() + cb.record_failure() + cb.record_failure() # Opens circuit + + stats = cb.get_stats() + + assert stats["state"] == "open" + assert stats["failure_count"] == 3 + assert stats["total_failures"] == 4 + assert stats["total_successes"] == 1 + assert stats["state_transitions"] == 1 + + +class TestCircuitBreakerError: + """Tests for CircuitBreakerError exception.""" + + def test_error_contains_state(self) -> None: + """Test error contains circuit state.""" + error = CircuitBreakerError(CircuitState.OPEN, 10.0) + assert error.state == CircuitState.OPEN + + def test_error_contains_retry_time(self) -> None: + """Test error contains time until retry.""" + error = CircuitBreakerError(CircuitState.OPEN, 10.5) + assert error.time_until_retry == 10.5 + + def test_error_message_formatting(self) -> None: + """Test error message is properly formatted.""" + error = CircuitBreakerError(CircuitState.OPEN, 15.3) + assert "open" in str(error) + assert "15.3" in str(error) + + +class TestCircuitBreakerExecute: + """Tests for the execute wrapper method.""" + + @pytest.mark.asyncio + async def test_execute_calls_function(self) -> None: + """Test execute calls the provided function.""" + cb = CircuitBreaker("test") + mock_func = AsyncMock(return_value="success") + + result = await cb.execute(mock_func, "arg1", kwarg="value") + + mock_func.assert_called_once_with("arg1", kwarg="value") + assert result == "success" + + @pytest.mark.asyncio + async def test_execute_records_success(self) -> None: + """Test execute records success on successful call.""" + cb = CircuitBreaker("test") + mock_func = AsyncMock(return_value="ok") + + await cb.execute(mock_func) + + assert cb.total_successes == 1 + + @pytest.mark.asyncio + async def test_execute_records_failure(self) -> None: + """Test execute records failure when function raises.""" + cb = CircuitBreaker("test") + mock_func = AsyncMock(side_effect=RuntimeError("test error")) + + with pytest.raises(RuntimeError): + await cb.execute(mock_func) + + assert cb.failure_count == 1 + + @pytest.mark.asyncio + async def test_execute_raises_when_open(self) -> None: + """Test execute raises CircuitBreakerError when circuit is open.""" + cb = CircuitBreaker("test", failure_threshold=2) + + mock_func = AsyncMock(side_effect=RuntimeError("fail")) + + with pytest.raises(RuntimeError): + await cb.execute(mock_func) + with pytest.raises(RuntimeError): + await cb.execute(mock_func) + + # Circuit should now be open + assert cb.state == CircuitState.OPEN + + # Next call should raise CircuitBreakerError + with pytest.raises(CircuitBreakerError) as exc_info: + await cb.execute(mock_func) + + assert exc_info.value.state == CircuitState.OPEN + + @pytest.mark.asyncio + async def test_execute_allows_half_open_test(self) -> None: + """Test execute allows test request in half-open state.""" + cb = CircuitBreaker("test", failure_threshold=2, cooldown_seconds=0.1) + + mock_func = AsyncMock(side_effect=RuntimeError("fail")) + + with pytest.raises(RuntimeError): + await cb.execute(mock_func) + with pytest.raises(RuntimeError): + await cb.execute(mock_func) + + # Wait for cooldown + await asyncio.sleep(0.15) + assert cb.state == CircuitState.HALF_OPEN + + # Should allow test request + mock_func.side_effect = None + mock_func.return_value = "recovered" + + result = await cb.execute(mock_func) + assert result == "recovered" + assert cb.state == CircuitState.CLOSED + + +class TestCircuitBreakerConcurrency: + """Tests for thread safety and concurrent access.""" + + @pytest.mark.asyncio + async def test_concurrent_failures(self) -> None: + """Test concurrent failures are handled correctly.""" + cb = CircuitBreaker("test", failure_threshold=10) + + async def record_failure() -> None: + cb.record_failure() + + # Record 10 concurrent failures + await asyncio.gather(*[record_failure() for _ in range(10)]) + + assert cb.failure_count >= 10 + assert cb.state == CircuitState.OPEN + + @pytest.mark.asyncio + async def test_concurrent_mixed_operations(self) -> None: + """Test concurrent mixed success/failure operations.""" + cb = CircuitBreaker("test", failure_threshold=100) + + async def record_success() -> None: + cb.record_success() + + async def record_failure() -> None: + cb.record_failure() + + # Mix of operations + tasks = [record_failure() for _ in range(5)] + tasks.extend([record_success() for _ in range(3)]) + tasks.extend([record_failure() for _ in range(5)]) + + await asyncio.gather(*tasks) + + # At least some of each should have been recorded + assert cb.total_failures >= 5 + assert cb.total_successes >= 1 + + +class TestCircuitBreakerLogging: + """Tests for logging behavior.""" + + def test_logs_state_transitions(self) -> None: + """Test that state transitions are logged.""" + cb = CircuitBreaker("test", failure_threshold=2) + + with patch("src.circuit_breaker.logger") as mock_logger: + cb.record_failure() + cb.record_failure() + + # Should have logged the transition to OPEN + mock_logger.info.assert_called() + calls = [str(c) for c in mock_logger.info.call_args_list] + assert any("closed -> open" in c for c in calls) + + def test_logs_failure_warnings(self) -> None: + """Test that failures are logged as warnings.""" + cb = CircuitBreaker("test", failure_threshold=5) + + with patch("src.circuit_breaker.logger") as mock_logger: + cb.record_failure() + + mock_logger.warning.assert_called() + + def test_logs_threshold_reached_as_error(self) -> None: + """Test that reaching threshold is logged as error.""" + cb = CircuitBreaker("test", failure_threshold=2) + + with patch("src.circuit_breaker.logger") as mock_logger: + cb.record_failure() + cb.record_failure() + + mock_logger.error.assert_called() + calls = [str(c) for c in mock_logger.error.call_args_list] + assert any("threshold reached" in c for c in calls) diff --git a/apps/coordinator/tests/test_coordinator.py b/apps/coordinator/tests/test_coordinator.py index 8c4de4d..8835218 100644 --- a/apps/coordinator/tests/test_coordinator.py +++ b/apps/coordinator/tests/test_coordinator.py @@ -744,3 +744,186 @@ class TestCoordinatorActiveAgents: await coordinator.process_queue() assert coordinator.get_active_agent_count() == 3 + + +class TestCoordinatorCircuitBreaker: + """Tests for Coordinator circuit breaker integration (SEC-ORCH-7).""" + + @pytest.fixture + def temp_queue_file(self) -> Generator[Path, None, None]: + """Create a temporary file for queue persistence.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f: + temp_path = Path(f.name) + yield temp_path + if temp_path.exists(): + temp_path.unlink() + + @pytest.fixture + def queue_manager(self, temp_queue_file: Path) -> QueueManager: + """Create a queue manager with temporary storage.""" + return QueueManager(queue_file=temp_queue_file) + + def test_circuit_breaker_initialized(self, queue_manager: QueueManager) -> None: + """Test that circuit breaker is initialized with Coordinator.""" + from src.coordinator import Coordinator + + coordinator = Coordinator(queue_manager=queue_manager) + + assert coordinator.circuit_breaker is not None + assert coordinator.circuit_breaker.name == "coordinator_loop" + + def test_circuit_breaker_custom_settings(self, queue_manager: QueueManager) -> None: + """Test circuit breaker with custom threshold and cooldown.""" + from src.coordinator import Coordinator + + coordinator = Coordinator( + queue_manager=queue_manager, + circuit_breaker_threshold=3, + circuit_breaker_cooldown=15.0, + ) + + assert coordinator.circuit_breaker.failure_threshold == 3 + assert coordinator.circuit_breaker.cooldown_seconds == 15.0 + + def test_get_circuit_breaker_stats(self, queue_manager: QueueManager) -> None: + """Test getting circuit breaker statistics.""" + from src.coordinator import Coordinator + + coordinator = Coordinator(queue_manager=queue_manager) + + stats = coordinator.get_circuit_breaker_stats() + + assert "name" in stats + assert "state" in stats + assert "failure_count" in stats + assert "total_failures" in stats + assert stats["name"] == "coordinator_loop" + assert stats["state"] == "closed" + + @pytest.mark.asyncio + async def test_circuit_breaker_opens_on_repeated_failures( + self, queue_manager: QueueManager + ) -> None: + """Test that circuit breaker opens after repeated failures.""" + from src.circuit_breaker import CircuitState + from src.coordinator import Coordinator + + coordinator = Coordinator( + queue_manager=queue_manager, + poll_interval=0.02, + circuit_breaker_threshold=3, + circuit_breaker_cooldown=0.2, + ) + + failure_count = 0 + + async def failing_process_queue() -> None: + nonlocal failure_count + failure_count += 1 + raise RuntimeError("Simulated failure") + + coordinator.process_queue = failing_process_queue # type: ignore[method-assign] + + task = asyncio.create_task(coordinator.start()) + await asyncio.sleep(0.15) # Allow time for failures + + # Circuit should be open after 3 failures + assert coordinator.circuit_breaker.state == CircuitState.OPEN + assert coordinator.circuit_breaker.failure_count >= 3 + + await coordinator.stop() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + @pytest.mark.asyncio + async def test_circuit_breaker_backs_off_when_open( + self, queue_manager: QueueManager + ) -> None: + """Test that coordinator backs off when circuit breaker is open.""" + from src.coordinator import Coordinator + + coordinator = Coordinator( + queue_manager=queue_manager, + poll_interval=0.02, + circuit_breaker_threshold=2, + circuit_breaker_cooldown=0.3, + ) + + call_timestamps: list[float] = [] + + async def failing_process_queue() -> None: + call_timestamps.append(asyncio.get_event_loop().time()) + raise RuntimeError("Simulated failure") + + coordinator.process_queue = failing_process_queue # type: ignore[method-assign] + + task = asyncio.create_task(coordinator.start()) + await asyncio.sleep(0.5) # Allow time for failures and backoff + await coordinator.stop() + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Should have at least 2 calls (to trigger open), then back off + assert len(call_timestamps) >= 2 + + # After circuit opens, there should be a gap (cooldown) + if len(call_timestamps) >= 3: + # Check there's a larger gap after the first 2 calls + first_gap = call_timestamps[1] - call_timestamps[0] + later_gap = call_timestamps[2] - call_timestamps[1] + # Later gap should be larger due to cooldown + assert later_gap > first_gap * 2 + + @pytest.mark.asyncio + async def test_circuit_breaker_resets_on_success( + self, queue_manager: QueueManager + ) -> None: + """Test that circuit breaker resets after successful operation.""" + from src.circuit_breaker import CircuitState + from src.coordinator import Coordinator + + coordinator = Coordinator( + queue_manager=queue_manager, + poll_interval=0.02, + circuit_breaker_threshold=3, + ) + + # Record failures then success + coordinator.circuit_breaker.record_failure() + coordinator.circuit_breaker.record_failure() + assert coordinator.circuit_breaker.failure_count == 2 + + coordinator.circuit_breaker.record_success() + assert coordinator.circuit_breaker.failure_count == 0 + assert coordinator.circuit_breaker.state == CircuitState.CLOSED + + @pytest.mark.asyncio + async def test_circuit_breaker_stats_logged_on_stop( + self, queue_manager: QueueManager + ) -> None: + """Test that circuit breaker stats are logged when coordinator stops.""" + from src.coordinator import Coordinator + + coordinator = Coordinator(queue_manager=queue_manager, poll_interval=0.05) + + with patch("src.coordinator.logger") as mock_logger: + task = asyncio.create_task(coordinator.start()) + await asyncio.sleep(0.1) + await coordinator.stop() + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Should log circuit breaker stats on stop + info_calls = [str(call) for call in mock_logger.info.call_args_list] + assert any("circuit breaker" in call.lower() for call in info_calls)