import { describe, it, expect, beforeEach, vi } from "vitest"; import { Test, TestingModule } from "@nestjs/testing"; import { ExecutionContext, CallHandler, InternalServerErrorException } from "@nestjs/common"; import { of, throwError } from "rxjs"; import { RlsContextInterceptor, type TransactionClient } from "./rls-context.interceptor"; import { PrismaService } from "../../prisma/prisma.service"; import { getRlsClient } from "../../prisma/rls-context.provider"; import type { AuthenticatedRequest } from "../types/user.types"; describe("RlsContextInterceptor", () => { let interceptor: RlsContextInterceptor; let prismaService: PrismaService; let mockExecutionContext: ExecutionContext; let mockCallHandler: CallHandler; let mockTransactionClient: TransactionClient; beforeEach(async () => { // Create mock transaction client (excludes $connect, $disconnect, etc.) mockTransactionClient = { $executeRaw: vi.fn().mockResolvedValue(undefined), } as unknown as TransactionClient; // Create mock Prisma service const mockPrismaService = { $transaction: vi.fn( async ( callback: (tx: TransactionClient) => Promise, options?: { timeout?: number; maxWait?: number } ) => { return callback(mockTransactionClient); } ), }; const module: TestingModule = await Test.createTestingModule({ providers: [ RlsContextInterceptor, { provide: PrismaService, useValue: mockPrismaService, }, ], }).compile(); interceptor = module.get(RlsContextInterceptor); prismaService = module.get(PrismaService); // Setup mock call handler mockCallHandler = { handle: vi.fn(() => of({ data: "test response" })), }; }); const createMockExecutionContext = (request: Partial): ExecutionContext => { return { switchToHttp: () => ({ getRequest: () => request, }), } as ExecutionContext; }; describe("intercept", () => { it("should set user context when user is authenticated", async () => { const userId = "user-123"; const request: Partial = { user: { id: userId, email: "test@example.com", name: "Test User", }, }; mockExecutionContext = createMockExecutionContext(request); const result = await new Promise((resolve) => { interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({ next: resolve, }); }); expect(result).toEqual({ data: "test response" }); expect(mockTransactionClient.$executeRaw).toHaveBeenCalledWith( expect.arrayContaining(["SELECT set_config('app.current_user_id', ", ", true)"]), userId ); }); it("should set workspace context when workspace is present", async () => { const userId = "user-123"; const workspaceId = "workspace-456"; const request: Partial = { user: { id: userId, email: "test@example.com", name: "Test User", workspaceId, }, workspace: { id: workspaceId, }, }; mockExecutionContext = createMockExecutionContext(request); await new Promise((resolve) => { interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({ next: resolve, }); }); // Check that user context was set expect(mockTransactionClient.$executeRaw).toHaveBeenNthCalledWith( 1, expect.arrayContaining(["SELECT set_config('app.current_user_id', ", ", true)"]), userId ); // Check that workspace context was set expect(mockTransactionClient.$executeRaw).toHaveBeenNthCalledWith( 2, expect.arrayContaining(["SELECT set_config('app.current_workspace_id', ", ", true)"]), workspaceId ); }); it("should not set context when user is not authenticated", async () => { const request: Partial = { user: undefined, }; mockExecutionContext = createMockExecutionContext(request); await new Promise((resolve) => { interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({ next: resolve, }); }); expect(mockTransactionClient.$executeRaw).not.toHaveBeenCalled(); expect(mockCallHandler.handle).toHaveBeenCalled(); }); it("should propagate RLS client via AsyncLocalStorage", async () => { const userId = "user-123"; const request: Partial = { user: { id: userId, email: "test@example.com", name: "Test User", }, }; mockExecutionContext = createMockExecutionContext(request); // Override call handler to check if RLS client is available let capturedClient: PrismaClient | undefined; mockCallHandler = { handle: vi.fn(() => { capturedClient = getRlsClient(); return of({ data: "test response" }); }), }; await new Promise((resolve) => { interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({ next: resolve, }); }); expect(capturedClient).toBe(mockTransactionClient); }); it("should handle errors and still propagate them", async () => { const userId = "user-123"; const request: Partial = { user: { id: userId, email: "test@example.com", name: "Test User", }, }; mockExecutionContext = createMockExecutionContext(request); const error = new Error("Test error"); mockCallHandler = { handle: vi.fn(() => throwError(() => error)), }; await expect( new Promise((resolve, reject) => { interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({ next: resolve, error: reject, }); }) ).rejects.toThrow(error); // Context should still have been set before error expect(mockTransactionClient.$executeRaw).toHaveBeenCalled(); }); it("should clear RLS context after request completes", async () => { const userId = "user-123"; const request: Partial = { user: { id: userId, email: "test@example.com", name: "Test User", }, }; mockExecutionContext = createMockExecutionContext(request); await new Promise((resolve) => { interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({ next: resolve, }); }); // After the observable completes, RLS context should be cleared const client = getRlsClient(); expect(client).toBeUndefined(); }); it("should handle missing user.id gracefully", async () => { const request: Partial = { user: { id: "", email: "test@example.com", name: "Test User", }, }; mockExecutionContext = createMockExecutionContext(request); await new Promise((resolve) => { interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({ next: resolve, }); }); expect(mockTransactionClient.$executeRaw).not.toHaveBeenCalled(); expect(mockCallHandler.handle).toHaveBeenCalled(); }); it("should configure transaction with timeout and maxWait", async () => { const userId = "user-123"; const request: Partial = { user: { id: userId, email: "test@example.com", name: "Test User", }, }; mockExecutionContext = createMockExecutionContext(request); await new Promise((resolve) => { interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({ next: resolve, }); }); // Verify transaction was called with timeout options expect(prismaService.$transaction).toHaveBeenCalledWith( expect.any(Function), expect.objectContaining({ timeout: 30000, // 30 seconds maxWait: 10000, // 10 seconds }) ); }); it("should sanitize database errors before sending to client", async () => { const userId = "user-123"; const request: Partial = { user: { id: userId, email: "test@example.com", name: "Test User", }, }; mockExecutionContext = createMockExecutionContext(request); // Mock transaction to throw a database error with sensitive information const databaseError = new Error( "PrismaClientKnownRequestError: Connection failed to database.internal.example.com:5432" ); vi.spyOn(prismaService, "$transaction").mockRejectedValue(databaseError); const errorPromise = new Promise((resolve, reject) => { interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({ next: resolve, error: reject, }); }); await expect(errorPromise).rejects.toThrow(InternalServerErrorException); await expect(errorPromise).rejects.toThrow("Request processing failed"); // Verify the detailed error was NOT sent to the client await expect(errorPromise).rejects.not.toThrow("database.internal.example.com"); }); }); });