diff --git a/apps/orchestrator/src/api/agents/agent-control.service.ts b/apps/orchestrator/src/api/agents/agent-control.service.ts new file mode 100644 index 0000000..e45ae0f --- /dev/null +++ b/apps/orchestrator/src/api/agents/agent-control.service.ts @@ -0,0 +1,68 @@ +import { Injectable } from "@nestjs/common"; +import type { Prisma } from "@prisma/client"; +import { PrismaService } from "../../prisma/prisma.service"; + +@Injectable() +export class AgentControlService { + constructor(private readonly prisma: PrismaService) {} + + private toJsonValue(value: Record): Prisma.InputJsonValue { + return value as Prisma.InputJsonValue; + } + + private async createOperatorAuditLog( + agentId: string, + operatorId: string, + action: "inject" | "pause" | "resume", + payload: Record + ): Promise { + await this.prisma.operatorAuditLog.create({ + data: { + sessionId: agentId, + userId: operatorId, + provider: "internal", + action, + metadata: this.toJsonValue({ payload }), + }, + }); + } + + async injectMessage(agentId: string, operatorId: string, message: string): Promise { + const treeEntry = await this.prisma.agentSessionTree.findUnique({ + where: { sessionId: agentId }, + select: { id: true }, + }); + + if (treeEntry) { + await this.prisma.agentConversationMessage.create({ + data: { + sessionId: agentId, + role: "operator", + content: message, + provider: "internal", + metadata: this.toJsonValue({}), + }, + }); + } + + await this.createOperatorAuditLog(agentId, operatorId, "inject", { message }); + } + + async pauseAgent(agentId: string, operatorId: string): Promise { + await this.prisma.agentSessionTree.updateMany({ + where: { sessionId: agentId }, + data: { status: "paused" }, + }); + + await this.createOperatorAuditLog(agentId, operatorId, "pause", {}); + } + + async resumeAgent(agentId: string, operatorId: string): Promise { + await this.prisma.agentSessionTree.updateMany({ + where: { sessionId: agentId }, + data: { status: "running" }, + }); + + await this.createOperatorAuditLog(agentId, operatorId, "resume", {}); + } +} diff --git a/apps/orchestrator/src/api/agents/agents-killswitch.controller.spec.ts b/apps/orchestrator/src/api/agents/agents-killswitch.controller.spec.ts index 399f1aa..59abe78 100644 --- a/apps/orchestrator/src/api/agents/agents-killswitch.controller.spec.ts +++ b/apps/orchestrator/src/api/agents/agents-killswitch.controller.spec.ts @@ -6,6 +6,7 @@ import { AgentLifecycleService } from "../../spawner/agent-lifecycle.service"; import { KillswitchService } from "../../killswitch/killswitch.service"; import { AgentEventsService } from "./agent-events.service"; import { AgentMessagesService } from "./agent-messages.service"; +import { AgentControlService } from "./agent-control.service"; import type { KillAllResult } from "../../killswitch/killswitch.service"; describe("AgentsController - Killswitch Endpoints", () => { @@ -35,6 +36,11 @@ describe("AgentsController - Killswitch Endpoints", () => { getReplayMessages: ReturnType; getMessagesAfter: ReturnType; }; + let mockControlService: { + injectMessage: ReturnType; + pauseAgent: ReturnType; + resumeAgent: ReturnType; + }; beforeEach(() => { mockKillswitchService = { @@ -77,13 +83,20 @@ describe("AgentsController - Killswitch Endpoints", () => { getMessagesAfter: vi.fn().mockResolvedValue([]), }; + mockControlService = { + injectMessage: vi.fn().mockResolvedValue(undefined), + pauseAgent: vi.fn().mockResolvedValue(undefined), + resumeAgent: vi.fn().mockResolvedValue(undefined), + }; + controller = new AgentsController( mockQueueService as unknown as QueueService, mockSpawnerService as unknown as AgentSpawnerService, mockLifecycleService as unknown as AgentLifecycleService, mockKillswitchService as unknown as KillswitchService, mockEventsService as unknown as AgentEventsService, - mockMessagesService as unknown as AgentMessagesService + mockMessagesService as unknown as AgentMessagesService, + mockControlService as unknown as AgentControlService ); }); diff --git a/apps/orchestrator/src/api/agents/agents.controller.spec.ts b/apps/orchestrator/src/api/agents/agents.controller.spec.ts index baa2721..5b30e86 100644 --- a/apps/orchestrator/src/api/agents/agents.controller.spec.ts +++ b/apps/orchestrator/src/api/agents/agents.controller.spec.ts @@ -5,6 +5,7 @@ import { AgentLifecycleService } from "../../spawner/agent-lifecycle.service"; import { KillswitchService } from "../../killswitch/killswitch.service"; import { AgentEventsService } from "./agent-events.service"; import { AgentMessagesService } from "./agent-messages.service"; +import { AgentControlService } from "./agent-control.service"; import { describe, it, expect, beforeEach, afterEach, vi } from "vitest"; describe("AgentsController", () => { @@ -36,6 +37,11 @@ describe("AgentsController", () => { getReplayMessages: ReturnType; getMessagesAfter: ReturnType; }; + let controlService: { + injectMessage: ReturnType; + pauseAgent: ReturnType; + resumeAgent: ReturnType; + }; beforeEach(() => { // Create mock services @@ -81,6 +87,12 @@ describe("AgentsController", () => { getMessagesAfter: vi.fn().mockResolvedValue([]), }; + controlService = { + injectMessage: vi.fn().mockResolvedValue(undefined), + pauseAgent: vi.fn().mockResolvedValue(undefined), + resumeAgent: vi.fn().mockResolvedValue(undefined), + }; + // Create controller with mocked services controller = new AgentsController( queueService as unknown as QueueService, @@ -88,7 +100,8 @@ describe("AgentsController", () => { lifecycleService as unknown as AgentLifecycleService, killswitchService as unknown as KillswitchService, eventsService as unknown as AgentEventsService, - messagesService as unknown as AgentMessagesService + messagesService as unknown as AgentMessagesService, + controlService as unknown as AgentControlService ); }); @@ -378,6 +391,47 @@ describe("AgentsController", () => { }); }); + describe("agent control endpoints", () => { + const agentId = "0b64079f-4487-42b9-92eb-cf8ea0042a64"; + + it("should inject an operator message", async () => { + const req = { apiKey: "control-key" }; + + const result = await controller.injectAgentMessage( + agentId, + { message: "pause and summarize" }, + req + ); + + expect(controlService.injectMessage).toHaveBeenCalledWith( + agentId, + "control-key", + "pause and summarize" + ); + expect(result).toEqual({ message: `Message injected into agent ${agentId}` }); + }); + + it("should default operator id when request api key is missing", async () => { + await controller.injectAgentMessage(agentId, { message: "continue" }, {}); + + expect(controlService.injectMessage).toHaveBeenCalledWith(agentId, "operator", "continue"); + }); + + it("should pause an agent", async () => { + const result = await controller.pauseAgent(agentId, {}, { apiKey: "ops-user" }); + + expect(controlService.pauseAgent).toHaveBeenCalledWith(agentId, "ops-user"); + expect(result).toEqual({ message: `Agent ${agentId} paused` }); + }); + + it("should resume an agent", async () => { + const result = await controller.resumeAgent(agentId, {}, { apiKey: "ops-user" }); + + expect(controlService.resumeAgent).toHaveBeenCalledWith(agentId, "ops-user"); + expect(result).toEqual({ message: `Agent ${agentId} resumed` }); + }); + }); + describe("getAgentMessages", () => { it("should return paginated message history", async () => { const agentId = "0b64079f-4487-42b9-92eb-cf8ea0042a64"; diff --git a/apps/orchestrator/src/api/agents/agents.controller.ts b/apps/orchestrator/src/api/agents/agents.controller.ts index a94efac..aa82360 100644 --- a/apps/orchestrator/src/api/agents/agents.controller.ts +++ b/apps/orchestrator/src/api/agents/agents.controller.ts @@ -14,6 +14,7 @@ import { Sse, MessageEvent, Query, + Request, } from "@nestjs/common"; import type { AgentConversationMessage } from "@prisma/client"; import { Throttle } from "@nestjs/throttler"; @@ -28,6 +29,9 @@ import { OrchestratorThrottlerGuard } from "../../common/guards/throttler.guard" import { AgentEventsService } from "./agent-events.service"; import { GetMessagesQueryDto } from "./dto/get-messages-query.dto"; import { AgentMessagesService } from "./agent-messages.service"; +import { AgentControlService } from "./agent-control.service"; +import { InjectAgentDto } from "./dto/inject-agent.dto"; +import { PauseAgentDto, ResumeAgentDto } from "./dto/control-agent.dto"; /** * Controller for agent management endpoints @@ -51,7 +55,8 @@ export class AgentsController { private readonly lifecycleService: AgentLifecycleService, private readonly killswitchService: KillswitchService, private readonly eventsService: AgentEventsService, - private readonly messagesService: AgentMessagesService + private readonly messagesService: AgentMessagesService, + private readonly agentControlService: AgentControlService ) {} /** @@ -374,6 +379,57 @@ export class AgentsController { } } + @Post(":agentId/inject") + @Throttle({ default: { limit: 10, ttl: 60000 } }) + @HttpCode(200) + @UsePipes(new ValidationPipe({ transform: true, whitelist: true })) + async injectAgentMessage( + @Param("agentId", ParseUUIDPipe) agentId: string, + @Body() dto: InjectAgentDto, + @Request() req: { apiKey?: string } + ): Promise<{ message: string }> { + const operatorId = req.apiKey ?? "operator"; + await this.agentControlService.injectMessage(agentId, operatorId, dto.message); + + return { + message: `Message injected into agent ${agentId}`, + }; + } + + @Post(":agentId/pause") + @Throttle({ default: { limit: 10, ttl: 60000 } }) + @HttpCode(200) + @UsePipes(new ValidationPipe({ transform: true, whitelist: true })) + async pauseAgent( + @Param("agentId", ParseUUIDPipe) agentId: string, + @Body() _dto: PauseAgentDto, + @Request() req: { apiKey?: string } + ): Promise<{ message: string }> { + const operatorId = req.apiKey ?? "operator"; + await this.agentControlService.pauseAgent(agentId, operatorId); + + return { + message: `Agent ${agentId} paused`, + }; + } + + @Post(":agentId/resume") + @Throttle({ default: { limit: 10, ttl: 60000 } }) + @HttpCode(200) + @UsePipes(new ValidationPipe({ transform: true, whitelist: true })) + async resumeAgent( + @Param("agentId", ParseUUIDPipe) agentId: string, + @Body() _dto: ResumeAgentDto, + @Request() req: { apiKey?: string } + ): Promise<{ message: string }> { + const operatorId = req.apiKey ?? "operator"; + await this.agentControlService.resumeAgent(agentId, operatorId); + + return { + message: `Agent ${agentId} resumed`, + }; + } + /** * Kill all active agents * @returns Summary of kill operation diff --git a/apps/orchestrator/src/api/agents/agents.module.ts b/apps/orchestrator/src/api/agents/agents.module.ts index d59ce73..903f226 100644 --- a/apps/orchestrator/src/api/agents/agents.module.ts +++ b/apps/orchestrator/src/api/agents/agents.module.ts @@ -8,10 +8,16 @@ import { OrchestratorApiKeyGuard } from "../../common/guards/api-key.guard"; import { AgentEventsService } from "./agent-events.service"; import { PrismaModule } from "../../prisma/prisma.module"; import { AgentMessagesService } from "./agent-messages.service"; +import { AgentControlService } from "./agent-control.service"; @Module({ imports: [QueueModule, SpawnerModule, KillswitchModule, ValkeyModule, PrismaModule], controllers: [AgentsController], - providers: [OrchestratorApiKeyGuard, AgentEventsService, AgentMessagesService], + providers: [ + OrchestratorApiKeyGuard, + AgentEventsService, + AgentMessagesService, + AgentControlService, + ], }) export class AgentsModule {} diff --git a/apps/orchestrator/src/api/agents/dto/control-agent.dto.ts b/apps/orchestrator/src/api/agents/dto/control-agent.dto.ts new file mode 100644 index 0000000..bf529e1 --- /dev/null +++ b/apps/orchestrator/src/api/agents/dto/control-agent.dto.ts @@ -0,0 +1,3 @@ +export class PauseAgentDto {} + +export class ResumeAgentDto {} diff --git a/apps/orchestrator/src/api/agents/dto/inject-agent.dto.ts b/apps/orchestrator/src/api/agents/dto/inject-agent.dto.ts new file mode 100644 index 0000000..b13cecb --- /dev/null +++ b/apps/orchestrator/src/api/agents/dto/inject-agent.dto.ts @@ -0,0 +1,7 @@ +import { IsNotEmpty, IsString } from "class-validator"; + +export class InjectAgentDto { + @IsString() + @IsNotEmpty() + message!: string; +}