import { Inject, Logger } from '@nestjs/common'; import { WebSocketGateway, WebSocketServer, SubscribeMessage, OnGatewayConnection, OnGatewayDisconnect, type OnGatewayInit, ConnectedSocket, MessageBody, } from '@nestjs/websockets'; import { Server, Socket } from 'socket.io'; import type { AgentSessionEvent } from '@mariozechner/pi-coding-agent'; import type { Auth } from '@mosaicstack/auth'; import type { Brain } from '@mosaicstack/brain'; import type { SetThinkingPayload, SlashCommandPayload, SystemReloadPayload, RoutingDecisionInfo, AbortPayload, } from '@mosaicstack/types'; import { AgentService, type ConversationHistoryMessage } from '../agent/agent.service.js'; import { AUTH } from '../auth/auth.tokens.js'; import { BRAIN } from '../brain/brain.tokens.js'; import { CommandRegistryService } from '../commands/command-registry.service.js'; import { CommandExecutorService } from '../commands/command-executor.service.js'; import { RoutingEngineService } from '../agent/routing/routing-engine.service.js'; import { v4 as uuid } from 'uuid'; import { ChatSocketMessageDto } from './chat.dto.js'; import { validateSocketSession } from './chat.gateway-auth.js'; /** Per-client state tracking streaming accumulation for persistence. */ interface ClientSession { conversationId: string; cleanup: () => void; /** Accumulated assistant response text for the current turn. */ assistantText: string; /** Tool calls observed during the current turn. */ toolCalls: Array<{ toolCallId: string; toolName: string; args: unknown; isError: boolean }>; /** Tool calls in-flight (started but not ended yet). */ pendingToolCalls: Map; /** Last routing decision made for this session (M4-008) */ lastRoutingDecision?: RoutingDecisionInfo; } /** * Per-conversation model overrides set via /model command (M4-007). * Keyed by conversationId, value is the model name to use. */ const modelOverrides = new Map(); @WebSocketGateway({ cors: { origin: process.env['GATEWAY_CORS_ORIGIN'] ?? 'http://localhost:3000', }, namespace: '/chat', }) export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect { @WebSocketServer() server!: Server; private readonly logger = new Logger(ChatGateway.name); private readonly clientSessions = new Map(); constructor( @Inject(AgentService) private readonly agentService: AgentService, @Inject(AUTH) private readonly auth: Auth, @Inject(BRAIN) private readonly brain: Brain, @Inject(CommandRegistryService) private readonly commandRegistry: CommandRegistryService, @Inject(CommandExecutorService) private readonly commandExecutor: CommandExecutorService, @Inject(RoutingEngineService) private readonly routingEngine: RoutingEngineService, ) {} afterInit(): void { this.logger.log('Chat WebSocket gateway initialized'); } async handleConnection(client: Socket): Promise { const session = await validateSocketSession(client.handshake.headers, this.auth); if (!session) { this.logger.warn(`Rejected unauthenticated WebSocket client: ${client.id}`); client.disconnect(); return; } client.data.user = session.user; client.data.session = session.session; this.logger.log(`Client connected: ${client.id}`); // Broadcast command manifest to the newly connected client client.emit('commands:manifest', { manifest: this.commandRegistry.getManifest() }); } handleDisconnect(client: Socket): void { this.logger.log(`Client disconnected: ${client.id}`); const session = this.clientSessions.get(client.id); if (session) { session.cleanup(); this.agentService.removeChannel(session.conversationId, `websocket:${client.id}`); this.clientSessions.delete(client.id); } } @SubscribeMessage('message') async handleMessage( @ConnectedSocket() client: Socket, @MessageBody() data: ChatSocketMessageDto, ): Promise { const conversationId = data.conversationId ?? uuid(); const userId = (client.data.user as { id: string } | undefined)?.id; this.logger.log(`Message from ${client.id} in conversation ${conversationId}`); // Ensure agent session exists for this conversation let sessionRoutingDecision: RoutingDecisionInfo | undefined; try { let agentSession = this.agentService.getSession(conversationId); if (!agentSession) { // When resuming an existing conversation, load prior messages to inject as context (M1-004) const conversationHistory = await this.loadConversationHistory(conversationId, userId); // M5-004: Check if there's an existing sessionId bound to this conversation let existingSessionId: string | undefined; if (userId) { existingSessionId = await this.getConversationSessionId(conversationId, userId); if (existingSessionId) { this.logger.log( `Resuming existing sessionId=${existingSessionId} for conversation=${conversationId}`, ); } } // Determine provider/model via routing engine or per-session /model override (M4-012 / M4-007) let resolvedProvider = data.provider; let resolvedModelId = data.modelId; const modelOverride = modelOverrides.get(conversationId); if (modelOverride) { // /model override bypasses routing engine (M4-007) resolvedModelId = modelOverride; this.logger.log( `Using /model override "${modelOverride}" for conversation=${conversationId}`, ); } else if (!resolvedProvider && !resolvedModelId) { // No explicit provider/model from client — use routing engine (M4-012) try { const routingDecision = await this.routingEngine.resolve(data.content, userId); resolvedProvider = routingDecision.provider; resolvedModelId = routingDecision.model; sessionRoutingDecision = { model: routingDecision.model, provider: routingDecision.provider, ruleName: routingDecision.ruleName, reason: routingDecision.reason, }; this.logger.log( `Routing decision for conversation=${conversationId}: ${routingDecision.provider}/${routingDecision.model} (rule="${routingDecision.ruleName}")`, ); } catch (routingErr) { this.logger.warn( `Routing engine failed for conversation=${conversationId}, using defaults`, routingErr instanceof Error ? routingErr.message : String(routingErr), ); } } // M5-004: Use existingSessionId as sessionId when available (session reuse) const sessionIdToCreate = existingSessionId ?? conversationId; agentSession = await this.agentService.createSession(sessionIdToCreate, { provider: resolvedProvider, modelId: resolvedModelId, agentConfigId: data.agentId, userId, conversationHistory: conversationHistory.length > 0 ? conversationHistory : undefined, }); if (conversationHistory.length > 0) { this.logger.log( `Loaded ${conversationHistory.length} prior messages for conversation=${conversationId}`, ); } } } catch (err) { this.logger.error( `Session creation failed for client=${client.id}, conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); client.emit('error', { conversationId, error: 'Failed to start agent session. Please try again.', }); return; } // Ensure conversation record exists in the DB before persisting messages // M5-004: Also bind the sessionId to the conversation record if (userId) { await this.ensureConversation(conversationId, userId); await this.bindSessionToConversation(conversationId, userId, conversationId); } // M5-007: Count the user message this.agentService.recordMessage(conversationId); // Persist the user message if (userId) { try { await this.brain.conversations.addMessage( { conversationId, role: 'user', content: data.content, metadata: { timestamp: new Date().toISOString(), }, }, userId, ); } catch (err) { this.logger.error( `Failed to persist user message for conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); } } // Always clean up previous listener to prevent leak const existing = this.clientSessions.get(client.id); if (existing) { existing.cleanup(); } // Subscribe to agent events and relay to client const cleanup = this.agentService.onEvent(conversationId, (event: AgentSessionEvent) => { this.relayEvent(client, conversationId, event); }); // Preserve routing decision from the existing client session if we didn't get a new one const prevClientSession = this.clientSessions.get(client.id); const routingDecisionToStore = sessionRoutingDecision ?? prevClientSession?.lastRoutingDecision; this.clientSessions.set(client.id, { conversationId, cleanup, assistantText: '', toolCalls: [], pendingToolCalls: new Map(), lastRoutingDecision: routingDecisionToStore, }); // Track channel connection this.agentService.addChannel(conversationId, `websocket:${client.id}`); // Send session info so the client knows the model/provider (M4-008: include routing decision) // Include agentName when a named agent config is active (M5-001) { const agentSession = this.agentService.getSession(conversationId); if (agentSession) { const piSession = agentSession.piSession; client.emit('session:info', { conversationId, provider: agentSession.provider, modelId: agentSession.modelId, thinkingLevel: piSession.thinkingLevel, availableThinkingLevels: piSession.getAvailableThinkingLevels(), ...(agentSession.agentName ? { agentName: agentSession.agentName } : {}), ...(routingDecisionToStore ? { routingDecision: routingDecisionToStore } : {}), }); } } // Send acknowledgment client.emit('message:ack', { conversationId, messageId: uuid() }); // Dispatch to agent try { await this.agentService.prompt(conversationId, data.content); } catch (err) { this.logger.error( `Agent prompt failed for client=${client.id}, conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); client.emit('error', { conversationId, error: 'The agent failed to process your message. Please try again.', }); } } @SubscribeMessage('set:thinking') handleSetThinking( @ConnectedSocket() client: Socket, @MessageBody() data: SetThinkingPayload, ): void { const session = this.agentService.getSession(data.conversationId); if (!session) { client.emit('error', { conversationId: data.conversationId, error: 'No active session for this conversation.', }); return; } const validLevels = session.piSession.getAvailableThinkingLevels(); if (!validLevels.includes(data.level as never)) { client.emit('error', { conversationId: data.conversationId, error: `Invalid thinking level "${data.level}". Available: ${validLevels.join(', ')}`, }); return; } session.piSession.setThinkingLevel(data.level as never); this.logger.log( `Thinking level set to "${data.level}" for conversation ${data.conversationId}`, ); client.emit('session:info', { conversationId: data.conversationId, provider: session.provider, modelId: session.modelId, thinkingLevel: session.piSession.thinkingLevel, availableThinkingLevels: session.piSession.getAvailableThinkingLevels(), ...(session.agentName ? { agentName: session.agentName } : {}), }); } @SubscribeMessage('abort') async handleAbort( @ConnectedSocket() client: Socket, @MessageBody() data: AbortPayload, ): Promise { const conversationId = data.conversationId; this.logger.log(`Abort requested by ${client.id} for conversation ${conversationId}`); const session = this.agentService.getSession(conversationId); if (!session) { client.emit('error', { conversationId, error: 'No active session to abort.', }); return; } try { await session.piSession.abort(); this.logger.log(`Agent session ${conversationId} aborted successfully`); } catch (err) { this.logger.error( `Failed to abort session ${conversationId}`, err instanceof Error ? err.stack : String(err), ); client.emit('error', { conversationId, error: 'Failed to abort the agent operation.', }); } } @SubscribeMessage('command:execute') async handleCommandExecute( @ConnectedSocket() client: Socket, @MessageBody() payload: SlashCommandPayload, ): Promise { const userId = (client.data.user as { id: string } | undefined)?.id ?? 'unknown'; const result = await this.commandExecutor.execute(payload, userId); client.emit('command:result', result); } broadcastReload(payload: SystemReloadPayload): void { this.server.emit('system:reload', payload); this.logger.log('Broadcasted system:reload to all connected clients'); } /** * Set a per-conversation model override (M4-007 / M5-002). * When set, the routing engine is bypassed and the specified model is used. * Pass null to clear the override and resume automatic routing. * M5-005: Emits session:info to clients subscribed to this conversation when a model is set. * M5-007: Records a model switch in session metrics. */ setModelOverride(conversationId: string, modelName: string | null): void { if (modelName) { modelOverrides.set(conversationId, modelName); this.logger.log(`Model override set: conversation=${conversationId} model="${modelName}"`); // M5-002: Update the live session's modelId so session:info reflects the new model immediately this.agentService.updateSessionModel(conversationId, modelName); // M5-005: Broadcast session:info to all clients subscribed to this conversation this.broadcastSessionInfo(conversationId); } else { modelOverrides.delete(conversationId); this.logger.log(`Model override cleared: conversation=${conversationId}`); } } /** * Return the active model override for a conversation, or undefined if none. */ getModelOverride(conversationId: string): string | undefined { return modelOverrides.get(conversationId); } /** * M5-005: Broadcast session:info to all clients currently subscribed to a conversation. * Called on model or agent switch to ensure the TUI TopBar updates immediately. */ broadcastSessionInfo( conversationId: string, extra?: { agentName?: string; routingDecision?: RoutingDecisionInfo }, ): void { const agentSession = this.agentService.getSession(conversationId); if (!agentSession) return; const piSession = agentSession.piSession; const resolvedAgentName = extra?.agentName ?? agentSession.agentName; const payload = { conversationId, provider: agentSession.provider, modelId: agentSession.modelId, thinkingLevel: piSession.thinkingLevel, availableThinkingLevels: piSession.getAvailableThinkingLevels(), ...(resolvedAgentName ? { agentName: resolvedAgentName } : {}), ...(extra?.routingDecision ? { routingDecision: extra.routingDecision } : {}), }; // Emit to all clients currently subscribed to this conversation for (const [clientId, session] of this.clientSessions) { if (session.conversationId === conversationId) { const socket = this.server.sockets.sockets.get(clientId); if (socket?.connected) { socket.emit('session:info', payload); } } } } /** * Ensure a conversation record exists in the DB. * Creates it if absent — safe to call concurrently since a duplicate insert * would fail on the PK constraint and be caught here. */ private async ensureConversation(conversationId: string, userId: string): Promise { try { const existing = await this.brain.conversations.findById(conversationId, userId); if (!existing) { await this.brain.conversations.create({ id: conversationId, userId, }); } } catch (err) { this.logger.error( `Failed to ensure conversation record for conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); } } /** * M5-004: Bind the agent sessionId to the conversation record in the DB. * Updates the sessionId column so future resumes can reuse the session. */ private async bindSessionToConversation( conversationId: string, userId: string, sessionId: string, ): Promise { try { await this.brain.conversations.update(conversationId, userId, { sessionId }); } catch (err) { this.logger.error( `Failed to bind sessionId=${sessionId} to conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); } } /** * M5-004: Retrieve the sessionId bound to a conversation, if any. * Returns undefined when the conversation does not exist or has no bound session. */ private async getConversationSessionId( conversationId: string, userId: string, ): Promise { try { const conv = await this.brain.conversations.findById(conversationId, userId); return conv?.sessionId ?? undefined; } catch (err) { this.logger.error( `Failed to get sessionId for conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); return undefined; } } /** * Load prior conversation messages from DB for context injection on session resume (M1-004). * Returns an empty array when no history exists, the conversation is not owned by the user, * or userId is not provided. */ private async loadConversationHistory( conversationId: string, userId: string | undefined, ): Promise { if (!userId) return []; try { const messages = await this.brain.conversations.findMessages(conversationId, userId); if (messages.length === 0) return []; return messages.map((msg) => ({ role: msg.role as 'user' | 'assistant' | 'system', content: msg.content, createdAt: msg.createdAt, })); } catch (err) { this.logger.error( `Failed to load conversation history for conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); return []; } } private relayEvent(client: Socket, conversationId: string, event: AgentSessionEvent): void { if (!client.connected) { this.logger.warn( `Dropping event ${event.type} for disconnected client=${client.id}, conversation=${conversationId}`, ); return; } switch (event.type) { case 'agent_start': { // Reset accumulation buffers for the new turn const cs = this.clientSessions.get(client.id); if (cs) { cs.assistantText = ''; cs.toolCalls = []; cs.pendingToolCalls.clear(); } client.emit('agent:start', { conversationId }); break; } case 'agent_end': { // Gather usage stats from the Pi session const agentSession = this.agentService.getSession(conversationId); const piSession = agentSession?.piSession; const stats = piSession?.getSessionStats(); const contextUsage = piSession?.getContextUsage(); const usagePayload = stats ? { provider: agentSession?.provider ?? 'unknown', modelId: agentSession?.modelId ?? 'unknown', thinkingLevel: piSession?.thinkingLevel ?? 'off', tokens: stats.tokens, cost: stats.cost, context: { percent: contextUsage?.percent ?? null, window: contextUsage?.contextWindow ?? 0, }, } : undefined; client.emit('agent:end', { conversationId, usage: usagePayload, }); // M5-007: Accumulate token usage in session metrics if (stats?.tokens) { this.agentService.recordTokenUsage(conversationId, { input: stats.tokens.input ?? 0, output: stats.tokens.output ?? 0, cacheRead: stats.tokens.cacheRead ?? 0, cacheWrite: stats.tokens.cacheWrite ?? 0, total: stats.tokens.total ?? 0, }); } // Persist the assistant message with metadata const cs = this.clientSessions.get(client.id); const userId = (client.data.user as { id: string } | undefined)?.id; if (cs && userId && cs.assistantText.trim().length > 0) { const metadata: Record = { timestamp: new Date().toISOString(), model: agentSession?.modelId ?? 'unknown', provider: agentSession?.provider ?? 'unknown', toolCalls: cs.toolCalls, }; if (stats?.tokens) { metadata['tokenUsage'] = { input: stats.tokens.input, output: stats.tokens.output, cacheRead: stats.tokens.cacheRead, cacheWrite: stats.tokens.cacheWrite, total: stats.tokens.total, }; } this.brain.conversations .addMessage( { conversationId, role: 'assistant', content: cs.assistantText, metadata, }, userId, ) .catch((err: unknown) => { this.logger.error( `Failed to persist assistant message for conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); }); // Reset accumulation cs.assistantText = ''; cs.toolCalls = []; cs.pendingToolCalls.clear(); } break; } case 'message_update': { const assistantEvent = event.assistantMessageEvent; if (assistantEvent.type === 'text_delta') { // Accumulate assistant text for persistence const cs = this.clientSessions.get(client.id); if (cs) { cs.assistantText += assistantEvent.delta; } client.emit('agent:text', { conversationId, text: assistantEvent.delta, }); } else if (assistantEvent.type === 'thinking_delta') { client.emit('agent:thinking', { conversationId, text: assistantEvent.delta, }); } break; } case 'tool_execution_start': { // Track pending tool call for later recording const cs = this.clientSessions.get(client.id); if (cs) { cs.pendingToolCalls.set(event.toolCallId, { toolName: event.toolName, args: event.args, }); } client.emit('agent:tool:start', { conversationId, toolCallId: event.toolCallId, toolName: event.toolName, }); break; } case 'tool_execution_end': { // Finalise tool call record const cs = this.clientSessions.get(client.id); if (cs) { const pending = cs.pendingToolCalls.get(event.toolCallId); cs.toolCalls.push({ toolCallId: event.toolCallId, toolName: event.toolName, args: pending?.args ?? null, isError: event.isError, }); cs.pendingToolCalls.delete(event.toolCallId); } client.emit('agent:tool:end', { conversationId, toolCallId: event.toolCallId, toolName: event.toolName, isError: event.isError, }); break; } } } }