Files
stack/apps/gateway/src/chat/chat.gateway.ts
Jarvis 774b76447d
Some checks failed
ci/woodpecker/pr/ci Pipeline failed
ci/woodpecker/push/ci Pipeline failed
fix: rename all packages from @mosaic/* to @mosaicstack/*
- Updated all package.json name fields and dependency references
- Updated all TypeScript/JavaScript imports
- Updated .woodpecker/publish.yml filters and registry paths
- Updated tools/install.sh scope default
- Updated .npmrc registry paths (worktree + host)
- Enhanced update-checker.ts with checkForAllUpdates() multi-package support
- Updated CLI update command to show table of all packages
- Added KNOWN_PACKAGES, formatAllPackagesTable, getInstallAllCommand
- Marked checkForUpdate() with @deprecated JSDoc

Closes #391
2026-04-04 21:43:23 -05:00

696 lines
24 KiB
TypeScript

import { Inject, Logger } from '@nestjs/common';
import {
WebSocketGateway,
WebSocketServer,
SubscribeMessage,
OnGatewayConnection,
OnGatewayDisconnect,
type OnGatewayInit,
ConnectedSocket,
MessageBody,
} from '@nestjs/websockets';
import { Server, Socket } from 'socket.io';
import type { AgentSessionEvent } from '@mariozechner/pi-coding-agent';
import type { Auth } from '@mosaicstack/auth';
import type { Brain } from '@mosaicstack/brain';
import type {
SetThinkingPayload,
SlashCommandPayload,
SystemReloadPayload,
RoutingDecisionInfo,
AbortPayload,
} from '@mosaicstack/types';
import { AgentService, type ConversationHistoryMessage } from '../agent/agent.service.js';
import { AUTH } from '../auth/auth.tokens.js';
import { BRAIN } from '../brain/brain.tokens.js';
import { CommandRegistryService } from '../commands/command-registry.service.js';
import { CommandExecutorService } from '../commands/command-executor.service.js';
import { RoutingEngineService } from '../agent/routing/routing-engine.service.js';
import { v4 as uuid } from 'uuid';
import { ChatSocketMessageDto } from './chat.dto.js';
import { validateSocketSession } from './chat.gateway-auth.js';
/** Per-client state tracking streaming accumulation for persistence. */
interface ClientSession {
conversationId: string;
cleanup: () => void;
/** Accumulated assistant response text for the current turn. */
assistantText: string;
/** Tool calls observed during the current turn. */
toolCalls: Array<{ toolCallId: string; toolName: string; args: unknown; isError: boolean }>;
/** Tool calls in-flight (started but not ended yet). */
pendingToolCalls: Map<string, { toolName: string; args: unknown }>;
/** Last routing decision made for this session (M4-008) */
lastRoutingDecision?: RoutingDecisionInfo;
}
/**
* Per-conversation model overrides set via /model command (M4-007).
* Keyed by conversationId, value is the model name to use.
*/
const modelOverrides = new Map<string, string>();
@WebSocketGateway({
cors: {
origin: process.env['GATEWAY_CORS_ORIGIN'] ?? 'http://localhost:3000',
},
namespace: '/chat',
})
export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect {
@WebSocketServer()
server!: Server;
private readonly logger = new Logger(ChatGateway.name);
private readonly clientSessions = new Map<string, ClientSession>();
constructor(
@Inject(AgentService) private readonly agentService: AgentService,
@Inject(AUTH) private readonly auth: Auth,
@Inject(BRAIN) private readonly brain: Brain,
@Inject(CommandRegistryService) private readonly commandRegistry: CommandRegistryService,
@Inject(CommandExecutorService) private readonly commandExecutor: CommandExecutorService,
@Inject(RoutingEngineService) private readonly routingEngine: RoutingEngineService,
) {}
afterInit(): void {
this.logger.log('Chat WebSocket gateway initialized');
}
async handleConnection(client: Socket): Promise<void> {
const session = await validateSocketSession(client.handshake.headers, this.auth);
if (!session) {
this.logger.warn(`Rejected unauthenticated WebSocket client: ${client.id}`);
client.disconnect();
return;
}
client.data.user = session.user;
client.data.session = session.session;
this.logger.log(`Client connected: ${client.id}`);
// Broadcast command manifest to the newly connected client
client.emit('commands:manifest', { manifest: this.commandRegistry.getManifest() });
}
handleDisconnect(client: Socket): void {
this.logger.log(`Client disconnected: ${client.id}`);
const session = this.clientSessions.get(client.id);
if (session) {
session.cleanup();
this.agentService.removeChannel(session.conversationId, `websocket:${client.id}`);
this.clientSessions.delete(client.id);
}
}
@SubscribeMessage('message')
async handleMessage(
@ConnectedSocket() client: Socket,
@MessageBody() data: ChatSocketMessageDto,
): Promise<void> {
const conversationId = data.conversationId ?? uuid();
const userId = (client.data.user as { id: string } | undefined)?.id;
this.logger.log(`Message from ${client.id} in conversation ${conversationId}`);
// Ensure agent session exists for this conversation
let sessionRoutingDecision: RoutingDecisionInfo | undefined;
try {
let agentSession = this.agentService.getSession(conversationId);
if (!agentSession) {
// When resuming an existing conversation, load prior messages to inject as context (M1-004)
const conversationHistory = await this.loadConversationHistory(conversationId, userId);
// M5-004: Check if there's an existing sessionId bound to this conversation
let existingSessionId: string | undefined;
if (userId) {
existingSessionId = await this.getConversationSessionId(conversationId, userId);
if (existingSessionId) {
this.logger.log(
`Resuming existing sessionId=${existingSessionId} for conversation=${conversationId}`,
);
}
}
// Determine provider/model via routing engine or per-session /model override (M4-012 / M4-007)
let resolvedProvider = data.provider;
let resolvedModelId = data.modelId;
const modelOverride = modelOverrides.get(conversationId);
if (modelOverride) {
// /model override bypasses routing engine (M4-007)
resolvedModelId = modelOverride;
this.logger.log(
`Using /model override "${modelOverride}" for conversation=${conversationId}`,
);
} else if (!resolvedProvider && !resolvedModelId) {
// No explicit provider/model from client — use routing engine (M4-012)
try {
const routingDecision = await this.routingEngine.resolve(data.content, userId);
resolvedProvider = routingDecision.provider;
resolvedModelId = routingDecision.model;
sessionRoutingDecision = {
model: routingDecision.model,
provider: routingDecision.provider,
ruleName: routingDecision.ruleName,
reason: routingDecision.reason,
};
this.logger.log(
`Routing decision for conversation=${conversationId}: ${routingDecision.provider}/${routingDecision.model} (rule="${routingDecision.ruleName}")`,
);
} catch (routingErr) {
this.logger.warn(
`Routing engine failed for conversation=${conversationId}, using defaults`,
routingErr instanceof Error ? routingErr.message : String(routingErr),
);
}
}
// M5-004: Use existingSessionId as sessionId when available (session reuse)
const sessionIdToCreate = existingSessionId ?? conversationId;
agentSession = await this.agentService.createSession(sessionIdToCreate, {
provider: resolvedProvider,
modelId: resolvedModelId,
agentConfigId: data.agentId,
userId,
conversationHistory: conversationHistory.length > 0 ? conversationHistory : undefined,
});
if (conversationHistory.length > 0) {
this.logger.log(
`Loaded ${conversationHistory.length} prior messages for conversation=${conversationId}`,
);
}
}
} catch (err) {
this.logger.error(
`Session creation failed for client=${client.id}, conversation=${conversationId}`,
err instanceof Error ? err.stack : String(err),
);
client.emit('error', {
conversationId,
error: 'Failed to start agent session. Please try again.',
});
return;
}
// Ensure conversation record exists in the DB before persisting messages
// M5-004: Also bind the sessionId to the conversation record
if (userId) {
await this.ensureConversation(conversationId, userId);
await this.bindSessionToConversation(conversationId, userId, conversationId);
}
// M5-007: Count the user message
this.agentService.recordMessage(conversationId);
// Persist the user message
if (userId) {
try {
await this.brain.conversations.addMessage(
{
conversationId,
role: 'user',
content: data.content,
metadata: {
timestamp: new Date().toISOString(),
},
},
userId,
);
} catch (err) {
this.logger.error(
`Failed to persist user message for conversation=${conversationId}`,
err instanceof Error ? err.stack : String(err),
);
}
}
// Always clean up previous listener to prevent leak
const existing = this.clientSessions.get(client.id);
if (existing) {
existing.cleanup();
}
// Subscribe to agent events and relay to client
const cleanup = this.agentService.onEvent(conversationId, (event: AgentSessionEvent) => {
this.relayEvent(client, conversationId, event);
});
// Preserve routing decision from the existing client session if we didn't get a new one
const prevClientSession = this.clientSessions.get(client.id);
const routingDecisionToStore = sessionRoutingDecision ?? prevClientSession?.lastRoutingDecision;
this.clientSessions.set(client.id, {
conversationId,
cleanup,
assistantText: '',
toolCalls: [],
pendingToolCalls: new Map(),
lastRoutingDecision: routingDecisionToStore,
});
// Track channel connection
this.agentService.addChannel(conversationId, `websocket:${client.id}`);
// Send session info so the client knows the model/provider (M4-008: include routing decision)
// Include agentName when a named agent config is active (M5-001)
{
const agentSession = this.agentService.getSession(conversationId);
if (agentSession) {
const piSession = agentSession.piSession;
client.emit('session:info', {
conversationId,
provider: agentSession.provider,
modelId: agentSession.modelId,
thinkingLevel: piSession.thinkingLevel,
availableThinkingLevels: piSession.getAvailableThinkingLevels(),
...(agentSession.agentName ? { agentName: agentSession.agentName } : {}),
...(routingDecisionToStore ? { routingDecision: routingDecisionToStore } : {}),
});
}
}
// Send acknowledgment
client.emit('message:ack', { conversationId, messageId: uuid() });
// Dispatch to agent
try {
await this.agentService.prompt(conversationId, data.content);
} catch (err) {
this.logger.error(
`Agent prompt failed for client=${client.id}, conversation=${conversationId}`,
err instanceof Error ? err.stack : String(err),
);
client.emit('error', {
conversationId,
error: 'The agent failed to process your message. Please try again.',
});
}
}
@SubscribeMessage('set:thinking')
handleSetThinking(
@ConnectedSocket() client: Socket,
@MessageBody() data: SetThinkingPayload,
): void {
const session = this.agentService.getSession(data.conversationId);
if (!session) {
client.emit('error', {
conversationId: data.conversationId,
error: 'No active session for this conversation.',
});
return;
}
const validLevels = session.piSession.getAvailableThinkingLevels();
if (!validLevels.includes(data.level as never)) {
client.emit('error', {
conversationId: data.conversationId,
error: `Invalid thinking level "${data.level}". Available: ${validLevels.join(', ')}`,
});
return;
}
session.piSession.setThinkingLevel(data.level as never);
this.logger.log(
`Thinking level set to "${data.level}" for conversation ${data.conversationId}`,
);
client.emit('session:info', {
conversationId: data.conversationId,
provider: session.provider,
modelId: session.modelId,
thinkingLevel: session.piSession.thinkingLevel,
availableThinkingLevels: session.piSession.getAvailableThinkingLevels(),
...(session.agentName ? { agentName: session.agentName } : {}),
});
}
@SubscribeMessage('abort')
async handleAbort(
@ConnectedSocket() client: Socket,
@MessageBody() data: AbortPayload,
): Promise<void> {
const conversationId = data.conversationId;
this.logger.log(`Abort requested by ${client.id} for conversation ${conversationId}`);
const session = this.agentService.getSession(conversationId);
if (!session) {
client.emit('error', {
conversationId,
error: 'No active session to abort.',
});
return;
}
try {
await session.piSession.abort();
this.logger.log(`Agent session ${conversationId} aborted successfully`);
} catch (err) {
this.logger.error(
`Failed to abort session ${conversationId}`,
err instanceof Error ? err.stack : String(err),
);
client.emit('error', {
conversationId,
error: 'Failed to abort the agent operation.',
});
}
}
@SubscribeMessage('command:execute')
async handleCommandExecute(
@ConnectedSocket() client: Socket,
@MessageBody() payload: SlashCommandPayload,
): Promise<void> {
const userId = (client.data.user as { id: string } | undefined)?.id ?? 'unknown';
const result = await this.commandExecutor.execute(payload, userId);
client.emit('command:result', result);
}
broadcastReload(payload: SystemReloadPayload): void {
this.server.emit('system:reload', payload);
this.logger.log('Broadcasted system:reload to all connected clients');
}
/**
* Set a per-conversation model override (M4-007 / M5-002).
* When set, the routing engine is bypassed and the specified model is used.
* Pass null to clear the override and resume automatic routing.
* M5-005: Emits session:info to clients subscribed to this conversation when a model is set.
* M5-007: Records a model switch in session metrics.
*/
setModelOverride(conversationId: string, modelName: string | null): void {
if (modelName) {
modelOverrides.set(conversationId, modelName);
this.logger.log(`Model override set: conversation=${conversationId} model="${modelName}"`);
// M5-002: Update the live session's modelId so session:info reflects the new model immediately
this.agentService.updateSessionModel(conversationId, modelName);
// M5-005: Broadcast session:info to all clients subscribed to this conversation
this.broadcastSessionInfo(conversationId);
} else {
modelOverrides.delete(conversationId);
this.logger.log(`Model override cleared: conversation=${conversationId}`);
}
}
/**
* Return the active model override for a conversation, or undefined if none.
*/
getModelOverride(conversationId: string): string | undefined {
return modelOverrides.get(conversationId);
}
/**
* M5-005: Broadcast session:info to all clients currently subscribed to a conversation.
* Called on model or agent switch to ensure the TUI TopBar updates immediately.
*/
broadcastSessionInfo(
conversationId: string,
extra?: { agentName?: string; routingDecision?: RoutingDecisionInfo },
): void {
const agentSession = this.agentService.getSession(conversationId);
if (!agentSession) return;
const piSession = agentSession.piSession;
const resolvedAgentName = extra?.agentName ?? agentSession.agentName;
const payload = {
conversationId,
provider: agentSession.provider,
modelId: agentSession.modelId,
thinkingLevel: piSession.thinkingLevel,
availableThinkingLevels: piSession.getAvailableThinkingLevels(),
...(resolvedAgentName ? { agentName: resolvedAgentName } : {}),
...(extra?.routingDecision ? { routingDecision: extra.routingDecision } : {}),
};
// Emit to all clients currently subscribed to this conversation
for (const [clientId, session] of this.clientSessions) {
if (session.conversationId === conversationId) {
const socket = this.server.sockets.sockets.get(clientId);
if (socket?.connected) {
socket.emit('session:info', payload);
}
}
}
}
/**
* Ensure a conversation record exists in the DB.
* Creates it if absent — safe to call concurrently since a duplicate insert
* would fail on the PK constraint and be caught here.
*/
private async ensureConversation(conversationId: string, userId: string): Promise<void> {
try {
const existing = await this.brain.conversations.findById(conversationId, userId);
if (!existing) {
await this.brain.conversations.create({
id: conversationId,
userId,
});
}
} catch (err) {
this.logger.error(
`Failed to ensure conversation record for conversation=${conversationId}`,
err instanceof Error ? err.stack : String(err),
);
}
}
/**
* M5-004: Bind the agent sessionId to the conversation record in the DB.
* Updates the sessionId column so future resumes can reuse the session.
*/
private async bindSessionToConversation(
conversationId: string,
userId: string,
sessionId: string,
): Promise<void> {
try {
await this.brain.conversations.update(conversationId, userId, { sessionId });
} catch (err) {
this.logger.error(
`Failed to bind sessionId=${sessionId} to conversation=${conversationId}`,
err instanceof Error ? err.stack : String(err),
);
}
}
/**
* M5-004: Retrieve the sessionId bound to a conversation, if any.
* Returns undefined when the conversation does not exist or has no bound session.
*/
private async getConversationSessionId(
conversationId: string,
userId: string,
): Promise<string | undefined> {
try {
const conv = await this.brain.conversations.findById(conversationId, userId);
return conv?.sessionId ?? undefined;
} catch (err) {
this.logger.error(
`Failed to get sessionId for conversation=${conversationId}`,
err instanceof Error ? err.stack : String(err),
);
return undefined;
}
}
/**
* Load prior conversation messages from DB for context injection on session resume (M1-004).
* Returns an empty array when no history exists, the conversation is not owned by the user,
* or userId is not provided.
*/
private async loadConversationHistory(
conversationId: string,
userId: string | undefined,
): Promise<ConversationHistoryMessage[]> {
if (!userId) return [];
try {
const messages = await this.brain.conversations.findMessages(conversationId, userId);
if (messages.length === 0) return [];
return messages.map((msg) => ({
role: msg.role as 'user' | 'assistant' | 'system',
content: msg.content,
createdAt: msg.createdAt,
}));
} catch (err) {
this.logger.error(
`Failed to load conversation history for conversation=${conversationId}`,
err instanceof Error ? err.stack : String(err),
);
return [];
}
}
private relayEvent(client: Socket, conversationId: string, event: AgentSessionEvent): void {
if (!client.connected) {
this.logger.warn(
`Dropping event ${event.type} for disconnected client=${client.id}, conversation=${conversationId}`,
);
return;
}
switch (event.type) {
case 'agent_start': {
// Reset accumulation buffers for the new turn
const cs = this.clientSessions.get(client.id);
if (cs) {
cs.assistantText = '';
cs.toolCalls = [];
cs.pendingToolCalls.clear();
}
client.emit('agent:start', { conversationId });
break;
}
case 'agent_end': {
// Gather usage stats from the Pi session
const agentSession = this.agentService.getSession(conversationId);
const piSession = agentSession?.piSession;
const stats = piSession?.getSessionStats();
const contextUsage = piSession?.getContextUsage();
const usagePayload = stats
? {
provider: agentSession?.provider ?? 'unknown',
modelId: agentSession?.modelId ?? 'unknown',
thinkingLevel: piSession?.thinkingLevel ?? 'off',
tokens: stats.tokens,
cost: stats.cost,
context: {
percent: contextUsage?.percent ?? null,
window: contextUsage?.contextWindow ?? 0,
},
}
: undefined;
client.emit('agent:end', {
conversationId,
usage: usagePayload,
});
// M5-007: Accumulate token usage in session metrics
if (stats?.tokens) {
this.agentService.recordTokenUsage(conversationId, {
input: stats.tokens.input ?? 0,
output: stats.tokens.output ?? 0,
cacheRead: stats.tokens.cacheRead ?? 0,
cacheWrite: stats.tokens.cacheWrite ?? 0,
total: stats.tokens.total ?? 0,
});
}
// Persist the assistant message with metadata
const cs = this.clientSessions.get(client.id);
const userId = (client.data.user as { id: string } | undefined)?.id;
if (cs && userId && cs.assistantText.trim().length > 0) {
const metadata: Record<string, unknown> = {
timestamp: new Date().toISOString(),
model: agentSession?.modelId ?? 'unknown',
provider: agentSession?.provider ?? 'unknown',
toolCalls: cs.toolCalls,
};
if (stats?.tokens) {
metadata['tokenUsage'] = {
input: stats.tokens.input,
output: stats.tokens.output,
cacheRead: stats.tokens.cacheRead,
cacheWrite: stats.tokens.cacheWrite,
total: stats.tokens.total,
};
}
this.brain.conversations
.addMessage(
{
conversationId,
role: 'assistant',
content: cs.assistantText,
metadata,
},
userId,
)
.catch((err: unknown) => {
this.logger.error(
`Failed to persist assistant message for conversation=${conversationId}`,
err instanceof Error ? err.stack : String(err),
);
});
// Reset accumulation
cs.assistantText = '';
cs.toolCalls = [];
cs.pendingToolCalls.clear();
}
break;
}
case 'message_update': {
const assistantEvent = event.assistantMessageEvent;
if (assistantEvent.type === 'text_delta') {
// Accumulate assistant text for persistence
const cs = this.clientSessions.get(client.id);
if (cs) {
cs.assistantText += assistantEvent.delta;
}
client.emit('agent:text', {
conversationId,
text: assistantEvent.delta,
});
} else if (assistantEvent.type === 'thinking_delta') {
client.emit('agent:thinking', {
conversationId,
text: assistantEvent.delta,
});
}
break;
}
case 'tool_execution_start': {
// Track pending tool call for later recording
const cs = this.clientSessions.get(client.id);
if (cs) {
cs.pendingToolCalls.set(event.toolCallId, {
toolName: event.toolName,
args: event.args,
});
}
client.emit('agent:tool:start', {
conversationId,
toolCallId: event.toolCallId,
toolName: event.toolName,
});
break;
}
case 'tool_execution_end': {
// Finalise tool call record
const cs = this.clientSessions.get(client.id);
if (cs) {
const pending = cs.pendingToolCalls.get(event.toolCallId);
cs.toolCalls.push({
toolCallId: event.toolCallId,
toolName: event.toolName,
args: pending?.args ?? null,
isError: event.isError,
});
cs.pendingToolCalls.delete(event.toolCallId);
}
client.emit('agent:tool:end', {
conversationId,
toolCallId: event.toolCallId,
toolName: event.toolName,
isError: event.isError,
});
break;
}
}
}
}