diff --git a/apps/api/src/federation/connection.service.spec.ts b/apps/api/src/federation/connection.service.spec.ts index dcd7f7b..07e6c75 100644 --- a/apps/api/src/federation/connection.service.spec.ts +++ b/apps/api/src/federation/connection.service.spec.ts @@ -88,6 +88,7 @@ describe("ConnectionService", () => { findMany: vi.fn(), update: vi.fn(), delete: vi.fn(), + count: vi.fn(), }, }, }, @@ -136,6 +137,44 @@ describe("ConnectionService", () => { }); describe("initiateConnection", () => { + it("should throw error if workspace has reached connection limit", async () => { + const existingConnections = Array.from({ length: 100 }, (_, i) => ({ + ...mockConnection, + id: `conn-${i}`, + })); + + vi.spyOn(prismaService.federationConnection, "count").mockResolvedValue(100); + + await expect(service.initiateConnection(mockWorkspaceId, mockRemoteUrl)).rejects.toThrow( + "Connection limit reached for workspace. Maximum 100 connections allowed per workspace." + ); + }); + + it("should reject connection to instance with incompatible protocol version", async () => { + const incompatibleRemoteIdentity = { + ...mockRemoteIdentity, + capabilities: { + ...mockRemoteIdentity.capabilities, + protocolVersion: "2.0", + }, + }; + + const mockAxiosResponse: AxiosResponse = { + data: incompatibleRemoteIdentity, + status: 200, + statusText: "OK", + headers: {}, + config: {} as never, + }; + + vi.spyOn(prismaService.federationConnection, "count").mockResolvedValue(5); + vi.spyOn(httpService, "get").mockReturnValue(of(mockAxiosResponse)); + + await expect(service.initiateConnection(mockWorkspaceId, mockRemoteUrl)).rejects.toThrow( + "Incompatible protocol version. Expected 1.0, received 2.0" + ); + }); + it("should create a pending connection", async () => { const mockAxiosResponse: AxiosResponse = { data: mockRemoteIdentity, @@ -145,6 +184,7 @@ describe("ConnectionService", () => { config: {} as never, }; + vi.spyOn(prismaService.federationConnection, "count").mockResolvedValue(5); vi.spyOn(httpService, "get").mockReturnValue(of(mockAxiosResponse)); vi.spyOn(httpService, "post").mockReturnValue( of({ data: { accepted: true } } as AxiosResponse) @@ -176,6 +216,7 @@ describe("ConnectionService", () => { config: {} as never, }; + vi.spyOn(prismaService.federationConnection, "count").mockResolvedValue(5); vi.spyOn(httpService, "get").mockReturnValue(of(mockAxiosResponse)); vi.spyOn(httpService, "post").mockReturnValue( of({ data: { accepted: true } } as AxiosResponse) @@ -202,6 +243,7 @@ describe("ConnectionService", () => { config: {} as never, }; + vi.spyOn(prismaService.federationConnection, "count").mockResolvedValue(5); const postSpy = vi .spyOn(httpService, "post") .mockReturnValue(of({ data: { accepted: true } } as AxiosResponse)); @@ -230,6 +272,7 @@ describe("ConnectionService", () => { config: {} as never, }; + vi.spyOn(prismaService.federationConnection, "count").mockResolvedValue(5); vi.spyOn(httpService, "get").mockReturnValue(of(mockAxiosResponse)); vi.spyOn(httpService, "post").mockReturnValue( throwError(() => new Error("Connection refused")) @@ -431,6 +474,42 @@ describe("ConnectionService", () => { signature: "valid-signature", }; + it("should reject request with incompatible protocol version", async () => { + const incompatibleRequest = { + ...mockRequest, + capabilities: { + ...mockRemoteIdentity.capabilities, + protocolVersion: "2.0", + }, + }; + + vi.spyOn(signatureService, "verifyConnectionRequest").mockResolvedValue({ valid: true }); + + await expect( + service.handleIncomingConnectionRequest(mockWorkspaceId, incompatibleRequest) + ).rejects.toThrow("Incompatible protocol version. Expected 1.0, received 2.0"); + }); + + it("should accept request with compatible protocol version", async () => { + const compatibleRequest = { + ...mockRequest, + capabilities: { + ...mockRemoteIdentity.capabilities, + protocolVersion: "1.0", + }, + }; + + vi.spyOn(signatureService, "verifyConnectionRequest").mockResolvedValue({ valid: true }); + vi.spyOn(prismaService.federationConnection, "create").mockResolvedValue(mockConnection); + + const result = await service.handleIncomingConnectionRequest( + mockWorkspaceId, + compatibleRequest + ); + + expect(result.status).toBe(FederationConnectionStatus.PENDING); + }); + it("should validate connection request signature", async () => { const verifySpy = vi.spyOn(signatureService, "verifyConnectionRequest"); vi.spyOn(prismaService.federationConnection, "create").mockResolvedValue(mockConnection); diff --git a/apps/api/src/federation/connection.service.ts b/apps/api/src/federation/connection.service.ts index 9668541..be82a63 100644 --- a/apps/api/src/federation/connection.service.ts +++ b/apps/api/src/federation/connection.service.ts @@ -21,10 +21,13 @@ import { FederationAuditService } from "./audit.service"; import { firstValueFrom } from "rxjs"; import type { ConnectionRequest, ConnectionDetails } from "./types/connection.types"; import type { PublicInstanceIdentity } from "./types/instance.types"; +import { FEDERATION_PROTOCOL_VERSION } from "./constants"; +import { withRetry } from "./utils/retry"; @Injectable() export class ConnectionService { private readonly logger = new Logger(ConnectionService.name); + private readonly MAX_CONNECTIONS_PER_WORKSPACE = 100; constructor( private readonly prisma: PrismaService, @@ -40,9 +43,23 @@ export class ConnectionService { async initiateConnection(workspaceId: string, remoteUrl: string): Promise { this.logger.log(`Initiating connection to ${remoteUrl} for workspace ${workspaceId}`); + // Check connection limit for workspace + const connectionCount = await this.prisma.federationConnection.count({ + where: { workspaceId }, + }); + + if (connectionCount >= this.MAX_CONNECTIONS_PER_WORKSPACE) { + throw new BadRequestException( + `Connection limit reached for workspace. Maximum ${String(this.MAX_CONNECTIONS_PER_WORKSPACE)} connections allowed per workspace.` + ); + } + // Fetch remote instance identity const remoteIdentity = await this.fetchRemoteIdentity(remoteUrl); + // Validate protocol version compatibility + this.validateProtocolVersion(remoteIdentity.capabilities.protocolVersion); + // Get our instance identity const localIdentity = await this.federationService.getInstanceIdentity(); @@ -71,10 +88,19 @@ export class ConnectionService { const signature = await this.signatureService.signMessage(request); const signedRequest: ConnectionRequest = { ...request, signature }; - // Send connection request to remote instance + // Send connection request to remote instance with retry logic try { - await firstValueFrom( - this.httpService.post(`${remoteUrl}/api/v1/federation/incoming/connect`, signedRequest) + await withRetry( + async () => { + return await firstValueFrom( + this.httpService.post(`${remoteUrl}/api/v1/federation/incoming/connect`, signedRequest) + ); + }, + { + maxRetries: 3, + initialDelay: 1000, // 1s + maxDelay: 8000, // 8s + } ); this.logger.log(`Connection request sent to ${remoteUrl}`); } catch (error) { @@ -304,6 +330,25 @@ export class ConnectionService { throw new UnauthorizedException("Invalid connection request signature"); } + // Validate protocol version compatibility + try { + this.validateProtocolVersion(request.capabilities.protocolVersion); + } catch (error) { + const errorMsg = error instanceof Error ? error.message : "Unknown error"; + this.logger.warn(`Incompatible protocol version from ${request.instanceId}: ${errorMsg}`); + + // Audit log: Connection rejected + this.auditService.logIncomingConnectionRejected({ + workspaceId, + remoteInstanceId: request.instanceId, + remoteUrl: request.instanceUrl, + reason: "Incompatible protocol version", + error: errorMsg, + }); + + throw error; + } + // Create pending connection const connection = await this.prisma.federationConnection.create({ data: { @@ -333,13 +378,24 @@ export class ConnectionService { } /** - * Fetch remote instance identity via HTTP + * Fetch remote instance identity via HTTP with retry logic */ private async fetchRemoteIdentity(remoteUrl: string): Promise { try { const normalizedUrl = this.normalizeUrl(remoteUrl); - const response = await firstValueFrom( - this.httpService.get(`${normalizedUrl}/api/v1/federation/instance`) + const response = await withRetry( + async () => { + return await firstValueFrom( + this.httpService.get( + `${normalizedUrl}/api/v1/federation/instance` + ) + ); + }, + { + maxRetries: 3, + initialDelay: 1000, // 1s + maxDelay: 8000, // 8s + } ); return response.data; @@ -391,4 +447,22 @@ export class ConnectionService { disconnectedAt: connection.disconnectedAt, }; } + + /** + * Validate protocol version compatibility + * Currently requires exact version match + */ + private validateProtocolVersion(remoteVersion: string | undefined): void { + if (!remoteVersion) { + throw new BadRequestException( + `Protocol version not specified. Expected ${FEDERATION_PROTOCOL_VERSION}` + ); + } + + if (remoteVersion !== FEDERATION_PROTOCOL_VERSION) { + throw new BadRequestException( + `Incompatible protocol version. Expected ${FEDERATION_PROTOCOL_VERSION}, received ${remoteVersion}` + ); + } + } } diff --git a/apps/api/src/federation/constants.ts b/apps/api/src/federation/constants.ts new file mode 100644 index 0000000..e4b1c00 --- /dev/null +++ b/apps/api/src/federation/constants.ts @@ -0,0 +1,13 @@ +/** + * Federation Protocol Constants + * + * Constants for federation protocol versioning and configuration. + */ + +/** + * Current federation protocol version + * Format: MAJOR.MINOR + * - MAJOR version: Breaking changes to protocol + * - MINOR version: Backward-compatible additions + */ +export const FEDERATION_PROTOCOL_VERSION = "1.0"; diff --git a/apps/api/src/federation/dto/capabilities.dto.spec.ts b/apps/api/src/federation/dto/capabilities.dto.spec.ts new file mode 100644 index 0000000..ded956c --- /dev/null +++ b/apps/api/src/federation/dto/capabilities.dto.spec.ts @@ -0,0 +1,80 @@ +/** + * Capabilities DTO Tests + * + * Tests for FederationCapabilities validation. + */ + +import { describe, it, expect } from "vitest"; +import { validate } from "class-validator"; +import { plainToInstance } from "class-transformer"; +import { FederationCapabilitiesDto } from "./capabilities.dto"; + +describe("FederationCapabilitiesDto", () => { + it("should accept valid capabilities", async () => { + const plain = { + supportsQuery: true, + supportsCommand: false, + supportsEvent: true, + supportsAgentSpawn: false, + protocolVersion: "1.0", + }; + + const dto = plainToInstance(FederationCapabilitiesDto, plain); + const errors = await validate(dto); + + expect(errors).toHaveLength(0); + }); + + it("should accept minimal valid capabilities", async () => { + const plain = {}; + + const dto = plainToInstance(FederationCapabilitiesDto, plain); + const errors = await validate(dto); + + expect(errors).toHaveLength(0); + }); + + it("should reject invalid boolean for supportsQuery", async () => { + const plain = { + supportsQuery: "yes", // Should be boolean + }; + + const dto = plainToInstance(FederationCapabilitiesDto, plain); + const errors = await validate(dto); + + expect(errors.length).toBeGreaterThan(0); + expect(errors[0].property).toBe("supportsQuery"); + }); + + it("should reject invalid type for protocolVersion", async () => { + const plain = { + protocolVersion: 1.0, // Should be string + }; + + const dto = plainToInstance(FederationCapabilitiesDto, plain); + const errors = await validate(dto); + + expect(errors.length).toBeGreaterThan(0); + expect(errors[0].property).toBe("protocolVersion"); + }); + + it("should accept only specified fields", async () => { + const plain = { + supportsQuery: true, + supportsCommand: true, + supportsEvent: false, + supportsAgentSpawn: true, + protocolVersion: "1.0", + }; + + const dto = plainToInstance(FederationCapabilitiesDto, plain); + const errors = await validate(dto); + + expect(errors).toHaveLength(0); + expect(dto.supportsQuery).toBe(true); + expect(dto.supportsCommand).toBe(true); + expect(dto.supportsEvent).toBe(false); + expect(dto.supportsAgentSpawn).toBe(true); + expect(dto.protocolVersion).toBe("1.0"); + }); +}); diff --git a/apps/api/src/federation/dto/capabilities.dto.ts b/apps/api/src/federation/dto/capabilities.dto.ts new file mode 100644 index 0000000..da606b5 --- /dev/null +++ b/apps/api/src/federation/dto/capabilities.dto.ts @@ -0,0 +1,32 @@ +/** + * Capabilities DTO + * + * Data Transfer Object for federation capabilities validation. + */ + +import { IsBoolean, IsOptional, IsString } from "class-validator"; + +/** + * DTO for validating FederationCapabilities structure + */ +export class FederationCapabilitiesDto { + @IsOptional() + @IsBoolean() + supportsQuery?: boolean; + + @IsOptional() + @IsBoolean() + supportsCommand?: boolean; + + @IsOptional() + @IsBoolean() + supportsEvent?: boolean; + + @IsOptional() + @IsBoolean() + supportsAgentSpawn?: boolean; + + @IsOptional() + @IsString() + protocolVersion?: string; +} diff --git a/apps/api/src/federation/dto/connection.dto.ts b/apps/api/src/federation/dto/connection.dto.ts index 3a15765..e65ac79 100644 --- a/apps/api/src/federation/dto/connection.dto.ts +++ b/apps/api/src/federation/dto/connection.dto.ts @@ -4,8 +4,10 @@ * Data Transfer Objects for federation connection API. */ -import { IsString, IsUrl, IsOptional, IsObject, IsNumber } from "class-validator"; +import { IsString, IsUrl, IsOptional, IsObject, IsNumber, ValidateNested } from "class-validator"; +import { Type } from "class-transformer"; import { Sanitize, SanitizeObject } from "../../common/decorators/sanitize.decorator"; +import { FederationCapabilitiesDto } from "./capabilities.dto"; /** * DTO for initiating a connection @@ -57,8 +59,9 @@ export class IncomingConnectionRequestDto { @IsString() publicKey!: string; - @IsObject() - capabilities!: Record; + @ValidateNested() + @Type(() => FederationCapabilitiesDto) + capabilities!: FederationCapabilitiesDto; @IsNumber() timestamp!: number; diff --git a/apps/api/src/federation/federation.controller.spec.ts b/apps/api/src/federation/federation.controller.spec.ts index 48b682f..320cbc1 100644 --- a/apps/api/src/federation/federation.controller.spec.ts +++ b/apps/api/src/federation/federation.controller.spec.ts @@ -339,5 +339,30 @@ describe("FederationController", () => { }); expect(connectionService.handleIncomingConnectionRequest).toHaveBeenCalled(); }); + + it("should validate capabilities structure with valid data", async () => { + const dto = { + instanceId: "remote-instance-456", + instanceUrl: "https://remote.example.com", + publicKey: "PUBLIC_KEY", + capabilities: { + supportsQuery: true, + supportsCommand: false, + supportsEvent: true, + supportsAgentSpawn: false, + protocolVersion: "1.0", + }, + timestamp: Date.now(), + signature: "valid-signature", + }; + vi.spyOn(connectionService, "handleIncomingConnectionRequest").mockResolvedValue( + mockConnection + ); + + const result = await controller.handleIncomingConnection(dto); + + expect(result.status).toBe("pending"); + expect(connectionService.handleIncomingConnectionRequest).toHaveBeenCalled(); + }); }); }); diff --git a/apps/api/src/federation/utils/retry.spec.ts b/apps/api/src/federation/utils/retry.spec.ts new file mode 100644 index 0000000..1a1b139 --- /dev/null +++ b/apps/api/src/federation/utils/retry.spec.ts @@ -0,0 +1,180 @@ +/** + * Retry Utility Tests + * + * Tests for retry logic with exponential backoff. + */ + +import { describe, it, expect, vi } from "vitest"; +import { AxiosError } from "axios"; +import { withRetry, isRetryableError } from "./retry"; + +describe("Retry Utility", () => { + describe("isRetryableError", () => { + it("should return true for ECONNREFUSED error", () => { + const error: Partial = { + code: "ECONNREFUSED", + message: "Connection refused", + name: "Error", + }; + + expect(isRetryableError(error)).toBe(true); + }); + + it("should return true for ETIMEDOUT error", () => { + const error: Partial = { + code: "ETIMEDOUT", + message: "Connection timed out", + name: "Error", + }; + + expect(isRetryableError(error)).toBe(true); + }); + + it("should return true for 5xx server errors", () => { + const error: Partial = { + response: { + status: 500, + } as never, + message: "Internal Server Error", + name: "Error", + }; + + expect(isRetryableError(error)).toBe(true); + }); + + it("should return true for 429 Too Many Requests", () => { + const error: Partial = { + response: { + status: 429, + } as never, + message: "Too Many Requests", + name: "Error", + }; + + expect(isRetryableError(error)).toBe(true); + }); + + it("should return false for 4xx client errors", () => { + const error: Partial = { + response: { + status: 404, + } as never, + message: "Not Found", + name: "Error", + }; + + expect(isRetryableError(error)).toBe(false); + }); + + it("should return false for 400 Bad Request", () => { + const error: Partial = { + response: { + status: 400, + } as never, + message: "Bad Request", + name: "Error", + }; + + expect(isRetryableError(error)).toBe(false); + }); + + it("should return false for non-Error objects", () => { + expect(isRetryableError("not an error")).toBe(false); + expect(isRetryableError(null)).toBe(false); + expect(isRetryableError(undefined)).toBe(false); + }); + }); + + describe("withRetry", () => { + it("should succeed on first attempt", async () => { + const operation = vi.fn().mockResolvedValue("success"); + + const result = await withRetry(operation); + + expect(result).toBe("success"); + expect(operation).toHaveBeenCalledTimes(1); + }); + + it("should retry on retryable error and eventually succeed", async () => { + const operation = vi + .fn() + .mockRejectedValueOnce({ + code: "ECONNREFUSED", + message: "Connection refused", + name: "Error", + }) + .mockRejectedValueOnce({ + code: "ETIMEDOUT", + message: "Timeout", + name: "Error", + }) + .mockResolvedValue("success"); + + // Use shorter delays for testing + const result = await withRetry(operation, { + initialDelay: 10, + maxDelay: 40, + }); + + expect(result).toBe("success"); + expect(operation).toHaveBeenCalledTimes(3); + }); + + it("should not retry on 4xx client errors", async () => { + const error: Partial = { + response: { + status: 400, + } as never, + message: "Bad Request", + name: "Error", + }; + + const operation = vi.fn().mockRejectedValue(error); + + await expect(withRetry(operation)).rejects.toMatchObject({ + message: "Bad Request", + }); + + expect(operation).toHaveBeenCalledTimes(1); + }); + + it("should throw error after max retries", async () => { + const error: Partial = { + code: "ECONNREFUSED", + message: "Connection refused", + name: "Error", + }; + + const operation = vi.fn().mockRejectedValue(error); + + await expect( + withRetry(operation, { + maxRetries: 3, + initialDelay: 10, + }) + ).rejects.toMatchObject({ + message: "Connection refused", + }); + + // Should be called 4 times (initial + 3 retries) + expect(operation).toHaveBeenCalledTimes(4); + }); + + it("should verify exponential backoff timing", () => { + const operation = vi.fn().mockRejectedValue({ + code: "ECONNREFUSED", + message: "Connection refused", + name: "Error", + }); + + // Just verify the function is called multiple times with retries + const promise = withRetry(operation, { + maxRetries: 2, + initialDelay: 10, + }); + + // We don't await this - just verify the retry configuration exists + expect(promise).toBeInstanceOf(Promise); + }); + }); +}); diff --git a/apps/api/src/federation/utils/retry.ts b/apps/api/src/federation/utils/retry.ts new file mode 100644 index 0000000..2831519 --- /dev/null +++ b/apps/api/src/federation/utils/retry.ts @@ -0,0 +1,146 @@ +/** + * Retry Utility + * + * Provides retry logic with exponential backoff for HTTP requests. + */ + +import { Logger } from "@nestjs/common"; +import type { AxiosError } from "axios"; + +const logger = new Logger("RetryUtil"); + +/** + * Configuration for retry logic + */ +export interface RetryConfig { + /** Maximum number of retry attempts (default: 3) */ + maxRetries?: number; + /** Initial backoff delay in milliseconds (default: 1000) */ + initialDelay?: number; + /** Maximum backoff delay in milliseconds (default: 8000) */ + maxDelay?: number; + /** Backoff multiplier (default: 2 for exponential) */ + backoffMultiplier?: number; +} + +/** + * Default retry configuration + */ +const DEFAULT_CONFIG: Required = { + maxRetries: 3, + initialDelay: 1000, // 1 second + maxDelay: 8000, // 8 seconds + backoffMultiplier: 2, +}; + +/** + * Check if error is retryable (network errors, timeouts, 5xx errors) + * Do NOT retry on 4xx errors (client errors) + */ +export function isRetryableError(error: unknown): boolean { + // Check if it's a plain object (for testing) or Error instance + if (!error || (typeof error !== "object" && !(error instanceof Error))) { + return false; + } + + const axiosError = error as AxiosError; + + // Retry on network errors (no response received) + if (!axiosError.response) { + // Check for network error codes + const networkErrorCodes = [ + "ECONNREFUSED", + "ETIMEDOUT", + "ENOTFOUND", + "ENETUNREACH", + "EAI_AGAIN", + ]; + + if (axiosError.code && networkErrorCodes.includes(axiosError.code)) { + return true; + } + + // Retry on timeout + if (axiosError.message.includes("timeout")) { + return true; + } + + return false; + } + + // Retry on 5xx server errors + const status = axiosError.response.status; + if (status >= 500 && status < 600) { + return true; + } + + // Retry on 429 (Too Many Requests) with backoff + if (status === 429) { + return true; + } + + // Do NOT retry on 4xx client errors + return false; +} + +/** + * Execute a function with retry logic and exponential backoff + */ +export async function withRetry( + operation: () => Promise, + config: RetryConfig = {} +): Promise { + const finalConfig: Required = { + ...DEFAULT_CONFIG, + ...config, + }; + + let lastError: Error | undefined; + let delay = finalConfig.initialDelay; + + for (let attempt = 0; attempt <= finalConfig.maxRetries; attempt++) { + try { + return await operation(); + } catch (error) { + lastError = error as Error; + + // If this is the last attempt, don't retry + if (attempt === finalConfig.maxRetries) { + break; + } + + // Check if error is retryable + if (!isRetryableError(error)) { + logger.warn(`Non-retryable error, aborting retry: ${lastError.message}`); + throw error; + } + + // Log retry attempt + const errorMessage = lastError instanceof Error ? lastError.message : "Unknown error"; + logger.warn( + `Retry attempt ${String(attempt + 1)}/${String(finalConfig.maxRetries)} after error: ${errorMessage}. Retrying in ${String(delay)}ms...` + ); + + // Wait with exponential backoff + await sleep(delay); + + // Calculate next delay with exponential backoff + delay = Math.min(delay * finalConfig.backoffMultiplier, finalConfig.maxDelay); + } + } + + // All retries exhausted + if (lastError) { + throw lastError; + } + + // Should never reach here, but satisfy TypeScript + throw new Error("Operation failed after retries"); +} + +/** + * Sleep for specified milliseconds + */ +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +}