diff --git a/apps/coordinator/src/context_monitor.py b/apps/coordinator/src/context_monitor.py new file mode 100644 index 0000000..6d3f1e5 --- /dev/null +++ b/apps/coordinator/src/context_monitor.py @@ -0,0 +1,139 @@ +"""Context monitoring for agent token usage tracking.""" + +import asyncio +import logging +from collections import defaultdict +from collections.abc import Callable +from typing import Any + +from src.models import ContextAction, ContextUsage + +logger = logging.getLogger(__name__) + + +class ContextMonitor: + """Monitor agent context usage and trigger threshold-based actions. + + Tracks agent token usage in real-time by polling the Claude API. + Triggers appropriate actions based on defined thresholds: + - 80% (COMPACT_THRESHOLD): Trigger context compaction + - 95% (ROTATE_THRESHOLD): Trigger session rotation + """ + + COMPACT_THRESHOLD = 0.80 # 80% triggers compaction + ROTATE_THRESHOLD = 0.95 # 95% triggers rotation + + def __init__(self, api_client: Any, poll_interval: float = 10.0) -> None: + """Initialize context monitor. + + Args: + api_client: Claude API client for fetching context usage + poll_interval: Seconds between polls (default: 10s) + """ + self.api_client = api_client + self.poll_interval = poll_interval + self._usage_history: dict[str, list[ContextUsage]] = defaultdict(list) + self._monitoring_tasks: dict[str, bool] = {} + + async def get_context_usage(self, agent_id: str) -> ContextUsage: + """Get current context usage for an agent. + + Args: + agent_id: Unique identifier for the agent + + Returns: + ContextUsage object with current token usage + + Raises: + Exception: If API call fails + """ + response = await self.api_client.get_context_usage(agent_id) + usage = ContextUsage( + agent_id=agent_id, + used_tokens=response["used_tokens"], + total_tokens=response["total_tokens"], + ) + + # Log usage to history + self._usage_history[agent_id].append(usage) + logger.debug(f"Context usage for {agent_id}: {usage.usage_percent:.1f}%") + + return usage + + async def determine_action(self, agent_id: str) -> ContextAction: + """Determine appropriate action based on current context usage. + + Args: + agent_id: Unique identifier for the agent + + Returns: + ContextAction based on threshold crossings + """ + usage = await self.get_context_usage(agent_id) + + if usage.usage_ratio >= self.ROTATE_THRESHOLD: + logger.warning( + f"Agent {agent_id} hit ROTATE threshold: {usage.usage_percent:.1f}%" + ) + return ContextAction.ROTATE_SESSION + elif usage.usage_ratio >= self.COMPACT_THRESHOLD: + logger.info( + f"Agent {agent_id} hit COMPACT threshold: {usage.usage_percent:.1f}%" + ) + return ContextAction.COMPACT + else: + logger.debug(f"Agent {agent_id} continuing: {usage.usage_percent:.1f}%") + return ContextAction.CONTINUE + + def get_usage_history(self, agent_id: str) -> list[ContextUsage]: + """Get historical context usage for an agent. + + Args: + agent_id: Unique identifier for the agent + + Returns: + List of ContextUsage objects in chronological order + """ + return self._usage_history[agent_id] + + async def start_monitoring( + self, agent_id: str, callback: Callable[[str, ContextAction], None] + ) -> None: + """Start background monitoring loop for an agent. + + Polls context usage at regular intervals and calls callback with + appropriate actions when thresholds are crossed. + + Args: + agent_id: Unique identifier for the agent + callback: Function to call with (agent_id, action) on each poll + """ + self._monitoring_tasks[agent_id] = True + logger.info( + f"Started monitoring agent {agent_id} (poll interval: {self.poll_interval}s)" + ) + + while self._monitoring_tasks.get(agent_id, False): + try: + action = await self.determine_action(agent_id) + callback(agent_id, action) + except Exception as e: + logger.error(f"Error monitoring agent {agent_id}: {e}") + # Continue monitoring despite errors + + # Wait for next poll (or until stopped) + try: + await asyncio.sleep(self.poll_interval) + except asyncio.CancelledError: + break + + logger.info(f"Stopped monitoring agent {agent_id}") + + def stop_monitoring(self, agent_id: str) -> None: + """Stop background monitoring for an agent. + + Args: + agent_id: Unique identifier for the agent + """ + self._monitoring_tasks[agent_id] = False + logger.info(f"Requested stop for agent {agent_id} monitoring") diff --git a/apps/coordinator/src/models.py b/apps/coordinator/src/models.py new file mode 100644 index 0000000..eb04b97 --- /dev/null +++ b/apps/coordinator/src/models.py @@ -0,0 +1,110 @@ +"""Data models for mosaic-coordinator.""" + +from enum import Enum +from typing import Literal + +from pydantic import BaseModel, Field, field_validator + + +class ContextAction(str, Enum): + """Actions to take based on context usage thresholds.""" + + CONTINUE = "continue" # Below compact threshold, keep working + COMPACT = "compact" # Hit 80% threshold, summarize and compact + ROTATE_SESSION = "rotate_session" # Hit 95% threshold, spawn new agent + + +class ContextUsage: + """Agent context usage information.""" + + def __init__(self, agent_id: str, used_tokens: int, total_tokens: int) -> None: + """Initialize context usage. + + Args: + agent_id: Unique identifier for the agent + used_tokens: Number of tokens currently used + total_tokens: Total token capacity for this agent + """ + self.agent_id = agent_id + self.used_tokens = used_tokens + self.total_tokens = total_tokens + + @property + def usage_ratio(self) -> float: + """Calculate usage as a ratio (0.0-1.0). + + Returns: + Ratio of used tokens to total capacity + """ + if self.total_tokens == 0: + return 0.0 + return self.used_tokens / self.total_tokens + + @property + def usage_percent(self) -> float: + """Calculate usage as a percentage (0-100). + + Returns: + Percentage of context used + """ + return self.usage_ratio * 100 + + def __repr__(self) -> str: + """String representation.""" + return ( + f"ContextUsage(agent_id={self.agent_id!r}, " + f"used={self.used_tokens}, total={self.total_tokens}, " + f"usage={self.usage_percent:.1f}%)" + ) + + +class IssueMetadata(BaseModel): + """Parsed metadata from issue body.""" + + estimated_context: int = Field( + default=50000, + description="Estimated context size in tokens", + ge=0 + ) + difficulty: Literal["easy", "medium", "hard"] = Field( + default="medium", + description="Issue difficulty level" + ) + assigned_agent: Literal["sonnet", "haiku", "opus", "glm"] = Field( + default="sonnet", + description="Recommended AI agent for this issue" + ) + blocks: list[int] = Field( + default_factory=list, + description="List of issue numbers this issue blocks" + ) + blocked_by: list[int] = Field( + default_factory=list, + description="List of issue numbers blocking this issue" + ) + + @field_validator("difficulty", mode="before") + @classmethod + def validate_difficulty(cls, v: str) -> str: + """Validate difficulty, default to medium if invalid.""" + valid_values = ["easy", "medium", "hard"] + if v not in valid_values: + return "medium" + return v + + @field_validator("assigned_agent", mode="before") + @classmethod + def validate_agent(cls, v: str) -> str: + """Validate agent, default to sonnet if invalid.""" + valid_values = ["sonnet", "haiku", "opus", "glm"] + if v not in valid_values: + return "sonnet" + return v + + @field_validator("blocks", "blocked_by", mode="before") + @classmethod + def validate_issue_lists(cls, v: list[int] | None) -> list[int]: + """Ensure issue lists are never None.""" + if v is None: + return [] + return v diff --git a/apps/coordinator/tests/test_context_monitor.py b/apps/coordinator/tests/test_context_monitor.py new file mode 100644 index 0000000..38b9a32 --- /dev/null +++ b/apps/coordinator/tests/test_context_monitor.py @@ -0,0 +1,381 @@ +"""Tests for context monitoring.""" + +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from src.context_monitor import ContextMonitor +from src.models import ContextAction, ContextUsage, IssueMetadata + + +class TestContextUsage: + """Test ContextUsage model.""" + + def test_usage_ratio_calculation(self) -> None: + """Should calculate correct usage ratio.""" + usage = ContextUsage(agent_id="agent-1", used_tokens=80000, total_tokens=200000) + assert usage.usage_ratio == 0.4 + + def test_usage_percent_calculation(self) -> None: + """Should calculate correct usage percentage.""" + usage = ContextUsage(agent_id="agent-1", used_tokens=160000, total_tokens=200000) + assert usage.usage_percent == 80.0 + + def test_zero_total_tokens(self) -> None: + """Should handle zero total tokens without division error.""" + usage = ContextUsage(agent_id="agent-1", used_tokens=0, total_tokens=0) + assert usage.usage_ratio == 0.0 + assert usage.usage_percent == 0.0 + + def test_repr(self) -> None: + """Should provide readable string representation.""" + usage = ContextUsage(agent_id="agent-1", used_tokens=100000, total_tokens=200000) + repr_str = repr(usage) + assert "agent-1" in repr_str + assert "100000" in repr_str + assert "200000" in repr_str + assert "50.0%" in repr_str + + +class TestContextMonitor: + """Test ContextMonitor class.""" + + @pytest.fixture + def mock_claude_api(self) -> AsyncMock: + """Mock Claude API client.""" + mock = AsyncMock() + return mock + + @pytest.fixture + def monitor(self, mock_claude_api: AsyncMock) -> ContextMonitor: + """Create ContextMonitor instance with mocked API.""" + return ContextMonitor(api_client=mock_claude_api, poll_interval=1) + + @pytest.mark.asyncio + async def test_threshold_constants(self, monitor: ContextMonitor) -> None: + """Should define correct threshold constants.""" + assert monitor.COMPACT_THRESHOLD == 0.80 + assert monitor.ROTATE_THRESHOLD == 0.95 + + @pytest.mark.asyncio + async def test_get_context_usage_api_call(self, monitor: ContextMonitor, mock_claude_api: AsyncMock) -> None: + """Should call Claude API to get context usage.""" + # Mock API response + mock_claude_api.get_context_usage.return_value = { + "used_tokens": 80000, + "total_tokens": 200000, + } + + usage = await monitor.get_context_usage("agent-1") + + mock_claude_api.get_context_usage.assert_called_once_with("agent-1") + assert usage.agent_id == "agent-1" + assert usage.used_tokens == 80000 + assert usage.total_tokens == 200000 + + @pytest.mark.asyncio + async def test_determine_action_below_compact_threshold( + self, monitor: ContextMonitor, mock_claude_api: AsyncMock + ) -> None: + """Should return CONTINUE when below 80% threshold.""" + # Mock 70% usage + mock_claude_api.get_context_usage.return_value = { + "used_tokens": 140000, + "total_tokens": 200000, + } + + action = await monitor.determine_action("agent-1") + assert action == ContextAction.CONTINUE + + @pytest.mark.asyncio + async def test_determine_action_at_compact_threshold( + self, monitor: ContextMonitor, mock_claude_api: AsyncMock + ) -> None: + """Should return COMPACT when at exactly 80% threshold.""" + # Mock 80% usage + mock_claude_api.get_context_usage.return_value = { + "used_tokens": 160000, + "total_tokens": 200000, + } + + action = await monitor.determine_action("agent-1") + assert action == ContextAction.COMPACT + + @pytest.mark.asyncio + async def test_determine_action_between_thresholds( + self, monitor: ContextMonitor, mock_claude_api: AsyncMock + ) -> None: + """Should return COMPACT when between 80% and 95%.""" + # Mock 85% usage + mock_claude_api.get_context_usage.return_value = { + "used_tokens": 170000, + "total_tokens": 200000, + } + + action = await monitor.determine_action("agent-1") + assert action == ContextAction.COMPACT + + @pytest.mark.asyncio + async def test_determine_action_at_rotate_threshold( + self, monitor: ContextMonitor, mock_claude_api: AsyncMock + ) -> None: + """Should return ROTATE_SESSION when at exactly 95% threshold.""" + # Mock 95% usage + mock_claude_api.get_context_usage.return_value = { + "used_tokens": 190000, + "total_tokens": 200000, + } + + action = await monitor.determine_action("agent-1") + assert action == ContextAction.ROTATE_SESSION + + @pytest.mark.asyncio + async def test_determine_action_above_rotate_threshold( + self, monitor: ContextMonitor, mock_claude_api: AsyncMock + ) -> None: + """Should return ROTATE_SESSION when above 95% threshold.""" + # Mock 97% usage + mock_claude_api.get_context_usage.return_value = { + "used_tokens": 194000, + "total_tokens": 200000, + } + + action = await monitor.determine_action("agent-1") + assert action == ContextAction.ROTATE_SESSION + + @pytest.mark.asyncio + async def test_log_usage_history( + self, monitor: ContextMonitor, mock_claude_api: AsyncMock + ) -> None: + """Should log context usage over time.""" + # Mock responses for multiple checks + mock_claude_api.get_context_usage.side_effect = [ + {"used_tokens": 100000, "total_tokens": 200000}, + {"used_tokens": 150000, "total_tokens": 200000}, + {"used_tokens": 180000, "total_tokens": 200000}, + ] + + # Check usage multiple times + await monitor.determine_action("agent-1") + await monitor.determine_action("agent-1") + await monitor.determine_action("agent-1") + + # Verify history was recorded + history = monitor.get_usage_history("agent-1") + assert len(history) == 3 + assert history[0].usage_percent == 50.0 + assert history[1].usage_percent == 75.0 + assert history[2].usage_percent == 90.0 + + @pytest.mark.asyncio + async def test_background_monitoring_loop( + self, mock_claude_api: AsyncMock + ) -> None: + """Should run background monitoring loop with polling interval.""" + # Create monitor with very short poll interval for testing + monitor = ContextMonitor(api_client=mock_claude_api, poll_interval=0.1) + + # Mock API responses + mock_claude_api.get_context_usage.return_value = { + "used_tokens": 100000, + "total_tokens": 200000, + } + + # Track callbacks + callback_calls: list[tuple[str, ContextAction]] = [] + + def callback(agent_id: str, action: ContextAction) -> None: + callback_calls.append((agent_id, action)) + + # Start monitoring in background + task = asyncio.create_task(monitor.start_monitoring("agent-1", callback)) + + # Wait for a few polls + await asyncio.sleep(0.35) + + # Stop monitoring + monitor.stop_monitoring("agent-1") + await task + + # Should have polled at least 3 times (0.35s / 0.1s interval) + assert len(callback_calls) >= 3 + assert all(agent_id == "agent-1" for agent_id, _ in callback_calls) + + @pytest.mark.asyncio + async def test_background_monitoring_detects_threshold_crossing( + self, mock_claude_api: AsyncMock + ) -> None: + """Should detect threshold crossings during background monitoring.""" + monitor = ContextMonitor(api_client=mock_claude_api, poll_interval=0.1) + + # Mock progression: 70% -> 82% -> 96% + mock_claude_api.get_context_usage.side_effect = [ + {"used_tokens": 140000, "total_tokens": 200000}, # 70% CONTINUE + {"used_tokens": 164000, "total_tokens": 200000}, # 82% COMPACT + {"used_tokens": 192000, "total_tokens": 200000}, # 96% ROTATE + {"used_tokens": 192000, "total_tokens": 200000}, # Keep returning high + ] + + # Track callbacks + callback_calls: list[tuple[str, ContextAction]] = [] + + def callback(agent_id: str, action: ContextAction) -> None: + callback_calls.append((agent_id, action)) + + # Start monitoring + task = asyncio.create_task(monitor.start_monitoring("agent-1", callback)) + + # Wait for progression + await asyncio.sleep(0.35) + + # Stop monitoring + monitor.stop_monitoring("agent-1") + await task + + # Verify threshold crossings were detected + actions = [action for _, action in callback_calls] + assert ContextAction.CONTINUE in actions + assert ContextAction.COMPACT in actions + assert ContextAction.ROTATE_SESSION in actions + + @pytest.mark.asyncio + async def test_api_error_handling( + self, monitor: ContextMonitor, mock_claude_api: AsyncMock + ) -> None: + """Should handle API errors gracefully without crashing.""" + # Mock API error + mock_claude_api.get_context_usage.side_effect = Exception("API unavailable") + + # Should raise exception (caller handles it) + with pytest.raises(Exception, match="API unavailable"): + await monitor.get_context_usage("agent-1") + + @pytest.mark.asyncio + async def test_background_monitoring_continues_after_api_error( + self, mock_claude_api: AsyncMock + ) -> None: + """Should continue monitoring after API errors.""" + monitor = ContextMonitor(api_client=mock_claude_api, poll_interval=0.1) + + # Mock: error -> success -> success + mock_claude_api.get_context_usage.side_effect = [ + Exception("API error"), + {"used_tokens": 100000, "total_tokens": 200000}, + {"used_tokens": 100000, "total_tokens": 200000}, + ] + + callback_calls: list[tuple[str, ContextAction]] = [] + + def callback(agent_id: str, action: ContextAction) -> None: + callback_calls.append((agent_id, action)) + + # Start monitoring + task = asyncio.create_task(monitor.start_monitoring("agent-1", callback)) + + # Wait for recovery + await asyncio.sleep(0.35) + + # Stop monitoring + monitor.stop_monitoring("agent-1") + await task + + # Should have recovered and made successful callbacks + assert len(callback_calls) >= 2 + + @pytest.mark.asyncio + async def test_stop_monitoring_prevents_further_polls( + self, mock_claude_api: AsyncMock + ) -> None: + """Should stop polling when stop_monitoring is called.""" + monitor = ContextMonitor(api_client=mock_claude_api, poll_interval=0.1) + + mock_claude_api.get_context_usage.return_value = { + "used_tokens": 100000, + "total_tokens": 200000, + } + + callback_calls: list[tuple[str, ContextAction]] = [] + + def callback(agent_id: str, action: ContextAction) -> None: + callback_calls.append((agent_id, action)) + + # Start monitoring + task = asyncio.create_task(monitor.start_monitoring("agent-1", callback)) + + # Wait for a few polls + await asyncio.sleep(0.15) + initial_count = len(callback_calls) + + # Stop monitoring + monitor.stop_monitoring("agent-1") + await task + + # Wait a bit more + await asyncio.sleep(0.15) + + # Should not have increased + assert len(callback_calls) == initial_count + + +class TestIssueMetadata: + """Test IssueMetadata model.""" + + def test_default_values(self) -> None: + """Should use default values when not specified.""" + metadata = IssueMetadata() + assert metadata.estimated_context == 50000 + assert metadata.difficulty == "medium" + assert metadata.assigned_agent == "sonnet" + assert metadata.blocks == [] + assert metadata.blocked_by == [] + + def test_custom_values(self) -> None: + """Should accept custom values.""" + metadata = IssueMetadata( + estimated_context=100000, + difficulty="hard", + assigned_agent="opus", + blocks=[1, 2, 3], + blocked_by=[4, 5], + ) + assert metadata.estimated_context == 100000 + assert metadata.difficulty == "hard" + assert metadata.assigned_agent == "opus" + assert metadata.blocks == [1, 2, 3] + assert metadata.blocked_by == [4, 5] + + def test_validate_difficulty_invalid(self) -> None: + """Should default to medium for invalid difficulty.""" + metadata = IssueMetadata(difficulty="invalid") # type: ignore + assert metadata.difficulty == "medium" + + def test_validate_difficulty_valid(self) -> None: + """Should accept valid difficulty values.""" + for difficulty in ["easy", "medium", "hard"]: + metadata = IssueMetadata(difficulty=difficulty) # type: ignore + assert metadata.difficulty == difficulty + + def test_validate_agent_invalid(self) -> None: + """Should default to sonnet for invalid agent.""" + metadata = IssueMetadata(assigned_agent="invalid") # type: ignore + assert metadata.assigned_agent == "sonnet" + + def test_validate_agent_valid(self) -> None: + """Should accept valid agent values.""" + for agent in ["sonnet", "haiku", "opus", "glm"]: + metadata = IssueMetadata(assigned_agent=agent) # type: ignore + assert metadata.assigned_agent == agent + + def test_validate_issue_lists_none(self) -> None: + """Should convert None to empty list for issue lists.""" + metadata = IssueMetadata(blocks=None, blocked_by=None) # type: ignore + assert metadata.blocks == [] + assert metadata.blocked_by == [] + + def test_validate_issue_lists_with_values(self) -> None: + """Should preserve issue list values.""" + metadata = IssueMetadata(blocks=[1, 2], blocked_by=[3, 4]) + assert metadata.blocks == [1, 2] + assert metadata.blocked_by == [3, 4]