Files
stack/apps/api/src/prisma/prisma.service.ts
Jason Woltje e3cba37e8c
All checks were successful
ci/woodpecker/push/web Pipeline was successful
ci/woodpecker/push/api Pipeline was successful
fix(api,web): resolve RLS context SQL error, workspace guard crash, and projects response unwrapping (#531)
Co-authored-by: Jason Woltje <jason@diversecanvas.com>
Co-committed-by: Jason Woltje <jason@diversecanvas.com>
2026-02-27 04:18:35 +00:00

197 lines
6.8 KiB
TypeScript

import { Injectable, Logger, OnModuleDestroy, OnModuleInit } from "@nestjs/common";
import { PrismaClient } from "@prisma/client";
import { VaultService } from "../vault/vault.service";
import { createAccountEncryptionExtension } from "./account-encryption.extension";
import { createLlmEncryptionExtension } from "./llm-encryption.extension";
import { getRlsClient } from "./rls-context.provider";
/**
* Prisma service that manages database connection lifecycle
* Extends PrismaClient to provide connection management and health checks
*
* IMPORTANT: VaultService is required (not optional) for encryption/decryption
* of sensitive Account tokens. It automatically falls back to AES-256-GCM when
* OpenBao is unavailable.
*
* Encryption is handled via Prisma Client Extensions ($extends) on the `account`
* and `llmProviderInstance` models. The extended client is stored in `_xClient`
* and only those two model getters are overridden — all other models use the
* base PrismaClient inheritance, preserving full type safety.
*/
@Injectable()
export class PrismaService extends PrismaClient implements OnModuleInit, OnModuleDestroy {
private readonly logger = new Logger(PrismaService.name);
// Extended client with encryption hooks for account + llmProviderInstance
// eslint-disable-next-line @typescript-eslint/no-explicit-any
private _xClient: any = null;
constructor(private readonly vaultService: VaultService) {
super({
log: process.env.NODE_ENV === "development" ? ["query", "info", "warn", "error"] : ["error"],
});
}
/**
* Connect to database when NestJS module initializes
*/
async onModuleInit() {
try {
await this.$connect();
this.logger.log("Database connection established");
// Register encryption extensions (replaces removed $use() middleware)
this._xClient = this.$extends(createAccountEncryptionExtension(this.vaultService)).$extends(
createLlmEncryptionExtension(this.vaultService)
);
this.logger.log("Encryption extensions registered (Account, LlmProviderInstance)");
} catch (error) {
this.logger.error("Failed to connect to database", error);
throw error;
}
}
// Override only the 2 models that need encryption hooks.
// All other models (user, task, workspace, etc.) use base PrismaClient via inheritance.
// Cast _xClient to PrismaClient to preserve the accessor return types for consumers.
override get account() {
return this._xClient ? (this._xClient as PrismaClient).account : super.account;
}
override get llmProviderInstance() {
if (this._xClient) return (this._xClient as PrismaClient).llmProviderInstance;
return super.llmProviderInstance;
}
/**
* Disconnect from database when NestJS module is destroyed
*/
async onModuleDestroy() {
await this.$disconnect();
this.logger.log("Database connection closed");
}
/**
* Health check for database connectivity
* @returns true if database is accessible, false otherwise
*/
async isHealthy(): Promise<boolean> {
try {
await this.$queryRaw`SELECT 1`;
return true;
} catch (error) {
this.logger.error("Database health check failed", error);
return false;
}
}
/**
* Get database connection info for debugging
* @returns Connection status and basic info
*/
async getConnectionInfo(): Promise<{
connected: boolean;
database?: string;
version?: string;
}> {
try {
const result = await this.$queryRaw<{ current_database: string; version: string }[]>`
SELECT current_database(), version()
`;
if (result.length > 0 && result[0]) {
const dbVersion = result[0].version.split(" ")[0];
return {
connected: true,
database: result[0].current_database,
...(dbVersion && { version: dbVersion }),
};
}
return { connected: false };
} catch (error) {
this.logger.error("Failed to get connection info", error);
return { connected: false };
}
}
/**
* Sets workspace context for Row-Level Security (RLS)
* Sets both user_id and workspace_id session variables for PostgreSQL RLS policies
*
* IMPORTANT: Must be called within a transaction or use the default client
* Session variables are transaction-scoped (SET LOCAL) for connection pool safety
*
* @param userId - The ID of the authenticated user
* @param workspaceId - The ID of the workspace context
* @param client - Optional Prisma client (uses 'this' if not provided)
*
* @example
* ```typescript
* await prisma.$transaction(async (tx) => {
* await prisma.setWorkspaceContext(userId, workspaceId, tx);
* const tasks = await tx.task.findMany(); // Filtered by RLS
* });
* ```
*/
async setWorkspaceContext(
userId: string,
workspaceId: string,
client: PrismaClient = this
): Promise<void> {
// Use set_config() instead of SET LOCAL so values are safely parameterized.
// SET LOCAL with Prisma's tagged template produces invalid SQL (bind parameter $1
// is not supported in SET statements by PostgreSQL).
await client.$executeRaw`SELECT set_config('app.current_user_id', ${userId}, true)`;
await client.$executeRaw`SELECT set_config('app.current_workspace_id', ${workspaceId}, true)`;
}
/**
* Clears workspace context session variables
* Typically not needed as SET LOCAL is automatically cleared at transaction end
*
* @param client - Optional Prisma client (uses 'this' if not provided)
*/
async clearWorkspaceContext(client: PrismaClient = this): Promise<void> {
await client.$executeRaw`SELECT set_config('app.current_user_id', '', true)`;
await client.$executeRaw`SELECT set_config('app.current_workspace_id', '', true)`;
}
/**
* Executes a function with workspace context set within a transaction
* Automatically sets the context and ensures proper scoping
*
* @param userId - The ID of the authenticated user
* @param workspaceId - The ID of the workspace context
* @param fn - Function to execute with context (receives transaction client)
* @returns The result of the function
*
* @example
* ```typescript
* const tasks = await prisma.withWorkspaceContext(userId, workspaceId, async (tx) => {
* return tx.task.findMany({
* where: { status: 'IN_PROGRESS' }
* });
* });
* ```
*/
async withWorkspaceContext<T>(
userId: string,
workspaceId: string,
fn: (tx: PrismaClient) => Promise<T>
): Promise<T> {
const rlsClient = getRlsClient();
if (rlsClient) {
await this.setWorkspaceContext(userId, workspaceId, rlsClient as unknown as PrismaClient);
return fn(rlsClient as unknown as PrismaClient);
}
return this.$transaction(async (tx) => {
await this.setWorkspaceContext(userId, workspaceId, tx as PrismaClient);
return fn(tx as PrismaClient);
});
}
}