/** * SpeechGateway * * WebSocket gateway for real-time streaming transcription. * Uses a separate `/speech` namespace from the main WebSocket gateway. * * Protocol: * 1. Client connects with auth token in handshake * 2. Client emits `start-transcription` with optional { language } * 3. Client streams audio via `audio-chunk` events (Buffer/Uint8Array) * 4. Client emits `stop-transcription` to finalize * 5. Server responds with `transcription-final` containing the result * * Session management: * - One active transcription session per client * - Audio chunks accumulated in memory (Buffer array) * - On stop: chunks concatenated and sent to SpeechService.transcribe() * - Sessions cleaned up on disconnect * * Rate limiting: * - Total accumulated audio size is capped by config limits.maxUploadSize * * Issue #397 */ import { WebSocketGateway as WSGateway, WebSocketServer, SubscribeMessage, OnGatewayConnection, OnGatewayDisconnect, } from "@nestjs/websockets"; import { Logger, Inject } from "@nestjs/common"; import { Server, Socket } from "socket.io"; import { AuthService } from "../auth/auth.service"; import { PrismaService } from "../prisma/prisma.service"; import { SpeechService } from "./speech.service"; import { speechConfig, type SpeechConfig } from "./speech.config"; // ========================================== // Types // ========================================== interface AuthenticatedSocket extends Socket { data: { userId?: string; workspaceId?: string; }; } interface TranscriptionSession { chunks: Buffer[]; totalSize: number; language: string | undefined; startedAt: Date; } interface StartTranscriptionPayload { language?: string; } // ========================================== // Gateway // ========================================== @WSGateway({ namespace: "/speech", cors: { origin: (process.env.TRUSTED_ORIGINS ?? process.env.WEB_URL ?? "http://localhost:3000") .split(",") .map((s) => s.trim()), credentials: true, }, }) export class SpeechGateway implements OnGatewayConnection, OnGatewayDisconnect { @WebSocketServer() server!: Server; private readonly logger = new Logger(SpeechGateway.name); private readonly sessions = new Map(); private readonly CONNECTION_TIMEOUT_MS = 5000; constructor( private readonly authService: AuthService, private readonly prisma: PrismaService, private readonly speechService: SpeechService, @Inject(speechConfig.KEY) private readonly config: SpeechConfig ) {} // ========================================== // Connection lifecycle // ========================================== /** * Authenticate client on connection using the same pattern as the main WebSocket gateway. * Extracts token from handshake, verifies session, and checks workspace membership. */ async handleConnection(client: Socket): Promise { const authenticatedClient = client as AuthenticatedSocket; const timeoutId = setTimeout(() => { if (!authenticatedClient.data.userId) { this.logger.warn(`Client ${authenticatedClient.id} timed out during authentication`); authenticatedClient.emit("transcription-error", { message: "Authentication timed out.", }); authenticatedClient.disconnect(); } }, this.CONNECTION_TIMEOUT_MS); try { const token = this.extractTokenFromHandshake(authenticatedClient); if (!token) { this.logger.warn(`Client ${authenticatedClient.id} connected without token`); authenticatedClient.emit("transcription-error", { message: "Authentication failed: no token provided.", }); authenticatedClient.disconnect(); clearTimeout(timeoutId); return; } const sessionData = await this.authService.verifySession(token); if (!sessionData) { this.logger.warn(`Client ${authenticatedClient.id} has invalid token`); authenticatedClient.emit("transcription-error", { message: "Authentication failed: invalid or expired token.", }); authenticatedClient.disconnect(); clearTimeout(timeoutId); return; } const user = sessionData.user as { id: string }; const userId = user.id; const workspaceMembership = await this.prisma.workspaceMember.findFirst({ where: { userId }, select: { workspaceId: true, userId: true, role: true }, }); if (!workspaceMembership) { this.logger.warn(`User ${userId} has no workspace access`); authenticatedClient.emit("transcription-error", { message: "Authentication failed: no workspace access.", }); authenticatedClient.disconnect(); clearTimeout(timeoutId); return; } authenticatedClient.data.userId = userId; authenticatedClient.data.workspaceId = workspaceMembership.workspaceId; clearTimeout(timeoutId); this.logger.log( `Speech client ${authenticatedClient.id} connected (user: ${userId}, workspace: ${workspaceMembership.workspaceId})` ); } catch (error) { clearTimeout(timeoutId); this.logger.error( `Authentication failed for speech client ${authenticatedClient.id}:`, error instanceof Error ? error.message : "Unknown error" ); authenticatedClient.emit("transcription-error", { message: "Authentication failed: an unexpected error occurred.", }); authenticatedClient.disconnect(); } } /** * Clean up transcription session on client disconnect. */ handleDisconnect(client: Socket): void { const authenticatedClient = client as AuthenticatedSocket; const sessionId = authenticatedClient.id; if (this.sessions.has(sessionId)) { this.sessions.delete(sessionId); this.logger.log(`Cleaned up transcription session for client ${sessionId}`); } this.logger.debug(`Speech client ${sessionId} disconnected`); } // ========================================== // Transcription events // ========================================== /** * Start a new transcription session for the client. * Replaces any existing session for this client. * * @param client - The connected socket client * @param payload - Optional parameters: { language?: string } */ @SubscribeMessage("start-transcription") handleStartTranscription(client: Socket, payload: StartTranscriptionPayload): void { const authenticatedClient = client as AuthenticatedSocket; if (!authenticatedClient.data.userId) { authenticatedClient.emit("transcription-error", { message: "Not authenticated. Connect with a valid token.", }); return; } const sessionId = authenticatedClient.id; // Clean up any existing session for this client if (this.sessions.has(sessionId)) { this.sessions.delete(sessionId); this.logger.debug(`Replaced existing session for client ${sessionId}`); } const language = payload.language; const session: TranscriptionSession = { chunks: [], totalSize: 0, language, startedAt: new Date(), }; this.sessions.set(sessionId, session); authenticatedClient.emit("transcription-started", { sessionId, language, }); this.logger.debug( `Transcription session started for client ${sessionId} (language: ${language ?? "auto"})` ); } /** * Receive an audio chunk and accumulate it in the active session. * Enforces maximum buffer size from configuration. * * @param client - The connected socket client * @param data - Audio data as Buffer or Uint8Array */ @SubscribeMessage("audio-chunk") handleAudioChunk(client: Socket, data: Buffer | Uint8Array): void { const authenticatedClient = client as AuthenticatedSocket; if (!authenticatedClient.data.userId) { authenticatedClient.emit("transcription-error", { message: "Not authenticated. Connect with a valid token.", }); return; } const sessionId = authenticatedClient.id; const session = this.sessions.get(sessionId); if (!session) { authenticatedClient.emit("transcription-error", { message: "No active transcription session. Send start-transcription first.", }); return; } const chunk = Buffer.isBuffer(data) ? data : Buffer.from(data); const newTotalSize = session.totalSize + chunk.length; if (newTotalSize > this.config.limits.maxUploadSize) { authenticatedClient.emit("transcription-error", { message: `Audio buffer size (${String(newTotalSize)} bytes) exceeds maximum allowed size (${String(this.config.limits.maxUploadSize)} bytes).`, }); // Clean up the session on overflow this.sessions.delete(sessionId); return; } session.chunks.push(chunk); session.totalSize = newTotalSize; } /** * Stop the transcription session, concatenate audio chunks, and transcribe. * Emits `transcription-final` on success or `transcription-error` on failure. * * @param client - The connected socket client */ @SubscribeMessage("stop-transcription") async handleStopTranscription(client: Socket): Promise { const authenticatedClient = client as AuthenticatedSocket; if (!authenticatedClient.data.userId) { authenticatedClient.emit("transcription-error", { message: "Not authenticated. Connect with a valid token.", }); return; } const sessionId = authenticatedClient.id; const session = this.sessions.get(sessionId); if (!session) { authenticatedClient.emit("transcription-error", { message: "No active transcription session. Send start-transcription first.", }); return; } // Always remove session before processing (prevents double-stop) this.sessions.delete(sessionId); if (session.chunks.length === 0) { authenticatedClient.emit("transcription-error", { message: "No audio data received. Send audio-chunk events before stopping.", }); return; } try { const audioBuffer = Buffer.concat(session.chunks); const options: { language?: string } = {}; if (session.language) { options.language = session.language; } this.logger.debug( `Transcribing ${String(audioBuffer.length)} bytes for client ${sessionId} (language: ${session.language ?? "auto"})` ); const result = await this.speechService.transcribe(audioBuffer, options); authenticatedClient.emit("transcription-final", { text: result.text, language: result.language, durationSeconds: result.durationSeconds, confidence: result.confidence, segments: result.segments, }); this.logger.debug(`Transcription complete for client ${sessionId}: "${result.text}"`); } catch (error: unknown) { const message = error instanceof Error ? error.message : String(error); this.logger.error(`Transcription failed for client ${sessionId}: ${message}`); authenticatedClient.emit("transcription-error", { message: `Transcription failed: ${message}`, }); } } // ========================================== // Private helpers // ========================================== /** * Extract authentication token from Socket.IO handshake. * Checks auth.token, query.token, and Authorization header (in that order). */ private extractTokenFromHandshake(client: Socket): string | undefined { const authToken = client.handshake.auth.token as unknown; if (typeof authToken === "string" && authToken.length > 0) { return authToken; } const queryToken = client.handshake.query.token as unknown; if (typeof queryToken === "string" && queryToken.length > 0) { return queryToken; } const authHeader = client.handshake.headers.authorization as unknown; if (typeof authHeader === "string") { const parts = authHeader.split(" "); const [type, token] = parts; if (type === "Bearer" && token) { return token; } } return undefined; } }