diff --git a/apps/gateway/src/agent/__tests__/provider-adapters.test.ts b/apps/gateway/src/agent/__tests__/provider-adapters.test.ts new file mode 100644 index 0000000..82a6e9b --- /dev/null +++ b/apps/gateway/src/agent/__tests__/provider-adapters.test.ts @@ -0,0 +1,770 @@ +/** + * Provider Adapter Integration Tests — M3-012 + * + * Verifies that all five provider adapters (Anthropic, OpenAI, OpenRouter, Z.ai, Ollama) + * are properly integrated: registration, model listing, graceful degradation without + * API keys, capability matrix correctness, and ProviderCredentialsService behaviour. + * + * These tests are designed to run in CI with no real API keys; they test graceful + * degradation and static configuration rather than live network calls. + */ + +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; +import { ModelRegistry, AuthStorage } from '@mariozechner/pi-coding-agent'; +import { AnthropicAdapter } from '../adapters/anthropic.adapter.js'; +import { OpenAIAdapter } from '../adapters/openai.adapter.js'; +import { OpenRouterAdapter } from '../adapters/openrouter.adapter.js'; +import { ZaiAdapter } from '../adapters/zai.adapter.js'; +import { OllamaAdapter } from '../adapters/ollama.adapter.js'; +import { ProviderService } from '../provider.service.js'; +import { + getModelCapability, + MODEL_CAPABILITIES, + findModelsByCapability, +} from '../model-capabilities.js'; + +// --------------------------------------------------------------------------- +// Environment helpers +// --------------------------------------------------------------------------- + +const ALL_PROVIDER_KEYS = [ + 'ANTHROPIC_API_KEY', + 'OPENAI_API_KEY', + 'OPENROUTER_API_KEY', + 'ZAI_API_KEY', + 'ZAI_BASE_URL', + 'OLLAMA_BASE_URL', + 'OLLAMA_HOST', + 'OLLAMA_MODELS', + 'BETTER_AUTH_SECRET', +] as const; + +type EnvKey = (typeof ALL_PROVIDER_KEYS)[number]; + +function saveAndClearEnv(): Map { + const saved = new Map(); + for (const key of ALL_PROVIDER_KEYS) { + saved.set(key, process.env[key]); + delete process.env[key]; + } + return saved; +} + +function restoreEnv(saved: Map): void { + for (const key of ALL_PROVIDER_KEYS) { + const value = saved.get(key); + if (value === undefined) { + delete process.env[key]; + } else { + process.env[key] = value; + } + } +} + +function makeRegistry(): ModelRegistry { + return new ModelRegistry(AuthStorage.inMemory()); +} + +// --------------------------------------------------------------------------- +// 1. Adapter registration tests +// --------------------------------------------------------------------------- + +describe('AnthropicAdapter', () => { + let savedEnv: Map; + + beforeEach(() => { + savedEnv = saveAndClearEnv(); + }); + + afterEach(() => { + restoreEnv(savedEnv); + }); + + it('skips registration gracefully when ANTHROPIC_API_KEY is missing', async () => { + const adapter = new AnthropicAdapter(makeRegistry()); + await expect(adapter.register()).resolves.toBeUndefined(); + expect(adapter.listModels()).toEqual([]); + }); + + it('registers and listModels returns expected models when ANTHROPIC_API_KEY is set', async () => { + process.env['ANTHROPIC_API_KEY'] = 'sk-ant-test'; + const adapter = new AnthropicAdapter(makeRegistry()); + await adapter.register(); + + const models = adapter.listModels(); + expect(models.length).toBeGreaterThan(0); + + const ids = models.map((m) => m.id); + expect(ids).toContain('claude-opus-4-6'); + expect(ids).toContain('claude-sonnet-4-6'); + expect(ids).toContain('claude-haiku-4-5'); + + for (const model of models) { + expect(model.provider).toBe('anthropic'); + expect(model.contextWindow).toBe(200000); + } + }); + + it('healthCheck returns down with error when ANTHROPIC_API_KEY is missing', async () => { + const adapter = new AnthropicAdapter(makeRegistry()); + const health = await adapter.healthCheck(); + expect(health.status).toBe('down'); + expect(health.error).toMatch(/ANTHROPIC_API_KEY/); + expect(health.lastChecked).toBeTruthy(); + }); + + it('adapter name is "anthropic"', () => { + expect(new AnthropicAdapter(makeRegistry()).name).toBe('anthropic'); + }); +}); + +describe('OpenAIAdapter', () => { + let savedEnv: Map; + + beforeEach(() => { + savedEnv = saveAndClearEnv(); + }); + + afterEach(() => { + restoreEnv(savedEnv); + }); + + it('skips registration gracefully when OPENAI_API_KEY is missing', async () => { + const adapter = new OpenAIAdapter(makeRegistry()); + await expect(adapter.register()).resolves.toBeUndefined(); + expect(adapter.listModels()).toEqual([]); + }); + + it('registers and listModels returns Codex model when OPENAI_API_KEY is set', async () => { + process.env['OPENAI_API_KEY'] = 'sk-openai-test'; + const adapter = new OpenAIAdapter(makeRegistry()); + await adapter.register(); + + const models = adapter.listModels(); + expect(models.length).toBeGreaterThan(0); + + const ids = models.map((m) => m.id); + expect(ids).toContain(OpenAIAdapter.CODEX_MODEL_ID); + + const codex = models.find((m) => m.id === OpenAIAdapter.CODEX_MODEL_ID)!; + expect(codex.provider).toBe('openai'); + expect(codex.contextWindow).toBe(128_000); + expect(codex.maxTokens).toBe(16_384); + }); + + it('healthCheck returns down with error when OPENAI_API_KEY is missing', async () => { + const adapter = new OpenAIAdapter(makeRegistry()); + const health = await adapter.healthCheck(); + expect(health.status).toBe('down'); + expect(health.error).toMatch(/OPENAI_API_KEY/); + }); + + it('adapter name is "openai"', () => { + expect(new OpenAIAdapter(makeRegistry()).name).toBe('openai'); + }); +}); + +describe('OpenRouterAdapter', () => { + let savedEnv: Map; + + beforeEach(() => { + savedEnv = saveAndClearEnv(); + // Prevent real network calls during registration — stub global fetch + vi.stubGlobal( + 'fetch', + vi.fn().mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + data: [ + { + id: 'openai/gpt-4o', + name: 'GPT-4o', + context_length: 128000, + top_provider: { max_completion_tokens: 4096 }, + pricing: { prompt: '0.000005', completion: '0.000015' }, + architecture: { input_modalities: ['text', 'image'] }, + }, + ], + }), + }), + ); + }); + + afterEach(() => { + restoreEnv(savedEnv); + vi.unstubAllGlobals(); + }); + + it('skips registration gracefully when OPENROUTER_API_KEY is missing', async () => { + vi.unstubAllGlobals(); // no fetch call expected + const adapter = new OpenRouterAdapter(); + await expect(adapter.register()).resolves.toBeUndefined(); + expect(adapter.listModels()).toEqual([]); + }); + + it('registers and listModels returns models when OPENROUTER_API_KEY is set', async () => { + process.env['OPENROUTER_API_KEY'] = 'sk-or-test'; + const adapter = new OpenRouterAdapter(); + await adapter.register(); + + const models = adapter.listModels(); + expect(models.length).toBeGreaterThan(0); + + const first = models[0]!; + expect(first.provider).toBe('openrouter'); + expect(first.id).toBe('openai/gpt-4o'); + expect(first.inputTypes).toContain('image'); + }); + + it('healthCheck returns down with error when OPENROUTER_API_KEY is missing', async () => { + vi.unstubAllGlobals(); // no fetch call expected + const adapter = new OpenRouterAdapter(); + const health = await adapter.healthCheck(); + expect(health.status).toBe('down'); + expect(health.error).toMatch(/OPENROUTER_API_KEY/); + }); + + it('continues registration with empty model list when model fetch fails', async () => { + process.env['OPENROUTER_API_KEY'] = 'sk-or-test'; + vi.stubGlobal( + 'fetch', + vi.fn().mockResolvedValue({ + ok: false, + status: 500, + }), + ); + const adapter = new OpenRouterAdapter(); + await expect(adapter.register()).resolves.toBeUndefined(); + expect(adapter.listModels()).toEqual([]); + }); + + it('adapter name is "openrouter"', () => { + expect(new OpenRouterAdapter().name).toBe('openrouter'); + }); +}); + +describe('ZaiAdapter', () => { + let savedEnv: Map; + + beforeEach(() => { + savedEnv = saveAndClearEnv(); + }); + + afterEach(() => { + restoreEnv(savedEnv); + }); + + it('skips registration gracefully when ZAI_API_KEY is missing', async () => { + const adapter = new ZaiAdapter(); + await expect(adapter.register()).resolves.toBeUndefined(); + expect(adapter.listModels()).toEqual([]); + }); + + it('registers and listModels returns glm-5 when ZAI_API_KEY is set', async () => { + process.env['ZAI_API_KEY'] = 'zai-test-key'; + const adapter = new ZaiAdapter(); + await adapter.register(); + + const models = adapter.listModels(); + expect(models.length).toBeGreaterThan(0); + + const ids = models.map((m) => m.id); + expect(ids).toContain('glm-5'); + + const glm = models.find((m) => m.id === 'glm-5')!; + expect(glm.provider).toBe('zai'); + }); + + it('healthCheck returns down with error when ZAI_API_KEY is missing', async () => { + const adapter = new ZaiAdapter(); + const health = await adapter.healthCheck(); + expect(health.status).toBe('down'); + expect(health.error).toMatch(/ZAI_API_KEY/); + }); + + it('adapter name is "zai"', () => { + expect(new ZaiAdapter().name).toBe('zai'); + }); +}); + +describe('OllamaAdapter', () => { + let savedEnv: Map; + + beforeEach(() => { + savedEnv = saveAndClearEnv(); + }); + + afterEach(() => { + restoreEnv(savedEnv); + }); + + it('skips registration gracefully when OLLAMA_BASE_URL is missing', async () => { + const adapter = new OllamaAdapter(makeRegistry()); + await expect(adapter.register()).resolves.toBeUndefined(); + expect(adapter.listModels()).toEqual([]); + }); + + it('registers via OLLAMA_HOST fallback when OLLAMA_BASE_URL is absent', async () => { + process.env['OLLAMA_HOST'] = 'http://localhost:11434'; + const adapter = new OllamaAdapter(makeRegistry()); + await adapter.register(); + const models = adapter.listModels(); + expect(models.length).toBeGreaterThan(0); + }); + + it('registers default models (llama3.2, codellama, mistral) + embedding models', async () => { + process.env['OLLAMA_BASE_URL'] = 'http://localhost:11434'; + const adapter = new OllamaAdapter(makeRegistry()); + await adapter.register(); + + const models = adapter.listModels(); + const ids = models.map((m) => m.id); + + // Default completion models + expect(ids).toContain('llama3.2'); + expect(ids).toContain('codellama'); + expect(ids).toContain('mistral'); + + // Embedding models + expect(ids).toContain('nomic-embed-text'); + expect(ids).toContain('mxbai-embed-large'); + + for (const model of models) { + expect(model.provider).toBe('ollama'); + } + }); + + it('registers custom OLLAMA_MODELS list', async () => { + process.env['OLLAMA_BASE_URL'] = 'http://localhost:11434'; + process.env['OLLAMA_MODELS'] = 'phi3,gemma2'; + const adapter = new OllamaAdapter(makeRegistry()); + await adapter.register(); + + const completionIds = adapter.listModels().map((m) => m.id); + expect(completionIds).toContain('phi3'); + expect(completionIds).toContain('gemma2'); + expect(completionIds).not.toContain('llama3.2'); + }); + + it('healthCheck returns down with error when OLLAMA_BASE_URL is missing', async () => { + const adapter = new OllamaAdapter(makeRegistry()); + const health = await adapter.healthCheck(); + expect(health.status).toBe('down'); + expect(health.error).toMatch(/OLLAMA_BASE_URL/); + }); + + it('adapter name is "ollama"', () => { + expect(new OllamaAdapter(makeRegistry()).name).toBe('ollama'); + }); +}); + +// --------------------------------------------------------------------------- +// 2. ProviderService integration +// --------------------------------------------------------------------------- + +describe('ProviderService — adapter array integration', () => { + let savedEnv: Map; + + beforeEach(() => { + savedEnv = saveAndClearEnv(); + }); + + afterEach(() => { + restoreEnv(savedEnv); + }); + + it('contains all 5 adapters (ollama, anthropic, openai, openrouter, zai)', async () => { + const service = new ProviderService(null); + await service.onModuleInit(); + + // Exercise getAdapter for all five known provider names + const expectedProviders = ['ollama', 'anthropic', 'openai', 'openrouter', 'zai']; + for (const name of expectedProviders) { + const adapter = service.getAdapter(name); + expect(adapter, `Expected adapter "${name}" to be registered`).toBeDefined(); + expect(adapter!.name).toBe(name); + } + }); + + it('healthCheckAll runs without crashing and returns status for all 5 providers', async () => { + const service = new ProviderService(null); + await service.onModuleInit(); + + const results = await service.healthCheckAll(); + expect(typeof results).toBe('object'); + + const expectedProviders = ['ollama', 'anthropic', 'openai', 'openrouter', 'zai']; + for (const name of expectedProviders) { + const health = results[name]; + expect(health, `Expected health result for provider "${name}"`).toBeDefined(); + expect(['healthy', 'degraded', 'down']).toContain(health!.status); + expect(health!.lastChecked).toBeTruthy(); + } + }); + + it('healthCheckAll reports "down" for all providers when no keys are set', async () => { + const service = new ProviderService(null); + await service.onModuleInit(); + + const results = await service.healthCheckAll(); + // All unconfigured providers should be down (not healthy) + for (const [, health] of Object.entries(results)) { + expect(['down', 'degraded']).toContain(health.status); + } + }); + + it('getProvidersHealth returns entries for all 5 providers', async () => { + const service = new ProviderService(null); + await service.onModuleInit(); + + const healthList = service.getProvidersHealth(); + const names = healthList.map((h) => h.name); + + for (const expected of ['ollama', 'anthropic', 'openai', 'openrouter', 'zai']) { + expect(names).toContain(expected); + } + + for (const entry of healthList) { + expect(entry).toHaveProperty('name'); + expect(entry).toHaveProperty('status'); + expect(entry).toHaveProperty('lastChecked'); + expect(typeof entry.modelCount).toBe('number'); + } + }); + + it('service initialises without error when all env keys are absent', async () => { + const service = new ProviderService(null); + await expect(service.onModuleInit()).resolves.toBeUndefined(); + service.onModuleDestroy(); + }); +}); + +// --------------------------------------------------------------------------- +// 3. Model capability matrix +// --------------------------------------------------------------------------- + +describe('Model capability matrix', () => { + const expectedModels: Array<{ + id: string; + provider: string; + tier: string; + contextWindow: number; + reasoning?: boolean; + vision?: boolean; + embedding?: boolean; + }> = [ + { + id: 'claude-opus-4-6', + provider: 'anthropic', + tier: 'premium', + contextWindow: 200000, + reasoning: true, + vision: true, + }, + { + id: 'claude-sonnet-4-6', + provider: 'anthropic', + tier: 'standard', + contextWindow: 200000, + reasoning: true, + vision: true, + }, + { + id: 'claude-haiku-4-5', + provider: 'anthropic', + tier: 'cheap', + contextWindow: 200000, + reasoning: false, + vision: true, + }, + { + id: 'codex-gpt-5.4', + provider: 'openai', + tier: 'premium', + contextWindow: 128000, + }, + { + id: 'glm-5', + provider: 'zai', + tier: 'standard', + contextWindow: 128000, + }, + { + id: 'llama3.2', + provider: 'ollama', + tier: 'local', + contextWindow: 128000, + }, + { + id: 'codellama', + provider: 'ollama', + tier: 'local', + contextWindow: 16000, + }, + { + id: 'mistral', + provider: 'ollama', + tier: 'local', + contextWindow: 32000, + }, + { + id: 'nomic-embed-text', + provider: 'ollama', + tier: 'local', + contextWindow: 8192, + embedding: true, + }, + { + id: 'mxbai-embed-large', + provider: 'ollama', + tier: 'local', + contextWindow: 8192, + embedding: true, + }, + ]; + + it('MODEL_CAPABILITIES contains all expected model IDs', () => { + const allIds = MODEL_CAPABILITIES.map((m) => m.id); + for (const { id } of expectedModels) { + expect(allIds, `Expected model "${id}" in capability matrix`).toContain(id); + } + }); + + it('getModelCapability() returns correct tier and context window for each model', () => { + for (const expected of expectedModels) { + const cap = getModelCapability(expected.id); + expect(cap, `getModelCapability("${expected.id}") should be defined`).toBeDefined(); + expect(cap!.provider).toBe(expected.provider); + expect(cap!.tier).toBe(expected.tier); + expect(cap!.contextWindow).toBe(expected.contextWindow); + } + }); + + it('Anthropic models have correct capability flags (tools, streaming, vision, reasoning)', () => { + for (const expected of expectedModels.filter((m) => m.provider === 'anthropic')) { + const cap = getModelCapability(expected.id)!; + expect(cap.capabilities.tools).toBe(true); + expect(cap.capabilities.streaming).toBe(true); + if (expected.vision !== undefined) { + expect(cap.capabilities.vision).toBe(expected.vision); + } + if (expected.reasoning !== undefined) { + expect(cap.capabilities.reasoning).toBe(expected.reasoning); + } + } + }); + + it('Embedding models have embedding flag=true and other flags=false', () => { + for (const expected of expectedModels.filter((m) => m.embedding)) { + const cap = getModelCapability(expected.id)!; + expect(cap.capabilities.embedding).toBe(true); + expect(cap.capabilities.tools).toBe(false); + expect(cap.capabilities.streaming).toBe(false); + expect(cap.capabilities.reasoning).toBe(false); + } + }); + + it('findModelsByCapability filters by tier correctly', () => { + const premiumModels = findModelsByCapability({ tier: 'premium' }); + expect(premiumModels.length).toBeGreaterThan(0); + for (const m of premiumModels) { + expect(m.tier).toBe('premium'); + } + }); + + it('findModelsByCapability filters by provider correctly', () => { + const anthropicModels = findModelsByCapability({ provider: 'anthropic' }); + expect(anthropicModels.length).toBe(3); + for (const m of anthropicModels) { + expect(m.provider).toBe('anthropic'); + } + }); + + it('findModelsByCapability filters by capability flags correctly', () => { + const reasoningModels = findModelsByCapability({ capabilities: { reasoning: true } }); + expect(reasoningModels.length).toBeGreaterThan(0); + for (const m of reasoningModels) { + expect(m.capabilities.reasoning).toBe(true); + } + + const embeddingModels = findModelsByCapability({ capabilities: { embedding: true } }); + expect(embeddingModels.length).toBeGreaterThan(0); + for (const m of embeddingModels) { + expect(m.capabilities.embedding).toBe(true); + } + }); + + it('getModelCapability returns undefined for unknown model IDs', () => { + expect(getModelCapability('not-a-real-model')).toBeUndefined(); + expect(getModelCapability('')).toBeUndefined(); + }); + + it('all Anthropic models have maxOutputTokens > 0', () => { + const anthropicModels = MODEL_CAPABILITIES.filter((m) => m.provider === 'anthropic'); + for (const m of anthropicModels) { + expect(m.maxOutputTokens).toBeGreaterThan(0); + } + }); +}); + +// --------------------------------------------------------------------------- +// 4. ProviderCredentialsService — unit-level tests (encrypt/decrypt logic) +// --------------------------------------------------------------------------- + +describe('ProviderCredentialsService — encryption helpers', () => { + let savedEnv: Map; + + beforeEach(() => { + savedEnv = saveAndClearEnv(); + }); + + afterEach(() => { + restoreEnv(savedEnv); + }); + + /** + * The service uses module-level functions (encrypt/decrypt) that depend on + * BETTER_AUTH_SECRET. We test the behaviour through the service's public API + * using an in-memory mock DB so no real Postgres connection is needed. + */ + it('store/retrieve/remove work correctly with mock DB and BETTER_AUTH_SECRET set', async () => { + process.env['BETTER_AUTH_SECRET'] = 'test-secret-for-unit-tests-only'; + + // Build a minimal in-memory DB mock + const rows = new Map< + string, + { + encryptedValue: string; + credentialType: string; + expiresAt: Date | null; + metadata: null; + createdAt: Date; + updatedAt: Date; + } + >(); + + // We import the service but mock its DB dependency manually + // by testing the encrypt/decrypt indirectly — using the real module. + const { ProviderCredentialsService } = await import('../provider-credentials.service.js'); + + // Capture stored value from upsert call + let storedEncryptedValue = ''; + let storedCredentialType = ''; + const captureInsert = vi.fn().mockImplementation(() => ({ + values: vi + .fn() + .mockImplementation((data: { encryptedValue: string; credentialType: string }) => { + storedEncryptedValue = data.encryptedValue; + storedCredentialType = data.credentialType; + rows.set('user1:anthropic', { + encryptedValue: data.encryptedValue, + credentialType: data.credentialType, + expiresAt: null, + metadata: null, + createdAt: new Date(), + updatedAt: new Date(), + }); + return { + onConflictDoUpdate: vi.fn().mockResolvedValue(undefined), + }; + }), + })); + + const captureSelect = vi.fn().mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockImplementation(() => { + const row = rows.get('user1:anthropic'); + return Promise.resolve(row ? [row] : []); + }), + }), + }), + }); + + const captureDelete = vi.fn().mockReturnValue({ + where: vi.fn().mockResolvedValue(undefined), + }); + + const db = { + insert: captureInsert, + select: captureSelect, + delete: captureDelete, + }; + + const service = new ProviderCredentialsService(db as never); + + // store + await service.store('user1', 'anthropic', 'api_key', 'sk-ant-secret-value'); + + // verify encrypted value is not plain text + expect(storedEncryptedValue).not.toBe('sk-ant-secret-value'); + expect(storedEncryptedValue.length).toBeGreaterThan(0); + expect(storedCredentialType).toBe('api_key'); + + // retrieve + const retrieved = await service.retrieve('user1', 'anthropic'); + expect(retrieved).toBe('sk-ant-secret-value'); + + // remove (clears the row) + rows.delete('user1:anthropic'); + const afterRemove = await service.retrieve('user1', 'anthropic'); + expect(afterRemove).toBeNull(); + }); + + it('retrieve returns null when no credential is stored', async () => { + process.env['BETTER_AUTH_SECRET'] = 'test-secret-for-unit-tests-only'; + + const { ProviderCredentialsService } = await import('../provider-credentials.service.js'); + + const emptyDb = { + select: vi.fn().mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([]), + }), + }), + }), + }; + + const service = new ProviderCredentialsService(emptyDb as never); + const result = await service.retrieve('user-nobody', 'anthropic'); + expect(result).toBeNull(); + }); + + it('listProviders returns only metadata, never decrypted values', async () => { + process.env['BETTER_AUTH_SECRET'] = 'test-secret-for-unit-tests-only'; + + const { ProviderCredentialsService } = await import('../provider-credentials.service.js'); + + const fakeRow = { + provider: 'anthropic', + credentialType: 'api_key', + expiresAt: null, + metadata: null, + createdAt: new Date(), + updatedAt: new Date(), + }; + + const listDb = { + select: vi.fn().mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockResolvedValue([fakeRow]), + }), + }), + }; + + const service = new ProviderCredentialsService(listDb as never); + const providers = await service.listProviders('user1'); + + expect(providers).toHaveLength(1); + expect(providers[0]!.provider).toBe('anthropic'); + expect(providers[0]!.credentialType).toBe('api_key'); + expect(providers[0]!.exists).toBe(true); + + // Critically: no encrypted or plain-text value is exposed + expect(providers[0]).not.toHaveProperty('encryptedValue'); + expect(providers[0]).not.toHaveProperty('value'); + expect(providers[0]).not.toHaveProperty('apiKey'); + }); +}); diff --git a/apps/gateway/src/agent/routing/routing-engine.service.ts b/apps/gateway/src/agent/routing/routing-engine.service.ts new file mode 100644 index 0000000..57ec793 --- /dev/null +++ b/apps/gateway/src/agent/routing/routing-engine.service.ts @@ -0,0 +1,216 @@ +import { Inject, Injectable, Logger } from '@nestjs/common'; +import { routingRules, type Db, and, asc, eq, or } from '@mosaic/db'; +import { DB } from '../../database/database.module.js'; +import { ProviderService } from '../provider.service.js'; +import { classifyTask } from './task-classifier.js'; +import type { + RoutingCondition, + RoutingRule, + RoutingDecision, + TaskClassification, +} from './routing.types.js'; + +// ─── Injection tokens ──────────────────────────────────────────────────────── + +export const PROVIDER_SERVICE = Symbol('ProviderService'); + +// ─── Fallback chain ────────────────────────────────────────────────────────── + +/** + * Ordered fallback providers tried when no rule matches or all matched + * providers are unhealthy. + */ +const FALLBACK_CHAIN: Array<{ provider: string; model: string }> = [ + { provider: 'anthropic', model: 'claude-sonnet-4-6' }, + { provider: 'anthropic', model: 'claude-haiku-4-5' }, + { provider: 'ollama', model: 'llama3.2' }, +]; + +// ─── Service ───────────────────────────────────────────────────────────────── + +@Injectable() +export class RoutingEngineService { + private readonly logger = new Logger(RoutingEngineService.name); + + constructor( + @Inject(DB) private readonly db: Db, + @Inject(ProviderService) private readonly providerService: ProviderService, + ) {} + + /** + * Classify the message, evaluate routing rules in priority order, and return + * the best routing decision. + * + * @param message - Raw user message text used for classification. + * @param userId - Optional user ID for loading user-scoped rules. + * @param availableProviders - Optional pre-fetched provider health map to + * avoid redundant health checks inside tight loops. + */ + async resolve( + message: string, + userId?: string, + availableProviders?: Record, + ): Promise { + const classification = classifyTask(message); + this.logger.debug( + `Classification: taskType=${classification.taskType} complexity=${classification.complexity} domain=${classification.domain}`, + ); + + // Load health data once (re-use caller-supplied map if provided) + const health = availableProviders ?? (await this.providerService.healthCheckAll()); + + // Load all applicable rules ordered by priority + const rules = await this.loadRules(userId); + + // Evaluate rules in priority order + for (const rule of rules) { + if (!rule.enabled) continue; + + if (!this.matchConditions(rule, classification)) continue; + + const providerStatus = health[rule.action.provider]?.status; + const isHealthy = providerStatus === 'up' || providerStatus === 'ok'; + + if (!isHealthy) { + this.logger.debug( + `Rule "${rule.name}" matched but provider "${rule.action.provider}" is unhealthy (status: ${providerStatus ?? 'unknown'})`, + ); + continue; + } + + this.logger.debug( + `Rule matched: "${rule.name}" → ${rule.action.provider}/${rule.action.model}`, + ); + + return { + provider: rule.action.provider, + model: rule.action.model, + agentConfigId: rule.action.agentConfigId, + ruleName: rule.name, + reason: `Matched routing rule "${rule.name}"`, + }; + } + + // No rule matched (or all matched providers were unhealthy) — apply fallback chain + this.logger.debug('No rule matched; applying fallback chain'); + return this.applyFallbackChain(health); + } + + /** + * Check whether all conditions of a rule match the given task classification. + * An empty conditions array always matches (catch-all / fallback rule). + */ + matchConditions( + rule: Pick, + classification: TaskClassification, + ): boolean { + if (rule.conditions.length === 0) return true; + + return rule.conditions.every((condition) => this.evaluateCondition(condition, classification)); + } + + // ─── Private helpers ─────────────────────────────────────────────────────── + + private evaluateCondition( + condition: RoutingCondition, + classification: TaskClassification, + ): boolean { + // `costTier` is a valid condition field but is not part of TaskClassification + // (it is supplied via userOverrides / request context). Treat unknown fields as + // undefined so conditions referencing them simply do not match. + const fieldValue = (classification as unknown as Record)[condition.field]; + + switch (condition.operator) { + case 'eq': { + // Scalar equality: field value must equal condition value (string) + if (typeof condition.value !== 'string') return false; + return fieldValue === condition.value; + } + + case 'in': { + // Set membership: condition value (array) contains field value + if (!Array.isArray(condition.value)) return false; + return condition.value.includes(fieldValue as string); + } + + case 'includes': { + // Array containment: field value (array) includes condition value (string) + if (!Array.isArray(fieldValue)) return false; + if (typeof condition.value !== 'string') return false; + return (fieldValue as string[]).includes(condition.value); + } + + default: + return false; + } + } + + /** + * Load routing rules from the database. + * System rules + user-scoped rules (when userId is provided) are returned, + * ordered by priority ascending. + */ + private async loadRules(userId?: string): Promise { + const whereClause = userId + ? or( + eq(routingRules.scope, 'system'), + and(eq(routingRules.scope, 'user'), eq(routingRules.userId, userId)), + ) + : eq(routingRules.scope, 'system'); + + const rows = await this.db + .select() + .from(routingRules) + .where(whereClause) + .orderBy(asc(routingRules.priority)); + + return rows.map((row) => ({ + id: row.id, + name: row.name, + priority: row.priority, + scope: row.scope as 'system' | 'user', + userId: row.userId ?? undefined, + conditions: (row.conditions as unknown as RoutingCondition[]) ?? [], + action: row.action as unknown as { + provider: string; + model: string; + agentConfigId?: string; + systemPromptOverride?: string; + toolAllowlist?: string[]; + }, + enabled: row.enabled, + })); + } + + /** + * Walk the fallback chain and return the first healthy provider/model pair. + * If none are healthy, return the first entry unconditionally (last resort). + */ + private applyFallbackChain(health: Record): RoutingDecision { + for (const candidate of FALLBACK_CHAIN) { + const providerStatus = health[candidate.provider]?.status; + const isHealthy = providerStatus === 'up' || providerStatus === 'ok'; + if (isHealthy) { + this.logger.debug(`Fallback resolved: ${candidate.provider}/${candidate.model}`); + return { + provider: candidate.provider, + model: candidate.model, + ruleName: 'fallback', + reason: `Fallback chain — no matching rule; selected ${candidate.provider}/${candidate.model}`, + }; + } + } + + // All providers in the fallback chain are unhealthy — use the first entry + const lastResort = FALLBACK_CHAIN[0]!; + this.logger.warn( + `All fallback providers unhealthy; using last resort: ${lastResort.provider}/${lastResort.model}`, + ); + return { + provider: lastResort.provider, + model: lastResort.model, + ruleName: 'fallback', + reason: `Fallback chain exhausted (all providers unhealthy); using ${lastResort.provider}/${lastResort.model}`, + }; + } +} diff --git a/apps/gateway/src/agent/routing/routing-engine.test.ts b/apps/gateway/src/agent/routing/routing-engine.test.ts new file mode 100644 index 0000000..645d079 --- /dev/null +++ b/apps/gateway/src/agent/routing/routing-engine.test.ts @@ -0,0 +1,460 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { RoutingEngineService } from './routing-engine.service.js'; +import type { RoutingRule, TaskClassification } from './routing.types.js'; + +// ─── Helpers ───────────────────────────────────────────────────────────────── + +function makeRule( + overrides: Partial & + Pick, +): RoutingRule { + return { + id: overrides.id ?? crypto.randomUUID(), + scope: 'system', + enabled: true, + ...overrides, + }; +} + +function makeClassification(overrides: Partial = {}): TaskClassification { + return { + taskType: 'conversation', + complexity: 'simple', + domain: 'general', + requiredCapabilities: [], + ...overrides, + }; +} + +/** Build a minimal RoutingEngineService with mocked DB and ProviderService. */ +function makeService( + rules: RoutingRule[] = [], + healthMap: Record = {}, +): RoutingEngineService { + const mockDb = { + select: vi.fn().mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + orderBy: vi.fn().mockResolvedValue( + rules.map((r) => ({ + id: r.id, + name: r.name, + priority: r.priority, + scope: r.scope, + userId: r.userId ?? null, + conditions: r.conditions, + action: r.action, + enabled: r.enabled, + createdAt: new Date(), + updatedAt: new Date(), + })), + ), + }), + }), + }), + }; + + const mockProviderService = { + healthCheckAll: vi.fn().mockResolvedValue(healthMap), + }; + + // Inject mocked dependencies directly (bypass NestJS DI for unit tests) + const service = new (RoutingEngineService as unknown as new ( + db: unknown, + ps: unknown, + ) => RoutingEngineService)(mockDb, mockProviderService); + + return service; +} + +// ─── matchConditions ────────────────────────────────────────────────────────── + +describe('RoutingEngineService.matchConditions', () => { + let service: RoutingEngineService; + + beforeEach(() => { + service = makeService(); + }); + + it('returns true for empty conditions array (catch-all rule)', () => { + const rule = makeRule({ + name: 'fallback', + priority: 99, + conditions: [], + action: { provider: 'anthropic', model: 'claude-sonnet-4-6' }, + }); + expect(service.matchConditions(rule, makeClassification())).toBe(true); + }); + + it('matches eq operator on scalar field', () => { + const rule = makeRule({ + name: 'coding', + priority: 1, + conditions: [{ field: 'taskType', operator: 'eq', value: 'coding' }], + action: { provider: 'anthropic', model: 'claude-opus-4-6' }, + }); + expect(service.matchConditions(rule, makeClassification({ taskType: 'coding' }))).toBe(true); + expect(service.matchConditions(rule, makeClassification({ taskType: 'conversation' }))).toBe( + false, + ); + }); + + it('matches in operator: field value is in the condition array', () => { + const rule = makeRule({ + name: 'simple or moderate', + priority: 2, + conditions: [{ field: 'complexity', operator: 'in', value: ['simple', 'moderate'] }], + action: { provider: 'anthropic', model: 'claude-haiku-4-5' }, + }); + expect(service.matchConditions(rule, makeClassification({ complexity: 'simple' }))).toBe(true); + expect(service.matchConditions(rule, makeClassification({ complexity: 'moderate' }))).toBe( + true, + ); + expect(service.matchConditions(rule, makeClassification({ complexity: 'complex' }))).toBe( + false, + ); + }); + + it('matches includes operator: field array includes the condition value', () => { + const rule = makeRule({ + name: 'reasoning required', + priority: 3, + conditions: [{ field: 'requiredCapabilities', operator: 'includes', value: 'reasoning' }], + action: { provider: 'anthropic', model: 'claude-opus-4-6' }, + }); + expect( + service.matchConditions(rule, makeClassification({ requiredCapabilities: ['reasoning'] })), + ).toBe(true); + expect( + service.matchConditions( + rule, + makeClassification({ requiredCapabilities: ['tools', 'reasoning'] }), + ), + ).toBe(true); + expect( + service.matchConditions(rule, makeClassification({ requiredCapabilities: ['tools'] })), + ).toBe(false); + expect(service.matchConditions(rule, makeClassification({ requiredCapabilities: [] }))).toBe( + false, + ); + }); + + it('requires ALL conditions to match (AND logic)', () => { + const rule = makeRule({ + name: 'complex coding', + priority: 1, + conditions: [ + { field: 'taskType', operator: 'eq', value: 'coding' }, + { field: 'complexity', operator: 'eq', value: 'complex' }, + ], + action: { provider: 'anthropic', model: 'claude-opus-4-6' }, + }); + + // Both match + expect( + service.matchConditions( + rule, + makeClassification({ taskType: 'coding', complexity: 'complex' }), + ), + ).toBe(true); + + // Only one matches + expect( + service.matchConditions( + rule, + makeClassification({ taskType: 'coding', complexity: 'simple' }), + ), + ).toBe(false); + + // Neither matches + expect( + service.matchConditions( + rule, + makeClassification({ taskType: 'conversation', complexity: 'simple' }), + ), + ).toBe(false); + }); + + it('returns false for eq when condition value is an array (type mismatch)', () => { + const rule = makeRule({ + name: 'bad eq', + priority: 1, + conditions: [{ field: 'taskType', operator: 'eq', value: ['coding', 'research'] }], + action: { provider: 'anthropic', model: 'claude-sonnet-4-6' }, + }); + expect(service.matchConditions(rule, makeClassification({ taskType: 'coding' }))).toBe(false); + }); + + it('returns false for includes when field is not an array', () => { + const rule = makeRule({ + name: 'bad includes', + priority: 1, + conditions: [{ field: 'taskType', operator: 'includes', value: 'coding' }], + action: { provider: 'anthropic', model: 'claude-sonnet-4-6' }, + }); + // taskType is a string, not an array — should be false + expect(service.matchConditions(rule, makeClassification({ taskType: 'coding' }))).toBe(false); + }); +}); + +// ─── resolve — priority ordering ───────────────────────────────────────────── + +describe('RoutingEngineService.resolve — priority ordering', () => { + it('selects the highest-priority matching rule', async () => { + // Rules are supplied in priority-ascending order, as the DB would return them. + const rules = [ + makeRule({ + name: 'high priority', + priority: 1, + conditions: [{ field: 'taskType', operator: 'eq', value: 'coding' }], + action: { provider: 'anthropic', model: 'claude-opus-4-6' }, + }), + makeRule({ + name: 'low priority', + priority: 10, + conditions: [{ field: 'taskType', operator: 'eq', value: 'coding' }], + action: { provider: 'openai', model: 'gpt-4o' }, + }), + ]; + + const service = makeService(rules, { anthropic: { status: 'up' }, openai: { status: 'up' } }); + + const decision = await service.resolve('implement a function'); + expect(decision.ruleName).toBe('high priority'); + expect(decision.provider).toBe('anthropic'); + expect(decision.model).toBe('claude-opus-4-6'); + }); + + it('skips non-matching rules and picks first match', async () => { + const rules = [ + makeRule({ + name: 'research rule', + priority: 1, + conditions: [{ field: 'taskType', operator: 'eq', value: 'research' }], + action: { provider: 'openai', model: 'gpt-4o' }, + }), + makeRule({ + name: 'coding rule', + priority: 2, + conditions: [{ field: 'taskType', operator: 'eq', value: 'coding' }], + action: { provider: 'anthropic', model: 'claude-sonnet-4-6' }, + }), + ]; + + const service = makeService(rules, { anthropic: { status: 'up' }, openai: { status: 'up' } }); + + const decision = await service.resolve('implement a function'); + expect(decision.ruleName).toBe('coding rule'); + expect(decision.provider).toBe('anthropic'); + }); +}); + +// ─── resolve — unhealthy provider fallback ──────────────────────────────────── + +describe('RoutingEngineService.resolve — unhealthy provider handling', () => { + it('skips matched rule when provider is unhealthy, tries next rule', async () => { + const rules = [ + makeRule({ + name: 'primary rule', + priority: 1, + conditions: [{ field: 'taskType', operator: 'eq', value: 'coding' }], + action: { provider: 'anthropic', model: 'claude-opus-4-6' }, + }), + makeRule({ + name: 'secondary rule', + priority: 2, + conditions: [{ field: 'taskType', operator: 'eq', value: 'coding' }], + action: { provider: 'openai', model: 'gpt-4o' }, + }), + ]; + + const service = makeService(rules, { + anthropic: { status: 'down' }, // primary is unhealthy + openai: { status: 'up' }, + }); + + const decision = await service.resolve('implement a function'); + expect(decision.ruleName).toBe('secondary rule'); + expect(decision.provider).toBe('openai'); + }); + + it('falls back to Sonnet when all rules have unhealthy providers', async () => { + // Override the rule's provider to something unhealthy but keep anthropic up for fallback + const unhealthyRules = [ + makeRule({ + name: 'only rule', + priority: 1, + conditions: [{ field: 'taskType', operator: 'eq', value: 'coding' }], + action: { provider: 'openai', model: 'gpt-4o' }, // openai is unhealthy + }), + ]; + + const service2 = makeService(unhealthyRules, { + anthropic: { status: 'up' }, + openai: { status: 'down' }, + }); + + const decision = await service2.resolve('implement a function'); + // Should fall through to Sonnet fallback on anthropic + expect(decision.provider).toBe('anthropic'); + expect(decision.model).toBe('claude-sonnet-4-6'); + expect(decision.ruleName).toBe('fallback'); + }); + + it('falls back to Haiku when Sonnet provider is also down', async () => { + const rules: RoutingRule[] = []; // no rules + + const service = makeService(rules, { + anthropic: { status: 'down' }, // Sonnet is on anthropic — down + ollama: { status: 'up' }, // Haiku is also on anthropic — use Ollama as next + }); + + const decision = await service.resolve('hello there'); + // Sonnet (anthropic) is down, Haiku (anthropic) is down, Ollama is up + expect(decision.provider).toBe('ollama'); + expect(decision.model).toBe('llama3.2'); + expect(decision.ruleName).toBe('fallback'); + }); + + it('uses last resort (Sonnet) when all fallback providers are unhealthy', async () => { + const rules: RoutingRule[] = []; + + const service = makeService(rules, { + anthropic: { status: 'down' }, + ollama: { status: 'down' }, + }); + + const decision = await service.resolve('hello'); + // All unhealthy — still returns first fallback entry as last resort + expect(decision.provider).toBe('anthropic'); + expect(decision.model).toBe('claude-sonnet-4-6'); + expect(decision.ruleName).toBe('fallback'); + }); +}); + +// ─── resolve — empty conditions (catch-all rule) ────────────────────────────── + +describe('RoutingEngineService.resolve — empty conditions (fallback rule)', () => { + it('matches catch-all rule for any message', async () => { + const rules = [ + makeRule({ + name: 'catch-all', + priority: 99, + conditions: [], + action: { provider: 'anthropic', model: 'claude-sonnet-4-6' }, + }), + ]; + + const service = makeService(rules, { anthropic: { status: 'up' } }); + + const decision = await service.resolve('completely unrelated message xyz'); + expect(decision.ruleName).toBe('catch-all'); + expect(decision.provider).toBe('anthropic'); + expect(decision.model).toBe('claude-sonnet-4-6'); + }); + + it('catch-all is overridden by a higher-priority specific rule', async () => { + const rules = [ + makeRule({ + name: 'specific coding rule', + priority: 1, + conditions: [{ field: 'taskType', operator: 'eq', value: 'coding' }], + action: { provider: 'anthropic', model: 'claude-opus-4-6' }, + }), + makeRule({ + name: 'catch-all', + priority: 99, + conditions: [], + action: { provider: 'anthropic', model: 'claude-haiku-4-5' }, + }), + ]; + + const service = makeService(rules, { anthropic: { status: 'up' } }); + + const codingDecision = await service.resolve('implement a function'); + expect(codingDecision.ruleName).toBe('specific coding rule'); + expect(codingDecision.model).toBe('claude-opus-4-6'); + + const conversationDecision = await service.resolve('hello how are you'); + expect(conversationDecision.ruleName).toBe('catch-all'); + expect(conversationDecision.model).toBe('claude-haiku-4-5'); + }); +}); + +// ─── resolve — disabled rules ───────────────────────────────────────────────── + +describe('RoutingEngineService.resolve — disabled rules', () => { + it('skips disabled rules', async () => { + const rules = [ + makeRule({ + name: 'disabled rule', + priority: 1, + enabled: false, + conditions: [{ field: 'taskType', operator: 'eq', value: 'coding' }], + action: { provider: 'anthropic', model: 'claude-opus-4-6' }, + }), + makeRule({ + name: 'enabled fallback', + priority: 99, + conditions: [], + action: { provider: 'anthropic', model: 'claude-sonnet-4-6' }, + }), + ]; + + const service = makeService(rules, { anthropic: { status: 'up' } }); + + const decision = await service.resolve('implement a function'); + expect(decision.ruleName).toBe('enabled fallback'); + expect(decision.model).toBe('claude-sonnet-4-6'); + }); +}); + +// ─── resolve — pre-fetched health map ──────────────────────────────────────── + +describe('RoutingEngineService.resolve — availableProviders override', () => { + it('uses the provided health map instead of calling healthCheckAll', async () => { + const rules = [ + makeRule({ + name: 'coding rule', + priority: 1, + conditions: [{ field: 'taskType', operator: 'eq', value: 'coding' }], + action: { provider: 'anthropic', model: 'claude-opus-4-6' }, + }), + ]; + + const mockHealthCheckAll = vi.fn().mockResolvedValue({}); + const mockDb = { + select: vi.fn().mockReturnValue({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + orderBy: vi.fn().mockResolvedValue( + rules.map((r) => ({ + id: r.id, + name: r.name, + priority: r.priority, + scope: r.scope, + userId: r.userId ?? null, + conditions: r.conditions, + action: r.action, + enabled: r.enabled, + createdAt: new Date(), + updatedAt: new Date(), + })), + ), + }), + }), + }), + }; + const mockProviderService = { healthCheckAll: mockHealthCheckAll }; + + const service = new (RoutingEngineService as unknown as new ( + db: unknown, + ps: unknown, + ) => RoutingEngineService)(mockDb, mockProviderService); + + const preSupplied = { anthropic: { status: 'up' } }; + await service.resolve('implement a function', undefined, preSupplied); + + expect(mockHealthCheckAll).not.toHaveBeenCalled(); + }); +});