Move status validation from post-retrieval checks into Prisma WHERE clauses. This prevents TOCTOU issues and ensures only ACTIVE connections are retrieved. Removed redundant status checks after retrieval in both query and command services. Security improvement: Enforces status=ACTIVE in database query rather than checking after retrieval, preventing race conditions. Fixes #283 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
400 lines
12 KiB
TypeScript
400 lines
12 KiB
TypeScript
/**
|
|
* Command Service
|
|
*
|
|
* Handles federated command messages.
|
|
*/
|
|
|
|
import { Injectable, Logger } from "@nestjs/common";
|
|
import { ModuleRef } from "@nestjs/core";
|
|
import { HttpService } from "@nestjs/axios";
|
|
import { randomUUID } from "crypto";
|
|
import { firstValueFrom } from "rxjs";
|
|
import { PrismaService } from "../prisma/prisma.service";
|
|
import { FederationService } from "./federation.service";
|
|
import { SignatureService } from "./signature.service";
|
|
import {
|
|
FederationConnectionStatus,
|
|
FederationMessageType,
|
|
FederationMessageStatus,
|
|
} from "@prisma/client";
|
|
import type { CommandMessage, CommandResponse, CommandMessageDetails } from "./types/message.types";
|
|
import { CommandProcessingError, UnknownCommandTypeError } from "./errors/command.errors";
|
|
|
|
@Injectable()
|
|
export class CommandService {
|
|
private readonly logger = new Logger(CommandService.name);
|
|
|
|
constructor(
|
|
private readonly prisma: PrismaService,
|
|
private readonly federationService: FederationService,
|
|
private readonly signatureService: SignatureService,
|
|
private readonly httpService: HttpService,
|
|
private readonly moduleRef: ModuleRef
|
|
) {}
|
|
|
|
/**
|
|
* Send a command to a remote instance
|
|
*/
|
|
async sendCommand(
|
|
workspaceId: string,
|
|
connectionId: string,
|
|
commandType: string,
|
|
payload: Record<string, unknown>
|
|
): Promise<CommandMessageDetails> {
|
|
// Validate connection exists and is active (enforced in query)
|
|
const connection = await this.prisma.federationConnection.findUnique({
|
|
where: {
|
|
id: connectionId,
|
|
workspaceId,
|
|
status: FederationConnectionStatus.ACTIVE,
|
|
},
|
|
});
|
|
|
|
if (!connection) {
|
|
throw new Error("Connection not found");
|
|
}
|
|
|
|
// Get local instance identity
|
|
const identity = await this.federationService.getInstanceIdentity();
|
|
|
|
// Create command message
|
|
const messageId = randomUUID();
|
|
const timestamp = Date.now();
|
|
|
|
const commandPayload: Record<string, unknown> = {
|
|
messageId,
|
|
instanceId: identity.instanceId,
|
|
commandType,
|
|
payload,
|
|
timestamp,
|
|
};
|
|
|
|
// Sign the command
|
|
const signature = await this.signatureService.signMessage(commandPayload);
|
|
|
|
const signedCommand = {
|
|
messageId,
|
|
instanceId: identity.instanceId,
|
|
commandType,
|
|
payload,
|
|
timestamp,
|
|
signature,
|
|
} as CommandMessage;
|
|
|
|
// Store message in database
|
|
const message = await this.prisma.federationMessage.create({
|
|
data: {
|
|
workspaceId,
|
|
connectionId,
|
|
messageType: FederationMessageType.COMMAND,
|
|
messageId,
|
|
commandType,
|
|
payload: payload as never,
|
|
status: FederationMessageStatus.PENDING,
|
|
signature,
|
|
},
|
|
});
|
|
|
|
// Send command to remote instance
|
|
try {
|
|
const remoteUrl = `${connection.remoteUrl}/api/v1/federation/incoming/command`;
|
|
await firstValueFrom(this.httpService.post(remoteUrl, signedCommand));
|
|
|
|
this.logger.log(`Command sent to ${connection.remoteUrl}: ${messageId}`);
|
|
} catch (error) {
|
|
this.logger.error(`Failed to send command to ${connection.remoteUrl}`, error);
|
|
|
|
// Update message status to failed
|
|
await this.prisma.federationMessage.update({
|
|
where: { id: message.id },
|
|
data: {
|
|
status: FederationMessageStatus.FAILED,
|
|
error: error instanceof Error ? error.message : "Unknown error",
|
|
},
|
|
});
|
|
|
|
throw new Error("Failed to send command");
|
|
}
|
|
|
|
return this.mapToCommandMessageDetails(message);
|
|
}
|
|
|
|
/**
|
|
* Handle incoming command from remote instance
|
|
*/
|
|
async handleIncomingCommand(commandMessage: CommandMessage): Promise<CommandResponse> {
|
|
this.logger.log(
|
|
`Received command from ${commandMessage.instanceId}: ${commandMessage.messageId}`
|
|
);
|
|
|
|
// Validate timestamp
|
|
if (!this.signatureService.validateTimestamp(commandMessage.timestamp)) {
|
|
throw new Error("Command timestamp is outside acceptable range");
|
|
}
|
|
|
|
// Find connection for remote instance (status enforced in query)
|
|
const connection = await this.prisma.federationConnection.findFirst({
|
|
where: {
|
|
remoteInstanceId: commandMessage.instanceId,
|
|
status: FederationConnectionStatus.ACTIVE,
|
|
},
|
|
});
|
|
|
|
if (!connection) {
|
|
throw new Error("No connection found for remote instance");
|
|
}
|
|
|
|
// Verify signature
|
|
const { signature, ...messageToVerify } = commandMessage;
|
|
const verificationResult = await this.signatureService.verifyMessage(
|
|
messageToVerify,
|
|
signature,
|
|
commandMessage.instanceId
|
|
);
|
|
|
|
if (!verificationResult.valid) {
|
|
throw new Error(verificationResult.error ?? "Invalid signature");
|
|
}
|
|
|
|
// Process command
|
|
let responseData: unknown;
|
|
let success = true;
|
|
let errorMessage: string | undefined;
|
|
|
|
try {
|
|
// Route agent commands to FederationAgentService
|
|
if (commandMessage.commandType.startsWith("agent.")) {
|
|
// Import FederationAgentService dynamically to avoid circular dependency
|
|
const { FederationAgentService } = await import("./federation-agent.service");
|
|
const federationAgentService = this.moduleRef.get(FederationAgentService, {
|
|
strict: false,
|
|
});
|
|
|
|
const agentResponse = await federationAgentService.handleAgentCommand(
|
|
commandMessage.instanceId,
|
|
commandMessage.commandType,
|
|
commandMessage.payload
|
|
);
|
|
|
|
success = agentResponse.success;
|
|
responseData = agentResponse.data;
|
|
errorMessage = agentResponse.error;
|
|
} else {
|
|
// Unknown command type - throw business logic error
|
|
throw new UnknownCommandTypeError(commandMessage.commandType);
|
|
}
|
|
} catch (error) {
|
|
// Only catch expected business logic errors
|
|
// System errors (OOM, DB failures, network issues) should propagate
|
|
if (error instanceof CommandProcessingError) {
|
|
success = false;
|
|
errorMessage = error.message;
|
|
this.logger.warn(`Command processing failed (business logic): ${errorMessage}`, {
|
|
commandType: commandMessage.commandType,
|
|
instanceId: commandMessage.instanceId,
|
|
messageId: commandMessage.messageId,
|
|
});
|
|
} else {
|
|
// System error - log and re-throw to preserve stack trace
|
|
this.logger.error(`System error during command processing: ${String(error)}`, {
|
|
commandType: commandMessage.commandType,
|
|
instanceId: commandMessage.instanceId,
|
|
messageId: commandMessage.messageId,
|
|
error: error instanceof Error ? error.stack : String(error),
|
|
});
|
|
throw error;
|
|
}
|
|
}
|
|
|
|
// Get local instance identity
|
|
const identity = await this.federationService.getInstanceIdentity();
|
|
|
|
// Create response
|
|
const responseMessageId = randomUUID();
|
|
const responseTimestamp = Date.now();
|
|
|
|
const responsePayload: Record<string, unknown> = {
|
|
messageId: responseMessageId,
|
|
correlationId: commandMessage.messageId,
|
|
instanceId: identity.instanceId,
|
|
success,
|
|
timestamp: responseTimestamp,
|
|
};
|
|
|
|
if (responseData !== undefined) {
|
|
responsePayload.data = responseData;
|
|
}
|
|
|
|
if (errorMessage !== undefined) {
|
|
responsePayload.error = errorMessage;
|
|
}
|
|
|
|
// Sign the response
|
|
const responseSignature = await this.signatureService.signMessage(responsePayload);
|
|
|
|
const response = {
|
|
messageId: responseMessageId,
|
|
correlationId: commandMessage.messageId,
|
|
instanceId: identity.instanceId,
|
|
success,
|
|
...(responseData !== undefined ? { data: responseData } : {}),
|
|
...(errorMessage !== undefined ? { error: errorMessage } : {}),
|
|
timestamp: responseTimestamp,
|
|
signature: responseSignature,
|
|
} as CommandResponse;
|
|
|
|
return response;
|
|
}
|
|
|
|
/**
|
|
* Get all command messages for a workspace
|
|
*/
|
|
async getCommandMessages(
|
|
workspaceId: string,
|
|
status?: FederationMessageStatus
|
|
): Promise<CommandMessageDetails[]> {
|
|
const where: Record<string, unknown> = {
|
|
workspaceId,
|
|
messageType: FederationMessageType.COMMAND,
|
|
};
|
|
|
|
if (status) {
|
|
where.status = status;
|
|
}
|
|
|
|
const messages = await this.prisma.federationMessage.findMany({
|
|
where,
|
|
orderBy: { createdAt: "desc" },
|
|
});
|
|
|
|
return messages.map((msg) => this.mapToCommandMessageDetails(msg));
|
|
}
|
|
|
|
/**
|
|
* Get a single command message
|
|
*/
|
|
async getCommandMessage(workspaceId: string, messageId: string): Promise<CommandMessageDetails> {
|
|
const message = await this.prisma.federationMessage.findUnique({
|
|
where: { id: messageId, workspaceId },
|
|
});
|
|
|
|
if (!message) {
|
|
throw new Error("Command message not found");
|
|
}
|
|
|
|
return this.mapToCommandMessageDetails(message);
|
|
}
|
|
|
|
/**
|
|
* Process a command response from remote instance
|
|
*/
|
|
async processCommandResponse(response: CommandResponse): Promise<void> {
|
|
this.logger.log(`Received response for command: ${response.correlationId}`);
|
|
|
|
// Validate timestamp
|
|
if (!this.signatureService.validateTimestamp(response.timestamp)) {
|
|
throw new Error("Response timestamp is outside acceptable range");
|
|
}
|
|
|
|
// Find original command message
|
|
const message = await this.prisma.federationMessage.findFirst({
|
|
where: {
|
|
messageId: response.correlationId,
|
|
messageType: FederationMessageType.COMMAND,
|
|
},
|
|
});
|
|
|
|
if (!message) {
|
|
throw new Error("Original command message not found");
|
|
}
|
|
|
|
// Verify signature
|
|
const { signature, ...responseToVerify } = response;
|
|
const verificationResult = await this.signatureService.verifyMessage(
|
|
responseToVerify,
|
|
signature,
|
|
response.instanceId
|
|
);
|
|
|
|
if (!verificationResult.valid) {
|
|
throw new Error(verificationResult.error ?? "Invalid signature");
|
|
}
|
|
|
|
// Update message with response
|
|
const updateData: Record<string, unknown> = {
|
|
status: response.success ? FederationMessageStatus.DELIVERED : FederationMessageStatus.FAILED,
|
|
deliveredAt: new Date(),
|
|
};
|
|
|
|
if (response.data !== undefined) {
|
|
updateData.response = response.data;
|
|
}
|
|
|
|
if (response.error !== undefined) {
|
|
updateData.error = response.error;
|
|
}
|
|
|
|
await this.prisma.federationMessage.update({
|
|
where: { id: message.id },
|
|
data: updateData,
|
|
});
|
|
|
|
this.logger.log(`Command response processed: ${response.correlationId}`);
|
|
}
|
|
|
|
/**
|
|
* Map Prisma FederationMessage to CommandMessageDetails
|
|
*/
|
|
private mapToCommandMessageDetails(message: {
|
|
id: string;
|
|
workspaceId: string;
|
|
connectionId: string;
|
|
messageType: FederationMessageType;
|
|
messageId: string;
|
|
correlationId: string | null;
|
|
query: string | null;
|
|
commandType: string | null;
|
|
payload: unknown;
|
|
response: unknown;
|
|
status: FederationMessageStatus;
|
|
error: string | null;
|
|
createdAt: Date;
|
|
updatedAt: Date;
|
|
deliveredAt: Date | null;
|
|
}): CommandMessageDetails {
|
|
const details: CommandMessageDetails = {
|
|
id: message.id,
|
|
workspaceId: message.workspaceId,
|
|
connectionId: message.connectionId,
|
|
messageType: message.messageType,
|
|
messageId: message.messageId,
|
|
response: message.response,
|
|
status: message.status,
|
|
createdAt: message.createdAt,
|
|
updatedAt: message.updatedAt,
|
|
};
|
|
|
|
if (message.correlationId !== null) {
|
|
details.correlationId = message.correlationId;
|
|
}
|
|
|
|
if (message.commandType !== null) {
|
|
details.commandType = message.commandType;
|
|
}
|
|
|
|
if (message.payload !== null && typeof message.payload === "object") {
|
|
details.payload = message.payload as Record<string, unknown>;
|
|
}
|
|
|
|
if (message.error !== null) {
|
|
details.error = message.error;
|
|
}
|
|
|
|
if (message.deliveredAt !== null) {
|
|
details.deliveredAt = message.deliveredAt;
|
|
}
|
|
|
|
return details;
|
|
}
|
|
}
|