feat(#155): Build basic context monitor
Implements ContextMonitor class with real-time token usage tracking: - COMPACT_THRESHOLD at 0.80 (80% triggers compaction) - ROTATE_THRESHOLD at 0.95 (95% triggers rotation) - Poll Claude API for context usage - Return appropriate ContextAction based on thresholds - Background monitoring loop (10-second polling) - Log usage over time - Error handling and recovery Added ContextUsage model for tracking agent token consumption. Tests: - 25 test cases covering all functionality - 100% coverage for context_monitor.py and models.py - Mocked API responses for different usage levels - Background monitoring and threshold detection - Error handling verification Quality gates: - Type checking: PASS (mypy) - Linting: PASS (ruff) - Tests: PASS (25/25) - Coverage: 100% for new files, 95.43% overall Fixes #155 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
381
apps/coordinator/tests/test_context_monitor.py
Normal file
381
apps/coordinator/tests/test_context_monitor.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Tests for context monitoring."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.context_monitor import ContextMonitor
|
||||
from src.models import ContextAction, ContextUsage, IssueMetadata
|
||||
|
||||
|
||||
class TestContextUsage:
|
||||
"""Test ContextUsage model."""
|
||||
|
||||
def test_usage_ratio_calculation(self) -> None:
|
||||
"""Should calculate correct usage ratio."""
|
||||
usage = ContextUsage(agent_id="agent-1", used_tokens=80000, total_tokens=200000)
|
||||
assert usage.usage_ratio == 0.4
|
||||
|
||||
def test_usage_percent_calculation(self) -> None:
|
||||
"""Should calculate correct usage percentage."""
|
||||
usage = ContextUsage(agent_id="agent-1", used_tokens=160000, total_tokens=200000)
|
||||
assert usage.usage_percent == 80.0
|
||||
|
||||
def test_zero_total_tokens(self) -> None:
|
||||
"""Should handle zero total tokens without division error."""
|
||||
usage = ContextUsage(agent_id="agent-1", used_tokens=0, total_tokens=0)
|
||||
assert usage.usage_ratio == 0.0
|
||||
assert usage.usage_percent == 0.0
|
||||
|
||||
def test_repr(self) -> None:
|
||||
"""Should provide readable string representation."""
|
||||
usage = ContextUsage(agent_id="agent-1", used_tokens=100000, total_tokens=200000)
|
||||
repr_str = repr(usage)
|
||||
assert "agent-1" in repr_str
|
||||
assert "100000" in repr_str
|
||||
assert "200000" in repr_str
|
||||
assert "50.0%" in repr_str
|
||||
|
||||
|
||||
class TestContextMonitor:
|
||||
"""Test ContextMonitor class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_claude_api(self) -> AsyncMock:
|
||||
"""Mock Claude API client."""
|
||||
mock = AsyncMock()
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def monitor(self, mock_claude_api: AsyncMock) -> ContextMonitor:
|
||||
"""Create ContextMonitor instance with mocked API."""
|
||||
return ContextMonitor(api_client=mock_claude_api, poll_interval=1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_threshold_constants(self, monitor: ContextMonitor) -> None:
|
||||
"""Should define correct threshold constants."""
|
||||
assert monitor.COMPACT_THRESHOLD == 0.80
|
||||
assert monitor.ROTATE_THRESHOLD == 0.95
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_context_usage_api_call(self, monitor: ContextMonitor, mock_claude_api: AsyncMock) -> None:
|
||||
"""Should call Claude API to get context usage."""
|
||||
# Mock API response
|
||||
mock_claude_api.get_context_usage.return_value = {
|
||||
"used_tokens": 80000,
|
||||
"total_tokens": 200000,
|
||||
}
|
||||
|
||||
usage = await monitor.get_context_usage("agent-1")
|
||||
|
||||
mock_claude_api.get_context_usage.assert_called_once_with("agent-1")
|
||||
assert usage.agent_id == "agent-1"
|
||||
assert usage.used_tokens == 80000
|
||||
assert usage.total_tokens == 200000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_determine_action_below_compact_threshold(
|
||||
self, monitor: ContextMonitor, mock_claude_api: AsyncMock
|
||||
) -> None:
|
||||
"""Should return CONTINUE when below 80% threshold."""
|
||||
# Mock 70% usage
|
||||
mock_claude_api.get_context_usage.return_value = {
|
||||
"used_tokens": 140000,
|
||||
"total_tokens": 200000,
|
||||
}
|
||||
|
||||
action = await monitor.determine_action("agent-1")
|
||||
assert action == ContextAction.CONTINUE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_determine_action_at_compact_threshold(
|
||||
self, monitor: ContextMonitor, mock_claude_api: AsyncMock
|
||||
) -> None:
|
||||
"""Should return COMPACT when at exactly 80% threshold."""
|
||||
# Mock 80% usage
|
||||
mock_claude_api.get_context_usage.return_value = {
|
||||
"used_tokens": 160000,
|
||||
"total_tokens": 200000,
|
||||
}
|
||||
|
||||
action = await monitor.determine_action("agent-1")
|
||||
assert action == ContextAction.COMPACT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_determine_action_between_thresholds(
|
||||
self, monitor: ContextMonitor, mock_claude_api: AsyncMock
|
||||
) -> None:
|
||||
"""Should return COMPACT when between 80% and 95%."""
|
||||
# Mock 85% usage
|
||||
mock_claude_api.get_context_usage.return_value = {
|
||||
"used_tokens": 170000,
|
||||
"total_tokens": 200000,
|
||||
}
|
||||
|
||||
action = await monitor.determine_action("agent-1")
|
||||
assert action == ContextAction.COMPACT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_determine_action_at_rotate_threshold(
|
||||
self, monitor: ContextMonitor, mock_claude_api: AsyncMock
|
||||
) -> None:
|
||||
"""Should return ROTATE_SESSION when at exactly 95% threshold."""
|
||||
# Mock 95% usage
|
||||
mock_claude_api.get_context_usage.return_value = {
|
||||
"used_tokens": 190000,
|
||||
"total_tokens": 200000,
|
||||
}
|
||||
|
||||
action = await monitor.determine_action("agent-1")
|
||||
assert action == ContextAction.ROTATE_SESSION
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_determine_action_above_rotate_threshold(
|
||||
self, monitor: ContextMonitor, mock_claude_api: AsyncMock
|
||||
) -> None:
|
||||
"""Should return ROTATE_SESSION when above 95% threshold."""
|
||||
# Mock 97% usage
|
||||
mock_claude_api.get_context_usage.return_value = {
|
||||
"used_tokens": 194000,
|
||||
"total_tokens": 200000,
|
||||
}
|
||||
|
||||
action = await monitor.determine_action("agent-1")
|
||||
assert action == ContextAction.ROTATE_SESSION
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_usage_history(
|
||||
self, monitor: ContextMonitor, mock_claude_api: AsyncMock
|
||||
) -> None:
|
||||
"""Should log context usage over time."""
|
||||
# Mock responses for multiple checks
|
||||
mock_claude_api.get_context_usage.side_effect = [
|
||||
{"used_tokens": 100000, "total_tokens": 200000},
|
||||
{"used_tokens": 150000, "total_tokens": 200000},
|
||||
{"used_tokens": 180000, "total_tokens": 200000},
|
||||
]
|
||||
|
||||
# Check usage multiple times
|
||||
await monitor.determine_action("agent-1")
|
||||
await monitor.determine_action("agent-1")
|
||||
await monitor.determine_action("agent-1")
|
||||
|
||||
# Verify history was recorded
|
||||
history = monitor.get_usage_history("agent-1")
|
||||
assert len(history) == 3
|
||||
assert history[0].usage_percent == 50.0
|
||||
assert history[1].usage_percent == 75.0
|
||||
assert history[2].usage_percent == 90.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_monitoring_loop(
|
||||
self, mock_claude_api: AsyncMock
|
||||
) -> None:
|
||||
"""Should run background monitoring loop with polling interval."""
|
||||
# Create monitor with very short poll interval for testing
|
||||
monitor = ContextMonitor(api_client=mock_claude_api, poll_interval=0.1)
|
||||
|
||||
# Mock API responses
|
||||
mock_claude_api.get_context_usage.return_value = {
|
||||
"used_tokens": 100000,
|
||||
"total_tokens": 200000,
|
||||
}
|
||||
|
||||
# Track callbacks
|
||||
callback_calls: list[tuple[str, ContextAction]] = []
|
||||
|
||||
def callback(agent_id: str, action: ContextAction) -> None:
|
||||
callback_calls.append((agent_id, action))
|
||||
|
||||
# Start monitoring in background
|
||||
task = asyncio.create_task(monitor.start_monitoring("agent-1", callback))
|
||||
|
||||
# Wait for a few polls
|
||||
await asyncio.sleep(0.35)
|
||||
|
||||
# Stop monitoring
|
||||
monitor.stop_monitoring("agent-1")
|
||||
await task
|
||||
|
||||
# Should have polled at least 3 times (0.35s / 0.1s interval)
|
||||
assert len(callback_calls) >= 3
|
||||
assert all(agent_id == "agent-1" for agent_id, _ in callback_calls)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_monitoring_detects_threshold_crossing(
|
||||
self, mock_claude_api: AsyncMock
|
||||
) -> None:
|
||||
"""Should detect threshold crossings during background monitoring."""
|
||||
monitor = ContextMonitor(api_client=mock_claude_api, poll_interval=0.1)
|
||||
|
||||
# Mock progression: 70% -> 82% -> 96%
|
||||
mock_claude_api.get_context_usage.side_effect = [
|
||||
{"used_tokens": 140000, "total_tokens": 200000}, # 70% CONTINUE
|
||||
{"used_tokens": 164000, "total_tokens": 200000}, # 82% COMPACT
|
||||
{"used_tokens": 192000, "total_tokens": 200000}, # 96% ROTATE
|
||||
{"used_tokens": 192000, "total_tokens": 200000}, # Keep returning high
|
||||
]
|
||||
|
||||
# Track callbacks
|
||||
callback_calls: list[tuple[str, ContextAction]] = []
|
||||
|
||||
def callback(agent_id: str, action: ContextAction) -> None:
|
||||
callback_calls.append((agent_id, action))
|
||||
|
||||
# Start monitoring
|
||||
task = asyncio.create_task(monitor.start_monitoring("agent-1", callback))
|
||||
|
||||
# Wait for progression
|
||||
await asyncio.sleep(0.35)
|
||||
|
||||
# Stop monitoring
|
||||
monitor.stop_monitoring("agent-1")
|
||||
await task
|
||||
|
||||
# Verify threshold crossings were detected
|
||||
actions = [action for _, action in callback_calls]
|
||||
assert ContextAction.CONTINUE in actions
|
||||
assert ContextAction.COMPACT in actions
|
||||
assert ContextAction.ROTATE_SESSION in actions
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_handling(
|
||||
self, monitor: ContextMonitor, mock_claude_api: AsyncMock
|
||||
) -> None:
|
||||
"""Should handle API errors gracefully without crashing."""
|
||||
# Mock API error
|
||||
mock_claude_api.get_context_usage.side_effect = Exception("API unavailable")
|
||||
|
||||
# Should raise exception (caller handles it)
|
||||
with pytest.raises(Exception, match="API unavailable"):
|
||||
await monitor.get_context_usage("agent-1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_monitoring_continues_after_api_error(
|
||||
self, mock_claude_api: AsyncMock
|
||||
) -> None:
|
||||
"""Should continue monitoring after API errors."""
|
||||
monitor = ContextMonitor(api_client=mock_claude_api, poll_interval=0.1)
|
||||
|
||||
# Mock: error -> success -> success
|
||||
mock_claude_api.get_context_usage.side_effect = [
|
||||
Exception("API error"),
|
||||
{"used_tokens": 100000, "total_tokens": 200000},
|
||||
{"used_tokens": 100000, "total_tokens": 200000},
|
||||
]
|
||||
|
||||
callback_calls: list[tuple[str, ContextAction]] = []
|
||||
|
||||
def callback(agent_id: str, action: ContextAction) -> None:
|
||||
callback_calls.append((agent_id, action))
|
||||
|
||||
# Start monitoring
|
||||
task = asyncio.create_task(monitor.start_monitoring("agent-1", callback))
|
||||
|
||||
# Wait for recovery
|
||||
await asyncio.sleep(0.35)
|
||||
|
||||
# Stop monitoring
|
||||
monitor.stop_monitoring("agent-1")
|
||||
await task
|
||||
|
||||
# Should have recovered and made successful callbacks
|
||||
assert len(callback_calls) >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_monitoring_prevents_further_polls(
|
||||
self, mock_claude_api: AsyncMock
|
||||
) -> None:
|
||||
"""Should stop polling when stop_monitoring is called."""
|
||||
monitor = ContextMonitor(api_client=mock_claude_api, poll_interval=0.1)
|
||||
|
||||
mock_claude_api.get_context_usage.return_value = {
|
||||
"used_tokens": 100000,
|
||||
"total_tokens": 200000,
|
||||
}
|
||||
|
||||
callback_calls: list[tuple[str, ContextAction]] = []
|
||||
|
||||
def callback(agent_id: str, action: ContextAction) -> None:
|
||||
callback_calls.append((agent_id, action))
|
||||
|
||||
# Start monitoring
|
||||
task = asyncio.create_task(monitor.start_monitoring("agent-1", callback))
|
||||
|
||||
# Wait for a few polls
|
||||
await asyncio.sleep(0.15)
|
||||
initial_count = len(callback_calls)
|
||||
|
||||
# Stop monitoring
|
||||
monitor.stop_monitoring("agent-1")
|
||||
await task
|
||||
|
||||
# Wait a bit more
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
# Should not have increased
|
||||
assert len(callback_calls) == initial_count
|
||||
|
||||
|
||||
class TestIssueMetadata:
|
||||
"""Test IssueMetadata model."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Should use default values when not specified."""
|
||||
metadata = IssueMetadata()
|
||||
assert metadata.estimated_context == 50000
|
||||
assert metadata.difficulty == "medium"
|
||||
assert metadata.assigned_agent == "sonnet"
|
||||
assert metadata.blocks == []
|
||||
assert metadata.blocked_by == []
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Should accept custom values."""
|
||||
metadata = IssueMetadata(
|
||||
estimated_context=100000,
|
||||
difficulty="hard",
|
||||
assigned_agent="opus",
|
||||
blocks=[1, 2, 3],
|
||||
blocked_by=[4, 5],
|
||||
)
|
||||
assert metadata.estimated_context == 100000
|
||||
assert metadata.difficulty == "hard"
|
||||
assert metadata.assigned_agent == "opus"
|
||||
assert metadata.blocks == [1, 2, 3]
|
||||
assert metadata.blocked_by == [4, 5]
|
||||
|
||||
def test_validate_difficulty_invalid(self) -> None:
|
||||
"""Should default to medium for invalid difficulty."""
|
||||
metadata = IssueMetadata(difficulty="invalid") # type: ignore
|
||||
assert metadata.difficulty == "medium"
|
||||
|
||||
def test_validate_difficulty_valid(self) -> None:
|
||||
"""Should accept valid difficulty values."""
|
||||
for difficulty in ["easy", "medium", "hard"]:
|
||||
metadata = IssueMetadata(difficulty=difficulty) # type: ignore
|
||||
assert metadata.difficulty == difficulty
|
||||
|
||||
def test_validate_agent_invalid(self) -> None:
|
||||
"""Should default to sonnet for invalid agent."""
|
||||
metadata = IssueMetadata(assigned_agent="invalid") # type: ignore
|
||||
assert metadata.assigned_agent == "sonnet"
|
||||
|
||||
def test_validate_agent_valid(self) -> None:
|
||||
"""Should accept valid agent values."""
|
||||
for agent in ["sonnet", "haiku", "opus", "glm"]:
|
||||
metadata = IssueMetadata(assigned_agent=agent) # type: ignore
|
||||
assert metadata.assigned_agent == agent
|
||||
|
||||
def test_validate_issue_lists_none(self) -> None:
|
||||
"""Should convert None to empty list for issue lists."""
|
||||
metadata = IssueMetadata(blocks=None, blocked_by=None) # type: ignore
|
||||
assert metadata.blocks == []
|
||||
assert metadata.blocked_by == []
|
||||
|
||||
def test_validate_issue_lists_with_values(self) -> None:
|
||||
"""Should preserve issue list values."""
|
||||
metadata = IssueMetadata(blocks=[1, 2], blocked_by=[3, 4])
|
||||
assert metadata.blocks == [1, 2]
|
||||
assert metadata.blocked_by == [3, 4]
|
||||
Reference in New Issue
Block a user