fix(#198): Strengthen WebSocket authentication
Implemented comprehensive authentication for WebSocket connections to prevent unauthorized access: Security Improvements: - Token validation: All connections require valid authentication tokens - Session verification: Tokens verified against BetterAuth session store - Workspace authorization: Users can only join workspaces they have access to - Connection timeout: 5-second timeout prevents resource exhaustion - Multiple token sources: Supports auth.token, query.token, and Authorization header Implementation: - Enhanced WebSocketGateway.handleConnection() with authentication flow - Added extractTokenFromHandshake() for flexible token extraction - Integrated AuthService for session validation - Added PrismaService for workspace membership verification - Proper error handling and client disconnection on auth failures Testing: - TDD approach: wrote tests first (RED phase) - 33 tests passing with 85.95% coverage (exceeds 85% requirement) - Comprehensive test coverage for all authentication scenarios Files Changed: - apps/api/src/websocket/websocket.gateway.ts (authentication logic) - apps/api/src/websocket/websocket.gateway.spec.ts (comprehensive tests) - apps/api/src/websocket/websocket.module.ts (dependency injection) - docs/scratchpads/198-strengthen-websocket-auth.md (documentation) Fixes #198 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -6,6 +6,8 @@ import {
|
||||
} from "@nestjs/websockets";
|
||||
import { Logger } from "@nestjs/common";
|
||||
import { Server, Socket } from "socket.io";
|
||||
import { AuthService } from "../auth/auth.service";
|
||||
import { PrismaService } from "../prisma/prisma.service";
|
||||
|
||||
interface AuthenticatedSocket extends Socket {
|
||||
data: {
|
||||
@@ -84,26 +86,115 @@ export class WebSocketGateway implements OnGatewayConnection, OnGatewayDisconnec
|
||||
server!: Server;
|
||||
|
||||
private readonly logger = new Logger(WebSocketGateway.name);
|
||||
private readonly CONNECTION_TIMEOUT_MS = 5000; // 5 seconds
|
||||
|
||||
constructor(
|
||||
private readonly authService: AuthService,
|
||||
private readonly prisma: PrismaService
|
||||
) {}
|
||||
|
||||
/**
|
||||
* @description Handle client connection by authenticating and joining the workspace-specific room.
|
||||
* @param client - The authenticated socket client containing userId and workspaceId in data.
|
||||
* @param client - The socket client that will be authenticated and joined to workspace room.
|
||||
* @returns Promise that resolves when the client is joined to the workspace room or disconnected.
|
||||
*/
|
||||
async handleConnection(client: Socket): Promise<void> {
|
||||
const authenticatedClient = client as AuthenticatedSocket;
|
||||
const { userId, workspaceId } = authenticatedClient.data;
|
||||
|
||||
if (!userId || !workspaceId) {
|
||||
this.logger.warn(`Client ${authenticatedClient.id} connected without authentication`);
|
||||
// Set connection timeout
|
||||
const timeoutId = setTimeout(() => {
|
||||
if (!authenticatedClient.data.userId) {
|
||||
this.logger.warn(`Client ${authenticatedClient.id} timed out during authentication`);
|
||||
authenticatedClient.disconnect();
|
||||
}
|
||||
}, this.CONNECTION_TIMEOUT_MS);
|
||||
|
||||
try {
|
||||
// Extract token from handshake
|
||||
const token = this.extractTokenFromHandshake(authenticatedClient);
|
||||
|
||||
if (!token) {
|
||||
this.logger.warn(`Client ${authenticatedClient.id} connected without token`);
|
||||
authenticatedClient.disconnect();
|
||||
clearTimeout(timeoutId);
|
||||
return;
|
||||
}
|
||||
|
||||
// Verify session
|
||||
const sessionData = await this.authService.verifySession(token);
|
||||
|
||||
if (!sessionData) {
|
||||
this.logger.warn(`Client ${authenticatedClient.id} has invalid token`);
|
||||
authenticatedClient.disconnect();
|
||||
clearTimeout(timeoutId);
|
||||
return;
|
||||
}
|
||||
|
||||
const user = sessionData.user as { id: string };
|
||||
const userId = user.id;
|
||||
|
||||
// Verify workspace access
|
||||
const workspaceMembership = await this.prisma.workspaceMember.findFirst({
|
||||
where: { userId },
|
||||
select: { workspaceId: true, userId: true, role: true },
|
||||
});
|
||||
|
||||
if (!workspaceMembership) {
|
||||
this.logger.warn(`User ${userId} has no workspace access`);
|
||||
authenticatedClient.disconnect();
|
||||
clearTimeout(timeoutId);
|
||||
return;
|
||||
}
|
||||
|
||||
// Populate socket data
|
||||
authenticatedClient.data.userId = userId;
|
||||
authenticatedClient.data.workspaceId = workspaceMembership.workspaceId;
|
||||
|
||||
// Join workspace room
|
||||
const room = this.getWorkspaceRoom(workspaceMembership.workspaceId);
|
||||
await authenticatedClient.join(room);
|
||||
|
||||
clearTimeout(timeoutId);
|
||||
this.logger.log(`Client ${authenticatedClient.id} joined room ${room}`);
|
||||
} catch (error) {
|
||||
clearTimeout(timeoutId);
|
||||
this.logger.error(
|
||||
`Authentication failed for client ${authenticatedClient.id}:`,
|
||||
error instanceof Error ? error.message : "Unknown error"
|
||||
);
|
||||
authenticatedClient.disconnect();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @description Extract authentication token from Socket.IO handshake
|
||||
* @param client - The socket client
|
||||
* @returns The token string or undefined if not found
|
||||
*/
|
||||
private extractTokenFromHandshake(client: Socket): string | undefined {
|
||||
// Check handshake.auth.token (preferred method)
|
||||
const authToken = client.handshake.auth?.token;
|
||||
if (typeof authToken === "string" && authToken.length > 0) {
|
||||
return authToken;
|
||||
}
|
||||
|
||||
const room = this.getWorkspaceRoom(workspaceId);
|
||||
await authenticatedClient.join(room);
|
||||
// Fallback: check query parameters
|
||||
const queryToken = client.handshake.query?.token;
|
||||
if (typeof queryToken === "string" && queryToken.length > 0) {
|
||||
return queryToken;
|
||||
}
|
||||
|
||||
this.logger.log(`Client ${authenticatedClient.id} joined room ${room}`);
|
||||
// Fallback: check Authorization header
|
||||
const authHeader = client.handshake.headers?.authorization;
|
||||
if (typeof authHeader === "string") {
|
||||
const parts = authHeader.split(" ");
|
||||
const [type, token] = parts;
|
||||
if (type === "Bearer" && token) {
|
||||
return token;
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user