import { Injectable, Logger, OnModuleDestroy, OnModuleInit } from "@nestjs/common"; import { PrismaClient } from "@prisma/client"; import { VaultService } from "../vault/vault.service"; import { createAccountEncryptionExtension } from "./account-encryption.extension"; import { createLlmEncryptionExtension } from "./llm-encryption.extension"; import { getRlsClient } from "./rls-context.provider"; /** * Prisma service that manages database connection lifecycle * Extends PrismaClient to provide connection management and health checks * * IMPORTANT: VaultService is required (not optional) for encryption/decryption * of sensitive Account tokens. It automatically falls back to AES-256-GCM when * OpenBao is unavailable. * * Encryption is handled via Prisma Client Extensions ($extends) on the `account` * and `llmProviderInstance` models. The extended client is stored in `_xClient` * and only those two model getters are overridden — all other models use the * base PrismaClient inheritance, preserving full type safety. */ @Injectable() export class PrismaService extends PrismaClient implements OnModuleInit, OnModuleDestroy { private readonly logger = new Logger(PrismaService.name); // Extended client with encryption hooks for account + llmProviderInstance // eslint-disable-next-line @typescript-eslint/no-explicit-any private _xClient: any = null; constructor(private readonly vaultService: VaultService) { super({ log: process.env.NODE_ENV === "development" ? ["query", "info", "warn", "error"] : ["error"], }); } /** * Connect to database when NestJS module initializes */ async onModuleInit() { try { await this.$connect(); this.logger.log("Database connection established"); // Register encryption extensions (replaces removed $use() middleware) this._xClient = this.$extends(createAccountEncryptionExtension(this.vaultService)).$extends( createLlmEncryptionExtension(this.vaultService) ); this.logger.log("Encryption extensions registered (Account, LlmProviderInstance)"); } catch (error) { this.logger.error("Failed to connect to database", error); throw error; } } // Override only the 2 models that need encryption hooks. // All other models (user, task, workspace, etc.) use base PrismaClient via inheritance. // Cast _xClient to PrismaClient to preserve the accessor return types for consumers. override get account() { return this._xClient ? (this._xClient as PrismaClient).account : super.account; } override get llmProviderInstance() { if (this._xClient) return (this._xClient as PrismaClient).llmProviderInstance; return super.llmProviderInstance; } /** * Disconnect from database when NestJS module is destroyed */ async onModuleDestroy() { await this.$disconnect(); this.logger.log("Database connection closed"); } /** * Health check for database connectivity * @returns true if database is accessible, false otherwise */ async isHealthy(): Promise { try { await this.$queryRaw`SELECT 1`; return true; } catch (error) { this.logger.error("Database health check failed", error); return false; } } /** * Get database connection info for debugging * @returns Connection status and basic info */ async getConnectionInfo(): Promise<{ connected: boolean; database?: string; version?: string; }> { try { const result = await this.$queryRaw<{ current_database: string; version: string }[]>` SELECT current_database(), version() `; if (result.length > 0 && result[0]) { const dbVersion = result[0].version.split(" ")[0]; return { connected: true, database: result[0].current_database, ...(dbVersion && { version: dbVersion }), }; } return { connected: false }; } catch (error) { this.logger.error("Failed to get connection info", error); return { connected: false }; } } /** * Sets workspace context for Row-Level Security (RLS) * Sets both user_id and workspace_id session variables for PostgreSQL RLS policies * * IMPORTANT: Must be called within a transaction or use the default client * Session variables are transaction-scoped (SET LOCAL) for connection pool safety * * @param userId - The ID of the authenticated user * @param workspaceId - The ID of the workspace context * @param client - Optional Prisma client (uses 'this' if not provided) * * @example * ```typescript * await prisma.$transaction(async (tx) => { * await prisma.setWorkspaceContext(userId, workspaceId, tx); * const tasks = await tx.task.findMany(); // Filtered by RLS * }); * ``` */ async setWorkspaceContext( userId: string, workspaceId: string, client: PrismaClient = this ): Promise { await client.$executeRaw`SET LOCAL app.current_user_id = ${userId}`; await client.$executeRaw`SET LOCAL app.current_workspace_id = ${workspaceId}`; } /** * Clears workspace context session variables * Typically not needed as SET LOCAL is automatically cleared at transaction end * * @param client - Optional Prisma client (uses 'this' if not provided) */ async clearWorkspaceContext(client: PrismaClient = this): Promise { await client.$executeRaw`SET LOCAL app.current_user_id = NULL`; await client.$executeRaw`SET LOCAL app.current_workspace_id = NULL`; } /** * Executes a function with workspace context set within a transaction * Automatically sets the context and ensures proper scoping * * @param userId - The ID of the authenticated user * @param workspaceId - The ID of the workspace context * @param fn - Function to execute with context (receives transaction client) * @returns The result of the function * * @example * ```typescript * const tasks = await prisma.withWorkspaceContext(userId, workspaceId, async (tx) => { * return tx.task.findMany({ * where: { status: 'IN_PROGRESS' } * }); * }); * ``` */ async withWorkspaceContext( userId: string, workspaceId: string, fn: (tx: PrismaClient) => Promise ): Promise { const rlsClient = getRlsClient(); if (rlsClient) { await this.setWorkspaceContext(userId, workspaceId, rlsClient as unknown as PrismaClient); return fn(rlsClient as unknown as PrismaClient); } return this.$transaction(async (tx) => { await this.setWorkspaceContext(userId, workspaceId, tx as PrismaClient); return fn(tx as PrismaClient); }); } }