All checks were successful
ci/woodpecker/push/ci Pipeline was successful
Co-authored-by: Jason Woltje <jason@diversecanvas.com> Co-committed-by: Jason Woltje <jason@diversecanvas.com>
285 lines
9.3 KiB
TypeScript
285 lines
9.3 KiB
TypeScript
import { Inject, Logger } from '@nestjs/common';
|
|
import {
|
|
WebSocketGateway,
|
|
WebSocketServer,
|
|
SubscribeMessage,
|
|
OnGatewayConnection,
|
|
OnGatewayDisconnect,
|
|
type OnGatewayInit,
|
|
ConnectedSocket,
|
|
MessageBody,
|
|
} from '@nestjs/websockets';
|
|
import { Server, Socket } from 'socket.io';
|
|
import type { AgentSessionEvent } from '@mariozechner/pi-coding-agent';
|
|
import type { Auth } from '@mosaic/auth';
|
|
import type { SetThinkingPayload, SlashCommandPayload, SystemReloadPayload } from '@mosaic/types';
|
|
import { AgentService } from '../agent/agent.service.js';
|
|
import { AUTH } from '../auth/auth.tokens.js';
|
|
import { CommandRegistryService } from '../commands/command-registry.service.js';
|
|
import { CommandExecutorService } from '../commands/command-executor.service.js';
|
|
import { v4 as uuid } from 'uuid';
|
|
import { ChatSocketMessageDto } from './chat.dto.js';
|
|
import { validateSocketSession } from './chat.gateway-auth.js';
|
|
|
|
@WebSocketGateway({
|
|
cors: {
|
|
origin: process.env['GATEWAY_CORS_ORIGIN'] ?? 'http://localhost:3000',
|
|
},
|
|
namespace: '/chat',
|
|
})
|
|
export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect {
|
|
@WebSocketServer()
|
|
server!: Server;
|
|
|
|
private readonly logger = new Logger(ChatGateway.name);
|
|
private readonly clientSessions = new Map<
|
|
string,
|
|
{ conversationId: string; cleanup: () => void }
|
|
>();
|
|
|
|
constructor(
|
|
@Inject(AgentService) private readonly agentService: AgentService,
|
|
@Inject(AUTH) private readonly auth: Auth,
|
|
@Inject(CommandRegistryService) private readonly commandRegistry: CommandRegistryService,
|
|
@Inject(CommandExecutorService) private readonly commandExecutor: CommandExecutorService,
|
|
) {}
|
|
|
|
afterInit(): void {
|
|
this.logger.log('Chat WebSocket gateway initialized');
|
|
}
|
|
|
|
async handleConnection(client: Socket): Promise<void> {
|
|
const session = await validateSocketSession(client.handshake.headers, this.auth);
|
|
if (!session) {
|
|
this.logger.warn(`Rejected unauthenticated WebSocket client: ${client.id}`);
|
|
client.disconnect();
|
|
return;
|
|
}
|
|
|
|
client.data.user = session.user;
|
|
client.data.session = session.session;
|
|
this.logger.log(`Client connected: ${client.id}`);
|
|
|
|
// Broadcast command manifest to the newly connected client
|
|
client.emit('commands:manifest', { manifest: this.commandRegistry.getManifest() });
|
|
}
|
|
|
|
handleDisconnect(client: Socket): void {
|
|
this.logger.log(`Client disconnected: ${client.id}`);
|
|
const session = this.clientSessions.get(client.id);
|
|
if (session) {
|
|
session.cleanup();
|
|
this.agentService.removeChannel(session.conversationId, `websocket:${client.id}`);
|
|
this.clientSessions.delete(client.id);
|
|
}
|
|
}
|
|
|
|
@SubscribeMessage('message')
|
|
async handleMessage(
|
|
@ConnectedSocket() client: Socket,
|
|
@MessageBody() data: ChatSocketMessageDto,
|
|
): Promise<void> {
|
|
const conversationId = data.conversationId ?? uuid();
|
|
|
|
this.logger.log(`Message from ${client.id} in conversation ${conversationId}`);
|
|
|
|
// Ensure agent session exists for this conversation
|
|
try {
|
|
let agentSession = this.agentService.getSession(conversationId);
|
|
if (!agentSession) {
|
|
const userId = (client.data.user as { id: string } | undefined)?.id;
|
|
agentSession = await this.agentService.createSession(conversationId, {
|
|
provider: data.provider,
|
|
modelId: data.modelId,
|
|
agentConfigId: data.agentId,
|
|
userId,
|
|
});
|
|
}
|
|
} catch (err) {
|
|
this.logger.error(
|
|
`Session creation failed for client=${client.id}, conversation=${conversationId}`,
|
|
err instanceof Error ? err.stack : String(err),
|
|
);
|
|
client.emit('error', {
|
|
conversationId,
|
|
error: 'Failed to start agent session. Please try again.',
|
|
});
|
|
return;
|
|
}
|
|
|
|
// Always clean up previous listener to prevent leak
|
|
const existing = this.clientSessions.get(client.id);
|
|
if (existing) {
|
|
existing.cleanup();
|
|
}
|
|
|
|
// Subscribe to agent events and relay to client
|
|
const cleanup = this.agentService.onEvent(conversationId, (event: AgentSessionEvent) => {
|
|
this.relayEvent(client, conversationId, event);
|
|
});
|
|
|
|
this.clientSessions.set(client.id, { conversationId, cleanup });
|
|
|
|
// Track channel connection
|
|
this.agentService.addChannel(conversationId, `websocket:${client.id}`);
|
|
|
|
// Send session info so the client knows the model/provider
|
|
{
|
|
const agentSession = this.agentService.getSession(conversationId);
|
|
if (agentSession) {
|
|
const piSession = agentSession.piSession;
|
|
client.emit('session:info', {
|
|
conversationId,
|
|
provider: agentSession.provider,
|
|
modelId: agentSession.modelId,
|
|
thinkingLevel: piSession.thinkingLevel,
|
|
availableThinkingLevels: piSession.getAvailableThinkingLevels(),
|
|
});
|
|
}
|
|
}
|
|
|
|
// Send acknowledgment
|
|
client.emit('message:ack', { conversationId, messageId: uuid() });
|
|
|
|
// Dispatch to agent
|
|
try {
|
|
await this.agentService.prompt(conversationId, data.content);
|
|
} catch (err) {
|
|
this.logger.error(
|
|
`Agent prompt failed for client=${client.id}, conversation=${conversationId}`,
|
|
err instanceof Error ? err.stack : String(err),
|
|
);
|
|
client.emit('error', {
|
|
conversationId,
|
|
error: 'The agent failed to process your message. Please try again.',
|
|
});
|
|
}
|
|
}
|
|
|
|
@SubscribeMessage('set:thinking')
|
|
handleSetThinking(
|
|
@ConnectedSocket() client: Socket,
|
|
@MessageBody() data: SetThinkingPayload,
|
|
): void {
|
|
const session = this.agentService.getSession(data.conversationId);
|
|
if (!session) {
|
|
client.emit('error', {
|
|
conversationId: data.conversationId,
|
|
error: 'No active session for this conversation.',
|
|
});
|
|
return;
|
|
}
|
|
|
|
const validLevels = session.piSession.getAvailableThinkingLevels();
|
|
if (!validLevels.includes(data.level as never)) {
|
|
client.emit('error', {
|
|
conversationId: data.conversationId,
|
|
error: `Invalid thinking level "${data.level}". Available: ${validLevels.join(', ')}`,
|
|
});
|
|
return;
|
|
}
|
|
|
|
session.piSession.setThinkingLevel(data.level as never);
|
|
this.logger.log(
|
|
`Thinking level set to "${data.level}" for conversation ${data.conversationId}`,
|
|
);
|
|
|
|
client.emit('session:info', {
|
|
conversationId: data.conversationId,
|
|
provider: session.provider,
|
|
modelId: session.modelId,
|
|
thinkingLevel: session.piSession.thinkingLevel,
|
|
availableThinkingLevels: session.piSession.getAvailableThinkingLevels(),
|
|
});
|
|
}
|
|
|
|
@SubscribeMessage('command:execute')
|
|
async handleCommandExecute(
|
|
@ConnectedSocket() client: Socket,
|
|
@MessageBody() payload: SlashCommandPayload,
|
|
): Promise<void> {
|
|
const userId = (client.data.user as { id: string } | undefined)?.id ?? 'unknown';
|
|
const result = await this.commandExecutor.execute(payload, userId);
|
|
client.emit('command:result', result);
|
|
}
|
|
|
|
broadcastReload(payload: SystemReloadPayload): void {
|
|
this.server.emit('system:reload', payload);
|
|
this.logger.log('Broadcasted system:reload to all connected clients');
|
|
}
|
|
|
|
private relayEvent(client: Socket, conversationId: string, event: AgentSessionEvent): void {
|
|
if (!client.connected) {
|
|
this.logger.warn(
|
|
`Dropping event ${event.type} for disconnected client=${client.id}, conversation=${conversationId}`,
|
|
);
|
|
return;
|
|
}
|
|
|
|
switch (event.type) {
|
|
case 'agent_start':
|
|
client.emit('agent:start', { conversationId });
|
|
break;
|
|
|
|
case 'agent_end': {
|
|
// Gather usage stats from the Pi session
|
|
const agentSession = this.agentService.getSession(conversationId);
|
|
const piSession = agentSession?.piSession;
|
|
const stats = piSession?.getSessionStats();
|
|
const contextUsage = piSession?.getContextUsage();
|
|
|
|
client.emit('agent:end', {
|
|
conversationId,
|
|
usage: stats
|
|
? {
|
|
provider: agentSession?.provider ?? 'unknown',
|
|
modelId: agentSession?.modelId ?? 'unknown',
|
|
thinkingLevel: piSession?.thinkingLevel ?? 'off',
|
|
tokens: stats.tokens,
|
|
cost: stats.cost,
|
|
context: {
|
|
percent: contextUsage?.percent ?? null,
|
|
window: contextUsage?.contextWindow ?? 0,
|
|
},
|
|
}
|
|
: undefined,
|
|
});
|
|
break;
|
|
}
|
|
|
|
case 'message_update': {
|
|
const assistantEvent = event.assistantMessageEvent;
|
|
if (assistantEvent.type === 'text_delta') {
|
|
client.emit('agent:text', {
|
|
conversationId,
|
|
text: assistantEvent.delta,
|
|
});
|
|
} else if (assistantEvent.type === 'thinking_delta') {
|
|
client.emit('agent:thinking', {
|
|
conversationId,
|
|
text: assistantEvent.delta,
|
|
});
|
|
}
|
|
break;
|
|
}
|
|
|
|
case 'tool_execution_start':
|
|
client.emit('agent:tool:start', {
|
|
conversationId,
|
|
toolCallId: event.toolCallId,
|
|
toolName: event.toolName,
|
|
});
|
|
break;
|
|
|
|
case 'tool_execution_end':
|
|
client.emit('agent:tool:end', {
|
|
conversationId,
|
|
toolCallId: event.toolCallId,
|
|
toolName: event.toolName,
|
|
isError: event.isError,
|
|
});
|
|
break;
|
|
}
|
|
}
|
|
}
|