Initial project structure
This commit is contained in:
147
tests/test_prediction_cache.py
Normal file
147
tests/test_prediction_cache.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Tests for PredictionCache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from mosaicstack_telemetry.prediction_cache import PredictionCache
|
||||
from mosaicstack_telemetry.types.events import Complexity, Provider, TaskType
|
||||
from mosaicstack_telemetry.types.predictions import (
|
||||
PredictionMetadata,
|
||||
PredictionQuery,
|
||||
PredictionResponse,
|
||||
)
|
||||
|
||||
|
||||
def _make_query(
|
||||
task_type: TaskType = TaskType.IMPLEMENTATION,
|
||||
model: str = "claude-sonnet-4-20250514",
|
||||
) -> PredictionQuery:
|
||||
return PredictionQuery(
|
||||
task_type=task_type,
|
||||
model=model,
|
||||
provider=Provider.ANTHROPIC,
|
||||
complexity=Complexity.MEDIUM,
|
||||
)
|
||||
|
||||
|
||||
def _make_response(sample_size: int = 100) -> PredictionResponse:
|
||||
return PredictionResponse(
|
||||
prediction=None,
|
||||
metadata=PredictionMetadata(
|
||||
sample_size=sample_size,
|
||||
fallback_level=0,
|
||||
confidence="high",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestPredictionCache:
|
||||
"""Tests for the TTL-based prediction cache."""
|
||||
|
||||
def test_cache_hit(self) -> None:
|
||||
"""Cached predictions are returned on hit."""
|
||||
cache = PredictionCache(ttl_seconds=60.0)
|
||||
query = _make_query()
|
||||
response = _make_response()
|
||||
|
||||
cache.put(query, response)
|
||||
result = cache.get(query)
|
||||
|
||||
assert result is not None
|
||||
assert result.metadata.sample_size == 100
|
||||
|
||||
def test_cache_miss(self) -> None:
|
||||
"""Missing keys return None."""
|
||||
cache = PredictionCache(ttl_seconds=60.0)
|
||||
query = _make_query()
|
||||
|
||||
result = cache.get(query)
|
||||
assert result is None
|
||||
|
||||
def test_cache_expiry(self) -> None:
|
||||
"""Expired entries return None."""
|
||||
cache = PredictionCache(ttl_seconds=0.05)
|
||||
query = _make_query()
|
||||
response = _make_response()
|
||||
|
||||
cache.put(query, response)
|
||||
time.sleep(0.1)
|
||||
result = cache.get(query)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_different_queries_different_keys(self) -> None:
|
||||
"""Different queries map to different cache entries."""
|
||||
cache = PredictionCache(ttl_seconds=60.0)
|
||||
query1 = _make_query(task_type=TaskType.IMPLEMENTATION)
|
||||
query2 = _make_query(task_type=TaskType.DEBUGGING)
|
||||
|
||||
cache.put(query1, _make_response(sample_size=100))
|
||||
cache.put(query2, _make_response(sample_size=200))
|
||||
|
||||
result1 = cache.get(query1)
|
||||
result2 = cache.get(query2)
|
||||
|
||||
assert result1 is not None
|
||||
assert result2 is not None
|
||||
assert result1.metadata.sample_size == 100
|
||||
assert result2.metadata.sample_size == 200
|
||||
|
||||
def test_cache_clear(self) -> None:
|
||||
"""Clear removes all entries."""
|
||||
cache = PredictionCache(ttl_seconds=60.0)
|
||||
query = _make_query()
|
||||
cache.put(query, _make_response())
|
||||
|
||||
assert cache.size == 1
|
||||
cache.clear()
|
||||
assert cache.size == 0
|
||||
assert cache.get(query) is None
|
||||
|
||||
def test_cache_overwrite(self) -> None:
|
||||
"""Putting a new value for the same key overwrites."""
|
||||
cache = PredictionCache(ttl_seconds=60.0)
|
||||
query = _make_query()
|
||||
|
||||
cache.put(query, _make_response(sample_size=100))
|
||||
cache.put(query, _make_response(sample_size=200))
|
||||
|
||||
result = cache.get(query)
|
||||
assert result is not None
|
||||
assert result.metadata.sample_size == 200
|
||||
|
||||
def test_thread_safety(self) -> None:
|
||||
"""Cache handles concurrent access from multiple threads."""
|
||||
cache = PredictionCache(ttl_seconds=60.0)
|
||||
errors: list[Exception] = []
|
||||
iterations = 100
|
||||
|
||||
def writer(thread_id: int) -> None:
|
||||
try:
|
||||
for i in range(iterations):
|
||||
query = _make_query(model=f"model-{thread_id}-{i}")
|
||||
cache.put(query, _make_response(sample_size=i))
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def reader(thread_id: int) -> None:
|
||||
try:
|
||||
for i in range(iterations):
|
||||
query = _make_query(model=f"model-{thread_id}-{i}")
|
||||
cache.get(query) # May or may not hit
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads: list[threading.Thread] = []
|
||||
for tid in range(4):
|
||||
threads.append(threading.Thread(target=writer, args=(tid,)))
|
||||
threads.append(threading.Thread(target=reader, args=(tid,)))
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
assert not errors, f"Thread errors: {errors}"
|
||||
Reference in New Issue
Block a user