import { Inject, Logger } from '@nestjs/common'; import { WebSocketGateway, WebSocketServer, SubscribeMessage, OnGatewayConnection, OnGatewayDisconnect, type OnGatewayInit, ConnectedSocket, MessageBody, } from '@nestjs/websockets'; import { Server, Socket } from 'socket.io'; import type { AgentSessionEvent } from '@mariozechner/pi-coding-agent'; import type { Auth } from '@mosaic/auth'; import type { SetThinkingPayload, SlashCommandPayload, SystemReloadPayload } from '@mosaic/types'; import { AgentService } from '../agent/agent.service.js'; import { AUTH } from '../auth/auth.tokens.js'; import { CommandRegistryService } from '../commands/command-registry.service.js'; import { CommandExecutorService } from '../commands/command-executor.service.js'; import { v4 as uuid } from 'uuid'; import { ChatSocketMessageDto } from './chat.dto.js'; import { validateSocketSession } from './chat.gateway-auth.js'; @WebSocketGateway({ cors: { origin: process.env['GATEWAY_CORS_ORIGIN'] ?? 'http://localhost:3000', }, namespace: '/chat', }) export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect { @WebSocketServer() server!: Server; private readonly logger = new Logger(ChatGateway.name); private readonly clientSessions = new Map< string, { conversationId: string; cleanup: () => void } >(); constructor( @Inject(AgentService) private readonly agentService: AgentService, @Inject(AUTH) private readonly auth: Auth, @Inject(CommandRegistryService) private readonly commandRegistry: CommandRegistryService, @Inject(CommandExecutorService) private readonly commandExecutor: CommandExecutorService, ) {} afterInit(): void { this.logger.log('Chat WebSocket gateway initialized'); } async handleConnection(client: Socket): Promise { const session = await validateSocketSession(client.handshake.headers, this.auth); if (!session) { this.logger.warn(`Rejected unauthenticated WebSocket client: ${client.id}`); client.disconnect(); return; } client.data.user = session.user; client.data.session = session.session; this.logger.log(`Client connected: ${client.id}`); // Broadcast command manifest to the newly connected client client.emit('commands:manifest', { manifest: this.commandRegistry.getManifest() }); } handleDisconnect(client: Socket): void { this.logger.log(`Client disconnected: ${client.id}`); const session = this.clientSessions.get(client.id); if (session) { session.cleanup(); this.agentService.removeChannel(session.conversationId, `websocket:${client.id}`); this.clientSessions.delete(client.id); } } @SubscribeMessage('message') async handleMessage( @ConnectedSocket() client: Socket, @MessageBody() data: ChatSocketMessageDto, ): Promise { const conversationId = data.conversationId ?? uuid(); this.logger.log(`Message from ${client.id} in conversation ${conversationId}`); // Ensure agent session exists for this conversation try { let agentSession = this.agentService.getSession(conversationId); if (!agentSession) { const userId = (client.data.user as { id: string } | undefined)?.id; agentSession = await this.agentService.createSession(conversationId, { provider: data.provider, modelId: data.modelId, agentConfigId: data.agentId, userId, }); } } catch (err) { this.logger.error( `Session creation failed for client=${client.id}, conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); client.emit('error', { conversationId, error: 'Failed to start agent session. Please try again.', }); return; } // Always clean up previous listener to prevent leak const existing = this.clientSessions.get(client.id); if (existing) { existing.cleanup(); } // Subscribe to agent events and relay to client const cleanup = this.agentService.onEvent(conversationId, (event: AgentSessionEvent) => { this.relayEvent(client, conversationId, event); }); this.clientSessions.set(client.id, { conversationId, cleanup }); // Track channel connection this.agentService.addChannel(conversationId, `websocket:${client.id}`); // Send session info so the client knows the model/provider { const agentSession = this.agentService.getSession(conversationId); if (agentSession) { const piSession = agentSession.piSession; client.emit('session:info', { conversationId, provider: agentSession.provider, modelId: agentSession.modelId, thinkingLevel: piSession.thinkingLevel, availableThinkingLevels: piSession.getAvailableThinkingLevels(), }); } } // Send acknowledgment client.emit('message:ack', { conversationId, messageId: uuid() }); // Dispatch to agent try { await this.agentService.prompt(conversationId, data.content); } catch (err) { this.logger.error( `Agent prompt failed for client=${client.id}, conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); client.emit('error', { conversationId, error: 'The agent failed to process your message. Please try again.', }); } } @SubscribeMessage('set:thinking') handleSetThinking( @ConnectedSocket() client: Socket, @MessageBody() data: SetThinkingPayload, ): void { const session = this.agentService.getSession(data.conversationId); if (!session) { client.emit('error', { conversationId: data.conversationId, error: 'No active session for this conversation.', }); return; } const validLevels = session.piSession.getAvailableThinkingLevels(); if (!validLevels.includes(data.level as never)) { client.emit('error', { conversationId: data.conversationId, error: `Invalid thinking level "${data.level}". Available: ${validLevels.join(', ')}`, }); return; } session.piSession.setThinkingLevel(data.level as never); this.logger.log( `Thinking level set to "${data.level}" for conversation ${data.conversationId}`, ); client.emit('session:info', { conversationId: data.conversationId, provider: session.provider, modelId: session.modelId, thinkingLevel: session.piSession.thinkingLevel, availableThinkingLevels: session.piSession.getAvailableThinkingLevels(), }); } @SubscribeMessage('command:execute') async handleCommandExecute( @ConnectedSocket() client: Socket, @MessageBody() payload: SlashCommandPayload, ): Promise { const userId = (client.data.user as { id: string } | undefined)?.id ?? 'unknown'; const result = await this.commandExecutor.execute(payload, userId); client.emit('command:result', result); } broadcastReload(payload: SystemReloadPayload): void { this.server.emit('system:reload', payload); this.logger.log('Broadcasted system:reload to all connected clients'); } private relayEvent(client: Socket, conversationId: string, event: AgentSessionEvent): void { if (!client.connected) { this.logger.warn( `Dropping event ${event.type} for disconnected client=${client.id}, conversation=${conversationId}`, ); return; } switch (event.type) { case 'agent_start': client.emit('agent:start', { conversationId }); break; case 'agent_end': { // Gather usage stats from the Pi session const agentSession = this.agentService.getSession(conversationId); const piSession = agentSession?.piSession; const stats = piSession?.getSessionStats(); const contextUsage = piSession?.getContextUsage(); client.emit('agent:end', { conversationId, usage: stats ? { provider: agentSession?.provider ?? 'unknown', modelId: agentSession?.modelId ?? 'unknown', thinkingLevel: piSession?.thinkingLevel ?? 'off', tokens: stats.tokens, cost: stats.cost, context: { percent: contextUsage?.percent ?? null, window: contextUsage?.contextWindow ?? 0, }, } : undefined, }); break; } case 'message_update': { const assistantEvent = event.assistantMessageEvent; if (assistantEvent.type === 'text_delta') { client.emit('agent:text', { conversationId, text: assistantEvent.delta, }); } else if (assistantEvent.type === 'thinking_delta') { client.emit('agent:thinking', { conversationId, text: assistantEvent.delta, }); } break; } case 'tool_execution_start': client.emit('agent:tool:start', { conversationId, toolCallId: event.toolCallId, toolName: event.toolName, }); break; case 'tool_execution_end': client.emit('agent:tool:end', { conversationId, toolCallId: event.toolCallId, toolName: event.toolName, isError: event.isError, }); break; } } }