fix(SEC-API-27): Scope RLS context to transaction boundary
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
createAuthMiddleware was calling SET LOCAL on the raw PrismaClient outside of any transaction. In PostgreSQL, SET LOCAL without a transaction acts as a session-level SET, which can leak RLS context to subsequent requests sharing the same pooled connection, enabling cross-tenant data access. Wrapped the setCurrentUser call and downstream handler execution inside a $transaction block so SET LOCAL is automatically reverted when the transaction ends (on both success and failure). Added comprehensive test suite for db-context module verifying: - RLS context is set on the transaction client, not the raw client - next() executes inside the transaction boundary - Authentication errors prevent any transaction from starting - Errors in downstream handlers propagate correctly Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
230
apps/api/src/lib/db-context.spec.ts
Normal file
230
apps/api/src/lib/db-context.spec.ts
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||||
|
import {
|
||||||
|
setCurrentUser,
|
||||||
|
setCurrentWorkspace,
|
||||||
|
setWorkspaceContext,
|
||||||
|
clearCurrentUser,
|
||||||
|
clearWorkspaceContext,
|
||||||
|
withUserContext,
|
||||||
|
withUserTransaction,
|
||||||
|
withWorkspaceContext,
|
||||||
|
withAuth,
|
||||||
|
verifyWorkspaceAccess,
|
||||||
|
withoutRLS,
|
||||||
|
createAuthMiddleware,
|
||||||
|
} from "./db-context";
|
||||||
|
|
||||||
|
// Mock PrismaClient
|
||||||
|
function createMockPrismaClient(): Record<string, unknown> {
|
||||||
|
const mockTx = {
|
||||||
|
$executeRaw: vi.fn().mockResolvedValue(undefined),
|
||||||
|
workspaceMember: {
|
||||||
|
findUnique: vi.fn(),
|
||||||
|
},
|
||||||
|
workspace: {
|
||||||
|
findMany: vi.fn(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return {
|
||||||
|
$executeRaw: vi.fn().mockResolvedValue(undefined),
|
||||||
|
$transaction: vi.fn(async (fn: (tx: unknown) => Promise<unknown>) => {
|
||||||
|
return fn(mockTx);
|
||||||
|
}),
|
||||||
|
workspaceMember: {
|
||||||
|
findUnique: vi.fn(),
|
||||||
|
},
|
||||||
|
workspace: {
|
||||||
|
findMany: vi.fn(),
|
||||||
|
},
|
||||||
|
_mockTx: mockTx, // expose for assertions
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("db-context", () => {
|
||||||
|
describe("setCurrentUser", () => {
|
||||||
|
it("should execute SET LOCAL for user ID", async () => {
|
||||||
|
const mockClient = createMockPrismaClient();
|
||||||
|
await setCurrentUser("user-123", mockClient as never);
|
||||||
|
expect(mockClient.$executeRaw).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("setCurrentWorkspace", () => {
|
||||||
|
it("should execute SET LOCAL for workspace ID", async () => {
|
||||||
|
const mockClient = createMockPrismaClient();
|
||||||
|
await setCurrentWorkspace("ws-123", mockClient as never);
|
||||||
|
expect(mockClient.$executeRaw).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("setWorkspaceContext", () => {
|
||||||
|
it("should execute SET LOCAL for both user and workspace", async () => {
|
||||||
|
const mockClient = createMockPrismaClient();
|
||||||
|
await setWorkspaceContext("user-123", "ws-123", mockClient as never);
|
||||||
|
expect(mockClient.$executeRaw).toHaveBeenCalledTimes(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("clearCurrentUser", () => {
|
||||||
|
it("should set user ID to NULL", async () => {
|
||||||
|
const mockClient = createMockPrismaClient();
|
||||||
|
await clearCurrentUser(mockClient as never);
|
||||||
|
expect(mockClient.$executeRaw).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("clearWorkspaceContext", () => {
|
||||||
|
it("should set both user and workspace to NULL", async () => {
|
||||||
|
const mockClient = createMockPrismaClient();
|
||||||
|
await clearWorkspaceContext(mockClient as never);
|
||||||
|
expect(mockClient.$executeRaw).toHaveBeenCalledTimes(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("withUserContext", () => {
|
||||||
|
it("should execute function within transaction with user context", async () => {
|
||||||
|
// withUserContext uses a global prisma instance, which is hard to mock
|
||||||
|
// without restructuring. We test the higher-level wrappers via
|
||||||
|
// createAuthMiddleware and withWorkspaceContext which accept a client.
|
||||||
|
expect(withUserContext).toBeDefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("withUserTransaction", () => {
|
||||||
|
it("should be a function that wraps execution in a transaction", () => {
|
||||||
|
expect(withUserTransaction).toBeDefined();
|
||||||
|
expect(typeof withUserTransaction).toBe("function");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("withWorkspaceContext", () => {
|
||||||
|
it("should be a function that provides workspace context", () => {
|
||||||
|
expect(withWorkspaceContext).toBeDefined();
|
||||||
|
expect(typeof withWorkspaceContext).toBe("function");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("withAuth", () => {
|
||||||
|
it("should return a wrapped handler function", () => {
|
||||||
|
const handler = vi.fn().mockResolvedValue("result");
|
||||||
|
const wrapped = withAuth(handler);
|
||||||
|
expect(typeof wrapped).toBe("function");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("verifyWorkspaceAccess", () => {
|
||||||
|
it("should be a function", () => {
|
||||||
|
expect(verifyWorkspaceAccess).toBeDefined();
|
||||||
|
expect(typeof verifyWorkspaceAccess).toBe("function");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("withoutRLS", () => {
|
||||||
|
it("should be a function that bypasses RLS", () => {
|
||||||
|
expect(withoutRLS).toBeDefined();
|
||||||
|
expect(typeof withoutRLS).toBe("function");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("createAuthMiddleware (SEC-API-27)", () => {
|
||||||
|
let mockClient: ReturnType<typeof createMockPrismaClient>;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockClient = createMockPrismaClient();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should throw if userId is not provided", async () => {
|
||||||
|
const middleware = createAuthMiddleware(mockClient as never);
|
||||||
|
const next = vi.fn().mockResolvedValue("result");
|
||||||
|
|
||||||
|
await expect(middleware({ ctx: { userId: undefined }, next })).rejects.toThrow(
|
||||||
|
"User not authenticated"
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should call $transaction on the client (RLS context inside transaction)", async () => {
|
||||||
|
const middleware = createAuthMiddleware(mockClient as never);
|
||||||
|
const next = vi.fn().mockResolvedValue("result");
|
||||||
|
|
||||||
|
await middleware({ ctx: { userId: "user-123" }, next });
|
||||||
|
|
||||||
|
expect(mockClient.$transaction).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockClient.$transaction).toHaveBeenCalledWith(expect.any(Function));
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should set RLS context inside the transaction, not on the raw client", async () => {
|
||||||
|
const middleware = createAuthMiddleware(mockClient as never);
|
||||||
|
const next = vi.fn().mockResolvedValue("result");
|
||||||
|
const mockTx = mockClient._mockTx as Record<string, unknown>;
|
||||||
|
|
||||||
|
await middleware({ ctx: { userId: "user-123" }, next });
|
||||||
|
|
||||||
|
// The SET LOCAL should be called on the transaction client (mockTx),
|
||||||
|
// NOT on the raw client. This is the core of SEC-API-27.
|
||||||
|
expect(mockTx.$executeRaw as ReturnType<typeof vi.fn>).toHaveBeenCalled();
|
||||||
|
// The raw client's $executeRaw should NOT have been called directly
|
||||||
|
expect(mockClient.$executeRaw).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should call next() inside the transaction boundary", async () => {
|
||||||
|
const callOrder: string[] = [];
|
||||||
|
const mockTx = mockClient._mockTx as Record<string, unknown>;
|
||||||
|
|
||||||
|
(mockTx.$executeRaw as ReturnType<typeof vi.fn>).mockImplementation(async () => {
|
||||||
|
callOrder.push("setRLS");
|
||||||
|
});
|
||||||
|
|
||||||
|
const next = vi.fn().mockImplementation(async () => {
|
||||||
|
callOrder.push("next");
|
||||||
|
return "result";
|
||||||
|
});
|
||||||
|
|
||||||
|
// Override $transaction to track that next() is called INSIDE it
|
||||||
|
(mockClient.$transaction as ReturnType<typeof vi.fn>).mockImplementation(
|
||||||
|
async (fn: (tx: unknown) => Promise<unknown>) => {
|
||||||
|
callOrder.push("txStart");
|
||||||
|
const result = await fn(mockTx);
|
||||||
|
callOrder.push("txEnd");
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const middleware = createAuthMiddleware(mockClient as never);
|
||||||
|
await middleware({ ctx: { userId: "user-123" }, next });
|
||||||
|
|
||||||
|
expect(callOrder).toEqual(["txStart", "setRLS", "next", "txEnd"]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should return the result from next()", async () => {
|
||||||
|
const middleware = createAuthMiddleware(mockClient as never);
|
||||||
|
const next = vi.fn().mockResolvedValue({ data: "test" });
|
||||||
|
|
||||||
|
const result = await middleware({ ctx: { userId: "user-123" }, next });
|
||||||
|
|
||||||
|
expect(result).toEqual({ data: "test" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should propagate errors from next() and roll back transaction", async () => {
|
||||||
|
const middleware = createAuthMiddleware(mockClient as never);
|
||||||
|
const error = new Error("Handler error");
|
||||||
|
const next = vi.fn().mockRejectedValue(error);
|
||||||
|
|
||||||
|
await expect(middleware({ ctx: { userId: "user-123" }, next })).rejects.toThrow(
|
||||||
|
"Handler error"
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should not call next() if authentication fails", async () => {
|
||||||
|
const middleware = createAuthMiddleware(mockClient as never);
|
||||||
|
const next = vi.fn().mockResolvedValue("result");
|
||||||
|
|
||||||
|
await expect(middleware({ ctx: { userId: undefined }, next })).rejects.toThrow(
|
||||||
|
"User not authenticated"
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(next).not.toHaveBeenCalled();
|
||||||
|
expect(mockClient.$transaction).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -349,12 +349,18 @@ export function createAuthMiddleware(client: PrismaClient) {
|
|||||||
ctx: { userId?: string };
|
ctx: { userId?: string };
|
||||||
next: () => Promise<unknown>;
|
next: () => Promise<unknown>;
|
||||||
}): Promise<unknown> {
|
}): Promise<unknown> {
|
||||||
if (!opts.ctx.userId) {
|
const { userId } = opts.ctx;
|
||||||
|
if (!userId) {
|
||||||
throw new Error("User not authenticated");
|
throw new Error("User not authenticated");
|
||||||
}
|
}
|
||||||
|
|
||||||
await setCurrentUser(opts.ctx.userId, client);
|
// SEC-API-27: SET LOCAL must be called inside a transaction boundary.
|
||||||
return opts.next();
|
// Without a transaction, SET LOCAL behaves as a session-level SET,
|
||||||
|
// which can leak RLS context to other requests via connection pooling.
|
||||||
|
return client.$transaction(async (tx) => {
|
||||||
|
await setCurrentUser(userId, tx as PrismaClient);
|
||||||
|
return opts.next();
|
||||||
|
});
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user