import { Inject, Logger } from '@nestjs/common'; import { WebSocketGateway, WebSocketServer, SubscribeMessage, OnGatewayConnection, OnGatewayDisconnect, type OnGatewayInit, ConnectedSocket, MessageBody, } from '@nestjs/websockets'; import { Server, Socket } from 'socket.io'; import type { AgentSessionEvent } from '@mariozechner/pi-coding-agent'; import type { Auth } from '@mosaic/auth'; import type { Brain } from '@mosaic/brain'; import type { SetThinkingPayload, SlashCommandPayload, SystemReloadPayload } from '@mosaic/types'; import { AgentService, type ConversationHistoryMessage } from '../agent/agent.service.js'; import { AUTH } from '../auth/auth.tokens.js'; import { BRAIN } from '../brain/brain.tokens.js'; import { CommandRegistryService } from '../commands/command-registry.service.js'; import { CommandExecutorService } from '../commands/command-executor.service.js'; import { v4 as uuid } from 'uuid'; import { ChatSocketMessageDto } from './chat.dto.js'; import { validateSocketSession } from './chat.gateway-auth.js'; /** Per-client state tracking streaming accumulation for persistence. */ interface ClientSession { conversationId: string; cleanup: () => void; /** Accumulated assistant response text for the current turn. */ assistantText: string; /** Tool calls observed during the current turn. */ toolCalls: Array<{ toolCallId: string; toolName: string; args: unknown; isError: boolean }>; /** Tool calls in-flight (started but not ended yet). */ pendingToolCalls: Map; } @WebSocketGateway({ cors: { origin: process.env['GATEWAY_CORS_ORIGIN'] ?? 'http://localhost:3000', }, namespace: '/chat', }) export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect { @WebSocketServer() server!: Server; private readonly logger = new Logger(ChatGateway.name); private readonly clientSessions = new Map(); constructor( @Inject(AgentService) private readonly agentService: AgentService, @Inject(AUTH) private readonly auth: Auth, @Inject(BRAIN) private readonly brain: Brain, @Inject(CommandRegistryService) private readonly commandRegistry: CommandRegistryService, @Inject(CommandExecutorService) private readonly commandExecutor: CommandExecutorService, ) {} afterInit(): void { this.logger.log('Chat WebSocket gateway initialized'); } async handleConnection(client: Socket): Promise { const session = await validateSocketSession(client.handshake.headers, this.auth); if (!session) { this.logger.warn(`Rejected unauthenticated WebSocket client: ${client.id}`); client.disconnect(); return; } client.data.user = session.user; client.data.session = session.session; this.logger.log(`Client connected: ${client.id}`); // Broadcast command manifest to the newly connected client client.emit('commands:manifest', { manifest: this.commandRegistry.getManifest() }); } handleDisconnect(client: Socket): void { this.logger.log(`Client disconnected: ${client.id}`); const session = this.clientSessions.get(client.id); if (session) { session.cleanup(); this.agentService.removeChannel(session.conversationId, `websocket:${client.id}`); this.clientSessions.delete(client.id); } } @SubscribeMessage('message') async handleMessage( @ConnectedSocket() client: Socket, @MessageBody() data: ChatSocketMessageDto, ): Promise { const conversationId = data.conversationId ?? uuid(); const userId = (client.data.user as { id: string } | undefined)?.id; this.logger.log(`Message from ${client.id} in conversation ${conversationId}`); // Ensure agent session exists for this conversation try { let agentSession = this.agentService.getSession(conversationId); if (!agentSession) { // When resuming an existing conversation, load prior messages to inject as context (M1-004) const conversationHistory = await this.loadConversationHistory(conversationId, userId); agentSession = await this.agentService.createSession(conversationId, { provider: data.provider, modelId: data.modelId, agentConfigId: data.agentId, userId, conversationHistory: conversationHistory.length > 0 ? conversationHistory : undefined, }); if (conversationHistory.length > 0) { this.logger.log( `Loaded ${conversationHistory.length} prior messages for conversation=${conversationId}`, ); } } } catch (err) { this.logger.error( `Session creation failed for client=${client.id}, conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); client.emit('error', { conversationId, error: 'Failed to start agent session. Please try again.', }); return; } // Ensure conversation record exists in the DB before persisting messages if (userId) { await this.ensureConversation(conversationId, userId); } // Persist the user message if (userId) { try { await this.brain.conversations.addMessage( { conversationId, role: 'user', content: data.content, metadata: { timestamp: new Date().toISOString(), }, }, userId, ); } catch (err) { this.logger.error( `Failed to persist user message for conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); } } // Always clean up previous listener to prevent leak const existing = this.clientSessions.get(client.id); if (existing) { existing.cleanup(); } // Subscribe to agent events and relay to client const cleanup = this.agentService.onEvent(conversationId, (event: AgentSessionEvent) => { this.relayEvent(client, conversationId, event); }); this.clientSessions.set(client.id, { conversationId, cleanup, assistantText: '', toolCalls: [], pendingToolCalls: new Map(), }); // Track channel connection this.agentService.addChannel(conversationId, `websocket:${client.id}`); // Send session info so the client knows the model/provider { const agentSession = this.agentService.getSession(conversationId); if (agentSession) { const piSession = agentSession.piSession; client.emit('session:info', { conversationId, provider: agentSession.provider, modelId: agentSession.modelId, thinkingLevel: piSession.thinkingLevel, availableThinkingLevels: piSession.getAvailableThinkingLevels(), }); } } // Send acknowledgment client.emit('message:ack', { conversationId, messageId: uuid() }); // Dispatch to agent try { await this.agentService.prompt(conversationId, data.content); } catch (err) { this.logger.error( `Agent prompt failed for client=${client.id}, conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); client.emit('error', { conversationId, error: 'The agent failed to process your message. Please try again.', }); } } @SubscribeMessage('set:thinking') handleSetThinking( @ConnectedSocket() client: Socket, @MessageBody() data: SetThinkingPayload, ): void { const session = this.agentService.getSession(data.conversationId); if (!session) { client.emit('error', { conversationId: data.conversationId, error: 'No active session for this conversation.', }); return; } const validLevels = session.piSession.getAvailableThinkingLevels(); if (!validLevels.includes(data.level as never)) { client.emit('error', { conversationId: data.conversationId, error: `Invalid thinking level "${data.level}". Available: ${validLevels.join(', ')}`, }); return; } session.piSession.setThinkingLevel(data.level as never); this.logger.log( `Thinking level set to "${data.level}" for conversation ${data.conversationId}`, ); client.emit('session:info', { conversationId: data.conversationId, provider: session.provider, modelId: session.modelId, thinkingLevel: session.piSession.thinkingLevel, availableThinkingLevels: session.piSession.getAvailableThinkingLevels(), }); } @SubscribeMessage('command:execute') async handleCommandExecute( @ConnectedSocket() client: Socket, @MessageBody() payload: SlashCommandPayload, ): Promise { const userId = (client.data.user as { id: string } | undefined)?.id ?? 'unknown'; const result = await this.commandExecutor.execute(payload, userId); client.emit('command:result', result); } broadcastReload(payload: SystemReloadPayload): void { this.server.emit('system:reload', payload); this.logger.log('Broadcasted system:reload to all connected clients'); } /** * Ensure a conversation record exists in the DB. * Creates it if absent — safe to call concurrently since a duplicate insert * would fail on the PK constraint and be caught here. */ private async ensureConversation(conversationId: string, userId: string): Promise { try { const existing = await this.brain.conversations.findById(conversationId, userId); if (!existing) { await this.brain.conversations.create({ id: conversationId, userId, }); } } catch (err) { this.logger.error( `Failed to ensure conversation record for conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); } } /** * Load prior conversation messages from DB for context injection on session resume (M1-004). * Returns an empty array when no history exists, the conversation is not owned by the user, * or userId is not provided. */ private async loadConversationHistory( conversationId: string, userId: string | undefined, ): Promise { if (!userId) return []; try { const messages = await this.brain.conversations.findMessages(conversationId, userId); if (messages.length === 0) return []; return messages.map((msg) => ({ role: msg.role as 'user' | 'assistant' | 'system', content: msg.content, createdAt: msg.createdAt, })); } catch (err) { this.logger.error( `Failed to load conversation history for conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); return []; } } private relayEvent(client: Socket, conversationId: string, event: AgentSessionEvent): void { if (!client.connected) { this.logger.warn( `Dropping event ${event.type} for disconnected client=${client.id}, conversation=${conversationId}`, ); return; } switch (event.type) { case 'agent_start': { // Reset accumulation buffers for the new turn const cs = this.clientSessions.get(client.id); if (cs) { cs.assistantText = ''; cs.toolCalls = []; cs.pendingToolCalls.clear(); } client.emit('agent:start', { conversationId }); break; } case 'agent_end': { // Gather usage stats from the Pi session const agentSession = this.agentService.getSession(conversationId); const piSession = agentSession?.piSession; const stats = piSession?.getSessionStats(); const contextUsage = piSession?.getContextUsage(); const usagePayload = stats ? { provider: agentSession?.provider ?? 'unknown', modelId: agentSession?.modelId ?? 'unknown', thinkingLevel: piSession?.thinkingLevel ?? 'off', tokens: stats.tokens, cost: stats.cost, context: { percent: contextUsage?.percent ?? null, window: contextUsage?.contextWindow ?? 0, }, } : undefined; client.emit('agent:end', { conversationId, usage: usagePayload, }); // Persist the assistant message with metadata const cs = this.clientSessions.get(client.id); const userId = (client.data.user as { id: string } | undefined)?.id; if (cs && userId && cs.assistantText.trim().length > 0) { const metadata: Record = { timestamp: new Date().toISOString(), model: agentSession?.modelId ?? 'unknown', provider: agentSession?.provider ?? 'unknown', toolCalls: cs.toolCalls, }; if (stats?.tokens) { metadata['tokenUsage'] = { input: stats.tokens.input, output: stats.tokens.output, cacheRead: stats.tokens.cacheRead, cacheWrite: stats.tokens.cacheWrite, total: stats.tokens.total, }; } this.brain.conversations .addMessage( { conversationId, role: 'assistant', content: cs.assistantText, metadata, }, userId, ) .catch((err: unknown) => { this.logger.error( `Failed to persist assistant message for conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); }); // Reset accumulation cs.assistantText = ''; cs.toolCalls = []; cs.pendingToolCalls.clear(); } break; } case 'message_update': { const assistantEvent = event.assistantMessageEvent; if (assistantEvent.type === 'text_delta') { // Accumulate assistant text for persistence const cs = this.clientSessions.get(client.id); if (cs) { cs.assistantText += assistantEvent.delta; } client.emit('agent:text', { conversationId, text: assistantEvent.delta, }); } else if (assistantEvent.type === 'thinking_delta') { client.emit('agent:thinking', { conversationId, text: assistantEvent.delta, }); } break; } case 'tool_execution_start': { // Track pending tool call for later recording const cs = this.clientSessions.get(client.id); if (cs) { cs.pendingToolCalls.set(event.toolCallId, { toolName: event.toolName, args: event.args, }); } client.emit('agent:tool:start', { conversationId, toolCallId: event.toolCallId, toolName: event.toolName, }); break; } case 'tool_execution_end': { // Finalise tool call record const cs = this.clientSessions.get(client.id); if (cs) { const pending = cs.pendingToolCalls.get(event.toolCallId); cs.toolCalls.push({ toolCallId: event.toolCallId, toolName: event.toolName, args: pending?.args ?? null, isError: event.isError, }); cs.pendingToolCalls.delete(event.toolCallId); } client.emit('agent:tool:end', { conversationId, toolCallId: event.toolCallId, toolName: event.toolName, isError: event.isError, }); break; } } } }