import { Inject, Injectable, Logger } from '@nestjs/common'; import type { ModelInfo } from '@mosaicstack/types'; import type { RoutingCriteria, RoutingResult, CostTier } from '@mosaicstack/types'; import { ProviderService } from './provider.service.js'; /** Per-million-token cost thresholds for tier classification */ const COST_TIER_THRESHOLDS: Record = { cheap: { maxInput: 1 }, standard: { maxInput: 10 }, premium: { maxInput: Infinity }, // local = self-hosted; treat as cheapest tier for cost scoring purposes local: { maxInput: 0 }, }; @Injectable() export class RoutingService { private readonly logger = new Logger(RoutingService.name); constructor(@Inject(ProviderService) private readonly providerService: ProviderService) {} /** * Select the best available model for the given criteria. * Returns null if no model matches the requirements. */ route(criteria: RoutingCriteria = {}): RoutingResult | null { const available = this.providerService.listAvailableModels(); if (available.length === 0) { this.logger.warn('No available models for routing'); return null; } // If a specific model is preferred, try it first if (criteria.preferredProvider && criteria.preferredModel) { const match = available.find( (m) => m.provider === criteria.preferredProvider && m.id === criteria.preferredModel, ); if (match) { return { provider: match.provider, modelId: match.id, modelName: match.name, score: 100, reasoning: 'Preferred model selected', }; } } // Score and rank candidates const scored = available .map((model) => this.scoreModel(model, criteria)) .filter((s) => s.score > 0) .sort((a, b) => b.score - a.score); if (scored.length === 0) { this.logger.warn('No models matched routing criteria', criteria); return null; } const best = scored[0] as RoutingResult; this.logger.debug( `Routed to ${best.provider}/${best.modelId} (score=${best.score}): ${best.reasoning}`, ); return best; } /** * List all available models ranked by suitability for the given criteria. */ rank(criteria: RoutingCriteria = {}): RoutingResult[] { const available = this.providerService.listAvailableModels(); return available .map((model) => this.scoreModel(model, criteria)) .filter((s) => s.score > 0) .sort((a, b) => b.score - a.score); } private scoreModel(model: ModelInfo, criteria: RoutingCriteria): RoutingResult { let score = 50; // Base score const reasons: string[] = []; // Hard requirements — disqualify if not met if (criteria.requireReasoning && !model.reasoning) { return this.disqualified(model, 'reasoning required but not supported'); } if (criteria.requireImageInput && !model.inputTypes.includes('image')) { return this.disqualified(model, 'image input required but not supported'); } if (criteria.minContextWindow && model.contextWindow < criteria.minContextWindow) { return this.disqualified( model, `context window ${model.contextWindow} < required ${criteria.minContextWindow}`, ); } // Cost tier matching if (criteria.costTier) { const tier = this.classifyTier(model); if (tier === criteria.costTier) { score += 20; reasons.push(`cost tier match (${tier})`); } else if ( (criteria.costTier === 'cheap' && tier === 'standard') || (criteria.costTier === 'standard' && tier === 'premium') ) { score += 5; reasons.push(`adjacent cost tier (wanted ${criteria.costTier}, got ${tier})`); } } // Prefer cheaper models when no cost tier specified if (!criteria.costTier) { const costPerMillion = model.cost.input; if (costPerMillion <= 1) score += 10; else if (costPerMillion <= 5) score += 5; } // Provider preference if (criteria.preferredProvider && model.provider === criteria.preferredProvider) { score += 15; reasons.push('preferred provider'); } // Reasoning bonus for complex tasks if (model.reasoning) { if (criteria.taskType === 'coding' || criteria.taskType === 'analysis') { score += 10; reasons.push('reasoning model for complex task'); } } // Large context bonus for analysis tasks if (criteria.taskType === 'analysis' && model.contextWindow >= 128_000) { score += 5; reasons.push('large context window'); } return { provider: model.provider, modelId: model.id, modelName: model.name, score, reasoning: reasons.length > 0 ? reasons.join('; ') : 'base score', }; } private classifyTier(model: ModelInfo): CostTier { const cost = model.cost.input; const cheapThreshold = COST_TIER_THRESHOLDS['cheap']; const standardThreshold = COST_TIER_THRESHOLDS['standard']; if (cost <= cheapThreshold.maxInput) return 'cheap'; if (cost <= standardThreshold.maxInput) return 'standard'; return 'premium'; } private disqualified(model: ModelInfo, reason: string): RoutingResult { return { provider: model.provider, modelId: model.id, modelName: model.name, score: 0, reasoning: `disqualified: ${reason}`, }; } }