import { Injectable, Logger, type OnModuleInit } from '@nestjs/common'; import { ModelRegistry, AuthStorage } from '@mariozechner/pi-coding-agent'; import { getModel, type Model, type Api } from '@mariozechner/pi-ai'; import type { CustomProviderConfig, IProviderAdapter, ModelInfo, ProviderHealth, ProviderInfo, } from '@mosaic/types'; import { OllamaAdapter, OpenAIAdapter } from './adapters/index.js'; import type { TestConnectionResultDto } from './provider.dto.js'; /** DI injection token for the provider adapter array. */ export const PROVIDER_ADAPTERS = Symbol('PROVIDER_ADAPTERS'); @Injectable() export class ProviderService implements OnModuleInit { private readonly logger = new Logger(ProviderService.name); private registry!: ModelRegistry; /** * Adapters registered with this service. * Built-in adapters (Ollama) are always present; additional adapters can be * supplied via the PROVIDER_ADAPTERS injection token in the future. */ private adapters: IProviderAdapter[] = []; async onModuleInit(): Promise { const authStorage = AuthStorage.inMemory(); this.registry = new ModelRegistry(authStorage); // Build the default set of adapters that rely on the registry this.adapters = [new OllamaAdapter(this.registry), new OpenAIAdapter(this.registry)]; // Run all adapter registrations first (Ollama, and any future adapters) await this.registerAll(); // Register API-key providers directly (Anthropic, Z.ai, custom) // OpenAI now has a dedicated adapter class (M3-003). this.registerAnthropicProvider(); this.registerZaiProvider(); this.registerCustomProviders(); const available = this.registry.getAvailable(); this.logger.log(`Providers initialized: ${available.length} models available`); } // --------------------------------------------------------------------------- // Adapter-pattern API // --------------------------------------------------------------------------- /** * Call register() on each adapter in order. * Errors from individual adapters are logged and do not abort the others. */ async registerAll(): Promise { for (const adapter of this.adapters) { try { await adapter.register(); } catch (err) { this.logger.error( `Adapter "${adapter.name}" registration failed`, err instanceof Error ? err.stack : String(err), ); } } } /** * Return the adapter registered under the given provider name, or undefined. */ getAdapter(providerName: string): IProviderAdapter | undefined { return this.adapters.find((a) => a.name === providerName); } /** * Run healthCheck() on all adapters and return results keyed by provider name. */ async healthCheckAll(): Promise> { const results: Record = {}; await Promise.all( this.adapters.map(async (adapter) => { try { results[adapter.name] = await adapter.healthCheck(); } catch (err) { results[adapter.name] = { status: 'down', lastChecked: new Date().toISOString(), error: err instanceof Error ? err.message : String(err), }; } }), ); return results; } // --------------------------------------------------------------------------- // Legacy / Pi-SDK-facing API (preserved for AgentService and RoutingService) // --------------------------------------------------------------------------- getRegistry(): ModelRegistry { return this.registry; } findModel(provider: string, modelId: string): Model | undefined { return this.registry.find(provider, modelId); } getDefaultModel(): Model | undefined { const available = this.registry.getAvailable(); return available[0]; } listProviders(): ProviderInfo[] { const allModels = this.registry.getAll(); const availableModels = this.registry.getAvailable(); const availableIds = new Set(availableModels.map((m) => `${m.provider}:${m.id}`)); const providerMap = new Map(); for (const model of allModels) { let info = providerMap.get(model.provider); if (!info) { info = { id: model.provider, name: model.provider, available: false, models: [], }; providerMap.set(model.provider, info); } const isAvailable = availableIds.has(`${model.provider}:${model.id}`); if (isAvailable) info.available = true; info.models.push(this.toModelInfo(model)); } return Array.from(providerMap.values()); } listAvailableModels(): ModelInfo[] { return this.registry.getAvailable().map((m) => this.toModelInfo(m)); } async testConnection(providerId: string, baseUrl?: string): Promise { // Delegate to the adapter when one exists and no URL override is given const adapter = this.getAdapter(providerId); if (adapter && !baseUrl) { const health = await adapter.healthCheck(); return { providerId, reachable: health.status !== 'down', latencyMs: health.latencyMs, error: health.error, }; } // Resolve baseUrl: explicit override > registered provider > ollama env let resolvedUrl = baseUrl; if (!resolvedUrl) { const allModels = this.registry.getAll(); const providerModels = allModels.filter((m) => m.provider === providerId); if (providerModels.length === 0) { return { providerId, reachable: false, error: `Provider '${providerId}' not found` }; } // For Ollama, derive the base URL from environment if (providerId === 'ollama') { const ollamaUrl = process.env['OLLAMA_BASE_URL'] ?? process.env['OLLAMA_HOST']; if (!ollamaUrl) { return { providerId, reachable: false, error: 'OLLAMA_BASE_URL not configured' }; } resolvedUrl = `${ollamaUrl}/v1/models`; } else { // For other providers, we can only do a basic check return { providerId, reachable: true, discoveredModels: providerModels.map((m) => m.id) }; } } else { resolvedUrl = resolvedUrl.replace(/\/?$/, '') + '/models'; } const start = Date.now(); try { const res = await fetch(resolvedUrl, { method: 'GET', headers: { Accept: 'application/json' }, signal: AbortSignal.timeout(5000), }); const latencyMs = Date.now() - start; if (!res.ok) { return { providerId, reachable: false, latencyMs, error: `HTTP ${res.status}` }; } let discoveredModels: string[] | undefined; try { const json = (await res.json()) as { models?: Array<{ id?: string; name?: string }> }; if (Array.isArray(json.models)) { discoveredModels = json.models.map((m) => m.id ?? m.name ?? '').filter(Boolean); } } catch { // ignore parse errors — endpoint was reachable } return { providerId, reachable: true, latencyMs, discoveredModels }; } catch (err) { const latencyMs = Date.now() - start; const message = err instanceof Error ? err.message : String(err); return { providerId, reachable: false, latencyMs, error: message }; } } registerCustomProvider(config: CustomProviderConfig): void { this.registry.registerProvider(config.id, { baseUrl: config.baseUrl, apiKey: config.apiKey, models: config.models.map((m) => ({ id: m.id, name: m.name, reasoning: m.reasoning ?? false, input: ['text'] as ('text' | 'image')[], cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, contextWindow: m.contextWindow ?? 4096, maxTokens: m.maxTokens ?? 4096, })), }); this.logger.log(`Registered custom provider: ${config.id} (${config.models.length} models)`); } // --------------------------------------------------------------------------- // Private helpers — direct registry registration for providers without adapters yet // (Anthropic, Z.ai will move to adapters in M3-002, M3-005) // --------------------------------------------------------------------------- private registerAnthropicProvider(): void { const apiKey = process.env['ANTHROPIC_API_KEY']; if (!apiKey) { this.logger.debug('Skipping Anthropic provider registration: ANTHROPIC_API_KEY not set'); return; } const models = ['claude-sonnet-4-6', 'claude-opus-4-6', 'claude-haiku-4-5'].map((id) => this.cloneBuiltInModel('anthropic', id, { maxTokens: 8192 }), ); this.registry.registerProvider('anthropic', { apiKey, baseUrl: 'https://api.anthropic.com', models, }); this.logger.log('Anthropic provider registered with 3 models'); } private registerZaiProvider(): void { const apiKey = process.env['ZAI_API_KEY']; if (!apiKey) { this.logger.debug('Skipping Z.ai provider registration: ZAI_API_KEY not set'); return; } const models = ['glm-4.5', 'glm-4.5-air', 'glm-4.5-flash'].map((id) => this.cloneBuiltInModel('zai', id), ); this.registry.registerProvider('zai', { apiKey, baseUrl: 'https://open.bigmodel.cn/api/paas/v4', models, }); this.logger.log('Z.ai provider registered with 3 models'); } private registerCustomProviders(): void { const customJson = process.env['MOSAIC_CUSTOM_PROVIDERS']; if (!customJson) return; try { const configs = JSON.parse(customJson) as CustomProviderConfig[]; for (const config of configs) { this.registerCustomProvider(config); } } catch (err) { this.logger.error('Failed to parse MOSAIC_CUSTOM_PROVIDERS', String(err)); } } private cloneBuiltInModel( provider: string, modelId: string, overrides: Partial> = {}, ): Model { const model = getModel(provider as never, modelId as never) as Model | undefined; if (!model) { throw new Error(`Built-in model not found: ${provider}:${modelId}`); } return { ...model, ...overrides }; } private toModelInfo(model: Model): ModelInfo { return { id: model.id, provider: model.provider, name: model.name, reasoning: model.reasoning, contextWindow: model.contextWindow, maxTokens: model.maxTokens, inputTypes: model.input, cost: model.cost, }; } }