From 94d6624c016530229276153a2b27b86c45b63db4 Mon Sep 17 00:00:00 2001 From: Jason Woltje Date: Thu, 12 Mar 2026 22:10:18 -0500 Subject: [PATCH] =?UTF-8?q?feat:=20multi-provider=20support=20=E2=80=94=20?= =?UTF-8?q?Anthropic=20+=20Ollama=20(P2-002)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add ProviderService wrapping Pi SDK's ModelRegistry for multi-provider LLM support. Built-in providers (Anthropic, OpenAI, Google, xAI, etc.) auto-discovered; Ollama registered via OLLAMA_BASE_URL env var; custom providers via MOSAIC_CUSTOM_PROVIDERS JSON env var. - ProviderService: wraps ModelRegistry, manages provider lifecycle - ProvidersController: GET /api/providers, GET /api/providers/models - AgentService: accepts provider/model params on session creation - ChatGateway: passes optional provider/modelId from chat messages - @mosaic/types: new provider/model type definitions Closes #20 Co-Authored-By: Claude Opus 4.6 --- apps/gateway/package.json | 2 + apps/gateway/src/agent/agent.module.ts | 7 +- apps/gateway/src/agent/agent.service.ts | 63 +++++++- apps/gateway/src/agent/provider.service.ts | 139 ++++++++++++++++++ .../gateway/src/agent/providers.controller.ts | 19 +++ apps/gateway/src/chat/chat.gateway.ts | 7 +- packages/types/src/index.ts | 1 + packages/types/src/provider/index.ts | 54 +++++++ pnpm-lock.yaml | 6 + 9 files changed, 287 insertions(+), 11 deletions(-) create mode 100644 apps/gateway/src/agent/provider.service.ts create mode 100644 apps/gateway/src/agent/providers.controller.ts create mode 100644 packages/types/src/provider/index.ts diff --git a/apps/gateway/package.json b/apps/gateway/package.json index 3498fa4..88674c8 100644 --- a/apps/gateway/package.json +++ b/apps/gateway/package.json @@ -12,10 +12,12 @@ "test": "vitest run --passWithNoTests" }, "dependencies": { + "@mariozechner/pi-ai": "~0.57.1", "@mariozechner/pi-coding-agent": "~0.57.1", "@mosaic/auth": "workspace:^", "@mosaic/brain": "workspace:^", "@mosaic/db": "workspace:^", + "@mosaic/types": "workspace:^", "@nestjs/common": "^11.0.0", "@nestjs/core": "^11.0.0", "@nestjs/platform-fastify": "^11.0.0", diff --git a/apps/gateway/src/agent/agent.module.ts b/apps/gateway/src/agent/agent.module.ts index 4de50c4..a659ff4 100644 --- a/apps/gateway/src/agent/agent.module.ts +++ b/apps/gateway/src/agent/agent.module.ts @@ -1,9 +1,12 @@ import { Global, Module } from '@nestjs/common'; import { AgentService } from './agent.service.js'; +import { ProviderService } from './provider.service.js'; +import { ProvidersController } from './providers.controller.js'; @Global() @Module({ - providers: [AgentService], - exports: [AgentService], + providers: [ProviderService, AgentService], + controllers: [ProvidersController], + exports: [AgentService, ProviderService], }) export class AgentModule {} diff --git a/apps/gateway/src/agent/agent.service.ts b/apps/gateway/src/agent/agent.service.ts index 9da4dc1..08d87fa 100644 --- a/apps/gateway/src/agent/agent.service.ts +++ b/apps/gateway/src/agent/agent.service.ts @@ -5,9 +5,17 @@ import { type AgentSession as PiAgentSession, type AgentSessionEvent, } from '@mariozechner/pi-coding-agent'; +import { ProviderService } from './provider.service.js'; + +export interface AgentSessionOptions { + provider?: string; + modelId?: string; +} export interface AgentSession { id: string; + provider: string; + modelId: string; piSession: PiAgentSession; listeners: Set<(event: AgentSessionEvent) => void>; unsubscribe: () => void; @@ -19,33 +27,46 @@ export class AgentService implements OnModuleDestroy { private readonly sessions = new Map(); private readonly creating = new Map>(); - async createSession(sessionId: string): Promise { + constructor(private readonly providerService: ProviderService) {} + + async createSession(sessionId: string, options?: AgentSessionOptions): Promise { const existing = this.sessions.get(sessionId); if (existing) return existing; const inflight = this.creating.get(sessionId); if (inflight) return inflight; - const promise = this.doCreateSession(sessionId).finally(() => { + const promise = this.doCreateSession(sessionId, options).finally(() => { this.creating.delete(sessionId); }); this.creating.set(sessionId, promise); return promise; } - private async doCreateSession(sessionId: string): Promise { - this.logger.log(`Creating agent session: ${sessionId}`); + private async doCreateSession( + sessionId: string, + options?: AgentSessionOptions, + ): Promise { + const model = this.resolveModel(options); + const providerName = model?.provider ?? 'default'; + const modelId = model?.id ?? 'default'; + + this.logger.log( + `Creating agent session: ${sessionId} (provider=${providerName}, model=${modelId})`, + ); let piSession: PiAgentSession; try { const result = await createAgentSession({ sessionManager: SessionManager.inMemory(), + modelRegistry: this.providerService.getRegistry(), + model: model ?? undefined, tools: [], }); piSession = result.session; } catch (err) { this.logger.error( - `Failed to create Pi SDK session for ${sessionId}`, + `Failed to create agent session for ${sessionId}`, err instanceof Error ? err.stack : String(err), ); throw new Error(`Agent session creation failed for ${sessionId}: ${String(err)}`); @@ -65,17 +86,43 @@ export class AgentService implements OnModuleDestroy { const session: AgentSession = { id: sessionId, + provider: providerName, + modelId, piSession, listeners, unsubscribe, }; this.sessions.set(sessionId, session); - this.logger.log(`Agent session ${sessionId} ready`); + this.logger.log(`Agent session ${sessionId} ready (${providerName}/${modelId})`); return session; } + private resolveModel(options?: AgentSessionOptions) { + if (!options?.provider && !options?.modelId) { + return this.providerService.getDefaultModel() ?? null; + } + + if (options.provider && options.modelId) { + const model = this.providerService.findModel(options.provider, options.modelId); + if (!model) { + throw new Error(`Model not found: ${options.provider}/${options.modelId}`); + } + return model; + } + + if (options.modelId) { + const available = this.providerService.listAvailableModels(); + const match = available.find((m) => m.id === options.modelId); + if (match) { + return this.providerService.findModel(match.provider, match.id) ?? null; + } + } + + return this.providerService.getDefaultModel() ?? null; + } + getSession(sessionId: string): AgentSession | undefined { return this.sessions.get(sessionId); } @@ -89,7 +136,7 @@ export class AgentService implements OnModuleDestroy { await session.piSession.prompt(message); } catch (err) { this.logger.error( - `Pi SDK prompt failed for session=${sessionId}, messageLength=${message.length}`, + `Prompt failed for session=${sessionId}, messageLength=${message.length}`, err instanceof Error ? err.stack : String(err), ); throw err; @@ -112,7 +159,7 @@ export class AgentService implements OnModuleDestroy { try { session.unsubscribe(); } catch (err) { - this.logger.error(`Failed to unsubscribe Pi session ${sessionId}`, String(err)); + this.logger.error(`Failed to unsubscribe session ${sessionId}`, String(err)); } session.listeners.clear(); this.sessions.delete(sessionId); diff --git a/apps/gateway/src/agent/provider.service.ts b/apps/gateway/src/agent/provider.service.ts new file mode 100644 index 0000000..64a4ff2 --- /dev/null +++ b/apps/gateway/src/agent/provider.service.ts @@ -0,0 +1,139 @@ +import { Injectable, Logger, type OnModuleInit } from '@nestjs/common'; +import { ModelRegistry, AuthStorage } from '@mariozechner/pi-coding-agent'; +import type { Model, Api } from '@mariozechner/pi-ai'; +import type { ModelInfo, ProviderInfo, CustomProviderConfig } from '@mosaic/types'; + +@Injectable() +export class ProviderService implements OnModuleInit { + private readonly logger = new Logger(ProviderService.name); + private registry!: ModelRegistry; + + async onModuleInit(): Promise { + const authStorage = AuthStorage.create(); + this.registry = new ModelRegistry(authStorage); + + this.registerOllamaProvider(); + this.registerCustomProviders(); + + const available = this.registry.getAvailable(); + this.logger.log(`Providers initialized: ${available.length} models available`); + } + + getRegistry(): ModelRegistry { + return this.registry; + } + + findModel(provider: string, modelId: string): Model | undefined { + return this.registry.find(provider, modelId); + } + + getDefaultModel(): Model | undefined { + const available = this.registry.getAvailable(); + return available[0]; + } + + listProviders(): ProviderInfo[] { + const allModels = this.registry.getAll(); + const availableModels = this.registry.getAvailable(); + const availableIds = new Set(availableModels.map((m) => `${m.provider}:${m.id}`)); + + const providerMap = new Map(); + + for (const model of allModels) { + let info = providerMap.get(model.provider); + if (!info) { + info = { + id: model.provider, + name: model.provider, + available: false, + models: [], + }; + providerMap.set(model.provider, info); + } + + const isAvailable = availableIds.has(`${model.provider}:${model.id}`); + if (isAvailable) info.available = true; + + info.models.push(this.toModelInfo(model)); + } + + return Array.from(providerMap.values()); + } + + listAvailableModels(): ModelInfo[] { + return this.registry.getAvailable().map((m) => this.toModelInfo(m)); + } + + registerCustomProvider(config: CustomProviderConfig): void { + this.registry.registerProvider(config.id, { + baseUrl: config.baseUrl, + apiKey: config.apiKey, + models: config.models.map((m) => ({ + id: m.id, + name: m.name, + reasoning: m.reasoning ?? false, + input: ['text'] as ('text' | 'image')[], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: m.contextWindow ?? 4096, + maxTokens: m.maxTokens ?? 4096, + })), + }); + + this.logger.log(`Registered custom provider: ${config.id} (${config.models.length} models)`); + } + + private registerOllamaProvider(): void { + const ollamaUrl = process.env['OLLAMA_BASE_URL'] ?? process.env['OLLAMA_HOST']; + if (!ollamaUrl) return; + + const modelsEnv = process.env['OLLAMA_MODELS'] ?? 'llama3.2,codellama,mistral'; + const modelIds = modelsEnv + .split(',') + .map((m) => m.trim()) + .filter(Boolean); + + this.registerCustomProvider({ + id: 'ollama', + name: 'Ollama', + baseUrl: `${ollamaUrl}/v1`, + models: modelIds.map((id) => ({ + id, + name: id, + reasoning: false, + contextWindow: 8192, + maxTokens: 4096, + })), + }); + + this.logger.log( + `Ollama provider registered at ${ollamaUrl} with models: ${modelIds.join(', ')}`, + ); + } + + private registerCustomProviders(): void { + const customJson = process.env['MOSAIC_CUSTOM_PROVIDERS']; + if (!customJson) return; + + try { + const configs = JSON.parse(customJson) as CustomProviderConfig[]; + for (const config of configs) { + this.registerCustomProvider(config); + } + } catch (err) { + this.logger.error('Failed to parse MOSAIC_CUSTOM_PROVIDERS', String(err)); + } + } + + private toModelInfo(model: Model): ModelInfo { + return { + id: model.id, + provider: model.provider, + name: model.name, + reasoning: model.reasoning, + contextWindow: model.contextWindow, + maxTokens: model.maxTokens, + inputTypes: model.input, + cost: model.cost, + }; + } +} diff --git a/apps/gateway/src/agent/providers.controller.ts b/apps/gateway/src/agent/providers.controller.ts new file mode 100644 index 0000000..d74f33a --- /dev/null +++ b/apps/gateway/src/agent/providers.controller.ts @@ -0,0 +1,19 @@ +import { Controller, Get, UseGuards } from '@nestjs/common'; +import { AuthGuard } from '../auth/auth.guard.js'; +import { ProviderService } from './provider.service.js'; + +@Controller('api/providers') +@UseGuards(AuthGuard) +export class ProvidersController { + constructor(private readonly providerService: ProviderService) {} + + @Get() + list() { + return this.providerService.listProviders(); + } + + @Get('models') + listModels() { + return this.providerService.listAvailableModels(); + } +} diff --git a/apps/gateway/src/chat/chat.gateway.ts b/apps/gateway/src/chat/chat.gateway.ts index 43a7a3c..34a61cb 100644 --- a/apps/gateway/src/chat/chat.gateway.ts +++ b/apps/gateway/src/chat/chat.gateway.ts @@ -17,6 +17,8 @@ import { v4 as uuid } from 'uuid'; interface ChatMessage { conversationId?: string; content: string; + provider?: string; + modelId?: string; } @WebSocketGateway({ @@ -65,7 +67,10 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa try { let agentSession = this.agentService.getSession(conversationId); if (!agentSession) { - agentSession = await this.agentService.createSession(conversationId); + agentSession = await this.agentService.createSession(conversationId, { + provider: data.provider, + modelId: data.modelId, + }); } } catch (err) { this.logger.error( diff --git a/packages/types/src/index.ts b/packages/types/src/index.ts index eb36664..3d52f83 100644 --- a/packages/types/src/index.ts +++ b/packages/types/src/index.ts @@ -2,3 +2,4 @@ export const VERSION = '0.0.0'; export * from './chat/index.js'; export * from './agent/index.js'; +export * from './provider/index.js'; diff --git a/packages/types/src/provider/index.ts b/packages/types/src/provider/index.ts new file mode 100644 index 0000000..1305391 --- /dev/null +++ b/packages/types/src/provider/index.ts @@ -0,0 +1,54 @@ +/** Known built-in LLM provider identifiers */ +export type KnownProvider = + | 'anthropic' + | 'openai' + | 'google' + | 'ollama' + | 'xai' + | 'groq' + | 'openrouter' + | 'zai' + | 'mistral'; + +/** Provider identifier — known providers or custom string */ +export type ProviderId = KnownProvider | string; + +/** Describes an available LLM model */ +export interface ModelInfo { + id: string; + provider: ProviderId; + name: string; + reasoning: boolean; + contextWindow: number; + maxTokens: number; + inputTypes: ('text' | 'image')[]; + cost: { + input: number; + output: number; + cacheRead: number; + cacheWrite: number; + }; +} + +/** Describes an available provider */ +export interface ProviderInfo { + id: ProviderId; + name: string; + available: boolean; + models: ModelInfo[]; +} + +/** Configuration for a custom (non-built-in) provider */ +export interface CustomProviderConfig { + id: string; + name: string; + baseUrl: string; + apiKey?: string; + models: Array<{ + id: string; + name: string; + reasoning?: boolean; + contextWindow?: number; + maxTokens?: number; + }>; +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 8e3b27c..ace0241 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -41,6 +41,9 @@ importers: apps/gateway: dependencies: + '@mariozechner/pi-ai': + specifier: ~0.57.1 + version: 0.57.1(ws@8.19.0)(zod@4.3.6) '@mariozechner/pi-coding-agent': specifier: ~0.57.1 version: 0.57.1(ws@8.19.0)(zod@4.3.6) @@ -53,6 +56,9 @@ importers: '@mosaic/db': specifier: workspace:^ version: link:../../packages/db + '@mosaic/types': + specifier: workspace:^ + version: link:../../packages/types '@nestjs/common': specifier: ^11.0.0 version: 11.1.16(class-transformer@0.5.1)(class-validator@0.15.1)(reflect-metadata@0.2.2)(rxjs@7.8.2) -- 2.49.1