Files
stack/apps/api/src/federation/command.service.ts
Jason Woltje aabf97fe4e fix(#283): Enforce connection status validation in queries
Move status validation from post-retrieval checks into Prisma WHERE
clauses. This prevents TOCTOU issues and ensures only ACTIVE
connections are retrieved. Removed redundant status checks after
retrieval in both query and command services.

Security improvement: Enforces status=ACTIVE in database query rather
than checking after retrieval, preventing race conditions.

Fixes #283

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-02-03 21:32:47 -06:00

400 lines
12 KiB
TypeScript

/**
* Command Service
*
* Handles federated command messages.
*/
import { Injectable, Logger } from "@nestjs/common";
import { ModuleRef } from "@nestjs/core";
import { HttpService } from "@nestjs/axios";
import { randomUUID } from "crypto";
import { firstValueFrom } from "rxjs";
import { PrismaService } from "../prisma/prisma.service";
import { FederationService } from "./federation.service";
import { SignatureService } from "./signature.service";
import {
FederationConnectionStatus,
FederationMessageType,
FederationMessageStatus,
} from "@prisma/client";
import type { CommandMessage, CommandResponse, CommandMessageDetails } from "./types/message.types";
import { CommandProcessingError, UnknownCommandTypeError } from "./errors/command.errors";
@Injectable()
export class CommandService {
private readonly logger = new Logger(CommandService.name);
constructor(
private readonly prisma: PrismaService,
private readonly federationService: FederationService,
private readonly signatureService: SignatureService,
private readonly httpService: HttpService,
private readonly moduleRef: ModuleRef
) {}
/**
* Send a command to a remote instance
*/
async sendCommand(
workspaceId: string,
connectionId: string,
commandType: string,
payload: Record<string, unknown>
): Promise<CommandMessageDetails> {
// Validate connection exists and is active (enforced in query)
const connection = await this.prisma.federationConnection.findUnique({
where: {
id: connectionId,
workspaceId,
status: FederationConnectionStatus.ACTIVE,
},
});
if (!connection) {
throw new Error("Connection not found");
}
// Get local instance identity
const identity = await this.federationService.getInstanceIdentity();
// Create command message
const messageId = randomUUID();
const timestamp = Date.now();
const commandPayload: Record<string, unknown> = {
messageId,
instanceId: identity.instanceId,
commandType,
payload,
timestamp,
};
// Sign the command
const signature = await this.signatureService.signMessage(commandPayload);
const signedCommand = {
messageId,
instanceId: identity.instanceId,
commandType,
payload,
timestamp,
signature,
} as CommandMessage;
// Store message in database
const message = await this.prisma.federationMessage.create({
data: {
workspaceId,
connectionId,
messageType: FederationMessageType.COMMAND,
messageId,
commandType,
payload: payload as never,
status: FederationMessageStatus.PENDING,
signature,
},
});
// Send command to remote instance
try {
const remoteUrl = `${connection.remoteUrl}/api/v1/federation/incoming/command`;
await firstValueFrom(this.httpService.post(remoteUrl, signedCommand));
this.logger.log(`Command sent to ${connection.remoteUrl}: ${messageId}`);
} catch (error) {
this.logger.error(`Failed to send command to ${connection.remoteUrl}`, error);
// Update message status to failed
await this.prisma.federationMessage.update({
where: { id: message.id },
data: {
status: FederationMessageStatus.FAILED,
error: error instanceof Error ? error.message : "Unknown error",
},
});
throw new Error("Failed to send command");
}
return this.mapToCommandMessageDetails(message);
}
/**
* Handle incoming command from remote instance
*/
async handleIncomingCommand(commandMessage: CommandMessage): Promise<CommandResponse> {
this.logger.log(
`Received command from ${commandMessage.instanceId}: ${commandMessage.messageId}`
);
// Validate timestamp
if (!this.signatureService.validateTimestamp(commandMessage.timestamp)) {
throw new Error("Command timestamp is outside acceptable range");
}
// Find connection for remote instance (status enforced in query)
const connection = await this.prisma.federationConnection.findFirst({
where: {
remoteInstanceId: commandMessage.instanceId,
status: FederationConnectionStatus.ACTIVE,
},
});
if (!connection) {
throw new Error("No connection found for remote instance");
}
// Verify signature
const { signature, ...messageToVerify } = commandMessage;
const verificationResult = await this.signatureService.verifyMessage(
messageToVerify,
signature,
commandMessage.instanceId
);
if (!verificationResult.valid) {
throw new Error(verificationResult.error ?? "Invalid signature");
}
// Process command
let responseData: unknown;
let success = true;
let errorMessage: string | undefined;
try {
// Route agent commands to FederationAgentService
if (commandMessage.commandType.startsWith("agent.")) {
// Import FederationAgentService dynamically to avoid circular dependency
const { FederationAgentService } = await import("./federation-agent.service");
const federationAgentService = this.moduleRef.get(FederationAgentService, {
strict: false,
});
const agentResponse = await federationAgentService.handleAgentCommand(
commandMessage.instanceId,
commandMessage.commandType,
commandMessage.payload
);
success = agentResponse.success;
responseData = agentResponse.data;
errorMessage = agentResponse.error;
} else {
// Unknown command type - throw business logic error
throw new UnknownCommandTypeError(commandMessage.commandType);
}
} catch (error) {
// Only catch expected business logic errors
// System errors (OOM, DB failures, network issues) should propagate
if (error instanceof CommandProcessingError) {
success = false;
errorMessage = error.message;
this.logger.warn(`Command processing failed (business logic): ${errorMessage}`, {
commandType: commandMessage.commandType,
instanceId: commandMessage.instanceId,
messageId: commandMessage.messageId,
});
} else {
// System error - log and re-throw to preserve stack trace
this.logger.error(`System error during command processing: ${String(error)}`, {
commandType: commandMessage.commandType,
instanceId: commandMessage.instanceId,
messageId: commandMessage.messageId,
error: error instanceof Error ? error.stack : String(error),
});
throw error;
}
}
// Get local instance identity
const identity = await this.federationService.getInstanceIdentity();
// Create response
const responseMessageId = randomUUID();
const responseTimestamp = Date.now();
const responsePayload: Record<string, unknown> = {
messageId: responseMessageId,
correlationId: commandMessage.messageId,
instanceId: identity.instanceId,
success,
timestamp: responseTimestamp,
};
if (responseData !== undefined) {
responsePayload.data = responseData;
}
if (errorMessage !== undefined) {
responsePayload.error = errorMessage;
}
// Sign the response
const responseSignature = await this.signatureService.signMessage(responsePayload);
const response = {
messageId: responseMessageId,
correlationId: commandMessage.messageId,
instanceId: identity.instanceId,
success,
...(responseData !== undefined ? { data: responseData } : {}),
...(errorMessage !== undefined ? { error: errorMessage } : {}),
timestamp: responseTimestamp,
signature: responseSignature,
} as CommandResponse;
return response;
}
/**
* Get all command messages for a workspace
*/
async getCommandMessages(
workspaceId: string,
status?: FederationMessageStatus
): Promise<CommandMessageDetails[]> {
const where: Record<string, unknown> = {
workspaceId,
messageType: FederationMessageType.COMMAND,
};
if (status) {
where.status = status;
}
const messages = await this.prisma.federationMessage.findMany({
where,
orderBy: { createdAt: "desc" },
});
return messages.map((msg) => this.mapToCommandMessageDetails(msg));
}
/**
* Get a single command message
*/
async getCommandMessage(workspaceId: string, messageId: string): Promise<CommandMessageDetails> {
const message = await this.prisma.federationMessage.findUnique({
where: { id: messageId, workspaceId },
});
if (!message) {
throw new Error("Command message not found");
}
return this.mapToCommandMessageDetails(message);
}
/**
* Process a command response from remote instance
*/
async processCommandResponse(response: CommandResponse): Promise<void> {
this.logger.log(`Received response for command: ${response.correlationId}`);
// Validate timestamp
if (!this.signatureService.validateTimestamp(response.timestamp)) {
throw new Error("Response timestamp is outside acceptable range");
}
// Find original command message
const message = await this.prisma.federationMessage.findFirst({
where: {
messageId: response.correlationId,
messageType: FederationMessageType.COMMAND,
},
});
if (!message) {
throw new Error("Original command message not found");
}
// Verify signature
const { signature, ...responseToVerify } = response;
const verificationResult = await this.signatureService.verifyMessage(
responseToVerify,
signature,
response.instanceId
);
if (!verificationResult.valid) {
throw new Error(verificationResult.error ?? "Invalid signature");
}
// Update message with response
const updateData: Record<string, unknown> = {
status: response.success ? FederationMessageStatus.DELIVERED : FederationMessageStatus.FAILED,
deliveredAt: new Date(),
};
if (response.data !== undefined) {
updateData.response = response.data;
}
if (response.error !== undefined) {
updateData.error = response.error;
}
await this.prisma.federationMessage.update({
where: { id: message.id },
data: updateData,
});
this.logger.log(`Command response processed: ${response.correlationId}`);
}
/**
* Map Prisma FederationMessage to CommandMessageDetails
*/
private mapToCommandMessageDetails(message: {
id: string;
workspaceId: string;
connectionId: string;
messageType: FederationMessageType;
messageId: string;
correlationId: string | null;
query: string | null;
commandType: string | null;
payload: unknown;
response: unknown;
status: FederationMessageStatus;
error: string | null;
createdAt: Date;
updatedAt: Date;
deliveredAt: Date | null;
}): CommandMessageDetails {
const details: CommandMessageDetails = {
id: message.id,
workspaceId: message.workspaceId,
connectionId: message.connectionId,
messageType: message.messageType,
messageId: message.messageId,
response: message.response,
status: message.status,
createdAt: message.createdAt,
updatedAt: message.updatedAt,
};
if (message.correlationId !== null) {
details.correlationId = message.correlationId;
}
if (message.commandType !== null) {
details.commandType = message.commandType;
}
if (message.payload !== null && typeof message.payload === "object") {
details.payload = message.payload as Record<string, unknown>;
}
if (message.error !== null) {
details.error = message.error;
}
if (message.deliveredAt !== null) {
details.deliveredAt = message.deliveredAt;
}
return details;
}
}