From 3a98b7866197cf9c11c00a725bf3d2839211cc93 Mon Sep 17 00:00:00 2001 From: Jason Woltje Date: Wed, 4 Feb 2026 07:12:42 -0600 Subject: [PATCH] fix: Complete CSRF protection implementation Closes three CSRF security gaps identified in code review: 1. Added X-CSRF-Token and X-Workspace-Id to CORS allowed headers - Updated apps/api/src/main.ts to accept CSRF token headers 2. Integrated CSRF token handling in web client - Added fetchCsrfToken() to fetch token from API - Store token in memory (not localStorage for security) - Automatically include X-CSRF-Token in POST/PUT/PATCH/DELETE - Implement automatic token refresh on 403 CSRF errors - Added comprehensive test coverage for CSRF functionality 3. Applied CSRF Guard globally - Added CsrfGuard as APP_GUARD in app.module.ts - Verified @SkipCsrf() decorator works for exempted endpoints All tests passing. CSRF protection now enforced application-wide. Co-Authored-By: Claude Sonnet 4.5 --- apps/api/src/app.module.ts | 5 + apps/api/src/main.ts | 2 +- apps/web/src/lib/api/client.test.ts | 348 +++++++++++++++++++++++++++- apps/web/src/lib/api/client.ts | 84 ++++++- 4 files changed, 434 insertions(+), 5 deletions(-) diff --git a/apps/api/src/app.module.ts b/apps/api/src/app.module.ts index 6dbd1b4..285471a 100644 --- a/apps/api/src/app.module.ts +++ b/apps/api/src/app.module.ts @@ -3,6 +3,7 @@ import { APP_INTERCEPTOR, APP_GUARD } from "@nestjs/core"; import { ThrottlerModule } from "@nestjs/throttler"; import { BullModule } from "@nestjs/bullmq"; import { ThrottlerValkeyStorageService, ThrottlerApiKeyGuard } from "./common/throttler"; +import { CsrfGuard } from "./common/guards/csrf.guard"; import { AppController } from "./app.controller"; import { AppService } from "./app.service"; import { CsrfController } from "./common/controllers/csrf.controller"; @@ -99,6 +100,10 @@ import { FederationModule } from "./federation/federation.module"; provide: APP_GUARD, useClass: ThrottlerApiKeyGuard, }, + { + provide: APP_GUARD, + useClass: CsrfGuard, + }, ], }) export class AppModule {} diff --git a/apps/api/src/main.ts b/apps/api/src/main.ts index 98fb2f2..a32e51a 100644 --- a/apps/api/src/main.ts +++ b/apps/api/src/main.ts @@ -75,7 +75,7 @@ async function bootstrap() { }, credentials: true, // Required for cookie-based authentication methods: ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], - allowedHeaders: ["Content-Type", "Authorization", "Cookie"], + allowedHeaders: ["Content-Type", "Authorization", "Cookie", "X-CSRF-Token", "X-Workspace-Id"], exposedHeaders: ["Set-Cookie"], maxAge: 86400, // 24 hours - cache preflight requests }); diff --git a/apps/web/src/lib/api/client.test.ts b/apps/web/src/lib/api/client.test.ts index 971b2ba..41ee087 100644 --- a/apps/web/src/lib/api/client.test.ts +++ b/apps/web/src/lib/api/client.test.ts @@ -1,7 +1,16 @@ import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; /* eslint-disable @typescript-eslint/no-unsafe-assignment */ /* eslint-disable @typescript-eslint/no-non-null-assertion */ -import { apiRequest, apiGet, apiPost, apiPatch, apiDelete } from "./client"; +import { + apiRequest, + apiGet, + apiPost, + apiPatch, + apiDelete, + fetchCsrfToken, + getCsrfToken, + clearCsrfToken, +} from "./client"; // Mock fetch globally const mockFetch = vi.fn(); @@ -10,6 +19,7 @@ global.fetch = mockFetch; describe("API Client", (): void => { beforeEach((): void => { mockFetch.mockClear(); + clearCsrfToken(); }); afterEach((): void => { @@ -126,6 +136,14 @@ describe("API Client", (): void => { it("should make a POST request with data", async (): Promise => { const postData = { name: "New Item" }; const mockResponse = { id: "1", ...postData }; + + // Mock CSRF token fetch + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: "test-token" }), + }); + + // Mock actual POST request mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve(mockResponse), @@ -144,6 +162,13 @@ describe("API Client", (): void => { }); it("should make a POST request without data", async (): Promise => { + // Mock CSRF token fetch + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: "test-token" }), + }); + + // Mock actual POST request mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve({}), @@ -159,13 +184,21 @@ describe("API Client", (): void => { }) ); - // Verify body is not in the call - const callArgs = mockFetch.mock.calls[0]![1] as RequestInit; + // Verify body is not in the call (second call is the actual POST) + const callArgs = mockFetch.mock.calls[1]![1] as RequestInit; expect(callArgs.body).toBeUndefined(); }); it("should include workspace ID in header when provided", async (): Promise => { const postData = { name: "New Item" }; + + // Mock CSRF token fetch + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: "test-token" }), + }); + + // Mock actual POST request mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve({}), @@ -189,6 +222,14 @@ describe("API Client", (): void => { it("should make a PATCH request with data", async (): Promise => { const patchData = { name: "Updated" }; const mockResponse = { id: "1", ...patchData }; + + // Mock CSRF token fetch + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: "test-token" }), + }); + + // Mock actual PATCH request mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve(mockResponse), @@ -209,6 +250,13 @@ describe("API Client", (): void => { describe("apiDelete", (): void => { it("should make a DELETE request", async (): Promise => { + // Mock CSRF token fetch + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: "test-token" }), + }); + + // Mock actual DELETE request mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve({ success: true }), @@ -376,4 +424,298 @@ describe("API Client", (): void => { }); }); }); + + describe("CSRF Protection", (): void => { + describe("fetchCsrfToken", (): void => { + it("should fetch CSRF token from API", async (): Promise => { + const mockToken = "test-csrf-token-abc123"; + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: mockToken }), + }); + + const token = await fetchCsrfToken(); + + expect(token).toBe(mockToken); + expect(mockFetch).toHaveBeenCalledWith( + "http://localhost:3001/api/v1/csrf/token", + expect.objectContaining({ + method: "GET", + credentials: "include", + }) + ); + }); + + it("should throw error when fetch fails", async (): Promise => { + mockFetch.mockResolvedValueOnce({ + ok: false, + statusText: "Internal Server Error", + json: () => + Promise.resolve({ + code: "SERVER_ERROR", + message: "Failed to generate token", + }), + }); + + await expect(fetchCsrfToken()).rejects.toThrow("Failed to generate token"); + }); + + it("should cache token in memory", async (): Promise => { + const mockToken = "cached-token-xyz"; + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: mockToken }), + }); + + await fetchCsrfToken(); + const cachedToken = getCsrfToken(); + + expect(cachedToken).toBe(mockToken); + }); + }); + + describe("CSRF token inclusion in requests", (): void => { + it("should include X-CSRF-Token header in POST requests", async (): Promise => { + const mockToken = "post-csrf-token"; + + // Mock token fetch + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: mockToken }), + }); + + // Mock actual POST request + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ data: { id: 1 } }), + }); + + await apiPost("/test", { title: "Test Task" }); + + // Second call should include CSRF token + expect(mockFetch).toHaveBeenCalledTimes(2); + const postCall = mockFetch.mock.calls[1]![1] as RequestInit; + const headers = postCall.headers as Record; + expect(headers["X-CSRF-Token"]).toBe(mockToken); + }); + + it("should include X-CSRF-Token header in PATCH requests", async (): Promise => { + const mockToken = "patch-csrf-token"; + + // Mock token fetch + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: mockToken }), + }); + + // Mock actual PATCH request + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ data: { id: 1 } }), + }); + + await apiPatch("/test/1", { title: "Updated Task" }); + + expect(mockFetch).toHaveBeenCalledTimes(2); + const patchCall = mockFetch.mock.calls[1]![1] as RequestInit; + const headers = patchCall.headers as Record; + expect(headers["X-CSRF-Token"]).toBe(mockToken); + }); + + it("should include X-CSRF-Token header in DELETE requests", async (): Promise => { + const mockToken = "delete-csrf-token"; + + // Mock token fetch + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: mockToken }), + }); + + // Mock actual DELETE request + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ success: true }), + }); + + await apiDelete("/test/1"); + + expect(mockFetch).toHaveBeenCalledTimes(2); + const deleteCall = mockFetch.mock.calls[1]![1] as RequestInit; + const headers = deleteCall.headers as Record; + expect(headers["X-CSRF-Token"]).toBe(mockToken); + }); + + it("should NOT include X-CSRF-Token header in GET requests", async (): Promise => { + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ data: [] }), + }); + + await apiGet("/test"); + + expect(mockFetch).toHaveBeenCalledTimes(1); + const getCall = mockFetch.mock.calls[0]![1] as RequestInit; + const headers = getCall.headers as Record; + expect(headers["X-CSRF-Token"]).toBeUndefined(); + }); + }); + + describe("Automatic token refresh on 403 CSRF errors", (): void => { + it("should refresh token and retry on 403 CSRF error", async (): Promise => { + const oldToken = "old-token"; + const newToken = "new-token"; + + // Initial token fetch + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: oldToken }), + }); + + // First POST fails with CSRF error + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 403, + json: () => + Promise.resolve({ + code: "CSRF_ERROR", + message: "CSRF token mismatch", + }), + }); + + // Token refresh succeeds + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: newToken }), + }); + + // Retry succeeds + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ data: { id: 1 } }), + }); + + const result = await apiPost("/test", { title: "Test Task" }); + + expect(result).toEqual({ data: { id: 1 } }); + expect(mockFetch).toHaveBeenCalledTimes(4); + expect(getCsrfToken()).toBe(newToken); + }); + + it("should throw error if retry also fails", async (): Promise => { + const oldToken = "old-token"; + const newToken = "new-token"; + + // Initial token fetch + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: oldToken }), + }); + + // First POST fails with CSRF error + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 403, + json: () => + Promise.resolve({ + code: "CSRF_ERROR", + message: "CSRF token mismatch", + }), + }); + + // Token refresh succeeds + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: newToken }), + }); + + // Retry also fails + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 401, + json: () => + Promise.resolve({ + code: "UNAUTHORIZED", + message: "Not authenticated", + }), + }); + + await expect(apiPost("/test", { title: "Test Task" })).rejects.toThrow("Not authenticated"); + }); + + it("should not retry non-CSRF 403 errors", async (): Promise => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 403, + json: () => + Promise.resolve({ + code: "FORBIDDEN", + message: "Access denied", + }), + }); + + await expect(apiGet("/test")).rejects.toThrow("Access denied"); + + // Should not have retried + expect(mockFetch).toHaveBeenCalledTimes(1); + }); + }); + + describe("Automatic token fetching", (): void => { + it("should fetch token automatically on first state-changing request", async (): Promise => { + const mockToken = "auto-fetched-token"; + + // Token fetch + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: mockToken }), + }); + + // Actual request + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ data: { id: 1 } }), + }); + + await apiPost("/test", { title: "Test Task" }); + + expect(mockFetch).toHaveBeenCalledTimes(2); + expect(getCsrfToken()).toBe(mockToken); + }); + + it("should reuse cached token for subsequent requests", async (): Promise => { + const mockToken = "cached-token-reused"; + + // First request - token fetch + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ token: mockToken }), + }); + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ data: { id: 1 } }), + }); + + await apiPost("/test", { title: "First Task" }); + + // Second request - reuses cached token + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve({ data: { id: 2 } }), + }); + + await apiPost("/test", { title: "Second Task" }); + + // Should only fetch token once + expect(mockFetch).toHaveBeenCalledTimes(3); + + const firstPostCall = mockFetch.mock.calls[1]![1] as RequestInit; + const secondPostCall = mockFetch.mock.calls[2]![1] as RequestInit; + const headers1 = firstPostCall.headers as Record; + const headers2 = secondPostCall.headers as Record; + + expect(headers1["X-CSRF-Token"]).toBe(mockToken); + expect(headers2["X-CSRF-Token"]).toBe(mockToken); + }); + }); + }); }); diff --git a/apps/web/src/lib/api/client.ts b/apps/web/src/lib/api/client.ts index 11c5f20..2a6308a 100644 --- a/apps/web/src/lib/api/client.ts +++ b/apps/web/src/lib/api/client.ts @@ -7,6 +7,12 @@ const API_BASE_URL = process.env.NEXT_PUBLIC_API_URL ?? "http://localhost:3001"; +/** + * In-memory CSRF token storage + * Using module-level variable instead of localStorage for security + */ +let csrfToken: string | undefined; + export interface ApiError { code: string; message: string; @@ -27,6 +33,60 @@ export interface ApiResponse { */ export interface ApiRequestOptions extends RequestInit { workspaceId?: string; + _isRetry?: boolean; // Internal flag to prevent infinite retry loops +} + +/** + * Fetch CSRF token from the API + * Token is stored in an httpOnly cookie and returned in response body + */ +export async function fetchCsrfToken(): Promise { + const url = `${API_BASE_URL}/api/v1/csrf/token`; + + const response = await fetch(url, { + method: "GET", + credentials: "include", + }); + + if (!response.ok) { + const error: ApiError = await response.json().catch( + (): ApiError => ({ + code: "UNKNOWN_ERROR", + message: response.statusText || "Failed to fetch CSRF token", + }) + ); + + throw new Error(error.message); + } + + const data = (await response.json()) as { token: string }; + csrfToken = data.token; + return data.token; +} + +/** + * Get the current CSRF token from memory + */ +export function getCsrfToken(): string | undefined { + return csrfToken; +} + +/** + * Clear the CSRF token from memory + * Useful for testing or after logout + */ +export function clearCsrfToken(): void { + csrfToken = undefined; +} + +/** + * Ensure CSRF token is available for state-changing requests + */ +async function ensureCsrfToken(): Promise { + if (!csrfToken) { + return fetchCsrfToken(); + } + return csrfToken; } /** @@ -34,7 +94,7 @@ export interface ApiRequestOptions extends RequestInit { */ export async function apiRequest(endpoint: string, options: ApiRequestOptions = {}): Promise { const url = `${API_BASE_URL}${endpoint}`; - const { workspaceId, ...fetchOptions } = options; + const { workspaceId, _isRetry, ...fetchOptions } = options; // Build headers with workspace ID if provided const baseHeaders = (fetchOptions.headers as Record | undefined) ?? {}; @@ -48,6 +108,15 @@ export async function apiRequest(endpoint: string, options: ApiRequestOptions headers["X-Workspace-Id"] = workspaceId; } + // Add CSRF token for state-changing requests (POST, PUT, PATCH, DELETE) + const method = (fetchOptions.method ?? "GET").toUpperCase(); + const isStateChanging = ["POST", "PUT", "PATCH", "DELETE"].includes(method); + + if (isStateChanging) { + const token = await ensureCsrfToken(); + headers["X-CSRF-Token"] = token; + } + const response = await fetch(url, { ...fetchOptions, headers, @@ -62,6 +131,19 @@ export async function apiRequest(endpoint: string, options: ApiRequestOptions }) ); + // Handle CSRF token mismatch - refresh token and retry once + if ( + response.status === 403 && + (error.code === "CSRF_ERROR" || error.message.includes("CSRF")) && + !_isRetry + ) { + // Refresh CSRF token + await fetchCsrfToken(); + + // Retry the request with new token + return apiRequest(endpoint, { ...options, _isRetry: true }); + } + throw new Error(error.message); }