diff --git a/packages/queue/src/index.ts b/packages/queue/src/index.ts index 9998ac0..1a005ff 100644 --- a/packages/queue/src/index.ts +++ b/packages/queue/src/index.ts @@ -15,10 +15,22 @@ export type { export { RedisTaskRepository, TaskAlreadyExistsError, + TaskAtomicConflictError, TaskNotFoundError, + TaskOwnershipError, TaskSerializationError, + TaskTransitionError, +} from './task-repository.js'; +export type { + ClaimTaskInput, + CompleteTaskInput, + FailTaskInput, + HeartbeatTaskInput, + RedisTaskClient, + RedisTaskRepositoryOptions, + RedisTaskTransaction, + ReleaseTaskInput, } from './task-repository.js'; -export type { RedisTaskClient, RedisTaskRepositoryOptions } from './task-repository.js'; export { TASK_LANES, TASK_PRIORITIES, diff --git a/packages/queue/src/task-repository.ts b/packages/queue/src/task-repository.ts index b31f8b8..30d504a 100644 --- a/packages/queue/src/task-repository.ts +++ b/packages/queue/src/task-repository.ts @@ -16,6 +16,7 @@ const PRIORITY_SET = new Set(TASK_PRIORITIES); const LANE_SET = new Set(TASK_LANES); const DEFAULT_KEY_PREFIX = 'mosaic:queue'; +const MAX_ATOMIC_RETRIES = 8; interface RepositoryKeys { readonly taskIds: string; @@ -27,6 +28,15 @@ export interface RedisTaskClient { set(key: string, value: string, mode?: 'NX' | 'XX'): Promise<'OK' | null>; smembers(key: string): Promise; sadd(key: string, member: string): Promise; + watch(...keys: string[]): Promise<'OK'>; + unwatch(): Promise<'OK'>; + multi(): RedisTaskTransaction; +} + +export interface RedisTaskTransaction { + set(key: string, value: string, mode?: 'NX' | 'XX'): RedisTaskTransaction; + sadd(key: string, member: string): RedisTaskTransaction; + exec(): Promise; } export interface RedisTaskRepositoryOptions { @@ -35,6 +45,30 @@ export interface RedisTaskRepositoryOptions { readonly now?: () => number; } +export interface ClaimTaskInput { + readonly agentId: string; + readonly ttlSeconds: number; +} + +export interface ReleaseTaskInput { + readonly agentId?: string; +} + +export interface HeartbeatTaskInput { + readonly agentId?: string; + readonly ttlSeconds?: number; +} + +export interface CompleteTaskInput { + readonly agentId?: string; + readonly summary?: string; +} + +export interface FailTaskInput { + readonly agentId?: string; + readonly reason: string; +} + export class TaskAlreadyExistsError extends Error { public constructor(taskId: string) { super(`Task ${taskId} already exists.`); @@ -56,6 +90,29 @@ export class TaskSerializationError extends Error { } } +export class TaskTransitionError extends Error { + public constructor(taskId: string, status: TaskStatus, action: string) { + super(`Task ${taskId} cannot transition from ${status} via ${action}.`); + this.name = 'TaskTransitionError'; + } +} + +export class TaskOwnershipError extends Error { + public constructor(taskId: string, expectedAgentId: string, actualAgentId: string) { + super( + `Task ${taskId} is owned by ${actualAgentId}, not ${expectedAgentId}.`, + ); + this.name = 'TaskOwnershipError'; + } +} + +export class TaskAtomicConflictError extends Error { + public constructor(taskId: string) { + super(`Task ${taskId} could not be updated atomically after multiple retries.`); + this.name = 'TaskAtomicConflictError'; + } +} + export class RedisTaskRepository { private readonly client: RedisTaskClient; private readonly keys: RepositoryKeys; @@ -157,6 +214,158 @@ export class RedisTaskRepository { return updated; } + + public async claim(taskId: string, input: ClaimTaskInput): Promise { + if (input.ttlSeconds <= 0) { + throw new Error(`Task ${taskId} claim ttl must be greater than 0 seconds.`); + } + + return this.mutateTaskAtomically(taskId, (existing, now) => { + if (!canClaimTask(existing, now)) { + throw new TaskTransitionError(taskId, existing.status, 'claim'); + } + + const base = withoutCompletionAndFailureFields(withoutClaimFields(existing)); + + return { + ...base, + status: 'claimed', + claimedBy: input.agentId, + claimedAt: now, + claimTTL: input.ttlSeconds, + updatedAt: now, + }; + }); + } + + public async release(taskId: string, input: ReleaseTaskInput = {}): Promise { + return this.mutateTaskAtomically(taskId, (existing, now) => { + if (!isClaimedLikeStatus(existing.status)) { + throw new TaskTransitionError(taskId, existing.status, 'release'); + } + + assertTaskOwnership(taskId, existing, input.agentId); + + const base = withoutClaimFields(existing); + + return { + ...base, + status: 'pending', + updatedAt: now, + }; + }); + } + + public async heartbeat( + taskId: string, + input: HeartbeatTaskInput = {}, + ): Promise { + return this.mutateTaskAtomically(taskId, (existing, now) => { + if (!isClaimedLikeStatus(existing.status)) { + throw new TaskTransitionError(taskId, existing.status, 'heartbeat'); + } + + if (isClaimExpired(existing, now)) { + throw new TaskTransitionError(taskId, existing.status, 'heartbeat'); + } + + assertTaskOwnership(taskId, existing, input.agentId); + + const ttl = input.ttlSeconds ?? existing.claimTTL; + + if (ttl === undefined || ttl <= 0) { + throw new TaskTransitionError(taskId, existing.status, 'heartbeat'); + } + + return { + ...existing, + claimedAt: now, + claimTTL: ttl, + updatedAt: now, + }; + }); + } + + public async complete( + taskId: string, + input: CompleteTaskInput = {}, + ): Promise { + return this.mutateTaskAtomically(taskId, (existing, now) => { + if (!isClaimedLikeStatus(existing.status)) { + throw new TaskTransitionError(taskId, existing.status, 'complete'); + } + + assertTaskOwnership(taskId, existing, input.agentId); + + const base = withoutCompletionAndFailureFields(withoutClaimFields(existing)); + + return { + ...base, + status: 'completed', + completedAt: now, + ...(input.summary === undefined ? {} : { completionSummary: input.summary }), + updatedAt: now, + }; + }); + } + + public async fail(taskId: string, input: FailTaskInput): Promise { + return this.mutateTaskAtomically(taskId, (existing, now) => { + if (!isClaimedLikeStatus(existing.status)) { + throw new TaskTransitionError(taskId, existing.status, 'fail'); + } + + assertTaskOwnership(taskId, existing, input.agentId); + + const base = withoutCompletionAndFailureFields(withoutClaimFields(existing)); + + return { + ...base, + status: 'failed', + failedAt: now, + failureReason: input.reason, + retryCount: existing.retryCount + 1, + updatedAt: now, + }; + }); + } + + private async mutateTaskAtomically( + taskId: string, + mutation: (existing: Task, now: number) => Task, + ): Promise { + const taskKey = this.keys.task(taskId); + + for (let attempt = 0; attempt < MAX_ATOMIC_RETRIES; attempt += 1) { + await this.client.watch(taskKey); + + try { + const raw = await this.client.get(taskKey); + + if (raw === null) { + throw new TaskNotFoundError(taskId); + } + + const existing = deserializeTask(taskId, raw); + const updated = mutation(existing, this.now()); + + const transaction = this.client.multi(); + transaction.set(taskKey, JSON.stringify(updated), 'XX'); + transaction.sadd(this.keys.taskIds, taskId); + const execResult = await transaction.exec(); + + if (execResult === null) { + continue; + } + + return updated; + } finally { + await this.client.unwatch(); + } + } + + throw new TaskAtomicConflictError(taskId); + } } function matchesFilters(task: Task, filters: TaskListFilters): boolean { @@ -175,6 +384,69 @@ function matchesFilters(task: Task, filters: TaskListFilters): boolean { return true; } +function canClaimTask(task: Task, now: number): boolean { + if (task.status === 'pending') { + return true; + } + + if (!isClaimedLikeStatus(task.status)) { + return false; + } + + return isClaimExpired(task, now); +} + +function isClaimedLikeStatus(status: TaskStatus): boolean { + return status === 'claimed' || status === 'in-progress'; +} + +function isClaimExpired(task: Task, now: number): boolean { + if (task.claimedAt === undefined || task.claimTTL === undefined) { + return false; + } + + return task.claimedAt + task.claimTTL * 1000 <= now; +} + +function assertTaskOwnership( + taskId: string, + task: Task, + expectedAgentId: string | undefined, +): void { + if (expectedAgentId === undefined || task.claimedBy === undefined) { + return; + } + + if (task.claimedBy !== expectedAgentId) { + throw new TaskOwnershipError(taskId, expectedAgentId, task.claimedBy); + } +} + +type TaskWithoutClaimFields = Omit; +type TaskWithoutCompletionAndFailureFields = Omit< + Task, + 'completedAt' | 'failedAt' | 'failureReason' | 'completionSummary' +>; + +function withoutClaimFields(task: Task): TaskWithoutClaimFields { + const draft: Partial = { ...task }; + delete draft.claimedBy; + delete draft.claimedAt; + delete draft.claimTTL; + return draft as TaskWithoutClaimFields; +} + +function withoutCompletionAndFailureFields( + task: TaskWithoutClaimFields, +): TaskWithoutCompletionAndFailureFields { + const draft: Partial = { ...task }; + delete draft.completedAt; + delete draft.failedAt; + delete draft.failureReason; + delete draft.completionSummary; + return draft as TaskWithoutCompletionAndFailureFields; +} + function deserializeTask(taskId: string, raw: string): Task { let parsed: unknown; diff --git a/packages/queue/tests/task-atomic.test.ts b/packages/queue/tests/task-atomic.test.ts new file mode 100644 index 0000000..f931f27 --- /dev/null +++ b/packages/queue/tests/task-atomic.test.ts @@ -0,0 +1,330 @@ +import { describe, expect, it } from 'vitest'; + +import { + RedisTaskRepository, + TaskTransitionError, + type RedisTaskClient, + type RedisTaskTransaction, +} from '../src/task-repository.js'; + +type QueuedOperation = + | { + readonly type: 'set'; + readonly key: string; + readonly value: string; + readonly mode?: 'NX' | 'XX'; + } + | { + readonly type: 'sadd'; + readonly key: string; + readonly member: string; + }; + +class InMemoryRedisBackend { + public readonly kv = new Map(); + public readonly sets = new Map>(); + public readonly revisions = new Map(); + + public getRevision(key: string): number { + return this.revisions.get(key) ?? 0; + } + + public bumpRevision(key: string): void { + this.revisions.set(key, this.getRevision(key) + 1); + } +} + +class InMemoryRedisTransaction implements RedisTaskTransaction { + private readonly operations: QueuedOperation[] = []; + + public constructor( + private readonly backend: InMemoryRedisBackend, + private readonly watchedRevisions: ReadonlyMap, + ) {} + + public set(key: string, value: string, mode?: 'NX' | 'XX'): RedisTaskTransaction { + this.operations.push({ + type: 'set', + key, + value, + mode, + }); + return this; + } + + public sadd(key: string, member: string): RedisTaskTransaction { + this.operations.push({ + type: 'sadd', + key, + member, + }); + return this; + } + + public exec(): Promise { + for (const [key, revision] of this.watchedRevisions.entries()) { + if (this.backend.getRevision(key) !== revision) { + return Promise.resolve(null); + } + } + + const results: (readonly [Error | null, unknown])[] = []; + + for (const operation of this.operations) { + if (operation.type === 'set') { + const exists = this.backend.kv.has(operation.key); + if (operation.mode === 'NX' && exists) { + results.push([null, null]); + continue; + } + + if (operation.mode === 'XX' && !exists) { + results.push([null, null]); + continue; + } + + this.backend.kv.set(operation.key, operation.value); + this.backend.bumpRevision(operation.key); + results.push([null, 'OK']); + continue; + } + + const set = this.backend.sets.get(operation.key) ?? new Set(); + const before = set.size; + + set.add(operation.member); + this.backend.sets.set(operation.key, set); + this.backend.bumpRevision(operation.key); + results.push([null, set.size === before ? 0 : 1]); + } + + return Promise.resolve(results); + } +} + +class InMemoryAtomicRedisClient implements RedisTaskClient { + private watchedRevisions = new Map(); + + public constructor(private readonly backend: InMemoryRedisBackend) {} + + public get(key: string): Promise { + return Promise.resolve(this.backend.kv.get(key) ?? null); + } + + public set( + key: string, + value: string, + mode?: 'NX' | 'XX', + ): Promise<'OK' | null> { + const exists = this.backend.kv.has(key); + + if (mode === 'NX' && exists) { + return Promise.resolve(null); + } + + if (mode === 'XX' && !exists) { + return Promise.resolve(null); + } + + this.backend.kv.set(key, value); + this.backend.bumpRevision(key); + + return Promise.resolve('OK'); + } + + public smembers(key: string): Promise { + return Promise.resolve([...(this.backend.sets.get(key) ?? new Set())]); + } + + public sadd(key: string, member: string): Promise { + const values = this.backend.sets.get(key) ?? new Set(); + const before = values.size; + + values.add(member); + this.backend.sets.set(key, values); + this.backend.bumpRevision(key); + + return Promise.resolve(values.size === before ? 0 : 1); + } + + public watch(...keys: string[]): Promise<'OK'> { + this.watchedRevisions = new Map( + keys.map((key) => [key, this.backend.getRevision(key)]), + ); + return Promise.resolve('OK'); + } + + public unwatch(): Promise<'OK'> { + this.watchedRevisions.clear(); + return Promise.resolve('OK'); + } + + public multi(): RedisTaskTransaction { + const watchedSnapshot = new Map(this.watchedRevisions); + this.watchedRevisions.clear(); + return new InMemoryRedisTransaction(this.backend, watchedSnapshot); + } +} + +function createRepositoryPair(now: () => number): [RedisTaskRepository, RedisTaskRepository] { + const backend = new InMemoryRedisBackend(); + + return [ + new RedisTaskRepository({ + client: new InMemoryAtomicRedisClient(backend), + now, + }), + new RedisTaskRepository({ + client: new InMemoryAtomicRedisClient(backend), + now, + }), + ]; +} + +describe('RedisTaskRepository atomic transitions', () => { + it('claims a pending task once and blocks concurrent double-claim', async () => { + let timestamp = 1_700_000_000_000; + const now = (): number => timestamp; + const [repositoryA, repositoryB] = createRepositoryPair(now); + + await repositoryA.create({ + project: 'queue', + mission: 'phase1', + taskId: 'MQ-004', + title: 'Atomic claim', + }); + + const [claimA, claimB] = await Promise.allSettled([ + repositoryA.claim('MQ-004', { agentId: 'agent-a', ttlSeconds: 60 }), + repositoryB.claim('MQ-004', { agentId: 'agent-b', ttlSeconds: 60 }), + ]); + + const fulfilled = [claimA, claimB].filter((result) => result.status === 'fulfilled'); + const rejected = [claimA, claimB].filter((result) => result.status === 'rejected'); + + expect(fulfilled).toHaveLength(1); + expect(rejected).toHaveLength(1); + }); + + it('allows claim takeover after TTL expiry', async () => { + let timestamp = 1_700_000_000_000; + const now = (): number => timestamp; + const [repositoryA, repositoryB] = createRepositoryPair(now); + + await repositoryA.create({ + project: 'queue', + mission: 'phase1', + taskId: 'MQ-004-EXP', + title: 'TTL expiry', + }); + + await repositoryA.claim('MQ-004-EXP', { + agentId: 'agent-a', + ttlSeconds: 1, + }); + + timestamp += 2_000; + + const takeover = await repositoryB.claim('MQ-004-EXP', { + agentId: 'agent-b', + ttlSeconds: 60, + }); + + expect(takeover.claimedBy).toBe('agent-b'); + }); + + it('releases a claimed task back to pending', async () => { + const [repository] = createRepositoryPair(() => 1_700_000_000_000); + + await repository.create({ + project: 'queue', + mission: 'phase1', + taskId: 'MQ-004-REL', + title: 'Release test', + }); + + await repository.claim('MQ-004-REL', { + agentId: 'agent-a', + ttlSeconds: 60, + }); + + const released = await repository.release('MQ-004-REL', { + agentId: 'agent-a', + }); + + expect(released.status).toBe('pending'); + expect(released.claimedBy).toBeUndefined(); + expect(released.claimedAt).toBeUndefined(); + }); + + it('heartbeats, completes, and fails with valid transitions', async () => { + let timestamp = 1_700_000_000_000; + const now = (): number => timestamp; + const [repository] = createRepositoryPair(now); + + await repository.create({ + project: 'queue', + mission: 'phase1', + taskId: 'MQ-004-HCF', + title: 'Transition test', + }); + + await repository.claim('MQ-004-HCF', { + agentId: 'agent-a', + ttlSeconds: 60, + }); + + timestamp += 1_000; + const heartbeat = await repository.heartbeat('MQ-004-HCF', { + agentId: 'agent-a', + ttlSeconds: 120, + }); + expect(heartbeat.claimTTL).toBe(120); + expect(heartbeat.claimedAt).toBe(1_700_000_001_000); + + const completed = await repository.complete('MQ-004-HCF', { + agentId: 'agent-a', + summary: 'done', + }); + expect(completed.status).toBe('completed'); + expect(completed.completionSummary).toBe('done'); + + await repository.create({ + project: 'queue', + mission: 'phase1', + taskId: 'MQ-004-FAIL', + title: 'Failure test', + }); + + await repository.claim('MQ-004-FAIL', { + agentId: 'agent-a', + ttlSeconds: 60, + }); + + const failed = await repository.fail('MQ-004-FAIL', { + agentId: 'agent-a', + reason: 'boom', + }); + + expect(failed.status).toBe('failed'); + expect(failed.failureReason).toBe('boom'); + expect(failed.retryCount).toBe(1); + }); + + it('rejects invalid transitions', async () => { + const [repository] = createRepositoryPair(() => 1_700_000_000_000); + + await repository.create({ + project: 'queue', + mission: 'phase1', + taskId: 'MQ-004-INV', + title: 'Invalid transitions', + }); + + await expect( + repository.complete('MQ-004-INV', { + agentId: 'agent-a', + }), + ).rejects.toBeInstanceOf(TaskTransitionError); + }); +}); diff --git a/packages/queue/tests/task-repository.test.ts b/packages/queue/tests/task-repository.test.ts index 180b7df..226d546 100644 --- a/packages/queue/tests/task-repository.test.ts +++ b/packages/queue/tests/task-repository.test.ts @@ -4,8 +4,23 @@ import { RedisTaskRepository, TaskAlreadyExistsError, type RedisTaskClient, + type RedisTaskTransaction, } from '../src/task-repository.js'; +class NoopRedisTransaction implements RedisTaskTransaction { + public set(): RedisTaskTransaction { + return this; + } + + public sadd(): RedisTaskTransaction { + return this; + } + + public exec(): Promise { + return Promise.resolve([]); + } +} + class InMemoryRedisClient implements RedisTaskClient { private readonly kv = new Map(); private readonly sets = new Map>(); @@ -46,6 +61,18 @@ class InMemoryRedisClient implements RedisTaskClient { return Promise.resolve(values.size === beforeSize ? 0 : 1); } + + public watch(): Promise<'OK'> { + return Promise.resolve('OK'); + } + + public unwatch(): Promise<'OK'> { + return Promise.resolve('OK'); + } + + public multi(): RedisTaskTransaction { + return new NoopRedisTransaction(); + } } describe('RedisTaskRepository CRUD', () => {