Files
stack/apps/gateway/src/agent/provider.service.ts
Jason Woltje 10761f3e47
Some checks failed
ci/woodpecker/push/ci Pipeline failed
ci/woodpecker/pr/ci Pipeline failed
feat(providers): OpenRouter adapter + Ollama embedding support — M3-004/006 (#311)
Co-authored-by: Jason Woltje <jason@diversecanvas.com>
Co-committed-by: Jason Woltje <jason@diversecanvas.com>
2026-03-21 21:38:09 +00:00

407 lines
13 KiB
TypeScript

import { Injectable, Logger, type OnModuleDestroy, type OnModuleInit } from '@nestjs/common';
import { ModelRegistry, AuthStorage } from '@mariozechner/pi-coding-agent';
import { getModel, type Model, type Api } from '@mariozechner/pi-ai';
import type {
CustomProviderConfig,
IProviderAdapter,
ModelInfo,
ProviderHealth,
ProviderInfo,
} from '@mosaic/types';
import {
AnthropicAdapter,
OllamaAdapter,
OpenAIAdapter,
OpenRouterAdapter,
} from './adapters/index.js';
import type { TestConnectionResultDto } from './provider.dto.js';
/** Default health check interval in seconds */
const DEFAULT_HEALTH_INTERVAL_SECS = 60;
/** DI injection token for the provider adapter array. */
export const PROVIDER_ADAPTERS = Symbol('PROVIDER_ADAPTERS');
@Injectable()
export class ProviderService implements OnModuleInit, OnModuleDestroy {
private readonly logger = new Logger(ProviderService.name);
private registry!: ModelRegistry;
/**
* Adapters registered with this service.
* Built-in adapters (Ollama) are always present; additional adapters can be
* supplied via the PROVIDER_ADAPTERS injection token in the future.
*/
private adapters: IProviderAdapter[] = [];
/**
* Cached health status per provider, updated by the health check scheduler.
*/
private healthCache: Map<string, ProviderHealth & { modelCount: number }> = new Map();
/** Timer handle for the periodic health check scheduler */
private healthCheckTimer: ReturnType<typeof setInterval> | null = null;
async onModuleInit(): Promise<void> {
const authStorage = AuthStorage.inMemory();
this.registry = new ModelRegistry(authStorage);
// Build the default set of adapters that rely on the registry
this.adapters = [
new OllamaAdapter(this.registry),
new AnthropicAdapter(this.registry),
new OpenAIAdapter(this.registry),
new OpenRouterAdapter(),
];
// Run all adapter registrations first (Ollama, Anthropic, and any future adapters)
await this.registerAll();
// Register API-key providers directly (Z.ai, custom)
// OpenAI now has a dedicated adapter (M3-003).
this.registerZaiProvider();
this.registerCustomProviders();
const available = this.registry.getAvailable();
this.logger.log(`Providers initialized: ${available.length} models available`);
// Kick off the health check scheduler
this.startHealthCheckScheduler();
}
onModuleDestroy(): void {
if (this.healthCheckTimer !== null) {
clearInterval(this.healthCheckTimer);
this.healthCheckTimer = null;
}
}
// ---------------------------------------------------------------------------
// Health check scheduler
// ---------------------------------------------------------------------------
/**
* Start periodic health checks on all adapters.
* Interval is configurable via PROVIDER_HEALTH_INTERVAL env (seconds, default 60).
*/
private startHealthCheckScheduler(): void {
const intervalSecs =
parseInt(process.env['PROVIDER_HEALTH_INTERVAL'] ?? '', 10) || DEFAULT_HEALTH_INTERVAL_SECS;
const intervalMs = intervalSecs * 1000;
// Run an initial check immediately (non-blocking)
void this.runScheduledHealthChecks();
this.healthCheckTimer = setInterval(() => {
void this.runScheduledHealthChecks();
}, intervalMs);
this.logger.log(`Provider health check scheduler started (interval: ${intervalSecs}s)`);
}
private async runScheduledHealthChecks(): Promise<void> {
for (const adapter of this.adapters) {
try {
const health = await adapter.healthCheck();
const modelCount = adapter.listModels().length;
this.healthCache.set(adapter.name, { ...health, modelCount });
this.logger.debug(
`Health check [${adapter.name}]: ${health.status} (${health.latencyMs ?? 'n/a'}ms)`,
);
} catch (err) {
const modelCount = adapter.listModels().length;
this.healthCache.set(adapter.name, {
status: 'down',
lastChecked: new Date().toISOString(),
error: err instanceof Error ? err.message : String(err),
modelCount,
});
}
}
}
/**
* Return the cached health status for all adapters.
* Format: array of { name, status, latencyMs, lastChecked, modelCount }
*/
getProvidersHealth(): Array<{
name: string;
status: string;
latencyMs?: number;
lastChecked: string;
modelCount: number;
error?: string;
}> {
return this.adapters.map((adapter) => {
const cached = this.healthCache.get(adapter.name);
if (cached) {
return {
name: adapter.name,
status: cached.status,
latencyMs: cached.latencyMs,
lastChecked: cached.lastChecked,
modelCount: cached.modelCount,
error: cached.error,
};
}
// Not yet checked — return a pending placeholder
return {
name: adapter.name,
status: 'unknown',
lastChecked: new Date().toISOString(),
modelCount: adapter.listModels().length,
};
});
}
// ---------------------------------------------------------------------------
// Adapter-pattern API
// ---------------------------------------------------------------------------
/**
* Call register() on each adapter in order.
* Errors from individual adapters are logged and do not abort the others.
*/
async registerAll(): Promise<void> {
for (const adapter of this.adapters) {
try {
await adapter.register();
} catch (err) {
this.logger.error(
`Adapter "${adapter.name}" registration failed`,
err instanceof Error ? err.stack : String(err),
);
}
}
}
/**
* Return the adapter registered under the given provider name, or undefined.
*/
getAdapter(providerName: string): IProviderAdapter | undefined {
return this.adapters.find((a) => a.name === providerName);
}
/**
* Run healthCheck() on all adapters and return results keyed by provider name.
*/
async healthCheckAll(): Promise<Record<string, ProviderHealth>> {
const results: Record<string, ProviderHealth> = {};
await Promise.all(
this.adapters.map(async (adapter) => {
try {
results[adapter.name] = await adapter.healthCheck();
} catch (err) {
results[adapter.name] = {
status: 'down',
lastChecked: new Date().toISOString(),
error: err instanceof Error ? err.message : String(err),
};
}
}),
);
return results;
}
// ---------------------------------------------------------------------------
// Legacy / Pi-SDK-facing API (preserved for AgentService and RoutingService)
// ---------------------------------------------------------------------------
getRegistry(): ModelRegistry {
return this.registry;
}
findModel(provider: string, modelId: string): Model<Api> | undefined {
return this.registry.find(provider, modelId);
}
getDefaultModel(): Model<Api> | undefined {
const available = this.registry.getAvailable();
return available[0];
}
listProviders(): ProviderInfo[] {
const allModels = this.registry.getAll();
const availableModels = this.registry.getAvailable();
const availableIds = new Set(availableModels.map((m) => `${m.provider}:${m.id}`));
const providerMap = new Map<string, ProviderInfo>();
for (const model of allModels) {
let info = providerMap.get(model.provider);
if (!info) {
info = {
id: model.provider,
name: model.provider,
available: false,
models: [],
};
providerMap.set(model.provider, info);
}
const isAvailable = availableIds.has(`${model.provider}:${model.id}`);
if (isAvailable) info.available = true;
info.models.push(this.toModelInfo(model));
}
return Array.from(providerMap.values());
}
listAvailableModels(): ModelInfo[] {
return this.registry.getAvailable().map((m) => this.toModelInfo(m));
}
async testConnection(providerId: string, baseUrl?: string): Promise<TestConnectionResultDto> {
// Delegate to the adapter when one exists and no URL override is given
const adapter = this.getAdapter(providerId);
if (adapter && !baseUrl) {
const health = await adapter.healthCheck();
return {
providerId,
reachable: health.status !== 'down',
latencyMs: health.latencyMs,
error: health.error,
};
}
// Resolve baseUrl: explicit override > registered provider > ollama env
let resolvedUrl = baseUrl;
if (!resolvedUrl) {
const allModels = this.registry.getAll();
const providerModels = allModels.filter((m) => m.provider === providerId);
if (providerModels.length === 0) {
return { providerId, reachable: false, error: `Provider '${providerId}' not found` };
}
// For Ollama, derive the base URL from environment
if (providerId === 'ollama') {
const ollamaUrl = process.env['OLLAMA_BASE_URL'] ?? process.env['OLLAMA_HOST'];
if (!ollamaUrl) {
return { providerId, reachable: false, error: 'OLLAMA_BASE_URL not configured' };
}
resolvedUrl = `${ollamaUrl}/v1/models`;
} else {
// For other providers, we can only do a basic check
return { providerId, reachable: true, discoveredModels: providerModels.map((m) => m.id) };
}
} else {
resolvedUrl = resolvedUrl.replace(/\/?$/, '') + '/models';
}
const start = Date.now();
try {
const res = await fetch(resolvedUrl, {
method: 'GET',
headers: { Accept: 'application/json' },
signal: AbortSignal.timeout(5000),
});
const latencyMs = Date.now() - start;
if (!res.ok) {
return { providerId, reachable: false, latencyMs, error: `HTTP ${res.status}` };
}
let discoveredModels: string[] | undefined;
try {
const json = (await res.json()) as { models?: Array<{ id?: string; name?: string }> };
if (Array.isArray(json.models)) {
discoveredModels = json.models.map((m) => m.id ?? m.name ?? '').filter(Boolean);
}
} catch {
// ignore parse errors — endpoint was reachable
}
return { providerId, reachable: true, latencyMs, discoveredModels };
} catch (err) {
const latencyMs = Date.now() - start;
const message = err instanceof Error ? err.message : String(err);
return { providerId, reachable: false, latencyMs, error: message };
}
}
registerCustomProvider(config: CustomProviderConfig): void {
this.registry.registerProvider(config.id, {
baseUrl: config.baseUrl,
apiKey: config.apiKey,
models: config.models.map((m) => ({
id: m.id,
name: m.name,
reasoning: m.reasoning ?? false,
input: ['text'] as ('text' | 'image')[],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: m.contextWindow ?? 4096,
maxTokens: m.maxTokens ?? 4096,
})),
});
this.logger.log(`Registered custom provider: ${config.id} (${config.models.length} models)`);
}
// ---------------------------------------------------------------------------
// Private helpers — direct registry registration for providers without adapters yet
// (Z.ai will move to an adapter in M3-005)
// ---------------------------------------------------------------------------
private registerZaiProvider(): void {
const apiKey = process.env['ZAI_API_KEY'];
if (!apiKey) {
this.logger.debug('Skipping Z.ai provider registration: ZAI_API_KEY not set');
return;
}
const models = ['glm-4.5', 'glm-4.5-air', 'glm-4.5-flash'].map((id) =>
this.cloneBuiltInModel('zai', id),
);
this.registry.registerProvider('zai', {
apiKey,
baseUrl: 'https://open.bigmodel.cn/api/paas/v4',
models,
});
this.logger.log('Z.ai provider registered with 3 models');
}
private registerCustomProviders(): void {
const customJson = process.env['MOSAIC_CUSTOM_PROVIDERS'];
if (!customJson) return;
try {
const configs = JSON.parse(customJson) as CustomProviderConfig[];
for (const config of configs) {
this.registerCustomProvider(config);
}
} catch (err) {
this.logger.error('Failed to parse MOSAIC_CUSTOM_PROVIDERS', String(err));
}
}
private cloneBuiltInModel(
provider: string,
modelId: string,
overrides: Partial<Model<Api>> = {},
): Model<Api> {
const model = getModel(provider as never, modelId as never) as Model<Api> | undefined;
if (!model) {
throw new Error(`Built-in model not found: ${provider}:${modelId}`);
}
return { ...model, ...overrides };
}
private toModelInfo(model: Model<Api>): ModelInfo {
return {
id: model.id,
provider: model.provider,
name: model.name,
reasoning: model.reasoning,
contextWindow: model.contextWindow,
maxTokens: model.maxTokens,
inputTypes: model.input,
cost: model.cost,
};
}
}