148 lines
4.5 KiB
Python
148 lines
4.5 KiB
Python
"""Tests for PredictionCache."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import threading
|
|
import time
|
|
|
|
from mosaicstack_telemetry.prediction_cache import PredictionCache
|
|
from mosaicstack_telemetry.types.events import Complexity, Provider, TaskType
|
|
from mosaicstack_telemetry.types.predictions import (
|
|
PredictionMetadata,
|
|
PredictionQuery,
|
|
PredictionResponse,
|
|
)
|
|
|
|
|
|
def _make_query(
|
|
task_type: TaskType = TaskType.IMPLEMENTATION,
|
|
model: str = "claude-sonnet-4-20250514",
|
|
) -> PredictionQuery:
|
|
return PredictionQuery(
|
|
task_type=task_type,
|
|
model=model,
|
|
provider=Provider.ANTHROPIC,
|
|
complexity=Complexity.MEDIUM,
|
|
)
|
|
|
|
|
|
def _make_response(sample_size: int = 100) -> PredictionResponse:
|
|
return PredictionResponse(
|
|
prediction=None,
|
|
metadata=PredictionMetadata(
|
|
sample_size=sample_size,
|
|
fallback_level=0,
|
|
confidence="high",
|
|
),
|
|
)
|
|
|
|
|
|
class TestPredictionCache:
|
|
"""Tests for the TTL-based prediction cache."""
|
|
|
|
def test_cache_hit(self) -> None:
|
|
"""Cached predictions are returned on hit."""
|
|
cache = PredictionCache(ttl_seconds=60.0)
|
|
query = _make_query()
|
|
response = _make_response()
|
|
|
|
cache.put(query, response)
|
|
result = cache.get(query)
|
|
|
|
assert result is not None
|
|
assert result.metadata.sample_size == 100
|
|
|
|
def test_cache_miss(self) -> None:
|
|
"""Missing keys return None."""
|
|
cache = PredictionCache(ttl_seconds=60.0)
|
|
query = _make_query()
|
|
|
|
result = cache.get(query)
|
|
assert result is None
|
|
|
|
def test_cache_expiry(self) -> None:
|
|
"""Expired entries return None."""
|
|
cache = PredictionCache(ttl_seconds=0.05)
|
|
query = _make_query()
|
|
response = _make_response()
|
|
|
|
cache.put(query, response)
|
|
time.sleep(0.1)
|
|
result = cache.get(query)
|
|
|
|
assert result is None
|
|
|
|
def test_different_queries_different_keys(self) -> None:
|
|
"""Different queries map to different cache entries."""
|
|
cache = PredictionCache(ttl_seconds=60.0)
|
|
query1 = _make_query(task_type=TaskType.IMPLEMENTATION)
|
|
query2 = _make_query(task_type=TaskType.DEBUGGING)
|
|
|
|
cache.put(query1, _make_response(sample_size=100))
|
|
cache.put(query2, _make_response(sample_size=200))
|
|
|
|
result1 = cache.get(query1)
|
|
result2 = cache.get(query2)
|
|
|
|
assert result1 is not None
|
|
assert result2 is not None
|
|
assert result1.metadata.sample_size == 100
|
|
assert result2.metadata.sample_size == 200
|
|
|
|
def test_cache_clear(self) -> None:
|
|
"""Clear removes all entries."""
|
|
cache = PredictionCache(ttl_seconds=60.0)
|
|
query = _make_query()
|
|
cache.put(query, _make_response())
|
|
|
|
assert cache.size == 1
|
|
cache.clear()
|
|
assert cache.size == 0
|
|
assert cache.get(query) is None
|
|
|
|
def test_cache_overwrite(self) -> None:
|
|
"""Putting a new value for the same key overwrites."""
|
|
cache = PredictionCache(ttl_seconds=60.0)
|
|
query = _make_query()
|
|
|
|
cache.put(query, _make_response(sample_size=100))
|
|
cache.put(query, _make_response(sample_size=200))
|
|
|
|
result = cache.get(query)
|
|
assert result is not None
|
|
assert result.metadata.sample_size == 200
|
|
|
|
def test_thread_safety(self) -> None:
|
|
"""Cache handles concurrent access from multiple threads."""
|
|
cache = PredictionCache(ttl_seconds=60.0)
|
|
errors: list[Exception] = []
|
|
iterations = 100
|
|
|
|
def writer(thread_id: int) -> None:
|
|
try:
|
|
for i in range(iterations):
|
|
query = _make_query(model=f"model-{thread_id}-{i}")
|
|
cache.put(query, _make_response(sample_size=i))
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
def reader(thread_id: int) -> None:
|
|
try:
|
|
for i in range(iterations):
|
|
query = _make_query(model=f"model-{thread_id}-{i}")
|
|
cache.get(query) # May or may not hit
|
|
except Exception as e:
|
|
errors.append(e)
|
|
|
|
threads: list[threading.Thread] = []
|
|
for tid in range(4):
|
|
threads.append(threading.Thread(target=writer, args=(tid,)))
|
|
threads.append(threading.Thread(target=reader, args=(tid,)))
|
|
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join(timeout=5)
|
|
|
|
assert not errors, f"Thread errors: {errors}"
|