From b31f864057e31323b50aba229e6adba3c5de3149 Mon Sep 17 00:00:00 2001 From: Jason Woltje Date: Fri, 6 Mar 2026 18:18:35 +0000 Subject: [PATCH] fix: atomic create/update and mget list (#5) Co-authored-by: Jason Woltje Co-committed-by: Jason Woltje --- packages/queue/src/task-repository.ts | 167 ++++++++++++++----- packages/queue/tests/task-atomic.test.ts | 100 +++++++++++ packages/queue/tests/task-repository.test.ts | 146 +++++++++++++++- 3 files changed, 369 insertions(+), 44 deletions(-) diff --git a/packages/queue/src/task-repository.ts b/packages/queue/src/task-repository.ts index 30d504a..d200021 100644 --- a/packages/queue/src/task-repository.ts +++ b/packages/queue/src/task-repository.ts @@ -17,6 +17,14 @@ const LANE_SET = new Set(TASK_LANES); const DEFAULT_KEY_PREFIX = 'mosaic:queue'; const MAX_ATOMIC_RETRIES = 8; +const UPDATE_ALLOWED_STATUS_TRANSITIONS: Readonly> = { + pending: ['blocked'], + blocked: ['pending'], + claimed: ['in-progress'], + 'in-progress': ['claimed'], + completed: [], + failed: [], +}; interface RepositoryKeys { readonly taskIds: string; @@ -25,6 +33,7 @@ interface RepositoryKeys { export interface RedisTaskClient { get(key: string): Promise; + mget(...keys: string[]): Promise<(string | null)[]>; set(key: string, value: string, mode?: 'NX' | 'XX'): Promise<'OK' | null>; smembers(key: string): Promise; sadd(key: string, member: string): Promise; @@ -144,19 +153,50 @@ export class RedisTaskRepository { updatedAt: timestamp, }; - const saveResult = await this.client.set( - this.keys.task(task.taskId), - JSON.stringify(task), - 'NX', - ); + const taskKey = this.keys.task(task.taskId); + const serializedTask = JSON.stringify(task); - if (saveResult !== 'OK') { - throw new TaskAlreadyExistsError(task.taskId); + for (let attempt = 0; attempt < MAX_ATOMIC_RETRIES; attempt += 1) { + await this.client.watch(taskKey); + + try { + const transaction = this.client.multi(); + transaction.set(taskKey, serializedTask, 'NX'); + transaction.sadd(this.keys.taskIds, task.taskId); + const execResult = await transaction.exec(); + + if (execResult === null) { + continue; + } + + const setResult = execResult[0]; + + if (setResult === undefined) { + throw new TaskAtomicConflictError(task.taskId); + } + + const [setError, setReply] = setResult; + + if (setError !== null) { + throw setError; + } + + if (setReply !== 'OK') { + throw new TaskAlreadyExistsError(task.taskId); + } + + const saddResult = execResult[1]; + if (saddResult !== undefined && saddResult[0] !== null) { + throw saddResult[0]; + } + + return task; + } finally { + await this.client.unwatch(); + } } - await this.client.sadd(this.keys.taskIds, task.taskId); - - return task; + throw new TaskAlreadyExistsError(task.taskId); } public async get(taskId: string): Promise { @@ -171,8 +211,28 @@ export class RedisTaskRepository { public async list(filters: TaskListFilters = {}): Promise { const taskIds = await this.client.smembers(this.keys.taskIds); - const records = await Promise.all(taskIds.map(async (taskId) => this.get(taskId))); - const tasks = records.filter((task): task is Task => task !== null); + + if (taskIds.length === 0) { + return []; + } + + const taskKeys = taskIds.map((taskId) => this.keys.task(taskId)); + const records = await this.client.mget(...taskKeys); + const tasks: Task[] = []; + + for (const [index, rawTask] of records.entries()) { + if (rawTask === null || rawTask === undefined) { + continue; + } + + const taskId = taskIds[index]; + + if (taskId === undefined) { + continue; + } + + tasks.push(deserializeTask(taskId, rawTask)); + } return tasks .filter((task) => @@ -186,33 +246,17 @@ export class RedisTaskRepository { } public async update(taskId: string, patch: TaskUpdateInput): Promise { - const existing = await this.get(taskId); + return this.mutateTaskAtomically(taskId, (existing, now) => { + assertUpdatePatchIsAllowed(taskId, existing, patch); - if (existing === null) { - throw new TaskNotFoundError(taskId); - } - - const updated: Task = { - ...existing, - ...patch, - dependencies: - patch.dependencies === undefined ? existing.dependencies : [...patch.dependencies], - updatedAt: this.now(), - }; - - const saveResult = await this.client.set( - this.keys.task(taskId), - JSON.stringify(updated), - 'XX', - ); - - if (saveResult !== 'OK') { - throw new TaskNotFoundError(taskId); - } - - await this.client.sadd(this.keys.taskIds, taskId); - - return updated; + return { + ...existing, + ...patch, + dependencies: + patch.dependencies === undefined ? existing.dependencies : [...patch.dependencies], + updatedAt: now, + }; + }); } public async claim(taskId: string, input: ClaimTaskInput): Promise { @@ -358,6 +402,26 @@ export class RedisTaskRepository { continue; } + const setResult = execResult[0]; + if (setResult === undefined) { + throw new TaskAtomicConflictError(taskId); + } + + const [setError, setReply] = setResult; + + if (setError !== null) { + throw setError; + } + + if (setReply !== 'OK') { + throw new TaskNotFoundError(taskId); + } + + const saddResult = execResult[1]; + if (saddResult !== undefined && saddResult[0] !== null) { + throw saddResult[0]; + } + return updated; } finally { await this.client.unwatch(); @@ -384,6 +448,33 @@ function matchesFilters(task: Task, filters: TaskListFilters): boolean { return true; } +function assertUpdatePatchIsAllowed(taskId: string, task: Task, patch: TaskUpdateInput): void { + if (patch.status !== undefined && !canTransitionStatusViaUpdate(task.status, patch.status)) { + throw new TaskTransitionError(taskId, task.status, 'update'); + } + + if ( + patch.claimedBy !== undefined || + patch.claimedAt !== undefined || + patch.claimTTL !== undefined || + patch.completedAt !== undefined || + patch.failedAt !== undefined || + patch.failureReason !== undefined || + patch.completionSummary !== undefined || + patch.retryCount !== undefined + ) { + throw new TaskTransitionError(taskId, task.status, 'update'); + } +} + +function canTransitionStatusViaUpdate(from: TaskStatus, to: TaskStatus): boolean { + if (from === to) { + return true; + } + + return UPDATE_ALLOWED_STATUS_TRANSITIONS[from].includes(to); +} + function canClaimTask(task: Task, now: number): boolean { if (task.status === 'pending') { return true; diff --git a/packages/queue/tests/task-atomic.test.ts b/packages/queue/tests/task-atomic.test.ts index b2a7d1a..2646f6d 100644 --- a/packages/queue/tests/task-atomic.test.ts +++ b/packages/queue/tests/task-atomic.test.ts @@ -2,6 +2,7 @@ import { describe, expect, it } from 'vitest'; import { RedisTaskRepository, + TaskAlreadyExistsError, TaskOwnershipError, TaskTransitionError, type RedisTaskClient, @@ -112,6 +113,10 @@ class InMemoryAtomicRedisClient implements RedisTaskClient { return Promise.resolve(this.backend.kv.get(key) ?? null); } + public mget(...keys: string[]): Promise<(string | null)[]> { + return Promise.resolve(keys.map((key) => this.backend.kv.get(key) ?? null)); + } + public set( key: string, value: string, @@ -167,6 +172,25 @@ class InMemoryAtomicRedisClient implements RedisTaskClient { } } +class StrictAtomicRedisClient extends InMemoryAtomicRedisClient { + public override set( + key: string, + value: string, + mode?: 'NX' | 'XX', + ): Promise<'OK' | null> { + void key; + void value; + void mode; + throw new Error('Direct set() is not allowed in strict atomic tests.'); + } + + public override sadd(key: string, member: string): Promise { + void key; + void member; + throw new Error('Direct sadd() is not allowed in strict atomic tests.'); + } +} + function createRepositoryPair(now: () => number): [RedisTaskRepository, RedisTaskRepository] { const backend = new InMemoryRedisBackend(); @@ -182,7 +206,56 @@ function createRepositoryPair(now: () => number): [RedisTaskRepository, RedisTas ]; } +function createStrictRepositoryPair( + now: () => number, +): [RedisTaskRepository, RedisTaskRepository] { + const backend = new InMemoryRedisBackend(); + + return [ + new RedisTaskRepository({ + client: new StrictAtomicRedisClient(backend), + now, + }), + new RedisTaskRepository({ + client: new StrictAtomicRedisClient(backend), + now, + }), + ]; +} + describe('RedisTaskRepository atomic transitions', () => { + it('creates atomically under concurrent create race', async () => { + const [repositoryA, repositoryB] = createStrictRepositoryPair( + () => 1_700_000_000_000, + ); + + const [createA, createB] = await Promise.allSettled([ + repositoryA.create({ + project: 'queue', + mission: 'phase1', + taskId: 'MQ-004-CREATE', + title: 'create race', + }), + repositoryB.create({ + project: 'queue', + mission: 'phase1', + taskId: 'MQ-004-CREATE', + title: 'create race duplicate', + }), + ]); + + const fulfilled = [createA, createB].filter( + (result) => result.status === 'fulfilled', + ); + const rejected = [createA, createB].filter( + (result) => result.status === 'rejected', + ); + + expect(fulfilled).toHaveLength(1); + expect(rejected).toHaveLength(1); + expect(rejected[0]?.reason).toBeInstanceOf(TaskAlreadyExistsError); + }); + it('claims a pending task once and blocks concurrent double-claim', async () => { let timestamp = 1_700_000_000_000; const now = (): number => timestamp; @@ -356,4 +429,31 @@ describe('RedisTaskRepository atomic transitions', () => { }), ).rejects.toBeInstanceOf(TaskOwnershipError); }); + + it('merges concurrent non-conflicting update patches atomically', async () => { + const [repositoryA, repositoryB] = createRepositoryPair(() => 1_700_000_000_000); + + await repositoryA.create({ + project: 'queue', + mission: 'phase1', + taskId: 'MQ-004-UPD', + title: 'Original title', + description: 'Original description', + }); + + await Promise.all([ + repositoryA.update('MQ-004-UPD', { + title: 'Updated title', + }), + repositoryB.update('MQ-004-UPD', { + description: 'Updated description', + }), + ]); + + const latest = await repositoryA.get('MQ-004-UPD'); + + expect(latest).not.toBeNull(); + expect(latest?.title).toBe('Updated title'); + expect(latest?.description).toBe('Updated description'); + }); }); diff --git a/packages/queue/tests/task-repository.test.ts b/packages/queue/tests/task-repository.test.ts index 226d546..d0232fc 100644 --- a/packages/queue/tests/task-repository.test.ts +++ b/packages/queue/tests/task-repository.test.ts @@ -3,21 +3,81 @@ import { describe, expect, it } from 'vitest'; import { RedisTaskRepository, TaskAlreadyExistsError, + TaskTransitionError, type RedisTaskClient, type RedisTaskTransaction, } from '../src/task-repository.js'; class NoopRedisTransaction implements RedisTaskTransaction { - public set(): RedisTaskTransaction { + private readonly operations: ( + | { + readonly type: 'set'; + readonly key: string; + readonly value: string; + readonly mode?: 'NX' | 'XX'; + } + | { + readonly type: 'sadd'; + readonly key: string; + readonly member: string; + } + )[] = []; + + public constructor( + private readonly kv: Map, + private readonly sets: Map>, + ) {} + + public set(key: string, value: string, mode?: 'NX' | 'XX'): RedisTaskTransaction { + this.operations.push({ + type: 'set', + key, + value, + mode, + }); return this; } - public sadd(): RedisTaskTransaction { + public sadd(key: string, member: string): RedisTaskTransaction { + this.operations.push({ + type: 'sadd', + key, + member, + }); return this; } public exec(): Promise { - return Promise.resolve([]); + const results: (readonly [Error | null, unknown])[] = []; + + for (const operation of this.operations) { + if (operation.type === 'set') { + const exists = this.kv.has(operation.key); + + if (operation.mode === 'NX' && exists) { + results.push([null, null]); + continue; + } + + if (operation.mode === 'XX' && !exists) { + results.push([null, null]); + continue; + } + + this.kv.set(operation.key, operation.value); + results.push([null, 'OK']); + continue; + } + + const values = this.sets.get(operation.key) ?? new Set(); + const beforeSize = values.size; + + values.add(operation.member); + this.sets.set(operation.key, values); + results.push([null, values.size === beforeSize ? 0 : 1]); + } + + return Promise.resolve(results); } } @@ -29,6 +89,10 @@ class InMemoryRedisClient implements RedisTaskClient { return Promise.resolve(this.kv.get(key) ?? null); } + public mget(...keys: string[]): Promise<(string | null)[]> { + return Promise.resolve(keys.map((key) => this.kv.get(key) ?? null)); + } + public set( key: string, value: string, @@ -71,7 +135,24 @@ class InMemoryRedisClient implements RedisTaskClient { } public multi(): RedisTaskTransaction { - return new NoopRedisTransaction(); + return new NoopRedisTransaction(this.kv, this.sets); + } +} + +class MgetTrackingRedisClient extends InMemoryRedisClient { + public getCalls = 0; + public mgetCalls = 0; + public lastMgetKeys: string[] = []; + + public override get(key: string): Promise { + this.getCalls += 1; + return super.get(key); + } + + public override mget(...keys: string[]): Promise<(string | null)[]> { + this.mgetCalls += 1; + this.lastMgetKeys = [...keys]; + return super.mget(...keys); } } @@ -141,8 +222,9 @@ describe('RedisTaskRepository CRUD', () => { title: 'Claimed task', }); - await repository.update('MQ-003B', { - status: 'claimed', + await repository.claim('MQ-003B', { + agentId: 'agent-a', + ttlSeconds: 60, }); const byProject = await repository.list({ @@ -160,6 +242,39 @@ describe('RedisTaskRepository CRUD', () => { expect(byStatus.map((task) => task.taskId)).toEqual(['MQ-003B']); }); + it('lists 3+ tasks with a single mget call', async () => { + const client = new MgetTrackingRedisClient(); + const repository = new RedisTaskRepository({ + client, + }); + + await repository.create({ + project: 'queue', + mission: 'phase-list', + taskId: 'MQ-MGET-001', + title: 'Task one', + }); + await repository.create({ + project: 'queue', + mission: 'phase-list', + taskId: 'MQ-MGET-002', + title: 'Task two', + }); + await repository.create({ + project: 'queue', + mission: 'phase-list', + taskId: 'MQ-MGET-003', + title: 'Task three', + }); + + const tasks = await repository.list(); + + expect(tasks).toHaveLength(3); + expect(client.mgetCalls).toBe(1); + expect(client.getCalls).toBe(0); + expect(client.lastMgetKeys).toHaveLength(3); + }); + it('updates mutable fields and preserves immutable fields', async () => { const repository = new RedisTaskRepository({ client: new InMemoryRedisClient(), @@ -195,4 +310,23 @@ describe('RedisTaskRepository CRUD', () => { expect(updated.taskId).toBe('MQ-003'); expect(updated.updatedAt).toBe(1_700_000_000_001); }); + + it('rejects status transitions through update()', async () => { + const repository = new RedisTaskRepository({ + client: new InMemoryRedisClient(), + }); + + await repository.create({ + project: 'queue', + mission: 'phase1', + taskId: 'MQ-003-TRANSITION', + title: 'Transition guard', + }); + + await expect( + repository.update('MQ-003-TRANSITION', { + status: 'completed', + }), + ).rejects.toBeInstanceOf(TaskTransitionError); + }); });