feat(chat): persist messages to DB via ConversationsRepo (M1-001/002/003)
Wire ChatGateway to ConversationsRepo so every user message is saved on receipt (M1-001) and every assistant response is saved on agent:end with accumulated streaming text (M1-002). Metadata includes model, provider, tokenUsage (input/output/cache counts), toolCalls array, and timestamp on each message record (M1-003). Auto-creates the conversation row when the agent session is new. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -12,15 +12,29 @@ import {
|
||||
import { Server, Socket } from 'socket.io';
|
||||
import type { AgentSessionEvent } from '@mariozechner/pi-coding-agent';
|
||||
import type { Auth } from '@mosaic/auth';
|
||||
import type { Brain } from '@mosaic/brain';
|
||||
import type { SetThinkingPayload, SlashCommandPayload, SystemReloadPayload } from '@mosaic/types';
|
||||
import { AgentService } from '../agent/agent.service.js';
|
||||
import { AUTH } from '../auth/auth.tokens.js';
|
||||
import { BRAIN } from '../brain/brain.tokens.js';
|
||||
import { CommandRegistryService } from '../commands/command-registry.service.js';
|
||||
import { CommandExecutorService } from '../commands/command-executor.service.js';
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import { ChatSocketMessageDto } from './chat.dto.js';
|
||||
import { validateSocketSession } from './chat.gateway-auth.js';
|
||||
|
||||
/** Per-client state tracking streaming accumulation for persistence. */
|
||||
interface ClientSession {
|
||||
conversationId: string;
|
||||
cleanup: () => void;
|
||||
/** Accumulated assistant response text for the current turn. */
|
||||
assistantText: string;
|
||||
/** Tool calls observed during the current turn. */
|
||||
toolCalls: Array<{ toolCallId: string; toolName: string; args: unknown; isError: boolean }>;
|
||||
/** Tool calls in-flight (started but not ended yet). */
|
||||
pendingToolCalls: Map<string, { toolName: string; args: unknown }>;
|
||||
}
|
||||
|
||||
@WebSocketGateway({
|
||||
cors: {
|
||||
origin: process.env['GATEWAY_CORS_ORIGIN'] ?? 'http://localhost:3000',
|
||||
@@ -32,14 +46,12 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
server!: Server;
|
||||
|
||||
private readonly logger = new Logger(ChatGateway.name);
|
||||
private readonly clientSessions = new Map<
|
||||
string,
|
||||
{ conversationId: string; cleanup: () => void }
|
||||
>();
|
||||
private readonly clientSessions = new Map<string, ClientSession>();
|
||||
|
||||
constructor(
|
||||
@Inject(AgentService) private readonly agentService: AgentService,
|
||||
@Inject(AUTH) private readonly auth: Auth,
|
||||
@Inject(BRAIN) private readonly brain: Brain,
|
||||
@Inject(CommandRegistryService) private readonly commandRegistry: CommandRegistryService,
|
||||
@Inject(CommandExecutorService) private readonly commandExecutor: CommandExecutorService,
|
||||
) {}
|
||||
@@ -80,6 +92,7 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
@MessageBody() data: ChatSocketMessageDto,
|
||||
): Promise<void> {
|
||||
const conversationId = data.conversationId ?? uuid();
|
||||
const userId = (client.data.user as { id: string } | undefined)?.id;
|
||||
|
||||
this.logger.log(`Message from ${client.id} in conversation ${conversationId}`);
|
||||
|
||||
@@ -87,7 +100,6 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
try {
|
||||
let agentSession = this.agentService.getSession(conversationId);
|
||||
if (!agentSession) {
|
||||
const userId = (client.data.user as { id: string } | undefined)?.id;
|
||||
agentSession = await this.agentService.createSession(conversationId, {
|
||||
provider: data.provider,
|
||||
modelId: data.modelId,
|
||||
@@ -107,6 +119,30 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
return;
|
||||
}
|
||||
|
||||
// Ensure conversation record exists in the DB before persisting messages
|
||||
if (userId) {
|
||||
await this.ensureConversation(conversationId, userId);
|
||||
}
|
||||
|
||||
// Persist the user message
|
||||
if (userId) {
|
||||
try {
|
||||
await this.brain.conversations.addMessage({
|
||||
conversationId,
|
||||
role: 'user',
|
||||
content: data.content,
|
||||
metadata: {
|
||||
timestamp: new Date().toISOString(),
|
||||
},
|
||||
});
|
||||
} catch (err) {
|
||||
this.logger.error(
|
||||
`Failed to persist user message for conversation=${conversationId}`,
|
||||
err instanceof Error ? err.stack : String(err),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Always clean up previous listener to prevent leak
|
||||
const existing = this.clientSessions.get(client.id);
|
||||
if (existing) {
|
||||
@@ -118,7 +154,13 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
this.relayEvent(client, conversationId, event);
|
||||
});
|
||||
|
||||
this.clientSessions.set(client.id, { conversationId, cleanup });
|
||||
this.clientSessions.set(client.id, {
|
||||
conversationId,
|
||||
cleanup,
|
||||
assistantText: '',
|
||||
toolCalls: [],
|
||||
pendingToolCalls: new Map(),
|
||||
});
|
||||
|
||||
// Track channel connection
|
||||
this.agentService.addChannel(conversationId, `websocket:${client.id}`);
|
||||
@@ -208,6 +250,28 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
this.logger.log('Broadcasted system:reload to all connected clients');
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure a conversation record exists in the DB.
|
||||
* Creates it if absent — safe to call concurrently since a duplicate insert
|
||||
* would fail on the PK constraint and be caught here.
|
||||
*/
|
||||
private async ensureConversation(conversationId: string, userId: string): Promise<void> {
|
||||
try {
|
||||
const existing = await this.brain.conversations.findById(conversationId);
|
||||
if (!existing) {
|
||||
await this.brain.conversations.create({
|
||||
id: conversationId,
|
||||
userId,
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
this.logger.error(
|
||||
`Failed to ensure conversation record for conversation=${conversationId}`,
|
||||
err instanceof Error ? err.stack : String(err),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private relayEvent(client: Socket, conversationId: string, event: AgentSessionEvent): void {
|
||||
if (!client.connected) {
|
||||
this.logger.warn(
|
||||
@@ -217,9 +281,17 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
}
|
||||
|
||||
switch (event.type) {
|
||||
case 'agent_start':
|
||||
case 'agent_start': {
|
||||
// Reset accumulation buffers for the new turn
|
||||
const cs = this.clientSessions.get(client.id);
|
||||
if (cs) {
|
||||
cs.assistantText = '';
|
||||
cs.toolCalls = [];
|
||||
cs.pendingToolCalls.clear();
|
||||
}
|
||||
client.emit('agent:start', { conversationId });
|
||||
break;
|
||||
}
|
||||
|
||||
case 'agent_end': {
|
||||
// Gather usage stats from the Pi session
|
||||
@@ -228,28 +300,76 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
const stats = piSession?.getSessionStats();
|
||||
const contextUsage = piSession?.getContextUsage();
|
||||
|
||||
const usagePayload = stats
|
||||
? {
|
||||
provider: agentSession?.provider ?? 'unknown',
|
||||
modelId: agentSession?.modelId ?? 'unknown',
|
||||
thinkingLevel: piSession?.thinkingLevel ?? 'off',
|
||||
tokens: stats.tokens,
|
||||
cost: stats.cost,
|
||||
context: {
|
||||
percent: contextUsage?.percent ?? null,
|
||||
window: contextUsage?.contextWindow ?? 0,
|
||||
},
|
||||
}
|
||||
: undefined;
|
||||
|
||||
client.emit('agent:end', {
|
||||
conversationId,
|
||||
usage: stats
|
||||
? {
|
||||
provider: agentSession?.provider ?? 'unknown',
|
||||
modelId: agentSession?.modelId ?? 'unknown',
|
||||
thinkingLevel: piSession?.thinkingLevel ?? 'off',
|
||||
tokens: stats.tokens,
|
||||
cost: stats.cost,
|
||||
context: {
|
||||
percent: contextUsage?.percent ?? null,
|
||||
window: contextUsage?.contextWindow ?? 0,
|
||||
},
|
||||
}
|
||||
: undefined,
|
||||
usage: usagePayload,
|
||||
});
|
||||
|
||||
// Persist the assistant message with metadata
|
||||
const cs = this.clientSessions.get(client.id);
|
||||
const userId = (client.data.user as { id: string } | undefined)?.id;
|
||||
if (cs && userId && cs.assistantText.trim().length > 0) {
|
||||
const metadata: Record<string, unknown> = {
|
||||
timestamp: new Date().toISOString(),
|
||||
model: agentSession?.modelId ?? 'unknown',
|
||||
provider: agentSession?.provider ?? 'unknown',
|
||||
toolCalls: cs.toolCalls,
|
||||
};
|
||||
|
||||
if (stats?.tokens) {
|
||||
metadata['tokenUsage'] = {
|
||||
input: stats.tokens.input,
|
||||
output: stats.tokens.output,
|
||||
cacheRead: stats.tokens.cacheRead,
|
||||
cacheWrite: stats.tokens.cacheWrite,
|
||||
total: stats.tokens.total,
|
||||
};
|
||||
}
|
||||
|
||||
this.brain.conversations
|
||||
.addMessage({
|
||||
conversationId,
|
||||
role: 'assistant',
|
||||
content: cs.assistantText,
|
||||
metadata,
|
||||
})
|
||||
.catch((err: unknown) => {
|
||||
this.logger.error(
|
||||
`Failed to persist assistant message for conversation=${conversationId}`,
|
||||
err instanceof Error ? err.stack : String(err),
|
||||
);
|
||||
});
|
||||
|
||||
// Reset accumulation
|
||||
cs.assistantText = '';
|
||||
cs.toolCalls = [];
|
||||
cs.pendingToolCalls.clear();
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 'message_update': {
|
||||
const assistantEvent = event.assistantMessageEvent;
|
||||
if (assistantEvent.type === 'text_delta') {
|
||||
// Accumulate assistant text for persistence
|
||||
const cs = this.clientSessions.get(client.id);
|
||||
if (cs) {
|
||||
cs.assistantText += assistantEvent.delta;
|
||||
}
|
||||
client.emit('agent:text', {
|
||||
conversationId,
|
||||
text: assistantEvent.delta,
|
||||
@@ -263,15 +383,36 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
break;
|
||||
}
|
||||
|
||||
case 'tool_execution_start':
|
||||
case 'tool_execution_start': {
|
||||
// Track pending tool call for later recording
|
||||
const cs = this.clientSessions.get(client.id);
|
||||
if (cs) {
|
||||
cs.pendingToolCalls.set(event.toolCallId, {
|
||||
toolName: event.toolName,
|
||||
args: event.args,
|
||||
});
|
||||
}
|
||||
client.emit('agent:tool:start', {
|
||||
conversationId,
|
||||
toolCallId: event.toolCallId,
|
||||
toolName: event.toolName,
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
case 'tool_execution_end':
|
||||
case 'tool_execution_end': {
|
||||
// Finalise tool call record
|
||||
const cs = this.clientSessions.get(client.id);
|
||||
if (cs) {
|
||||
const pending = cs.pendingToolCalls.get(event.toolCallId);
|
||||
cs.toolCalls.push({
|
||||
toolCallId: event.toolCallId,
|
||||
toolName: event.toolName,
|
||||
args: pending?.args ?? null,
|
||||
isError: event.isError,
|
||||
});
|
||||
cs.pendingToolCalls.delete(event.toolCallId);
|
||||
}
|
||||
client.emit('agent:tool:end', {
|
||||
conversationId,
|
||||
toolCallId: event.toolCallId,
|
||||
@@ -279,6 +420,7 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
isError: event.isError,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user