fix(#338): Add rate limiting to orchestrator API

- Add @nestjs/throttler for rate limiting support
- Configure multiple throttle profiles: default (100/min), strict (10/min for spawn/kill), status (200/min for polling)
- Apply strict rate limits to spawn and kill endpoints to prevent DoS
- Apply higher rate limits to status/health endpoints for monitoring
- Add OrchestratorThrottlerGuard with X-Forwarded-For support for proxy setups
- Add unit tests for throttler guard

Refs #338

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Jason Woltje
2026-02-05 18:26:50 -06:00
parent 3f16bbeca1
commit ce7fb27c46
8 changed files with 244 additions and 4 deletions

View File

@@ -12,21 +12,28 @@ import {
HttpCode,
UseGuards,
} from "@nestjs/common";
import { Throttle } from "@nestjs/throttler";
import { QueueService } from "../../queue/queue.service";
import { AgentSpawnerService } from "../../spawner/agent-spawner.service";
import { AgentLifecycleService } from "../../spawner/agent-lifecycle.service";
import { KillswitchService } from "../../killswitch/killswitch.service";
import { SpawnAgentDto, SpawnAgentResponseDto } from "./dto/spawn-agent.dto";
import { OrchestratorApiKeyGuard } from "../../common/guards/api-key.guard";
import { OrchestratorThrottlerGuard } from "../../common/guards/throttler.guard";
/**
* Controller for agent management endpoints
*
* All endpoints require API key authentication via X-API-Key header.
* Set ORCHESTRATOR_API_KEY environment variable to configure the expected key.
*
* Rate limits:
* - Status endpoints: 200 requests/minute
* - Spawn/kill endpoints: 10 requests/minute (strict)
* - Default: 100 requests/minute
*/
@Controller("agents")
@UseGuards(OrchestratorApiKeyGuard)
@UseGuards(OrchestratorApiKeyGuard, OrchestratorThrottlerGuard)
export class AgentsController {
private readonly logger = new Logger(AgentsController.name);
@@ -43,6 +50,7 @@ export class AgentsController {
* @returns Agent spawn response with agentId and status
*/
@Post("spawn")
@Throttle({ strict: { limit: 10, ttl: 60000 } })
@UsePipes(new ValidationPipe({ transform: true, whitelist: true }))
async spawn(@Body() dto: SpawnAgentDto): Promise<SpawnAgentResponseDto> {
this.logger.log(`Received spawn request for task: ${dto.taskId}`);
@@ -81,6 +89,7 @@ export class AgentsController {
* @returns Array of all agent sessions with their status
*/
@Get()
@Throttle({ status: { limit: 200, ttl: 60000 } })
listAgents(): {
agentId: string;
taskId: string;
@@ -123,6 +132,7 @@ export class AgentsController {
* @returns Agent status details
*/
@Get(":agentId/status")
@Throttle({ status: { limit: 200, ttl: 60000 } })
async getAgentStatus(@Param("agentId") agentId: string): Promise<{
agentId: string;
taskId: string;
@@ -181,6 +191,7 @@ export class AgentsController {
* @returns Success message
*/
@Post(":agentId/kill")
@Throttle({ strict: { limit: 10, ttl: 60000 } })
@HttpCode(200)
async killAgent(@Param("agentId") agentId: string): Promise<{ message: string }> {
this.logger.warn(`Received kill request for agent: ${agentId}`);
@@ -204,6 +215,7 @@ export class AgentsController {
* @returns Summary of kill operation
*/
@Post("kill-all")
@Throttle({ strict: { limit: 10, ttl: 60000 } })
@HttpCode(200)
async killAllAgents(): Promise<{
message: string;

View File

@@ -1,12 +1,22 @@
import { Controller, Get } from "@nestjs/common";
import { Controller, Get, UseGuards } from "@nestjs/common";
import { Throttle } from "@nestjs/throttler";
import { HealthService } from "./health.service";
import { OrchestratorThrottlerGuard } from "../../common/guards/throttler.guard";
/**
* Health check controller for orchestrator service
*
* Rate limits:
* - Health endpoints: 200 requests/minute (higher for monitoring)
*/
@Controller("health")
@UseGuards(OrchestratorThrottlerGuard)
export class HealthController {
constructor(private readonly healthService: HealthService) {}
@Get()
check() {
@Throttle({ status: { limit: 200, ttl: 60000 } })
check(): { status: string; uptime: number; timestamp: string } {
return {
status: "healthy",
uptime: this.healthService.getUptime(),
@@ -15,7 +25,8 @@ export class HealthController {
}
@Get("ready")
ready() {
@Throttle({ status: { limit: 200, ttl: 60000 } })
ready(): { ready: boolean } {
// NOTE: Check Valkey connection, Docker daemon (see issue #TBD)
return { ready: true };
}

View File

@@ -1,7 +1,9 @@
import { Module } from "@nestjs/common";
import { HealthController } from "./health.controller";
import { HealthService } from "./health.service";
@Module({
controllers: [HealthController],
providers: [HealthService],
})
export class HealthModule {}

View File

@@ -1,12 +1,19 @@
import { Module } from "@nestjs/common";
import { ConfigModule } from "@nestjs/config";
import { BullModule } from "@nestjs/bullmq";
import { ThrottlerModule } from "@nestjs/throttler";
import { HealthModule } from "./api/health/health.module";
import { AgentsModule } from "./api/agents/agents.module";
import { CoordinatorModule } from "./coordinator/coordinator.module";
import { BudgetModule } from "./budget/budget.module";
import { orchestratorConfig } from "./config/orchestrator.config";
/**
* Rate limiting configuration:
* - 'default': Standard API endpoints (100 requests per minute)
* - 'strict': Spawn/kill endpoints (10 requests per minute) - prevents DoS
* - 'status': Status/health endpoints (200 requests per minute) - higher for polling
*/
@Module({
imports: [
ConfigModule.forRoot({
@@ -19,6 +26,23 @@ import { orchestratorConfig } from "./config/orchestrator.config";
port: parseInt(process.env.VALKEY_PORT ?? "6379"),
},
}),
ThrottlerModule.forRoot([
{
name: "default",
ttl: 60000, // 1 minute
limit: 100, // 100 requests per minute
},
{
name: "strict",
ttl: 60000, // 1 minute
limit: 10, // 10 requests per minute for spawn/kill
},
{
name: "status",
ttl: 60000, // 1 minute
limit: 200, // 200 requests per minute for status endpoints
},
]),
HealthModule,
AgentsModule,
CoordinatorModule,

View File

@@ -0,0 +1,122 @@
import { describe, it, expect, beforeEach, vi } from "vitest";
import { ExecutionContext } from "@nestjs/common";
import { ThrottlerException, ThrottlerModuleOptions, ThrottlerStorage } from "@nestjs/throttler";
import { Reflector } from "@nestjs/core";
import { OrchestratorThrottlerGuard } from "./throttler.guard";
describe("OrchestratorThrottlerGuard", () => {
let guard: OrchestratorThrottlerGuard;
beforeEach(() => {
// Create guard with minimal mocks for testing protected methods
const options: ThrottlerModuleOptions = {
throttlers: [{ name: "default", ttl: 60000, limit: 100 }],
};
const storageService = {} as ThrottlerStorage;
const reflector = {} as Reflector;
guard = new OrchestratorThrottlerGuard(options, storageService, reflector);
});
describe("getTracker", () => {
it("should extract IP from X-Forwarded-For header", async () => {
const req = {
headers: {
"x-forwarded-for": "192.168.1.1, 10.0.0.1",
},
ip: "127.0.0.1",
};
// Access protected method for testing
const tracker = await (
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
).getTracker(req);
expect(tracker).toBe("192.168.1.1");
});
it("should handle X-Forwarded-For as array", async () => {
const req = {
headers: {
"x-forwarded-for": ["192.168.1.1, 10.0.0.1"],
},
ip: "127.0.0.1",
};
const tracker = await (
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
).getTracker(req);
expect(tracker).toBe("192.168.1.1");
});
it("should fallback to request IP when no X-Forwarded-For", async () => {
const req = {
headers: {},
ip: "192.168.2.2",
};
const tracker = await (
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
).getTracker(req);
expect(tracker).toBe("192.168.2.2");
});
it("should fallback to connection remoteAddress when no IP", async () => {
const req = {
headers: {},
connection: {
remoteAddress: "192.168.3.3",
},
};
const tracker = await (
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
).getTracker(req);
expect(tracker).toBe("192.168.3.3");
});
it("should return 'unknown' when no IP available", async () => {
const req = {
headers: {},
};
const tracker = await (
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
).getTracker(req);
expect(tracker).toBe("unknown");
});
});
describe("throwThrottlingException", () => {
it("should throw ThrottlerException with endpoint info", () => {
const mockRequest = {
url: "/agents/spawn",
};
const mockContext = {
switchToHttp: vi.fn().mockReturnValue({
getRequest: vi.fn().mockReturnValue(mockRequest),
}),
} as unknown as ExecutionContext;
expect(() => {
(
guard as unknown as { throwThrottlingException: (context: ExecutionContext) => void }
).throwThrottlingException(mockContext);
}).toThrow(ThrottlerException);
try {
(
guard as unknown as { throwThrottlingException: (context: ExecutionContext) => void }
).throwThrottlingException(mockContext);
} catch (error) {
expect(error).toBeInstanceOf(ThrottlerException);
expect((error as ThrottlerException).message).toContain("/agents/spawn");
}
});
});
});

View File

@@ -0,0 +1,63 @@
import { Injectable, ExecutionContext } from "@nestjs/common";
import { ThrottlerGuard, ThrottlerException } from "@nestjs/throttler";
interface RequestWithHeaders {
headers?: Record<string, string | string[]>;
ip?: string;
connection?: { remoteAddress?: string };
url?: string;
}
/**
* OrchestratorThrottlerGuard - Rate limiting guard for orchestrator API endpoints
*
* Uses the X-Forwarded-For header for client IP identification when behind a proxy,
* falling back to the direct connection IP.
*
* Usage:
* @UseGuards(OrchestratorThrottlerGuard)
* @Controller('agents')
* export class AgentsController { ... }
*/
@Injectable()
export class OrchestratorThrottlerGuard extends ThrottlerGuard {
/**
* Get the client IP address for rate limiting tracking
* Prioritizes X-Forwarded-For header for proxy setups
*/
protected getTracker(req: Record<string, unknown>): Promise<string> {
const request = req as RequestWithHeaders;
const headers = request.headers;
// Check X-Forwarded-For header first (for proxied requests)
if (headers) {
const forwardedFor = headers["x-forwarded-for"];
if (forwardedFor) {
// Get the first IP in the chain (original client)
const ips = Array.isArray(forwardedFor) ? forwardedFor[0] : forwardedFor;
if (ips) {
const clientIp = ips.split(",")[0]?.trim();
if (clientIp) {
return Promise.resolve(clientIp);
}
}
}
}
// Fallback to direct connection IP
const ip = request.ip ?? request.connection?.remoteAddress ?? "unknown";
return Promise.resolve(ip);
}
/**
* Custom error message for rate limit exceeded
*/
protected throwThrottlingException(context: ExecutionContext): Promise<void> {
const request = context.switchToHttp().getRequest<RequestWithHeaders>();
const endpoint = request.url ?? "unknown";
throw new ThrottlerException(
`Rate limit exceeded for endpoint ${endpoint}. Please try again later.`
);
}
}