"""Tests for PredictionCache.""" from __future__ import annotations import threading import time from mosaicstack_telemetry.prediction_cache import PredictionCache from mosaicstack_telemetry.types.events import Complexity, Provider, TaskType from mosaicstack_telemetry.types.predictions import ( PredictionMetadata, PredictionQuery, PredictionResponse, ) def _make_query( task_type: TaskType = TaskType.IMPLEMENTATION, model: str = "claude-sonnet-4-20250514", ) -> PredictionQuery: return PredictionQuery( task_type=task_type, model=model, provider=Provider.ANTHROPIC, complexity=Complexity.MEDIUM, ) def _make_response(sample_size: int = 100) -> PredictionResponse: return PredictionResponse( prediction=None, metadata=PredictionMetadata( sample_size=sample_size, fallback_level=0, confidence="high", ), ) class TestPredictionCache: """Tests for the TTL-based prediction cache.""" def test_cache_hit(self) -> None: """Cached predictions are returned on hit.""" cache = PredictionCache(ttl_seconds=60.0) query = _make_query() response = _make_response() cache.put(query, response) result = cache.get(query) assert result is not None assert result.metadata.sample_size == 100 def test_cache_miss(self) -> None: """Missing keys return None.""" cache = PredictionCache(ttl_seconds=60.0) query = _make_query() result = cache.get(query) assert result is None def test_cache_expiry(self) -> None: """Expired entries return None.""" cache = PredictionCache(ttl_seconds=0.05) query = _make_query() response = _make_response() cache.put(query, response) time.sleep(0.1) result = cache.get(query) assert result is None def test_different_queries_different_keys(self) -> None: """Different queries map to different cache entries.""" cache = PredictionCache(ttl_seconds=60.0) query1 = _make_query(task_type=TaskType.IMPLEMENTATION) query2 = _make_query(task_type=TaskType.DEBUGGING) cache.put(query1, _make_response(sample_size=100)) cache.put(query2, _make_response(sample_size=200)) result1 = cache.get(query1) result2 = cache.get(query2) assert result1 is not None assert result2 is not None assert result1.metadata.sample_size == 100 assert result2.metadata.sample_size == 200 def test_cache_clear(self) -> None: """Clear removes all entries.""" cache = PredictionCache(ttl_seconds=60.0) query = _make_query() cache.put(query, _make_response()) assert cache.size == 1 cache.clear() assert cache.size == 0 assert cache.get(query) is None def test_cache_overwrite(self) -> None: """Putting a new value for the same key overwrites.""" cache = PredictionCache(ttl_seconds=60.0) query = _make_query() cache.put(query, _make_response(sample_size=100)) cache.put(query, _make_response(sample_size=200)) result = cache.get(query) assert result is not None assert result.metadata.sample_size == 200 def test_thread_safety(self) -> None: """Cache handles concurrent access from multiple threads.""" cache = PredictionCache(ttl_seconds=60.0) errors: list[Exception] = [] iterations = 100 def writer(thread_id: int) -> None: try: for i in range(iterations): query = _make_query(model=f"model-{thread_id}-{i}") cache.put(query, _make_response(sample_size=i)) except Exception as e: errors.append(e) def reader(thread_id: int) -> None: try: for i in range(iterations): query = _make_query(model=f"model-{thread_id}-{i}") cache.get(query) # May or may not hit except Exception as e: errors.append(e) threads: list[threading.Thread] = [] for tid in range(4): threads.append(threading.Thread(target=writer, args=(tid,))) threads.append(threading.Thread(target=reader, args=(tid,))) for t in threads: t.start() for t in threads: t.join(timeout=5) assert not errors, f"Thread errors: {errors}"