diff --git a/apps/gateway/src/agent/agent.module.ts b/apps/gateway/src/agent/agent.module.ts index a659ff4..9c28cc4 100644 --- a/apps/gateway/src/agent/agent.module.ts +++ b/apps/gateway/src/agent/agent.module.ts @@ -1,12 +1,13 @@ import { Global, Module } from '@nestjs/common'; import { AgentService } from './agent.service.js'; import { ProviderService } from './provider.service.js'; +import { RoutingService } from './routing.service.js'; import { ProvidersController } from './providers.controller.js'; @Global() @Module({ - providers: [ProviderService, AgentService], + providers: [ProviderService, RoutingService, AgentService], controllers: [ProvidersController], - exports: [AgentService, ProviderService], + exports: [AgentService, ProviderService, RoutingService], }) export class AgentModule {} diff --git a/apps/gateway/src/agent/providers.controller.ts b/apps/gateway/src/agent/providers.controller.ts index d74f33a..1c8a004 100644 --- a/apps/gateway/src/agent/providers.controller.ts +++ b/apps/gateway/src/agent/providers.controller.ts @@ -1,11 +1,16 @@ -import { Controller, Get, UseGuards } from '@nestjs/common'; +import { Body, Controller, Get, Post, UseGuards } from '@nestjs/common'; +import type { RoutingCriteria } from '@mosaic/types'; import { AuthGuard } from '../auth/auth.guard.js'; import { ProviderService } from './provider.service.js'; +import { RoutingService } from './routing.service.js'; @Controller('api/providers') @UseGuards(AuthGuard) export class ProvidersController { - constructor(private readonly providerService: ProviderService) {} + constructor( + private readonly providerService: ProviderService, + private readonly routingService: RoutingService, + ) {} @Get() list() { @@ -16,4 +21,14 @@ export class ProvidersController { listModels() { return this.providerService.listAvailableModels(); } + + @Post('route') + route(@Body() criteria: RoutingCriteria) { + return this.routingService.route(criteria); + } + + @Post('rank') + rank(@Body() criteria: RoutingCriteria) { + return this.routingService.rank(criteria); + } } diff --git a/apps/gateway/src/agent/routing.service.ts b/apps/gateway/src/agent/routing.service.ts new file mode 100644 index 0000000..5902f71 --- /dev/null +++ b/apps/gateway/src/agent/routing.service.ts @@ -0,0 +1,162 @@ +import { Injectable, Logger } from '@nestjs/common'; +import type { ModelInfo } from '@mosaic/types'; +import type { RoutingCriteria, RoutingResult, CostTier } from '@mosaic/types'; +import { ProviderService } from './provider.service.js'; + +/** Per-million-token cost thresholds for tier classification */ +const COST_TIER_THRESHOLDS: Record = { + cheap: { maxInput: 1 }, + standard: { maxInput: 10 }, + premium: { maxInput: Infinity }, +}; + +@Injectable() +export class RoutingService { + private readonly logger = new Logger(RoutingService.name); + + constructor(private readonly providerService: ProviderService) {} + + /** + * Select the best available model for the given criteria. + * Returns null if no model matches the requirements. + */ + route(criteria: RoutingCriteria = {}): RoutingResult | null { + const available = this.providerService.listAvailableModels(); + if (available.length === 0) { + this.logger.warn('No available models for routing'); + return null; + } + + // If a specific model is preferred, try it first + if (criteria.preferredProvider && criteria.preferredModel) { + const match = available.find( + (m) => m.provider === criteria.preferredProvider && m.id === criteria.preferredModel, + ); + if (match) { + return { + provider: match.provider, + modelId: match.id, + modelName: match.name, + score: 100, + reasoning: 'Preferred model selected', + }; + } + } + + // Score and rank candidates + const scored = available + .map((model) => this.scoreModel(model, criteria)) + .filter((s) => s.score > 0) + .sort((a, b) => b.score - a.score); + + if (scored.length === 0) { + this.logger.warn('No models matched routing criteria', criteria); + return null; + } + + const best = scored[0] as RoutingResult; + this.logger.debug( + `Routed to ${best.provider}/${best.modelId} (score=${best.score}): ${best.reasoning}`, + ); + return best; + } + + /** + * List all available models ranked by suitability for the given criteria. + */ + rank(criteria: RoutingCriteria = {}): RoutingResult[] { + const available = this.providerService.listAvailableModels(); + return available + .map((model) => this.scoreModel(model, criteria)) + .filter((s) => s.score > 0) + .sort((a, b) => b.score - a.score); + } + + private scoreModel(model: ModelInfo, criteria: RoutingCriteria): RoutingResult { + let score = 50; // Base score + const reasons: string[] = []; + + // Hard requirements — disqualify if not met + if (criteria.requireReasoning && !model.reasoning) { + return this.disqualified(model, 'reasoning required but not supported'); + } + + if (criteria.requireImageInput && !model.inputTypes.includes('image')) { + return this.disqualified(model, 'image input required but not supported'); + } + + if (criteria.minContextWindow && model.contextWindow < criteria.minContextWindow) { + return this.disqualified( + model, + `context window ${model.contextWindow} < required ${criteria.minContextWindow}`, + ); + } + + // Cost tier matching + if (criteria.costTier) { + const tier = this.classifyTier(model); + if (tier === criteria.costTier) { + score += 20; + reasons.push(`cost tier match (${tier})`); + } else if ( + (criteria.costTier === 'cheap' && tier === 'standard') || + (criteria.costTier === 'standard' && tier === 'premium') + ) { + score += 5; + reasons.push(`adjacent cost tier (wanted ${criteria.costTier}, got ${tier})`); + } + } + + // Prefer cheaper models when no cost tier specified + if (!criteria.costTier) { + const costPerMillion = model.cost.input; + if (costPerMillion <= 1) score += 10; + else if (costPerMillion <= 5) score += 5; + } + + // Provider preference + if (criteria.preferredProvider && model.provider === criteria.preferredProvider) { + score += 15; + reasons.push('preferred provider'); + } + + // Reasoning bonus for complex tasks + if (model.reasoning) { + if (criteria.taskType === 'coding' || criteria.taskType === 'analysis') { + score += 10; + reasons.push('reasoning model for complex task'); + } + } + + // Large context bonus for analysis tasks + if (criteria.taskType === 'analysis' && model.contextWindow >= 128_000) { + score += 5; + reasons.push('large context window'); + } + + return { + provider: model.provider, + modelId: model.id, + modelName: model.name, + score, + reasoning: reasons.length > 0 ? reasons.join('; ') : 'base score', + }; + } + + private classifyTier(model: ModelInfo): CostTier { + const cost = model.cost.input; + if (cost <= COST_TIER_THRESHOLDS.cheap.maxInput) return 'cheap'; + if (cost <= COST_TIER_THRESHOLDS.standard.maxInput) return 'standard'; + return 'premium'; + } + + private disqualified(model: ModelInfo, reason: string): RoutingResult { + return { + provider: model.provider, + modelId: model.id, + modelName: model.name, + score: 0, + reasoning: `disqualified: ${reason}`, + }; + } +} diff --git a/packages/types/src/index.ts b/packages/types/src/index.ts index 3d52f83..3634278 100644 --- a/packages/types/src/index.ts +++ b/packages/types/src/index.ts @@ -3,3 +3,4 @@ export const VERSION = '0.0.0'; export * from './chat/index.js'; export * from './agent/index.js'; export * from './provider/index.js'; +export * from './routing/index.js'; diff --git a/packages/types/src/routing/index.ts b/packages/types/src/routing/index.ts new file mode 100644 index 0000000..b26cd67 --- /dev/null +++ b/packages/types/src/routing/index.ts @@ -0,0 +1,25 @@ +/** Cost tier for model selection */ +export type CostTier = 'cheap' | 'standard' | 'premium'; + +/** Task type hint for routing */ +export type TaskType = 'chat' | 'coding' | 'analysis' | 'summarization' | 'general'; + +/** Routing criteria for model selection */ +export interface RoutingCriteria { + taskType?: TaskType; + costTier?: CostTier; + requireReasoning?: boolean; + requireImageInput?: boolean; + minContextWindow?: number; + preferredProvider?: string; + preferredModel?: string; +} + +/** Result of a routing decision */ +export interface RoutingResult { + provider: string; + modelId: string; + modelName: string; + score: number; + reasoning: string; +}