import { Logger } from '@nestjs/common'; import { WebSocketGateway, WebSocketServer, SubscribeMessage, OnGatewayConnection, OnGatewayDisconnect, type OnGatewayInit, ConnectedSocket, MessageBody, } from '@nestjs/websockets'; import { Server, Socket } from 'socket.io'; import type { AgentSessionEvent } from '@mariozechner/pi-coding-agent'; import { AgentService } from '../agent/agent.service.js'; import { v4 as uuid } from 'uuid'; interface ChatMessage { conversationId?: string; content: string; } @WebSocketGateway({ cors: { origin: '*' }, namespace: '/chat', }) export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect { @WebSocketServer() server!: Server; private readonly logger = new Logger(ChatGateway.name); private readonly clientSessions = new Map< string, { conversationId: string; cleanup: () => void } >(); constructor(private readonly agentService: AgentService) {} afterInit(): void { this.logger.log('Chat WebSocket gateway initialized'); } handleConnection(client: Socket): void { this.logger.log(`Client connected: ${client.id}`); } handleDisconnect(client: Socket): void { this.logger.log(`Client disconnected: ${client.id}`); const session = this.clientSessions.get(client.id); if (session) { session.cleanup(); this.clientSessions.delete(client.id); } } @SubscribeMessage('message') async handleMessage( @ConnectedSocket() client: Socket, @MessageBody() data: ChatMessage, ): Promise { const conversationId = data.conversationId ?? uuid(); this.logger.log(`Message from ${client.id} in conversation ${conversationId}`); // Ensure agent session exists for this conversation try { let agentSession = this.agentService.getSession(conversationId); if (!agentSession) { agentSession = await this.agentService.createSession(conversationId); } } catch (err) { this.logger.error( `Session creation failed for client=${client.id}, conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); client.emit('error', { conversationId, error: 'Failed to start agent session. Please try again.', }); return; } // Always clean up previous listener to prevent leak const existing = this.clientSessions.get(client.id); if (existing) { existing.cleanup(); } // Subscribe to agent events and relay to client const cleanup = this.agentService.onEvent(conversationId, (event: AgentSessionEvent) => { this.relayEvent(client, conversationId, event); }); this.clientSessions.set(client.id, { conversationId, cleanup }); // Send acknowledgment client.emit('message:ack', { conversationId, messageId: uuid() }); // Dispatch to agent try { await this.agentService.prompt(conversationId, data.content); } catch (err) { this.logger.error( `Agent prompt failed for client=${client.id}, conversation=${conversationId}`, err instanceof Error ? err.stack : String(err), ); client.emit('error', { conversationId, error: 'The agent failed to process your message. Please try again.', }); } } private relayEvent(client: Socket, conversationId: string, event: AgentSessionEvent): void { if (!client.connected) { this.logger.warn( `Dropping event ${event.type} for disconnected client=${client.id}, conversation=${conversationId}`, ); return; } switch (event.type) { case 'agent_start': client.emit('agent:start', { conversationId }); break; case 'agent_end': client.emit('agent:end', { conversationId }); break; case 'message_update': { const assistantEvent = event.assistantMessageEvent; if (assistantEvent.type === 'text_delta') { client.emit('agent:text', { conversationId, text: assistantEvent.delta, }); } else if (assistantEvent.type === 'thinking_delta') { client.emit('agent:thinking', { conversationId, text: assistantEvent.delta, }); } break; } case 'tool_execution_start': client.emit('agent:tool:start', { conversationId, toolCallId: event.toolCallId, toolName: event.toolName, }); break; case 'tool_execution_end': client.emit('agent:tool:end', { conversationId, toolCallId: event.toolCallId, toolName: event.toolName, isError: event.isError, }); break; } } }