import { Inject, Injectable, Logger, type OnModuleDestroy } from '@nestjs/common'; import { createAgentSession, SessionManager, type AgentSession as PiAgentSession, type AgentSessionEvent, type ToolDefinition, } from '@mariozechner/pi-coding-agent'; import type { Brain } from '@mosaic/brain'; import type { Memory } from '@mosaic/memory'; import { BRAIN } from '../brain/brain.tokens.js'; import { MEMORY } from '../memory/memory.tokens.js'; import { EmbeddingService } from '../memory/embedding.service.js'; import { CoordService } from '../coord/coord.service.js'; import { ProviderService } from './provider.service.js'; import { createBrainTools } from './tools/brain-tools.js'; import { createCoordTools } from './tools/coord-tools.js'; import { createMemoryTools } from './tools/memory-tools.js'; import { createFileTools } from './tools/file-tools.js'; import { createGitTools } from './tools/git-tools.js'; import { createShellTools } from './tools/shell-tools.js'; import { createWebTools } from './tools/web-tools.js'; import type { SessionInfoDto } from './session.dto.js'; export interface AgentSessionOptions { provider?: string; modelId?: string; } export interface AgentSession { id: string; provider: string; modelId: string; piSession: PiAgentSession; listeners: Set<(event: AgentSessionEvent) => void>; unsubscribe: () => void; createdAt: number; promptCount: number; channels: Set; } @Injectable() export class AgentService implements OnModuleDestroy { private readonly logger = new Logger(AgentService.name); private readonly sessions = new Map(); private readonly creating = new Map>(); private readonly customTools: ToolDefinition[]; constructor( @Inject(ProviderService) private readonly providerService: ProviderService, @Inject(BRAIN) private readonly brain: Brain, @Inject(MEMORY) private readonly memory: Memory, @Inject(EmbeddingService) private readonly embeddingService: EmbeddingService, @Inject(CoordService) private readonly coordService: CoordService, ) { const fileBaseDir = process.env['AGENT_FILE_SANDBOX_DIR'] ?? process.cwd(); const gitDefaultCwd = process.env['AGENT_GIT_CWD'] ?? process.cwd(); const shellDefaultCwd = process.env['AGENT_SHELL_CWD'] ?? process.cwd(); this.customTools = [ ...createBrainTools(brain), ...createCoordTools(coordService), ...createMemoryTools(memory, embeddingService.available ? embeddingService : null), ...createFileTools(fileBaseDir), ...createGitTools(gitDefaultCwd), ...createShellTools(shellDefaultCwd), ...createWebTools(), ]; this.logger.log(`Registered ${this.customTools.length} custom tools`); } async createSession(sessionId: string, options?: AgentSessionOptions): Promise { const existing = this.sessions.get(sessionId); if (existing) return existing; const inflight = this.creating.get(sessionId); if (inflight) return inflight; const promise = this.doCreateSession(sessionId, options).finally(() => { this.creating.delete(sessionId); }); this.creating.set(sessionId, promise); return promise; } private async doCreateSession( sessionId: string, options?: AgentSessionOptions, ): Promise { const model = this.resolveModel(options); const providerName = model?.provider ?? 'default'; const modelId = model?.id ?? 'default'; this.logger.log( `Creating agent session: ${sessionId} (provider=${providerName}, model=${modelId})`, ); let piSession: PiAgentSession; try { const result = await createAgentSession({ sessionManager: SessionManager.inMemory(), modelRegistry: this.providerService.getRegistry(), model: model ?? undefined, tools: [], customTools: this.customTools, }); piSession = result.session; } catch (err) { this.logger.error( `Failed to create agent session for ${sessionId}`, err instanceof Error ? err.stack : String(err), ); throw new Error(`Agent session creation failed for ${sessionId}: ${String(err)}`); } const listeners = new Set<(event: AgentSessionEvent) => void>(); const unsubscribe = piSession.subscribe((event) => { for (const listener of listeners) { try { listener(event); } catch (err) { this.logger.error(`Event listener error in session ${sessionId}`, err); } } }); const session: AgentSession = { id: sessionId, provider: providerName, modelId, piSession, listeners, unsubscribe, createdAt: Date.now(), promptCount: 0, channels: new Set(), }; this.sessions.set(sessionId, session); this.logger.log(`Agent session ${sessionId} ready (${providerName}/${modelId})`); return session; } private resolveModel(options?: AgentSessionOptions) { if (!options?.provider && !options?.modelId) { return this.providerService.getDefaultModel() ?? null; } if (options.provider && options.modelId) { const model = this.providerService.findModel(options.provider, options.modelId); if (!model) { throw new Error(`Model not found: ${options.provider}/${options.modelId}`); } return model; } if (options.modelId) { const available = this.providerService.listAvailableModels(); const match = available.find((m) => m.id === options.modelId); if (match) { return this.providerService.findModel(match.provider, match.id) ?? null; } } return this.providerService.getDefaultModel() ?? null; } getSession(sessionId: string): AgentSession | undefined { return this.sessions.get(sessionId); } listSessions(): SessionInfoDto[] { const now = Date.now(); return Array.from(this.sessions.values()).map((s) => ({ id: s.id, provider: s.provider, modelId: s.modelId, createdAt: new Date(s.createdAt).toISOString(), promptCount: s.promptCount, channels: Array.from(s.channels), durationMs: now - s.createdAt, })); } getSessionInfo(sessionId: string): SessionInfoDto | undefined { const s = this.sessions.get(sessionId); if (!s) return undefined; return { id: s.id, provider: s.provider, modelId: s.modelId, createdAt: new Date(s.createdAt).toISOString(), promptCount: s.promptCount, channels: Array.from(s.channels), durationMs: Date.now() - s.createdAt, }; } addChannel(sessionId: string, channel: string): void { const session = this.sessions.get(sessionId); if (session) { session.channels.add(channel); } } removeChannel(sessionId: string, channel: string): void { const session = this.sessions.get(sessionId); if (session) { session.channels.delete(channel); } } async prompt(sessionId: string, message: string): Promise { const session = this.sessions.get(sessionId); if (!session) { throw new Error(`No agent session found: ${sessionId}`); } session.promptCount += 1; try { await session.piSession.prompt(message); } catch (err) { this.logger.error( `Prompt failed for session=${sessionId}, messageLength=${message.length}`, err instanceof Error ? err.stack : String(err), ); throw err; } } onEvent(sessionId: string, listener: (event: AgentSessionEvent) => void): () => void { const session = this.sessions.get(sessionId); if (!session) { throw new Error(`No agent session found: ${sessionId}`); } session.listeners.add(listener); return () => session.listeners.delete(listener); } async destroySession(sessionId: string): Promise { const session = this.sessions.get(sessionId); if (!session) return; this.logger.log(`Destroying agent session ${sessionId}`); try { session.unsubscribe(); } catch (err) { this.logger.error(`Failed to unsubscribe session ${sessionId}`, String(err)); } try { session.piSession.dispose(); } catch (err) { this.logger.error(`Failed to dispose piSession for ${sessionId}`, String(err)); } session.listeners.clear(); session.channels.clear(); this.sessions.delete(sessionId); } async onModuleDestroy(): Promise { this.logger.log('Shutting down all agent sessions'); const stops = Array.from(this.sessions.keys()).map((id) => this.destroySession(id)); const results = await Promise.allSettled(stops); for (const result of results) { if (result.status === 'rejected') { this.logger.error('Session shutdown failure', String(result.reason)); } } } }