fix(#338): Add rate limiting to orchestrator API
- Add @nestjs/throttler for rate limiting support - Configure multiple throttle profiles: default (100/min), strict (10/min for spawn/kill), status (200/min for polling) - Apply strict rate limits to spawn and kill endpoints to prevent DoS - Apply higher rate limits to status/health endpoints for monitoring - Add OrchestratorThrottlerGuard with X-Forwarded-For support for proxy setups - Add unit tests for throttler guard Refs #338 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
122
apps/orchestrator/src/common/guards/throttler.guard.spec.ts
Normal file
122
apps/orchestrator/src/common/guards/throttler.guard.spec.ts
Normal file
@@ -0,0 +1,122 @@
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { ExecutionContext } from "@nestjs/common";
|
||||
import { ThrottlerException, ThrottlerModuleOptions, ThrottlerStorage } from "@nestjs/throttler";
|
||||
import { Reflector } from "@nestjs/core";
|
||||
import { OrchestratorThrottlerGuard } from "./throttler.guard";
|
||||
|
||||
describe("OrchestratorThrottlerGuard", () => {
|
||||
let guard: OrchestratorThrottlerGuard;
|
||||
|
||||
beforeEach(() => {
|
||||
// Create guard with minimal mocks for testing protected methods
|
||||
const options: ThrottlerModuleOptions = {
|
||||
throttlers: [{ name: "default", ttl: 60000, limit: 100 }],
|
||||
};
|
||||
const storageService = {} as ThrottlerStorage;
|
||||
const reflector = {} as Reflector;
|
||||
|
||||
guard = new OrchestratorThrottlerGuard(options, storageService, reflector);
|
||||
});
|
||||
|
||||
describe("getTracker", () => {
|
||||
it("should extract IP from X-Forwarded-For header", async () => {
|
||||
const req = {
|
||||
headers: {
|
||||
"x-forwarded-for": "192.168.1.1, 10.0.0.1",
|
||||
},
|
||||
ip: "127.0.0.1",
|
||||
};
|
||||
|
||||
// Access protected method for testing
|
||||
const tracker = await (
|
||||
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
|
||||
).getTracker(req);
|
||||
|
||||
expect(tracker).toBe("192.168.1.1");
|
||||
});
|
||||
|
||||
it("should handle X-Forwarded-For as array", async () => {
|
||||
const req = {
|
||||
headers: {
|
||||
"x-forwarded-for": ["192.168.1.1, 10.0.0.1"],
|
||||
},
|
||||
ip: "127.0.0.1",
|
||||
};
|
||||
|
||||
const tracker = await (
|
||||
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
|
||||
).getTracker(req);
|
||||
|
||||
expect(tracker).toBe("192.168.1.1");
|
||||
});
|
||||
|
||||
it("should fallback to request IP when no X-Forwarded-For", async () => {
|
||||
const req = {
|
||||
headers: {},
|
||||
ip: "192.168.2.2",
|
||||
};
|
||||
|
||||
const tracker = await (
|
||||
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
|
||||
).getTracker(req);
|
||||
|
||||
expect(tracker).toBe("192.168.2.2");
|
||||
});
|
||||
|
||||
it("should fallback to connection remoteAddress when no IP", async () => {
|
||||
const req = {
|
||||
headers: {},
|
||||
connection: {
|
||||
remoteAddress: "192.168.3.3",
|
||||
},
|
||||
};
|
||||
|
||||
const tracker = await (
|
||||
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
|
||||
).getTracker(req);
|
||||
|
||||
expect(tracker).toBe("192.168.3.3");
|
||||
});
|
||||
|
||||
it("should return 'unknown' when no IP available", async () => {
|
||||
const req = {
|
||||
headers: {},
|
||||
};
|
||||
|
||||
const tracker = await (
|
||||
guard as unknown as { getTracker: (req: unknown) => Promise<string> }
|
||||
).getTracker(req);
|
||||
|
||||
expect(tracker).toBe("unknown");
|
||||
});
|
||||
});
|
||||
|
||||
describe("throwThrottlingException", () => {
|
||||
it("should throw ThrottlerException with endpoint info", () => {
|
||||
const mockRequest = {
|
||||
url: "/agents/spawn",
|
||||
};
|
||||
|
||||
const mockContext = {
|
||||
switchToHttp: vi.fn().mockReturnValue({
|
||||
getRequest: vi.fn().mockReturnValue(mockRequest),
|
||||
}),
|
||||
} as unknown as ExecutionContext;
|
||||
|
||||
expect(() => {
|
||||
(
|
||||
guard as unknown as { throwThrottlingException: (context: ExecutionContext) => void }
|
||||
).throwThrottlingException(mockContext);
|
||||
}).toThrow(ThrottlerException);
|
||||
|
||||
try {
|
||||
(
|
||||
guard as unknown as { throwThrottlingException: (context: ExecutionContext) => void }
|
||||
).throwThrottlingException(mockContext);
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(ThrottlerException);
|
||||
expect((error as ThrottlerException).message).toContain("/agents/spawn");
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
63
apps/orchestrator/src/common/guards/throttler.guard.ts
Normal file
63
apps/orchestrator/src/common/guards/throttler.guard.ts
Normal file
@@ -0,0 +1,63 @@
|
||||
import { Injectable, ExecutionContext } from "@nestjs/common";
|
||||
import { ThrottlerGuard, ThrottlerException } from "@nestjs/throttler";
|
||||
|
||||
interface RequestWithHeaders {
|
||||
headers?: Record<string, string | string[]>;
|
||||
ip?: string;
|
||||
connection?: { remoteAddress?: string };
|
||||
url?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* OrchestratorThrottlerGuard - Rate limiting guard for orchestrator API endpoints
|
||||
*
|
||||
* Uses the X-Forwarded-For header for client IP identification when behind a proxy,
|
||||
* falling back to the direct connection IP.
|
||||
*
|
||||
* Usage:
|
||||
* @UseGuards(OrchestratorThrottlerGuard)
|
||||
* @Controller('agents')
|
||||
* export class AgentsController { ... }
|
||||
*/
|
||||
@Injectable()
|
||||
export class OrchestratorThrottlerGuard extends ThrottlerGuard {
|
||||
/**
|
||||
* Get the client IP address for rate limiting tracking
|
||||
* Prioritizes X-Forwarded-For header for proxy setups
|
||||
*/
|
||||
protected getTracker(req: Record<string, unknown>): Promise<string> {
|
||||
const request = req as RequestWithHeaders;
|
||||
const headers = request.headers;
|
||||
|
||||
// Check X-Forwarded-For header first (for proxied requests)
|
||||
if (headers) {
|
||||
const forwardedFor = headers["x-forwarded-for"];
|
||||
if (forwardedFor) {
|
||||
// Get the first IP in the chain (original client)
|
||||
const ips = Array.isArray(forwardedFor) ? forwardedFor[0] : forwardedFor;
|
||||
if (ips) {
|
||||
const clientIp = ips.split(",")[0]?.trim();
|
||||
if (clientIp) {
|
||||
return Promise.resolve(clientIp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to direct connection IP
|
||||
const ip = request.ip ?? request.connection?.remoteAddress ?? "unknown";
|
||||
return Promise.resolve(ip);
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom error message for rate limit exceeded
|
||||
*/
|
||||
protected throwThrottlingException(context: ExecutionContext): Promise<void> {
|
||||
const request = context.switchToHttp().getRequest<RequestWithHeaders>();
|
||||
const endpoint = request.url ?? "unknown";
|
||||
|
||||
throw new ThrottlerException(
|
||||
`Rate limit exceeded for endpoint ${endpoint}. Please try again later.`
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user