diff --git a/apps/api/src/speech/speech.gateway.spec.ts b/apps/api/src/speech/speech.gateway.spec.ts new file mode 100644 index 0000000..dac50e1 --- /dev/null +++ b/apps/api/src/speech/speech.gateway.spec.ts @@ -0,0 +1,683 @@ +/** + * SpeechGateway Tests + * + * Issue #397: WebSocket streaming transcription endpoint tests. + * Written FIRST following TDD (Red-Green-Refactor). + * + * Tests cover: + * - Authentication via handshake token + * - Session lifecycle: start -> audio chunks -> stop + * - Transcription result emission + * - Session cleanup on disconnect + * - Error handling + * - Buffer size limit enforcement + */ + +import { describe, it, expect, beforeEach, vi } from "vitest"; +import { Socket } from "socket.io"; +import { SpeechGateway } from "./speech.gateway"; +import { SpeechService } from "./speech.service"; +import { AuthService } from "../auth/auth.service"; +import { PrismaService } from "../prisma/prisma.service"; +import type { SpeechConfig } from "./speech.config"; +import type { TranscriptionResult } from "./interfaces/speech-types"; + +// ========================================== +// Test helpers +// ========================================== + +interface AuthenticatedSocket extends Socket { + data: { + userId?: string; + workspaceId?: string; + }; +} + +function createMockConfig(): SpeechConfig { + return { + stt: { + enabled: true, + baseUrl: "http://localhost:8000/v1", + model: "test-model", + language: "en", + }, + tts: { + default: { enabled: true, url: "http://localhost:8880/v1", voice: "test", format: "mp3" }, + premium: { enabled: false, url: "" }, + fallback: { enabled: false, url: "" }, + }, + limits: { + maxUploadSize: 25_000_000, + maxDurationSeconds: 600, + maxTextLength: 4096, + }, + }; +} + +function createMockSocket(overrides?: Partial): AuthenticatedSocket { + return { + id: "test-socket-id", + join: vi.fn(), + leave: vi.fn(), + emit: vi.fn(), + disconnect: vi.fn(), + data: {}, + handshake: { + auth: { token: "valid-token" }, + query: {}, + headers: {}, + }, + ...overrides, + } as unknown as AuthenticatedSocket; +} + +function createMockAuthService(): { + verifySession: ReturnType; +} { + return { + verifySession: vi.fn().mockResolvedValue({ + user: { id: "user-123" }, + session: { id: "session-123" }, + }), + }; +} + +function createMockPrismaService(): { + workspaceMember: { findFirst: ReturnType }; +} { + return { + workspaceMember: { + findFirst: vi.fn().mockResolvedValue({ + userId: "user-123", + workspaceId: "workspace-456", + role: "MEMBER", + }), + }, + }; +} + +function createMockSpeechService(): { + transcribe: ReturnType; + isSTTAvailable: ReturnType; +} { + return { + transcribe: vi.fn().mockResolvedValue({ + text: "Hello world", + language: "en", + durationSeconds: 2.5, + } satisfies TranscriptionResult), + isSTTAvailable: vi.fn().mockReturnValue(true), + }; +} + +// ========================================== +// Tests +// ========================================== + +describe("SpeechGateway", () => { + let gateway: SpeechGateway; + let mockAuthService: ReturnType; + let mockPrismaService: ReturnType; + let mockSpeechService: ReturnType; + let mockConfig: SpeechConfig; + let mockClient: AuthenticatedSocket; + + beforeEach(() => { + mockAuthService = createMockAuthService(); + mockPrismaService = createMockPrismaService(); + mockSpeechService = createMockSpeechService(); + mockConfig = createMockConfig(); + mockClient = createMockSocket(); + + gateway = new SpeechGateway( + mockAuthService as unknown as AuthService, + mockPrismaService as unknown as PrismaService, + mockSpeechService as unknown as SpeechService, + mockConfig + ); + + vi.clearAllMocks(); + }); + + // ========================================== + // Authentication + // ========================================== + describe("handleConnection", () => { + it("should authenticate client and populate socket data on valid token", async () => { + mockAuthService.verifySession.mockResolvedValue({ + user: { id: "user-123" }, + session: { id: "session-123" }, + }); + mockPrismaService.workspaceMember.findFirst.mockResolvedValue({ + userId: "user-123", + workspaceId: "workspace-456", + role: "MEMBER", + }); + + await gateway.handleConnection(mockClient); + + expect(mockAuthService.verifySession).toHaveBeenCalledWith("valid-token"); + expect(mockClient.data.userId).toBe("user-123"); + expect(mockClient.data.workspaceId).toBe("workspace-456"); + }); + + it("should disconnect client without token", async () => { + const clientNoToken = createMockSocket({ + handshake: { auth: {}, query: {}, headers: {} }, + } as Partial); + + await gateway.handleConnection(clientNoToken); + + expect(clientNoToken.disconnect).toHaveBeenCalled(); + }); + + it("should disconnect client with invalid token", async () => { + mockAuthService.verifySession.mockResolvedValue(null); + + await gateway.handleConnection(mockClient); + + expect(mockClient.disconnect).toHaveBeenCalled(); + }); + + it("should disconnect client without workspace access", async () => { + mockAuthService.verifySession.mockResolvedValue({ + user: { id: "user-123" }, + session: { id: "session-123" }, + }); + mockPrismaService.workspaceMember.findFirst.mockResolvedValue(null); + + await gateway.handleConnection(mockClient); + + expect(mockClient.disconnect).toHaveBeenCalled(); + }); + + it("should disconnect client when auth throws", async () => { + mockAuthService.verifySession.mockRejectedValue(new Error("Auth failure")); + + await gateway.handleConnection(mockClient); + + expect(mockClient.disconnect).toHaveBeenCalled(); + }); + + it("should extract token from handshake.query as fallback", async () => { + const clientQueryToken = createMockSocket({ + handshake: { + auth: {}, + query: { token: "query-token" }, + headers: {}, + }, + } as Partial); + + mockAuthService.verifySession.mockResolvedValue({ + user: { id: "user-123" }, + session: { id: "session-123" }, + }); + mockPrismaService.workspaceMember.findFirst.mockResolvedValue({ + userId: "user-123", + workspaceId: "workspace-456", + role: "MEMBER", + }); + + await gateway.handleConnection(clientQueryToken); + + expect(mockAuthService.verifySession).toHaveBeenCalledWith("query-token"); + }); + }); + + // ========================================== + // start-transcription + // ========================================== + describe("handleStartTranscription", () => { + beforeEach(async () => { + mockAuthService.verifySession.mockResolvedValue({ + user: { id: "user-123" }, + session: { id: "session-123" }, + }); + mockPrismaService.workspaceMember.findFirst.mockResolvedValue({ + userId: "user-123", + workspaceId: "workspace-456", + role: "MEMBER", + }); + await gateway.handleConnection(mockClient); + vi.clearAllMocks(); + }); + + it("should create a transcription session", () => { + gateway.handleStartTranscription(mockClient, { language: "en" }); + + expect(mockClient.emit).toHaveBeenCalledWith( + "transcription-started", + expect.objectContaining({ sessionId: expect.any(String) }) + ); + }); + + it("should create a session with optional language parameter", () => { + gateway.handleStartTranscription(mockClient, { language: "fr" }); + + expect(mockClient.emit).toHaveBeenCalledWith( + "transcription-started", + expect.objectContaining({ sessionId: expect.any(String) }) + ); + }); + + it("should create a session with no options", () => { + gateway.handleStartTranscription(mockClient, {}); + + expect(mockClient.emit).toHaveBeenCalledWith( + "transcription-started", + expect.objectContaining({ sessionId: expect.any(String) }) + ); + }); + + it("should emit error if client is not authenticated", () => { + const unauthClient = createMockSocket(); + // Not connected through handleConnection, so no userId set + + gateway.handleStartTranscription(unauthClient, {}); + + expect(unauthClient.emit).toHaveBeenCalledWith( + "transcription-error", + expect.objectContaining({ message: expect.any(String) }) + ); + }); + + it("should replace existing session if one already exists", () => { + gateway.handleStartTranscription(mockClient, {}); + gateway.handleStartTranscription(mockClient, { language: "de" }); + + // Should have emitted transcription-started twice (no error) + const startedCalls = (mockClient.emit as ReturnType).mock.calls.filter( + (call: unknown[]) => call[0] === "transcription-started" + ); + expect(startedCalls).toHaveLength(2); + }); + }); + + // ========================================== + // audio-chunk + // ========================================== + describe("handleAudioChunk", () => { + beforeEach(async () => { + mockAuthService.verifySession.mockResolvedValue({ + user: { id: "user-123" }, + session: { id: "session-123" }, + }); + mockPrismaService.workspaceMember.findFirst.mockResolvedValue({ + userId: "user-123", + workspaceId: "workspace-456", + role: "MEMBER", + }); + await gateway.handleConnection(mockClient); + vi.clearAllMocks(); + gateway.handleStartTranscription(mockClient, {}); + vi.clearAllMocks(); + }); + + it("should accumulate audio data in the session", () => { + const chunk = Buffer.from("audio-data-1"); + gateway.handleAudioChunk(mockClient, chunk); + + // No error emitted + const errorCalls = (mockClient.emit as ReturnType).mock.calls.filter( + (call: unknown[]) => call[0] === "transcription-error" + ); + expect(errorCalls).toHaveLength(0); + }); + + it("should accept Uint8Array data and convert to Buffer", () => { + const chunk = new Uint8Array([1, 2, 3, 4]); + gateway.handleAudioChunk(mockClient, chunk); + + const errorCalls = (mockClient.emit as ReturnType).mock.calls.filter( + (call: unknown[]) => call[0] === "transcription-error" + ); + expect(errorCalls).toHaveLength(0); + }); + + it("should emit error if no active session exists", () => { + const noSessionClient = createMockSocket({ id: "no-session" }); + noSessionClient.data = { userId: "user-123", workspaceId: "workspace-456" }; + + const chunk = Buffer.from("audio-data"); + gateway.handleAudioChunk(noSessionClient, chunk); + + expect(noSessionClient.emit).toHaveBeenCalledWith( + "transcription-error", + expect.objectContaining({ message: expect.any(String) }) + ); + }); + + it("should emit error if client is not authenticated", () => { + const unauthClient = createMockSocket({ id: "unauth" }); + // Not authenticated + + const chunk = Buffer.from("audio-data"); + gateway.handleAudioChunk(unauthClient, chunk); + + expect(unauthClient.emit).toHaveBeenCalledWith( + "transcription-error", + expect.objectContaining({ message: expect.any(String) }) + ); + }); + + it("should emit error when buffer size exceeds max upload size", () => { + // Set a very small max upload size + const smallConfig = createMockConfig(); + smallConfig.limits.maxUploadSize = 10; + + const limitedGateway = new SpeechGateway( + mockAuthService as unknown as AuthService, + mockPrismaService as unknown as PrismaService, + mockSpeechService as unknown as SpeechService, + smallConfig + ); + + // We need to manually set up the authenticated client in the new gateway + const limitedClient = createMockSocket({ id: "limited-client" }); + limitedClient.data = { userId: "user-123", workspaceId: "workspace-456" }; + + // Start session directly (since handleConnection populates data) + limitedGateway.handleStartTranscription(limitedClient, {}); + vi.clearAllMocks(); + + // Send a chunk that exceeds the limit + const largeChunk = Buffer.alloc(20, "a"); + limitedGateway.handleAudioChunk(limitedClient, largeChunk); + + expect(limitedClient.emit).toHaveBeenCalledWith( + "transcription-error", + expect.objectContaining({ message: expect.stringContaining("exceeds") }) + ); + }); + + it("should emit error when accumulated buffer size exceeds max upload size", () => { + const smallConfig = createMockConfig(); + smallConfig.limits.maxUploadSize = 15; + + const limitedGateway = new SpeechGateway( + mockAuthService as unknown as AuthService, + mockPrismaService as unknown as PrismaService, + mockSpeechService as unknown as SpeechService, + smallConfig + ); + + const limitedClient = createMockSocket({ id: "limited-client-2" }); + limitedClient.data = { userId: "user-123", workspaceId: "workspace-456" }; + + limitedGateway.handleStartTranscription(limitedClient, {}); + vi.clearAllMocks(); + + // Send two chunks that together exceed the limit + const chunk1 = Buffer.alloc(10, "a"); + const chunk2 = Buffer.alloc(10, "b"); + limitedGateway.handleAudioChunk(limitedClient, chunk1); + limitedGateway.handleAudioChunk(limitedClient, chunk2); + + expect(limitedClient.emit).toHaveBeenCalledWith( + "transcription-error", + expect.objectContaining({ message: expect.stringContaining("exceeds") }) + ); + }); + }); + + // ========================================== + // stop-transcription + // ========================================== + describe("handleStopTranscription", () => { + beforeEach(async () => { + mockAuthService.verifySession.mockResolvedValue({ + user: { id: "user-123" }, + session: { id: "session-123" }, + }); + mockPrismaService.workspaceMember.findFirst.mockResolvedValue({ + userId: "user-123", + workspaceId: "workspace-456", + role: "MEMBER", + }); + await gateway.handleConnection(mockClient); + vi.clearAllMocks(); + }); + + it("should transcribe accumulated audio and emit final result", async () => { + gateway.handleStartTranscription(mockClient, { language: "en" }); + + const chunk1 = Buffer.from("audio-part-1"); + const chunk2 = Buffer.from("audio-part-2"); + gateway.handleAudioChunk(mockClient, chunk1); + gateway.handleAudioChunk(mockClient, chunk2); + + vi.clearAllMocks(); + + const expectedResult: TranscriptionResult = { + text: "Hello world", + language: "en", + durationSeconds: 2.5, + }; + mockSpeechService.transcribe.mockResolvedValue(expectedResult); + + await gateway.handleStopTranscription(mockClient); + + // Should have called transcribe with concatenated buffer + expect(mockSpeechService.transcribe).toHaveBeenCalledWith( + expect.any(Buffer), + expect.objectContaining({}) + ); + + // Should emit transcription-final + expect(mockClient.emit).toHaveBeenCalledWith( + "transcription-final", + expect.objectContaining({ text: "Hello world" }) + ); + }); + + it("should pass language option to SpeechService.transcribe", async () => { + gateway.handleStartTranscription(mockClient, { language: "fr" }); + gateway.handleAudioChunk(mockClient, Buffer.from("audio")); + + vi.clearAllMocks(); + + await gateway.handleStopTranscription(mockClient); + + expect(mockSpeechService.transcribe).toHaveBeenCalledWith( + expect.any(Buffer), + expect.objectContaining({ language: "fr" }) + ); + }); + + it("should clean up session after stop", async () => { + gateway.handleStartTranscription(mockClient, {}); + gateway.handleAudioChunk(mockClient, Buffer.from("audio")); + + await gateway.handleStopTranscription(mockClient); + + vi.clearAllMocks(); + + // Sending more audio after stop should error (no session) + gateway.handleAudioChunk(mockClient, Buffer.from("more-audio")); + + expect(mockClient.emit).toHaveBeenCalledWith( + "transcription-error", + expect.objectContaining({ message: expect.any(String) }) + ); + }); + + it("should emit transcription-error when transcription fails", async () => { + gateway.handleStartTranscription(mockClient, {}); + gateway.handleAudioChunk(mockClient, Buffer.from("audio")); + + vi.clearAllMocks(); + + mockSpeechService.transcribe.mockRejectedValue(new Error("STT service down")); + + await gateway.handleStopTranscription(mockClient); + + expect(mockClient.emit).toHaveBeenCalledWith( + "transcription-error", + expect.objectContaining({ message: expect.stringContaining("STT service down") }) + ); + }); + + it("should emit error if no active session exists", async () => { + await gateway.handleStopTranscription(mockClient); + + expect(mockClient.emit).toHaveBeenCalledWith( + "transcription-error", + expect.objectContaining({ message: expect.any(String) }) + ); + }); + + it("should emit error if client is not authenticated", async () => { + const unauthClient = createMockSocket({ id: "unauth-stop" }); + + await gateway.handleStopTranscription(unauthClient); + + expect(unauthClient.emit).toHaveBeenCalledWith( + "transcription-error", + expect.objectContaining({ message: expect.any(String) }) + ); + }); + + it("should emit error when stopping with no audio chunks received", async () => { + gateway.handleStartTranscription(mockClient, {}); + + vi.clearAllMocks(); + + await gateway.handleStopTranscription(mockClient); + + expect(mockClient.emit).toHaveBeenCalledWith( + "transcription-error", + expect.objectContaining({ message: expect.stringContaining("No audio") }) + ); + }); + }); + + // ========================================== + // handleDisconnect + // ========================================== + describe("handleDisconnect", () => { + beforeEach(async () => { + mockAuthService.verifySession.mockResolvedValue({ + user: { id: "user-123" }, + session: { id: "session-123" }, + }); + mockPrismaService.workspaceMember.findFirst.mockResolvedValue({ + userId: "user-123", + workspaceId: "workspace-456", + role: "MEMBER", + }); + await gateway.handleConnection(mockClient); + vi.clearAllMocks(); + }); + + it("should clean up active session on disconnect", () => { + gateway.handleStartTranscription(mockClient, {}); + gateway.handleAudioChunk(mockClient, Buffer.from("audio")); + + gateway.handleDisconnect(mockClient); + + // Session should be gone. Verify by trying to add a chunk to a new + // socket with the same ID (should error since session was cleaned up). + const newClient = createMockSocket({ id: mockClient.id }); + newClient.data = { userId: "user-123", workspaceId: "workspace-456" }; + + gateway.handleAudioChunk(newClient, Buffer.from("more")); + + expect(newClient.emit).toHaveBeenCalledWith( + "transcription-error", + expect.objectContaining({ message: expect.any(String) }) + ); + }); + + it("should not throw when disconnecting client without active session", () => { + expect(() => gateway.handleDisconnect(mockClient)).not.toThrow(); + }); + + it("should not throw when disconnecting unauthenticated client", () => { + const unauthClient = createMockSocket({ id: "unauth-disconnect" }); + expect(() => gateway.handleDisconnect(unauthClient)).not.toThrow(); + }); + }); + + // ========================================== + // Edge cases + // ========================================== + describe("edge cases", () => { + beforeEach(async () => { + mockAuthService.verifySession.mockResolvedValue({ + user: { id: "user-123" }, + session: { id: "session-123" }, + }); + mockPrismaService.workspaceMember.findFirst.mockResolvedValue({ + userId: "user-123", + workspaceId: "workspace-456", + role: "MEMBER", + }); + await gateway.handleConnection(mockClient); + vi.clearAllMocks(); + }); + + it("should handle multiple start-stop cycles for the same client", async () => { + // First cycle + gateway.handleStartTranscription(mockClient, {}); + gateway.handleAudioChunk(mockClient, Buffer.from("cycle-1")); + await gateway.handleStopTranscription(mockClient); + + vi.clearAllMocks(); + + // Second cycle + gateway.handleStartTranscription(mockClient, { language: "de" }); + gateway.handleAudioChunk(mockClient, Buffer.from("cycle-2")); + await gateway.handleStopTranscription(mockClient); + + expect(mockSpeechService.transcribe).toHaveBeenCalledTimes(1); + expect(mockClient.emit).toHaveBeenCalledWith( + "transcription-final", + expect.objectContaining({ text: "Hello world" }) + ); + }); + + it("should isolate sessions between different clients", async () => { + const client2 = createMockSocket({ id: "client-2" }); + client2.data = { userId: "user-456", workspaceId: "workspace-789" }; + + // Client 2 also needs to be "connected" + mockAuthService.verifySession.mockResolvedValue({ + user: { id: "user-456" }, + session: { id: "session-456" }, + }); + mockPrismaService.workspaceMember.findFirst.mockResolvedValue({ + userId: "user-456", + workspaceId: "workspace-789", + role: "MEMBER", + }); + await gateway.handleConnection(client2); + vi.clearAllMocks(); + + // Start sessions for both clients + gateway.handleStartTranscription(mockClient, {}); + gateway.handleStartTranscription(client2, {}); + + // Send audio to client 1 only + gateway.handleAudioChunk(mockClient, Buffer.from("audio-for-client-1")); + + // Stop client 2 (no audio) + await gateway.handleStopTranscription(client2); + + // Client 2 should get an error (no audio received) + expect(client2.emit).toHaveBeenCalledWith( + "transcription-error", + expect.objectContaining({ message: expect.stringContaining("No audio") }) + ); + + vi.clearAllMocks(); + + // Stop client 1 (has audio) -- should succeed + await gateway.handleStopTranscription(mockClient); + expect(mockSpeechService.transcribe).toHaveBeenCalled(); + expect(mockClient.emit).toHaveBeenCalledWith( + "transcription-final", + expect.objectContaining({ text: "Hello world" }) + ); + }); + }); +}); diff --git a/apps/api/src/speech/speech.gateway.ts b/apps/api/src/speech/speech.gateway.ts new file mode 100644 index 0000000..907ec57 --- /dev/null +++ b/apps/api/src/speech/speech.gateway.ts @@ -0,0 +1,366 @@ +/** + * SpeechGateway + * + * WebSocket gateway for real-time streaming transcription. + * Uses a separate `/speech` namespace from the main WebSocket gateway. + * + * Protocol: + * 1. Client connects with auth token in handshake + * 2. Client emits `start-transcription` with optional { language } + * 3. Client streams audio via `audio-chunk` events (Buffer/Uint8Array) + * 4. Client emits `stop-transcription` to finalize + * 5. Server responds with `transcription-final` containing the result + * + * Session management: + * - One active transcription session per client + * - Audio chunks accumulated in memory (Buffer array) + * - On stop: chunks concatenated and sent to SpeechService.transcribe() + * - Sessions cleaned up on disconnect + * + * Rate limiting: + * - Total accumulated audio size is capped by config limits.maxUploadSize + * + * Issue #397 + */ + +import { + WebSocketGateway as WSGateway, + WebSocketServer, + SubscribeMessage, + OnGatewayConnection, + OnGatewayDisconnect, +} from "@nestjs/websockets"; +import { Logger, Inject } from "@nestjs/common"; +import { Server, Socket } from "socket.io"; +import { AuthService } from "../auth/auth.service"; +import { PrismaService } from "../prisma/prisma.service"; +import { SpeechService } from "./speech.service"; +import { speechConfig, type SpeechConfig } from "./speech.config"; + +// ========================================== +// Types +// ========================================== + +interface AuthenticatedSocket extends Socket { + data: { + userId?: string; + workspaceId?: string; + }; +} + +interface TranscriptionSession { + chunks: Buffer[]; + totalSize: number; + language: string | undefined; + startedAt: Date; +} + +interface StartTranscriptionPayload { + language?: string; +} + +// ========================================== +// Gateway +// ========================================== + +@WSGateway({ + namespace: "/speech", + cors: { + origin: process.env.WEB_URL ?? "http://localhost:3000", + credentials: true, + }, +}) +export class SpeechGateway implements OnGatewayConnection, OnGatewayDisconnect { + @WebSocketServer() + server!: Server; + + private readonly logger = new Logger(SpeechGateway.name); + private readonly sessions = new Map(); + private readonly CONNECTION_TIMEOUT_MS = 5000; + + constructor( + private readonly authService: AuthService, + private readonly prisma: PrismaService, + private readonly speechService: SpeechService, + @Inject(speechConfig.KEY) + private readonly config: SpeechConfig + ) {} + + // ========================================== + // Connection lifecycle + // ========================================== + + /** + * Authenticate client on connection using the same pattern as the main WebSocket gateway. + * Extracts token from handshake, verifies session, and checks workspace membership. + */ + async handleConnection(client: Socket): Promise { + const authenticatedClient = client as AuthenticatedSocket; + + const timeoutId = setTimeout(() => { + if (!authenticatedClient.data.userId) { + this.logger.warn(`Client ${authenticatedClient.id} timed out during authentication`); + authenticatedClient.disconnect(); + } + }, this.CONNECTION_TIMEOUT_MS); + + try { + const token = this.extractTokenFromHandshake(authenticatedClient); + + if (!token) { + this.logger.warn(`Client ${authenticatedClient.id} connected without token`); + authenticatedClient.disconnect(); + clearTimeout(timeoutId); + return; + } + + const sessionData = await this.authService.verifySession(token); + + if (!sessionData) { + this.logger.warn(`Client ${authenticatedClient.id} has invalid token`); + authenticatedClient.disconnect(); + clearTimeout(timeoutId); + return; + } + + const user = sessionData.user as { id: string }; + const userId = user.id; + + const workspaceMembership = await this.prisma.workspaceMember.findFirst({ + where: { userId }, + select: { workspaceId: true, userId: true, role: true }, + }); + + if (!workspaceMembership) { + this.logger.warn(`User ${userId} has no workspace access`); + authenticatedClient.disconnect(); + clearTimeout(timeoutId); + return; + } + + authenticatedClient.data.userId = userId; + authenticatedClient.data.workspaceId = workspaceMembership.workspaceId; + + clearTimeout(timeoutId); + this.logger.log( + `Speech client ${authenticatedClient.id} connected (user: ${userId}, workspace: ${workspaceMembership.workspaceId})` + ); + } catch (error) { + clearTimeout(timeoutId); + this.logger.error( + `Authentication failed for speech client ${authenticatedClient.id}:`, + error instanceof Error ? error.message : "Unknown error" + ); + authenticatedClient.disconnect(); + } + } + + /** + * Clean up transcription session on client disconnect. + */ + handleDisconnect(client: Socket): void { + const authenticatedClient = client as AuthenticatedSocket; + const sessionId = authenticatedClient.id; + + if (this.sessions.has(sessionId)) { + this.sessions.delete(sessionId); + this.logger.log(`Cleaned up transcription session for client ${sessionId}`); + } + + this.logger.debug(`Speech client ${sessionId} disconnected`); + } + + // ========================================== + // Transcription events + // ========================================== + + /** + * Start a new transcription session for the client. + * Replaces any existing session for this client. + * + * @param client - The connected socket client + * @param payload - Optional parameters: { language?: string } + */ + @SubscribeMessage("start-transcription") + handleStartTranscription(client: Socket, payload: StartTranscriptionPayload): void { + const authenticatedClient = client as AuthenticatedSocket; + + if (!authenticatedClient.data.userId) { + authenticatedClient.emit("transcription-error", { + message: "Not authenticated. Connect with a valid token.", + }); + return; + } + + const sessionId = authenticatedClient.id; + + // Clean up any existing session for this client + if (this.sessions.has(sessionId)) { + this.sessions.delete(sessionId); + this.logger.debug(`Replaced existing session for client ${sessionId}`); + } + + const language = payload.language; + + const session: TranscriptionSession = { + chunks: [], + totalSize: 0, + language, + startedAt: new Date(), + }; + + this.sessions.set(sessionId, session); + + authenticatedClient.emit("transcription-started", { + sessionId, + language, + }); + + this.logger.debug( + `Transcription session started for client ${sessionId} (language: ${language ?? "auto"})` + ); + } + + /** + * Receive an audio chunk and accumulate it in the active session. + * Enforces maximum buffer size from configuration. + * + * @param client - The connected socket client + * @param data - Audio data as Buffer or Uint8Array + */ + @SubscribeMessage("audio-chunk") + handleAudioChunk(client: Socket, data: Buffer | Uint8Array): void { + const authenticatedClient = client as AuthenticatedSocket; + + if (!authenticatedClient.data.userId) { + authenticatedClient.emit("transcription-error", { + message: "Not authenticated. Connect with a valid token.", + }); + return; + } + + const sessionId = authenticatedClient.id; + const session = this.sessions.get(sessionId); + + if (!session) { + authenticatedClient.emit("transcription-error", { + message: "No active transcription session. Send start-transcription first.", + }); + return; + } + + const chunk = Buffer.isBuffer(data) ? data : Buffer.from(data); + const newTotalSize = session.totalSize + chunk.length; + + if (newTotalSize > this.config.limits.maxUploadSize) { + authenticatedClient.emit("transcription-error", { + message: `Audio buffer size (${String(newTotalSize)} bytes) exceeds maximum allowed size (${String(this.config.limits.maxUploadSize)} bytes).`, + }); + // Clean up the session on overflow + this.sessions.delete(sessionId); + return; + } + + session.chunks.push(chunk); + session.totalSize = newTotalSize; + } + + /** + * Stop the transcription session, concatenate audio chunks, and transcribe. + * Emits `transcription-final` on success or `transcription-error` on failure. + * + * @param client - The connected socket client + */ + @SubscribeMessage("stop-transcription") + async handleStopTranscription(client: Socket): Promise { + const authenticatedClient = client as AuthenticatedSocket; + + if (!authenticatedClient.data.userId) { + authenticatedClient.emit("transcription-error", { + message: "Not authenticated. Connect with a valid token.", + }); + return; + } + + const sessionId = authenticatedClient.id; + const session = this.sessions.get(sessionId); + + if (!session) { + authenticatedClient.emit("transcription-error", { + message: "No active transcription session. Send start-transcription first.", + }); + return; + } + + // Always remove session before processing (prevents double-stop) + this.sessions.delete(sessionId); + + if (session.chunks.length === 0) { + authenticatedClient.emit("transcription-error", { + message: "No audio data received. Send audio-chunk events before stopping.", + }); + return; + } + + try { + const audioBuffer = Buffer.concat(session.chunks); + const options: { language?: string } = {}; + if (session.language) { + options.language = session.language; + } + + this.logger.debug( + `Transcribing ${String(audioBuffer.length)} bytes for client ${sessionId} (language: ${session.language ?? "auto"})` + ); + + const result = await this.speechService.transcribe(audioBuffer, options); + + authenticatedClient.emit("transcription-final", { + text: result.text, + language: result.language, + durationSeconds: result.durationSeconds, + confidence: result.confidence, + segments: result.segments, + }); + + this.logger.debug(`Transcription complete for client ${sessionId}: "${result.text}"`); + } catch (error: unknown) { + const message = error instanceof Error ? error.message : String(error); + this.logger.error(`Transcription failed for client ${sessionId}: ${message}`); + authenticatedClient.emit("transcription-error", { + message: `Transcription failed: ${message}`, + }); + } + } + + // ========================================== + // Private helpers + // ========================================== + + /** + * Extract authentication token from Socket.IO handshake. + * Checks auth.token, query.token, and Authorization header (in that order). + */ + private extractTokenFromHandshake(client: Socket): string | undefined { + const authToken = client.handshake.auth.token as unknown; + if (typeof authToken === "string" && authToken.length > 0) { + return authToken; + } + + const queryToken = client.handshake.query.token as unknown; + if (typeof queryToken === "string" && queryToken.length > 0) { + return queryToken; + } + + const authHeader = client.handshake.headers.authorization as unknown; + if (typeof authHeader === "string") { + const parts = authHeader.split(" "); + const [type, token] = parts; + if (type === "Bearer" && token) { + return token; + } + } + + return undefined; + } +} diff --git a/apps/api/src/speech/speech.module.ts b/apps/api/src/speech/speech.module.ts index d2151ef..42978f9 100644 --- a/apps/api/src/speech/speech.module.ts +++ b/apps/api/src/speech/speech.module.ts @@ -11,15 +11,18 @@ * * Imports: * - ConfigModule.forFeature(speechConfig) for speech configuration + * - AuthModule for WebSocket authentication + * - PrismaModule for workspace membership queries * * Providers: * - SpeechService: High-level speech operations with provider selection + * - SpeechGateway: WebSocket gateway for streaming transcription (Issue #397) * - TTS_PROVIDERS: Map populated by factory based on config * * Exports: * - SpeechService for use by other modules (e.g., controllers, brain) * - * Issue #389, #390, #391 + * Issue #389, #390, #391, #397 */ import { Module, type OnModuleInit, Logger } from "@nestjs/common"; @@ -32,15 +35,19 @@ import { } from "./speech.config"; import { SpeechService } from "./speech.service"; import { SpeechController } from "./speech.controller"; +import { SpeechGateway } from "./speech.gateway"; import { STT_PROVIDER, TTS_PROVIDERS } from "./speech.constants"; import { SpeachesSttProvider } from "./providers/speaches-stt.provider"; import { createTTSProviders } from "./providers/tts-provider.factory"; +import { AuthModule } from "../auth/auth.module"; +import { PrismaModule } from "../prisma/prisma.module"; @Module({ - imports: [ConfigModule.forFeature(speechConfig)], + imports: [ConfigModule.forFeature(speechConfig), AuthModule, PrismaModule], controllers: [SpeechController], providers: [ SpeechService, + SpeechGateway, // STT provider: conditionally register SpeachesSttProvider when STT is enabled ...(isSttEnabled() ? [