diff --git a/apps/gateway/src/agent/agent.service.ts b/apps/gateway/src/agent/agent.service.ts index 2e90197..4910c7e 100644 --- a/apps/gateway/src/agent/agent.service.ts +++ b/apps/gateway/src/agent/agent.service.ts @@ -16,6 +16,10 @@ import { ProviderService } from './provider.service.js'; import { createBrainTools } from './tools/brain-tools.js'; import { createCoordTools } from './tools/coord-tools.js'; import { createMemoryTools } from './tools/memory-tools.js'; +import { createFileTools } from './tools/file-tools.js'; +import { createGitTools } from './tools/git-tools.js'; +import { createShellTools } from './tools/shell-tools.js'; +import { createWebTools } from './tools/web-tools.js'; import type { SessionInfoDto } from './session.dto.js'; export interface AgentSessionOptions { @@ -50,10 +54,18 @@ export class AgentService implements OnModuleDestroy { @Inject(EmbeddingService) private readonly embeddingService: EmbeddingService, @Inject(CoordService) private readonly coordService: CoordService, ) { + const fileBaseDir = process.env['AGENT_FILE_SANDBOX_DIR'] ?? process.cwd(); + const gitDefaultCwd = process.env['AGENT_GIT_CWD'] ?? process.cwd(); + const shellDefaultCwd = process.env['AGENT_SHELL_CWD'] ?? process.cwd(); + this.customTools = [ ...createBrainTools(brain), ...createCoordTools(coordService), ...createMemoryTools(memory, embeddingService.available ? embeddingService : null), + ...createFileTools(fileBaseDir), + ...createGitTools(gitDefaultCwd), + ...createShellTools(shellDefaultCwd), + ...createWebTools(), ]; this.logger.log(`Registered ${this.customTools.length} custom tools`); } diff --git a/apps/gateway/src/agent/tools/file-tools.ts b/apps/gateway/src/agent/tools/file-tools.ts new file mode 100644 index 0000000..c893217 --- /dev/null +++ b/apps/gateway/src/agent/tools/file-tools.ts @@ -0,0 +1,189 @@ +import { Type } from '@sinclair/typebox'; +import type { ToolDefinition } from '@mariozechner/pi-coding-agent'; +import { readFile, writeFile, readdir, stat } from 'node:fs/promises'; +import { resolve, relative, join } from 'node:path'; + +/** + * Safety constraint: all file operations are restricted to a base directory. + * Paths that escape the sandbox via ../ traversal are rejected. + */ +function resolveSafe(baseDir: string, inputPath: string): string { + const resolved = resolve(baseDir, inputPath); + const rel = relative(baseDir, resolved); + if (rel.startsWith('..') || resolve(resolved) !== resolve(join(baseDir, rel))) { + throw new Error(`Path escape detected: "${inputPath}" resolves outside base directory`); + } + return resolved; +} + +const MAX_READ_BYTES = 512 * 1024; // 512 KB read limit +const MAX_WRITE_BYTES = 1024 * 1024; // 1 MB write limit + +export function createFileTools(baseDir: string): ToolDefinition[] { + const readFileTool: ToolDefinition = { + name: 'fs_read_file', + label: 'Read File', + description: + 'Read the contents of a file. Path is resolved relative to the sandbox base directory.', + parameters: Type.Object({ + path: Type.String({ + description: 'File path (relative to sandbox base or absolute within it)', + }), + encoding: Type.Optional( + Type.String({ description: 'Encoding: utf8 (default), base64, hex' }), + ), + }), + async execute(_toolCallId, params) { + const { path, encoding } = params as { path: string; encoding?: string }; + let safePath: string; + try { + safePath = resolveSafe(baseDir, path); + } catch (err) { + return { + content: [{ type: 'text' as const, text: `Error: ${String(err)}` }], + details: undefined, + }; + } + + try { + const info = await stat(safePath); + if (!info.isFile()) { + return { + content: [{ type: 'text' as const, text: `Error: path is not a file: ${path}` }], + details: undefined, + }; + } + if (info.size > MAX_READ_BYTES) { + return { + content: [ + { + type: 'text' as const, + text: `Error: file too large (${info.size} bytes, limit ${MAX_READ_BYTES} bytes)`, + }, + ], + details: undefined, + }; + } + const enc = (encoding ?? 'utf8') as BufferEncoding; + const content = await readFile(safePath, { encoding: enc }); + return { + content: [{ type: 'text' as const, text: String(content) }], + details: undefined, + }; + } catch (err) { + return { + content: [{ type: 'text' as const, text: `Error reading file: ${String(err)}` }], + details: undefined, + }; + } + }, + }; + + const writeFileTool: ToolDefinition = { + name: 'fs_write_file', + label: 'Write File', + description: + 'Write content to a file. Path is resolved relative to the sandbox base directory. Overwrites existing file.', + parameters: Type.Object({ + path: Type.String({ + description: 'File path (relative to sandbox base or absolute within it)', + }), + content: Type.String({ description: 'Content to write' }), + encoding: Type.Optional(Type.String({ description: 'Encoding: utf8 (default), base64' })), + }), + async execute(_toolCallId, params) { + const { path, content, encoding } = params as { + path: string; + content: string; + encoding?: string; + }; + let safePath: string; + try { + safePath = resolveSafe(baseDir, path); + } catch (err) { + return { + content: [{ type: 'text' as const, text: `Error: ${String(err)}` }], + details: undefined, + }; + } + + if (Buffer.byteLength(content, 'utf8') > MAX_WRITE_BYTES) { + return { + content: [ + { + type: 'text' as const, + text: `Error: content too large (limit ${MAX_WRITE_BYTES} bytes)`, + }, + ], + details: undefined, + }; + } + + try { + const enc = (encoding ?? 'utf8') as BufferEncoding; + await writeFile(safePath, content, { encoding: enc }); + return { + content: [{ type: 'text' as const, text: `File written successfully: ${path}` }], + details: undefined, + }; + } catch (err) { + return { + content: [{ type: 'text' as const, text: `Error writing file: ${String(err)}` }], + details: undefined, + }; + } + }, + }; + + const listDirectoryTool: ToolDefinition = { + name: 'fs_list_directory', + label: 'List Directory', + description: 'List files and directories at a given path within the sandbox base directory.', + parameters: Type.Object({ + path: Type.Optional( + Type.String({ + description: 'Directory path (relative to sandbox base). Defaults to base directory.', + }), + ), + }), + async execute(_toolCallId, params) { + const { path } = params as { path?: string }; + const target = path ?? '.'; + let safePath: string; + try { + safePath = resolveSafe(baseDir, target); + } catch (err) { + return { + content: [{ type: 'text' as const, text: `Error: ${String(err)}` }], + details: undefined, + }; + } + + try { + const info = await stat(safePath); + if (!info.isDirectory()) { + return { + content: [{ type: 'text' as const, text: `Error: path is not a directory: ${target}` }], + details: undefined, + }; + } + const entries = await readdir(safePath, { withFileTypes: true }); + const items = entries.map((e) => ({ + name: e.name, + type: e.isDirectory() ? 'directory' : e.isSymbolicLink() ? 'symlink' : 'file', + })); + return { + content: [{ type: 'text' as const, text: JSON.stringify(items, null, 2) }], + details: undefined, + }; + } catch (err) { + return { + content: [{ type: 'text' as const, text: `Error listing directory: ${String(err)}` }], + details: undefined, + }; + } + }, + }; + + return [readFileTool, writeFileTool, listDirectoryTool]; +} diff --git a/apps/gateway/src/agent/tools/git-tools.ts b/apps/gateway/src/agent/tools/git-tools.ts new file mode 100644 index 0000000..4c612f3 --- /dev/null +++ b/apps/gateway/src/agent/tools/git-tools.ts @@ -0,0 +1,135 @@ +import { Type } from '@sinclair/typebox'; +import type { ToolDefinition } from '@mariozechner/pi-coding-agent'; +import { exec } from 'node:child_process'; +import { promisify } from 'node:util'; + +const execAsync = promisify(exec); + +const GIT_TIMEOUT_MS = 15_000; +const MAX_OUTPUT_BYTES = 100 * 1024; // 100 KB + +async function runGit( + args: string[], + cwd?: string, +): Promise<{ stdout: string; stderr: string; error?: string }> { + // Only allow specific safe read-only git subcommands + const allowedSubcommands = ['status', 'log', 'diff', 'show', 'branch', 'tag', 'ls-files']; + const subcommand = args[0]; + if (!subcommand || !allowedSubcommands.includes(subcommand)) { + return { + stdout: '', + stderr: '', + error: `Blocked: git subcommand "${subcommand}" is not allowed. Permitted: ${allowedSubcommands.join(', ')}`, + }; + } + + const cmd = `git ${args.map((a) => JSON.stringify(a)).join(' ')}`; + try { + const { stdout, stderr } = await execAsync(cmd, { + cwd, + timeout: GIT_TIMEOUT_MS, + maxBuffer: MAX_OUTPUT_BYTES, + }); + return { stdout, stderr }; + } catch (err: unknown) { + const e = err as { stdout?: string; stderr?: string; message?: string }; + return { + stdout: e.stdout ?? '', + stderr: e.stderr ?? '', + error: e.message ?? String(err), + }; + } +} + +export function createGitTools(defaultCwd?: string): ToolDefinition[] { + const gitStatus: ToolDefinition = { + name: 'git_status', + label: 'Git Status', + description: 'Show the working tree status (staged, unstaged, untracked files).', + parameters: Type.Object({ + cwd: Type.Optional(Type.String({ description: 'Repository working directory.' })), + }), + async execute(_toolCallId, params) { + const { cwd } = params as { cwd?: string }; + const result = await runGit(['status', '--short', '--branch'], cwd ?? defaultCwd); + const text = result.error + ? `Error: ${result.error}\n${result.stderr}` + : result.stdout || '(no output)'; + return { + content: [{ type: 'text' as const, text: text }], + details: undefined, + }; + }, + }; + + const gitLog: ToolDefinition = { + name: 'git_log', + label: 'Git Log', + description: 'Show recent commit history.', + parameters: Type.Object({ + limit: Type.Optional(Type.Number({ description: 'Number of commits to show (default 20)' })), + oneline: Type.Optional( + Type.Boolean({ description: 'Compact one-line format (default true)' }), + ), + cwd: Type.Optional(Type.String({ description: 'Repository working directory.' })), + }), + async execute(_toolCallId, params) { + const { limit, oneline, cwd } = params as { + limit?: number; + oneline?: boolean; + cwd?: string; + }; + const args = ['log', `--max-count=${limit ?? 20}`]; + if (oneline !== false) args.push('--oneline'); + const result = await runGit(args, cwd ?? defaultCwd); + const text = result.error + ? `Error: ${result.error}\n${result.stderr}` + : result.stdout || '(no commits)'; + return { + content: [{ type: 'text' as const, text: text }], + details: undefined, + }; + }, + }; + + const gitDiff: ToolDefinition = { + name: 'git_diff', + label: 'Git Diff', + description: 'Show changes between commits, working tree, or staged changes.', + parameters: Type.Object({ + staged: Type.Optional( + Type.Boolean({ description: 'Show staged (cached) changes instead of unstaged' }), + ), + ref: Type.Optional( + Type.String({ description: 'Compare against this ref (commit SHA, branch, or tag)' }), + ), + path: Type.Optional( + Type.String({ description: 'Limit diff to a specific file or directory' }), + ), + cwd: Type.Optional(Type.String({ description: 'Repository working directory.' })), + }), + async execute(_toolCallId, params) { + const { staged, ref, path, cwd } = params as { + staged?: boolean; + ref?: string; + path?: string; + cwd?: string; + }; + const args = ['diff']; + if (staged) args.push('--cached'); + if (ref) args.push(ref); + args.push('--'); + if (path) args.push(path); + const result = await runGit(args, cwd ?? defaultCwd); + const text = result.error + ? `Error: ${result.error}\n${result.stderr}` + : result.stdout || '(no diff)'; + return { + content: [{ type: 'text' as const, text: text }], + details: undefined, + }; + }, + }; + + return [gitStatus, gitLog, gitDiff]; +} diff --git a/apps/gateway/src/agent/tools/index.ts b/apps/gateway/src/agent/tools/index.ts index b2e0416..80c98d2 100644 --- a/apps/gateway/src/agent/tools/index.ts +++ b/apps/gateway/src/agent/tools/index.ts @@ -1,2 +1,6 @@ export { createBrainTools } from './brain-tools.js'; export { createCoordTools } from './coord-tools.js'; +export { createFileTools } from './file-tools.js'; +export { createGitTools } from './git-tools.js'; +export { createShellTools } from './shell-tools.js'; +export { createWebTools } from './web-tools.js'; diff --git a/apps/gateway/src/agent/tools/shell-tools.ts b/apps/gateway/src/agent/tools/shell-tools.ts new file mode 100644 index 0000000..6e801e5 --- /dev/null +++ b/apps/gateway/src/agent/tools/shell-tools.ts @@ -0,0 +1,195 @@ +import { Type } from '@sinclair/typebox'; +import type { ToolDefinition } from '@mariozechner/pi-coding-agent'; +import { spawn } from 'node:child_process'; + +const DEFAULT_TIMEOUT_MS = 30_000; +const MAX_OUTPUT_BYTES = 100 * 1024; // 100 KB + +/** + * Commands that are outright blocked for safety. + * This is a denylist; the agent should be instructed to use + * the least-privilege command necessary. + */ +const BLOCKED_COMMANDS = new Set([ + 'rm', + 'rmdir', + 'mkfs', + 'dd', + 'format', + 'fdisk', + 'parted', + 'shred', + 'wipefs', + 'sudo', + 'su', + 'chown', + 'chmod', + 'passwd', + 'useradd', + 'userdel', + 'groupadd', + 'shutdown', + 'reboot', + 'halt', + 'poweroff', + 'kill', + 'killall', + 'pkill', + 'curl', + 'wget', + 'nc', + 'netcat', + 'ncat', + 'ssh', + 'scp', + 'sftp', + 'rsync', + 'iptables', + 'ip6tables', + 'nft', + 'ufw', + 'firewall-cmd', + 'docker', + 'podman', + 'kubectl', + 'helm', + 'terraform', + 'ansible', + 'crontab', + 'at', + 'batch', +]); + +function extractBaseCommand(command: string): string { + // Extract the first word (the binary name), stripping path + const trimmed = command.trim(); + const firstToken = trimmed.split(/\s+/)[0] ?? ''; + return firstToken.split('/').pop() ?? firstToken; +} + +function runCommand( + command: string, + options: { timeoutMs: number; cwd?: string }, +): Promise<{ stdout: string; stderr: string; exitCode: number | null; timedOut: boolean }> { + return new Promise((resolve) => { + const child = spawn('sh', ['-c', command], { + cwd: options.cwd, + stdio: ['ignore', 'pipe', 'pipe'], + detached: false, + }); + + let stdout = ''; + let stderr = ''; + let timedOut = false; + let totalBytes = 0; + let truncated = false; + + child.stdout?.on('data', (chunk: Buffer) => { + if (truncated) return; + totalBytes += chunk.length; + if (totalBytes > MAX_OUTPUT_BYTES) { + stdout += chunk.subarray(0, MAX_OUTPUT_BYTES - (totalBytes - chunk.length)).toString(); + stdout += '\n[output truncated at 100 KB limit]'; + truncated = true; + child.kill('SIGTERM'); + } else { + stdout += chunk.toString(); + } + }); + + child.stderr?.on('data', (chunk: Buffer) => { + if (stderr.length < MAX_OUTPUT_BYTES) { + stderr += chunk.toString(); + } + }); + + const timer = setTimeout(() => { + timedOut = true; + child.kill('SIGTERM'); + setTimeout(() => { + try { + child.kill('SIGKILL'); + } catch { + // already exited + } + }, 2000); + }, options.timeoutMs); + + child.on('close', (exitCode) => { + clearTimeout(timer); + resolve({ stdout, stderr, exitCode, timedOut }); + }); + + child.on('error', (err) => { + clearTimeout(timer); + resolve({ stdout, stderr: stderr + String(err), exitCode: null, timedOut: false }); + }); + }); +} + +export function createShellTools(defaultCwd?: string): ToolDefinition[] { + const shellExec: ToolDefinition = { + name: 'shell_exec', + label: 'Shell Execute', + description: + 'Execute a shell command with timeout and output limits. Dangerous commands (rm, sudo, docker, etc.) are blocked.', + parameters: Type.Object({ + command: Type.String({ description: 'Shell command to execute' }), + cwd: Type.Optional(Type.String({ description: 'Working directory for the command.' })), + timeout: Type.Optional( + Type.Number({ description: 'Timeout in milliseconds (default 30000, max 60000)' }), + ), + }), + async execute(_toolCallId, params) { + const { command, cwd, timeout } = params as { + command: string; + cwd?: string; + timeout?: number; + }; + + const base = extractBaseCommand(command); + if (BLOCKED_COMMANDS.has(base)) { + return { + content: [ + { + type: 'text' as const, + text: `Error: command "${base}" is blocked for safety reasons.`, + }, + ], + details: undefined, + }; + } + + const timeoutMs = Math.min(timeout ?? DEFAULT_TIMEOUT_MS, 60_000); + + const result = await runCommand(command, { + timeoutMs, + cwd: cwd ?? defaultCwd, + }); + + if (result.timedOut) { + return { + content: [ + { + type: 'text' as const, + text: `Command timed out after ${timeoutMs}ms.\nPartial stdout:\n${result.stdout}\nPartial stderr:\n${result.stderr}`, + }, + ], + details: undefined, + }; + } + + const parts: string[] = []; + if (result.stdout) parts.push(`stdout:\n${result.stdout}`); + if (result.stderr) parts.push(`stderr:\n${result.stderr}`); + parts.push(`exit code: ${result.exitCode ?? 'null'}`); + + return { + content: [{ type: 'text' as const, text: parts.join('\n') }], + details: undefined, + }; + }, + }; + + return [shellExec]; +} diff --git a/apps/gateway/src/agent/tools/web-tools.ts b/apps/gateway/src/agent/tools/web-tools.ts new file mode 100644 index 0000000..acc3dbd --- /dev/null +++ b/apps/gateway/src/agent/tools/web-tools.ts @@ -0,0 +1,225 @@ +import { Type } from '@sinclair/typebox'; +import type { ToolDefinition } from '@mariozechner/pi-coding-agent'; + +const DEFAULT_TIMEOUT_MS = 15_000; +const MAX_RESPONSE_BYTES = 512 * 1024; // 512 KB + +/** + * Blocked URL patterns (private IP ranges, localhost, link-local). + */ +const BLOCKED_HOSTNAMES = [ + /^localhost$/i, + /^127\./, + /^10\./, + /^172\.(1[6-9]|2\d|3[01])\./, + /^192\.168\./, + /^::1$/, + /^fc[0-9a-f][0-9a-f]:/i, + /^fe80:/i, + /^0\.0\.0\.0$/, + /^169\.254\./, +]; + +function isBlockedUrl(urlString: string): string | null { + let parsed: URL; + try { + parsed = new URL(urlString); + } catch { + return `Invalid URL: ${urlString}`; + } + if (parsed.protocol !== 'http:' && parsed.protocol !== 'https:') { + return `Unsupported protocol: ${parsed.protocol}. Only http and https are allowed.`; + } + const hostname = parsed.hostname; + for (const pattern of BLOCKED_HOSTNAMES) { + if (pattern.test(hostname)) { + return `Blocked: requests to "${hostname}" are not allowed (private/local addresses).`; + } + } + return null; +} + +async function fetchWithLimit( + url: string, + options: RequestInit, + timeoutMs: number, +): Promise<{ text: string; status: number; contentType: string }> { + const controller = new AbortController(); + const timer = setTimeout(() => controller.abort(), timeoutMs); + + try { + const response = await fetch(url, { ...options, signal: controller.signal }); + const contentType = response.headers.get('content-type') ?? ''; + + // Stream response and enforce size limit + const reader = response.body?.getReader(); + if (!reader) { + return { text: '', status: response.status, contentType }; + } + + const chunks: Uint8Array[] = []; + let totalBytes = 0; + let truncated = false; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + totalBytes += value.length; + if (totalBytes > MAX_RESPONSE_BYTES) { + const remaining = MAX_RESPONSE_BYTES - (totalBytes - value.length); + chunks.push(value.subarray(0, remaining)); + truncated = true; + reader.cancel(); + break; + } + chunks.push(value); + } + + const combined = new Uint8Array(chunks.reduce((acc, c) => acc + c.length, 0)); + let offset = 0; + for (const chunk of chunks) { + combined.set(chunk, offset); + offset += chunk.length; + } + + let text = new TextDecoder().decode(combined); + if (truncated) { + text += '\n[response truncated at 512 KB limit]'; + } + + return { text, status: response.status, contentType }; + } finally { + clearTimeout(timer); + } +} + +export function createWebTools(): ToolDefinition[] { + const webGet: ToolDefinition = { + name: 'web_get', + label: 'HTTP GET', + description: + 'Perform an HTTP GET request and return the response body. Private/local addresses are blocked.', + parameters: Type.Object({ + url: Type.String({ description: 'URL to fetch (http/https only)' }), + headers: Type.Optional( + Type.Record(Type.String(), Type.String(), { + description: 'Optional request headers as key-value pairs', + }), + ), + timeout: Type.Optional( + Type.Number({ description: 'Timeout in milliseconds (default 15000, max 30000)' }), + ), + }), + async execute(_toolCallId, params) { + const { url, headers, timeout } = params as { + url: string; + headers?: Record; + timeout?: number; + }; + + const blocked = isBlockedUrl(url); + if (blocked) { + return { + content: [{ type: 'text' as const, text: `Error: ${blocked}` }], + details: undefined, + }; + } + + const timeoutMs = Math.min(timeout ?? DEFAULT_TIMEOUT_MS, 30_000); + + try { + const result = await fetchWithLimit( + url, + { method: 'GET', headers: headers ?? {} }, + timeoutMs, + ); + return { + content: [ + { + type: 'text' as const, + text: `HTTP ${result.status} (${result.contentType})\n\n${result.text}`, + }, + ], + details: undefined, + }; + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + return { + content: [{ type: 'text' as const, text: `Error fetching URL: ${msg}` }], + details: undefined, + }; + } + }, + }; + + const webPost: ToolDefinition = { + name: 'web_post', + label: 'HTTP POST', + description: + 'Perform an HTTP POST request with a JSON or text body. Private/local addresses are blocked.', + parameters: Type.Object({ + url: Type.String({ description: 'URL to POST to (http/https only)' }), + body: Type.String({ description: 'Request body (JSON string or plain text)' }), + contentType: Type.Optional( + Type.String({ description: 'Content-Type header (default: application/json)' }), + ), + headers: Type.Optional( + Type.Record(Type.String(), Type.String(), { + description: 'Optional additional request headers', + }), + ), + timeout: Type.Optional( + Type.Number({ description: 'Timeout in milliseconds (default 15000, max 30000)' }), + ), + }), + async execute(_toolCallId, params) { + const { url, body, contentType, headers, timeout } = params as { + url: string; + body: string; + contentType?: string; + headers?: Record; + timeout?: number; + }; + + const blocked = isBlockedUrl(url); + if (blocked) { + return { + content: [{ type: 'text' as const, text: `Error: ${blocked}` }], + details: undefined, + }; + } + + const timeoutMs = Math.min(timeout ?? DEFAULT_TIMEOUT_MS, 30_000); + const ct = contentType ?? 'application/json'; + + try { + const result = await fetchWithLimit( + url, + { + method: 'POST', + headers: { 'Content-Type': ct, ...(headers ?? {}) }, + body, + }, + timeoutMs, + ); + return { + content: [ + { + type: 'text' as const, + text: `HTTP ${result.status} (${result.contentType})\n\n${result.text}`, + }, + ], + details: undefined, + }; + } catch (err) { + const msg = err instanceof Error ? err.message : String(err); + return { + content: [{ type: 'text' as const, text: `Error posting to URL: ${msg}` }], + details: undefined, + }; + } + }, + }; + + return [webGet, webPost]; +}