feat(chat): persist messages to DB via ConversationsRepo (M1-001/002/003)
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
ci/woodpecker/pr/ci Pipeline was successful

Wire ChatGateway to ConversationsRepo so every user message is saved on
receipt (M1-001) and every assistant response is saved on agent:end with
accumulated streaming text (M1-002). Metadata includes model, provider,
tokenUsage (input/output/cache counts), toolCalls array, and timestamp
on each message record (M1-003). Auto-creates the conversation row when
the agent session is new.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-21 15:17:06 -05:00
parent 36095ad80f
commit ef3529a587

View File

@@ -12,15 +12,29 @@ import {
import { Server, Socket } from 'socket.io';
import type { AgentSessionEvent } from '@mariozechner/pi-coding-agent';
import type { Auth } from '@mosaic/auth';
import type { Brain } from '@mosaic/brain';
import type { SetThinkingPayload, SlashCommandPayload, SystemReloadPayload } from '@mosaic/types';
import { AgentService } from '../agent/agent.service.js';
import { AUTH } from '../auth/auth.tokens.js';
import { BRAIN } from '../brain/brain.tokens.js';
import { CommandRegistryService } from '../commands/command-registry.service.js';
import { CommandExecutorService } from '../commands/command-executor.service.js';
import { v4 as uuid } from 'uuid';
import { ChatSocketMessageDto } from './chat.dto.js';
import { validateSocketSession } from './chat.gateway-auth.js';
/** Per-client state tracking streaming accumulation for persistence. */
interface ClientSession {
conversationId: string;
cleanup: () => void;
/** Accumulated assistant response text for the current turn. */
assistantText: string;
/** Tool calls observed during the current turn. */
toolCalls: Array<{ toolCallId: string; toolName: string; args: unknown; isError: boolean }>;
/** Tool calls in-flight (started but not ended yet). */
pendingToolCalls: Map<string, { toolName: string; args: unknown }>;
}
@WebSocketGateway({
cors: {
origin: process.env['GATEWAY_CORS_ORIGIN'] ?? 'http://localhost:3000',
@@ -32,14 +46,12 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
server!: Server;
private readonly logger = new Logger(ChatGateway.name);
private readonly clientSessions = new Map<
string,
{ conversationId: string; cleanup: () => void }
>();
private readonly clientSessions = new Map<string, ClientSession>();
constructor(
@Inject(AgentService) private readonly agentService: AgentService,
@Inject(AUTH) private readonly auth: Auth,
@Inject(BRAIN) private readonly brain: Brain,
@Inject(CommandRegistryService) private readonly commandRegistry: CommandRegistryService,
@Inject(CommandExecutorService) private readonly commandExecutor: CommandExecutorService,
) {}
@@ -80,6 +92,7 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
@MessageBody() data: ChatSocketMessageDto,
): Promise<void> {
const conversationId = data.conversationId ?? uuid();
const userId = (client.data.user as { id: string } | undefined)?.id;
this.logger.log(`Message from ${client.id} in conversation ${conversationId}`);
@@ -87,7 +100,6 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
try {
let agentSession = this.agentService.getSession(conversationId);
if (!agentSession) {
const userId = (client.data.user as { id: string } | undefined)?.id;
agentSession = await this.agentService.createSession(conversationId, {
provider: data.provider,
modelId: data.modelId,
@@ -107,6 +119,30 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
return;
}
// Ensure conversation record exists in the DB before persisting messages
if (userId) {
await this.ensureConversation(conversationId, userId);
}
// Persist the user message
if (userId) {
try {
await this.brain.conversations.addMessage({
conversationId,
role: 'user',
content: data.content,
metadata: {
timestamp: new Date().toISOString(),
},
});
} catch (err) {
this.logger.error(
`Failed to persist user message for conversation=${conversationId}`,
err instanceof Error ? err.stack : String(err),
);
}
}
// Always clean up previous listener to prevent leak
const existing = this.clientSessions.get(client.id);
if (existing) {
@@ -118,7 +154,13 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
this.relayEvent(client, conversationId, event);
});
this.clientSessions.set(client.id, { conversationId, cleanup });
this.clientSessions.set(client.id, {
conversationId,
cleanup,
assistantText: '',
toolCalls: [],
pendingToolCalls: new Map(),
});
// Track channel connection
this.agentService.addChannel(conversationId, `websocket:${client.id}`);
@@ -208,6 +250,28 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
this.logger.log('Broadcasted system:reload to all connected clients');
}
/**
* Ensure a conversation record exists in the DB.
* Creates it if absent — safe to call concurrently since a duplicate insert
* would fail on the PK constraint and be caught here.
*/
private async ensureConversation(conversationId: string, userId: string): Promise<void> {
try {
const existing = await this.brain.conversations.findById(conversationId);
if (!existing) {
await this.brain.conversations.create({
id: conversationId,
userId,
});
}
} catch (err) {
this.logger.error(
`Failed to ensure conversation record for conversation=${conversationId}`,
err instanceof Error ? err.stack : String(err),
);
}
}
private relayEvent(client: Socket, conversationId: string, event: AgentSessionEvent): void {
if (!client.connected) {
this.logger.warn(
@@ -217,9 +281,17 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
}
switch (event.type) {
case 'agent_start':
case 'agent_start': {
// Reset accumulation buffers for the new turn
const cs = this.clientSessions.get(client.id);
if (cs) {
cs.assistantText = '';
cs.toolCalls = [];
cs.pendingToolCalls.clear();
}
client.emit('agent:start', { conversationId });
break;
}
case 'agent_end': {
// Gather usage stats from the Pi session
@@ -228,28 +300,76 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
const stats = piSession?.getSessionStats();
const contextUsage = piSession?.getContextUsage();
const usagePayload = stats
? {
provider: agentSession?.provider ?? 'unknown',
modelId: agentSession?.modelId ?? 'unknown',
thinkingLevel: piSession?.thinkingLevel ?? 'off',
tokens: stats.tokens,
cost: stats.cost,
context: {
percent: contextUsage?.percent ?? null,
window: contextUsage?.contextWindow ?? 0,
},
}
: undefined;
client.emit('agent:end', {
conversationId,
usage: stats
? {
provider: agentSession?.provider ?? 'unknown',
modelId: agentSession?.modelId ?? 'unknown',
thinkingLevel: piSession?.thinkingLevel ?? 'off',
tokens: stats.tokens,
cost: stats.cost,
context: {
percent: contextUsage?.percent ?? null,
window: contextUsage?.contextWindow ?? 0,
},
}
: undefined,
usage: usagePayload,
});
// Persist the assistant message with metadata
const cs = this.clientSessions.get(client.id);
const userId = (client.data.user as { id: string } | undefined)?.id;
if (cs && userId && cs.assistantText.trim().length > 0) {
const metadata: Record<string, unknown> = {
timestamp: new Date().toISOString(),
model: agentSession?.modelId ?? 'unknown',
provider: agentSession?.provider ?? 'unknown',
toolCalls: cs.toolCalls,
};
if (stats?.tokens) {
metadata['tokenUsage'] = {
input: stats.tokens.input,
output: stats.tokens.output,
cacheRead: stats.tokens.cacheRead,
cacheWrite: stats.tokens.cacheWrite,
total: stats.tokens.total,
};
}
this.brain.conversations
.addMessage({
conversationId,
role: 'assistant',
content: cs.assistantText,
metadata,
})
.catch((err: unknown) => {
this.logger.error(
`Failed to persist assistant message for conversation=${conversationId}`,
err instanceof Error ? err.stack : String(err),
);
});
// Reset accumulation
cs.assistantText = '';
cs.toolCalls = [];
cs.pendingToolCalls.clear();
}
break;
}
case 'message_update': {
const assistantEvent = event.assistantMessageEvent;
if (assistantEvent.type === 'text_delta') {
// Accumulate assistant text for persistence
const cs = this.clientSessions.get(client.id);
if (cs) {
cs.assistantText += assistantEvent.delta;
}
client.emit('agent:text', {
conversationId,
text: assistantEvent.delta,
@@ -263,15 +383,36 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
break;
}
case 'tool_execution_start':
case 'tool_execution_start': {
// Track pending tool call for later recording
const cs = this.clientSessions.get(client.id);
if (cs) {
cs.pendingToolCalls.set(event.toolCallId, {
toolName: event.toolName,
args: event.args,
});
}
client.emit('agent:tool:start', {
conversationId,
toolCallId: event.toolCallId,
toolName: event.toolName,
});
break;
}
case 'tool_execution_end':
case 'tool_execution_end': {
// Finalise tool call record
const cs = this.clientSessions.get(client.id);
if (cs) {
const pending = cs.pendingToolCalls.get(event.toolCallId);
cs.toolCalls.push({
toolCallId: event.toolCallId,
toolName: event.toolName,
args: pending?.args ?? null,
isError: event.isError,
});
cs.pendingToolCalls.delete(event.toolCallId);
}
client.emit('agent:tool:end', {
conversationId,
toolCallId: event.toolCallId,
@@ -279,6 +420,7 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
isError: event.isError,
});
break;
}
}
}
}