Files
telemetry-client-py/tests/test_prediction_cache.py

148 lines
4.5 KiB
Python

"""Tests for PredictionCache."""
from __future__ import annotations
import threading
import time
from mosaicstack_telemetry.prediction_cache import PredictionCache
from mosaicstack_telemetry.types.events import Complexity, Provider, TaskType
from mosaicstack_telemetry.types.predictions import (
PredictionMetadata,
PredictionQuery,
PredictionResponse,
)
def _make_query(
task_type: TaskType = TaskType.IMPLEMENTATION,
model: str = "claude-sonnet-4-20250514",
) -> PredictionQuery:
return PredictionQuery(
task_type=task_type,
model=model,
provider=Provider.ANTHROPIC,
complexity=Complexity.MEDIUM,
)
def _make_response(sample_size: int = 100) -> PredictionResponse:
return PredictionResponse(
prediction=None,
metadata=PredictionMetadata(
sample_size=sample_size,
fallback_level=0,
confidence="high",
),
)
class TestPredictionCache:
"""Tests for the TTL-based prediction cache."""
def test_cache_hit(self) -> None:
"""Cached predictions are returned on hit."""
cache = PredictionCache(ttl_seconds=60.0)
query = _make_query()
response = _make_response()
cache.put(query, response)
result = cache.get(query)
assert result is not None
assert result.metadata.sample_size == 100
def test_cache_miss(self) -> None:
"""Missing keys return None."""
cache = PredictionCache(ttl_seconds=60.0)
query = _make_query()
result = cache.get(query)
assert result is None
def test_cache_expiry(self) -> None:
"""Expired entries return None."""
cache = PredictionCache(ttl_seconds=0.05)
query = _make_query()
response = _make_response()
cache.put(query, response)
time.sleep(0.1)
result = cache.get(query)
assert result is None
def test_different_queries_different_keys(self) -> None:
"""Different queries map to different cache entries."""
cache = PredictionCache(ttl_seconds=60.0)
query1 = _make_query(task_type=TaskType.IMPLEMENTATION)
query2 = _make_query(task_type=TaskType.DEBUGGING)
cache.put(query1, _make_response(sample_size=100))
cache.put(query2, _make_response(sample_size=200))
result1 = cache.get(query1)
result2 = cache.get(query2)
assert result1 is not None
assert result2 is not None
assert result1.metadata.sample_size == 100
assert result2.metadata.sample_size == 200
def test_cache_clear(self) -> None:
"""Clear removes all entries."""
cache = PredictionCache(ttl_seconds=60.0)
query = _make_query()
cache.put(query, _make_response())
assert cache.size == 1
cache.clear()
assert cache.size == 0
assert cache.get(query) is None
def test_cache_overwrite(self) -> None:
"""Putting a new value for the same key overwrites."""
cache = PredictionCache(ttl_seconds=60.0)
query = _make_query()
cache.put(query, _make_response(sample_size=100))
cache.put(query, _make_response(sample_size=200))
result = cache.get(query)
assert result is not None
assert result.metadata.sample_size == 200
def test_thread_safety(self) -> None:
"""Cache handles concurrent access from multiple threads."""
cache = PredictionCache(ttl_seconds=60.0)
errors: list[Exception] = []
iterations = 100
def writer(thread_id: int) -> None:
try:
for i in range(iterations):
query = _make_query(model=f"model-{thread_id}-{i}")
cache.put(query, _make_response(sample_size=i))
except Exception as e:
errors.append(e)
def reader(thread_id: int) -> None:
try:
for i in range(iterations):
query = _make_query(model=f"model-{thread_id}-{i}")
cache.get(query) # May or may not hit
except Exception as e:
errors.append(e)
threads: list[threading.Thread] = []
for tid in range(4):
threads.append(threading.Thread(target=writer, args=(tid,)))
threads.append(threading.Thread(target=reader, args=(tid,)))
for t in threads:
t.start()
for t in threads:
t.join(timeout=5)
assert not errors, f"Thread errors: {errors}"