307 lines
9.4 KiB
TypeScript
307 lines
9.4 KiB
TypeScript
import { describe, it, expect, beforeEach, vi } from "vitest";
|
|
import { Test, TestingModule } from "@nestjs/testing";
|
|
import { ExecutionContext, CallHandler, InternalServerErrorException } from "@nestjs/common";
|
|
import { of, throwError } from "rxjs";
|
|
import { RlsContextInterceptor, type TransactionClient } from "./rls-context.interceptor";
|
|
import { PrismaService } from "../../prisma/prisma.service";
|
|
import { getRlsClient } from "../../prisma/rls-context.provider";
|
|
import type { AuthenticatedRequest } from "../types/user.types";
|
|
|
|
describe("RlsContextInterceptor", () => {
|
|
let interceptor: RlsContextInterceptor;
|
|
let prismaService: PrismaService;
|
|
let mockExecutionContext: ExecutionContext;
|
|
let mockCallHandler: CallHandler;
|
|
let mockTransactionClient: TransactionClient;
|
|
|
|
beforeEach(async () => {
|
|
// Create mock transaction client (excludes $connect, $disconnect, etc.)
|
|
mockTransactionClient = {
|
|
$executeRaw: vi.fn().mockResolvedValue(undefined),
|
|
} as unknown as TransactionClient;
|
|
|
|
// Create mock Prisma service
|
|
const mockPrismaService = {
|
|
$transaction: vi.fn(
|
|
async (
|
|
callback: (tx: TransactionClient) => Promise<unknown>,
|
|
options?: { timeout?: number; maxWait?: number }
|
|
) => {
|
|
return callback(mockTransactionClient);
|
|
}
|
|
),
|
|
};
|
|
|
|
const module: TestingModule = await Test.createTestingModule({
|
|
providers: [
|
|
RlsContextInterceptor,
|
|
{
|
|
provide: PrismaService,
|
|
useValue: mockPrismaService,
|
|
},
|
|
],
|
|
}).compile();
|
|
|
|
interceptor = module.get<RlsContextInterceptor>(RlsContextInterceptor);
|
|
prismaService = module.get<PrismaService>(PrismaService);
|
|
|
|
// Setup mock call handler
|
|
mockCallHandler = {
|
|
handle: vi.fn(() => of({ data: "test response" })),
|
|
};
|
|
});
|
|
|
|
const createMockExecutionContext = (request: Partial<AuthenticatedRequest>): ExecutionContext => {
|
|
return {
|
|
switchToHttp: () => ({
|
|
getRequest: () => request,
|
|
}),
|
|
} as ExecutionContext;
|
|
};
|
|
|
|
describe("intercept", () => {
|
|
it("should set user context when user is authenticated", async () => {
|
|
const userId = "user-123";
|
|
const request: Partial<AuthenticatedRequest> = {
|
|
user: {
|
|
id: userId,
|
|
email: "test@example.com",
|
|
name: "Test User",
|
|
},
|
|
};
|
|
|
|
mockExecutionContext = createMockExecutionContext(request);
|
|
|
|
const result = await new Promise((resolve) => {
|
|
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
|
next: resolve,
|
|
});
|
|
});
|
|
|
|
expect(result).toEqual({ data: "test response" });
|
|
expect(mockTransactionClient.$executeRaw).toHaveBeenCalledWith(
|
|
expect.arrayContaining(["SELECT set_config('app.current_user_id', ", ", true)"]),
|
|
userId
|
|
);
|
|
});
|
|
|
|
it("should set workspace context when workspace is present", async () => {
|
|
const userId = "user-123";
|
|
const workspaceId = "workspace-456";
|
|
const request: Partial<AuthenticatedRequest> = {
|
|
user: {
|
|
id: userId,
|
|
email: "test@example.com",
|
|
name: "Test User",
|
|
workspaceId,
|
|
},
|
|
workspace: {
|
|
id: workspaceId,
|
|
},
|
|
};
|
|
|
|
mockExecutionContext = createMockExecutionContext(request);
|
|
|
|
await new Promise((resolve) => {
|
|
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
|
next: resolve,
|
|
});
|
|
});
|
|
|
|
// Check that user context was set
|
|
expect(mockTransactionClient.$executeRaw).toHaveBeenNthCalledWith(
|
|
1,
|
|
expect.arrayContaining(["SELECT set_config('app.current_user_id', ", ", true)"]),
|
|
userId
|
|
);
|
|
// Check that workspace context was set
|
|
expect(mockTransactionClient.$executeRaw).toHaveBeenNthCalledWith(
|
|
2,
|
|
expect.arrayContaining(["SELECT set_config('app.current_workspace_id', ", ", true)"]),
|
|
workspaceId
|
|
);
|
|
});
|
|
|
|
it("should not set context when user is not authenticated", async () => {
|
|
const request: Partial<AuthenticatedRequest> = {
|
|
user: undefined,
|
|
};
|
|
|
|
mockExecutionContext = createMockExecutionContext(request);
|
|
|
|
await new Promise((resolve) => {
|
|
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
|
next: resolve,
|
|
});
|
|
});
|
|
|
|
expect(mockTransactionClient.$executeRaw).not.toHaveBeenCalled();
|
|
expect(mockCallHandler.handle).toHaveBeenCalled();
|
|
});
|
|
|
|
it("should propagate RLS client via AsyncLocalStorage", async () => {
|
|
const userId = "user-123";
|
|
const request: Partial<AuthenticatedRequest> = {
|
|
user: {
|
|
id: userId,
|
|
email: "test@example.com",
|
|
name: "Test User",
|
|
},
|
|
};
|
|
|
|
mockExecutionContext = createMockExecutionContext(request);
|
|
|
|
// Override call handler to check if RLS client is available
|
|
let capturedClient: PrismaClient | undefined;
|
|
mockCallHandler = {
|
|
handle: vi.fn(() => {
|
|
capturedClient = getRlsClient();
|
|
return of({ data: "test response" });
|
|
}),
|
|
};
|
|
|
|
await new Promise((resolve) => {
|
|
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
|
next: resolve,
|
|
});
|
|
});
|
|
|
|
expect(capturedClient).toBe(mockTransactionClient);
|
|
});
|
|
|
|
it("should handle errors and still propagate them", async () => {
|
|
const userId = "user-123";
|
|
const request: Partial<AuthenticatedRequest> = {
|
|
user: {
|
|
id: userId,
|
|
email: "test@example.com",
|
|
name: "Test User",
|
|
},
|
|
};
|
|
|
|
mockExecutionContext = createMockExecutionContext(request);
|
|
|
|
const error = new Error("Test error");
|
|
mockCallHandler = {
|
|
handle: vi.fn(() => throwError(() => error)),
|
|
};
|
|
|
|
await expect(
|
|
new Promise((resolve, reject) => {
|
|
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
|
next: resolve,
|
|
error: reject,
|
|
});
|
|
})
|
|
).rejects.toThrow(error);
|
|
|
|
// Context should still have been set before error
|
|
expect(mockTransactionClient.$executeRaw).toHaveBeenCalled();
|
|
});
|
|
|
|
it("should clear RLS context after request completes", async () => {
|
|
const userId = "user-123";
|
|
const request: Partial<AuthenticatedRequest> = {
|
|
user: {
|
|
id: userId,
|
|
email: "test@example.com",
|
|
name: "Test User",
|
|
},
|
|
};
|
|
|
|
mockExecutionContext = createMockExecutionContext(request);
|
|
|
|
await new Promise((resolve) => {
|
|
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
|
next: resolve,
|
|
});
|
|
});
|
|
|
|
// After the observable completes, RLS context should be cleared
|
|
const client = getRlsClient();
|
|
expect(client).toBeUndefined();
|
|
});
|
|
|
|
it("should handle missing user.id gracefully", async () => {
|
|
const request: Partial<AuthenticatedRequest> = {
|
|
user: {
|
|
id: "",
|
|
email: "test@example.com",
|
|
name: "Test User",
|
|
},
|
|
};
|
|
|
|
mockExecutionContext = createMockExecutionContext(request);
|
|
|
|
await new Promise((resolve) => {
|
|
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
|
next: resolve,
|
|
});
|
|
});
|
|
|
|
expect(mockTransactionClient.$executeRaw).not.toHaveBeenCalled();
|
|
expect(mockCallHandler.handle).toHaveBeenCalled();
|
|
});
|
|
|
|
it("should configure transaction with timeout and maxWait", async () => {
|
|
const userId = "user-123";
|
|
const request: Partial<AuthenticatedRequest> = {
|
|
user: {
|
|
id: userId,
|
|
email: "test@example.com",
|
|
name: "Test User",
|
|
},
|
|
};
|
|
|
|
mockExecutionContext = createMockExecutionContext(request);
|
|
|
|
await new Promise((resolve) => {
|
|
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
|
next: resolve,
|
|
});
|
|
});
|
|
|
|
// Verify transaction was called with timeout options
|
|
expect(prismaService.$transaction).toHaveBeenCalledWith(
|
|
expect.any(Function),
|
|
expect.objectContaining({
|
|
timeout: 30000, // 30 seconds
|
|
maxWait: 10000, // 10 seconds
|
|
})
|
|
);
|
|
});
|
|
|
|
it("should sanitize database errors before sending to client", async () => {
|
|
const userId = "user-123";
|
|
const request: Partial<AuthenticatedRequest> = {
|
|
user: {
|
|
id: userId,
|
|
email: "test@example.com",
|
|
name: "Test User",
|
|
},
|
|
};
|
|
|
|
mockExecutionContext = createMockExecutionContext(request);
|
|
|
|
// Mock transaction to throw a database error with sensitive information
|
|
const databaseError = new Error(
|
|
"PrismaClientKnownRequestError: Connection failed to database.internal.example.com:5432"
|
|
);
|
|
vi.spyOn(prismaService, "$transaction").mockRejectedValue(databaseError);
|
|
|
|
const errorPromise = new Promise((resolve, reject) => {
|
|
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
|
next: resolve,
|
|
error: reject,
|
|
});
|
|
});
|
|
|
|
await expect(errorPromise).rejects.toThrow(InternalServerErrorException);
|
|
await expect(errorPromise).rejects.toThrow("Request processing failed");
|
|
|
|
// Verify the detailed error was NOT sent to the client
|
|
await expect(errorPromise).rejects.not.toThrow("database.internal.example.com");
|
|
});
|
|
});
|
|
});
|