fix(#198): Strengthen WebSocket authentication

Implemented comprehensive authentication for WebSocket connections to prevent
unauthorized access:

Security Improvements:
- Token validation: All connections require valid authentication tokens
- Session verification: Tokens verified against BetterAuth session store
- Workspace authorization: Users can only join workspaces they have access to
- Connection timeout: 5-second timeout prevents resource exhaustion
- Multiple token sources: Supports auth.token, query.token, and Authorization header

Implementation:
- Enhanced WebSocketGateway.handleConnection() with authentication flow
- Added extractTokenFromHandshake() for flexible token extraction
- Integrated AuthService for session validation
- Added PrismaService for workspace membership verification
- Proper error handling and client disconnection on auth failures

Testing:
- TDD approach: wrote tests first (RED phase)
- 33 tests passing with 85.95% coverage (exceeds 85% requirement)
- Comprehensive test coverage for all authentication scenarios

Files Changed:
- apps/api/src/websocket/websocket.gateway.ts (authentication logic)
- apps/api/src/websocket/websocket.gateway.spec.ts (comprehensive tests)
- apps/api/src/websocket/websocket.module.ts (dependency injection)
- docs/scratchpads/198-strengthen-websocket-auth.md (documentation)

Fixes #198

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Jason Woltje
2026-02-02 13:04:34 -06:00
parent 431bcb3f0f
commit 210b3d2e8f
4 changed files with 490 additions and 20 deletions

View File

@@ -6,6 +6,8 @@ import {
} from "@nestjs/websockets";
import { Logger } from "@nestjs/common";
import { Server, Socket } from "socket.io";
import { AuthService } from "../auth/auth.service";
import { PrismaService } from "../prisma/prisma.service";
interface AuthenticatedSocket extends Socket {
data: {
@@ -84,26 +86,115 @@ export class WebSocketGateway implements OnGatewayConnection, OnGatewayDisconnec
server!: Server;
private readonly logger = new Logger(WebSocketGateway.name);
private readonly CONNECTION_TIMEOUT_MS = 5000; // 5 seconds
constructor(
private readonly authService: AuthService,
private readonly prisma: PrismaService
) {}
/**
* @description Handle client connection by authenticating and joining the workspace-specific room.
* @param client - The authenticated socket client containing userId and workspaceId in data.
* @param client - The socket client that will be authenticated and joined to workspace room.
* @returns Promise that resolves when the client is joined to the workspace room or disconnected.
*/
async handleConnection(client: Socket): Promise<void> {
const authenticatedClient = client as AuthenticatedSocket;
const { userId, workspaceId } = authenticatedClient.data;
if (!userId || !workspaceId) {
this.logger.warn(`Client ${authenticatedClient.id} connected without authentication`);
// Set connection timeout
const timeoutId = setTimeout(() => {
if (!authenticatedClient.data.userId) {
this.logger.warn(`Client ${authenticatedClient.id} timed out during authentication`);
authenticatedClient.disconnect();
}
}, this.CONNECTION_TIMEOUT_MS);
try {
// Extract token from handshake
const token = this.extractTokenFromHandshake(authenticatedClient);
if (!token) {
this.logger.warn(`Client ${authenticatedClient.id} connected without token`);
authenticatedClient.disconnect();
clearTimeout(timeoutId);
return;
}
// Verify session
const sessionData = await this.authService.verifySession(token);
if (!sessionData) {
this.logger.warn(`Client ${authenticatedClient.id} has invalid token`);
authenticatedClient.disconnect();
clearTimeout(timeoutId);
return;
}
const user = sessionData.user as { id: string };
const userId = user.id;
// Verify workspace access
const workspaceMembership = await this.prisma.workspaceMember.findFirst({
where: { userId },
select: { workspaceId: true, userId: true, role: true },
});
if (!workspaceMembership) {
this.logger.warn(`User ${userId} has no workspace access`);
authenticatedClient.disconnect();
clearTimeout(timeoutId);
return;
}
// Populate socket data
authenticatedClient.data.userId = userId;
authenticatedClient.data.workspaceId = workspaceMembership.workspaceId;
// Join workspace room
const room = this.getWorkspaceRoom(workspaceMembership.workspaceId);
await authenticatedClient.join(room);
clearTimeout(timeoutId);
this.logger.log(`Client ${authenticatedClient.id} joined room ${room}`);
} catch (error) {
clearTimeout(timeoutId);
this.logger.error(
`Authentication failed for client ${authenticatedClient.id}:`,
error instanceof Error ? error.message : "Unknown error"
);
authenticatedClient.disconnect();
return;
}
}
/**
* @description Extract authentication token from Socket.IO handshake
* @param client - The socket client
* @returns The token string or undefined if not found
*/
private extractTokenFromHandshake(client: Socket): string | undefined {
// Check handshake.auth.token (preferred method)
const authToken = client.handshake.auth?.token;
if (typeof authToken === "string" && authToken.length > 0) {
return authToken;
}
const room = this.getWorkspaceRoom(workspaceId);
await authenticatedClient.join(room);
// Fallback: check query parameters
const queryToken = client.handshake.query?.token;
if (typeof queryToken === "string" && queryToken.length > 0) {
return queryToken;
}
this.logger.log(`Client ${authenticatedClient.id} joined room ${room}`);
// Fallback: check Authorization header
const authHeader = client.handshake.headers?.authorization;
if (typeof authHeader === "string") {
const parts = authHeader.split(" ");
const [type, token] = parts;
if (type === "Bearer" && token) {
return token;
}
}
return undefined;
}
/**