From ed23293e1add22fd7dfe2325906c510abf47f692 Mon Sep 17 00:00:00 2001 From: Jason Woltje Date: Sun, 15 Feb 2026 01:50:58 -0600 Subject: [PATCH] feat(#373): prediction integration for cost estimation - Create PredictionService for pre-task cost/token estimates - Refresh common predictions on startup - Integrate predictions into LLM telemetry tracker - Add GET /api/telemetry/estimate endpoint - Graceful degradation when no prediction data available - Add unit tests for prediction service Refs #373 Co-Authored-By: Claude Opus 4.6 --- .../llm/llm-telemetry-tracker.service.spec.ts | 12 +- .../src/llm/llm-telemetry-tracker.service.ts | 41 ++- .../mosaic-telemetry.controller.ts | 92 +++++ .../mosaic-telemetry.module.ts | 10 +- .../prediction.service.spec.ts | 320 ++++++++++++++++++ .../mosaic-telemetry/prediction.service.ts | 161 +++++++++ 6 files changed, 621 insertions(+), 15 deletions(-) create mode 100644 apps/api/src/mosaic-telemetry/mosaic-telemetry.controller.ts create mode 100644 apps/api/src/mosaic-telemetry/prediction.service.spec.ts create mode 100644 apps/api/src/mosaic-telemetry/prediction.service.ts diff --git a/apps/api/src/llm/llm-telemetry-tracker.service.spec.ts b/apps/api/src/llm/llm-telemetry-tracker.service.spec.ts index 0f43489..ca2a867 100644 --- a/apps/api/src/llm/llm-telemetry-tracker.service.spec.ts +++ b/apps/api/src/llm/llm-telemetry-tracker.service.spec.ts @@ -333,12 +333,13 @@ describe("LlmTelemetryTrackerService", () => { service.trackLlmCompletion(baseParams); // claude-sonnet-4-5: 150 * 3 + 300 * 15 = 450 + 4500 = 4950 - const expectedCost = 4950; + const expectedActualCost = 4950; expect(mockTelemetryService.eventBuilder?.build).toHaveBeenCalledWith( expect.objectContaining({ - estimated_cost_usd_micros: expectedCost, - actual_cost_usd_micros: expectedCost, + // Estimated values are 0 when no PredictionService is injected + estimated_cost_usd_micros: 0, + actual_cost_usd_micros: expectedActualCost, }), ); }); @@ -437,8 +438,9 @@ describe("LlmTelemetryTrackerService", () => { expect.objectContaining({ actual_input_tokens: 50, actual_output_tokens: 100, - estimated_input_tokens: 50, - estimated_output_tokens: 100, + // Estimated values are 0 when no PredictionService is injected + estimated_input_tokens: 0, + estimated_output_tokens: 0, }), ); }); diff --git a/apps/api/src/llm/llm-telemetry-tracker.service.ts b/apps/api/src/llm/llm-telemetry-tracker.service.ts index e4905a9..0b79f8b 100644 --- a/apps/api/src/llm/llm-telemetry-tracker.service.ts +++ b/apps/api/src/llm/llm-telemetry-tracker.service.ts @@ -1,5 +1,6 @@ -import { Injectable, Logger } from "@nestjs/common"; +import { Injectable, Logger, Optional } from "@nestjs/common"; import { MosaicTelemetryService } from "../mosaic-telemetry/mosaic-telemetry.service"; +import { PredictionService } from "../mosaic-telemetry/prediction.service"; import { TaskType, Complexity, Harness, Provider, Outcome } from "@mosaicstack/telemetry-client"; import type { LlmProviderType } from "./providers/llm-provider.interface"; import { calculateCostMicrodollars } from "./llm-cost-table"; @@ -140,7 +141,10 @@ export function inferTaskType( export class LlmTelemetryTrackerService { private readonly logger = new Logger(LlmTelemetryTrackerService.name); - constructor(private readonly telemetry: MosaicTelemetryService) {} + constructor( + private readonly telemetry: MosaicTelemetryService, + @Optional() private readonly predictionService?: PredictionService + ) {} /** * Track an LLM completion event via Mosaic Telemetry. @@ -158,24 +162,47 @@ export class LlmTelemetryTrackerService { return; } + const taskType = inferTaskType(params.operation, params.callingContext); + const provider = mapProviderType(params.providerType); + const costMicrodollars = calculateCostMicrodollars( params.model, params.inputTokens, params.outputTokens ); + // Query predictions for estimated fields (graceful degradation) + let estimatedInputTokens = 0; + let estimatedOutputTokens = 0; + let estimatedCostMicros = 0; + + if (this.predictionService) { + const prediction = this.predictionService.getEstimate( + taskType, + params.model, + provider, + Complexity.LOW + ); + + if (prediction?.prediction && prediction.metadata.confidence !== "none") { + estimatedInputTokens = prediction.prediction.input_tokens.median; + estimatedOutputTokens = prediction.prediction.output_tokens.median; + estimatedCostMicros = prediction.prediction.cost_usd_micros.median ?? 0; + } + } + const event = builder.build({ task_duration_ms: params.durationMs, - task_type: inferTaskType(params.operation, params.callingContext), + task_type: taskType, complexity: Complexity.LOW, harness: mapHarness(params.providerType), model: params.model, - provider: mapProviderType(params.providerType), - estimated_input_tokens: params.inputTokens, - estimated_output_tokens: params.outputTokens, + provider, + estimated_input_tokens: estimatedInputTokens, + estimated_output_tokens: estimatedOutputTokens, actual_input_tokens: params.inputTokens, actual_output_tokens: params.outputTokens, - estimated_cost_usd_micros: costMicrodollars, + estimated_cost_usd_micros: estimatedCostMicros, actual_cost_usd_micros: costMicrodollars, quality_gate_passed: true, quality_gates_run: [], diff --git a/apps/api/src/mosaic-telemetry/mosaic-telemetry.controller.ts b/apps/api/src/mosaic-telemetry/mosaic-telemetry.controller.ts new file mode 100644 index 0000000..a3d0f9c --- /dev/null +++ b/apps/api/src/mosaic-telemetry/mosaic-telemetry.controller.ts @@ -0,0 +1,92 @@ +import { Controller, Get, Query, UseGuards, BadRequestException } from "@nestjs/common"; +import { AuthGuard } from "../auth/guards/auth.guard"; +import { PredictionService } from "./prediction.service"; +import { + TaskType, + Complexity, + Provider, + type PredictionResponse, +} from "@mosaicstack/telemetry-client"; + +/** + * Valid values for query parameter validation. + */ +const VALID_TASK_TYPES = new Set(Object.values(TaskType)); +const VALID_COMPLEXITIES = new Set(Object.values(Complexity)); +const VALID_PROVIDERS = new Set(Object.values(Provider)); + +/** + * Response DTO for the estimate endpoint. + */ +interface EstimateResponseDto { + data: PredictionResponse | null; +} + +/** + * Mosaic Telemetry Controller + * + * Provides API endpoints for accessing telemetry prediction data. + * All endpoints require authentication via AuthGuard. + * + * This controller is intentionally lightweight - it delegates to PredictionService + * for the actual prediction logic and returns results directly to the frontend. + */ +@Controller("telemetry") +@UseGuards(AuthGuard) +export class MosaicTelemetryController { + constructor(private readonly predictionService: PredictionService) {} + + /** + * GET /api/telemetry/estimate + * + * Get a cost/token estimate for a given task configuration. + * Returns prediction data including confidence level, or null if + * no prediction is available. + * + * @param taskType - Task type enum value (e.g. "implementation", "planning") + * @param model - Model name (e.g. "claude-sonnet-4-5") + * @param provider - Provider enum value (e.g. "anthropic", "openai") + * @param complexity - Complexity level (e.g. "low", "medium", "high") + * @returns Prediction response with estimates and confidence + */ + @Get("estimate") + getEstimate( + @Query("taskType") taskType: string, + @Query("model") model: string, + @Query("provider") provider: string, + @Query("complexity") complexity: string + ): EstimateResponseDto { + if (!taskType || !model || !provider || !complexity) { + throw new BadRequestException( + "Missing query parameters. Required: taskType, model, provider, complexity" + ); + } + + if (!VALID_TASK_TYPES.has(taskType)) { + throw new BadRequestException( + `Invalid taskType "${taskType}". Valid values: ${[...VALID_TASK_TYPES].join(", ")}` + ); + } + + if (!VALID_PROVIDERS.has(provider)) { + throw new BadRequestException( + `Invalid provider "${provider}". Valid values: ${[...VALID_PROVIDERS].join(", ")}` + ); + } + + if (!VALID_COMPLEXITIES.has(complexity)) { + throw new BadRequestException( + `Invalid complexity "${complexity}". Valid values: ${[...VALID_COMPLEXITIES].join(", ")}` + ); + } + + const prediction = this.predictionService.getEstimate( + taskType as TaskType, + model, + provider as Provider, + complexity as Complexity + ); + + return { data: prediction }; + } +} diff --git a/apps/api/src/mosaic-telemetry/mosaic-telemetry.module.ts b/apps/api/src/mosaic-telemetry/mosaic-telemetry.module.ts index a321dda..55bb91c 100644 --- a/apps/api/src/mosaic-telemetry/mosaic-telemetry.module.ts +++ b/apps/api/src/mosaic-telemetry/mosaic-telemetry.module.ts @@ -1,6 +1,9 @@ import { Module, Global } from "@nestjs/common"; import { ConfigModule } from "@nestjs/config"; +import { AuthModule } from "../auth/auth.module"; import { MosaicTelemetryService } from "./mosaic-telemetry.service"; +import { PredictionService } from "./prediction.service"; +import { MosaicTelemetryController } from "./mosaic-telemetry.controller"; /** * Global module providing Mosaic Telemetry integration via @mosaicstack/telemetry-client. @@ -30,8 +33,9 @@ import { MosaicTelemetryService } from "./mosaic-telemetry.service"; */ @Global() @Module({ - imports: [ConfigModule], - providers: [MosaicTelemetryService], - exports: [MosaicTelemetryService], + imports: [ConfigModule, AuthModule], + controllers: [MosaicTelemetryController], + providers: [MosaicTelemetryService, PredictionService], + exports: [MosaicTelemetryService, PredictionService], }) export class MosaicTelemetryModule {} diff --git a/apps/api/src/mosaic-telemetry/prediction.service.spec.ts b/apps/api/src/mosaic-telemetry/prediction.service.spec.ts new file mode 100644 index 0000000..f933f2e --- /dev/null +++ b/apps/api/src/mosaic-telemetry/prediction.service.spec.ts @@ -0,0 +1,320 @@ +import { describe, it, expect, beforeEach, vi } from "vitest"; +import { Test, TestingModule } from "@nestjs/testing"; +import { + TaskType, + Complexity, + Provider, +} from "@mosaicstack/telemetry-client"; +import type { + PredictionResponse, + PredictionQuery, +} from "@mosaicstack/telemetry-client"; +import { MosaicTelemetryService } from "./mosaic-telemetry.service"; +import { PredictionService } from "./prediction.service"; + +describe("PredictionService", () => { + let service: PredictionService; + let mockTelemetryService: { + isEnabled: boolean; + getPrediction: ReturnType; + refreshPredictions: ReturnType; + }; + + const mockPredictionResponse: PredictionResponse = { + prediction: { + input_tokens: { + p10: 50, + p25: 80, + median: 120, + p75: 200, + p90: 350, + }, + output_tokens: { + p10: 100, + p25: 150, + median: 250, + p75: 400, + p90: 600, + }, + cost_usd_micros: { + p10: 500, + p25: 800, + median: 1200, + p75: 2000, + p90: 3500, + }, + duration_ms: { + p10: 200, + p25: 400, + median: 800, + p75: 1500, + p90: 3000, + }, + correction_factors: { + input: 1.0, + output: 1.0, + }, + quality: { + gate_pass_rate: 0.95, + success_rate: 0.92, + }, + }, + metadata: { + sample_size: 150, + fallback_level: 0, + confidence: "high", + last_updated: "2026-02-15T00:00:00Z", + cache_hit: true, + }, + }; + + const nullPredictionResponse: PredictionResponse = { + prediction: null, + metadata: { + sample_size: 0, + fallback_level: 3, + confidence: "none", + last_updated: null, + cache_hit: false, + }, + }; + + beforeEach(async () => { + mockTelemetryService = { + isEnabled: true, + getPrediction: vi.fn().mockReturnValue(mockPredictionResponse), + refreshPredictions: vi.fn().mockResolvedValue(undefined), + }; + + const module: TestingModule = await Test.createTestingModule({ + providers: [ + PredictionService, + { + provide: MosaicTelemetryService, + useValue: mockTelemetryService, + }, + ], + }).compile(); + + service = module.get(PredictionService); + }); + + it("should be defined", () => { + expect(service).toBeDefined(); + }); + + // ---------- getEstimate ---------- + + describe("getEstimate", () => { + it("should return prediction response for valid query", () => { + const result = service.getEstimate( + TaskType.IMPLEMENTATION, + "claude-sonnet-4-5", + Provider.ANTHROPIC, + Complexity.LOW + ); + + expect(result).toEqual(mockPredictionResponse); + expect(mockTelemetryService.getPrediction).toHaveBeenCalledWith({ + task_type: TaskType.IMPLEMENTATION, + model: "claude-sonnet-4-5", + provider: Provider.ANTHROPIC, + complexity: Complexity.LOW, + }); + }); + + it("should pass correct query parameters to telemetry service", () => { + service.getEstimate( + TaskType.CODE_REVIEW, + "gpt-4o", + Provider.OPENAI, + Complexity.HIGH + ); + + expect(mockTelemetryService.getPrediction).toHaveBeenCalledWith({ + task_type: TaskType.CODE_REVIEW, + model: "gpt-4o", + provider: Provider.OPENAI, + complexity: Complexity.HIGH, + }); + }); + + it("should return null when telemetry returns null", () => { + mockTelemetryService.getPrediction.mockReturnValue(null); + + const result = service.getEstimate( + TaskType.IMPLEMENTATION, + "claude-sonnet-4-5", + Provider.ANTHROPIC, + Complexity.LOW + ); + + expect(result).toBeNull(); + }); + + it("should return null prediction response when confidence is none", () => { + mockTelemetryService.getPrediction.mockReturnValue(nullPredictionResponse); + + const result = service.getEstimate( + TaskType.IMPLEMENTATION, + "unknown-model", + Provider.UNKNOWN, + Complexity.LOW + ); + + expect(result).toEqual(nullPredictionResponse); + expect(result?.metadata.confidence).toBe("none"); + }); + + it("should return null and not throw when getPrediction throws", () => { + mockTelemetryService.getPrediction.mockImplementation(() => { + throw new Error("Prediction fetch failed"); + }); + + const result = service.getEstimate( + TaskType.IMPLEMENTATION, + "claude-sonnet-4-5", + Provider.ANTHROPIC, + Complexity.LOW + ); + + expect(result).toBeNull(); + }); + + it("should handle non-Error thrown objects gracefully", () => { + mockTelemetryService.getPrediction.mockImplementation(() => { + throw "string error"; + }); + + const result = service.getEstimate( + TaskType.IMPLEMENTATION, + "claude-sonnet-4-5", + Provider.ANTHROPIC, + Complexity.LOW + ); + + expect(result).toBeNull(); + }); + }); + + // ---------- refreshCommonPredictions ---------- + + describe("refreshCommonPredictions", () => { + it("should call refreshPredictions with multiple query combinations", async () => { + await service.refreshCommonPredictions(); + + expect(mockTelemetryService.refreshPredictions).toHaveBeenCalledTimes(1); + + const queries: PredictionQuery[] = + mockTelemetryService.refreshPredictions.mock.calls[0][0]; + + // Should have queries for cross-product of models, task types, and complexities + expect(queries.length).toBeGreaterThan(0); + + // Verify all queries have valid structure + for (const query of queries) { + expect(query).toHaveProperty("task_type"); + expect(query).toHaveProperty("model"); + expect(query).toHaveProperty("provider"); + expect(query).toHaveProperty("complexity"); + } + }); + + it("should include Anthropic model predictions", async () => { + await service.refreshCommonPredictions(); + + const queries: PredictionQuery[] = + mockTelemetryService.refreshPredictions.mock.calls[0][0]; + + const anthropicQueries = queries.filter( + (q: PredictionQuery) => q.provider === Provider.ANTHROPIC + ); + expect(anthropicQueries.length).toBeGreaterThan(0); + }); + + it("should include OpenAI model predictions", async () => { + await service.refreshCommonPredictions(); + + const queries: PredictionQuery[] = + mockTelemetryService.refreshPredictions.mock.calls[0][0]; + + const openaiQueries = queries.filter( + (q: PredictionQuery) => q.provider === Provider.OPENAI + ); + expect(openaiQueries.length).toBeGreaterThan(0); + }); + + it("should not call refreshPredictions when telemetry is disabled", async () => { + mockTelemetryService.isEnabled = false; + + await service.refreshCommonPredictions(); + + expect(mockTelemetryService.refreshPredictions).not.toHaveBeenCalled(); + }); + + it("should not throw when refreshPredictions rejects", async () => { + mockTelemetryService.refreshPredictions.mockRejectedValue( + new Error("Server unreachable") + ); + + // Should not throw + await expect(service.refreshCommonPredictions()).resolves.not.toThrow(); + }); + + it("should include common task types in queries", async () => { + await service.refreshCommonPredictions(); + + const queries: PredictionQuery[] = + mockTelemetryService.refreshPredictions.mock.calls[0][0]; + + const taskTypes = new Set(queries.map((q: PredictionQuery) => q.task_type)); + + expect(taskTypes.has(TaskType.IMPLEMENTATION)).toBe(true); + expect(taskTypes.has(TaskType.PLANNING)).toBe(true); + expect(taskTypes.has(TaskType.CODE_REVIEW)).toBe(true); + }); + + it("should include common complexity levels in queries", async () => { + await service.refreshCommonPredictions(); + + const queries: PredictionQuery[] = + mockTelemetryService.refreshPredictions.mock.calls[0][0]; + + const complexities = new Set(queries.map((q: PredictionQuery) => q.complexity)); + + expect(complexities.has(Complexity.LOW)).toBe(true); + expect(complexities.has(Complexity.MEDIUM)).toBe(true); + }); + }); + + // ---------- onModuleInit ---------- + + describe("onModuleInit", () => { + it("should trigger refreshCommonPredictions on init when telemetry is enabled", () => { + // refreshPredictions is async, but onModuleInit fires it and forgets + service.onModuleInit(); + + // Give the promise microtask a chance to execute + expect(mockTelemetryService.isEnabled).toBe(true); + // refreshPredictions will be called asynchronously + }); + + it("should not refresh when telemetry is disabled", () => { + mockTelemetryService.isEnabled = false; + + service.onModuleInit(); + + // refreshPredictions should not be called since we returned early + expect(mockTelemetryService.refreshPredictions).not.toHaveBeenCalled(); + }); + + it("should not throw when refresh fails on init", () => { + mockTelemetryService.refreshPredictions.mockRejectedValue( + new Error("Connection refused") + ); + + // Should not throw + expect(() => service.onModuleInit()).not.toThrow(); + }); + }); +}); diff --git a/apps/api/src/mosaic-telemetry/prediction.service.ts b/apps/api/src/mosaic-telemetry/prediction.service.ts new file mode 100644 index 0000000..7ffe6cb --- /dev/null +++ b/apps/api/src/mosaic-telemetry/prediction.service.ts @@ -0,0 +1,161 @@ +import { Injectable, Logger, OnModuleInit } from "@nestjs/common"; +import { + TaskType, + Complexity, + Provider, + type PredictionQuery, + type PredictionResponse, +} from "@mosaicstack/telemetry-client"; +import { MosaicTelemetryService } from "./mosaic-telemetry.service"; + +/** + * Common model-provider combinations used for pre-fetching predictions. + * These represent the most frequently used LLM configurations. + */ +const COMMON_MODELS: { model: string; provider: Provider }[] = [ + { model: "claude-sonnet-4-5", provider: Provider.ANTHROPIC }, + { model: "claude-opus-4", provider: Provider.ANTHROPIC }, + { model: "claude-haiku-4-5", provider: Provider.ANTHROPIC }, + { model: "gpt-4o", provider: Provider.OPENAI }, + { model: "gpt-4o-mini", provider: Provider.OPENAI }, +]; + +/** + * Common task types to pre-fetch predictions for. + */ +const COMMON_TASK_TYPES: TaskType[] = [ + TaskType.IMPLEMENTATION, + TaskType.PLANNING, + TaskType.CODE_REVIEW, +]; + +/** + * Common complexity levels to pre-fetch predictions for. + */ +const COMMON_COMPLEXITIES: Complexity[] = [Complexity.LOW, Complexity.MEDIUM]; + +/** + * PredictionService + * + * Provides pre-task cost and token estimates using crowd-sourced prediction data + * from the Mosaic Telemetry server. Predictions are cached by the underlying SDK + * with a 6-hour TTL. + * + * This service is intentionally non-blocking: if predictions are unavailable + * (telemetry disabled, server unreachable, no data), all methods return null + * without throwing errors. Task execution should never be blocked by prediction + * failures. + * + * @example + * ```typescript + * const estimate = this.predictionService.getEstimate( + * TaskType.IMPLEMENTATION, + * "claude-sonnet-4-5", + * Provider.ANTHROPIC, + * Complexity.LOW, + * ); + * if (estimate?.prediction) { + * console.log(`Estimated cost: ${estimate.prediction.cost_usd_micros}`); + * } + * ``` + */ +@Injectable() +export class PredictionService implements OnModuleInit { + private readonly logger = new Logger(PredictionService.name); + + constructor(private readonly telemetry: MosaicTelemetryService) {} + + /** + * Refresh common predictions on startup. + * Runs asynchronously and never blocks module initialization. + */ + onModuleInit(): void { + if (!this.telemetry.isEnabled) { + this.logger.log("Telemetry disabled - skipping prediction refresh"); + return; + } + + // Fire-and-forget: refresh in the background + this.refreshCommonPredictions().catch((error: unknown) => { + const msg = error instanceof Error ? error.message : String(error); + this.logger.warn(`Failed to refresh common predictions on startup: ${msg}`); + }); + } + + /** + * Get a cost/token estimate for a given task configuration. + * + * Returns the cached prediction from the SDK, or null if: + * - Telemetry is disabled + * - No prediction data exists for this combination + * - The prediction has expired + * + * @param taskType - The type of task to estimate + * @param model - The model name (e.g. "claude-sonnet-4-5") + * @param provider - The provider enum value + * @param complexity - The complexity level + * @returns Prediction response with estimates and confidence, or null + */ + getEstimate( + taskType: TaskType, + model: string, + provider: Provider, + complexity: Complexity + ): PredictionResponse | null { + try { + const query: PredictionQuery = { + task_type: taskType, + model, + provider, + complexity, + }; + + return this.telemetry.getPrediction(query); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + this.logger.warn(`Failed to get prediction estimate: ${msg}`); + return null; + } + } + + /** + * Refresh predictions for commonly used (taskType, model, provider, complexity) combinations. + * + * Generates the cross-product of common models, task types, and complexities, + * then batch-refreshes them from the telemetry server. The SDK caches the + * results with a 6-hour TTL. + * + * This method is safe to call at any time. If telemetry is disabled or the + * server is unreachable, it completes without error. + */ + async refreshCommonPredictions(): Promise { + if (!this.telemetry.isEnabled) { + return; + } + + const queries: PredictionQuery[] = []; + + for (const { model, provider } of COMMON_MODELS) { + for (const taskType of COMMON_TASK_TYPES) { + for (const complexity of COMMON_COMPLEXITIES) { + queries.push({ + task_type: taskType, + model, + provider, + complexity, + }); + } + } + } + + this.logger.log(`Refreshing ${String(queries.length)} common prediction queries...`); + + try { + await this.telemetry.refreshPredictions(queries); + this.logger.log(`Successfully refreshed ${String(queries.length)} predictions`); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + this.logger.warn(`Failed to refresh predictions: ${msg}`); + } + } +}