/** * Account Encryption Middleware Tests * * Tests transparent encryption/decryption of OAuth tokens in Account model * using Prisma middleware. */ import { describe, it, expect, beforeAll, afterAll, vi } from "vitest"; import { Test, TestingModule } from "@nestjs/testing"; import { PrismaClient } from "@prisma/client"; import { CryptoService } from "../federation/crypto.service"; import { ConfigService } from "@nestjs/config"; import { registerAccountEncryptionMiddleware } from "./account-encryption.middleware"; describe("AccountEncryptionMiddleware", () => { let mockPrisma: any; let cryptoService: CryptoService; let mockConfigService: Partial; let middlewareFunction: any; beforeAll(() => { // Mock ConfigService with a valid test encryption key mockConfigService = { get: vi.fn((key: string) => { if (key === "ENCRYPTION_KEY") { // Valid 64-character hex string (32 bytes) return "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"; } return null; }), }; cryptoService = new CryptoService(mockConfigService as ConfigService); // Create a mock Prisma client mockPrisma = { $use: vi.fn((fn) => { middlewareFunction = fn; }), }; // Register the middleware registerAccountEncryptionMiddleware(mockPrisma, cryptoService); }); afterAll(async () => { // No cleanup needed for mocks }); describe("Encryption on Create", () => { it("should encrypt accessToken on account creation", async () => { const plainAccessToken = "test-access-token-12345"; const mockParams = { model: "Account", action: "create" as const, args: { data: { userId: "test-user-id", accountId: "test-account-id", providerId: "github", accessToken: plainAccessToken, }, }, }; // Middleware should modify args const result = await callMiddleware(mockPrisma, mockParams); // Verify accessToken is encrypted (starts with hex:hex:hex format) expect(result.args.data.accessToken).toBeDefined(); expect(result.args.data.accessToken).not.toBe(plainAccessToken); expect(result.args.data.accessToken).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/); }); it("should encrypt refreshToken on account creation", async () => { const plainRefreshToken = "test-refresh-token-67890"; const mockParams = { model: "Account", action: "create" as const, args: { data: { userId: "test-user-id", accountId: "test-account-id", providerId: "github", refreshToken: plainRefreshToken, }, }, }; const result = await callMiddleware(mockPrisma, mockParams); expect(result.args.data.refreshToken).toBeDefined(); expect(result.args.data.refreshToken).not.toBe(plainRefreshToken); expect(result.args.data.refreshToken).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/); }); it("should encrypt idToken on account creation", async () => { const plainIdToken = "test-id-token-abcdef"; const mockParams = { model: "Account", action: "create" as const, args: { data: { userId: "test-user-id", accountId: "test-account-id", providerId: "oauth", idToken: plainIdToken, }, }, }; const result = await callMiddleware(mockPrisma, mockParams); expect(result.args.data.idToken).toBeDefined(); expect(result.args.data.idToken).not.toBe(plainIdToken); expect(result.args.data.idToken).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/); }); it("should encrypt all three tokens when present", async () => { const mockParams = { model: "Account", action: "create" as const, args: { data: { userId: "test-user-id", accountId: "test-account-id", providerId: "oauth", accessToken: "access-123", refreshToken: "refresh-456", idToken: "id-789", }, }, }; const result = await callMiddleware(mockPrisma, mockParams); expect(result.args.data.accessToken).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/); expect(result.args.data.refreshToken).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/); expect(result.args.data.idToken).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/); }); it("should handle null tokens gracefully", async () => { const mockParams = { model: "Account", action: "create" as const, args: { data: { userId: "test-user-id", accountId: "test-account-id", providerId: "github", accessToken: null, refreshToken: null, idToken: null, }, }, }; const result = await callMiddleware(mockPrisma, mockParams); expect(result.args.data.accessToken).toBeNull(); expect(result.args.data.refreshToken).toBeNull(); expect(result.args.data.idToken).toBeNull(); }); it("should handle undefined tokens gracefully", async () => { const mockParams = { model: "Account", action: "create" as const, args: { data: { userId: "test-user-id", accountId: "test-account-id", providerId: "github", }, }, }; const result = await callMiddleware(mockPrisma, mockParams); expect(result.args.data.accessToken).toBeUndefined(); expect(result.args.data.refreshToken).toBeUndefined(); expect(result.args.data.idToken).toBeUndefined(); }); }); describe("Encryption on Update", () => { it("should encrypt accessToken on account update", async () => { const plainAccessToken = "updated-access-token"; const mockParams = { model: "Account", action: "update" as const, args: { where: { id: "account-id" }, data: { accessToken: plainAccessToken, }, }, }; const result = await callMiddleware(mockPrisma, mockParams); expect(result.args.data.accessToken).toBeDefined(); expect(result.args.data.accessToken).not.toBe(plainAccessToken); expect(result.args.data.accessToken).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/); }); it("should handle updateMany action", async () => { const mockParams = { model: "Account", action: "updateMany" as const, args: { where: { providerId: "github" }, data: { accessToken: "new-token", }, }, }; const result = await callMiddleware(mockPrisma, mockParams); expect(result.args.data.accessToken).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/); }); it("should not encrypt already encrypted tokens (idempotent)", async () => { // Simulate a token that's already encrypted const encryptedToken = cryptoService.encrypt("original-token"); const mockParams = { model: "Account", action: "update" as const, args: { where: { id: "account-id" }, data: { accessToken: encryptedToken, }, }, }; const result = await callMiddleware(mockPrisma, mockParams); // Should remain unchanged if already encrypted expect(result.args.data.accessToken).toBe(encryptedToken); }); it("should handle upsert action (encrypt both create and update)", async () => { const plainCreateToken = "create-token"; const plainUpdateToken = "update-token"; const mockParams = { model: "Account", action: "upsert" as const, args: { where: { id: "account-id" }, create: { userId: "user-id", accountId: "account-id", providerId: "github", accessToken: plainCreateToken, }, update: { accessToken: plainUpdateToken, }, }, }; const result = await callMiddleware(mockPrisma, mockParams); // Both create and update data should be encrypted expect(result.args.create.accessToken).toBeDefined(); expect(result.args.create.accessToken).not.toBe(plainCreateToken); expect(result.args.create.accessToken).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/); expect(result.args.update.accessToken).toBeDefined(); expect(result.args.update.accessToken).not.toBe(plainUpdateToken); expect(result.args.update.accessToken).toMatch(/^[0-9a-f]+:[0-9a-f]+:[0-9a-f]+$/); }); }); describe("Decryption on Read", () => { it("should decrypt accessToken on findUnique", async () => { const plainToken = "my-access-token"; const encryptedToken = cryptoService.encrypt(plainToken); const mockParams = { model: "Account", action: "findUnique" as const, args: { where: { id: "account-id" }, }, }; // Mock database returning encrypted data with encryptionVersion const mockNext = vi.fn(async () => ({ id: "account-id", userId: "user-id", accountId: "account-id", providerId: "github", accessToken: encryptedToken, refreshToken: null, idToken: null, encryptionVersion: "aes", })); // Call middleware - it should decrypt the result const result = (await middlewareFunction(mockParams, mockNext)) as any; expect(mockNext).toHaveBeenCalledWith(mockParams); expect(result.accessToken).toBe(plainToken); // Decrypted by middleware expect(result.encryptionVersion).toBe("aes"); }); it("should decrypt all tokens on findMany", async () => { const plainAccess = "access-token"; const plainRefresh = "refresh-token"; const plainId = "id-token"; const mockParams = { model: "Account", action: "findMany" as const, args: { where: { providerId: "github" }, }, }; // Mock database returning multiple encrypted records const mockNext = vi.fn(async () => [ { id: "account-1", userId: "user-id", accountId: "account-1", providerId: "github", accessToken: cryptoService.encrypt(plainAccess), refreshToken: cryptoService.encrypt(plainRefresh), idToken: cryptoService.encrypt(plainId), encryptionVersion: "aes", }, ]); const result = (await middlewareFunction(mockParams, mockNext)) as any[]; expect(mockNext).toHaveBeenCalledWith(mockParams); expect(result[0].accessToken).toBe(plainAccess); expect(result[0].refreshToken).toBe(plainRefresh); expect(result[0].idToken).toBe(plainId); }); it("should handle null tokens on read", async () => { const mockParams = { model: "Account", action: "findUnique" as const, args: { where: { id: "account-id" }, }, }; const mockNext = vi.fn(async () => ({ id: "account-id", userId: "user-id", accountId: "account-id", providerId: "github", accessToken: null, refreshToken: null, idToken: null, encryptionVersion: null, })); const result = (await middlewareFunction(mockParams, mockNext)) as any; expect(result.accessToken).toBeNull(); expect(result.refreshToken).toBeNull(); expect(result.idToken).toBeNull(); }); it("should handle legacy plaintext tokens (backward compatibility)", async () => { // Simulate old data without encryptionVersion field const plaintextToken = "legacy-plaintext-token"; const mockParams = { model: "Account", action: "findUnique" as const, args: { where: { id: "account-id" }, }, }; const mockNext = vi.fn(async () => ({ id: "account-id", userId: "user-id", accountId: "account-id", providerId: "github", accessToken: plaintextToken, refreshToken: null, idToken: null, encryptionVersion: null, // No encryption version = plaintext })); const result = (await middlewareFunction(mockParams, mockNext)) as any; // Should pass through unchanged (no encryptionVersion) expect(result.accessToken).toBe(plaintextToken); }); it("should throw error on vault ciphertext when OpenBao unavailable", async () => { // Simulate vault Transit encryption format when OpenBao is unavailable const vaultCiphertext = "vault:v1:base64encodeddata"; const mockParams = { model: "Account", action: "findUnique" as const, args: { where: { id: "account-id" }, }, }; const mockNext = vi.fn(async () => ({ id: "account-id", userId: "user-id", accountId: "account-id", providerId: "oauth", accessToken: vaultCiphertext, refreshToken: null, idToken: null, encryptionVersion: "vault", // vault encryption })); // Should throw error because VaultService can't decrypt vault:v1: without OpenBao await expect(middlewareFunction(mockParams, mockNext)).rejects.toThrow( "Failed to decrypt account credentials" ); }); it("should use encryptionVersion as primary discriminator", async () => { // Even if token looks like AES format, should not decrypt if encryptionVersion != 'aes' const fakeEncryptedToken = "abc123:def456:789ghi"; // Looks like AES format const mockParams = { model: "Account", action: "findUnique" as const, args: { where: { id: "account-id" }, }, }; const mockNext = vi.fn(async () => ({ id: "account-id", userId: "user-id", accountId: "account-id", providerId: "github", accessToken: fakeEncryptedToken, refreshToken: null, idToken: null, encryptionVersion: null, // No encryption version })); const result = (await middlewareFunction(mockParams, mockNext)) as any; // Should NOT attempt to decrypt (encryptionVersion is null) expect(result.accessToken).toBe(fakeEncryptedToken); }); it("should throw error on corrupted encrypted data", async () => { // Test with malformed/corrupted encrypted token const corruptedToken = "deadbeef:cafebabe:corrupted_data_xyz"; // Valid format but wrong data const mockParams = { model: "Account", action: "findUnique" as const, args: { where: { id: "account-id" }, }, }; const mockNext = vi.fn(async () => ({ id: "account-id", userId: "user-id", accountId: "account-id", providerId: "github", accessToken: corruptedToken, refreshToken: null, idToken: null, encryptionVersion: "aes", // Marked as encrypted })); // Should throw error - decryption failures are now propagated to prevent silent corruption await expect(middlewareFunction(mockParams, mockNext)).rejects.toThrow( "Failed to decrypt account credentials" ); }); it("should throw error on completely malformed encrypted format", async () => { // Test with data that doesn't match expected format at all const malformedToken = "this:is:not:valid:encrypted:data:too:many:parts"; const mockParams = { model: "Account", action: "findUnique" as const, args: { where: { id: "account-id" }, }, }; const mockNext = vi.fn(async () => ({ id: "account-id", userId: "user-id", accountId: "account-id", providerId: "github", accessToken: malformedToken, refreshToken: null, idToken: null, encryptionVersion: "aes", })); // Should throw error - malformed data cannot be decrypted await expect(middlewareFunction(mockParams, mockNext)).rejects.toThrow( "Failed to decrypt account credentials" ); }); }); describe("Non-Account Models", () => { it("should not process other models", async () => { const mockParams = { model: "User", action: "create" as const, args: { data: { email: "test@example.com", name: "Test User", }, }, }; const result = await callMiddleware(mockPrisma, mockParams); // Should pass through unchanged expect(result.args.data).toEqual({ email: "test@example.com", name: "Test User", }); }); it("should not process Account queries without token fields", async () => { const mockParams = { model: "Account", action: "findUnique" as const, args: { where: { id: "account-id" }, select: { id: true, providerId: true, }, }, }; const result = await callMiddleware(mockPrisma, mockParams); // Should pass through without modification expect(result.args.select).toEqual({ id: true, providerId: true, }); }); }); // Helper function to simulate middleware execution for write operations async function callMiddleware(client: any, params: any) { if (!middlewareFunction) { throw new Error("Middleware not registered"); } // Call middleware with a mock next function that returns the params unchanged // This is useful for testing write operations where we check if data was encrypted return middlewareFunction(params, async (p: any) => p); } });