diff --git a/apps/gateway/src/chat/chat.gateway.ts b/apps/gateway/src/chat/chat.gateway.ts index 8d68249..1d4916e 100644 --- a/apps/gateway/src/chat/chat.gateway.ts +++ b/apps/gateway/src/chat/chat.gateway.ts @@ -12,15 +12,29 @@ import { import { Server, Socket } from 'socket.io'; import type { AgentSessionEvent } from '@mariozechner/pi-coding-agent'; import type { Auth } from '@mosaic/auth'; +import type { Brain } from '@mosaic/brain'; import type { SetThinkingPayload, SlashCommandPayload, SystemReloadPayload } from '@mosaic/types'; import { AgentService } from '../agent/agent.service.js'; import { AUTH } from '../auth/auth.tokens.js'; +import { BRAIN } from '../brain/brain.tokens.js'; import { CommandRegistryService } from '../commands/command-registry.service.js'; import { CommandExecutorService } from '../commands/command-executor.service.js'; import { v4 as uuid } from 'uuid'; import { ChatSocketMessageDto } from './chat.dto.js'; import { validateSocketSession } from './chat.gateway-auth.js'; +/** Per-client state tracking streaming accumulation for persistence. */ +interface ClientSession { + conversationId: string; + cleanup: () => void; + /** Accumulated assistant response text for the current turn. */ + assistantText: string; + /** Tool calls observed during the current turn. */ + toolCalls: Array<{ toolCallId: string; toolName: string; args: unknown; isError: boolean }>; + /** Tool calls in-flight (started but not ended yet). */ + pendingToolCalls: Map; +} + @WebSocketGateway({ cors: { origin: process.env['GATEWAY_CORS_ORIGIN'] ?? 'http://localhost:3000', @@ -32,14 +46,12 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa server!: Server; private readonly logger = new Logger(ChatGateway.name); - private readonly clientSessions = new Map< - string, - { conversationId: string; cleanup: () => void } - >(); + private readonly clientSessions = new Map(); constructor( @Inject(AgentService) private readonly agentService: AgentService, @Inject(AUTH) private readonly auth: Auth, + @Inject(BRAIN) private readonly brain: Brain, @Inject(CommandRegistryService) private readonly commandRegistry: CommandRegistryService, @Inject(CommandExecutorService) private readonly commandExecutor: CommandExecutorService, ) {} @@ -80,6 +92,7 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa @MessageBody() data: ChatSocketMessageDto, ): Promise { const conversationId = data.conversationId ?? uuid(); + const userId = (client.data.user as { id: string } | undefined)?.id; this.logger.log(`Message from ${client.id} in conversation ${conversationId}`); @@ -87,7 +100,6 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa try { let agentSession = this.agentService.getSession(conversationId); if (!agentSession) { - const userId = (client.data.user as { id: string } | undefined)?.id; agentSession = await this.agentService.createSession(conversationId, { provider: data.provider, modelId: data.modelId, @@ -107,6 +119,30 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa return; } + // Ensure conversation record exists in the DB before persisting messages + if (userId) { + await this.ensureConversation(conversationId, userId); + } + + // Persist the user message + if (userId) { + try { + await this.brain.conversations.addMessage({ + conversationId, + role: 'user', + content: data.content, + metadata: { + timestamp: new Date().toISOString(), + }, + }); + } catch (err) { + this.logger.error( + `Failed to persist user message for conversation=${conversationId}`, + err instanceof Error ? err.stack : String(err), + ); + } + } + // Always clean up previous listener to prevent leak const existing = this.clientSessions.get(client.id); if (existing) { @@ -118,7 +154,13 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa this.relayEvent(client, conversationId, event); }); - this.clientSessions.set(client.id, { conversationId, cleanup }); + this.clientSessions.set(client.id, { + conversationId, + cleanup, + assistantText: '', + toolCalls: [], + pendingToolCalls: new Map(), + }); // Track channel connection this.agentService.addChannel(conversationId, `websocket:${client.id}`); @@ -208,6 +250,28 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa this.logger.log('Broadcasted system:reload to all connected clients'); } + /** + * Ensure a conversation record exists in the DB. + * Creates it if absent — safe to call concurrently since a duplicate insert + * would fail on the PK constraint and be caught here. + */ + private async ensureConversation(conversationId: string, userId: string): Promise { + try { + const existing = await this.brain.conversations.findById(conversationId); + if (!existing) { + await this.brain.conversations.create({ + id: conversationId, + userId, + }); + } + } catch (err) { + this.logger.error( + `Failed to ensure conversation record for conversation=${conversationId}`, + err instanceof Error ? err.stack : String(err), + ); + } + } + private relayEvent(client: Socket, conversationId: string, event: AgentSessionEvent): void { if (!client.connected) { this.logger.warn( @@ -217,9 +281,17 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa } switch (event.type) { - case 'agent_start': + case 'agent_start': { + // Reset accumulation buffers for the new turn + const cs = this.clientSessions.get(client.id); + if (cs) { + cs.assistantText = ''; + cs.toolCalls = []; + cs.pendingToolCalls.clear(); + } client.emit('agent:start', { conversationId }); break; + } case 'agent_end': { // Gather usage stats from the Pi session @@ -228,28 +300,76 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa const stats = piSession?.getSessionStats(); const contextUsage = piSession?.getContextUsage(); + const usagePayload = stats + ? { + provider: agentSession?.provider ?? 'unknown', + modelId: agentSession?.modelId ?? 'unknown', + thinkingLevel: piSession?.thinkingLevel ?? 'off', + tokens: stats.tokens, + cost: stats.cost, + context: { + percent: contextUsage?.percent ?? null, + window: contextUsage?.contextWindow ?? 0, + }, + } + : undefined; + client.emit('agent:end', { conversationId, - usage: stats - ? { - provider: agentSession?.provider ?? 'unknown', - modelId: agentSession?.modelId ?? 'unknown', - thinkingLevel: piSession?.thinkingLevel ?? 'off', - tokens: stats.tokens, - cost: stats.cost, - context: { - percent: contextUsage?.percent ?? null, - window: contextUsage?.contextWindow ?? 0, - }, - } - : undefined, + usage: usagePayload, }); + + // Persist the assistant message with metadata + const cs = this.clientSessions.get(client.id); + const userId = (client.data.user as { id: string } | undefined)?.id; + if (cs && userId && cs.assistantText.trim().length > 0) { + const metadata: Record = { + timestamp: new Date().toISOString(), + model: agentSession?.modelId ?? 'unknown', + provider: agentSession?.provider ?? 'unknown', + toolCalls: cs.toolCalls, + }; + + if (stats?.tokens) { + metadata['tokenUsage'] = { + input: stats.tokens.input, + output: stats.tokens.output, + cacheRead: stats.tokens.cacheRead, + cacheWrite: stats.tokens.cacheWrite, + total: stats.tokens.total, + }; + } + + this.brain.conversations + .addMessage({ + conversationId, + role: 'assistant', + content: cs.assistantText, + metadata, + }) + .catch((err: unknown) => { + this.logger.error( + `Failed to persist assistant message for conversation=${conversationId}`, + err instanceof Error ? err.stack : String(err), + ); + }); + + // Reset accumulation + cs.assistantText = ''; + cs.toolCalls = []; + cs.pendingToolCalls.clear(); + } break; } case 'message_update': { const assistantEvent = event.assistantMessageEvent; if (assistantEvent.type === 'text_delta') { + // Accumulate assistant text for persistence + const cs = this.clientSessions.get(client.id); + if (cs) { + cs.assistantText += assistantEvent.delta; + } client.emit('agent:text', { conversationId, text: assistantEvent.delta, @@ -263,15 +383,36 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa break; } - case 'tool_execution_start': + case 'tool_execution_start': { + // Track pending tool call for later recording + const cs = this.clientSessions.get(client.id); + if (cs) { + cs.pendingToolCalls.set(event.toolCallId, { + toolName: event.toolName, + args: event.args, + }); + } client.emit('agent:tool:start', { conversationId, toolCallId: event.toolCallId, toolName: event.toolName, }); break; + } - case 'tool_execution_end': + case 'tool_execution_end': { + // Finalise tool call record + const cs = this.clientSessions.get(client.id); + if (cs) { + const pending = cs.pendingToolCalls.get(event.toolCallId); + cs.toolCalls.push({ + toolCallId: event.toolCallId, + toolName: event.toolName, + args: pending?.args ?? null, + isError: event.isError, + }); + cs.pendingToolCalls.delete(event.toolCallId); + } client.emit('agent:tool:end', { conversationId, toolCallId: event.toolCallId, @@ -279,6 +420,7 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa isError: event.isError, }); break; + } } } }