import { describe, it, expect, beforeEach, vi } from "vitest"; import { Test, TestingModule } from "@nestjs/testing"; import { TaskType, Complexity, Provider, } from "@mosaicstack/telemetry-client"; import type { PredictionResponse, PredictionQuery, } from "@mosaicstack/telemetry-client"; import { MosaicTelemetryService } from "./mosaic-telemetry.service"; import { PredictionService } from "./prediction.service"; describe("PredictionService", () => { let service: PredictionService; let mockTelemetryService: { isEnabled: boolean; getPrediction: ReturnType; refreshPredictions: ReturnType; }; const mockPredictionResponse: PredictionResponse = { prediction: { input_tokens: { p10: 50, p25: 80, median: 120, p75: 200, p90: 350, }, output_tokens: { p10: 100, p25: 150, median: 250, p75: 400, p90: 600, }, cost_usd_micros: { p10: 500, p25: 800, median: 1200, p75: 2000, p90: 3500, }, duration_ms: { p10: 200, p25: 400, median: 800, p75: 1500, p90: 3000, }, correction_factors: { input: 1.0, output: 1.0, }, quality: { gate_pass_rate: 0.95, success_rate: 0.92, }, }, metadata: { sample_size: 150, fallback_level: 0, confidence: "high", last_updated: "2026-02-15T00:00:00Z", cache_hit: true, }, }; const nullPredictionResponse: PredictionResponse = { prediction: null, metadata: { sample_size: 0, fallback_level: 3, confidence: "none", last_updated: null, cache_hit: false, }, }; beforeEach(async () => { mockTelemetryService = { isEnabled: true, getPrediction: vi.fn().mockReturnValue(mockPredictionResponse), refreshPredictions: vi.fn().mockResolvedValue(undefined), }; const module: TestingModule = await Test.createTestingModule({ providers: [ PredictionService, { provide: MosaicTelemetryService, useValue: mockTelemetryService, }, ], }).compile(); service = module.get(PredictionService); }); it("should be defined", () => { expect(service).toBeDefined(); }); // ---------- getEstimate ---------- describe("getEstimate", () => { it("should return prediction response for valid query", () => { const result = service.getEstimate( TaskType.IMPLEMENTATION, "claude-sonnet-4-5", Provider.ANTHROPIC, Complexity.LOW ); expect(result).toEqual(mockPredictionResponse); expect(mockTelemetryService.getPrediction).toHaveBeenCalledWith({ task_type: TaskType.IMPLEMENTATION, model: "claude-sonnet-4-5", provider: Provider.ANTHROPIC, complexity: Complexity.LOW, }); }); it("should pass correct query parameters to telemetry service", () => { service.getEstimate( TaskType.CODE_REVIEW, "gpt-4o", Provider.OPENAI, Complexity.HIGH ); expect(mockTelemetryService.getPrediction).toHaveBeenCalledWith({ task_type: TaskType.CODE_REVIEW, model: "gpt-4o", provider: Provider.OPENAI, complexity: Complexity.HIGH, }); }); it("should return null when telemetry returns null", () => { mockTelemetryService.getPrediction.mockReturnValue(null); const result = service.getEstimate( TaskType.IMPLEMENTATION, "claude-sonnet-4-5", Provider.ANTHROPIC, Complexity.LOW ); expect(result).toBeNull(); }); it("should return null prediction response when confidence is none", () => { mockTelemetryService.getPrediction.mockReturnValue(nullPredictionResponse); const result = service.getEstimate( TaskType.IMPLEMENTATION, "unknown-model", Provider.UNKNOWN, Complexity.LOW ); expect(result).toEqual(nullPredictionResponse); expect(result?.metadata.confidence).toBe("none"); }); it("should return null and not throw when getPrediction throws", () => { mockTelemetryService.getPrediction.mockImplementation(() => { throw new Error("Prediction fetch failed"); }); const result = service.getEstimate( TaskType.IMPLEMENTATION, "claude-sonnet-4-5", Provider.ANTHROPIC, Complexity.LOW ); expect(result).toBeNull(); }); it("should handle non-Error thrown objects gracefully", () => { mockTelemetryService.getPrediction.mockImplementation(() => { throw "string error"; }); const result = service.getEstimate( TaskType.IMPLEMENTATION, "claude-sonnet-4-5", Provider.ANTHROPIC, Complexity.LOW ); expect(result).toBeNull(); }); }); // ---------- refreshCommonPredictions ---------- describe("refreshCommonPredictions", () => { it("should call refreshPredictions with multiple query combinations", async () => { await service.refreshCommonPredictions(); expect(mockTelemetryService.refreshPredictions).toHaveBeenCalledTimes(1); const queries: PredictionQuery[] = mockTelemetryService.refreshPredictions.mock.calls[0][0]; // Should have queries for cross-product of models, task types, and complexities expect(queries.length).toBeGreaterThan(0); // Verify all queries have valid structure for (const query of queries) { expect(query).toHaveProperty("task_type"); expect(query).toHaveProperty("model"); expect(query).toHaveProperty("provider"); expect(query).toHaveProperty("complexity"); } }); it("should include Anthropic model predictions", async () => { await service.refreshCommonPredictions(); const queries: PredictionQuery[] = mockTelemetryService.refreshPredictions.mock.calls[0][0]; const anthropicQueries = queries.filter( (q: PredictionQuery) => q.provider === Provider.ANTHROPIC ); expect(anthropicQueries.length).toBeGreaterThan(0); }); it("should include OpenAI model predictions", async () => { await service.refreshCommonPredictions(); const queries: PredictionQuery[] = mockTelemetryService.refreshPredictions.mock.calls[0][0]; const openaiQueries = queries.filter( (q: PredictionQuery) => q.provider === Provider.OPENAI ); expect(openaiQueries.length).toBeGreaterThan(0); }); it("should not call refreshPredictions when telemetry is disabled", async () => { mockTelemetryService.isEnabled = false; await service.refreshCommonPredictions(); expect(mockTelemetryService.refreshPredictions).not.toHaveBeenCalled(); }); it("should not throw when refreshPredictions rejects", async () => { mockTelemetryService.refreshPredictions.mockRejectedValue( new Error("Server unreachable") ); // Should not throw await expect(service.refreshCommonPredictions()).resolves.not.toThrow(); }); it("should include common task types in queries", async () => { await service.refreshCommonPredictions(); const queries: PredictionQuery[] = mockTelemetryService.refreshPredictions.mock.calls[0][0]; const taskTypes = new Set(queries.map((q: PredictionQuery) => q.task_type)); expect(taskTypes.has(TaskType.IMPLEMENTATION)).toBe(true); expect(taskTypes.has(TaskType.PLANNING)).toBe(true); expect(taskTypes.has(TaskType.CODE_REVIEW)).toBe(true); }); it("should include common complexity levels in queries", async () => { await service.refreshCommonPredictions(); const queries: PredictionQuery[] = mockTelemetryService.refreshPredictions.mock.calls[0][0]; const complexities = new Set(queries.map((q: PredictionQuery) => q.complexity)); expect(complexities.has(Complexity.LOW)).toBe(true); expect(complexities.has(Complexity.MEDIUM)).toBe(true); }); }); // ---------- onModuleInit ---------- describe("onModuleInit", () => { it("should trigger refreshCommonPredictions on init when telemetry is enabled", () => { // refreshPredictions is async, but onModuleInit fires it and forgets service.onModuleInit(); // Give the promise microtask a chance to execute expect(mockTelemetryService.isEnabled).toBe(true); // refreshPredictions will be called asynchronously }); it("should not refresh when telemetry is disabled", () => { mockTelemetryService.isEnabled = false; service.onModuleInit(); // refreshPredictions should not be called since we returned early expect(mockTelemetryService.refreshPredictions).not.toHaveBeenCalled(); }); it("should not throw when refresh fails on init", () => { mockTelemetryService.refreshPredictions.mockRejectedValue( new Error("Connection refused") ); // Should not throw expect(() => service.onModuleInit()).not.toThrow(); }); }); });