Compare commits

..

2 Commits

Author SHA1 Message Date
869bf88104 fix(api): add AuthModule to OrchestratorModule — prevents AuthGuard DI crash
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
2026-03-01 13:21:53 -06:00
d2e645a9b2 fix(api): skip throttler for widget polling; add orchestrator agents/events endpoints
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
2026-03-01 13:20:25 -06:00
9 changed files with 396 additions and 50 deletions

View File

@@ -58,6 +58,7 @@ import { ContainerReaperModule } from "./container-reaper/container-reaper.modul
import { FleetSettingsModule } from "./fleet-settings/fleet-settings.module";
import { OnboardingModule } from "./onboarding/onboarding.module";
import { ChatProxyModule } from "./chat-proxy/chat-proxy.module";
import { OrchestratorModule } from "./orchestrator/orchestrator.module";
@Module({
imports: [
@@ -137,6 +138,7 @@ import { ChatProxyModule } from "./chat-proxy/chat-proxy.module";
FleetSettingsModule,
OnboardingModule,
ChatProxyModule,
OrchestratorModule,
],
controllers: [AppController, CsrfController],
providers: [

View File

@@ -0,0 +1,194 @@
import { beforeEach, describe, expect, it, vi, afterEach } from "vitest";
import type { Response } from "express";
import { AgentStatus } from "@prisma/client";
import { OrchestratorController } from "./orchestrator.controller";
import { PrismaService } from "../prisma/prisma.service";
import { AuthGuard } from "../auth/guards/auth.guard";
describe("OrchestratorController", () => {
const mockPrismaService = {
agent: {
findMany: vi.fn(),
},
};
let controller: OrchestratorController;
beforeEach(() => {
vi.clearAllMocks();
controller = new OrchestratorController(mockPrismaService as unknown as PrismaService);
});
afterEach(() => {
vi.useRealTimers();
});
describe("getAgents", () => {
it("returns active agents with API widget shape", async () => {
mockPrismaService.agent.findMany.mockResolvedValue([
{
id: "agent-1",
name: "Planner",
status: AgentStatus.WORKING,
role: "planner",
createdAt: new Date("2026-02-28T10:00:00.000Z"),
},
]);
const result = await controller.getAgents();
expect(result).toEqual([
{
id: "agent-1",
name: "Planner",
status: AgentStatus.WORKING,
type: "planner",
createdAt: new Date("2026-02-28T10:00:00.000Z"),
},
]);
expect(mockPrismaService.agent.findMany).toHaveBeenCalledWith({
where: {
status: {
not: AgentStatus.TERMINATED,
},
},
orderBy: {
createdAt: "desc",
},
select: {
id: true,
name: true,
status: true,
role: true,
createdAt: true,
},
});
});
it("falls back to type=agent when role is missing", async () => {
mockPrismaService.agent.findMany.mockResolvedValue([
{
id: "agent-2",
name: null,
status: AgentStatus.IDLE,
role: null,
createdAt: new Date("2026-02-28T11:00:00.000Z"),
},
]);
const result = await controller.getAgents();
expect(result[0]).toMatchObject({
id: "agent-2",
type: "agent",
});
});
});
describe("streamEvents", () => {
it("sets SSE headers and writes initial data payload", async () => {
const onHandlers: Record<string, (() => void) | undefined> = {};
const mockRes = {
setHeader: vi.fn(),
write: vi.fn(),
end: vi.fn(),
on: vi.fn((event: string, handler: () => void) => {
onHandlers[event] = handler;
return mockRes;
}),
} as unknown as Response;
mockPrismaService.agent.findMany.mockResolvedValue([
{
id: "agent-1",
name: "Worker",
status: AgentStatus.WORKING,
role: "worker",
createdAt: new Date("2026-02-28T12:00:00.000Z"),
},
]);
await controller.streamEvents(mockRes);
expect(mockRes.setHeader).toHaveBeenCalledWith("Content-Type", "text/event-stream");
expect(mockRes.setHeader).toHaveBeenCalledWith("Cache-Control", "no-cache");
expect(mockRes.setHeader).toHaveBeenCalledWith("Connection", "keep-alive");
expect(mockRes.setHeader).toHaveBeenCalledWith("X-Accel-Buffering", "no");
expect(mockRes.write).toHaveBeenCalledWith(
expect.stringContaining('"type":"agents:updated"')
);
expect(typeof onHandlers.close).toBe("function");
});
it("polls every 5 seconds and only emits when payload changes", async () => {
vi.useFakeTimers();
const onHandlers: Record<string, (() => void) | undefined> = {};
const mockRes = {
setHeader: vi.fn(),
write: vi.fn(),
end: vi.fn(),
on: vi.fn((event: string, handler: () => void) => {
onHandlers[event] = handler;
return mockRes;
}),
} as unknown as Response;
const firstPayload = [
{
id: "agent-1",
name: "Worker",
status: AgentStatus.WORKING,
role: "worker",
createdAt: new Date("2026-02-28T12:00:00.000Z"),
},
];
const secondPayload = [
{
id: "agent-1",
name: "Worker",
status: AgentStatus.WAITING,
role: "worker",
createdAt: new Date("2026-02-28T12:00:00.000Z"),
},
];
mockPrismaService.agent.findMany
.mockResolvedValueOnce(firstPayload)
.mockResolvedValueOnce(firstPayload)
.mockResolvedValueOnce(secondPayload);
await controller.streamEvents(mockRes);
// 1 initial data event
const getDataEventCalls = () =>
mockRes.write.mock.calls.filter(
(call) => typeof call[0] === "string" && call[0].startsWith("data: ")
);
expect(getDataEventCalls()).toHaveLength(1);
// No change after first poll => no new data event
await vi.advanceTimersByTimeAsync(5000);
expect(getDataEventCalls()).toHaveLength(1);
// Status changed on second poll => emits new data event
await vi.advanceTimersByTimeAsync(5000);
expect(getDataEventCalls()).toHaveLength(2);
onHandlers.close?.();
expect(mockRes.end).toHaveBeenCalledTimes(1);
});
});
describe("security", () => {
it("uses AuthGuard at the controller level", () => {
const guards = Reflect.getMetadata("__guards__", OrchestratorController) as unknown[];
const guardClasses = guards.map((guard) => guard);
expect(guardClasses).toContain(AuthGuard);
});
});
});

View File

@@ -0,0 +1,115 @@
import { Controller, Get, Res, UseGuards } from "@nestjs/common";
import { AgentStatus } from "@prisma/client";
import type { Response } from "express";
import { AuthGuard } from "../auth/guards/auth.guard";
import { PrismaService } from "../prisma/prisma.service";
const AGENT_POLL_INTERVAL_MS = 5_000;
const SSE_HEARTBEAT_MS = 15_000;
interface OrchestratorAgentDto {
id: string;
name: string | null;
status: AgentStatus;
type: string;
createdAt: Date;
}
@Controller("orchestrator")
@UseGuards(AuthGuard)
export class OrchestratorController {
constructor(private readonly prisma: PrismaService) {}
@Get("agents")
async getAgents(): Promise<OrchestratorAgentDto[]> {
return this.fetchActiveAgents();
}
@Get("events")
async streamEvents(@Res() res: Response): Promise<void> {
res.setHeader("Content-Type", "text/event-stream");
res.setHeader("Cache-Control", "no-cache");
res.setHeader("Connection", "keep-alive");
res.setHeader("X-Accel-Buffering", "no");
if (typeof res.flushHeaders === "function") {
res.flushHeaders();
}
let isClosed = false;
let previousSnapshot = "";
const emitSnapshotIfChanged = async (): Promise<void> => {
if (isClosed) {
return;
}
try {
const agents = await this.fetchActiveAgents();
const snapshot = JSON.stringify(agents);
if (snapshot !== previousSnapshot) {
previousSnapshot = snapshot;
res.write(
`data: ${JSON.stringify({
type: "agents:updated",
agents,
timestamp: new Date().toISOString(),
})}\n\n`
);
}
} catch (error: unknown) {
const message = error instanceof Error ? error.message : String(error);
res.write(`event: error\n`);
res.write(`data: ${JSON.stringify({ error: message })}\n\n`);
}
};
await emitSnapshotIfChanged();
const pollInterval = setInterval(() => {
void emitSnapshotIfChanged();
}, AGENT_POLL_INTERVAL_MS);
const heartbeatInterval = setInterval(() => {
if (!isClosed) {
res.write(": keepalive\n\n");
}
}, SSE_HEARTBEAT_MS);
res.on("close", () => {
isClosed = true;
clearInterval(pollInterval);
clearInterval(heartbeatInterval);
res.end();
});
}
private async fetchActiveAgents(): Promise<OrchestratorAgentDto[]> {
const agents = await this.prisma.agent.findMany({
where: {
status: {
not: AgentStatus.TERMINATED,
},
},
orderBy: {
createdAt: "desc",
},
select: {
id: true,
name: true,
status: true,
role: true,
createdAt: true,
},
});
return agents.map((agent) => ({
id: agent.id,
name: agent.name,
status: agent.status,
type: agent.role ?? "agent",
createdAt: agent.createdAt,
}));
}
}

View File

@@ -0,0 +1,10 @@
import { Module } from "@nestjs/common";
import { AuthModule } from "../auth/auth.module";
import { PrismaModule } from "../prisma/prisma.module";
import { OrchestratorController } from "./orchestrator.controller";
@Module({
imports: [AuthModule, PrismaModule],
controllers: [OrchestratorController],
})
export class OrchestratorModule {}

View File

@@ -0,0 +1,31 @@
import { describe, expect, it } from "vitest";
import { WidgetsController } from "./widgets.controller";
const THROTTLER_SKIP_DEFAULT_KEY = "THROTTLER:SKIPdefault";
describe("WidgetsController throttler metadata", () => {
it("marks widget data polling endpoints to skip throttling", () => {
const pollingHandlers = [
WidgetsController.prototype.getStatCardData,
WidgetsController.prototype.getChartData,
WidgetsController.prototype.getListData,
WidgetsController.prototype.getCalendarPreviewData,
WidgetsController.prototype.getActiveProjectsData,
WidgetsController.prototype.getAgentChainsData,
];
for (const handler of pollingHandlers) {
expect(Reflect.getMetadata(THROTTLER_SKIP_DEFAULT_KEY, handler)).toBe(true);
}
});
it("does not skip throttling for non-polling widget routes", () => {
expect(
Reflect.getMetadata(THROTTLER_SKIP_DEFAULT_KEY, WidgetsController.prototype.findAll)
).toBe(undefined);
expect(
Reflect.getMetadata(THROTTLER_SKIP_DEFAULT_KEY, WidgetsController.prototype.findByName)
).toBe(undefined);
});
});

View File

@@ -1,4 +1,5 @@
import { Controller, Get, Post, Body, Param, UseGuards, Request } from "@nestjs/common";
import { SkipThrottle as SkipThrottler } from "@nestjs/throttler";
import { WidgetsService } from "./widgets.service";
import { WidgetDataService } from "./widget-data.service";
import { AuthGuard } from "../auth/guards/auth.guard";
@@ -43,6 +44,7 @@ export class WidgetsController {
* Get stat card widget data
*/
@Post("data/stat-card")
@SkipThrottler()
@UseGuards(WorkspaceGuard)
async getStatCardData(@Request() req: RequestWithWorkspace, @Body() query: StatCardQueryDto) {
return this.widgetDataService.getStatCardData(req.workspace.id, query);
@@ -53,6 +55,7 @@ export class WidgetsController {
* Get chart widget data
*/
@Post("data/chart")
@SkipThrottler()
@UseGuards(WorkspaceGuard)
async getChartData(@Request() req: RequestWithWorkspace, @Body() query: ChartQueryDto) {
return this.widgetDataService.getChartData(req.workspace.id, query);
@@ -63,6 +66,7 @@ export class WidgetsController {
* Get list widget data
*/
@Post("data/list")
@SkipThrottler()
@UseGuards(WorkspaceGuard)
async getListData(@Request() req: RequestWithWorkspace, @Body() query: ListQueryDto) {
return this.widgetDataService.getListData(req.workspace.id, query);
@@ -73,6 +77,7 @@ export class WidgetsController {
* Get calendar preview widget data
*/
@Post("data/calendar-preview")
@SkipThrottler()
@UseGuards(WorkspaceGuard)
async getCalendarPreviewData(
@Request() req: RequestWithWorkspace,
@@ -86,6 +91,7 @@ export class WidgetsController {
* Get active projects widget data
*/
@Post("data/active-projects")
@SkipThrottler()
@UseGuards(WorkspaceGuard)
async getActiveProjectsData(@Request() req: RequestWithWorkspace) {
return this.widgetDataService.getActiveProjectsData(req.workspace.id);
@@ -96,6 +102,7 @@ export class WidgetsController {
* Get agent chains widget data (active agent sessions)
*/
@Post("data/agent-chains")
@SkipThrottler()
@UseGuards(WorkspaceGuard)
async getAgentChainsData(@Request() req: RequestWithWorkspace) {
return this.widgetDataService.getAgentChainsData(req.workspace.id);

View File

@@ -85,14 +85,6 @@ const INITIAL_FORM: ProviderFormState = {
isActive: true,
};
function mapProviderTypeToApi(type: string): "ollama" | "openai" | "claude" {
if (type === "ollama" || type === "claude") {
return type;
}
return "openai";
}
function getErrorMessage(error: unknown, fallback: string): string {
if (error instanceof Error && error.message.trim().length > 0) {
return error.message;
@@ -101,6 +93,18 @@ function getErrorMessage(error: unknown, fallback: string): string {
return fallback;
}
function buildProviderName(displayName: string, type: string): string {
const slug = displayName
.trim()
.toLowerCase()
.replace(/[^a-z0-9]+/g, "-")
.replace(/^-+/, "")
.replace(/-+$/, "");
const candidate = `${type}-${slug.length > 0 ? slug : "provider"}`;
return candidate.slice(0, 100);
}
function normalizeProviderModels(models: unknown): FleetProviderModel[] {
if (!Array.isArray(models)) {
return [];
@@ -149,11 +153,11 @@ function modelsToEditorText(models: unknown): string {
.join("\n");
}
function parseModelsText(value: string): string[] {
function parseModelsText(value: string): FleetProviderModel[] {
const seen = new Set<string>();
return value
.split(/\r?\n/g)
.split(/\n|,/g)
.map((segment) => segment.trim())
.filter((segment) => segment.length > 0)
.filter((segment) => {
@@ -162,7 +166,8 @@ function parseModelsText(value: string): string[] {
}
seen.add(segment);
return true;
});
})
.map((id) => ({ id, name: id }));
}
function maskApiKey(value: string): string {
@@ -274,7 +279,6 @@ export default function ProvidersSettingsPage(): ReactElement {
}
const models = parseModelsText(form.modelsText);
const providerModels = models.map((id) => ({ id, name: id }));
const baseUrl = form.baseUrl.trim();
const apiKey = form.apiKey.trim();
@@ -285,7 +289,7 @@ export default function ProvidersSettingsPage(): ReactElement {
const updatePayload: UpdateFleetProviderRequest = {
displayName,
isActive: form.isActive,
models: providerModels,
models,
};
if (baseUrl.length > 0) {
@@ -299,27 +303,21 @@ export default function ProvidersSettingsPage(): ReactElement {
await updateFleetProvider(editingProvider.id, updatePayload);
setSuccessMessage(`Updated provider "${displayName}".`);
} else {
const config: CreateFleetProviderRequest["config"] = {};
const createPayload: CreateFleetProviderRequest = {
name: buildProviderName(displayName, form.type),
displayName,
type: form.type,
models,
};
if (baseUrl.length > 0) {
config.endpoint = baseUrl;
createPayload.baseUrl = baseUrl;
}
if (apiKey.length > 0) {
config.apiKey = apiKey;
createPayload.apiKey = apiKey;
}
if (models.length > 0) {
config.models = models;
}
const createPayload: CreateFleetProviderRequest = {
displayName,
providerType: mapProviderTypeToApi(form.type),
config,
isEnabled: form.isActive,
};
await createFleetProvider(createPayload);
setSuccessMessage(`Added provider "${displayName}".`);
}

View File

@@ -34,25 +34,17 @@ describe("createFleetProvider", (): void => {
vi.mocked(client.apiPost).mockResolvedValueOnce({ id: "provider-1" } as never);
await createFleetProvider({
providerType: "openai",
name: "openai-main",
displayName: "OpenAI Main",
config: {
endpoint: "https://api.openai.com/v1",
apiKey: "sk-test",
models: ["gpt-4.1-mini", "gpt-4o-mini"],
},
isEnabled: true,
type: "openai",
apiKey: "sk-test",
});
expect(client.apiPost).toHaveBeenCalledWith("/api/fleet-settings/providers", {
providerType: "openai",
name: "openai-main",
displayName: "OpenAI Main",
config: {
endpoint: "https://api.openai.com/v1",
apiKey: "sk-test",
models: ["gpt-4.1-mini", "gpt-4o-mini"],
},
isEnabled: true,
type: "openai",
apiKey: "sk-test",
});
});
});

View File

@@ -16,16 +16,13 @@ export interface FleetProvider {
}
export interface CreateFleetProviderRequest {
providerType: "ollama" | "openai" | "claude";
name: string;
displayName: string;
config: {
endpoint?: string;
apiKey?: string;
models?: string[];
timeout?: number;
};
isDefault?: boolean;
isEnabled?: boolean;
type: string;
baseUrl?: string;
apiKey?: string;
apiType?: string;
models?: FleetProviderModel[];
}
export interface UpdateFleetProviderRequest {