diff --git a/apps/gateway/src/federation/__tests__/enrollment.service.spec.ts b/apps/gateway/src/federation/__tests__/enrollment.service.spec.ts index 4d9e9c0..8bf474e 100644 --- a/apps/gateway/src/federation/__tests__/enrollment.service.spec.ts +++ b/apps/gateway/src/federation/__tests__/enrollment.service.spec.ts @@ -14,15 +14,13 @@ * - GoneException when grant status is not pending * * redeem — success path: + * - atomically claims token BEFORE cert issuance (claim → issueCert → tx) * - calls CaService.issueCert with correct args - * - atomically marks token used (UPDATE … WHERE used_at IS NULL) - * - calls GrantsService.activateGrant - * - updates peer record (certPem, certSerial, certNotAfter, state=active) - * - inserts audit log row + * - activates grant + updates peer + writes audit log inside a transaction * - returns { certPem, certChainPem } * * redeem — replay protection: - * - GoneException when UPDATE rows-updated === 0 (concurrent request won the race) + * - GoneException when claim UPDATE returns empty array (concurrent request won) */ import 'reflect-metadata'; @@ -81,30 +79,45 @@ function makeGrant(overrides: Partial> = {}) { function makeDb({ tokenRows = [makeTokenRow()], - updateRowCount = 1, + // claimedRows is returned by the .returning() on the token-claim UPDATE. + // Empty array = concurrent request won the race (GoneException). + claimedRows = [{ token: TOKEN }], }: { tokenRows?: unknown[]; - updateRowCount?: number; + claimedRows?: unknown[]; } = {}) { - // insert().values() — used for token creation and audit log insert + // insert().values() — for createToken (outer db, not tx) const insertValues = vi.fn().mockResolvedValue(undefined); const insertMock = vi.fn().mockReturnValue({ values: insertValues }); - // select().from().where().limit() + // select().from().where().limit() — for fetching the token row const limitSelect = vi.fn().mockResolvedValue(tokenRows); const whereSelect = vi.fn().mockReturnValue({ limit: limitSelect }); const fromSelect = vi.fn().mockReturnValue({ where: whereSelect }); const selectMock = vi.fn().mockReturnValue({ from: fromSelect }); - // update().set().where() - const whereUpdate = vi.fn().mockResolvedValue({ rowCount: updateRowCount }); - const setMock = vi.fn().mockReturnValue({ where: whereUpdate }); - const updateMock = vi.fn().mockReturnValue({ set: setMock }); + // update().set().where().returning() — for the atomic token claim (outer db) + const returningMock = vi.fn().mockResolvedValue(claimedRows); + const whereClaimUpdate = vi.fn().mockReturnValue({ returning: returningMock }); + const setClaimMock = vi.fn().mockReturnValue({ where: whereClaimUpdate }); + const claimUpdateMock = vi.fn().mockReturnValue({ set: setClaimMock }); + + // transaction(cb) — cb receives txMock; txMock has update + insert + const txInsertValues = vi.fn().mockResolvedValue(undefined); + const txInsertMock = vi.fn().mockReturnValue({ values: txInsertValues }); + const txWhereUpdate = vi.fn().mockResolvedValue(undefined); + const txSetMock = vi.fn().mockReturnValue({ where: txWhereUpdate }); + const txUpdateMock = vi.fn().mockReturnValue({ set: txSetMock }); + const txMock = { update: txUpdateMock, insert: txInsertMock }; + const transactionMock = vi + .fn() + .mockImplementation(async (cb: (tx: typeof txMock) => Promise) => cb(txMock)); return { insert: insertMock, select: selectMock, - update: updateMock, + update: claimUpdateMock, + transaction: transactionMock, _mocks: { insertValues, insertMock, @@ -112,9 +125,17 @@ function makeDb({ whereSelect, fromSelect, selectMock, - whereUpdate, - setMock, - updateMock, + returningMock, + whereClaimUpdate, + setClaimMock, + claimUpdateMock, + txInsertValues, + txInsertMock, + txWhereUpdate, + txSetMock, + txUpdateMock, + txMock, + transactionMock, }, }; } @@ -236,14 +257,23 @@ describe('EnrollmentService.redeem — error paths', () => { await expect(service.redeem(TOKEN, '---CSR---')).rejects.toBeInstanceOf(GoneException); }); - it('throws GoneException when UPDATE returns 0 rows (concurrent replay)', async () => { - const db = makeDb({ updateRowCount: 0 }); + it('throws GoneException when token claim UPDATE returns empty array (concurrent replay)', async () => { + const db = makeDb({ claimedRows: [] }); const caService = makeCaService(); const grantsService = makeGrantsService(); const service = buildService({ db, caService, grantsService }); await expect(service.redeem(TOKEN, '---CSR---')).rejects.toBeInstanceOf(GoneException); }); + + it('does NOT call issueCert when token claim fails (no double minting)', async () => { + const db = makeDb({ claimedRows: [] }); + const caService = makeCaService(); + const service = buildService({ db, caService }); + + await expect(service.redeem(TOKEN, '---CSR---')).rejects.toBeInstanceOf(GoneException); + expect(caService.issueCert).not.toHaveBeenCalled(); + }); }); // --------------------------------------------------------------------------- @@ -263,6 +293,22 @@ describe('EnrollmentService.redeem — success path', () => { service = buildService({ db, caService, grantsService }); }); + it('claims token BEFORE calling issueCert (prevents double minting)', async () => { + const callOrder: string[] = []; + db._mocks.returningMock.mockImplementation(async () => { + callOrder.push('claim'); + return [{ token: TOKEN }]; + }); + caService.issueCert.mockImplementation(async () => { + callOrder.push('issueCert'); + return { certPem: MOCK_CERT_PEM, certChainPem: MOCK_CHAIN_PEM, serialNumber: MOCK_SERIAL }; + }); + + await service.redeem(TOKEN, MOCK_CERT_PEM); + + expect(callOrder).toEqual(['claim', 'issueCert']); + }); + it('calls CaService.issueCert with grantId, subjectUserId, csrPem, ttlSeconds=300', async () => { await service.redeem(TOKEN, MOCK_CERT_PEM); @@ -276,17 +322,26 @@ describe('EnrollmentService.redeem — success path', () => { ); }); - it('calls GrantsService.activateGrant with the grantId', async () => { + it('runs activate grant + peer update + audit inside a transaction', async () => { await service.redeem(TOKEN, MOCK_CERT_PEM); - expect(grantsService.activateGrant).toHaveBeenCalledWith(GRANT_ID); + expect(db._mocks.transactionMock).toHaveBeenCalledOnce(); + // tx.update called twice: activate grant + update peer + expect(db._mocks.txUpdateMock).toHaveBeenCalledTimes(2); + // tx.insert called once: audit log + expect(db._mocks.txInsertMock).toHaveBeenCalledOnce(); }); - it('updates the federationPeers row with certPem, certSerial, state=active', async () => { + it('activates grant (sets status=active) inside the transaction', async () => { await service.redeem(TOKEN, MOCK_CERT_PEM); - // The update mock is called twice: once for the token mark-used, once for peers - expect(db._mocks.setMock).toHaveBeenCalledWith( + expect(db._mocks.txSetMock).toHaveBeenCalledWith(expect.objectContaining({ status: 'active' })); + }); + + it('updates the federationPeers row with certPem, certSerial, state=active inside the transaction', async () => { + await service.redeem(TOKEN, MOCK_CERT_PEM); + + expect(db._mocks.txSetMock).toHaveBeenCalledWith( expect.objectContaining({ certPem: MOCK_CERT_PEM, certSerial: MOCK_SERIAL, @@ -295,12 +350,10 @@ describe('EnrollmentService.redeem — success path', () => { ); }); - it('inserts an audit log row', async () => { + it('inserts an audit log row inside the transaction', async () => { await service.redeem(TOKEN, MOCK_CERT_PEM); - // insert is called at least twice: once for token creation is not in redeem, but - // redeem calls insert for the audit log - expect(db._mocks.insertValues).toHaveBeenCalledWith( + expect(db._mocks.txInsertValues).toHaveBeenCalledWith( expect.objectContaining({ peerId: PEER_ID, grantId: GRANT_ID, diff --git a/apps/gateway/src/federation/enrollment.service.ts b/apps/gateway/src/federation/enrollment.service.ts index 19044e8..3559a16 100644 --- a/apps/gateway/src/federation/enrollment.service.ts +++ b/apps/gateway/src/federation/enrollment.service.ts @@ -3,11 +3,13 @@ * * Responsibilities: * 1. Generate time-limited single-use enrollment tokens (admin action). - * 2. Redeem a token: validate → issue cert via CaService → atomically mark - * used → activate grant → update peer record → write audit log. + * 2. Redeem a token: validate → atomically claim token → issue cert via + * CaService → transactionally activate grant + update peer + write audit. * - * Replay protection: the UPDATE … WHERE used_at IS NULL pattern ensures only - * one concurrent request can win — all others receive GoneException (410). + * Replay protection: the token is claimed (UPDATE WHERE used_at IS NULL) BEFORE + * cert issuance. This prevents double cert minting on concurrent requests. + * If cert issuance fails after claim, the token is consumed and the grant + * stays pending — admin must create a new grant. */ import { @@ -28,6 +30,7 @@ import { isNull, sql, federationEnrollmentTokens, + federationGrants, federationPeers, federationAuditLog, } from '@mosaicstack/db'; @@ -88,12 +91,12 @@ export class EnrollmentService { * 2. usedAt set → GoneException (already used) * 3. expiresAt < now → GoneException (expired) * 4. Load grant — verify status is 'pending' - * 5. Issue cert via CaService - * 6. Atomically mark token used (replay guard) - * 7. Activate grant - * 8. Update peer record (certPem, certSerial, certNotAfter, state=active) - * 9. Write audit log - * 10. Return { certPem, certChainPem } + * 5. Atomically claim token (UPDATE WHERE used_at IS NULL RETURNING token) + * — if no rows returned, concurrent request won → GoneException + * 6. Issue cert via CaService (network call, outside transaction) + * — if this fails, token is consumed; grant stays pending; admin must recreate + * 7. Transaction: activate grant + update peer record + write audit log + * 8. Return { certPem, certChainPem } */ async redeem(token: string, csrPem: string): Promise { // 1. Fetch token row @@ -134,7 +137,24 @@ export class EnrollmentService { ); } - // 5. Issue certificate via CaService + // 5. Atomically claim the token BEFORE cert issuance to prevent double-minting. + // WHERE used_at IS NULL ensures only one concurrent request wins. + // Using .returning() works on both node-postgres and PGlite without rowCount inspection. + const claimed = await this.db + .update(federationEnrollmentTokens) + .set({ usedAt: sql`NOW()` }) + .where( + and(eq(federationEnrollmentTokens.token, token), isNull(federationEnrollmentTokens.usedAt)), + ) + .returning({ token: federationEnrollmentTokens.token }); + + if (claimed.length === 0) { + throw new GoneException('Enrollment token has already been used (concurrent request)'); + } + + // 6. Issue certificate via CaService (network call — outside any transaction). + // If this throws, the token is already consumed. The grant stays pending. + // Admin must revoke the grant and create a new one. let issued; try { issued = await this.caService.issueCert({ @@ -144,62 +164,50 @@ export class EnrollmentService { ttlSeconds: 300, }); } catch (err) { + this.logger.error( + `issueCert failed after token ${token} was claimed — grant ${row.grantId} is stranded pending`, + err instanceof Error ? err.stack : String(err), + ); if (err instanceof FederationScopeError) { throw new BadRequestException((err as Error).message); } throw err; } - // 6. Atomically mark token used — WHERE used_at IS NULL prevents replay - const markResult = await this.db - .update(federationEnrollmentTokens) - .set({ usedAt: sql`NOW()` }) - .where( - and(eq(federationEnrollmentTokens.token, token), isNull(federationEnrollmentTokens.usedAt)), - ); - - // Drizzle returns rowCount on update operations - const rowsUpdated = - markResult && typeof markResult === 'object' && 'rowCount' in markResult - ? (markResult as { rowCount: number }).rowCount - : 1; // default to 1 if driver doesn't report rowCount (e.g. PGlite) - - if (rowsUpdated === 0) { - // Another concurrent request won the race - throw new GoneException('Enrollment token has already been used (concurrent request)'); - } - - // 7. Activate grant - await this.grantsService.activateGrant(row.grantId); - - // 8. Update peer record + // 7. Atomically activate grant, update peer record, and write audit log. const certNotAfter = this.extractCertNotAfter(issued.certPem); - await this.db - .update(federationPeers) - .set({ - certPem: issued.certPem, - certSerial: issued.serialNumber, - certNotAfter, - state: 'active', - }) - .where(eq(federationPeers.id, row.peerId)); + await this.db.transaction(async (tx) => { + await tx + .update(federationGrants) + .set({ status: 'active' }) + .where(eq(federationGrants.id, row.grantId)); - // 9. Write audit log - await this.db.insert(federationAuditLog).values({ - requestId: crypto.randomUUID(), - peerId: row.peerId, - grantId: row.grantId, - verb: 'enrollment', - resource: 'federation_grant', - statusCode: 200, - outcome: 'allowed', + await tx + .update(federationPeers) + .set({ + certPem: issued.certPem, + certSerial: issued.serialNumber, + certNotAfter, + state: 'active', + }) + .where(eq(federationPeers.id, row.peerId)); + + await tx.insert(federationAuditLog).values({ + requestId: crypto.randomUUID(), + peerId: row.peerId, + grantId: row.grantId, + verb: 'enrollment', + resource: 'federation_grant', + statusCode: 200, + outcome: 'allowed', + }); }); this.logger.log( `Enrollment complete — peerId=${row.peerId} grantId=${row.grantId} serial=${issued.serialNumber}`, ); - // 10. Return cert material + // 8. Return cert material return { certPem: issued.certPem, certChainPem: issued.certChainPem,