fix(#338): Add rate limiting to orchestrator API
- Add @nestjs/throttler for rate limiting support - Configure multiple throttle profiles: default (100/min), strict (10/min for spawn/kill), status (200/min for polling) - Apply strict rate limits to spawn and kill endpoints to prevent DoS - Apply higher rate limits to status/health endpoints for monitoring - Add OrchestratorThrottlerGuard with X-Forwarded-For support for proxy setups - Add unit tests for throttler guard Refs #338 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -26,6 +26,7 @@
|
|||||||
"@nestjs/config": "^4.0.2",
|
"@nestjs/config": "^4.0.2",
|
||||||
"@nestjs/core": "^11.1.12",
|
"@nestjs/core": "^11.1.12",
|
||||||
"@nestjs/platform-express": "^11.1.12",
|
"@nestjs/platform-express": "^11.1.12",
|
||||||
|
"@nestjs/throttler": "^6.5.0",
|
||||||
"bullmq": "^5.67.2",
|
"bullmq": "^5.67.2",
|
||||||
"class-transformer": "^0.5.1",
|
"class-transformer": "^0.5.1",
|
||||||
"class-validator": "^0.14.1",
|
"class-validator": "^0.14.1",
|
||||||
|
|||||||
@@ -12,21 +12,28 @@ import {
|
|||||||
HttpCode,
|
HttpCode,
|
||||||
UseGuards,
|
UseGuards,
|
||||||
} from "@nestjs/common";
|
} from "@nestjs/common";
|
||||||
|
import { Throttle } from "@nestjs/throttler";
|
||||||
import { QueueService } from "../../queue/queue.service";
|
import { QueueService } from "../../queue/queue.service";
|
||||||
import { AgentSpawnerService } from "../../spawner/agent-spawner.service";
|
import { AgentSpawnerService } from "../../spawner/agent-spawner.service";
|
||||||
import { AgentLifecycleService } from "../../spawner/agent-lifecycle.service";
|
import { AgentLifecycleService } from "../../spawner/agent-lifecycle.service";
|
||||||
import { KillswitchService } from "../../killswitch/killswitch.service";
|
import { KillswitchService } from "../../killswitch/killswitch.service";
|
||||||
import { SpawnAgentDto, SpawnAgentResponseDto } from "./dto/spawn-agent.dto";
|
import { SpawnAgentDto, SpawnAgentResponseDto } from "./dto/spawn-agent.dto";
|
||||||
import { OrchestratorApiKeyGuard } from "../../common/guards/api-key.guard";
|
import { OrchestratorApiKeyGuard } from "../../common/guards/api-key.guard";
|
||||||
|
import { OrchestratorThrottlerGuard } from "../../common/guards/throttler.guard";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Controller for agent management endpoints
|
* Controller for agent management endpoints
|
||||||
*
|
*
|
||||||
* All endpoints require API key authentication via X-API-Key header.
|
* All endpoints require API key authentication via X-API-Key header.
|
||||||
* Set ORCHESTRATOR_API_KEY environment variable to configure the expected key.
|
* Set ORCHESTRATOR_API_KEY environment variable to configure the expected key.
|
||||||
|
*
|
||||||
|
* Rate limits:
|
||||||
|
* - Status endpoints: 200 requests/minute
|
||||||
|
* - Spawn/kill endpoints: 10 requests/minute (strict)
|
||||||
|
* - Default: 100 requests/minute
|
||||||
*/
|
*/
|
||||||
@Controller("agents")
|
@Controller("agents")
|
||||||
@UseGuards(OrchestratorApiKeyGuard)
|
@UseGuards(OrchestratorApiKeyGuard, OrchestratorThrottlerGuard)
|
||||||
export class AgentsController {
|
export class AgentsController {
|
||||||
private readonly logger = new Logger(AgentsController.name);
|
private readonly logger = new Logger(AgentsController.name);
|
||||||
|
|
||||||
@@ -43,6 +50,7 @@ export class AgentsController {
|
|||||||
* @returns Agent spawn response with agentId and status
|
* @returns Agent spawn response with agentId and status
|
||||||
*/
|
*/
|
||||||
@Post("spawn")
|
@Post("spawn")
|
||||||
|
@Throttle({ strict: { limit: 10, ttl: 60000 } })
|
||||||
@UsePipes(new ValidationPipe({ transform: true, whitelist: true }))
|
@UsePipes(new ValidationPipe({ transform: true, whitelist: true }))
|
||||||
async spawn(@Body() dto: SpawnAgentDto): Promise<SpawnAgentResponseDto> {
|
async spawn(@Body() dto: SpawnAgentDto): Promise<SpawnAgentResponseDto> {
|
||||||
this.logger.log(`Received spawn request for task: ${dto.taskId}`);
|
this.logger.log(`Received spawn request for task: ${dto.taskId}`);
|
||||||
@@ -81,6 +89,7 @@ export class AgentsController {
|
|||||||
* @returns Array of all agent sessions with their status
|
* @returns Array of all agent sessions with their status
|
||||||
*/
|
*/
|
||||||
@Get()
|
@Get()
|
||||||
|
@Throttle({ status: { limit: 200, ttl: 60000 } })
|
||||||
listAgents(): {
|
listAgents(): {
|
||||||
agentId: string;
|
agentId: string;
|
||||||
taskId: string;
|
taskId: string;
|
||||||
@@ -123,6 +132,7 @@ export class AgentsController {
|
|||||||
* @returns Agent status details
|
* @returns Agent status details
|
||||||
*/
|
*/
|
||||||
@Get(":agentId/status")
|
@Get(":agentId/status")
|
||||||
|
@Throttle({ status: { limit: 200, ttl: 60000 } })
|
||||||
async getAgentStatus(@Param("agentId") agentId: string): Promise<{
|
async getAgentStatus(@Param("agentId") agentId: string): Promise<{
|
||||||
agentId: string;
|
agentId: string;
|
||||||
taskId: string;
|
taskId: string;
|
||||||
@@ -181,6 +191,7 @@ export class AgentsController {
|
|||||||
* @returns Success message
|
* @returns Success message
|
||||||
*/
|
*/
|
||||||
@Post(":agentId/kill")
|
@Post(":agentId/kill")
|
||||||
|
@Throttle({ strict: { limit: 10, ttl: 60000 } })
|
||||||
@HttpCode(200)
|
@HttpCode(200)
|
||||||
async killAgent(@Param("agentId") agentId: string): Promise<{ message: string }> {
|
async killAgent(@Param("agentId") agentId: string): Promise<{ message: string }> {
|
||||||
this.logger.warn(`Received kill request for agent: ${agentId}`);
|
this.logger.warn(`Received kill request for agent: ${agentId}`);
|
||||||
@@ -204,6 +215,7 @@ export class AgentsController {
|
|||||||
* @returns Summary of kill operation
|
* @returns Summary of kill operation
|
||||||
*/
|
*/
|
||||||
@Post("kill-all")
|
@Post("kill-all")
|
||||||
|
@Throttle({ strict: { limit: 10, ttl: 60000 } })
|
||||||
@HttpCode(200)
|
@HttpCode(200)
|
||||||
async killAllAgents(): Promise<{
|
async killAllAgents(): Promise<{
|
||||||
message: string;
|
message: string;
|
||||||
|
|||||||
@@ -1,12 +1,22 @@
|
|||||||
import { Controller, Get } from "@nestjs/common";
|
import { Controller, Get, UseGuards } from "@nestjs/common";
|
||||||
|
import { Throttle } from "@nestjs/throttler";
|
||||||
import { HealthService } from "./health.service";
|
import { HealthService } from "./health.service";
|
||||||
|
import { OrchestratorThrottlerGuard } from "../../common/guards/throttler.guard";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Health check controller for orchestrator service
|
||||||
|
*
|
||||||
|
* Rate limits:
|
||||||
|
* - Health endpoints: 200 requests/minute (higher for monitoring)
|
||||||
|
*/
|
||||||
@Controller("health")
|
@Controller("health")
|
||||||
|
@UseGuards(OrchestratorThrottlerGuard)
|
||||||
export class HealthController {
|
export class HealthController {
|
||||||
constructor(private readonly healthService: HealthService) {}
|
constructor(private readonly healthService: HealthService) {}
|
||||||
|
|
||||||
@Get()
|
@Get()
|
||||||
check() {
|
@Throttle({ status: { limit: 200, ttl: 60000 } })
|
||||||
|
check(): { status: string; uptime: number; timestamp: string } {
|
||||||
return {
|
return {
|
||||||
status: "healthy",
|
status: "healthy",
|
||||||
uptime: this.healthService.getUptime(),
|
uptime: this.healthService.getUptime(),
|
||||||
@@ -15,7 +25,8 @@ export class HealthController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Get("ready")
|
@Get("ready")
|
||||||
ready() {
|
@Throttle({ status: { limit: 200, ttl: 60000 } })
|
||||||
|
ready(): { ready: boolean } {
|
||||||
// NOTE: Check Valkey connection, Docker daemon (see issue #TBD)
|
// NOTE: Check Valkey connection, Docker daemon (see issue #TBD)
|
||||||
return { ready: true };
|
return { ready: true };
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import { Module } from "@nestjs/common";
|
import { Module } from "@nestjs/common";
|
||||||
import { HealthController } from "./health.controller";
|
import { HealthController } from "./health.controller";
|
||||||
|
import { HealthService } from "./health.service";
|
||||||
|
|
||||||
@Module({
|
@Module({
|
||||||
controllers: [HealthController],
|
controllers: [HealthController],
|
||||||
|
providers: [HealthService],
|
||||||
})
|
})
|
||||||
export class HealthModule {}
|
export class HealthModule {}
|
||||||
|
|||||||
@@ -1,12 +1,19 @@
|
|||||||
import { Module } from "@nestjs/common";
|
import { Module } from "@nestjs/common";
|
||||||
import { ConfigModule } from "@nestjs/config";
|
import { ConfigModule } from "@nestjs/config";
|
||||||
import { BullModule } from "@nestjs/bullmq";
|
import { BullModule } from "@nestjs/bullmq";
|
||||||
|
import { ThrottlerModule } from "@nestjs/throttler";
|
||||||
import { HealthModule } from "./api/health/health.module";
|
import { HealthModule } from "./api/health/health.module";
|
||||||
import { AgentsModule } from "./api/agents/agents.module";
|
import { AgentsModule } from "./api/agents/agents.module";
|
||||||
import { CoordinatorModule } from "./coordinator/coordinator.module";
|
import { CoordinatorModule } from "./coordinator/coordinator.module";
|
||||||
import { BudgetModule } from "./budget/budget.module";
|
import { BudgetModule } from "./budget/budget.module";
|
||||||
import { orchestratorConfig } from "./config/orchestrator.config";
|
import { orchestratorConfig } from "./config/orchestrator.config";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rate limiting configuration:
|
||||||
|
* - 'default': Standard API endpoints (100 requests per minute)
|
||||||
|
* - 'strict': Spawn/kill endpoints (10 requests per minute) - prevents DoS
|
||||||
|
* - 'status': Status/health endpoints (200 requests per minute) - higher for polling
|
||||||
|
*/
|
||||||
@Module({
|
@Module({
|
||||||
imports: [
|
imports: [
|
||||||
ConfigModule.forRoot({
|
ConfigModule.forRoot({
|
||||||
@@ -19,6 +26,23 @@ import { orchestratorConfig } from "./config/orchestrator.config";
|
|||||||
port: parseInt(process.env.VALKEY_PORT ?? "6379"),
|
port: parseInt(process.env.VALKEY_PORT ?? "6379"),
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
ThrottlerModule.forRoot([
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
ttl: 60000, // 1 minute
|
||||||
|
limit: 100, // 100 requests per minute
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "strict",
|
||||||
|
ttl: 60000, // 1 minute
|
||||||
|
limit: 10, // 10 requests per minute for spawn/kill
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "status",
|
||||||
|
ttl: 60000, // 1 minute
|
||||||
|
limit: 200, // 200 requests per minute for status endpoints
|
||||||
|
},
|
||||||
|
]),
|
||||||
HealthModule,
|
HealthModule,
|
||||||
AgentsModule,
|
AgentsModule,
|
||||||
CoordinatorModule,
|
CoordinatorModule,
|
||||||
|
|||||||
122
apps/orchestrator/src/common/guards/throttler.guard.spec.ts
Normal file
122
apps/orchestrator/src/common/guards/throttler.guard.spec.ts
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||||
|
import { ExecutionContext } from "@nestjs/common";
|
||||||
|
import { ThrottlerException, ThrottlerModuleOptions, ThrottlerStorage } from "@nestjs/throttler";
|
||||||
|
import { Reflector } from "@nestjs/core";
|
||||||
|
import { OrchestratorThrottlerGuard } from "./throttler.guard";
|
||||||
|
|
||||||
|
describe("OrchestratorThrottlerGuard", () => {
|
||||||
|
let guard: OrchestratorThrottlerGuard;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
// Create guard with minimal mocks for testing protected methods
|
||||||
|
const options: ThrottlerModuleOptions = {
|
||||||
|
throttlers: [{ name: "default", ttl: 60000, limit: 100 }],
|
||||||
|
};
|
||||||
|
const storageService = {} as ThrottlerStorage;
|
||||||
|
const reflector = {} as Reflector;
|
||||||
|
|
||||||
|
guard = new OrchestratorThrottlerGuard(options, storageService, reflector);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("getTracker", () => {
|
||||||
|
it("should extract IP from X-Forwarded-For header", async () => {
|
||||||
|
const req = {
|
||||||
|
headers: {
|
||||||
|
"x-forwarded-for": "192.168.1.1, 10.0.0.1",
|
||||||
|
},
|
||||||
|
ip: "127.0.0.1",
|
||||||
|
};
|
||||||
|
|
||||||
|
// Access protected method for testing
|
||||||
|
const tracker = await (
|
||||||
|
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
|
||||||
|
).getTracker(req);
|
||||||
|
|
||||||
|
expect(tracker).toBe("192.168.1.1");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should handle X-Forwarded-For as array", async () => {
|
||||||
|
const req = {
|
||||||
|
headers: {
|
||||||
|
"x-forwarded-for": ["192.168.1.1, 10.0.0.1"],
|
||||||
|
},
|
||||||
|
ip: "127.0.0.1",
|
||||||
|
};
|
||||||
|
|
||||||
|
const tracker = await (
|
||||||
|
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
|
||||||
|
).getTracker(req);
|
||||||
|
|
||||||
|
expect(tracker).toBe("192.168.1.1");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should fallback to request IP when no X-Forwarded-For", async () => {
|
||||||
|
const req = {
|
||||||
|
headers: {},
|
||||||
|
ip: "192.168.2.2",
|
||||||
|
};
|
||||||
|
|
||||||
|
const tracker = await (
|
||||||
|
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
|
||||||
|
).getTracker(req);
|
||||||
|
|
||||||
|
expect(tracker).toBe("192.168.2.2");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should fallback to connection remoteAddress when no IP", async () => {
|
||||||
|
const req = {
|
||||||
|
headers: {},
|
||||||
|
connection: {
|
||||||
|
remoteAddress: "192.168.3.3",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const tracker = await (
|
||||||
|
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
|
||||||
|
).getTracker(req);
|
||||||
|
|
||||||
|
expect(tracker).toBe("192.168.3.3");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should return 'unknown' when no IP available", async () => {
|
||||||
|
const req = {
|
||||||
|
headers: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
const tracker = await (
|
||||||
|
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
|
||||||
|
).getTracker(req);
|
||||||
|
|
||||||
|
expect(tracker).toBe("unknown");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("throwThrottlingException", () => {
|
||||||
|
it("should throw ThrottlerException with endpoint info", () => {
|
||||||
|
const mockRequest = {
|
||||||
|
url: "/agents/spawn",
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockContext = {
|
||||||
|
switchToHttp: vi.fn().mockReturnValue({
|
||||||
|
getRequest: vi.fn().mockReturnValue(mockRequest),
|
||||||
|
}),
|
||||||
|
} as unknown as ExecutionContext;
|
||||||
|
|
||||||
|
expect(() => {
|
||||||
|
(
|
||||||
|
guard as unknown as { throwThrottlingException: (context: ExecutionContext) => void }
|
||||||
|
).throwThrottlingException(mockContext);
|
||||||
|
}).toThrow(ThrottlerException);
|
||||||
|
|
||||||
|
try {
|
||||||
|
(
|
||||||
|
guard as unknown as { throwThrottlingException: (context: ExecutionContext) => void }
|
||||||
|
).throwThrottlingException(mockContext);
|
||||||
|
} catch (error) {
|
||||||
|
expect(error).toBeInstanceOf(ThrottlerException);
|
||||||
|
expect((error as ThrottlerException).message).toContain("/agents/spawn");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
63
apps/orchestrator/src/common/guards/throttler.guard.ts
Normal file
63
apps/orchestrator/src/common/guards/throttler.guard.ts
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import { Injectable, ExecutionContext } from "@nestjs/common";
|
||||||
|
import { ThrottlerGuard, ThrottlerException } from "@nestjs/throttler";
|
||||||
|
|
||||||
|
interface RequestWithHeaders {
|
||||||
|
headers?: Record<string, string | string[]>;
|
||||||
|
ip?: string;
|
||||||
|
connection?: { remoteAddress?: string };
|
||||||
|
url?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OrchestratorThrottlerGuard - Rate limiting guard for orchestrator API endpoints
|
||||||
|
*
|
||||||
|
* Uses the X-Forwarded-For header for client IP identification when behind a proxy,
|
||||||
|
* falling back to the direct connection IP.
|
||||||
|
*
|
||||||
|
* Usage:
|
||||||
|
* @UseGuards(OrchestratorThrottlerGuard)
|
||||||
|
* @Controller('agents')
|
||||||
|
* export class AgentsController { ... }
|
||||||
|
*/
|
||||||
|
@Injectable()
|
||||||
|
export class OrchestratorThrottlerGuard extends ThrottlerGuard {
|
||||||
|
/**
|
||||||
|
* Get the client IP address for rate limiting tracking
|
||||||
|
* Prioritizes X-Forwarded-For header for proxy setups
|
||||||
|
*/
|
||||||
|
protected getTracker(req: Record<string, unknown>): Promise<string> {
|
||||||
|
const request = req as RequestWithHeaders;
|
||||||
|
const headers = request.headers;
|
||||||
|
|
||||||
|
// Check X-Forwarded-For header first (for proxied requests)
|
||||||
|
if (headers) {
|
||||||
|
const forwardedFor = headers["x-forwarded-for"];
|
||||||
|
if (forwardedFor) {
|
||||||
|
// Get the first IP in the chain (original client)
|
||||||
|
const ips = Array.isArray(forwardedFor) ? forwardedFor[0] : forwardedFor;
|
||||||
|
if (ips) {
|
||||||
|
const clientIp = ips.split(",")[0]?.trim();
|
||||||
|
if (clientIp) {
|
||||||
|
return Promise.resolve(clientIp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to direct connection IP
|
||||||
|
const ip = request.ip ?? request.connection?.remoteAddress ?? "unknown";
|
||||||
|
return Promise.resolve(ip);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Custom error message for rate limit exceeded
|
||||||
|
*/
|
||||||
|
protected throwThrottlingException(context: ExecutionContext): Promise<void> {
|
||||||
|
const request = context.switchToHttp().getRequest<RequestWithHeaders>();
|
||||||
|
const endpoint = request.url ?? "unknown";
|
||||||
|
|
||||||
|
throw new ThrottlerException(
|
||||||
|
`Rate limit exceeded for endpoint ${endpoint}. Please try again later.`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
5
pnpm-lock.yaml
generated
5
pnpm-lock.yaml
generated
@@ -292,6 +292,9 @@ importers:
|
|||||||
'@nestjs/platform-express':
|
'@nestjs/platform-express':
|
||||||
specifier: ^11.1.12
|
specifier: ^11.1.12
|
||||||
version: 11.1.12(@nestjs/common@11.1.12(class-transformer@0.5.1)(class-validator@0.14.3)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.12)
|
version: 11.1.12(@nestjs/common@11.1.12(class-transformer@0.5.1)(class-validator@0.14.3)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.12)
|
||||||
|
'@nestjs/throttler':
|
||||||
|
specifier: ^6.5.0
|
||||||
|
version: 6.5.0(@nestjs/common@11.1.12(class-transformer@0.5.1)(class-validator@0.14.3)(reflect-metadata@0.2.2)(rxjs@7.8.2))(@nestjs/core@11.1.12)(reflect-metadata@0.2.2)
|
||||||
bullmq:
|
bullmq:
|
||||||
specifier: ^5.67.2
|
specifier: ^5.67.2
|
||||||
version: 5.67.2
|
version: 5.67.2
|
||||||
@@ -454,6 +457,8 @@ importers:
|
|||||||
specifier: ^3.0.8
|
specifier: ^3.0.8
|
||||||
version: 3.2.4(@types/node@22.19.7)(jiti@2.6.1)(jsdom@26.1.0)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)
|
version: 3.2.4(@types/node@22.19.7)(jiti@2.6.1)(jsdom@26.1.0)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)
|
||||||
|
|
||||||
|
packages/cli-tools: {}
|
||||||
|
|
||||||
packages/config:
|
packages/config:
|
||||||
dependencies:
|
dependencies:
|
||||||
'@eslint/js':
|
'@eslint/js':
|
||||||
|
|||||||
Reference in New Issue
Block a user