Merge remote-tracking branch 'origin/fix/rls-dto-errors' into develop

This commit is contained in:
Jason Woltje
2026-01-29 20:30:20 -06:00
9 changed files with 63 additions and 50 deletions

View File

@@ -1,4 +1,12 @@
import { Controller, Get, Query, Param, UseGuards, Request } from "@nestjs/common"; import {
Controller,
Get,
Query,
Param,
UseGuards,
Request,
UnauthorizedException
} from "@nestjs/common";
import { ActivityService } from "./activity.service"; import { ActivityService } from "./activity.service";
import { EntityType } from "@prisma/client"; import { EntityType } from "@prisma/client";
import type { QueryActivityLogDto } from "./dto"; import type { QueryActivityLogDto } from "./dto";
@@ -34,7 +42,7 @@ export class ActivityController {
async findOne(@Param("id") id: string, @Request() req: any) { async findOne(@Param("id") id: string, @Request() req: any) {
const workspaceId = req.user?.workspaceId; const workspaceId = req.user?.workspaceId;
if (!workspaceId) { if (!workspaceId) {
throw new Error("User workspaceId not found"); throw new UnauthorizedException("User workspaceId not found");
} }
return this.activityService.findOne(id, workspaceId); return this.activityService.findOne(id, workspaceId);
} }
@@ -52,7 +60,7 @@ export class ActivityController {
) { ) {
const workspaceId = req.user?.workspaceId; const workspaceId = req.user?.workspaceId;
if (!workspaceId) { if (!workspaceId) {
throw new Error("User workspaceId not found"); throw new UnauthorizedException("User workspaceId not found");
} }
return this.activityService.getAuditTrail(workspaceId, entityType, entityId); return this.activityService.getAuditTrail(workspaceId, entityType, entityId);
} }

View File

@@ -14,8 +14,9 @@ import { Type } from "class-transformer";
* DTO for querying activity logs with filters and pagination * DTO for querying activity logs with filters and pagination
*/ */
export class QueryActivityLogDto { export class QueryActivityLogDto {
@IsOptional()
@IsUUID("4", { message: "workspaceId must be a valid UUID" }) @IsUUID("4", { message: "workspaceId must be a valid UUID" })
workspaceId!: string; workspaceId?: string;
@IsOptional() @IsOptional()
@IsUUID("4", { message: "userId must be a valid UUID" }) @IsUUID("4", { message: "userId must be a valid UUID" })

View File

@@ -7,13 +7,11 @@ import {
Logger, Logger,
} from "@nestjs/common"; } from "@nestjs/common";
import { PrismaService } from "../../prisma/prisma.service"; import { PrismaService } from "../../prisma/prisma.service";
import { setCurrentUser } from "../../lib/db-context";
/** /**
* WorkspaceGuard ensures that: * WorkspaceGuard ensures that:
* 1. A workspace is specified in the request (header, param, or body) * 1. A workspace is specified in the request (header, param, or body)
* 2. The authenticated user is a member of that workspace * 2. The authenticated user is a member of that workspace
* 3. The user context is set for Row-Level Security (RLS)
* *
* This guard should be used in combination with AuthGuard: * This guard should be used in combination with AuthGuard:
* *
@@ -25,7 +23,7 @@ import { setCurrentUser } from "../../lib/db-context";
* @Get() * @Get()
* async getTasks(@Workspace() workspaceId: string) { * async getTasks(@Workspace() workspaceId: string) {
* // workspaceId is verified and available * // workspaceId is verified and available
* // RLS context is automatically set * // Service layer must use withUserContext() for RLS
* } * }
* } * }
* ``` * ```
@@ -36,6 +34,9 @@ import { setCurrentUser } from "../../lib/db-context";
* - Request body: `workspaceId` field * - Request body: `workspaceId` field
* *
* Priority: Header > Param > Body * Priority: Header > Param > Body
*
* Note: RLS context must be set at the service layer using withUserContext()
* or withUserTransaction() to ensure proper transaction scoping with connection pooling.
*/ */
@Injectable() @Injectable()
export class WorkspaceGuard implements CanActivate { export class WorkspaceGuard implements CanActivate {
@@ -75,9 +76,6 @@ export class WorkspaceGuard implements CanActivate {
); );
} }
// Set RLS context for this request
await setCurrentUser(user.id, this.prisma);
// Attach workspace info to request for convenience // Attach workspace info to request for convenience
request.workspace = { request.workspace = {
id: workspaceId, id: workspaceId,

View File

@@ -12,8 +12,9 @@ import { Type } from "class-transformer";
* DTO for querying domains with filters and pagination * DTO for querying domains with filters and pagination
*/ */
export class QueryDomainsDto { export class QueryDomainsDto {
@IsOptional()
@IsUUID("4", { message: "workspaceId must be a valid UUID" }) @IsUUID("4", { message: "workspaceId must be a valid UUID" })
workspaceId!: string; workspaceId?: string;
@IsOptional() @IsOptional()
@IsString({ message: "search must be a string" }) @IsString({ message: "search must be a string" })

View File

@@ -13,8 +13,9 @@ import { Type } from "class-transformer";
* DTO for querying events with filters and pagination * DTO for querying events with filters and pagination
*/ */
export class QueryEventsDto { export class QueryEventsDto {
@IsOptional()
@IsUUID("4", { message: "workspaceId must be a valid UUID" }) @IsUUID("4", { message: "workspaceId must be a valid UUID" })
workspaceId!: string; workspaceId?: string;
@IsOptional() @IsOptional()
@IsUUID("4", { message: "projectId must be a valid UUID" }) @IsUUID("4", { message: "projectId must be a valid UUID" })

View File

@@ -14,8 +14,9 @@ import { Type } from "class-transformer";
* DTO for querying ideas with filters and pagination * DTO for querying ideas with filters and pagination
*/ */
export class QueryIdeasDto { export class QueryIdeasDto {
@IsOptional()
@IsUUID("4", { message: "workspaceId must be a valid UUID" }) @IsUUID("4", { message: "workspaceId must be a valid UUID" })
workspaceId!: string; workspaceId?: string;
@IsOptional() @IsOptional()
@IsEnum(IdeaStatus, { message: "status must be a valid IdeaStatus" }) @IsEnum(IdeaStatus, { message: "status must be a valid IdeaStatus" })

View File

@@ -22,51 +22,58 @@ function getPrismaInstance(): PrismaClient {
} }
/** /**
* Sets the current user ID for RLS policies. * Sets the current user ID for RLS policies within a transaction context.
* Must be called before executing any queries that rely on RLS. * Must be called before executing any queries that rely on RLS.
* *
* Note: SET LOCAL must be used within a transaction to ensure it's scoped
* correctly with connection pooling. This is a low-level function - prefer
* using withUserContext or withUserTransaction for most use cases.
*
* @param userId - The UUID of the current user * @param userId - The UUID of the current user
* @param client - Optional Prisma client (defaults to global prisma) * @param client - Prisma client (required - must be a transaction client)
* *
* @example * @example
* ```typescript * ```typescript
* await setCurrentUser(userId); * await prisma.$transaction(async (tx) => {
* const tasks = await prisma.task.findMany(); // Automatically filtered by RLS * await setCurrentUser(userId, tx);
* const tasks = await tx.task.findMany(); // Automatically filtered by RLS
* });
* ``` * ```
*/ */
export async function setCurrentUser( export async function setCurrentUser(
userId: string, userId: string,
client?: PrismaClient client: PrismaClient
): Promise<void> { ): Promise<void> {
const prismaClient = client || getPrismaInstance(); await client.$executeRaw`SET LOCAL app.current_user_id = ${userId}`;
await prismaClient.$executeRaw`SET LOCAL app.current_user_id = ${userId}`;
} }
/** /**
* Clears the current user context. * Clears the current user context within a transaction.
* Use this to reset the session or when switching users. * Use this to reset the session or when switching users.
* *
* @param client - Optional Prisma client (defaults to global prisma) * Note: SET LOCAL is automatically cleared at transaction end,
* so explicit clearing is typically unnecessary.
*
* @param client - Prisma client (required - must be a transaction client)
*/ */
export async function clearCurrentUser( export async function clearCurrentUser(
client?: PrismaClient client: PrismaClient
): Promise<void> { ): Promise<void> {
const prismaClient = client || getPrismaInstance(); await client.$executeRaw`SET LOCAL app.current_user_id = NULL`;
await prismaClient.$executeRaw`SET LOCAL app.current_user_id = NULL`;
} }
/** /**
* Executes a function with the current user context set. * Executes a function with the current user context set within a transaction.
* Automatically sets and clears the user context. * Automatically sets the user context and ensures it's properly scoped.
* *
* @param userId - The UUID of the current user * @param userId - The UUID of the current user
* @param fn - The function to execute with user context * @param fn - The function to execute with user context (receives transaction client)
* @returns The result of the function * @returns The result of the function
* *
* @example * @example
* ```typescript * ```typescript
* const tasks = await withUserContext(userId, async () => { * const tasks = await withUserContext(userId, async (tx) => {
* return prisma.task.findMany({ * return tx.task.findMany({
* where: { workspaceId } * where: { workspaceId }
* }); * });
* }); * });
@@ -74,16 +81,13 @@ export async function clearCurrentUser(
*/ */
export async function withUserContext<T>( export async function withUserContext<T>(
userId: string, userId: string,
fn: () => Promise<T> fn: (tx: any) => Promise<T>
): Promise<T> { ): Promise<T> {
await setCurrentUser(userId); const prismaClient = getPrismaInstance();
try { return prismaClient.$transaction(async (tx) => {
return await fn(); await setCurrentUser(userId, tx as PrismaClient);
} finally { return fn(tx);
// Note: LOCAL settings are automatically cleared at transaction end });
// but we explicitly clear here for consistency
await clearCurrentUser();
}
} }
/** /**
@@ -168,9 +172,8 @@ export async function verifyWorkspaceAccess(
userId: string, userId: string,
workspaceId: string workspaceId: string
): Promise<boolean> { ): Promise<boolean> {
const prismaClient = getPrismaInstance(); return withUserContext(userId, async (tx) => {
return withUserContext(userId, async () => { const member = await tx.workspaceMember.findUnique({
const member = await prismaClient.workspaceMember.findUnique({
where: { where: {
workspaceId_userId: { workspaceId_userId: {
workspaceId, workspaceId,
@@ -195,9 +198,8 @@ export async function verifyWorkspaceAccess(
* ``` * ```
*/ */
export async function getUserWorkspaces(userId: string) { export async function getUserWorkspaces(userId: string) {
const prismaClient = getPrismaInstance(); return withUserContext(userId, async (tx) => {
return withUserContext(userId, async () => { return tx.workspace.findMany({
return prismaClient.workspace.findMany({
include: { include: {
members: { members: {
where: { userId }, where: { userId },
@@ -219,9 +221,8 @@ export async function isWorkspaceAdmin(
userId: string, userId: string,
workspaceId: string workspaceId: string
): Promise<boolean> { ): Promise<boolean> {
const prismaClient = getPrismaInstance(); return withUserContext(userId, async (tx) => {
return withUserContext(userId, async () => { const member = await tx.workspaceMember.findUnique({
const member = await prismaClient.workspaceMember.findUnique({
where: { where: {
workspaceId_userId: { workspaceId_userId: {
workspaceId, workspaceId,

View File

@@ -14,8 +14,9 @@ import { Type } from "class-transformer";
* DTO for querying projects with filters and pagination * DTO for querying projects with filters and pagination
*/ */
export class QueryProjectsDto { export class QueryProjectsDto {
@IsOptional()
@IsUUID("4", { message: "workspaceId must be a valid UUID" }) @IsUUID("4", { message: "workspaceId must be a valid UUID" })
workspaceId!: string; workspaceId?: string;
@IsOptional() @IsOptional()
@IsEnum(ProjectStatus, { message: "status must be a valid ProjectStatus" }) @IsEnum(ProjectStatus, { message: "status must be a valid ProjectStatus" })

View File

@@ -14,8 +14,9 @@ import { Type } from "class-transformer";
* DTO for querying tasks with filters and pagination * DTO for querying tasks with filters and pagination
*/ */
export class QueryTasksDto { export class QueryTasksDto {
@IsOptional()
@IsUUID("4", { message: "workspaceId must be a valid UUID" }) @IsUUID("4", { message: "workspaceId must be a valid UUID" })
workspaceId!: string; workspaceId?: string;
@IsOptional() @IsOptional()
@IsEnum(TaskStatus, { message: "status must be a valid TaskStatus" }) @IsEnum(TaskStatus, { message: "status must be a valid TaskStatus" })