import { Injectable, Logger, OnModuleInit } from "@nestjs/common"; import { TaskType, Complexity, Provider, type PredictionQuery, type PredictionResponse, } from "@mosaicstack/telemetry-client"; import { MosaicTelemetryService } from "./mosaic-telemetry.service"; /** * Common model-provider combinations used for pre-fetching predictions. * These represent the most frequently used LLM configurations. */ const COMMON_MODELS: { model: string; provider: Provider }[] = [ { model: "claude-sonnet-4-5", provider: Provider.ANTHROPIC }, { model: "claude-opus-4", provider: Provider.ANTHROPIC }, { model: "claude-haiku-4-5", provider: Provider.ANTHROPIC }, { model: "gpt-4o", provider: Provider.OPENAI }, { model: "gpt-4o-mini", provider: Provider.OPENAI }, ]; /** * Common task types to pre-fetch predictions for. */ const COMMON_TASK_TYPES: TaskType[] = [ TaskType.IMPLEMENTATION, TaskType.PLANNING, TaskType.CODE_REVIEW, ]; /** * Common complexity levels to pre-fetch predictions for. */ const COMMON_COMPLEXITIES: Complexity[] = [Complexity.LOW, Complexity.MEDIUM]; /** * PredictionService * * Provides pre-task cost and token estimates using crowd-sourced prediction data * from the Mosaic Telemetry server. Predictions are cached by the underlying SDK * with a 6-hour TTL. * * This service is intentionally non-blocking: if predictions are unavailable * (telemetry disabled, server unreachable, no data), all methods return null * without throwing errors. Task execution should never be blocked by prediction * failures. * * @example * ```typescript * const estimate = this.predictionService.getEstimate( * TaskType.IMPLEMENTATION, * "claude-sonnet-4-5", * Provider.ANTHROPIC, * Complexity.LOW, * ); * if (estimate?.prediction) { * console.log(`Estimated cost: ${estimate.prediction.cost_usd_micros}`); * } * ``` */ @Injectable() export class PredictionService implements OnModuleInit { private readonly logger = new Logger(PredictionService.name); constructor(private readonly telemetry: MosaicTelemetryService) {} /** * Refresh common predictions on startup. * Runs asynchronously and never blocks module initialization. */ onModuleInit(): void { if (!this.telemetry.isEnabled) { this.logger.log("Telemetry disabled - skipping prediction refresh"); return; } // Fire-and-forget: refresh in the background this.refreshCommonPredictions().catch((error: unknown) => { const msg = error instanceof Error ? error.message : String(error); this.logger.warn(`Failed to refresh common predictions on startup: ${msg}`); }); } /** * Get a cost/token estimate for a given task configuration. * * Returns the cached prediction from the SDK, or null if: * - Telemetry is disabled * - No prediction data exists for this combination * - The prediction has expired * * @param taskType - The type of task to estimate * @param model - The model name (e.g. "claude-sonnet-4-5") * @param provider - The provider enum value * @param complexity - The complexity level * @returns Prediction response with estimates and confidence, or null */ getEstimate( taskType: TaskType, model: string, provider: Provider, complexity: Complexity ): PredictionResponse | null { try { const query: PredictionQuery = { task_type: taskType, model, provider, complexity, }; return this.telemetry.getPrediction(query); } catch (error: unknown) { const msg = error instanceof Error ? error.message : String(error); this.logger.warn(`Failed to get prediction estimate: ${msg}`); return null; } } /** * Refresh predictions for commonly used (taskType, model, provider, complexity) combinations. * * Generates the cross-product of common models, task types, and complexities, * then batch-refreshes them from the telemetry server. The SDK caches the * results with a 6-hour TTL. * * This method is safe to call at any time. If telemetry is disabled or the * server is unreachable, it completes without error. */ async refreshCommonPredictions(): Promise { if (!this.telemetry.isEnabled) { return; } const queries: PredictionQuery[] = []; for (const { model, provider } of COMMON_MODELS) { for (const taskType of COMMON_TASK_TYPES) { for (const complexity of COMMON_COMPLEXITIES) { queries.push({ task_type: taskType, model, provider, complexity, }); } } } this.logger.log(`Refreshing ${String(queries.length)} common prediction queries...`); try { await this.telemetry.refreshPredictions(queries); this.logger.log(`Successfully refreshed ${String(queries.length)} predictions`); } catch (error: unknown) { const msg = error instanceof Error ? error.message : String(error); this.logger.warn(`Failed to refresh predictions: ${msg}`); } } }