import { Inject, Injectable, Logger, Optional, type OnModuleDestroy, type OnModuleInit, } from '@nestjs/common'; import { ModelRegistry, AuthStorage } from '@mariozechner/pi-coding-agent'; import { getModel, type Model, type Api } from '@mariozechner/pi-ai'; import type { CustomProviderConfig, IProviderAdapter, ModelInfo, ProviderHealth, ProviderInfo, } from '@mosaicstack/types'; import { AnthropicAdapter, OllamaAdapter, OpenAIAdapter, OpenRouterAdapter, ZaiAdapter, } from './adapters/index.js'; import type { TestConnectionResultDto } from './provider.dto.js'; import { ProviderCredentialsService } from './provider-credentials.service.js'; /** Default health check interval in seconds */ const DEFAULT_HEALTH_INTERVAL_SECS = 60; /** DI injection token for the provider adapter array. */ export const PROVIDER_ADAPTERS = Symbol('PROVIDER_ADAPTERS'); /** Environment variable names for well-known providers */ const PROVIDER_ENV_KEYS: Record = { anthropic: 'ANTHROPIC_API_KEY', openai: 'OPENAI_API_KEY', openrouter: 'OPENROUTER_API_KEY', zai: 'ZAI_API_KEY', }; @Injectable() export class ProviderService implements OnModuleInit, OnModuleDestroy { private readonly logger = new Logger(ProviderService.name); private registry!: ModelRegistry; constructor( @Optional() @Inject(ProviderCredentialsService) private readonly credentialsService: ProviderCredentialsService | null, ) {} /** * Adapters registered with this service. * Built-in adapters (Ollama) are always present; additional adapters can be * supplied via the PROVIDER_ADAPTERS injection token in the future. */ private adapters: IProviderAdapter[] = []; /** * Cached health status per provider, updated by the health check scheduler. */ private healthCache: Map = new Map(); /** Timer handle for the periodic health check scheduler */ private healthCheckTimer: ReturnType | null = null; async onModuleInit(): Promise { const authStorage = AuthStorage.inMemory(); this.registry = ModelRegistry.inMemory(authStorage); // Build the default set of adapters that rely on the registry this.adapters = [ new OllamaAdapter(this.registry), new AnthropicAdapter(this.registry), new OpenAIAdapter(this.registry), new OpenRouterAdapter(), new ZaiAdapter(), ]; // Run all adapter registrations first (Ollama, Anthropic, OpenAI, OpenRouter, Z.ai) await this.registerAll(); // Register API-key providers directly (custom) this.registerCustomProviders(); const available = this.registry.getAvailable(); this.logger.log(`Providers initialized: ${available.length} models available`); // Kick off the health check scheduler this.startHealthCheckScheduler(); } onModuleDestroy(): void { if (this.healthCheckTimer !== null) { clearInterval(this.healthCheckTimer); this.healthCheckTimer = null; } } // --------------------------------------------------------------------------- // Health check scheduler // --------------------------------------------------------------------------- /** * Start periodic health checks on all adapters. * Interval is configurable via PROVIDER_HEALTH_INTERVAL env (seconds, default 60). */ private startHealthCheckScheduler(): void { const intervalSecs = parseInt(process.env['PROVIDER_HEALTH_INTERVAL'] ?? '', 10) || DEFAULT_HEALTH_INTERVAL_SECS; const intervalMs = intervalSecs * 1000; // Run an initial check immediately (non-blocking) void this.runScheduledHealthChecks(); this.healthCheckTimer = setInterval(() => { void this.runScheduledHealthChecks(); }, intervalMs); this.logger.log(`Provider health check scheduler started (interval: ${intervalSecs}s)`); } private async runScheduledHealthChecks(): Promise { for (const adapter of this.adapters) { try { const health = await adapter.healthCheck(); const modelCount = adapter.listModels().length; this.healthCache.set(adapter.name, { ...health, modelCount }); this.logger.debug( `Health check [${adapter.name}]: ${health.status} (${health.latencyMs ?? 'n/a'}ms)`, ); } catch (err) { const modelCount = adapter.listModels().length; this.healthCache.set(adapter.name, { status: 'down', lastChecked: new Date().toISOString(), error: err instanceof Error ? err.message : String(err), modelCount, }); } } } /** * Return the cached health status for all adapters. * Format: array of { name, status, latencyMs, lastChecked, modelCount } */ getProvidersHealth(): Array<{ name: string; status: string; latencyMs?: number; lastChecked: string; modelCount: number; error?: string; }> { return this.adapters.map((adapter) => { const cached = this.healthCache.get(adapter.name); if (cached) { return { name: adapter.name, status: cached.status, latencyMs: cached.latencyMs, lastChecked: cached.lastChecked, modelCount: cached.modelCount, error: cached.error, }; } // Not yet checked — return a pending placeholder return { name: adapter.name, status: 'unknown', lastChecked: new Date().toISOString(), modelCount: adapter.listModels().length, }; }); } // --------------------------------------------------------------------------- // Adapter-pattern API // --------------------------------------------------------------------------- /** * Call register() on each adapter in order. * Errors from individual adapters are logged and do not abort the others. */ async registerAll(): Promise { for (const adapter of this.adapters) { try { await adapter.register(); } catch (err) { this.logger.error( `Adapter "${adapter.name}" registration failed`, err instanceof Error ? err.stack : String(err), ); } } } /** * Return the adapter registered under the given provider name, or undefined. */ getAdapter(providerName: string): IProviderAdapter | undefined { return this.adapters.find((a) => a.name === providerName); } /** * Run healthCheck() on all adapters and return results keyed by provider name. */ async healthCheckAll(): Promise> { const results: Record = {}; await Promise.all( this.adapters.map(async (adapter) => { try { results[adapter.name] = await adapter.healthCheck(); } catch (err) { results[adapter.name] = { status: 'down', lastChecked: new Date().toISOString(), error: err instanceof Error ? err.message : String(err), }; } }), ); return results; } // --------------------------------------------------------------------------- // Legacy / Pi-SDK-facing API (preserved for AgentService and RoutingService) // --------------------------------------------------------------------------- getRegistry(): ModelRegistry { return this.registry; } findModel(provider: string, modelId: string): Model | undefined { return this.registry.find(provider, modelId); } getDefaultModel(): Model | undefined { const available = this.registry.getAvailable(); return available[0]; } listProviders(): ProviderInfo[] { const allModels = this.registry.getAll(); const availableModels = this.registry.getAvailable(); const availableIds = new Set(availableModels.map((m) => `${m.provider}:${m.id}`)); const providerMap = new Map(); for (const model of allModels) { let info = providerMap.get(model.provider); if (!info) { info = { id: model.provider, name: model.provider, available: false, models: [], }; providerMap.set(model.provider, info); } const isAvailable = availableIds.has(`${model.provider}:${model.id}`); if (isAvailable) info.available = true; info.models.push(this.toModelInfo(model)); } return Array.from(providerMap.values()); } listAvailableModels(): ModelInfo[] { return this.registry.getAvailable().map((m) => this.toModelInfo(m)); } async testConnection(providerId: string, baseUrl?: string): Promise { // Delegate to the adapter when one exists and no URL override is given const adapter = this.getAdapter(providerId); if (adapter && !baseUrl) { const health = await adapter.healthCheck(); return { providerId, reachable: health.status !== 'down', latencyMs: health.latencyMs, error: health.error, }; } // Resolve baseUrl: explicit override > registered provider > ollama env let resolvedUrl = baseUrl; if (!resolvedUrl) { const allModels = this.registry.getAll(); const providerModels = allModels.filter((m) => m.provider === providerId); if (providerModels.length === 0) { return { providerId, reachable: false, error: `Provider '${providerId}' not found` }; } // For Ollama, derive the base URL from environment if (providerId === 'ollama') { const ollamaUrl = process.env['OLLAMA_BASE_URL'] ?? process.env['OLLAMA_HOST']; if (!ollamaUrl) { return { providerId, reachable: false, error: 'OLLAMA_BASE_URL not configured' }; } resolvedUrl = `${ollamaUrl}/v1/models`; } else { // For other providers, we can only do a basic check return { providerId, reachable: true, discoveredModels: providerModels.map((m) => m.id) }; } } else { resolvedUrl = resolvedUrl.replace(/\/?$/, '') + '/models'; } const start = Date.now(); try { const res = await fetch(resolvedUrl, { method: 'GET', headers: { Accept: 'application/json' }, signal: AbortSignal.timeout(5000), }); const latencyMs = Date.now() - start; if (!res.ok) { return { providerId, reachable: false, latencyMs, error: `HTTP ${res.status}` }; } let discoveredModels: string[] | undefined; try { const json = (await res.json()) as { models?: Array<{ id?: string; name?: string }> }; if (Array.isArray(json.models)) { discoveredModels = json.models.map((m) => m.id ?? m.name ?? '').filter(Boolean); } } catch { // ignore parse errors — endpoint was reachable } return { providerId, reachable: true, latencyMs, discoveredModels }; } catch (err) { const latencyMs = Date.now() - start; const message = err instanceof Error ? err.message : String(err); return { providerId, reachable: false, latencyMs, error: message }; } } registerCustomProvider(config: CustomProviderConfig): void { this.registry.registerProvider(config.id, { baseUrl: config.baseUrl, apiKey: config.apiKey, models: config.models.map((m) => ({ id: m.id, name: m.name, reasoning: m.reasoning ?? false, input: ['text'] as ('text' | 'image')[], cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, contextWindow: m.contextWindow ?? 4096, maxTokens: m.maxTokens ?? 4096, })), }); this.logger.log(`Registered custom provider: ${config.id} (${config.models.length} models)`); } // --------------------------------------------------------------------------- // Private helpers // --------------------------------------------------------------------------- private registerCustomProviders(): void { const customJson = process.env['MOSAIC_CUSTOM_PROVIDERS']; if (!customJson) return; try { const configs = JSON.parse(customJson) as CustomProviderConfig[]; for (const config of configs) { this.registerCustomProvider(config); } } catch (err) { this.logger.error('Failed to parse MOSAIC_CUSTOM_PROVIDERS', String(err)); } } /** * Resolve an API key for a provider, scoped to a specific user. * User-stored credentials take precedence over environment variables. * Returns null if no key is available from either source. */ async resolveApiKey(userId: string, provider: string): Promise { if (this.credentialsService) { const userKey = await this.credentialsService.retrieve(userId, provider); if (userKey) { this.logger.debug(`Using user-scoped credential for user=${userId} provider=${provider}`); return userKey; } } // Fall back to environment variable const envVar = PROVIDER_ENV_KEYS[provider]; const envKey = envVar ? (process.env[envVar] ?? null) : null; if (envKey) { this.logger.debug(`Using env-var credential for provider=${provider}`); } return envKey; } private cloneBuiltInModel( provider: string, modelId: string, overrides: Partial> = {}, ): Model { const model = getModel(provider as never, modelId as never) as Model | undefined; if (!model) { throw new Error(`Built-in model not found: ${provider}:${modelId}`); } return { ...model, ...overrides }; } private toModelInfo(model: Model): ModelInfo { return { id: model.id, provider: model.provider, name: model.name, reasoning: model.reasoning, contextWindow: model.contextWindow, maxTokens: model.maxTokens, inputTypes: model.input, cost: model.cost, }; } }