- Updated all package.json name fields and dependency references - Updated all TypeScript/JavaScript imports - Updated .woodpecker/publish.yml filters and registry paths - Updated tools/install.sh scope default - Updated .npmrc registry paths (worktree + host) - Enhanced update-checker.ts with checkForAllUpdates() multi-package support - Updated CLI update command to show table of all packages - Added KNOWN_PACKAGES, formatAllPackagesTable, getInstallAllCommand - Marked checkForUpdate() with @deprecated JSDoc Closes #391
431 lines
14 KiB
TypeScript
431 lines
14 KiB
TypeScript
import {
|
|
Inject,
|
|
Injectable,
|
|
Logger,
|
|
Optional,
|
|
type OnModuleDestroy,
|
|
type OnModuleInit,
|
|
} from '@nestjs/common';
|
|
import { ModelRegistry, AuthStorage } from '@mariozechner/pi-coding-agent';
|
|
import { getModel, type Model, type Api } from '@mariozechner/pi-ai';
|
|
import type {
|
|
CustomProviderConfig,
|
|
IProviderAdapter,
|
|
ModelInfo,
|
|
ProviderHealth,
|
|
ProviderInfo,
|
|
} from '@mosaicstack/types';
|
|
import {
|
|
AnthropicAdapter,
|
|
OllamaAdapter,
|
|
OpenAIAdapter,
|
|
OpenRouterAdapter,
|
|
ZaiAdapter,
|
|
} from './adapters/index.js';
|
|
import type { TestConnectionResultDto } from './provider.dto.js';
|
|
import { ProviderCredentialsService } from './provider-credentials.service.js';
|
|
|
|
/** Default health check interval in seconds */
|
|
const DEFAULT_HEALTH_INTERVAL_SECS = 60;
|
|
|
|
/** DI injection token for the provider adapter array. */
|
|
export const PROVIDER_ADAPTERS = Symbol('PROVIDER_ADAPTERS');
|
|
|
|
/** Environment variable names for well-known providers */
|
|
const PROVIDER_ENV_KEYS: Record<string, string> = {
|
|
anthropic: 'ANTHROPIC_API_KEY',
|
|
openai: 'OPENAI_API_KEY',
|
|
openrouter: 'OPENROUTER_API_KEY',
|
|
zai: 'ZAI_API_KEY',
|
|
};
|
|
|
|
@Injectable()
|
|
export class ProviderService implements OnModuleInit, OnModuleDestroy {
|
|
private readonly logger = new Logger(ProviderService.name);
|
|
private registry!: ModelRegistry;
|
|
|
|
constructor(
|
|
@Optional()
|
|
@Inject(ProviderCredentialsService)
|
|
private readonly credentialsService: ProviderCredentialsService | null,
|
|
) {}
|
|
|
|
/**
|
|
* Adapters registered with this service.
|
|
* Built-in adapters (Ollama) are always present; additional adapters can be
|
|
* supplied via the PROVIDER_ADAPTERS injection token in the future.
|
|
*/
|
|
private adapters: IProviderAdapter[] = [];
|
|
|
|
/**
|
|
* Cached health status per provider, updated by the health check scheduler.
|
|
*/
|
|
private healthCache: Map<string, ProviderHealth & { modelCount: number }> = new Map();
|
|
|
|
/** Timer handle for the periodic health check scheduler */
|
|
private healthCheckTimer: ReturnType<typeof setInterval> | null = null;
|
|
|
|
async onModuleInit(): Promise<void> {
|
|
const authStorage = AuthStorage.inMemory();
|
|
this.registry = ModelRegistry.inMemory(authStorage);
|
|
|
|
// Build the default set of adapters that rely on the registry
|
|
this.adapters = [
|
|
new OllamaAdapter(this.registry),
|
|
new AnthropicAdapter(this.registry),
|
|
new OpenAIAdapter(this.registry),
|
|
new OpenRouterAdapter(),
|
|
new ZaiAdapter(),
|
|
];
|
|
|
|
// Run all adapter registrations first (Ollama, Anthropic, OpenAI, OpenRouter, Z.ai)
|
|
await this.registerAll();
|
|
|
|
// Register API-key providers directly (custom)
|
|
this.registerCustomProviders();
|
|
|
|
const available = this.registry.getAvailable();
|
|
this.logger.log(`Providers initialized: ${available.length} models available`);
|
|
|
|
// Kick off the health check scheduler
|
|
this.startHealthCheckScheduler();
|
|
}
|
|
|
|
onModuleDestroy(): void {
|
|
if (this.healthCheckTimer !== null) {
|
|
clearInterval(this.healthCheckTimer);
|
|
this.healthCheckTimer = null;
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Health check scheduler
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/**
|
|
* Start periodic health checks on all adapters.
|
|
* Interval is configurable via PROVIDER_HEALTH_INTERVAL env (seconds, default 60).
|
|
*/
|
|
private startHealthCheckScheduler(): void {
|
|
const intervalSecs =
|
|
parseInt(process.env['PROVIDER_HEALTH_INTERVAL'] ?? '', 10) || DEFAULT_HEALTH_INTERVAL_SECS;
|
|
const intervalMs = intervalSecs * 1000;
|
|
|
|
// Run an initial check immediately (non-blocking)
|
|
void this.runScheduledHealthChecks();
|
|
|
|
this.healthCheckTimer = setInterval(() => {
|
|
void this.runScheduledHealthChecks();
|
|
}, intervalMs);
|
|
|
|
this.logger.log(`Provider health check scheduler started (interval: ${intervalSecs}s)`);
|
|
}
|
|
|
|
private async runScheduledHealthChecks(): Promise<void> {
|
|
for (const adapter of this.adapters) {
|
|
try {
|
|
const health = await adapter.healthCheck();
|
|
const modelCount = adapter.listModels().length;
|
|
this.healthCache.set(adapter.name, { ...health, modelCount });
|
|
this.logger.debug(
|
|
`Health check [${adapter.name}]: ${health.status} (${health.latencyMs ?? 'n/a'}ms)`,
|
|
);
|
|
} catch (err) {
|
|
const modelCount = adapter.listModels().length;
|
|
this.healthCache.set(adapter.name, {
|
|
status: 'down',
|
|
lastChecked: new Date().toISOString(),
|
|
error: err instanceof Error ? err.message : String(err),
|
|
modelCount,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Return the cached health status for all adapters.
|
|
* Format: array of { name, status, latencyMs, lastChecked, modelCount }
|
|
*/
|
|
getProvidersHealth(): Array<{
|
|
name: string;
|
|
status: string;
|
|
latencyMs?: number;
|
|
lastChecked: string;
|
|
modelCount: number;
|
|
error?: string;
|
|
}> {
|
|
return this.adapters.map((adapter) => {
|
|
const cached = this.healthCache.get(adapter.name);
|
|
if (cached) {
|
|
return {
|
|
name: adapter.name,
|
|
status: cached.status,
|
|
latencyMs: cached.latencyMs,
|
|
lastChecked: cached.lastChecked,
|
|
modelCount: cached.modelCount,
|
|
error: cached.error,
|
|
};
|
|
}
|
|
// Not yet checked — return a pending placeholder
|
|
return {
|
|
name: adapter.name,
|
|
status: 'unknown',
|
|
lastChecked: new Date().toISOString(),
|
|
modelCount: adapter.listModels().length,
|
|
};
|
|
});
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Adapter-pattern API
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/**
|
|
* Call register() on each adapter in order.
|
|
* Errors from individual adapters are logged and do not abort the others.
|
|
*/
|
|
async registerAll(): Promise<void> {
|
|
for (const adapter of this.adapters) {
|
|
try {
|
|
await adapter.register();
|
|
} catch (err) {
|
|
this.logger.error(
|
|
`Adapter "${adapter.name}" registration failed`,
|
|
err instanceof Error ? err.stack : String(err),
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Return the adapter registered under the given provider name, or undefined.
|
|
*/
|
|
getAdapter(providerName: string): IProviderAdapter | undefined {
|
|
return this.adapters.find((a) => a.name === providerName);
|
|
}
|
|
|
|
/**
|
|
* Run healthCheck() on all adapters and return results keyed by provider name.
|
|
*/
|
|
async healthCheckAll(): Promise<Record<string, ProviderHealth>> {
|
|
const results: Record<string, ProviderHealth> = {};
|
|
await Promise.all(
|
|
this.adapters.map(async (adapter) => {
|
|
try {
|
|
results[adapter.name] = await adapter.healthCheck();
|
|
} catch (err) {
|
|
results[adapter.name] = {
|
|
status: 'down',
|
|
lastChecked: new Date().toISOString(),
|
|
error: err instanceof Error ? err.message : String(err),
|
|
};
|
|
}
|
|
}),
|
|
);
|
|
return results;
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Legacy / Pi-SDK-facing API (preserved for AgentService and RoutingService)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
getRegistry(): ModelRegistry {
|
|
return this.registry;
|
|
}
|
|
|
|
findModel(provider: string, modelId: string): Model<Api> | undefined {
|
|
return this.registry.find(provider, modelId);
|
|
}
|
|
|
|
getDefaultModel(): Model<Api> | undefined {
|
|
const available = this.registry.getAvailable();
|
|
return available[0];
|
|
}
|
|
|
|
listProviders(): ProviderInfo[] {
|
|
const allModels = this.registry.getAll();
|
|
const availableModels = this.registry.getAvailable();
|
|
const availableIds = new Set(availableModels.map((m) => `${m.provider}:${m.id}`));
|
|
|
|
const providerMap = new Map<string, ProviderInfo>();
|
|
|
|
for (const model of allModels) {
|
|
let info = providerMap.get(model.provider);
|
|
if (!info) {
|
|
info = {
|
|
id: model.provider,
|
|
name: model.provider,
|
|
available: false,
|
|
models: [],
|
|
};
|
|
providerMap.set(model.provider, info);
|
|
}
|
|
|
|
const isAvailable = availableIds.has(`${model.provider}:${model.id}`);
|
|
if (isAvailable) info.available = true;
|
|
|
|
info.models.push(this.toModelInfo(model));
|
|
}
|
|
|
|
return Array.from(providerMap.values());
|
|
}
|
|
|
|
listAvailableModels(): ModelInfo[] {
|
|
return this.registry.getAvailable().map((m) => this.toModelInfo(m));
|
|
}
|
|
|
|
async testConnection(providerId: string, baseUrl?: string): Promise<TestConnectionResultDto> {
|
|
// Delegate to the adapter when one exists and no URL override is given
|
|
const adapter = this.getAdapter(providerId);
|
|
if (adapter && !baseUrl) {
|
|
const health = await adapter.healthCheck();
|
|
return {
|
|
providerId,
|
|
reachable: health.status !== 'down',
|
|
latencyMs: health.latencyMs,
|
|
error: health.error,
|
|
};
|
|
}
|
|
|
|
// Resolve baseUrl: explicit override > registered provider > ollama env
|
|
let resolvedUrl = baseUrl;
|
|
|
|
if (!resolvedUrl) {
|
|
const allModels = this.registry.getAll();
|
|
const providerModels = allModels.filter((m) => m.provider === providerId);
|
|
if (providerModels.length === 0) {
|
|
return { providerId, reachable: false, error: `Provider '${providerId}' not found` };
|
|
}
|
|
// For Ollama, derive the base URL from environment
|
|
if (providerId === 'ollama') {
|
|
const ollamaUrl = process.env['OLLAMA_BASE_URL'] ?? process.env['OLLAMA_HOST'];
|
|
if (!ollamaUrl) {
|
|
return { providerId, reachable: false, error: 'OLLAMA_BASE_URL not configured' };
|
|
}
|
|
resolvedUrl = `${ollamaUrl}/v1/models`;
|
|
} else {
|
|
// For other providers, we can only do a basic check
|
|
return { providerId, reachable: true, discoveredModels: providerModels.map((m) => m.id) };
|
|
}
|
|
} else {
|
|
resolvedUrl = resolvedUrl.replace(/\/?$/, '') + '/models';
|
|
}
|
|
|
|
const start = Date.now();
|
|
try {
|
|
const res = await fetch(resolvedUrl, {
|
|
method: 'GET',
|
|
headers: { Accept: 'application/json' },
|
|
signal: AbortSignal.timeout(5000),
|
|
});
|
|
|
|
const latencyMs = Date.now() - start;
|
|
|
|
if (!res.ok) {
|
|
return { providerId, reachable: false, latencyMs, error: `HTTP ${res.status}` };
|
|
}
|
|
|
|
let discoveredModels: string[] | undefined;
|
|
try {
|
|
const json = (await res.json()) as { models?: Array<{ id?: string; name?: string }> };
|
|
if (Array.isArray(json.models)) {
|
|
discoveredModels = json.models.map((m) => m.id ?? m.name ?? '').filter(Boolean);
|
|
}
|
|
} catch {
|
|
// ignore parse errors — endpoint was reachable
|
|
}
|
|
|
|
return { providerId, reachable: true, latencyMs, discoveredModels };
|
|
} catch (err) {
|
|
const latencyMs = Date.now() - start;
|
|
const message = err instanceof Error ? err.message : String(err);
|
|
return { providerId, reachable: false, latencyMs, error: message };
|
|
}
|
|
}
|
|
|
|
registerCustomProvider(config: CustomProviderConfig): void {
|
|
this.registry.registerProvider(config.id, {
|
|
baseUrl: config.baseUrl,
|
|
apiKey: config.apiKey,
|
|
models: config.models.map((m) => ({
|
|
id: m.id,
|
|
name: m.name,
|
|
reasoning: m.reasoning ?? false,
|
|
input: ['text'] as ('text' | 'image')[],
|
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
|
contextWindow: m.contextWindow ?? 4096,
|
|
maxTokens: m.maxTokens ?? 4096,
|
|
})),
|
|
});
|
|
|
|
this.logger.log(`Registered custom provider: ${config.id} (${config.models.length} models)`);
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Private helpers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
private registerCustomProviders(): void {
|
|
const customJson = process.env['MOSAIC_CUSTOM_PROVIDERS'];
|
|
if (!customJson) return;
|
|
|
|
try {
|
|
const configs = JSON.parse(customJson) as CustomProviderConfig[];
|
|
for (const config of configs) {
|
|
this.registerCustomProvider(config);
|
|
}
|
|
} catch (err) {
|
|
this.logger.error('Failed to parse MOSAIC_CUSTOM_PROVIDERS', String(err));
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Resolve an API key for a provider, scoped to a specific user.
|
|
* User-stored credentials take precedence over environment variables.
|
|
* Returns null if no key is available from either source.
|
|
*/
|
|
async resolveApiKey(userId: string, provider: string): Promise<string | null> {
|
|
if (this.credentialsService) {
|
|
const userKey = await this.credentialsService.retrieve(userId, provider);
|
|
if (userKey) {
|
|
this.logger.debug(`Using user-scoped credential for user=${userId} provider=${provider}`);
|
|
return userKey;
|
|
}
|
|
}
|
|
|
|
// Fall back to environment variable
|
|
const envVar = PROVIDER_ENV_KEYS[provider];
|
|
const envKey = envVar ? (process.env[envVar] ?? null) : null;
|
|
if (envKey) {
|
|
this.logger.debug(`Using env-var credential for provider=${provider}`);
|
|
}
|
|
return envKey;
|
|
}
|
|
|
|
private cloneBuiltInModel(
|
|
provider: string,
|
|
modelId: string,
|
|
overrides: Partial<Model<Api>> = {},
|
|
): Model<Api> {
|
|
const model = getModel(provider as never, modelId as never) as Model<Api> | undefined;
|
|
if (!model) {
|
|
throw new Error(`Built-in model not found: ${provider}:${modelId}`);
|
|
}
|
|
|
|
return { ...model, ...overrides };
|
|
}
|
|
|
|
private toModelInfo(model: Model<Api>): ModelInfo {
|
|
return {
|
|
id: model.id,
|
|
provider: model.provider,
|
|
name: model.name,
|
|
reasoning: model.reasoning,
|
|
contextWindow: model.contextWindow,
|
|
maxTokens: model.maxTokens,
|
|
inputTypes: model.input,
|
|
cost: model.cost,
|
|
};
|
|
}
|
|
}
|