199 lines
5.5 KiB
TypeScript
199 lines
5.5 KiB
TypeScript
/**
|
|
* RLS Context Integration Tests
|
|
*
|
|
* Tests that the RlsContextInterceptor correctly sets RLS context
|
|
* and that services can access the RLS-scoped client.
|
|
*/
|
|
|
|
import { describe, it, expect, beforeEach, vi } from "vitest";
|
|
import { Test, TestingModule } from "@nestjs/testing";
|
|
import { Injectable, Controller, Get, UseGuards, UseInterceptors } from "@nestjs/common";
|
|
import { of } from "rxjs";
|
|
import { RlsContextInterceptor, type TransactionClient } from "./rls-context.interceptor";
|
|
import { PrismaService } from "../../prisma/prisma.service";
|
|
import { getRlsClient } from "../../prisma/rls-context.provider";
|
|
|
|
/**
|
|
* Mock service that uses getRlsClient() pattern
|
|
*/
|
|
@Injectable()
|
|
class TestService {
|
|
private rlsClientUsed = false;
|
|
private queriesExecuted: string[] = [];
|
|
|
|
constructor(private readonly prisma: PrismaService) {}
|
|
|
|
async findWithRls(): Promise<{ usedRlsClient: boolean; queries: string[] }> {
|
|
const client = getRlsClient() ?? this.prisma;
|
|
this.rlsClientUsed = client !== this.prisma;
|
|
|
|
// Track that we're using the client
|
|
this.queriesExecuted.push("findMany");
|
|
|
|
return {
|
|
usedRlsClient: this.rlsClientUsed,
|
|
queries: this.queriesExecuted,
|
|
};
|
|
}
|
|
|
|
reset() {
|
|
this.rlsClientUsed = false;
|
|
this.queriesExecuted = [];
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Mock controller that uses the test service
|
|
*/
|
|
@Controller("test")
|
|
class TestController {
|
|
constructor(private readonly testService: TestService) {}
|
|
|
|
@Get()
|
|
@UseInterceptors(RlsContextInterceptor)
|
|
async test() {
|
|
return this.testService.findWithRls();
|
|
}
|
|
}
|
|
|
|
describe("RLS Context Integration", () => {
|
|
let testService: TestService;
|
|
let prismaService: PrismaService;
|
|
let mockTransactionClient: TransactionClient;
|
|
|
|
beforeEach(async () => {
|
|
// Create mock transaction client (excludes $connect, $disconnect, etc.)
|
|
mockTransactionClient = {
|
|
$executeRaw: vi.fn().mockResolvedValue(undefined),
|
|
} as unknown as TransactionClient;
|
|
|
|
// Create mock Prisma service
|
|
const mockPrismaService = {
|
|
$transaction: vi.fn(async (callback: (tx: TransactionClient) => Promise<unknown>) => {
|
|
return callback(mockTransactionClient);
|
|
}),
|
|
};
|
|
|
|
const module: TestingModule = await Test.createTestingModule({
|
|
controllers: [TestController],
|
|
providers: [
|
|
TestService,
|
|
RlsContextInterceptor,
|
|
{
|
|
provide: PrismaService,
|
|
useValue: mockPrismaService,
|
|
},
|
|
],
|
|
}).compile();
|
|
|
|
testService = module.get<TestService>(TestService);
|
|
prismaService = module.get<PrismaService>(PrismaService);
|
|
});
|
|
|
|
describe("Service queries with RLS context", () => {
|
|
it("should provide RLS client to services when user is authenticated", async () => {
|
|
const userId = "user-123";
|
|
const workspaceId = "workspace-456";
|
|
|
|
// Create interceptor instance
|
|
const interceptor = new RlsContextInterceptor(prismaService);
|
|
|
|
// Mock execution context
|
|
const mockContext = {
|
|
switchToHttp: () => ({
|
|
getRequest: () => ({
|
|
user: {
|
|
id: userId,
|
|
email: "test@example.com",
|
|
name: "Test User",
|
|
workspaceId,
|
|
},
|
|
workspace: {
|
|
id: workspaceId,
|
|
},
|
|
}),
|
|
}),
|
|
} as any;
|
|
|
|
// Mock call handler
|
|
const mockNext = {
|
|
handle: vi.fn(() => {
|
|
// This simulates the controller calling the service
|
|
// Must return an Observable, not a Promise
|
|
const result = testService.findWithRls();
|
|
return of(result);
|
|
}),
|
|
} as any;
|
|
|
|
const result = await new Promise((resolve) => {
|
|
interceptor.intercept(mockContext, mockNext).subscribe({
|
|
next: resolve,
|
|
});
|
|
});
|
|
|
|
// Verify RLS client was used
|
|
expect(result).toMatchObject({
|
|
usedRlsClient: true,
|
|
queries: ["findMany"],
|
|
});
|
|
|
|
// Verify transaction-local set_config calls were made
|
|
expect(mockTransactionClient.$executeRaw).toHaveBeenCalledWith(
|
|
expect.arrayContaining(["SELECT set_config('app.current_user_id', ", ", true)"]),
|
|
userId
|
|
);
|
|
expect(mockTransactionClient.$executeRaw).toHaveBeenCalledWith(
|
|
expect.arrayContaining(["SELECT set_config('app.current_workspace_id', ", ", true)"]),
|
|
workspaceId
|
|
);
|
|
});
|
|
|
|
it("should fall back to standard client when no RLS context", async () => {
|
|
// Call service directly without going through interceptor
|
|
testService.reset();
|
|
const result = await testService.findWithRls();
|
|
|
|
expect(result).toMatchObject({
|
|
usedRlsClient: false,
|
|
queries: ["findMany"],
|
|
});
|
|
});
|
|
});
|
|
|
|
describe("RLS context scoping", () => {
|
|
it("should clear RLS context after request completes", async () => {
|
|
const userId = "user-123";
|
|
|
|
const interceptor = new RlsContextInterceptor(prismaService);
|
|
|
|
const mockContext = {
|
|
switchToHttp: () => ({
|
|
getRequest: () => ({
|
|
user: {
|
|
id: userId,
|
|
email: "test@example.com",
|
|
name: "Test User",
|
|
},
|
|
}),
|
|
}),
|
|
} as any;
|
|
|
|
const mockNext = {
|
|
handle: vi.fn(() => {
|
|
return of({ data: "test" });
|
|
}),
|
|
} as any;
|
|
|
|
await new Promise((resolve) => {
|
|
interceptor.intercept(mockContext, mockNext).subscribe({
|
|
next: resolve,
|
|
});
|
|
});
|
|
|
|
// After request completes, RLS context should be cleared
|
|
const client = getRlsClient();
|
|
expect(client).toBeUndefined();
|
|
});
|
|
});
|
|
});
|