/** * RLS Context Integration Tests * * Tests that the RlsContextInterceptor correctly sets RLS context * and that services can access the RLS-scoped client. */ import { describe, it, expect, beforeEach, vi } from "vitest"; import { Test, TestingModule } from "@nestjs/testing"; import { Injectable, Controller, Get, UseGuards, UseInterceptors } from "@nestjs/common"; import { of } from "rxjs"; import { RlsContextInterceptor, type TransactionClient } from "./rls-context.interceptor"; import { PrismaService } from "../../prisma/prisma.service"; import { getRlsClient } from "../../prisma/rls-context.provider"; /** * Mock service that uses getRlsClient() pattern */ @Injectable() class TestService { private rlsClientUsed = false; private queriesExecuted: string[] = []; constructor(private readonly prisma: PrismaService) {} async findWithRls(): Promise<{ usedRlsClient: boolean; queries: string[] }> { const client = getRlsClient() ?? this.prisma; this.rlsClientUsed = client !== this.prisma; // Track that we're using the client this.queriesExecuted.push("findMany"); return { usedRlsClient: this.rlsClientUsed, queries: this.queriesExecuted, }; } reset() { this.rlsClientUsed = false; this.queriesExecuted = []; } } /** * Mock controller that uses the test service */ @Controller("test") class TestController { constructor(private readonly testService: TestService) {} @Get() @UseInterceptors(RlsContextInterceptor) async test() { return this.testService.findWithRls(); } } describe("RLS Context Integration", () => { let testService: TestService; let prismaService: PrismaService; let mockTransactionClient: TransactionClient; beforeEach(async () => { // Create mock transaction client (excludes $connect, $disconnect, etc.) mockTransactionClient = { $executeRaw: vi.fn().mockResolvedValue(undefined), } as unknown as TransactionClient; // Create mock Prisma service const mockPrismaService = { $transaction: vi.fn(async (callback: (tx: TransactionClient) => Promise) => { return callback(mockTransactionClient); }), }; const module: TestingModule = await Test.createTestingModule({ controllers: [TestController], providers: [ TestService, RlsContextInterceptor, { provide: PrismaService, useValue: mockPrismaService, }, ], }).compile(); testService = module.get(TestService); prismaService = module.get(PrismaService); }); describe("Service queries with RLS context", () => { it("should provide RLS client to services when user is authenticated", async () => { const userId = "user-123"; const workspaceId = "workspace-456"; // Create interceptor instance const interceptor = new RlsContextInterceptor(prismaService); // Mock execution context const mockContext = { switchToHttp: () => ({ getRequest: () => ({ user: { id: userId, email: "test@example.com", name: "Test User", workspaceId, }, workspace: { id: workspaceId, }, }), }), } as any; // Mock call handler const mockNext = { handle: vi.fn(() => { // This simulates the controller calling the service // Must return an Observable, not a Promise const result = testService.findWithRls(); return of(result); }), } as any; const result = await new Promise((resolve) => { interceptor.intercept(mockContext, mockNext).subscribe({ next: resolve, }); }); // Verify RLS client was used expect(result).toMatchObject({ usedRlsClient: true, queries: ["findMany"], }); // Verify transaction-local set_config calls were made expect(mockTransactionClient.$executeRaw).toHaveBeenCalledWith( expect.arrayContaining(["SELECT set_config('app.current_user_id', ", ", true)"]), userId ); expect(mockTransactionClient.$executeRaw).toHaveBeenCalledWith( expect.arrayContaining(["SELECT set_config('app.current_workspace_id', ", ", true)"]), workspaceId ); }); it("should fall back to standard client when no RLS context", async () => { // Call service directly without going through interceptor testService.reset(); const result = await testService.findWithRls(); expect(result).toMatchObject({ usedRlsClient: false, queries: ["findMany"], }); }); }); describe("RLS context scoping", () => { it("should clear RLS context after request completes", async () => { const userId = "user-123"; const interceptor = new RlsContextInterceptor(prismaService); const mockContext = { switchToHttp: () => ({ getRequest: () => ({ user: { id: userId, email: "test@example.com", name: "Test User", }, }), }), } as any; const mockNext = { handle: vi.fn(() => { return of({ data: "test" }); }), } as any; await new Promise((resolve) => { interceptor.intercept(mockContext, mockNext).subscribe({ next: resolve, }); }); // After request completes, RLS context should be cleared const client = getRlsClient(); expect(client).toBeUndefined(); }); }); });