Co-authored-by: Jason Woltje <jason@diversecanvas.com> Co-committed-by: Jason Woltje <jason@diversecanvas.com>
165 lines
4.6 KiB
TypeScript
165 lines
4.6 KiB
TypeScript
import { Logger } from '@nestjs/common';
|
|
import {
|
|
WebSocketGateway,
|
|
WebSocketServer,
|
|
SubscribeMessage,
|
|
OnGatewayConnection,
|
|
OnGatewayDisconnect,
|
|
type OnGatewayInit,
|
|
ConnectedSocket,
|
|
MessageBody,
|
|
} from '@nestjs/websockets';
|
|
import { Server, Socket } from 'socket.io';
|
|
import type { AgentSessionEvent } from '@mariozechner/pi-coding-agent';
|
|
import { AgentService } from '../agent/agent.service.js';
|
|
import { v4 as uuid } from 'uuid';
|
|
|
|
interface ChatMessage {
|
|
conversationId?: string;
|
|
content: string;
|
|
}
|
|
|
|
@WebSocketGateway({
|
|
cors: { origin: '*' },
|
|
namespace: '/chat',
|
|
})
|
|
export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect {
|
|
@WebSocketServer()
|
|
server!: Server;
|
|
|
|
private readonly logger = new Logger(ChatGateway.name);
|
|
private readonly clientSessions = new Map<
|
|
string,
|
|
{ conversationId: string; cleanup: () => void }
|
|
>();
|
|
|
|
constructor(private readonly agentService: AgentService) {}
|
|
|
|
afterInit(): void {
|
|
this.logger.log('Chat WebSocket gateway initialized');
|
|
}
|
|
|
|
handleConnection(client: Socket): void {
|
|
this.logger.log(`Client connected: ${client.id}`);
|
|
}
|
|
|
|
handleDisconnect(client: Socket): void {
|
|
this.logger.log(`Client disconnected: ${client.id}`);
|
|
const session = this.clientSessions.get(client.id);
|
|
if (session) {
|
|
session.cleanup();
|
|
this.clientSessions.delete(client.id);
|
|
}
|
|
}
|
|
|
|
@SubscribeMessage('message')
|
|
async handleMessage(
|
|
@ConnectedSocket() client: Socket,
|
|
@MessageBody() data: ChatMessage,
|
|
): Promise<void> {
|
|
const conversationId = data.conversationId ?? uuid();
|
|
|
|
this.logger.log(`Message from ${client.id} in conversation ${conversationId}`);
|
|
|
|
// Ensure agent session exists for this conversation
|
|
try {
|
|
let agentSession = this.agentService.getSession(conversationId);
|
|
if (!agentSession) {
|
|
agentSession = await this.agentService.createSession(conversationId);
|
|
}
|
|
} catch (err) {
|
|
this.logger.error(
|
|
`Session creation failed for client=${client.id}, conversation=${conversationId}`,
|
|
err instanceof Error ? err.stack : String(err),
|
|
);
|
|
client.emit('error', {
|
|
conversationId,
|
|
error: 'Failed to start agent session. Please try again.',
|
|
});
|
|
return;
|
|
}
|
|
|
|
// Always clean up previous listener to prevent leak
|
|
const existing = this.clientSessions.get(client.id);
|
|
if (existing) {
|
|
existing.cleanup();
|
|
}
|
|
|
|
// Subscribe to agent events and relay to client
|
|
const cleanup = this.agentService.onEvent(conversationId, (event: AgentSessionEvent) => {
|
|
this.relayEvent(client, conversationId, event);
|
|
});
|
|
|
|
this.clientSessions.set(client.id, { conversationId, cleanup });
|
|
|
|
// Send acknowledgment
|
|
client.emit('message:ack', { conversationId, messageId: uuid() });
|
|
|
|
// Dispatch to agent
|
|
try {
|
|
await this.agentService.prompt(conversationId, data.content);
|
|
} catch (err) {
|
|
this.logger.error(
|
|
`Agent prompt failed for client=${client.id}, conversation=${conversationId}`,
|
|
err instanceof Error ? err.stack : String(err),
|
|
);
|
|
client.emit('error', {
|
|
conversationId,
|
|
error: 'The agent failed to process your message. Please try again.',
|
|
});
|
|
}
|
|
}
|
|
|
|
private relayEvent(client: Socket, conversationId: string, event: AgentSessionEvent): void {
|
|
if (!client.connected) {
|
|
this.logger.warn(
|
|
`Dropping event ${event.type} for disconnected client=${client.id}, conversation=${conversationId}`,
|
|
);
|
|
return;
|
|
}
|
|
|
|
switch (event.type) {
|
|
case 'agent_start':
|
|
client.emit('agent:start', { conversationId });
|
|
break;
|
|
|
|
case 'agent_end':
|
|
client.emit('agent:end', { conversationId });
|
|
break;
|
|
|
|
case 'message_update': {
|
|
const assistantEvent = event.assistantMessageEvent;
|
|
if (assistantEvent.type === 'text_delta') {
|
|
client.emit('agent:text', {
|
|
conversationId,
|
|
text: assistantEvent.delta,
|
|
});
|
|
} else if (assistantEvent.type === 'thinking_delta') {
|
|
client.emit('agent:thinking', {
|
|
conversationId,
|
|
text: assistantEvent.delta,
|
|
});
|
|
}
|
|
break;
|
|
}
|
|
|
|
case 'tool_execution_start':
|
|
client.emit('agent:tool:start', {
|
|
conversationId,
|
|
toolCallId: event.toolCallId,
|
|
toolName: event.toolName,
|
|
});
|
|
break;
|
|
|
|
case 'tool_execution_end':
|
|
client.emit('agent:tool:end', {
|
|
conversationId,
|
|
toolCallId: event.toolCallId,
|
|
toolName: event.toolName,
|
|
isError: event.isError,
|
|
});
|
|
break;
|
|
}
|
|
}
|
|
}
|