From ce7fb27c464541639633fe3b548ebb52b2b7a85f Mon Sep 17 00:00:00 2001 From: Jason Woltje Date: Thu, 5 Feb 2026 18:26:50 -0600 Subject: [PATCH] fix(#338): Add rate limiting to orchestrator API - Add @nestjs/throttler for rate limiting support - Configure multiple throttle profiles: default (100/min), strict (10/min for spawn/kill), status (200/min for polling) - Apply strict rate limits to spawn and kill endpoints to prevent DoS - Apply higher rate limits to status/health endpoints for monitoring - Add OrchestratorThrottlerGuard with X-Forwarded-For support for proxy setups - Add unit tests for throttler guard Refs #338 Co-Authored-By: Claude Opus 4.5 --- apps/orchestrator/package.json | 1 + .../src/api/agents/agents.controller.ts | 14 +- .../src/api/health/health.controller.ts | 17 ++- .../src/api/health/health.module.ts | 2 + apps/orchestrator/src/app.module.ts | 24 ++++ .../src/common/guards/throttler.guard.spec.ts | 122 ++++++++++++++++++ .../src/common/guards/throttler.guard.ts | 63 +++++++++ pnpm-lock.yaml | 5 + 8 files changed, 244 insertions(+), 4 deletions(-) create mode 100644 apps/orchestrator/src/common/guards/throttler.guard.spec.ts create mode 100644 apps/orchestrator/src/common/guards/throttler.guard.ts diff --git a/apps/orchestrator/package.json b/apps/orchestrator/package.json index 12287d8..4983a02 100644 --- a/apps/orchestrator/package.json +++ b/apps/orchestrator/package.json @@ -26,6 +26,7 @@ "@nestjs/config": "^4.0.2", "@nestjs/core": "^11.1.12", "@nestjs/platform-express": "^11.1.12", + "@nestjs/throttler": "^6.5.0", "bullmq": "^5.67.2", "class-transformer": "^0.5.1", "class-validator": "^0.14.1", diff --git a/apps/orchestrator/src/api/agents/agents.controller.ts b/apps/orchestrator/src/api/agents/agents.controller.ts index 69e4d90..3c0bd52 100644 --- a/apps/orchestrator/src/api/agents/agents.controller.ts +++ b/apps/orchestrator/src/api/agents/agents.controller.ts @@ -12,21 +12,28 @@ import { HttpCode, UseGuards, } from "@nestjs/common"; +import { Throttle } from "@nestjs/throttler"; import { QueueService } from "../../queue/queue.service"; import { AgentSpawnerService } from "../../spawner/agent-spawner.service"; import { AgentLifecycleService } from "../../spawner/agent-lifecycle.service"; import { KillswitchService } from "../../killswitch/killswitch.service"; import { SpawnAgentDto, SpawnAgentResponseDto } from "./dto/spawn-agent.dto"; import { OrchestratorApiKeyGuard } from "../../common/guards/api-key.guard"; +import { OrchestratorThrottlerGuard } from "../../common/guards/throttler.guard"; /** * Controller for agent management endpoints * * All endpoints require API key authentication via X-API-Key header. * Set ORCHESTRATOR_API_KEY environment variable to configure the expected key. + * + * Rate limits: + * - Status endpoints: 200 requests/minute + * - Spawn/kill endpoints: 10 requests/minute (strict) + * - Default: 100 requests/minute */ @Controller("agents") -@UseGuards(OrchestratorApiKeyGuard) +@UseGuards(OrchestratorApiKeyGuard, OrchestratorThrottlerGuard) export class AgentsController { private readonly logger = new Logger(AgentsController.name); @@ -43,6 +50,7 @@ export class AgentsController { * @returns Agent spawn response with agentId and status */ @Post("spawn") + @Throttle({ strict: { limit: 10, ttl: 60000 } }) @UsePipes(new ValidationPipe({ transform: true, whitelist: true })) async spawn(@Body() dto: SpawnAgentDto): Promise { this.logger.log(`Received spawn request for task: ${dto.taskId}`); @@ -81,6 +89,7 @@ export class AgentsController { * @returns Array of all agent sessions with their status */ @Get() + @Throttle({ status: { limit: 200, ttl: 60000 } }) listAgents(): { agentId: string; taskId: string; @@ -123,6 +132,7 @@ export class AgentsController { * @returns Agent status details */ @Get(":agentId/status") + @Throttle({ status: { limit: 200, ttl: 60000 } }) async getAgentStatus(@Param("agentId") agentId: string): Promise<{ agentId: string; taskId: string; @@ -181,6 +191,7 @@ export class AgentsController { * @returns Success message */ @Post(":agentId/kill") + @Throttle({ strict: { limit: 10, ttl: 60000 } }) @HttpCode(200) async killAgent(@Param("agentId") agentId: string): Promise<{ message: string }> { this.logger.warn(`Received kill request for agent: ${agentId}`); @@ -204,6 +215,7 @@ export class AgentsController { * @returns Summary of kill operation */ @Post("kill-all") + @Throttle({ strict: { limit: 10, ttl: 60000 } }) @HttpCode(200) async killAllAgents(): Promise<{ message: string; diff --git a/apps/orchestrator/src/api/health/health.controller.ts b/apps/orchestrator/src/api/health/health.controller.ts index 9401148..a0e0de6 100644 --- a/apps/orchestrator/src/api/health/health.controller.ts +++ b/apps/orchestrator/src/api/health/health.controller.ts @@ -1,12 +1,22 @@ -import { Controller, Get } from "@nestjs/common"; +import { Controller, Get, UseGuards } from "@nestjs/common"; +import { Throttle } from "@nestjs/throttler"; import { HealthService } from "./health.service"; +import { OrchestratorThrottlerGuard } from "../../common/guards/throttler.guard"; +/** + * Health check controller for orchestrator service + * + * Rate limits: + * - Health endpoints: 200 requests/minute (higher for monitoring) + */ @Controller("health") +@UseGuards(OrchestratorThrottlerGuard) export class HealthController { constructor(private readonly healthService: HealthService) {} @Get() - check() { + @Throttle({ status: { limit: 200, ttl: 60000 } }) + check(): { status: string; uptime: number; timestamp: string } { return { status: "healthy", uptime: this.healthService.getUptime(), @@ -15,7 +25,8 @@ export class HealthController { } @Get("ready") - ready() { + @Throttle({ status: { limit: 200, ttl: 60000 } }) + ready(): { ready: boolean } { // NOTE: Check Valkey connection, Docker daemon (see issue #TBD) return { ready: true }; } diff --git a/apps/orchestrator/src/api/health/health.module.ts b/apps/orchestrator/src/api/health/health.module.ts index 40b7bdf..bf94834 100644 --- a/apps/orchestrator/src/api/health/health.module.ts +++ b/apps/orchestrator/src/api/health/health.module.ts @@ -1,7 +1,9 @@ import { Module } from "@nestjs/common"; import { HealthController } from "./health.controller"; +import { HealthService } from "./health.service"; @Module({ controllers: [HealthController], + providers: [HealthService], }) export class HealthModule {} diff --git a/apps/orchestrator/src/app.module.ts b/apps/orchestrator/src/app.module.ts index 55b7e24..5ff056a 100644 --- a/apps/orchestrator/src/app.module.ts +++ b/apps/orchestrator/src/app.module.ts @@ -1,12 +1,19 @@ import { Module } from "@nestjs/common"; import { ConfigModule } from "@nestjs/config"; import { BullModule } from "@nestjs/bullmq"; +import { ThrottlerModule } from "@nestjs/throttler"; import { HealthModule } from "./api/health/health.module"; import { AgentsModule } from "./api/agents/agents.module"; import { CoordinatorModule } from "./coordinator/coordinator.module"; import { BudgetModule } from "./budget/budget.module"; import { orchestratorConfig } from "./config/orchestrator.config"; +/** + * Rate limiting configuration: + * - 'default': Standard API endpoints (100 requests per minute) + * - 'strict': Spawn/kill endpoints (10 requests per minute) - prevents DoS + * - 'status': Status/health endpoints (200 requests per minute) - higher for polling + */ @Module({ imports: [ ConfigModule.forRoot({ @@ -19,6 +26,23 @@ import { orchestratorConfig } from "./config/orchestrator.config"; port: parseInt(process.env.VALKEY_PORT ?? "6379"), }, }), + ThrottlerModule.forRoot([ + { + name: "default", + ttl: 60000, // 1 minute + limit: 100, // 100 requests per minute + }, + { + name: "strict", + ttl: 60000, // 1 minute + limit: 10, // 10 requests per minute for spawn/kill + }, + { + name: "status", + ttl: 60000, // 1 minute + limit: 200, // 200 requests per minute for status endpoints + }, + ]), HealthModule, AgentsModule, CoordinatorModule, diff --git a/apps/orchestrator/src/common/guards/throttler.guard.spec.ts b/apps/orchestrator/src/common/guards/throttler.guard.spec.ts new file mode 100644 index 0000000..53cf169 --- /dev/null +++ b/apps/orchestrator/src/common/guards/throttler.guard.spec.ts @@ -0,0 +1,122 @@ +import { describe, it, expect, beforeEach, vi } from "vitest"; +import { ExecutionContext } from "@nestjs/common"; +import { ThrottlerException, ThrottlerModuleOptions, ThrottlerStorage } from "@nestjs/throttler"; +import { Reflector } from "@nestjs/core"; +import { OrchestratorThrottlerGuard } from "./throttler.guard"; + +describe("OrchestratorThrottlerGuard", () => { + let guard: OrchestratorThrottlerGuard; + + beforeEach(() => { + // Create guard with minimal mocks for testing protected methods + const options: ThrottlerModuleOptions = { + throttlers: [{ name: "default", ttl: 60000, limit: 100 }], + }; + const storageService = {} as ThrottlerStorage; + const reflector = {} as Reflector; + + guard = new OrchestratorThrottlerGuard(options, storageService, reflector); + }); + + describe("getTracker", () => { + it("should extract IP from X-Forwarded-For header", async () => { + const req = { + headers: { + "x-forwarded-for": "192.168.1.1, 10.0.0.1", + }, + ip: "127.0.0.1", + }; + + // Access protected method for testing + const tracker = await ( + guard as unknown as { getTracker: (req: unknown) => Promise } + ).getTracker(req); + + expect(tracker).toBe("192.168.1.1"); + }); + + it("should handle X-Forwarded-For as array", async () => { + const req = { + headers: { + "x-forwarded-for": ["192.168.1.1, 10.0.0.1"], + }, + ip: "127.0.0.1", + }; + + const tracker = await ( + guard as unknown as { getTracker: (req: unknown) => Promise } + ).getTracker(req); + + expect(tracker).toBe("192.168.1.1"); + }); + + it("should fallback to request IP when no X-Forwarded-For", async () => { + const req = { + headers: {}, + ip: "192.168.2.2", + }; + + const tracker = await ( + guard as unknown as { getTracker: (req: unknown) => Promise } + ).getTracker(req); + + expect(tracker).toBe("192.168.2.2"); + }); + + it("should fallback to connection remoteAddress when no IP", async () => { + const req = { + headers: {}, + connection: { + remoteAddress: "192.168.3.3", + }, + }; + + const tracker = await ( + guard as unknown as { getTracker: (req: unknown) => Promise } + ).getTracker(req); + + expect(tracker).toBe("192.168.3.3"); + }); + + it("should return 'unknown' when no IP available", async () => { + const req = { + headers: {}, + }; + + const tracker = await ( + guard as unknown as { getTracker: (req: unknown) => Promise } + ).getTracker(req); + + expect(tracker).toBe("unknown"); + }); + }); + + describe("throwThrottlingException", () => { + it("should throw ThrottlerException with endpoint info", () => { + const mockRequest = { + url: "/agents/spawn", + }; + + const mockContext = { + switchToHttp: vi.fn().mockReturnValue({ + getRequest: vi.fn().mockReturnValue(mockRequest), + }), + } as unknown as ExecutionContext; + + expect(() => { + ( + guard as unknown as { throwThrottlingException: (context: ExecutionContext) => void } + ).throwThrottlingException(mockContext); + }).toThrow(ThrottlerException); + + try { + ( + guard as unknown as { throwThrottlingException: (context: ExecutionContext) => void } + ).throwThrottlingException(mockContext); + } catch (error) { + expect(error).toBeInstanceOf(ThrottlerException); + expect((error as ThrottlerException).message).toContain("/agents/spawn"); + } + }); + }); +}); diff --git a/apps/orchestrator/src/common/guards/throttler.guard.ts b/apps/orchestrator/src/common/guards/throttler.guard.ts new file mode 100644 index 0000000..3158cb6 --- /dev/null +++ b/apps/orchestrator/src/common/guards/throttler.guard.ts @@ -0,0 +1,63 @@ +import { Injectable, ExecutionContext } from "@nestjs/common"; +import { ThrottlerGuard, ThrottlerException } from "@nestjs/throttler"; + +interface RequestWithHeaders { + headers?: Record; + ip?: string; + connection?: { remoteAddress?: string }; + url?: string; +} + +/** + * OrchestratorThrottlerGuard - Rate limiting guard for orchestrator API endpoints + * + * Uses the X-Forwarded-For header for client IP identification when behind a proxy, + * falling back to the direct connection IP. + * + * Usage: + * @UseGuards(OrchestratorThrottlerGuard) + * @Controller('agents') + * export class AgentsController { ... } + */ +@Injectable() +export class OrchestratorThrottlerGuard extends ThrottlerGuard { + /** + * Get the client IP address for rate limiting tracking + * Prioritizes X-Forwarded-For header for proxy setups + */ + protected getTracker(req: Record): Promise { + const request = req as RequestWithHeaders; + const headers = request.headers; + + // Check X-Forwarded-For header first (for proxied requests) + if (headers) { + const forwardedFor = headers["x-forwarded-for"]; + if (forwardedFor) { + // Get the first IP in the chain (original client) + const ips = Array.isArray(forwardedFor) ? forwardedFor[0] : forwardedFor; + if (ips) { + const clientIp = ips.split(",")[0]?.trim(); + if (clientIp) { + return Promise.resolve(clientIp); + } + } + } + } + + // Fallback to direct connection IP + const ip = request.ip ?? request.connection?.remoteAddress ?? "unknown"; + return Promise.resolve(ip); + } + + /** + * Custom error message for rate limit exceeded + */ + protected throwThrottlingException(context: ExecutionContext): Promise { + const request = context.switchToHttp().getRequest(); + const endpoint = request.url ?? "unknown"; + + throw new ThrottlerException( + `Rate limit exceeded for endpoint ${endpoint}. Please try again later.` + ); + } +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index eecabe9..2cf9137 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -292,6 +292,9 @@ importers: '@nestjs/platform-express': specifier: ^11.1.12 version: 11.1.12(@nestjs/common@11.1.12(class-transformer@0.5.1)(class-validator@0.14.3)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.12) + '@nestjs/throttler': + specifier: ^6.5.0 + version: 6.5.0(@nestjs/common@11.1.12(class-transformer@0.5.1)(class-validator@0.14.3)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.12)(reflect-metadata@0.2.2) bullmq: specifier: ^5.67.2 version: 5.67.2 @@ -454,6 +457,8 @@ importers: specifier: ^3.0.8 version: 3.2.4(@types/node@22.19.7)(jiti@2.6.1)(jsdom@26.1.0)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2) + packages/cli-tools: {} + packages/config: dependencies: '@eslint/js':