feat(#373): prediction integration for cost estimation
Some checks failed
ci/woodpecker/push/api Pipeline failed
Some checks failed
ci/woodpecker/push/api Pipeline failed
- Create PredictionService for pre-task cost/token estimates - Refresh common predictions on startup - Integrate predictions into LLM telemetry tracker - Add GET /api/telemetry/estimate endpoint - Graceful degradation when no prediction data available - Add unit tests for prediction service Refs #373 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -333,12 +333,13 @@ describe("LlmTelemetryTrackerService", () => {
|
|||||||
service.trackLlmCompletion(baseParams);
|
service.trackLlmCompletion(baseParams);
|
||||||
|
|
||||||
// claude-sonnet-4-5: 150 * 3 + 300 * 15 = 450 + 4500 = 4950
|
// claude-sonnet-4-5: 150 * 3 + 300 * 15 = 450 + 4500 = 4950
|
||||||
const expectedCost = 4950;
|
const expectedActualCost = 4950;
|
||||||
|
|
||||||
expect(mockTelemetryService.eventBuilder?.build).toHaveBeenCalledWith(
|
expect(mockTelemetryService.eventBuilder?.build).toHaveBeenCalledWith(
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
estimated_cost_usd_micros: expectedCost,
|
// Estimated values are 0 when no PredictionService is injected
|
||||||
actual_cost_usd_micros: expectedCost,
|
estimated_cost_usd_micros: 0,
|
||||||
|
actual_cost_usd_micros: expectedActualCost,
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
@@ -437,8 +438,9 @@ describe("LlmTelemetryTrackerService", () => {
|
|||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
actual_input_tokens: 50,
|
actual_input_tokens: 50,
|
||||||
actual_output_tokens: 100,
|
actual_output_tokens: 100,
|
||||||
estimated_input_tokens: 50,
|
// Estimated values are 0 when no PredictionService is injected
|
||||||
estimated_output_tokens: 100,
|
estimated_input_tokens: 0,
|
||||||
|
estimated_output_tokens: 0,
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import { Injectable, Logger } from "@nestjs/common";
|
import { Injectable, Logger, Optional } from "@nestjs/common";
|
||||||
import { MosaicTelemetryService } from "../mosaic-telemetry/mosaic-telemetry.service";
|
import { MosaicTelemetryService } from "../mosaic-telemetry/mosaic-telemetry.service";
|
||||||
|
import { PredictionService } from "../mosaic-telemetry/prediction.service";
|
||||||
import { TaskType, Complexity, Harness, Provider, Outcome } from "@mosaicstack/telemetry-client";
|
import { TaskType, Complexity, Harness, Provider, Outcome } from "@mosaicstack/telemetry-client";
|
||||||
import type { LlmProviderType } from "./providers/llm-provider.interface";
|
import type { LlmProviderType } from "./providers/llm-provider.interface";
|
||||||
import { calculateCostMicrodollars } from "./llm-cost-table";
|
import { calculateCostMicrodollars } from "./llm-cost-table";
|
||||||
@@ -140,7 +141,10 @@ export function inferTaskType(
|
|||||||
export class LlmTelemetryTrackerService {
|
export class LlmTelemetryTrackerService {
|
||||||
private readonly logger = new Logger(LlmTelemetryTrackerService.name);
|
private readonly logger = new Logger(LlmTelemetryTrackerService.name);
|
||||||
|
|
||||||
constructor(private readonly telemetry: MosaicTelemetryService) {}
|
constructor(
|
||||||
|
private readonly telemetry: MosaicTelemetryService,
|
||||||
|
@Optional() private readonly predictionService?: PredictionService
|
||||||
|
) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Track an LLM completion event via Mosaic Telemetry.
|
* Track an LLM completion event via Mosaic Telemetry.
|
||||||
@@ -158,24 +162,47 @@ export class LlmTelemetryTrackerService {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const taskType = inferTaskType(params.operation, params.callingContext);
|
||||||
|
const provider = mapProviderType(params.providerType);
|
||||||
|
|
||||||
const costMicrodollars = calculateCostMicrodollars(
|
const costMicrodollars = calculateCostMicrodollars(
|
||||||
params.model,
|
params.model,
|
||||||
params.inputTokens,
|
params.inputTokens,
|
||||||
params.outputTokens
|
params.outputTokens
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Query predictions for estimated fields (graceful degradation)
|
||||||
|
let estimatedInputTokens = 0;
|
||||||
|
let estimatedOutputTokens = 0;
|
||||||
|
let estimatedCostMicros = 0;
|
||||||
|
|
||||||
|
if (this.predictionService) {
|
||||||
|
const prediction = this.predictionService.getEstimate(
|
||||||
|
taskType,
|
||||||
|
params.model,
|
||||||
|
provider,
|
||||||
|
Complexity.LOW
|
||||||
|
);
|
||||||
|
|
||||||
|
if (prediction?.prediction && prediction.metadata.confidence !== "none") {
|
||||||
|
estimatedInputTokens = prediction.prediction.input_tokens.median;
|
||||||
|
estimatedOutputTokens = prediction.prediction.output_tokens.median;
|
||||||
|
estimatedCostMicros = prediction.prediction.cost_usd_micros.median ?? 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const event = builder.build({
|
const event = builder.build({
|
||||||
task_duration_ms: params.durationMs,
|
task_duration_ms: params.durationMs,
|
||||||
task_type: inferTaskType(params.operation, params.callingContext),
|
task_type: taskType,
|
||||||
complexity: Complexity.LOW,
|
complexity: Complexity.LOW,
|
||||||
harness: mapHarness(params.providerType),
|
harness: mapHarness(params.providerType),
|
||||||
model: params.model,
|
model: params.model,
|
||||||
provider: mapProviderType(params.providerType),
|
provider,
|
||||||
estimated_input_tokens: params.inputTokens,
|
estimated_input_tokens: estimatedInputTokens,
|
||||||
estimated_output_tokens: params.outputTokens,
|
estimated_output_tokens: estimatedOutputTokens,
|
||||||
actual_input_tokens: params.inputTokens,
|
actual_input_tokens: params.inputTokens,
|
||||||
actual_output_tokens: params.outputTokens,
|
actual_output_tokens: params.outputTokens,
|
||||||
estimated_cost_usd_micros: costMicrodollars,
|
estimated_cost_usd_micros: estimatedCostMicros,
|
||||||
actual_cost_usd_micros: costMicrodollars,
|
actual_cost_usd_micros: costMicrodollars,
|
||||||
quality_gate_passed: true,
|
quality_gate_passed: true,
|
||||||
quality_gates_run: [],
|
quality_gates_run: [],
|
||||||
|
|||||||
92
apps/api/src/mosaic-telemetry/mosaic-telemetry.controller.ts
Normal file
92
apps/api/src/mosaic-telemetry/mosaic-telemetry.controller.ts
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import { Controller, Get, Query, UseGuards, BadRequestException } from "@nestjs/common";
|
||||||
|
import { AuthGuard } from "../auth/guards/auth.guard";
|
||||||
|
import { PredictionService } from "./prediction.service";
|
||||||
|
import {
|
||||||
|
TaskType,
|
||||||
|
Complexity,
|
||||||
|
Provider,
|
||||||
|
type PredictionResponse,
|
||||||
|
} from "@mosaicstack/telemetry-client";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Valid values for query parameter validation.
|
||||||
|
*/
|
||||||
|
const VALID_TASK_TYPES = new Set<string>(Object.values(TaskType));
|
||||||
|
const VALID_COMPLEXITIES = new Set<string>(Object.values(Complexity));
|
||||||
|
const VALID_PROVIDERS = new Set<string>(Object.values(Provider));
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Response DTO for the estimate endpoint.
|
||||||
|
*/
|
||||||
|
interface EstimateResponseDto {
|
||||||
|
data: PredictionResponse | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mosaic Telemetry Controller
|
||||||
|
*
|
||||||
|
* Provides API endpoints for accessing telemetry prediction data.
|
||||||
|
* All endpoints require authentication via AuthGuard.
|
||||||
|
*
|
||||||
|
* This controller is intentionally lightweight - it delegates to PredictionService
|
||||||
|
* for the actual prediction logic and returns results directly to the frontend.
|
||||||
|
*/
|
||||||
|
@Controller("telemetry")
|
||||||
|
@UseGuards(AuthGuard)
|
||||||
|
export class MosaicTelemetryController {
|
||||||
|
constructor(private readonly predictionService: PredictionService) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* GET /api/telemetry/estimate
|
||||||
|
*
|
||||||
|
* Get a cost/token estimate for a given task configuration.
|
||||||
|
* Returns prediction data including confidence level, or null if
|
||||||
|
* no prediction is available.
|
||||||
|
*
|
||||||
|
* @param taskType - Task type enum value (e.g. "implementation", "planning")
|
||||||
|
* @param model - Model name (e.g. "claude-sonnet-4-5")
|
||||||
|
* @param provider - Provider enum value (e.g. "anthropic", "openai")
|
||||||
|
* @param complexity - Complexity level (e.g. "low", "medium", "high")
|
||||||
|
* @returns Prediction response with estimates and confidence
|
||||||
|
*/
|
||||||
|
@Get("estimate")
|
||||||
|
getEstimate(
|
||||||
|
@Query("taskType") taskType: string,
|
||||||
|
@Query("model") model: string,
|
||||||
|
@Query("provider") provider: string,
|
||||||
|
@Query("complexity") complexity: string
|
||||||
|
): EstimateResponseDto {
|
||||||
|
if (!taskType || !model || !provider || !complexity) {
|
||||||
|
throw new BadRequestException(
|
||||||
|
"Missing query parameters. Required: taskType, model, provider, complexity"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!VALID_TASK_TYPES.has(taskType)) {
|
||||||
|
throw new BadRequestException(
|
||||||
|
`Invalid taskType "${taskType}". Valid values: ${[...VALID_TASK_TYPES].join(", ")}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!VALID_PROVIDERS.has(provider)) {
|
||||||
|
throw new BadRequestException(
|
||||||
|
`Invalid provider "${provider}". Valid values: ${[...VALID_PROVIDERS].join(", ")}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!VALID_COMPLEXITIES.has(complexity)) {
|
||||||
|
throw new BadRequestException(
|
||||||
|
`Invalid complexity "${complexity}". Valid values: ${[...VALID_COMPLEXITIES].join(", ")}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const prediction = this.predictionService.getEstimate(
|
||||||
|
taskType as TaskType,
|
||||||
|
model,
|
||||||
|
provider as Provider,
|
||||||
|
complexity as Complexity
|
||||||
|
);
|
||||||
|
|
||||||
|
return { data: prediction };
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
import { Module, Global } from "@nestjs/common";
|
import { Module, Global } from "@nestjs/common";
|
||||||
import { ConfigModule } from "@nestjs/config";
|
import { ConfigModule } from "@nestjs/config";
|
||||||
|
import { AuthModule } from "../auth/auth.module";
|
||||||
import { MosaicTelemetryService } from "./mosaic-telemetry.service";
|
import { MosaicTelemetryService } from "./mosaic-telemetry.service";
|
||||||
|
import { PredictionService } from "./prediction.service";
|
||||||
|
import { MosaicTelemetryController } from "./mosaic-telemetry.controller";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Global module providing Mosaic Telemetry integration via @mosaicstack/telemetry-client.
|
* Global module providing Mosaic Telemetry integration via @mosaicstack/telemetry-client.
|
||||||
@@ -30,8 +33,9 @@ import { MosaicTelemetryService } from "./mosaic-telemetry.service";
|
|||||||
*/
|
*/
|
||||||
@Global()
|
@Global()
|
||||||
@Module({
|
@Module({
|
||||||
imports: [ConfigModule],
|
imports: [ConfigModule, AuthModule],
|
||||||
providers: [MosaicTelemetryService],
|
controllers: [MosaicTelemetryController],
|
||||||
exports: [MosaicTelemetryService],
|
providers: [MosaicTelemetryService, PredictionService],
|
||||||
|
exports: [MosaicTelemetryService, PredictionService],
|
||||||
})
|
})
|
||||||
export class MosaicTelemetryModule {}
|
export class MosaicTelemetryModule {}
|
||||||
|
|||||||
320
apps/api/src/mosaic-telemetry/prediction.service.spec.ts
Normal file
320
apps/api/src/mosaic-telemetry/prediction.service.spec.ts
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||||
|
import { Test, TestingModule } from "@nestjs/testing";
|
||||||
|
import {
|
||||||
|
TaskType,
|
||||||
|
Complexity,
|
||||||
|
Provider,
|
||||||
|
} from "@mosaicstack/telemetry-client";
|
||||||
|
import type {
|
||||||
|
PredictionResponse,
|
||||||
|
PredictionQuery,
|
||||||
|
} from "@mosaicstack/telemetry-client";
|
||||||
|
import { MosaicTelemetryService } from "./mosaic-telemetry.service";
|
||||||
|
import { PredictionService } from "./prediction.service";
|
||||||
|
|
||||||
|
describe("PredictionService", () => {
|
||||||
|
let service: PredictionService;
|
||||||
|
let mockTelemetryService: {
|
||||||
|
isEnabled: boolean;
|
||||||
|
getPrediction: ReturnType<typeof vi.fn>;
|
||||||
|
refreshPredictions: ReturnType<typeof vi.fn>;
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockPredictionResponse: PredictionResponse = {
|
||||||
|
prediction: {
|
||||||
|
input_tokens: {
|
||||||
|
p10: 50,
|
||||||
|
p25: 80,
|
||||||
|
median: 120,
|
||||||
|
p75: 200,
|
||||||
|
p90: 350,
|
||||||
|
},
|
||||||
|
output_tokens: {
|
||||||
|
p10: 100,
|
||||||
|
p25: 150,
|
||||||
|
median: 250,
|
||||||
|
p75: 400,
|
||||||
|
p90: 600,
|
||||||
|
},
|
||||||
|
cost_usd_micros: {
|
||||||
|
p10: 500,
|
||||||
|
p25: 800,
|
||||||
|
median: 1200,
|
||||||
|
p75: 2000,
|
||||||
|
p90: 3500,
|
||||||
|
},
|
||||||
|
duration_ms: {
|
||||||
|
p10: 200,
|
||||||
|
p25: 400,
|
||||||
|
median: 800,
|
||||||
|
p75: 1500,
|
||||||
|
p90: 3000,
|
||||||
|
},
|
||||||
|
correction_factors: {
|
||||||
|
input: 1.0,
|
||||||
|
output: 1.0,
|
||||||
|
},
|
||||||
|
quality: {
|
||||||
|
gate_pass_rate: 0.95,
|
||||||
|
success_rate: 0.92,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
metadata: {
|
||||||
|
sample_size: 150,
|
||||||
|
fallback_level: 0,
|
||||||
|
confidence: "high",
|
||||||
|
last_updated: "2026-02-15T00:00:00Z",
|
||||||
|
cache_hit: true,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const nullPredictionResponse: PredictionResponse = {
|
||||||
|
prediction: null,
|
||||||
|
metadata: {
|
||||||
|
sample_size: 0,
|
||||||
|
fallback_level: 3,
|
||||||
|
confidence: "none",
|
||||||
|
last_updated: null,
|
||||||
|
cache_hit: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
mockTelemetryService = {
|
||||||
|
isEnabled: true,
|
||||||
|
getPrediction: vi.fn().mockReturnValue(mockPredictionResponse),
|
||||||
|
refreshPredictions: vi.fn().mockResolvedValue(undefined),
|
||||||
|
};
|
||||||
|
|
||||||
|
const module: TestingModule = await Test.createTestingModule({
|
||||||
|
providers: [
|
||||||
|
PredictionService,
|
||||||
|
{
|
||||||
|
provide: MosaicTelemetryService,
|
||||||
|
useValue: mockTelemetryService,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}).compile();
|
||||||
|
|
||||||
|
service = module.get<PredictionService>(PredictionService);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should be defined", () => {
|
||||||
|
expect(service).toBeDefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
// ---------- getEstimate ----------
|
||||||
|
|
||||||
|
describe("getEstimate", () => {
|
||||||
|
it("should return prediction response for valid query", () => {
|
||||||
|
const result = service.getEstimate(
|
||||||
|
TaskType.IMPLEMENTATION,
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
Provider.ANTHROPIC,
|
||||||
|
Complexity.LOW
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result).toEqual(mockPredictionResponse);
|
||||||
|
expect(mockTelemetryService.getPrediction).toHaveBeenCalledWith({
|
||||||
|
task_type: TaskType.IMPLEMENTATION,
|
||||||
|
model: "claude-sonnet-4-5",
|
||||||
|
provider: Provider.ANTHROPIC,
|
||||||
|
complexity: Complexity.LOW,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should pass correct query parameters to telemetry service", () => {
|
||||||
|
service.getEstimate(
|
||||||
|
TaskType.CODE_REVIEW,
|
||||||
|
"gpt-4o",
|
||||||
|
Provider.OPENAI,
|
||||||
|
Complexity.HIGH
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mockTelemetryService.getPrediction).toHaveBeenCalledWith({
|
||||||
|
task_type: TaskType.CODE_REVIEW,
|
||||||
|
model: "gpt-4o",
|
||||||
|
provider: Provider.OPENAI,
|
||||||
|
complexity: Complexity.HIGH,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should return null when telemetry returns null", () => {
|
||||||
|
mockTelemetryService.getPrediction.mockReturnValue(null);
|
||||||
|
|
||||||
|
const result = service.getEstimate(
|
||||||
|
TaskType.IMPLEMENTATION,
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
Provider.ANTHROPIC,
|
||||||
|
Complexity.LOW
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should return null prediction response when confidence is none", () => {
|
||||||
|
mockTelemetryService.getPrediction.mockReturnValue(nullPredictionResponse);
|
||||||
|
|
||||||
|
const result = service.getEstimate(
|
||||||
|
TaskType.IMPLEMENTATION,
|
||||||
|
"unknown-model",
|
||||||
|
Provider.UNKNOWN,
|
||||||
|
Complexity.LOW
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result).toEqual(nullPredictionResponse);
|
||||||
|
expect(result?.metadata.confidence).toBe("none");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should return null and not throw when getPrediction throws", () => {
|
||||||
|
mockTelemetryService.getPrediction.mockImplementation(() => {
|
||||||
|
throw new Error("Prediction fetch failed");
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = service.getEstimate(
|
||||||
|
TaskType.IMPLEMENTATION,
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
Provider.ANTHROPIC,
|
||||||
|
Complexity.LOW
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle non-Error thrown objects gracefully", () => {
|
||||||
|
mockTelemetryService.getPrediction.mockImplementation(() => {
|
||||||
|
throw "string error";
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = service.getEstimate(
|
||||||
|
TaskType.IMPLEMENTATION,
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
Provider.ANTHROPIC,
|
||||||
|
Complexity.LOW
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result).toBeNull();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ---------- refreshCommonPredictions ----------
|
||||||
|
|
||||||
|
describe("refreshCommonPredictions", () => {
|
||||||
|
it("should call refreshPredictions with multiple query combinations", async () => {
|
||||||
|
await service.refreshCommonPredictions();
|
||||||
|
|
||||||
|
expect(mockTelemetryService.refreshPredictions).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
|
const queries: PredictionQuery[] =
|
||||||
|
mockTelemetryService.refreshPredictions.mock.calls[0][0];
|
||||||
|
|
||||||
|
// Should have queries for cross-product of models, task types, and complexities
|
||||||
|
expect(queries.length).toBeGreaterThan(0);
|
||||||
|
|
||||||
|
// Verify all queries have valid structure
|
||||||
|
for (const query of queries) {
|
||||||
|
expect(query).toHaveProperty("task_type");
|
||||||
|
expect(query).toHaveProperty("model");
|
||||||
|
expect(query).toHaveProperty("provider");
|
||||||
|
expect(query).toHaveProperty("complexity");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should include Anthropic model predictions", async () => {
|
||||||
|
await service.refreshCommonPredictions();
|
||||||
|
|
||||||
|
const queries: PredictionQuery[] =
|
||||||
|
mockTelemetryService.refreshPredictions.mock.calls[0][0];
|
||||||
|
|
||||||
|
const anthropicQueries = queries.filter(
|
||||||
|
(q: PredictionQuery) => q.provider === Provider.ANTHROPIC
|
||||||
|
);
|
||||||
|
expect(anthropicQueries.length).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should include OpenAI model predictions", async () => {
|
||||||
|
await service.refreshCommonPredictions();
|
||||||
|
|
||||||
|
const queries: PredictionQuery[] =
|
||||||
|
mockTelemetryService.refreshPredictions.mock.calls[0][0];
|
||||||
|
|
||||||
|
const openaiQueries = queries.filter(
|
||||||
|
(q: PredictionQuery) => q.provider === Provider.OPENAI
|
||||||
|
);
|
||||||
|
expect(openaiQueries.length).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should not call refreshPredictions when telemetry is disabled", async () => {
|
||||||
|
mockTelemetryService.isEnabled = false;
|
||||||
|
|
||||||
|
await service.refreshCommonPredictions();
|
||||||
|
|
||||||
|
expect(mockTelemetryService.refreshPredictions).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should not throw when refreshPredictions rejects", async () => {
|
||||||
|
mockTelemetryService.refreshPredictions.mockRejectedValue(
|
||||||
|
new Error("Server unreachable")
|
||||||
|
);
|
||||||
|
|
||||||
|
// Should not throw
|
||||||
|
await expect(service.refreshCommonPredictions()).resolves.not.toThrow();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should include common task types in queries", async () => {
|
||||||
|
await service.refreshCommonPredictions();
|
||||||
|
|
||||||
|
const queries: PredictionQuery[] =
|
||||||
|
mockTelemetryService.refreshPredictions.mock.calls[0][0];
|
||||||
|
|
||||||
|
const taskTypes = new Set(queries.map((q: PredictionQuery) => q.task_type));
|
||||||
|
|
||||||
|
expect(taskTypes.has(TaskType.IMPLEMENTATION)).toBe(true);
|
||||||
|
expect(taskTypes.has(TaskType.PLANNING)).toBe(true);
|
||||||
|
expect(taskTypes.has(TaskType.CODE_REVIEW)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should include common complexity levels in queries", async () => {
|
||||||
|
await service.refreshCommonPredictions();
|
||||||
|
|
||||||
|
const queries: PredictionQuery[] =
|
||||||
|
mockTelemetryService.refreshPredictions.mock.calls[0][0];
|
||||||
|
|
||||||
|
const complexities = new Set(queries.map((q: PredictionQuery) => q.complexity));
|
||||||
|
|
||||||
|
expect(complexities.has(Complexity.LOW)).toBe(true);
|
||||||
|
expect(complexities.has(Complexity.MEDIUM)).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ---------- onModuleInit ----------
|
||||||
|
|
||||||
|
describe("onModuleInit", () => {
|
||||||
|
it("should trigger refreshCommonPredictions on init when telemetry is enabled", () => {
|
||||||
|
// refreshPredictions is async, but onModuleInit fires it and forgets
|
||||||
|
service.onModuleInit();
|
||||||
|
|
||||||
|
// Give the promise microtask a chance to execute
|
||||||
|
expect(mockTelemetryService.isEnabled).toBe(true);
|
||||||
|
// refreshPredictions will be called asynchronously
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should not refresh when telemetry is disabled", () => {
|
||||||
|
mockTelemetryService.isEnabled = false;
|
||||||
|
|
||||||
|
service.onModuleInit();
|
||||||
|
|
||||||
|
// refreshPredictions should not be called since we returned early
|
||||||
|
expect(mockTelemetryService.refreshPredictions).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should not throw when refresh fails on init", () => {
|
||||||
|
mockTelemetryService.refreshPredictions.mockRejectedValue(
|
||||||
|
new Error("Connection refused")
|
||||||
|
);
|
||||||
|
|
||||||
|
// Should not throw
|
||||||
|
expect(() => service.onModuleInit()).not.toThrow();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
161
apps/api/src/mosaic-telemetry/prediction.service.ts
Normal file
161
apps/api/src/mosaic-telemetry/prediction.service.ts
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import { Injectable, Logger, OnModuleInit } from "@nestjs/common";
|
||||||
|
import {
|
||||||
|
TaskType,
|
||||||
|
Complexity,
|
||||||
|
Provider,
|
||||||
|
type PredictionQuery,
|
||||||
|
type PredictionResponse,
|
||||||
|
} from "@mosaicstack/telemetry-client";
|
||||||
|
import { MosaicTelemetryService } from "./mosaic-telemetry.service";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Common model-provider combinations used for pre-fetching predictions.
|
||||||
|
* These represent the most frequently used LLM configurations.
|
||||||
|
*/
|
||||||
|
const COMMON_MODELS: { model: string; provider: Provider }[] = [
|
||||||
|
{ model: "claude-sonnet-4-5", provider: Provider.ANTHROPIC },
|
||||||
|
{ model: "claude-opus-4", provider: Provider.ANTHROPIC },
|
||||||
|
{ model: "claude-haiku-4-5", provider: Provider.ANTHROPIC },
|
||||||
|
{ model: "gpt-4o", provider: Provider.OPENAI },
|
||||||
|
{ model: "gpt-4o-mini", provider: Provider.OPENAI },
|
||||||
|
];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Common task types to pre-fetch predictions for.
|
||||||
|
*/
|
||||||
|
const COMMON_TASK_TYPES: TaskType[] = [
|
||||||
|
TaskType.IMPLEMENTATION,
|
||||||
|
TaskType.PLANNING,
|
||||||
|
TaskType.CODE_REVIEW,
|
||||||
|
];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Common complexity levels to pre-fetch predictions for.
|
||||||
|
*/
|
||||||
|
const COMMON_COMPLEXITIES: Complexity[] = [Complexity.LOW, Complexity.MEDIUM];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* PredictionService
|
||||||
|
*
|
||||||
|
* Provides pre-task cost and token estimates using crowd-sourced prediction data
|
||||||
|
* from the Mosaic Telemetry server. Predictions are cached by the underlying SDK
|
||||||
|
* with a 6-hour TTL.
|
||||||
|
*
|
||||||
|
* This service is intentionally non-blocking: if predictions are unavailable
|
||||||
|
* (telemetry disabled, server unreachable, no data), all methods return null
|
||||||
|
* without throwing errors. Task execution should never be blocked by prediction
|
||||||
|
* failures.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* ```typescript
|
||||||
|
* const estimate = this.predictionService.getEstimate(
|
||||||
|
* TaskType.IMPLEMENTATION,
|
||||||
|
* "claude-sonnet-4-5",
|
||||||
|
* Provider.ANTHROPIC,
|
||||||
|
* Complexity.LOW,
|
||||||
|
* );
|
||||||
|
* if (estimate?.prediction) {
|
||||||
|
* console.log(`Estimated cost: ${estimate.prediction.cost_usd_micros}`);
|
||||||
|
* }
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
@Injectable()
|
||||||
|
export class PredictionService implements OnModuleInit {
|
||||||
|
private readonly logger = new Logger(PredictionService.name);
|
||||||
|
|
||||||
|
constructor(private readonly telemetry: MosaicTelemetryService) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Refresh common predictions on startup.
|
||||||
|
* Runs asynchronously and never blocks module initialization.
|
||||||
|
*/
|
||||||
|
onModuleInit(): void {
|
||||||
|
if (!this.telemetry.isEnabled) {
|
||||||
|
this.logger.log("Telemetry disabled - skipping prediction refresh");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fire-and-forget: refresh in the background
|
||||||
|
this.refreshCommonPredictions().catch((error: unknown) => {
|
||||||
|
const msg = error instanceof Error ? error.message : String(error);
|
||||||
|
this.logger.warn(`Failed to refresh common predictions on startup: ${msg}`);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a cost/token estimate for a given task configuration.
|
||||||
|
*
|
||||||
|
* Returns the cached prediction from the SDK, or null if:
|
||||||
|
* - Telemetry is disabled
|
||||||
|
* - No prediction data exists for this combination
|
||||||
|
* - The prediction has expired
|
||||||
|
*
|
||||||
|
* @param taskType - The type of task to estimate
|
||||||
|
* @param model - The model name (e.g. "claude-sonnet-4-5")
|
||||||
|
* @param provider - The provider enum value
|
||||||
|
* @param complexity - The complexity level
|
||||||
|
* @returns Prediction response with estimates and confidence, or null
|
||||||
|
*/
|
||||||
|
getEstimate(
|
||||||
|
taskType: TaskType,
|
||||||
|
model: string,
|
||||||
|
provider: Provider,
|
||||||
|
complexity: Complexity
|
||||||
|
): PredictionResponse | null {
|
||||||
|
try {
|
||||||
|
const query: PredictionQuery = {
|
||||||
|
task_type: taskType,
|
||||||
|
model,
|
||||||
|
provider,
|
||||||
|
complexity,
|
||||||
|
};
|
||||||
|
|
||||||
|
return this.telemetry.getPrediction(query);
|
||||||
|
} catch (error: unknown) {
|
||||||
|
const msg = error instanceof Error ? error.message : String(error);
|
||||||
|
this.logger.warn(`Failed to get prediction estimate: ${msg}`);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Refresh predictions for commonly used (taskType, model, provider, complexity) combinations.
|
||||||
|
*
|
||||||
|
* Generates the cross-product of common models, task types, and complexities,
|
||||||
|
* then batch-refreshes them from the telemetry server. The SDK caches the
|
||||||
|
* results with a 6-hour TTL.
|
||||||
|
*
|
||||||
|
* This method is safe to call at any time. If telemetry is disabled or the
|
||||||
|
* server is unreachable, it completes without error.
|
||||||
|
*/
|
||||||
|
async refreshCommonPredictions(): Promise<void> {
|
||||||
|
if (!this.telemetry.isEnabled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const queries: PredictionQuery[] = [];
|
||||||
|
|
||||||
|
for (const { model, provider } of COMMON_MODELS) {
|
||||||
|
for (const taskType of COMMON_TASK_TYPES) {
|
||||||
|
for (const complexity of COMMON_COMPLEXITIES) {
|
||||||
|
queries.push({
|
||||||
|
task_type: taskType,
|
||||||
|
model,
|
||||||
|
provider,
|
||||||
|
complexity,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.logger.log(`Refreshing ${String(queries.length)} common prediction queries...`);
|
||||||
|
|
||||||
|
try {
|
||||||
|
await this.telemetry.refreshPredictions(queries);
|
||||||
|
this.logger.log(`Successfully refreshed ${String(queries.length)} predictions`);
|
||||||
|
} catch (error: unknown) {
|
||||||
|
const msg = error instanceof Error ? error.message : String(error);
|
||||||
|
this.logger.warn(`Failed to refresh predictions: ${msg}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user