feat: multi-provider support — Anthropic + Ollama (P2-002) (#74)
Co-authored-by: Jason Woltje <jason@diversecanvas.com> Co-committed-by: Jason Woltje <jason@diversecanvas.com>
This commit was merged in pull request #74.
This commit is contained in:
@@ -12,10 +12,12 @@
|
|||||||
"test": "vitest run --passWithNoTests"
|
"test": "vitest run --passWithNoTests"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@mariozechner/pi-ai": "~0.57.1",
|
||||||
"@mariozechner/pi-coding-agent": "~0.57.1",
|
"@mariozechner/pi-coding-agent": "~0.57.1",
|
||||||
"@mosaic/auth": "workspace:^",
|
"@mosaic/auth": "workspace:^",
|
||||||
"@mosaic/brain": "workspace:^",
|
"@mosaic/brain": "workspace:^",
|
||||||
"@mosaic/db": "workspace:^",
|
"@mosaic/db": "workspace:^",
|
||||||
|
"@mosaic/types": "workspace:^",
|
||||||
"@nestjs/common": "^11.0.0",
|
"@nestjs/common": "^11.0.0",
|
||||||
"@nestjs/core": "^11.0.0",
|
"@nestjs/core": "^11.0.0",
|
||||||
"@nestjs/platform-fastify": "^11.0.0",
|
"@nestjs/platform-fastify": "^11.0.0",
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import { Global, Module } from '@nestjs/common';
|
import { Global, Module } from '@nestjs/common';
|
||||||
import { AgentService } from './agent.service.js';
|
import { AgentService } from './agent.service.js';
|
||||||
|
import { ProviderService } from './provider.service.js';
|
||||||
|
import { ProvidersController } from './providers.controller.js';
|
||||||
|
|
||||||
@Global()
|
@Global()
|
||||||
@Module({
|
@Module({
|
||||||
providers: [AgentService],
|
providers: [ProviderService, AgentService],
|
||||||
exports: [AgentService],
|
controllers: [ProvidersController],
|
||||||
|
exports: [AgentService, ProviderService],
|
||||||
})
|
})
|
||||||
export class AgentModule {}
|
export class AgentModule {}
|
||||||
|
|||||||
@@ -5,9 +5,17 @@ import {
|
|||||||
type AgentSession as PiAgentSession,
|
type AgentSession as PiAgentSession,
|
||||||
type AgentSessionEvent,
|
type AgentSessionEvent,
|
||||||
} from '@mariozechner/pi-coding-agent';
|
} from '@mariozechner/pi-coding-agent';
|
||||||
|
import { ProviderService } from './provider.service.js';
|
||||||
|
|
||||||
|
export interface AgentSessionOptions {
|
||||||
|
provider?: string;
|
||||||
|
modelId?: string;
|
||||||
|
}
|
||||||
|
|
||||||
export interface AgentSession {
|
export interface AgentSession {
|
||||||
id: string;
|
id: string;
|
||||||
|
provider: string;
|
||||||
|
modelId: string;
|
||||||
piSession: PiAgentSession;
|
piSession: PiAgentSession;
|
||||||
listeners: Set<(event: AgentSessionEvent) => void>;
|
listeners: Set<(event: AgentSessionEvent) => void>;
|
||||||
unsubscribe: () => void;
|
unsubscribe: () => void;
|
||||||
@@ -19,33 +27,46 @@ export class AgentService implements OnModuleDestroy {
|
|||||||
private readonly sessions = new Map<string, AgentSession>();
|
private readonly sessions = new Map<string, AgentSession>();
|
||||||
private readonly creating = new Map<string, Promise<AgentSession>>();
|
private readonly creating = new Map<string, Promise<AgentSession>>();
|
||||||
|
|
||||||
async createSession(sessionId: string): Promise<AgentSession> {
|
constructor(private readonly providerService: ProviderService) {}
|
||||||
|
|
||||||
|
async createSession(sessionId: string, options?: AgentSessionOptions): Promise<AgentSession> {
|
||||||
const existing = this.sessions.get(sessionId);
|
const existing = this.sessions.get(sessionId);
|
||||||
if (existing) return existing;
|
if (existing) return existing;
|
||||||
|
|
||||||
const inflight = this.creating.get(sessionId);
|
const inflight = this.creating.get(sessionId);
|
||||||
if (inflight) return inflight;
|
if (inflight) return inflight;
|
||||||
|
|
||||||
const promise = this.doCreateSession(sessionId).finally(() => {
|
const promise = this.doCreateSession(sessionId, options).finally(() => {
|
||||||
this.creating.delete(sessionId);
|
this.creating.delete(sessionId);
|
||||||
});
|
});
|
||||||
this.creating.set(sessionId, promise);
|
this.creating.set(sessionId, promise);
|
||||||
return promise;
|
return promise;
|
||||||
}
|
}
|
||||||
|
|
||||||
private async doCreateSession(sessionId: string): Promise<AgentSession> {
|
private async doCreateSession(
|
||||||
this.logger.log(`Creating agent session: ${sessionId}`);
|
sessionId: string,
|
||||||
|
options?: AgentSessionOptions,
|
||||||
|
): Promise<AgentSession> {
|
||||||
|
const model = this.resolveModel(options);
|
||||||
|
const providerName = model?.provider ?? 'default';
|
||||||
|
const modelId = model?.id ?? 'default';
|
||||||
|
|
||||||
|
this.logger.log(
|
||||||
|
`Creating agent session: ${sessionId} (provider=${providerName}, model=${modelId})`,
|
||||||
|
);
|
||||||
|
|
||||||
let piSession: PiAgentSession;
|
let piSession: PiAgentSession;
|
||||||
try {
|
try {
|
||||||
const result = await createAgentSession({
|
const result = await createAgentSession({
|
||||||
sessionManager: SessionManager.inMemory(),
|
sessionManager: SessionManager.inMemory(),
|
||||||
|
modelRegistry: this.providerService.getRegistry(),
|
||||||
|
model: model ?? undefined,
|
||||||
tools: [],
|
tools: [],
|
||||||
});
|
});
|
||||||
piSession = result.session;
|
piSession = result.session;
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
this.logger.error(
|
this.logger.error(
|
||||||
`Failed to create Pi SDK session for ${sessionId}`,
|
`Failed to create agent session for ${sessionId}`,
|
||||||
err instanceof Error ? err.stack : String(err),
|
err instanceof Error ? err.stack : String(err),
|
||||||
);
|
);
|
||||||
throw new Error(`Agent session creation failed for ${sessionId}: ${String(err)}`);
|
throw new Error(`Agent session creation failed for ${sessionId}: ${String(err)}`);
|
||||||
@@ -65,17 +86,43 @@ export class AgentService implements OnModuleDestroy {
|
|||||||
|
|
||||||
const session: AgentSession = {
|
const session: AgentSession = {
|
||||||
id: sessionId,
|
id: sessionId,
|
||||||
|
provider: providerName,
|
||||||
|
modelId,
|
||||||
piSession,
|
piSession,
|
||||||
listeners,
|
listeners,
|
||||||
unsubscribe,
|
unsubscribe,
|
||||||
};
|
};
|
||||||
|
|
||||||
this.sessions.set(sessionId, session);
|
this.sessions.set(sessionId, session);
|
||||||
this.logger.log(`Agent session ${sessionId} ready`);
|
this.logger.log(`Agent session ${sessionId} ready (${providerName}/${modelId})`);
|
||||||
|
|
||||||
return session;
|
return session;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private resolveModel(options?: AgentSessionOptions) {
|
||||||
|
if (!options?.provider && !options?.modelId) {
|
||||||
|
return this.providerService.getDefaultModel() ?? null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.provider && options.modelId) {
|
||||||
|
const model = this.providerService.findModel(options.provider, options.modelId);
|
||||||
|
if (!model) {
|
||||||
|
throw new Error(`Model not found: ${options.provider}/${options.modelId}`);
|
||||||
|
}
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.modelId) {
|
||||||
|
const available = this.providerService.listAvailableModels();
|
||||||
|
const match = available.find((m) => m.id === options.modelId);
|
||||||
|
if (match) {
|
||||||
|
return this.providerService.findModel(match.provider, match.id) ?? null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.providerService.getDefaultModel() ?? null;
|
||||||
|
}
|
||||||
|
|
||||||
getSession(sessionId: string): AgentSession | undefined {
|
getSession(sessionId: string): AgentSession | undefined {
|
||||||
return this.sessions.get(sessionId);
|
return this.sessions.get(sessionId);
|
||||||
}
|
}
|
||||||
@@ -89,7 +136,7 @@ export class AgentService implements OnModuleDestroy {
|
|||||||
await session.piSession.prompt(message);
|
await session.piSession.prompt(message);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
this.logger.error(
|
this.logger.error(
|
||||||
`Pi SDK prompt failed for session=${sessionId}, messageLength=${message.length}`,
|
`Prompt failed for session=${sessionId}, messageLength=${message.length}`,
|
||||||
err instanceof Error ? err.stack : String(err),
|
err instanceof Error ? err.stack : String(err),
|
||||||
);
|
);
|
||||||
throw err;
|
throw err;
|
||||||
@@ -112,7 +159,7 @@ export class AgentService implements OnModuleDestroy {
|
|||||||
try {
|
try {
|
||||||
session.unsubscribe();
|
session.unsubscribe();
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
this.logger.error(`Failed to unsubscribe Pi session ${sessionId}`, String(err));
|
this.logger.error(`Failed to unsubscribe session ${sessionId}`, String(err));
|
||||||
}
|
}
|
||||||
session.listeners.clear();
|
session.listeners.clear();
|
||||||
this.sessions.delete(sessionId);
|
this.sessions.delete(sessionId);
|
||||||
|
|||||||
139
apps/gateway/src/agent/provider.service.ts
Normal file
139
apps/gateway/src/agent/provider.service.ts
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import { Injectable, Logger, type OnModuleInit } from '@nestjs/common';
|
||||||
|
import { ModelRegistry, AuthStorage } from '@mariozechner/pi-coding-agent';
|
||||||
|
import type { Model, Api } from '@mariozechner/pi-ai';
|
||||||
|
import type { ModelInfo, ProviderInfo, CustomProviderConfig } from '@mosaic/types';
|
||||||
|
|
||||||
|
@Injectable()
|
||||||
|
export class ProviderService implements OnModuleInit {
|
||||||
|
private readonly logger = new Logger(ProviderService.name);
|
||||||
|
private registry!: ModelRegistry;
|
||||||
|
|
||||||
|
async onModuleInit(): Promise<void> {
|
||||||
|
const authStorage = AuthStorage.create();
|
||||||
|
this.registry = new ModelRegistry(authStorage);
|
||||||
|
|
||||||
|
this.registerOllamaProvider();
|
||||||
|
this.registerCustomProviders();
|
||||||
|
|
||||||
|
const available = this.registry.getAvailable();
|
||||||
|
this.logger.log(`Providers initialized: ${available.length} models available`);
|
||||||
|
}
|
||||||
|
|
||||||
|
getRegistry(): ModelRegistry {
|
||||||
|
return this.registry;
|
||||||
|
}
|
||||||
|
|
||||||
|
findModel(provider: string, modelId: string): Model<Api> | undefined {
|
||||||
|
return this.registry.find(provider, modelId);
|
||||||
|
}
|
||||||
|
|
||||||
|
getDefaultModel(): Model<Api> | undefined {
|
||||||
|
const available = this.registry.getAvailable();
|
||||||
|
return available[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
listProviders(): ProviderInfo[] {
|
||||||
|
const allModels = this.registry.getAll();
|
||||||
|
const availableModels = this.registry.getAvailable();
|
||||||
|
const availableIds = new Set(availableModels.map((m) => `${m.provider}:${m.id}`));
|
||||||
|
|
||||||
|
const providerMap = new Map<string, ProviderInfo>();
|
||||||
|
|
||||||
|
for (const model of allModels) {
|
||||||
|
let info = providerMap.get(model.provider);
|
||||||
|
if (!info) {
|
||||||
|
info = {
|
||||||
|
id: model.provider,
|
||||||
|
name: model.provider,
|
||||||
|
available: false,
|
||||||
|
models: [],
|
||||||
|
};
|
||||||
|
providerMap.set(model.provider, info);
|
||||||
|
}
|
||||||
|
|
||||||
|
const isAvailable = availableIds.has(`${model.provider}:${model.id}`);
|
||||||
|
if (isAvailable) info.available = true;
|
||||||
|
|
||||||
|
info.models.push(this.toModelInfo(model));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Array.from(providerMap.values());
|
||||||
|
}
|
||||||
|
|
||||||
|
listAvailableModels(): ModelInfo[] {
|
||||||
|
return this.registry.getAvailable().map((m) => this.toModelInfo(m));
|
||||||
|
}
|
||||||
|
|
||||||
|
registerCustomProvider(config: CustomProviderConfig): void {
|
||||||
|
this.registry.registerProvider(config.id, {
|
||||||
|
baseUrl: config.baseUrl,
|
||||||
|
apiKey: config.apiKey,
|
||||||
|
models: config.models.map((m) => ({
|
||||||
|
id: m.id,
|
||||||
|
name: m.name,
|
||||||
|
reasoning: m.reasoning ?? false,
|
||||||
|
input: ['text'] as ('text' | 'image')[],
|
||||||
|
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||||
|
contextWindow: m.contextWindow ?? 4096,
|
||||||
|
maxTokens: m.maxTokens ?? 4096,
|
||||||
|
})),
|
||||||
|
});
|
||||||
|
|
||||||
|
this.logger.log(`Registered custom provider: ${config.id} (${config.models.length} models)`);
|
||||||
|
}
|
||||||
|
|
||||||
|
private registerOllamaProvider(): void {
|
||||||
|
const ollamaUrl = process.env['OLLAMA_BASE_URL'] ?? process.env['OLLAMA_HOST'];
|
||||||
|
if (!ollamaUrl) return;
|
||||||
|
|
||||||
|
const modelsEnv = process.env['OLLAMA_MODELS'] ?? 'llama3.2,codellama,mistral';
|
||||||
|
const modelIds = modelsEnv
|
||||||
|
.split(',')
|
||||||
|
.map((m) => m.trim())
|
||||||
|
.filter(Boolean);
|
||||||
|
|
||||||
|
this.registerCustomProvider({
|
||||||
|
id: 'ollama',
|
||||||
|
name: 'Ollama',
|
||||||
|
baseUrl: `${ollamaUrl}/v1`,
|
||||||
|
models: modelIds.map((id) => ({
|
||||||
|
id,
|
||||||
|
name: id,
|
||||||
|
reasoning: false,
|
||||||
|
contextWindow: 8192,
|
||||||
|
maxTokens: 4096,
|
||||||
|
})),
|
||||||
|
});
|
||||||
|
|
||||||
|
this.logger.log(
|
||||||
|
`Ollama provider registered at ${ollamaUrl} with models: ${modelIds.join(', ')}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private registerCustomProviders(): void {
|
||||||
|
const customJson = process.env['MOSAIC_CUSTOM_PROVIDERS'];
|
||||||
|
if (!customJson) return;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const configs = JSON.parse(customJson) as CustomProviderConfig[];
|
||||||
|
for (const config of configs) {
|
||||||
|
this.registerCustomProvider(config);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
this.logger.error('Failed to parse MOSAIC_CUSTOM_PROVIDERS', String(err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private toModelInfo(model: Model<Api>): ModelInfo {
|
||||||
|
return {
|
||||||
|
id: model.id,
|
||||||
|
provider: model.provider,
|
||||||
|
name: model.name,
|
||||||
|
reasoning: model.reasoning,
|
||||||
|
contextWindow: model.contextWindow,
|
||||||
|
maxTokens: model.maxTokens,
|
||||||
|
inputTypes: model.input,
|
||||||
|
cost: model.cost,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
19
apps/gateway/src/agent/providers.controller.ts
Normal file
19
apps/gateway/src/agent/providers.controller.ts
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import { Controller, Get, UseGuards } from '@nestjs/common';
|
||||||
|
import { AuthGuard } from '../auth/auth.guard.js';
|
||||||
|
import { ProviderService } from './provider.service.js';
|
||||||
|
|
||||||
|
@Controller('api/providers')
|
||||||
|
@UseGuards(AuthGuard)
|
||||||
|
export class ProvidersController {
|
||||||
|
constructor(private readonly providerService: ProviderService) {}
|
||||||
|
|
||||||
|
@Get()
|
||||||
|
list() {
|
||||||
|
return this.providerService.listProviders();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Get('models')
|
||||||
|
listModels() {
|
||||||
|
return this.providerService.listAvailableModels();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -17,6 +17,8 @@ import { v4 as uuid } from 'uuid';
|
|||||||
interface ChatMessage {
|
interface ChatMessage {
|
||||||
conversationId?: string;
|
conversationId?: string;
|
||||||
content: string;
|
content: string;
|
||||||
|
provider?: string;
|
||||||
|
modelId?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
@WebSocketGateway({
|
@WebSocketGateway({
|
||||||
@@ -65,7 +67,10 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
|||||||
try {
|
try {
|
||||||
let agentSession = this.agentService.getSession(conversationId);
|
let agentSession = this.agentService.getSession(conversationId);
|
||||||
if (!agentSession) {
|
if (!agentSession) {
|
||||||
agentSession = await this.agentService.createSession(conversationId);
|
agentSession = await this.agentService.createSession(conversationId, {
|
||||||
|
provider: data.provider,
|
||||||
|
modelId: data.modelId,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
this.logger.error(
|
this.logger.error(
|
||||||
|
|||||||
@@ -2,3 +2,4 @@ export const VERSION = '0.0.0';
|
|||||||
|
|
||||||
export * from './chat/index.js';
|
export * from './chat/index.js';
|
||||||
export * from './agent/index.js';
|
export * from './agent/index.js';
|
||||||
|
export * from './provider/index.js';
|
||||||
|
|||||||
54
packages/types/src/provider/index.ts
Normal file
54
packages/types/src/provider/index.ts
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
/** Known built-in LLM provider identifiers */
|
||||||
|
export type KnownProvider =
|
||||||
|
| 'anthropic'
|
||||||
|
| 'openai'
|
||||||
|
| 'google'
|
||||||
|
| 'ollama'
|
||||||
|
| 'xai'
|
||||||
|
| 'groq'
|
||||||
|
| 'openrouter'
|
||||||
|
| 'zai'
|
||||||
|
| 'mistral';
|
||||||
|
|
||||||
|
/** Provider identifier — known providers or custom string */
|
||||||
|
export type ProviderId = KnownProvider | string;
|
||||||
|
|
||||||
|
/** Describes an available LLM model */
|
||||||
|
export interface ModelInfo {
|
||||||
|
id: string;
|
||||||
|
provider: ProviderId;
|
||||||
|
name: string;
|
||||||
|
reasoning: boolean;
|
||||||
|
contextWindow: number;
|
||||||
|
maxTokens: number;
|
||||||
|
inputTypes: ('text' | 'image')[];
|
||||||
|
cost: {
|
||||||
|
input: number;
|
||||||
|
output: number;
|
||||||
|
cacheRead: number;
|
||||||
|
cacheWrite: number;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Describes an available provider */
|
||||||
|
export interface ProviderInfo {
|
||||||
|
id: ProviderId;
|
||||||
|
name: string;
|
||||||
|
available: boolean;
|
||||||
|
models: ModelInfo[];
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Configuration for a custom (non-built-in) provider */
|
||||||
|
export interface CustomProviderConfig {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
baseUrl: string;
|
||||||
|
apiKey?: string;
|
||||||
|
models: Array<{
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
reasoning?: boolean;
|
||||||
|
contextWindow?: number;
|
||||||
|
maxTokens?: number;
|
||||||
|
}>;
|
||||||
|
}
|
||||||
6
pnpm-lock.yaml
generated
6
pnpm-lock.yaml
generated
@@ -41,6 +41,9 @@ importers:
|
|||||||
|
|
||||||
apps/gateway:
|
apps/gateway:
|
||||||
dependencies:
|
dependencies:
|
||||||
|
'@mariozechner/pi-ai':
|
||||||
|
specifier: ~0.57.1
|
||||||
|
version: 0.57.1(ws@8.19.0)(zod@4.3.6)
|
||||||
'@mariozechner/pi-coding-agent':
|
'@mariozechner/pi-coding-agent':
|
||||||
specifier: ~0.57.1
|
specifier: ~0.57.1
|
||||||
version: 0.57.1(ws@8.19.0)(zod@4.3.6)
|
version: 0.57.1(ws@8.19.0)(zod@4.3.6)
|
||||||
@@ -53,6 +56,9 @@ importers:
|
|||||||
'@mosaic/db':
|
'@mosaic/db':
|
||||||
specifier: workspace:^
|
specifier: workspace:^
|
||||||
version: link:../../packages/db
|
version: link:../../packages/db
|
||||||
|
'@mosaic/types':
|
||||||
|
specifier: workspace:^
|
||||||
|
version: link:../../packages/types
|
||||||
'@nestjs/common':
|
'@nestjs/common':
|
||||||
specifier: ^11.0.0
|
specifier: ^11.0.0
|
||||||
version: 11.1.16(class-transformer@0.5.1)(class-validator@0.15.1)(reflect-metadata@0.2.2)(rxjs@7.8.2)
|
version: 11.1.16(class-transformer@0.5.1)(class-validator@0.15.1)(reflect-metadata@0.2.2)(rxjs@7.8.2)
|
||||||
|
|||||||
Reference in New Issue
Block a user