Files
stack/apps/api/src/common/interceptors/rls-context.interceptor.spec.ts
Jason Woltje 8424a28faa
All checks were successful
ci/woodpecker/push/api Pipeline was successful
fix(auth): use set_config for transaction-scoped RLS context
2026-02-18 23:23:15 -06:00

307 lines
9.4 KiB
TypeScript

import { describe, it, expect, beforeEach, vi } from "vitest";
import { Test, TestingModule } from "@nestjs/testing";
import { ExecutionContext, CallHandler, InternalServerErrorException } from "@nestjs/common";
import { of, throwError } from "rxjs";
import { RlsContextInterceptor, type TransactionClient } from "./rls-context.interceptor";
import { PrismaService } from "../../prisma/prisma.service";
import { getRlsClient } from "../../prisma/rls-context.provider";
import type { AuthenticatedRequest } from "../types/user.types";
describe("RlsContextInterceptor", () => {
let interceptor: RlsContextInterceptor;
let prismaService: PrismaService;
let mockExecutionContext: ExecutionContext;
let mockCallHandler: CallHandler;
let mockTransactionClient: TransactionClient;
beforeEach(async () => {
// Create mock transaction client (excludes $connect, $disconnect, etc.)
mockTransactionClient = {
$executeRaw: vi.fn().mockResolvedValue(undefined),
} as unknown as TransactionClient;
// Create mock Prisma service
const mockPrismaService = {
$transaction: vi.fn(
async (
callback: (tx: TransactionClient) => Promise<unknown>,
options?: { timeout?: number; maxWait?: number }
) => {
return callback(mockTransactionClient);
}
),
};
const module: TestingModule = await Test.createTestingModule({
providers: [
RlsContextInterceptor,
{
provide: PrismaService,
useValue: mockPrismaService,
},
],
}).compile();
interceptor = module.get<RlsContextInterceptor>(RlsContextInterceptor);
prismaService = module.get<PrismaService>(PrismaService);
// Setup mock call handler
mockCallHandler = {
handle: vi.fn(() => of({ data: "test response" })),
};
});
const createMockExecutionContext = (request: Partial<AuthenticatedRequest>): ExecutionContext => {
return {
switchToHttp: () => ({
getRequest: () => request,
}),
} as ExecutionContext;
};
describe("intercept", () => {
it("should set user context when user is authenticated", async () => {
const userId = "user-123";
const request: Partial<AuthenticatedRequest> = {
user: {
id: userId,
email: "test@example.com",
name: "Test User",
},
};
mockExecutionContext = createMockExecutionContext(request);
const result = await new Promise((resolve) => {
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
next: resolve,
});
});
expect(result).toEqual({ data: "test response" });
expect(mockTransactionClient.$executeRaw).toHaveBeenCalledWith(
expect.arrayContaining(["SELECT set_config('app.current_user_id', ", ", true)"]),
userId
);
});
it("should set workspace context when workspace is present", async () => {
const userId = "user-123";
const workspaceId = "workspace-456";
const request: Partial<AuthenticatedRequest> = {
user: {
id: userId,
email: "test@example.com",
name: "Test User",
workspaceId,
},
workspace: {
id: workspaceId,
},
};
mockExecutionContext = createMockExecutionContext(request);
await new Promise((resolve) => {
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
next: resolve,
});
});
// Check that user context was set
expect(mockTransactionClient.$executeRaw).toHaveBeenNthCalledWith(
1,
expect.arrayContaining(["SELECT set_config('app.current_user_id', ", ", true)"]),
userId
);
// Check that workspace context was set
expect(mockTransactionClient.$executeRaw).toHaveBeenNthCalledWith(
2,
expect.arrayContaining(["SELECT set_config('app.current_workspace_id', ", ", true)"]),
workspaceId
);
});
it("should not set context when user is not authenticated", async () => {
const request: Partial<AuthenticatedRequest> = {
user: undefined,
};
mockExecutionContext = createMockExecutionContext(request);
await new Promise((resolve) => {
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
next: resolve,
});
});
expect(mockTransactionClient.$executeRaw).not.toHaveBeenCalled();
expect(mockCallHandler.handle).toHaveBeenCalled();
});
it("should propagate RLS client via AsyncLocalStorage", async () => {
const userId = "user-123";
const request: Partial<AuthenticatedRequest> = {
user: {
id: userId,
email: "test@example.com",
name: "Test User",
},
};
mockExecutionContext = createMockExecutionContext(request);
// Override call handler to check if RLS client is available
let capturedClient: PrismaClient | undefined;
mockCallHandler = {
handle: vi.fn(() => {
capturedClient = getRlsClient();
return of({ data: "test response" });
}),
};
await new Promise((resolve) => {
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
next: resolve,
});
});
expect(capturedClient).toBe(mockTransactionClient);
});
it("should handle errors and still propagate them", async () => {
const userId = "user-123";
const request: Partial<AuthenticatedRequest> = {
user: {
id: userId,
email: "test@example.com",
name: "Test User",
},
};
mockExecutionContext = createMockExecutionContext(request);
const error = new Error("Test error");
mockCallHandler = {
handle: vi.fn(() => throwError(() => error)),
};
await expect(
new Promise((resolve, reject) => {
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
next: resolve,
error: reject,
});
})
).rejects.toThrow(error);
// Context should still have been set before error
expect(mockTransactionClient.$executeRaw).toHaveBeenCalled();
});
it("should clear RLS context after request completes", async () => {
const userId = "user-123";
const request: Partial<AuthenticatedRequest> = {
user: {
id: userId,
email: "test@example.com",
name: "Test User",
},
};
mockExecutionContext = createMockExecutionContext(request);
await new Promise((resolve) => {
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
next: resolve,
});
});
// After the observable completes, RLS context should be cleared
const client = getRlsClient();
expect(client).toBeUndefined();
});
it("should handle missing user.id gracefully", async () => {
const request: Partial<AuthenticatedRequest> = {
user: {
id: "",
email: "test@example.com",
name: "Test User",
},
};
mockExecutionContext = createMockExecutionContext(request);
await new Promise((resolve) => {
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
next: resolve,
});
});
expect(mockTransactionClient.$executeRaw).not.toHaveBeenCalled();
expect(mockCallHandler.handle).toHaveBeenCalled();
});
it("should configure transaction with timeout and maxWait", async () => {
const userId = "user-123";
const request: Partial<AuthenticatedRequest> = {
user: {
id: userId,
email: "test@example.com",
name: "Test User",
},
};
mockExecutionContext = createMockExecutionContext(request);
await new Promise((resolve) => {
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
next: resolve,
});
});
// Verify transaction was called with timeout options
expect(prismaService.$transaction).toHaveBeenCalledWith(
expect.any(Function),
expect.objectContaining({
timeout: 30000, // 30 seconds
maxWait: 10000, // 10 seconds
})
);
});
it("should sanitize database errors before sending to client", async () => {
const userId = "user-123";
const request: Partial<AuthenticatedRequest> = {
user: {
id: userId,
email: "test@example.com",
name: "Test User",
},
};
mockExecutionContext = createMockExecutionContext(request);
// Mock transaction to throw a database error with sensitive information
const databaseError = new Error(
"PrismaClientKnownRequestError: Connection failed to database.internal.example.com:5432"
);
vi.spyOn(prismaService, "$transaction").mockRejectedValue(databaseError);
const errorPromise = new Promise((resolve, reject) => {
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
next: resolve,
error: reject,
});
});
await expect(errorPromise).rejects.toThrow(InternalServerErrorException);
await expect(errorPromise).rejects.toThrow("Request processing failed");
// Verify the detailed error was NOT sent to the client
await expect(errorPromise).rejects.not.toThrow("database.internal.example.com");
});
});
});