Compare commits
143 Commits
v0.0.5
...
360d7fe96d
| Author | SHA1 | Date | |
|---|---|---|---|
| 360d7fe96d | |||
| 08da6b76d1 | |||
| 5d4efb467c | |||
| 6c6bcbdb7f | |||
| cfdd2b679c | |||
| 34d4dbbabd | |||
| 78d591b697 | |||
| e95c70d329 | |||
| d8ac088f3a | |||
| 0d7f3c6d14 | |||
| eddcca7533 | |||
| ad06e00f99 | |||
| 5b089392fd | |||
| 02ff3b3256 | |||
| 1d14ddcfe7 | |||
| 05a805eeca | |||
| ebf99d9ff7 | |||
| cf51fd6749 | |||
| bb22857fde | |||
| 5261048d67 | |||
| 36095ad80f | |||
| d06866f501 | |||
| 02e40f6c3c | |||
| de64695ac5 | |||
| dd108b9ab4 | |||
| f3e90df2a0 | |||
| 721e6bbc52 | |||
| 27848bf42e | |||
| 061edcaa78 | |||
| cbb729f377 | |||
| cfb491e127 | |||
| 20808b9b84 | |||
| fd61a36b01 | |||
| c0a7bae977 | |||
| 68e056ac91 | |||
| 77ba13b41b | |||
| 307bb427d6 | |||
| b89503fa8c | |||
| 254da35300 | |||
| 99926cdba2 | |||
| 25f880416a | |||
| 1138148543 | |||
| 4b70b603b3 | |||
| 2e7711fe65 | |||
| 417a57fa00 | |||
| 714fee52b9 | |||
| 133668f5b2 | |||
| 3b81bc9f3d | |||
| cbfd6fb996 | |||
| 3f8553ce07 | |||
| bf668e18f1 | |||
| 1f2b8125c6 | |||
| 93645295d5 | |||
| 7a52652be6 | |||
| 791c8f505e | |||
| 12653477d6 | |||
| dedfa0d9ac | |||
| c1d3dfd77e | |||
| f0476cae92 | |||
| b6effdcd6b | |||
| 39ef2ff123 | |||
| a989b5e549 | |||
| ff27e944a1 | |||
| 0821393c1d | |||
| 24f5c0699a | |||
| 96409c40bf | |||
| 8628f4f93a | |||
| b649b5c987 | |||
| b4d03a8b49 | |||
| 85aeebbde2 | |||
| a4bb563779 | |||
| 7f6464bbda | |||
| f0741e045f | |||
| 5a1991924c | |||
| bd5d14d07f | |||
| d5a1791dc5 | |||
| bd81c12071 | |||
| 4da255bf04 | |||
| 82c10a7b33 | |||
| d31070177c | |||
| 3792576566 | |||
| cd57c75e41 | |||
| 237a863dfd | |||
| cb92ba16c1 | |||
| 70e9f2c6bc | |||
| a760401407 | |||
| 22a5e9791c | |||
| d1bef49b4e | |||
| 76abf11eba | |||
| c4850fe6c1 | |||
| 0809f4e787 | |||
| 6a4c020179 | |||
| 3bb401641e | |||
| 54b821d8bd | |||
| 09e649fc7e | |||
| f208f72dc0 | |||
| d42cd68ea4 | |||
| 07647c8382 | |||
| 8633823257 | |||
| d0999a8e37 | |||
| ea800e3f14 | |||
| 5d2e6fae63 | |||
| fcd22c788a | |||
| ab61a15edc | |||
| 2c60459851 | |||
| ea524a6ba1 | |||
| 997a6d134f | |||
| 8aaf229483 | |||
| 049bb719e8 | |||
| 014ebdacda | |||
| 72a73c859c | |||
| 6d2b81f6e4 | |||
| 9d01a0d484 | |||
| d5102f62fa | |||
| a881e707e2 | |||
| 7d04874f3c | |||
| 9f036242fa | |||
| c4e52085e3 | |||
| 84e1868028 | |||
| f94f9f672b | |||
| cd29fc8708 | |||
| 6e22c0fdeb | |||
| 1f4d54e474 | |||
| b7a39b45d7 | |||
| 1bfdc91f90 | |||
| 58a90ac9d7 | |||
| 684dbdc6a4 | |||
| e92de12cf9 | |||
| 1f784a6a04 | |||
| ab37c2e69f | |||
| c8f3e0db44 | |||
| 02772a3910 | |||
| 85a25fd995 | |||
| 20f302367c | |||
| 54c6bfded0 | |||
| ca5472bc31 | |||
| 55b5a31c3c | |||
| 01e9891243 | |||
| 446a424c1f | |||
| 02a0d515d9 | |||
| 2bf3816efc | |||
| 96902bab44 | |||
| 280c5351e2 |
154
.env.example
154
.env.example
@@ -1,20 +1,154 @@
|
||||
# Database (port 5433 avoids conflict with host PostgreSQL)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Mosaic — Environment Variables Reference
|
||||
# Copy this file to .env and fill in the values for your deployment.
|
||||
# Lines beginning with # are comments; optional vars are commented out.
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ─── Database (PostgreSQL 17 + pgvector) ─────────────────────────────────────
|
||||
# Full connection string used by the gateway, ORM, and migration runner.
|
||||
# Port 5433 avoids conflict with a host-side PostgreSQL instance.
|
||||
DATABASE_URL=postgresql://mosaic:mosaic@localhost:5433/mosaic
|
||||
|
||||
# Valkey (Redis-compatible, port 6380 avoids conflict with host Redis/Valkey)
|
||||
# Docker Compose host-port override for the PostgreSQL container (default: 5433)
|
||||
# PG_HOST_PORT=5433
|
||||
|
||||
|
||||
# ─── Queue (Valkey 8 / Redis-compatible) ─────────────────────────────────────
|
||||
# Port 6380 avoids conflict with a host-side Redis/Valkey instance.
|
||||
VALKEY_URL=redis://localhost:6380
|
||||
|
||||
# Docker Compose host port overrides (optional)
|
||||
# PG_HOST_PORT=5433
|
||||
# Docker Compose host-port override for the Valkey container (default: 6380)
|
||||
# VALKEY_HOST_PORT=6380
|
||||
|
||||
# OpenTelemetry
|
||||
OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318
|
||||
OTEL_SERVICE_NAME=mosaic-gateway
|
||||
|
||||
# Auth (BetterAuth)
|
||||
# ─── Gateway ─────────────────────────────────────────────────────────────────
|
||||
# TCP port the NestJS/Fastify gateway listens on (default: 4000)
|
||||
GATEWAY_PORT=4000
|
||||
|
||||
# Comma-separated list of allowed CORS origins.
|
||||
# Must include the web app origin in production.
|
||||
GATEWAY_CORS_ORIGIN=http://localhost:3000
|
||||
|
||||
|
||||
# ─── Auth (BetterAuth) ───────────────────────────────────────────────────────
|
||||
# REQUIRED — random secret used to sign sessions and tokens.
|
||||
# Generate with: openssl rand -base64 32
|
||||
BETTER_AUTH_SECRET=change-me-to-a-random-32-char-string
|
||||
|
||||
# Public base URL of the gateway (used by BetterAuth for callback URLs)
|
||||
BETTER_AUTH_URL=http://localhost:4000
|
||||
|
||||
# Gateway
|
||||
GATEWAY_PORT=4000
|
||||
|
||||
# ─── Web App (Next.js) ───────────────────────────────────────────────────────
|
||||
# Public gateway URL — accessible from the browser, not just the server.
|
||||
NEXT_PUBLIC_GATEWAY_URL=http://localhost:4000
|
||||
|
||||
|
||||
# ─── OpenTelemetry ───────────────────────────────────────────────────────────
|
||||
# OTLP HTTP endpoint (otel-collector or any OpenTelemetry-compatible backend)
|
||||
OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318
|
||||
|
||||
# Service name shown in traces
|
||||
OTEL_SERVICE_NAME=mosaic-gateway
|
||||
|
||||
|
||||
# ─── AI Providers ────────────────────────────────────────────────────────────
|
||||
|
||||
# Ollama (local models — set OLLAMA_BASE_URL to enable)
|
||||
# OLLAMA_BASE_URL=http://localhost:11434
|
||||
# OLLAMA_HOST is a legacy alias for OLLAMA_BASE_URL
|
||||
# OLLAMA_HOST=http://localhost:11434
|
||||
# Comma-separated list of Ollama model IDs to register (default: llama3.2,codellama,mistral)
|
||||
# OLLAMA_MODELS=llama3.2,codellama,mistral
|
||||
|
||||
# Anthropic (claude-sonnet-4-6, claude-opus-4-6, claude-haiku-4-5)
|
||||
# ANTHROPIC_API_KEY=sk-ant-...
|
||||
|
||||
# OpenAI (gpt-4o, gpt-4o-mini, o3-mini)
|
||||
# OPENAI_API_KEY=sk-...
|
||||
|
||||
# Z.ai / GLM (glm-4.5, glm-4.5-air, glm-4.5-flash)
|
||||
# ZAI_API_KEY=...
|
||||
|
||||
# Custom providers — JSON array of provider configs
|
||||
# Format: [{"id":"<id>","baseUrl":"<url>","apiKey":"<key>","models":[{"id":"<model-id>","name":"<label>"}]}]
|
||||
# MOSAIC_CUSTOM_PROVIDERS=
|
||||
|
||||
|
||||
# ─── Embedding Service ───────────────────────────────────────────────────────
|
||||
# OpenAI-compatible embeddings endpoint (default: OpenAI)
|
||||
# EMBEDDING_API_URL=https://api.openai.com/v1
|
||||
# EMBEDDING_MODEL=text-embedding-3-small
|
||||
|
||||
|
||||
# ─── Log Summarization Service ───────────────────────────────────────────────
|
||||
# OpenAI-compatible chat completions endpoint for log summarization (default: OpenAI)
|
||||
# SUMMARIZATION_API_URL=https://api.openai.com/v1
|
||||
# SUMMARIZATION_MODEL=gpt-4o-mini
|
||||
|
||||
# Cron schedule for summarization job (default: every 6 hours)
|
||||
# SUMMARIZATION_CRON=0 */6 * * *
|
||||
|
||||
# Cron schedule for log tier management (default: daily at 03:00)
|
||||
# TIER_MANAGEMENT_CRON=0 3 * * *
|
||||
|
||||
|
||||
# ─── Agent ───────────────────────────────────────────────────────────────────
|
||||
# Filesystem sandbox root for agent file tools (default: process.cwd())
|
||||
# AGENT_FILE_SANDBOX_DIR=/var/lib/mosaic/sandbox
|
||||
|
||||
# Comma-separated list of tool names available to non-admin users.
|
||||
# Leave unset to allow all tools for all authenticated users.
|
||||
# AGENT_USER_TOOLS=read_file,list_directory,search_files
|
||||
|
||||
# System prompt injected into every agent session (optional)
|
||||
# AGENT_SYSTEM_PROMPT=You are a helpful assistant.
|
||||
|
||||
|
||||
# ─── MCP Servers ─────────────────────────────────────────────────────────────
|
||||
# JSON array of MCP server configs — set to enable MCP tool integration.
|
||||
# Each entry: {"name":"<id>","url":"<http-or-sse-url>"}
|
||||
# MCP_SERVERS=[{"name":"my-mcp","url":"http://localhost:3100/sse"}]
|
||||
|
||||
|
||||
# ─── Coordinator ─────────────────────────────────────────────────────────────
|
||||
# Root directory used to scope coordinator (worktree/repo) operations.
|
||||
# Defaults to the monorepo root auto-detected from process.cwd().
|
||||
# MOSAIC_WORKSPACE_ROOT=/home/user/projects/mosaic
|
||||
|
||||
|
||||
# ─── Discord Plugin (optional — set DISCORD_BOT_TOKEN to enable) ─────────────
|
||||
# DISCORD_BOT_TOKEN=
|
||||
# DISCORD_GUILD_ID=
|
||||
# DISCORD_GATEWAY_URL=http://localhost:4000
|
||||
|
||||
|
||||
# ─── Telegram Plugin (optional — set TELEGRAM_BOT_TOKEN to enable) ───────────
|
||||
# TELEGRAM_BOT_TOKEN=
|
||||
# TELEGRAM_GATEWAY_URL=http://localhost:4000
|
||||
|
||||
|
||||
# ─── SSO Providers (add credentials to enable) ───────────────────────────────
|
||||
|
||||
# --- Authentik (optional — set AUTHENTIK_CLIENT_ID to enable) ---
|
||||
# AUTHENTIK_ISSUER=https://auth.example.com/application/o/mosaic/
|
||||
# AUTHENTIK_CLIENT_ID=
|
||||
# AUTHENTIK_CLIENT_SECRET=
|
||||
|
||||
# --- WorkOS (optional — set WORKOS_CLIENT_ID to enable) ---
|
||||
# WORKOS_ISSUER=https://your-company.authkit.app
|
||||
# WORKOS_CLIENT_ID=client_...
|
||||
# WORKOS_CLIENT_SECRET=sk_live_...
|
||||
|
||||
# --- Keycloak (optional — set KEYCLOAK_CLIENT_ID to enable) ---
|
||||
# KEYCLOAK_ISSUER=https://auth.example.com/realms/master
|
||||
# Legacy alternative if you prefer to compose the issuer from separate vars:
|
||||
# KEYCLOAK_URL=https://auth.example.com
|
||||
# KEYCLOAK_REALM=master
|
||||
# KEYCLOAK_CLIENT_ID=mosaic
|
||||
# KEYCLOAK_CLIENT_SECRET=
|
||||
|
||||
# Feature flags — set to true alongside provider credentials to show SSO buttons in the UI
|
||||
# NEXT_PUBLIC_WORKOS_ENABLED=true
|
||||
# NEXT_PUBLIC_KEYCLOAK_ENABLED=true
|
||||
|
||||
@@ -1,4 +1 @@
|
||||
#!/bin/sh
|
||||
. "$(dirname "$0")/_/husky.sh"
|
||||
|
||||
npx lint-staged
|
||||
|
||||
@@ -1,4 +1 @@
|
||||
#!/bin/sh
|
||||
. "$(dirname "$0")/_/husky.sh"
|
||||
|
||||
pnpm typecheck && pnpm lint && pnpm format:check
|
||||
|
||||
@@ -4,3 +4,4 @@ pnpm-lock.yaml
|
||||
**/node_modules
|
||||
**/drizzle
|
||||
**/.next
|
||||
.claude/
|
||||
|
||||
@@ -1,57 +1,61 @@
|
||||
variables:
|
||||
- &node_image 'node:22-alpine'
|
||||
- &install_deps |
|
||||
corepack enable
|
||||
pnpm install --frozen-lockfile
|
||||
- &enable_pnpm 'corepack enable'
|
||||
|
||||
when:
|
||||
- event: [push, pull_request, manual]
|
||||
|
||||
# Turbo remote cache (turbo.mosaicstack.dev) is configured via Woodpecker
|
||||
# repository-level environment variables (TURBO_API, TURBO_TEAM, TURBO_TOKEN).
|
||||
# This avoids from_secret which is blocked on pull_request events.
|
||||
# If the env vars aren't set, turbo falls back to local cache only.
|
||||
|
||||
steps:
|
||||
install:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *install_deps
|
||||
- corepack enable
|
||||
- pnpm install --frozen-lockfile
|
||||
|
||||
typecheck:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *install_deps
|
||||
- *enable_pnpm
|
||||
- pnpm typecheck
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
# lint, format, and test are independent — run in parallel after typecheck
|
||||
lint:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *install_deps
|
||||
- *enable_pnpm
|
||||
- pnpm lint
|
||||
depends_on:
|
||||
- install
|
||||
- typecheck
|
||||
|
||||
format:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *install_deps
|
||||
- *enable_pnpm
|
||||
- pnpm format:check
|
||||
depends_on:
|
||||
- install
|
||||
- typecheck
|
||||
|
||||
test:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *install_deps
|
||||
- *enable_pnpm
|
||||
- pnpm test
|
||||
depends_on:
|
||||
- install
|
||||
- typecheck
|
||||
|
||||
build:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *install_deps
|
||||
- *enable_pnpm
|
||||
- pnpm build
|
||||
depends_on:
|
||||
- typecheck
|
||||
- lint
|
||||
- format
|
||||
- test
|
||||
|
||||
25
AGENTS.md
25
AGENTS.md
@@ -53,3 +53,28 @@ pnpm typecheck && pnpm lint && pnpm format:check # Quality gates
|
||||
- ESM everywhere (`"type": "module"`, `.js` extensions in imports)
|
||||
- NodeNext module resolution in all tsconfigs
|
||||
- Scratchpads are mandatory for non-trivial tasks
|
||||
|
||||
## docs/TASKS.md — Schema (CANONICAL)
|
||||
|
||||
The `agent` column specifies the required model for each task. **This is set at task creation by the orchestrator and must not be changed by workers.**
|
||||
|
||||
| Value | When to use | Budget |
|
||||
| -------- | ----------------------------------------------------------- | -------------------------- |
|
||||
| `codex` | All coding tasks (default for implementation) | OpenAI credits — preferred |
|
||||
| `glm-5` | Cost-sensitive coding where Codex is unavailable | Z.ai credits |
|
||||
| `haiku` | Review gates, verify tasks, status checks, docs-only | Cheapest Claude tier |
|
||||
| `sonnet` | Complex planning, multi-file reasoning, architecture review | Claude quota |
|
||||
| `opus` | Major cross-cutting architecture decisions ONLY | Most expensive — minimize |
|
||||
| `—` | No preference / auto-select cheapest capable | Pipeline decides |
|
||||
|
||||
Pipeline crons read this column and spawn accordingly. Workers never modify `docs/TASKS.md` — only the orchestrator writes it.
|
||||
|
||||
**Full schema:**
|
||||
|
||||
```
|
||||
| id | status | description | issue | agent | repo | branch | depends_on | estimate | notes |
|
||||
```
|
||||
|
||||
- `status`: `not-started` | `in-progress` | `done` | `failed` | `blocked` | `needs-qa`
|
||||
- `agent`: model value from table above (set before spawning)
|
||||
- `estimate`: token budget e.g. `8K`, `25K`
|
||||
|
||||
@@ -8,23 +8,30 @@
|
||||
"build": "tsc",
|
||||
"dev": "tsx watch src/main.ts",
|
||||
"lint": "eslint src",
|
||||
"typecheck": "tsc --noEmit",
|
||||
"typecheck": "tsc --noEmit -p tsconfig.typecheck.json",
|
||||
"test": "vitest run --passWithNoTests"
|
||||
},
|
||||
"dependencies": {
|
||||
"@anthropic-ai/sdk": "^0.80.0",
|
||||
"@fastify/helmet": "^13.0.2",
|
||||
"@mariozechner/pi-ai": "~0.57.1",
|
||||
"@mariozechner/pi-coding-agent": "~0.57.1",
|
||||
"@modelcontextprotocol/sdk": "^1.27.1",
|
||||
"@mosaic/auth": "workspace:^",
|
||||
"@mosaic/brain": "workspace:^",
|
||||
"@mosaic/coord": "workspace:^",
|
||||
"@mosaic/db": "workspace:^",
|
||||
"@mosaic/discord-plugin": "workspace:^",
|
||||
"@mosaic/log": "workspace:^",
|
||||
"@mosaic/memory": "workspace:^",
|
||||
"@mosaic/queue": "workspace:^",
|
||||
"@mosaic/telegram-plugin": "workspace:^",
|
||||
"@mosaic/types": "workspace:^",
|
||||
"@nestjs/common": "^11.0.0",
|
||||
"@nestjs/core": "^11.0.0",
|
||||
"@nestjs/platform-fastify": "^11.0.0",
|
||||
"@nestjs/platform-socket.io": "^11.0.0",
|
||||
"@nestjs/throttler": "^6.5.0",
|
||||
"@nestjs/websockets": "^11.0.0",
|
||||
"@opentelemetry/auto-instrumentations-node": "^0.71.0",
|
||||
"@opentelemetry/exporter-metrics-otlp-http": "^0.213.0",
|
||||
@@ -35,12 +42,17 @@
|
||||
"@opentelemetry/semantic-conventions": "^1.40.0",
|
||||
"@sinclair/typebox": "^0.34.48",
|
||||
"better-auth": "^1.5.5",
|
||||
"class-transformer": "^0.5.1",
|
||||
"class-validator": "^0.15.1",
|
||||
"dotenv": "^17.3.1",
|
||||
"fastify": "^5.0.0",
|
||||
"node-cron": "^4.2.1",
|
||||
"openai": "^6.32.0",
|
||||
"reflect-metadata": "^0.2.0",
|
||||
"rxjs": "^7.8.0",
|
||||
"socket.io": "^4.8.0",
|
||||
"uuid": "^11.0.0"
|
||||
"uuid": "^11.0.0",
|
||||
"zod": "^4.3.6"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^22.0.0",
|
||||
|
||||
605
apps/gateway/src/__tests__/conversation-persistence.test.ts
Normal file
605
apps/gateway/src/__tests__/conversation-persistence.test.ts
Normal file
@@ -0,0 +1,605 @@
|
||||
/**
|
||||
* Integration tests for conversation persistence and context resume (M1-008).
|
||||
*
|
||||
* Verifies the full flow end-to-end using in-memory mocks:
|
||||
* 1. User messages are persisted when sent via ChatGateway.
|
||||
* 2. Assistant responses are persisted with metadata on agent:end.
|
||||
* 3. Conversation history is loaded and injected into context on session resume.
|
||||
* 4. The search endpoint returns matching messages.
|
||||
*/
|
||||
|
||||
import { BadRequestException, NotFoundException } from '@nestjs/common';
|
||||
import { describe, expect, it, vi, beforeEach } from 'vitest';
|
||||
import type { ConversationHistoryMessage } from '../agent/agent.service.js';
|
||||
import { ConversationsController } from '../conversations/conversations.controller.js';
|
||||
import type { Message } from '@mosaic/brain';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared test data
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const USER_ID = 'user-test-001';
|
||||
const CONV_ID = 'conv-test-001';
|
||||
|
||||
function makeConversation(overrides?: Record<string, unknown>) {
|
||||
return {
|
||||
id: CONV_ID,
|
||||
userId: USER_ID,
|
||||
title: null,
|
||||
projectId: null,
|
||||
archived: false,
|
||||
createdAt: new Date('2026-01-01T00:00:00Z'),
|
||||
updatedAt: new Date('2026-01-01T00:00:00Z'),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeMessage(
|
||||
role: 'user' | 'assistant' | 'system',
|
||||
content: string,
|
||||
overrides?: Record<string, unknown>,
|
||||
) {
|
||||
return {
|
||||
id: `msg-${role}-${Math.random().toString(36).slice(2)}`,
|
||||
conversationId: CONV_ID,
|
||||
role,
|
||||
content,
|
||||
metadata: null,
|
||||
createdAt: new Date('2026-01-01T00:01:00Z'),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper: build a mock ConversationsRepo
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function createMockBrain(options?: {
|
||||
conversation?: ReturnType<typeof makeConversation> | undefined;
|
||||
messages?: ReturnType<typeof makeMessage>[];
|
||||
searchResults?: Array<{
|
||||
messageId: string;
|
||||
conversationId: string;
|
||||
conversationTitle: string | null;
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
createdAt: Date;
|
||||
}>;
|
||||
}) {
|
||||
const conversation = options?.conversation;
|
||||
const messages = options?.messages ?? [];
|
||||
const searchResults = options?.searchResults ?? [];
|
||||
|
||||
return {
|
||||
conversations: {
|
||||
findAll: vi.fn().mockResolvedValue(conversation ? [conversation] : []),
|
||||
findById: vi.fn().mockResolvedValue(conversation),
|
||||
create: vi.fn().mockResolvedValue(conversation ?? makeConversation()),
|
||||
update: vi.fn().mockResolvedValue(conversation),
|
||||
remove: vi.fn().mockResolvedValue(true),
|
||||
findMessages: vi.fn().mockResolvedValue(messages),
|
||||
addMessage: vi.fn().mockImplementation((data: unknown) => {
|
||||
const d = data as {
|
||||
conversationId: string;
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
metadata?: Record<string, unknown>;
|
||||
};
|
||||
return Promise.resolve(makeMessage(d.role, d.content, { metadata: d.metadata ?? null }));
|
||||
}),
|
||||
searchMessages: vi.fn().mockResolvedValue(searchResults),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. ConversationsRepo: addMessage persists user message
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('ConversationsRepo.addMessage — user message persistence', () => {
|
||||
it('persists a user message and returns the saved record', async () => {
|
||||
const brain = createMockBrain({ conversation: makeConversation() });
|
||||
|
||||
const result = await brain.conversations.addMessage(
|
||||
{
|
||||
conversationId: CONV_ID,
|
||||
role: 'user',
|
||||
content: 'Hello, agent!',
|
||||
metadata: { timestamp: '2026-01-01T00:01:00.000Z' },
|
||||
},
|
||||
USER_ID,
|
||||
);
|
||||
|
||||
expect(brain.conversations.addMessage).toHaveBeenCalledOnce();
|
||||
expect(result).toBeDefined();
|
||||
expect(result!.role).toBe('user');
|
||||
expect(result!.content).toBe('Hello, agent!');
|
||||
expect(result!.conversationId).toBe(CONV_ID);
|
||||
});
|
||||
|
||||
it('returns undefined when conversation does not belong to the user', async () => {
|
||||
// Simulate the repo enforcement: ownership mismatch returns undefined
|
||||
const brain = createMockBrain({ conversation: undefined });
|
||||
brain.conversations.addMessage = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
const result = await brain.conversations.addMessage(
|
||||
{ conversationId: CONV_ID, role: 'user', content: 'Hello' },
|
||||
'other-user',
|
||||
);
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. ConversationsRepo.addMessage — assistant response with metadata
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('ConversationsRepo.addMessage — assistant response metadata', () => {
|
||||
it('persists assistant message with model, provider, tokens and toolCalls metadata', async () => {
|
||||
const assistantMetadata = {
|
||||
timestamp: '2026-01-01T00:02:00.000Z',
|
||||
model: 'claude-3-5-sonnet-20241022',
|
||||
provider: 'anthropic',
|
||||
toolCalls: [
|
||||
{
|
||||
toolCallId: 'tc-001',
|
||||
toolName: 'read_file',
|
||||
args: { path: '/foo/bar.ts' },
|
||||
isError: false,
|
||||
},
|
||||
],
|
||||
tokenUsage: {
|
||||
input: 1000,
|
||||
output: 250,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 1250,
|
||||
},
|
||||
};
|
||||
|
||||
const brain = createMockBrain({ conversation: makeConversation() });
|
||||
|
||||
const result = await brain.conversations.addMessage(
|
||||
{
|
||||
conversationId: CONV_ID,
|
||||
role: 'assistant',
|
||||
content: 'Here is the file content you requested.',
|
||||
metadata: assistantMetadata,
|
||||
},
|
||||
USER_ID,
|
||||
);
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result!.role).toBe('assistant');
|
||||
expect(result!.content).toBe('Here is the file content you requested.');
|
||||
expect(result!.metadata).toMatchObject({
|
||||
model: 'claude-3-5-sonnet-20241022',
|
||||
provider: 'anthropic',
|
||||
tokenUsage: { input: 1000, output: 250, total: 1250 },
|
||||
});
|
||||
expect((result!.metadata as Record<string, unknown>)['toolCalls']).toHaveLength(1);
|
||||
expect(
|
||||
(
|
||||
(result!.metadata as Record<string, unknown>)['toolCalls'] as Array<Record<string, unknown>>
|
||||
)[0]!['toolName'],
|
||||
).toBe('read_file');
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 3. ChatGateway.loadConversationHistory — session resume loads history
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('Conversation resume — history loading', () => {
|
||||
it('maps DB messages to ConversationHistoryMessage shape', () => {
|
||||
// Simulate what ChatGateway.loadConversationHistory does:
|
||||
// convert DB Message rows to ConversationHistoryMessage for context injection.
|
||||
const dbMessages = [
|
||||
makeMessage('user', 'What is the capital of France?', {
|
||||
createdAt: new Date('2026-01-01T00:01:00Z'),
|
||||
}),
|
||||
makeMessage('assistant', 'The capital of France is Paris.', {
|
||||
createdAt: new Date('2026-01-01T00:01:05Z'),
|
||||
}),
|
||||
makeMessage('user', 'And Germany?', { createdAt: new Date('2026-01-01T00:02:00Z') }),
|
||||
makeMessage('assistant', 'The capital of Germany is Berlin.', {
|
||||
createdAt: new Date('2026-01-01T00:02:05Z'),
|
||||
}),
|
||||
];
|
||||
|
||||
// Replicate the mapping logic from ChatGateway
|
||||
const history: ConversationHistoryMessage[] = dbMessages.map((msg) => ({
|
||||
role: msg.role as 'user' | 'assistant' | 'system',
|
||||
content: msg.content,
|
||||
createdAt: msg.createdAt,
|
||||
}));
|
||||
|
||||
expect(history).toHaveLength(4);
|
||||
expect(history[0]).toEqual({
|
||||
role: 'user',
|
||||
content: 'What is the capital of France?',
|
||||
createdAt: new Date('2026-01-01T00:01:00Z'),
|
||||
});
|
||||
expect(history[1]).toEqual({
|
||||
role: 'assistant',
|
||||
content: 'The capital of France is Paris.',
|
||||
createdAt: new Date('2026-01-01T00:01:05Z'),
|
||||
});
|
||||
expect(history[2]!.role).toBe('user');
|
||||
expect(history[3]!.role).toBe('assistant');
|
||||
});
|
||||
|
||||
it('returns empty array when conversation has no messages', async () => {
|
||||
const brain = createMockBrain({ conversation: makeConversation(), messages: [] });
|
||||
|
||||
const messages = await brain.conversations.findMessages(CONV_ID, USER_ID);
|
||||
expect(messages).toHaveLength(0);
|
||||
|
||||
// Gateway produces empty history → no context injection
|
||||
const history: ConversationHistoryMessage[] = (messages as Message[]).map((msg) => ({
|
||||
role: msg.role as 'user' | 'assistant' | 'system',
|
||||
content: msg.content,
|
||||
createdAt: msg.createdAt,
|
||||
}));
|
||||
expect(history).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('returns empty array when conversation does not belong to the user', async () => {
|
||||
const brain = createMockBrain({ conversation: undefined });
|
||||
brain.conversations.findMessages = vi.fn().mockResolvedValue([]);
|
||||
|
||||
const messages = await brain.conversations.findMessages(CONV_ID, 'other-user');
|
||||
expect(messages).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('preserves message order (ascending by createdAt)', async () => {
|
||||
const ordered = [
|
||||
makeMessage('user', 'First', { createdAt: new Date('2026-01-01T00:01:00Z') }),
|
||||
makeMessage('assistant', 'Second', { createdAt: new Date('2026-01-01T00:01:05Z') }),
|
||||
makeMessage('user', 'Third', { createdAt: new Date('2026-01-01T00:02:00Z') }),
|
||||
];
|
||||
const brain = createMockBrain({ conversation: makeConversation(), messages: ordered });
|
||||
|
||||
const messages = await brain.conversations.findMessages(CONV_ID, USER_ID);
|
||||
expect(messages[0]!.content).toBe('First');
|
||||
expect(messages[1]!.content).toBe('Second');
|
||||
expect(messages[2]!.content).toBe('Third');
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 4. AgentService.buildHistoryPromptSection — context injection format
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('AgentService — buildHistoryPromptSection (context injection)', () => {
|
||||
/**
|
||||
* Replicate the private method logic to test it in isolation.
|
||||
* The real method lives in AgentService but is private; we mirror the
|
||||
* exact logic here so the test is independent of the service's constructor.
|
||||
*/
|
||||
function buildHistoryPromptSection(
|
||||
history: ConversationHistoryMessage[],
|
||||
contextWindow: number,
|
||||
_sessionId: string,
|
||||
): string {
|
||||
const TOKEN_BUDGET = Math.floor(contextWindow * 0.8);
|
||||
const HISTORY_HEADER = '## Conversation History (resumed session)\n\n';
|
||||
|
||||
const formatMessage = (msg: ConversationHistoryMessage): string => {
|
||||
const roleLabel =
|
||||
msg.role === 'user' ? 'User' : msg.role === 'assistant' ? 'Assistant' : 'System';
|
||||
return `**${roleLabel}:** ${msg.content}`;
|
||||
};
|
||||
|
||||
const estimateTokens = (text: string) => Math.ceil(text.length / 4);
|
||||
|
||||
const formatted = history.map((msg) => formatMessage(msg));
|
||||
const fullHistory = formatted.join('\n\n');
|
||||
const fullTokens = estimateTokens(HISTORY_HEADER + fullHistory);
|
||||
|
||||
if (fullTokens <= TOKEN_BUDGET) {
|
||||
return HISTORY_HEADER + fullHistory;
|
||||
}
|
||||
|
||||
// History exceeds budget — summarize oldest messages, keep recent verbatim
|
||||
const SUMMARY_RESERVE = Math.floor(TOKEN_BUDGET * 0.2);
|
||||
const verbatimBudget = TOKEN_BUDGET - SUMMARY_RESERVE;
|
||||
|
||||
let verbatimTokens = 0;
|
||||
let verbatimCutIndex = history.length;
|
||||
for (let i = history.length - 1; i >= 0; i--) {
|
||||
const t = estimateTokens(formatted[i]!);
|
||||
if (verbatimTokens + t > verbatimBudget) break;
|
||||
verbatimTokens += t;
|
||||
verbatimCutIndex = i;
|
||||
}
|
||||
|
||||
const summarizedMessages = history.slice(0, verbatimCutIndex);
|
||||
const verbatimMessages = history.slice(verbatimCutIndex);
|
||||
|
||||
let summaryText = '';
|
||||
if (summarizedMessages.length > 0) {
|
||||
const topics = summarizedMessages
|
||||
.filter((m) => m.role === 'user')
|
||||
.map((m) => m.content.slice(0, 120).replace(/\n/g, ' '))
|
||||
.join('; ');
|
||||
summaryText =
|
||||
`**Previous conversation summary** (${summarizedMessages.length} messages omitted for brevity):\n` +
|
||||
`Topics discussed: ${topics || '(no user messages in summarized portion)'}`;
|
||||
}
|
||||
|
||||
const verbatimSection = verbatimMessages.map((m) => formatMessage(m)).join('\n\n');
|
||||
|
||||
const parts: string[] = [HISTORY_HEADER];
|
||||
if (summaryText) parts.push(summaryText);
|
||||
if (verbatimSection) parts.push(verbatimSection);
|
||||
|
||||
return parts.join('\n\n');
|
||||
}
|
||||
|
||||
it('includes header and all messages when history fits within context budget', () => {
|
||||
const history: ConversationHistoryMessage[] = [
|
||||
{ role: 'user', content: 'Hello', createdAt: new Date() },
|
||||
{ role: 'assistant', content: 'Hi there!', createdAt: new Date() },
|
||||
];
|
||||
|
||||
const result = buildHistoryPromptSection(history, 8192, 'session-1');
|
||||
|
||||
expect(result).toContain('## Conversation History (resumed session)');
|
||||
expect(result).toContain('**User:** Hello');
|
||||
expect(result).toContain('**Assistant:** Hi there!');
|
||||
});
|
||||
|
||||
it('labels roles correctly (user, assistant, system)', () => {
|
||||
const history: ConversationHistoryMessage[] = [
|
||||
{ role: 'system', content: 'You are helpful.', createdAt: new Date() },
|
||||
{ role: 'user', content: 'Ping', createdAt: new Date() },
|
||||
{ role: 'assistant', content: 'Pong', createdAt: new Date() },
|
||||
];
|
||||
|
||||
const result = buildHistoryPromptSection(history, 8192, 'session-2');
|
||||
|
||||
expect(result).toContain('**System:** You are helpful.');
|
||||
expect(result).toContain('**User:** Ping');
|
||||
expect(result).toContain('**Assistant:** Pong');
|
||||
});
|
||||
|
||||
it('summarizes old messages when history exceeds 80% of context window', () => {
|
||||
// Create enough messages to exceed a tiny context window budget
|
||||
const longContent = 'A'.repeat(200);
|
||||
const history: ConversationHistoryMessage[] = Array.from({ length: 20 }, (_, i) => ({
|
||||
role: (i % 2 === 0 ? 'user' : 'assistant') as 'user' | 'assistant',
|
||||
content: `${longContent} message ${i}`,
|
||||
createdAt: new Date(),
|
||||
}));
|
||||
|
||||
// Use a small context window so history definitely exceeds 80%
|
||||
const result = buildHistoryPromptSection(history, 512, 'session-3');
|
||||
|
||||
// Should contain the summary prefix
|
||||
expect(result).toContain('messages omitted for brevity');
|
||||
expect(result).toContain('Topics discussed:');
|
||||
});
|
||||
|
||||
it('returns only header for empty history', () => {
|
||||
const result = buildHistoryPromptSection([], 8192, 'session-4');
|
||||
// With empty history, the full history join is '' and the section is just the header
|
||||
expect(result).toContain('## Conversation History (resumed session)');
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 5. ConversationsController.search — GET /api/conversations/search
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('ConversationsController — search endpoint', () => {
|
||||
let brain: ReturnType<typeof createMockBrain>;
|
||||
let controller: ConversationsController;
|
||||
|
||||
beforeEach(() => {
|
||||
const searchResults = [
|
||||
{
|
||||
messageId: 'msg-001',
|
||||
conversationId: CONV_ID,
|
||||
conversationTitle: 'Test Chat',
|
||||
role: 'user' as const,
|
||||
content: 'What is the capital of France?',
|
||||
createdAt: new Date('2026-01-01T00:01:00Z'),
|
||||
},
|
||||
{
|
||||
messageId: 'msg-002',
|
||||
conversationId: CONV_ID,
|
||||
conversationTitle: 'Test Chat',
|
||||
role: 'assistant' as const,
|
||||
content: 'The capital of France is Paris.',
|
||||
createdAt: new Date('2026-01-01T00:01:05Z'),
|
||||
},
|
||||
];
|
||||
brain = createMockBrain({ searchResults });
|
||||
controller = new ConversationsController(brain as never);
|
||||
});
|
||||
|
||||
it('returns matching messages for a valid search query', async () => {
|
||||
const results = await controller.search({ q: 'France' }, { id: USER_ID });
|
||||
|
||||
expect(brain.conversations.searchMessages).toHaveBeenCalledWith(USER_ID, 'France', 20, 0);
|
||||
expect(results).toHaveLength(2);
|
||||
expect(results[0]).toMatchObject({
|
||||
messageId: 'msg-001',
|
||||
role: 'user',
|
||||
content: 'What is the capital of France?',
|
||||
});
|
||||
expect(results[1]).toMatchObject({
|
||||
messageId: 'msg-002',
|
||||
role: 'assistant',
|
||||
content: 'The capital of France is Paris.',
|
||||
});
|
||||
});
|
||||
|
||||
it('uses custom limit and offset when provided', async () => {
|
||||
await controller.search({ q: 'Paris', limit: 5, offset: 10 }, { id: USER_ID });
|
||||
|
||||
expect(brain.conversations.searchMessages).toHaveBeenCalledWith(USER_ID, 'Paris', 5, 10);
|
||||
});
|
||||
|
||||
it('throws BadRequestException when query is empty', async () => {
|
||||
await expect(controller.search({ q: '' }, { id: USER_ID })).rejects.toBeInstanceOf(
|
||||
BadRequestException,
|
||||
);
|
||||
await expect(controller.search({ q: ' ' }, { id: USER_ID })).rejects.toBeInstanceOf(
|
||||
BadRequestException,
|
||||
);
|
||||
});
|
||||
|
||||
it('trims whitespace from query before passing to repo', async () => {
|
||||
await controller.search({ q: ' Berlin ' }, { id: USER_ID });
|
||||
|
||||
expect(brain.conversations.searchMessages).toHaveBeenCalledWith(
|
||||
USER_ID,
|
||||
'Berlin',
|
||||
expect.any(Number),
|
||||
expect.any(Number),
|
||||
);
|
||||
});
|
||||
|
||||
it('returns empty array when no messages match', async () => {
|
||||
brain.conversations.searchMessages = vi.fn().mockResolvedValue([]);
|
||||
|
||||
const results = await controller.search({ q: 'xyzzy-no-match' }, { id: USER_ID });
|
||||
|
||||
expect(results).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 6. ConversationsController — messages CRUD
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('ConversationsController — message CRUD', () => {
|
||||
it('listMessages returns 404 when conversation is not owned by user', async () => {
|
||||
const brain = createMockBrain({ conversation: undefined });
|
||||
const controller = new ConversationsController(brain as never);
|
||||
|
||||
await expect(controller.listMessages(CONV_ID, { id: USER_ID })).rejects.toBeInstanceOf(
|
||||
NotFoundException,
|
||||
);
|
||||
});
|
||||
|
||||
it('listMessages returns the messages for an owned conversation', async () => {
|
||||
const msgs = [makeMessage('user', 'Test message'), makeMessage('assistant', 'Test reply')];
|
||||
const brain = createMockBrain({ conversation: makeConversation(), messages: msgs });
|
||||
const controller = new ConversationsController(brain as never);
|
||||
|
||||
const result = await controller.listMessages(CONV_ID, { id: USER_ID });
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0]!.role).toBe('user');
|
||||
expect(result[1]!.role).toBe('assistant');
|
||||
});
|
||||
|
||||
it('addMessage returns the persisted message', async () => {
|
||||
const brain = createMockBrain({ conversation: makeConversation() });
|
||||
const controller = new ConversationsController(brain as never);
|
||||
|
||||
const result = await controller.addMessage(
|
||||
CONV_ID,
|
||||
{ role: 'user', content: 'Persisted content' },
|
||||
{ id: USER_ID },
|
||||
);
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.role).toBe('user');
|
||||
expect(result.content).toBe('Persisted content');
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 7. End-to-end persistence flow simulation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe('End-to-end persistence flow', () => {
|
||||
it('simulates a full conversation: persist user message → persist assistant response → resume with history', async () => {
|
||||
// ── Step 1: Conversation is created ────────────────────────────────────
|
||||
const brain = createMockBrain({ conversation: makeConversation() });
|
||||
|
||||
await brain.conversations.create({ id: CONV_ID, userId: USER_ID });
|
||||
expect(brain.conversations.create).toHaveBeenCalledOnce();
|
||||
|
||||
// ── Step 2: User message is persisted ──────────────────────────────────
|
||||
const userMsg = await brain.conversations.addMessage(
|
||||
{
|
||||
conversationId: CONV_ID,
|
||||
role: 'user',
|
||||
content: 'Explain monads in simple terms.',
|
||||
metadata: { timestamp: '2026-01-01T00:01:00.000Z' },
|
||||
},
|
||||
USER_ID,
|
||||
);
|
||||
|
||||
expect(userMsg).toBeDefined();
|
||||
expect(userMsg!.role).toBe('user');
|
||||
|
||||
// ── Step 3: Assistant response is persisted with metadata ───────────────
|
||||
const assistantMeta = {
|
||||
timestamp: '2026-01-01T00:01:10.000Z',
|
||||
model: 'claude-3-5-sonnet-20241022',
|
||||
provider: 'anthropic',
|
||||
toolCalls: [],
|
||||
tokenUsage: { input: 500, output: 120, cacheRead: 0, cacheWrite: 0, total: 620 },
|
||||
};
|
||||
|
||||
const assistantMsg = await brain.conversations.addMessage(
|
||||
{
|
||||
conversationId: CONV_ID,
|
||||
role: 'assistant',
|
||||
content: 'A monad is a design pattern that wraps values in a context...',
|
||||
metadata: assistantMeta,
|
||||
},
|
||||
USER_ID,
|
||||
);
|
||||
|
||||
expect(assistantMsg).toBeDefined();
|
||||
expect(assistantMsg!.role).toBe('assistant');
|
||||
|
||||
// ── Step 4: On session resume, history is loaded ────────────────────────
|
||||
const storedMessages = [
|
||||
makeMessage('user', 'Explain monads in simple terms.', {
|
||||
createdAt: new Date('2026-01-01T00:01:00Z'),
|
||||
metadata: { timestamp: '2026-01-01T00:01:00.000Z' },
|
||||
}),
|
||||
makeMessage('assistant', 'A monad is a design pattern that wraps values in a context...', {
|
||||
createdAt: new Date('2026-01-01T00:01:10Z'),
|
||||
metadata: assistantMeta,
|
||||
}),
|
||||
];
|
||||
|
||||
brain.conversations.findMessages = vi.fn().mockResolvedValue(storedMessages);
|
||||
|
||||
const dbMessages = await brain.conversations.findMessages(CONV_ID, USER_ID);
|
||||
expect(dbMessages).toHaveLength(2);
|
||||
|
||||
// ── Step 5: History is mapped for context injection ─────────────────────
|
||||
const history: ConversationHistoryMessage[] = (dbMessages as Message[]).map((msg) => ({
|
||||
role: msg.role as 'user' | 'assistant' | 'system',
|
||||
content: msg.content,
|
||||
createdAt: msg.createdAt,
|
||||
}));
|
||||
|
||||
expect(history[0]).toMatchObject({
|
||||
role: 'user',
|
||||
content: 'Explain monads in simple terms.',
|
||||
});
|
||||
expect(history[1]).toMatchObject({
|
||||
role: 'assistant',
|
||||
content: 'A monad is a design pattern that wraps values in a context...',
|
||||
});
|
||||
|
||||
// ── Step 6: History roles are valid for injection ───────────────────────
|
||||
for (const msg of history) {
|
||||
expect(['user', 'assistant', 'system']).toContain(msg.role);
|
||||
expect(typeof msg.content).toBe('string');
|
||||
expect(msg.createdAt).toBeInstanceOf(Date);
|
||||
}
|
||||
});
|
||||
});
|
||||
470
apps/gateway/src/__tests__/cross-user-isolation.test.ts
Normal file
470
apps/gateway/src/__tests__/cross-user-isolation.test.ts
Normal file
@@ -0,0 +1,470 @@
|
||||
/**
|
||||
* Integration test: Cross-user data isolation (M2-007)
|
||||
*
|
||||
* Verifies that every repository query path is scoped to the requesting user —
|
||||
* no user can read, write, or enumerate another user's records.
|
||||
*
|
||||
* Test strategy:
|
||||
* - Two real users (User A, User B) are inserted directly into the database.
|
||||
* - Realistic data (conversations + messages, agent configs, preferences,
|
||||
* insights) is created for each user.
|
||||
* - A shared system agent is inserted so both users can see it via
|
||||
* findAccessible().
|
||||
* - All assertions are made against the live database (no mocks).
|
||||
* - All inserted rows are cleaned up in the afterAll hook.
|
||||
*
|
||||
* Requires: DATABASE_URL pointing at a running PostgreSQL instance with
|
||||
* pgvector enabled and the Mosaic schema already applied.
|
||||
*/
|
||||
|
||||
import { afterAll, beforeAll, describe, expect, it } from 'vitest';
|
||||
import { createDb } from '@mosaic/db';
|
||||
import { createConversationsRepo } from '@mosaic/brain';
|
||||
import { createAgentsRepo } from '@mosaic/brain';
|
||||
import { createPreferencesRepo, createInsightsRepo } from '@mosaic/memory';
|
||||
import { users, conversations, messages, agents, preferences, insights } from '@mosaic/db';
|
||||
import { eq } from '@mosaic/db';
|
||||
import type { DbHandle } from '@mosaic/db';
|
||||
|
||||
// ─── Fixed IDs so the afterAll cleanup is deterministic ──────────────────────
|
||||
|
||||
const USER_A_ID = 'test-iso-user-a';
|
||||
const USER_B_ID = 'test-iso-user-b';
|
||||
const CONV_A_ID = 'aaaaaaaa-0000-0000-0000-000000000001';
|
||||
const CONV_B_ID = 'bbbbbbbb-0000-0000-0000-000000000001';
|
||||
const MSG_A_ID = 'aaaaaaaa-0000-0000-0000-000000000002';
|
||||
const MSG_B_ID = 'bbbbbbbb-0000-0000-0000-000000000002';
|
||||
const AGENT_A_ID = 'aaaaaaaa-0000-0000-0000-000000000003';
|
||||
const AGENT_B_ID = 'bbbbbbbb-0000-0000-0000-000000000003';
|
||||
const AGENT_SYS_ID = 'ffffffff-0000-0000-0000-000000000001';
|
||||
const PREF_A_ID = 'aaaaaaaa-0000-0000-0000-000000000004';
|
||||
const PREF_B_ID = 'bbbbbbbb-0000-0000-0000-000000000004';
|
||||
const INSIGHT_A_ID = 'aaaaaaaa-0000-0000-0000-000000000005';
|
||||
const INSIGHT_B_ID = 'bbbbbbbb-0000-0000-0000-000000000005';
|
||||
|
||||
// ─── Test fixture ─────────────────────────────────────────────────────────────
|
||||
|
||||
let handle: DbHandle;
|
||||
|
||||
beforeAll(async () => {
|
||||
handle = createDb();
|
||||
const db = handle.db;
|
||||
|
||||
// Insert two users
|
||||
await db
|
||||
.insert(users)
|
||||
.values([
|
||||
{
|
||||
id: USER_A_ID,
|
||||
name: 'Isolation Test User A',
|
||||
email: 'test-iso-user-a@example.invalid',
|
||||
emailVerified: false,
|
||||
},
|
||||
{
|
||||
id: USER_B_ID,
|
||||
name: 'Isolation Test User B',
|
||||
email: 'test-iso-user-b@example.invalid',
|
||||
emailVerified: false,
|
||||
},
|
||||
])
|
||||
.onConflictDoNothing();
|
||||
|
||||
// Conversations — one per user
|
||||
await db
|
||||
.insert(conversations)
|
||||
.values([
|
||||
{ id: CONV_A_ID, userId: USER_A_ID, title: 'User A conversation' },
|
||||
{ id: CONV_B_ID, userId: USER_B_ID, title: 'User B conversation' },
|
||||
])
|
||||
.onConflictDoNothing();
|
||||
|
||||
// Messages — one per conversation
|
||||
await db
|
||||
.insert(messages)
|
||||
.values([
|
||||
{
|
||||
id: MSG_A_ID,
|
||||
conversationId: CONV_A_ID,
|
||||
role: 'user',
|
||||
content: 'Hello from User A',
|
||||
},
|
||||
{
|
||||
id: MSG_B_ID,
|
||||
conversationId: CONV_B_ID,
|
||||
role: 'user',
|
||||
content: 'Hello from User B',
|
||||
},
|
||||
])
|
||||
.onConflictDoNothing();
|
||||
|
||||
// Agent configs — private agents (one per user) + one system agent
|
||||
await db
|
||||
.insert(agents)
|
||||
.values([
|
||||
{
|
||||
id: AGENT_A_ID,
|
||||
name: 'Agent A (private)',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
ownerId: USER_A_ID,
|
||||
isSystem: false,
|
||||
},
|
||||
{
|
||||
id: AGENT_B_ID,
|
||||
name: 'Agent B (private)',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
ownerId: USER_B_ID,
|
||||
isSystem: false,
|
||||
},
|
||||
{
|
||||
id: AGENT_SYS_ID,
|
||||
name: 'Shared System Agent',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
ownerId: null,
|
||||
isSystem: true,
|
||||
},
|
||||
])
|
||||
.onConflictDoNothing();
|
||||
|
||||
// Preferences — one per user (same key, different values)
|
||||
await db
|
||||
.insert(preferences)
|
||||
.values([
|
||||
{
|
||||
id: PREF_A_ID,
|
||||
userId: USER_A_ID,
|
||||
key: 'theme',
|
||||
value: 'dark',
|
||||
category: 'appearance',
|
||||
},
|
||||
{
|
||||
id: PREF_B_ID,
|
||||
userId: USER_B_ID,
|
||||
key: 'theme',
|
||||
value: 'light',
|
||||
category: 'appearance',
|
||||
},
|
||||
])
|
||||
.onConflictDoNothing();
|
||||
|
||||
// Insights — no embedding to keep the fixture simple; embedding-based search
|
||||
// is tested separately with a zero-vector that falls outside maxDistance
|
||||
await db
|
||||
.insert(insights)
|
||||
.values([
|
||||
{
|
||||
id: INSIGHT_A_ID,
|
||||
userId: USER_A_ID,
|
||||
content: 'User A insight',
|
||||
source: 'user',
|
||||
category: 'general',
|
||||
relevanceScore: 1.0,
|
||||
},
|
||||
{
|
||||
id: INSIGHT_B_ID,
|
||||
userId: USER_B_ID,
|
||||
content: 'User B insight',
|
||||
source: 'user',
|
||||
category: 'general',
|
||||
relevanceScore: 1.0,
|
||||
},
|
||||
])
|
||||
.onConflictDoNothing();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
if (!handle) return;
|
||||
const db = handle.db;
|
||||
|
||||
// Delete in dependency order (FK constraints)
|
||||
await db.delete(messages).where(eq(messages.id, MSG_A_ID));
|
||||
await db.delete(messages).where(eq(messages.id, MSG_B_ID));
|
||||
await db.delete(conversations).where(eq(conversations.id, CONV_A_ID));
|
||||
await db.delete(conversations).where(eq(conversations.id, CONV_B_ID));
|
||||
await db.delete(agents).where(eq(agents.id, AGENT_A_ID));
|
||||
await db.delete(agents).where(eq(agents.id, AGENT_B_ID));
|
||||
await db.delete(agents).where(eq(agents.id, AGENT_SYS_ID));
|
||||
await db.delete(preferences).where(eq(preferences.id, PREF_A_ID));
|
||||
await db.delete(preferences).where(eq(preferences.id, PREF_B_ID));
|
||||
await db.delete(insights).where(eq(insights.id, INSIGHT_A_ID));
|
||||
await db.delete(insights).where(eq(insights.id, INSIGHT_B_ID));
|
||||
await db.delete(users).where(eq(users.id, USER_A_ID));
|
||||
await db.delete(users).where(eq(users.id, USER_B_ID));
|
||||
|
||||
await handle.close();
|
||||
});
|
||||
|
||||
// ─── Conversations isolation ──────────────────────────────────────────────────
|
||||
|
||||
describe('ConversationsRepo — cross-user isolation', () => {
|
||||
it('User A can find their own conversation by id', async () => {
|
||||
const repo = createConversationsRepo(handle.db);
|
||||
const conv = await repo.findById(CONV_A_ID, USER_A_ID);
|
||||
expect(conv).toBeDefined();
|
||||
expect(conv!.id).toBe(CONV_A_ID);
|
||||
});
|
||||
|
||||
it('User B cannot find User A conversation by id (returns undefined)', async () => {
|
||||
const repo = createConversationsRepo(handle.db);
|
||||
const conv = await repo.findById(CONV_A_ID, USER_B_ID);
|
||||
expect(conv).toBeUndefined();
|
||||
});
|
||||
|
||||
it('User A cannot find User B conversation by id (returns undefined)', async () => {
|
||||
const repo = createConversationsRepo(handle.db);
|
||||
const conv = await repo.findById(CONV_B_ID, USER_A_ID);
|
||||
expect(conv).toBeUndefined();
|
||||
});
|
||||
|
||||
it('findAll returns only own conversations for User A', async () => {
|
||||
const repo = createConversationsRepo(handle.db);
|
||||
const convs = await repo.findAll(USER_A_ID);
|
||||
const ids = convs.map((c) => c.id);
|
||||
expect(ids).toContain(CONV_A_ID);
|
||||
expect(ids).not.toContain(CONV_B_ID);
|
||||
});
|
||||
|
||||
it('findAll returns only own conversations for User B', async () => {
|
||||
const repo = createConversationsRepo(handle.db);
|
||||
const convs = await repo.findAll(USER_B_ID);
|
||||
const ids = convs.map((c) => c.id);
|
||||
expect(ids).toContain(CONV_B_ID);
|
||||
expect(ids).not.toContain(CONV_A_ID);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Messages isolation ───────────────────────────────────────────────────────
|
||||
|
||||
describe('ConversationsRepo.findMessages — cross-user isolation', () => {
|
||||
it('User A can read messages from their own conversation', async () => {
|
||||
const repo = createConversationsRepo(handle.db);
|
||||
const msgs = await repo.findMessages(CONV_A_ID, USER_A_ID);
|
||||
const ids = msgs.map((m) => m.id);
|
||||
expect(ids).toContain(MSG_A_ID);
|
||||
});
|
||||
|
||||
it('User B cannot read messages from User A conversation (returns empty array)', async () => {
|
||||
const repo = createConversationsRepo(handle.db);
|
||||
const msgs = await repo.findMessages(CONV_A_ID, USER_B_ID);
|
||||
expect(msgs).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('User A cannot read messages from User B conversation (returns empty array)', async () => {
|
||||
const repo = createConversationsRepo(handle.db);
|
||||
const msgs = await repo.findMessages(CONV_B_ID, USER_A_ID);
|
||||
expect(msgs).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('addMessage is rejected when user does not own the conversation', async () => {
|
||||
const repo = createConversationsRepo(handle.db);
|
||||
const result = await repo.addMessage(
|
||||
{
|
||||
conversationId: CONV_A_ID,
|
||||
role: 'user',
|
||||
content: 'Attempted injection by User B',
|
||||
},
|
||||
USER_B_ID,
|
||||
);
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Agent configs isolation ──────────────────────────────────────────────────
|
||||
|
||||
describe('AgentsRepo.findAccessible — cross-user isolation', () => {
|
||||
it('User A sees their own private agent', async () => {
|
||||
const repo = createAgentsRepo(handle.db);
|
||||
const accessible = await repo.findAccessible(USER_A_ID);
|
||||
const ids = accessible.map((a) => a.id);
|
||||
expect(ids).toContain(AGENT_A_ID);
|
||||
});
|
||||
|
||||
it('User A does NOT see User B private agent', async () => {
|
||||
const repo = createAgentsRepo(handle.db);
|
||||
const accessible = await repo.findAccessible(USER_A_ID);
|
||||
const ids = accessible.map((a) => a.id);
|
||||
expect(ids).not.toContain(AGENT_B_ID);
|
||||
});
|
||||
|
||||
it('User B does NOT see User A private agent', async () => {
|
||||
const repo = createAgentsRepo(handle.db);
|
||||
const accessible = await repo.findAccessible(USER_B_ID);
|
||||
const ids = accessible.map((a) => a.id);
|
||||
expect(ids).not.toContain(AGENT_A_ID);
|
||||
});
|
||||
|
||||
it('Both users can see the shared system agent', async () => {
|
||||
const repo = createAgentsRepo(handle.db);
|
||||
const accessibleA = await repo.findAccessible(USER_A_ID);
|
||||
const accessibleB = await repo.findAccessible(USER_B_ID);
|
||||
expect(accessibleA.map((a) => a.id)).toContain(AGENT_SYS_ID);
|
||||
expect(accessibleB.map((a) => a.id)).toContain(AGENT_SYS_ID);
|
||||
});
|
||||
|
||||
it('findSystem returns the system agent for any caller', async () => {
|
||||
const repo = createAgentsRepo(handle.db);
|
||||
const system = await repo.findSystem();
|
||||
const ids = system.map((a) => a.id);
|
||||
expect(ids).toContain(AGENT_SYS_ID);
|
||||
});
|
||||
|
||||
it('update with ownerId prevents User B from modifying User A agent', async () => {
|
||||
const repo = createAgentsRepo(handle.db);
|
||||
const result = await repo.update(AGENT_A_ID, { model: 'hacked' }, USER_B_ID);
|
||||
expect(result).toBeUndefined();
|
||||
|
||||
// Verify the agent was not actually mutated
|
||||
const unchanged = await repo.findById(AGENT_A_ID);
|
||||
expect(unchanged?.model).toBe('test-model');
|
||||
});
|
||||
|
||||
it('remove prevents User B from deleting User A agent', async () => {
|
||||
const repo = createAgentsRepo(handle.db);
|
||||
const deleted = await repo.remove(AGENT_A_ID, USER_B_ID);
|
||||
expect(deleted).toBe(false);
|
||||
|
||||
// Verify the agent still exists
|
||||
const still = await repo.findById(AGENT_A_ID);
|
||||
expect(still).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Preferences isolation ────────────────────────────────────────────────────
|
||||
|
||||
describe('PreferencesRepo — cross-user isolation', () => {
|
||||
it('User A can retrieve their own preferences', async () => {
|
||||
const repo = createPreferencesRepo(handle.db);
|
||||
const prefs = await repo.findByUser(USER_A_ID);
|
||||
const ids = prefs.map((p) => p.id);
|
||||
expect(ids).toContain(PREF_A_ID);
|
||||
});
|
||||
|
||||
it('User A preferences do not contain User B preferences', async () => {
|
||||
const repo = createPreferencesRepo(handle.db);
|
||||
const prefs = await repo.findByUser(USER_A_ID);
|
||||
const ids = prefs.map((p) => p.id);
|
||||
expect(ids).not.toContain(PREF_B_ID);
|
||||
});
|
||||
|
||||
it('User B preferences do not contain User A preferences', async () => {
|
||||
const repo = createPreferencesRepo(handle.db);
|
||||
const prefs = await repo.findByUser(USER_B_ID);
|
||||
const ids = prefs.map((p) => p.id);
|
||||
expect(ids).not.toContain(PREF_A_ID);
|
||||
});
|
||||
|
||||
it('findByUserAndKey is scoped to the requesting user', async () => {
|
||||
const repo = createPreferencesRepo(handle.db);
|
||||
// Both users have key "theme" — each should only see their own value
|
||||
const prefA = await repo.findByUserAndKey(USER_A_ID, 'theme');
|
||||
const prefB = await repo.findByUserAndKey(USER_B_ID, 'theme');
|
||||
|
||||
expect(prefA).toBeDefined();
|
||||
// Drizzle returns JSONB values as parsed JS values; '"dark"' (JSON string) → 'dark'
|
||||
expect(prefA!.value).toBe('dark');
|
||||
expect(prefB).toBeDefined();
|
||||
expect(prefB!.value).toBe('light');
|
||||
});
|
||||
|
||||
it('remove is scoped to the requesting user (cannot delete another user pref)', async () => {
|
||||
const repo = createPreferencesRepo(handle.db);
|
||||
// User B tries to delete User A's "theme" preference — should silently fail
|
||||
const deleted = await repo.remove(USER_B_ID, 'theme');
|
||||
// This only deletes USER_B's own "theme" row; re-insert it for afterAll cleanup
|
||||
expect(deleted).toBe(true); // deletes User B's OWN theme pref
|
||||
|
||||
// User A's theme pref must be untouched
|
||||
const prefA = await repo.findByUserAndKey(USER_A_ID, 'theme');
|
||||
expect(prefA).toBeDefined();
|
||||
|
||||
// Re-insert User B's preference so afterAll cleanup still finds it
|
||||
await repo.upsert({
|
||||
id: PREF_B_ID,
|
||||
userId: USER_B_ID,
|
||||
key: 'theme',
|
||||
value: 'light',
|
||||
category: 'appearance',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Insights isolation ───────────────────────────────────────────────────────
|
||||
|
||||
describe('InsightsRepo — cross-user isolation', () => {
|
||||
it('User A can retrieve their own insights', async () => {
|
||||
const repo = createInsightsRepo(handle.db);
|
||||
const list = await repo.findByUser(USER_A_ID);
|
||||
const ids = list.map((i) => i.id);
|
||||
expect(ids).toContain(INSIGHT_A_ID);
|
||||
});
|
||||
|
||||
it('User A insights do not contain User B insights', async () => {
|
||||
const repo = createInsightsRepo(handle.db);
|
||||
const list = await repo.findByUser(USER_A_ID);
|
||||
const ids = list.map((i) => i.id);
|
||||
expect(ids).not.toContain(INSIGHT_B_ID);
|
||||
});
|
||||
|
||||
it('User B insights do not contain User A insights', async () => {
|
||||
const repo = createInsightsRepo(handle.db);
|
||||
const list = await repo.findByUser(USER_B_ID);
|
||||
const ids = list.map((i) => i.id);
|
||||
expect(ids).not.toContain(INSIGHT_A_ID);
|
||||
});
|
||||
|
||||
it('findById is scoped to the requesting user', async () => {
|
||||
const repo = createInsightsRepo(handle.db);
|
||||
const own = await repo.findById(INSIGHT_A_ID, USER_A_ID);
|
||||
const cross = await repo.findById(INSIGHT_A_ID, USER_B_ID);
|
||||
|
||||
expect(own).toBeDefined();
|
||||
expect(cross).toBeUndefined();
|
||||
});
|
||||
|
||||
it('searchByEmbedding returns only own insights', async () => {
|
||||
const repo = createInsightsRepo(handle.db);
|
||||
// Our test insights have no embedding — the query filters WHERE embedding IS NOT NULL
|
||||
// so the result set is empty, which already proves no cross-user leakage.
|
||||
// Using a 1536-dimension zero vector as the query embedding.
|
||||
const zeroVector = Array<number>(1536).fill(0);
|
||||
|
||||
const resultsA = await repo.searchByEmbedding(USER_A_ID, zeroVector, 50, 2.0);
|
||||
const resultsB = await repo.searchByEmbedding(USER_B_ID, zeroVector, 50, 2.0);
|
||||
|
||||
// The raw SQL query returns row objects directly (not wrapped in { insight }).
|
||||
// Cast via unknown to extract id safely regardless of the return shape.
|
||||
const toId = (r: unknown): string =>
|
||||
((r as Record<string, unknown>)['id'] as string | undefined) ??
|
||||
((r as Record<string, Record<string, unknown>>)['insight']?.['id'] as string | undefined) ??
|
||||
'';
|
||||
const idsInA = resultsA.map(toId);
|
||||
const idsInB = resultsB.map(toId);
|
||||
|
||||
// User B's insight must never appear in User A's search results
|
||||
expect(idsInA).not.toContain(INSIGHT_B_ID);
|
||||
// User A's insight must never appear in User B's search results
|
||||
expect(idsInB).not.toContain(INSIGHT_A_ID);
|
||||
});
|
||||
|
||||
it('update is scoped to the requesting user', async () => {
|
||||
const repo = createInsightsRepo(handle.db);
|
||||
const result = await repo.update(INSIGHT_A_ID, USER_B_ID, { content: 'hacked' });
|
||||
expect(result).toBeUndefined();
|
||||
|
||||
// Verify the insight was not mutated
|
||||
const unchanged = await repo.findById(INSIGHT_A_ID, USER_A_ID);
|
||||
expect(unchanged?.content).toBe('User A insight');
|
||||
});
|
||||
|
||||
it('remove is scoped to the requesting user', async () => {
|
||||
const repo = createInsightsRepo(handle.db);
|
||||
const deleted = await repo.remove(INSIGHT_A_ID, USER_B_ID);
|
||||
expect(deleted).toBe(false);
|
||||
|
||||
// Verify the insight still exists
|
||||
const still = await repo.findById(INSIGHT_A_ID, USER_A_ID);
|
||||
expect(still).toBeDefined();
|
||||
});
|
||||
});
|
||||
150
apps/gateway/src/__tests__/resource-ownership.test.ts
Normal file
150
apps/gateway/src/__tests__/resource-ownership.test.ts
Normal file
@@ -0,0 +1,150 @@
|
||||
import { ForbiddenException, NotFoundException } from '@nestjs/common';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import { ConversationsController } from '../conversations/conversations.controller.js';
|
||||
import { MissionsController } from '../missions/missions.controller.js';
|
||||
import { ProjectsController } from '../projects/projects.controller.js';
|
||||
import { TasksController } from '../tasks/tasks.controller.js';
|
||||
|
||||
function createBrain() {
|
||||
return {
|
||||
conversations: {
|
||||
findAll: vi.fn(),
|
||||
findById: vi.fn(),
|
||||
create: vi.fn(),
|
||||
update: vi.fn(),
|
||||
remove: vi.fn(),
|
||||
findMessages: vi.fn(),
|
||||
addMessage: vi.fn(),
|
||||
},
|
||||
projects: {
|
||||
findAll: vi.fn(),
|
||||
findAllForUser: vi.fn(),
|
||||
findById: vi.fn(),
|
||||
create: vi.fn(),
|
||||
update: vi.fn(),
|
||||
remove: vi.fn(),
|
||||
},
|
||||
missions: {
|
||||
findAll: vi.fn(),
|
||||
findAllByUser: vi.fn(),
|
||||
findById: vi.fn(),
|
||||
findByIdAndUser: vi.fn(),
|
||||
findByProject: vi.fn(),
|
||||
create: vi.fn(),
|
||||
update: vi.fn(),
|
||||
remove: vi.fn(),
|
||||
},
|
||||
missionTasks: {
|
||||
findByMissionAndUser: vi.fn(),
|
||||
findByIdAndUser: vi.fn(),
|
||||
create: vi.fn(),
|
||||
update: vi.fn(),
|
||||
remove: vi.fn(),
|
||||
},
|
||||
tasks: {
|
||||
findAll: vi.fn(),
|
||||
findById: vi.fn(),
|
||||
findByProject: vi.fn(),
|
||||
findByMission: vi.fn(),
|
||||
findByStatus: vi.fn(),
|
||||
create: vi.fn(),
|
||||
update: vi.fn(),
|
||||
remove: vi.fn(),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe('Resource ownership checks', () => {
|
||||
it('forbids access to another user conversation', async () => {
|
||||
const brain = createBrain();
|
||||
// The repo enforces ownership via the WHERE clause; it returns undefined when the
|
||||
// conversation does not belong to the requesting user.
|
||||
brain.conversations.findById.mockResolvedValue(undefined);
|
||||
const controller = new ConversationsController(brain as never);
|
||||
|
||||
await expect(controller.findOne('conv-1', { id: 'user-1' })).rejects.toBeInstanceOf(
|
||||
NotFoundException,
|
||||
);
|
||||
});
|
||||
|
||||
it('forbids access to another user project', async () => {
|
||||
const brain = createBrain();
|
||||
brain.projects.findById.mockResolvedValue({ id: 'project-1', ownerId: 'user-2' });
|
||||
const teamsService = { canAccessProject: vi.fn().mockResolvedValue(false) };
|
||||
const controller = new ProjectsController(brain as never, teamsService as never);
|
||||
|
||||
await expect(controller.findOne('project-1', { id: 'user-1' })).rejects.toBeInstanceOf(
|
||||
ForbiddenException,
|
||||
);
|
||||
});
|
||||
|
||||
it('forbids access to a mission owned by another user', async () => {
|
||||
const brain = createBrain();
|
||||
// findByIdAndUser returns undefined when the mission doesn't belong to the user
|
||||
brain.missions.findByIdAndUser.mockResolvedValue(undefined);
|
||||
const controller = new MissionsController(brain as never);
|
||||
|
||||
await expect(controller.findOne('mission-1', { id: 'user-1' })).rejects.toBeInstanceOf(
|
||||
NotFoundException,
|
||||
);
|
||||
});
|
||||
|
||||
it('forbids access to a task owned by another project owner', async () => {
|
||||
const brain = createBrain();
|
||||
brain.tasks.findById.mockResolvedValue({ id: 'task-1', projectId: 'project-1' });
|
||||
brain.projects.findById.mockResolvedValue({ id: 'project-1', ownerId: 'user-2' });
|
||||
const controller = new TasksController(brain as never);
|
||||
|
||||
await expect(controller.findOne('task-1', { id: 'user-1' })).rejects.toBeInstanceOf(
|
||||
ForbiddenException,
|
||||
);
|
||||
});
|
||||
|
||||
it('forbids creating a task with an unowned project', async () => {
|
||||
const brain = createBrain();
|
||||
brain.projects.findById.mockResolvedValue({ id: 'project-1', ownerId: 'user-2' });
|
||||
const controller = new TasksController(brain as never);
|
||||
|
||||
await expect(
|
||||
controller.create(
|
||||
{
|
||||
title: 'Task',
|
||||
projectId: 'project-1',
|
||||
},
|
||||
{ id: 'user-1' },
|
||||
),
|
||||
).rejects.toBeInstanceOf(ForbiddenException);
|
||||
});
|
||||
|
||||
it('forbids listing tasks for an unowned project', async () => {
|
||||
const brain = createBrain();
|
||||
brain.projects.findById.mockResolvedValue({ id: 'project-1', ownerId: 'user-2' });
|
||||
const controller = new TasksController(brain as never);
|
||||
|
||||
await expect(
|
||||
controller.list({ id: 'user-1' }, 'project-1', undefined, undefined),
|
||||
).rejects.toBeInstanceOf(ForbiddenException);
|
||||
});
|
||||
|
||||
it('lists only tasks for the current user owned projects when no filter is provided', async () => {
|
||||
const brain = createBrain();
|
||||
brain.projects.findAll.mockResolvedValue([
|
||||
{ id: 'project-1', ownerId: 'user-1' },
|
||||
{ id: 'project-2', ownerId: 'user-2' },
|
||||
]);
|
||||
brain.missions.findAll.mockResolvedValue([{ id: 'mission-1', projectId: 'project-1' }]);
|
||||
brain.tasks.findAll.mockResolvedValue([
|
||||
{ id: 'task-1', projectId: 'project-1' },
|
||||
{ id: 'task-2', missionId: 'mission-1' },
|
||||
{ id: 'task-3', projectId: 'project-2' },
|
||||
]);
|
||||
const controller = new TasksController(brain as never);
|
||||
|
||||
await expect(
|
||||
controller.list({ id: 'user-1' }, undefined, undefined, undefined),
|
||||
).resolves.toEqual([
|
||||
{ id: 'task-1', projectId: 'project-1' },
|
||||
{ id: 'task-2', missionId: 'mission-1' },
|
||||
]);
|
||||
});
|
||||
});
|
||||
73
apps/gateway/src/admin/admin-health.controller.ts
Normal file
73
apps/gateway/src/admin/admin-health.controller.ts
Normal file
@@ -0,0 +1,73 @@
|
||||
import { Controller, Get, Inject, UseGuards } from '@nestjs/common';
|
||||
import { sql, type Db } from '@mosaic/db';
|
||||
import { createQueue } from '@mosaic/queue';
|
||||
import { DB } from '../database/database.module.js';
|
||||
import { AgentService } from '../agent/agent.service.js';
|
||||
import { ProviderService } from '../agent/provider.service.js';
|
||||
import { AdminGuard } from './admin.guard.js';
|
||||
import type { HealthStatusDto, ServiceStatusDto } from './admin.dto.js';
|
||||
|
||||
@Controller('api/admin/health')
|
||||
@UseGuards(AdminGuard)
|
||||
export class AdminHealthController {
|
||||
constructor(
|
||||
@Inject(DB) private readonly db: Db,
|
||||
@Inject(AgentService) private readonly agentService: AgentService,
|
||||
@Inject(ProviderService) private readonly providerService: ProviderService,
|
||||
) {}
|
||||
|
||||
@Get()
|
||||
async check(): Promise<HealthStatusDto> {
|
||||
const [database, cache] = await Promise.all([this.checkDatabase(), this.checkCache()]);
|
||||
|
||||
const sessions = this.agentService.listSessions();
|
||||
const providers = this.providerService.listProviders();
|
||||
|
||||
const allOk = database.status === 'ok' && cache.status === 'ok';
|
||||
|
||||
return {
|
||||
status: allOk ? 'ok' : 'degraded',
|
||||
database,
|
||||
cache,
|
||||
agentPool: { activeSessions: sessions.length },
|
||||
providers: providers.map((p) => ({
|
||||
id: p.id,
|
||||
name: p.name,
|
||||
available: p.available,
|
||||
modelCount: p.models.length,
|
||||
})),
|
||||
checkedAt: new Date().toISOString(),
|
||||
};
|
||||
}
|
||||
|
||||
private async checkDatabase(): Promise<ServiceStatusDto> {
|
||||
const start = Date.now();
|
||||
try {
|
||||
await this.db.execute(sql`SELECT 1`);
|
||||
return { status: 'ok', latencyMs: Date.now() - start };
|
||||
} catch (err) {
|
||||
return {
|
||||
status: 'error',
|
||||
latencyMs: Date.now() - start,
|
||||
error: err instanceof Error ? err.message : String(err),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private async checkCache(): Promise<ServiceStatusDto> {
|
||||
const start = Date.now();
|
||||
const handle = createQueue();
|
||||
try {
|
||||
await handle.redis.ping();
|
||||
return { status: 'ok', latencyMs: Date.now() - start };
|
||||
} catch (err) {
|
||||
return {
|
||||
status: 'error',
|
||||
latencyMs: Date.now() - start,
|
||||
error: err instanceof Error ? err.message : String(err),
|
||||
};
|
||||
} finally {
|
||||
await handle.close().catch(() => {});
|
||||
}
|
||||
}
|
||||
}
|
||||
146
apps/gateway/src/admin/admin.controller.ts
Normal file
146
apps/gateway/src/admin/admin.controller.ts
Normal file
@@ -0,0 +1,146 @@
|
||||
import {
|
||||
Body,
|
||||
Controller,
|
||||
Delete,
|
||||
Get,
|
||||
HttpCode,
|
||||
HttpStatus,
|
||||
Inject,
|
||||
InternalServerErrorException,
|
||||
NotFoundException,
|
||||
Param,
|
||||
Patch,
|
||||
Post,
|
||||
UseGuards,
|
||||
} from '@nestjs/common';
|
||||
import { eq, type Db, users as usersTable } from '@mosaic/db';
|
||||
import type { Auth } from '@mosaic/auth';
|
||||
import { AUTH } from '../auth/auth.tokens.js';
|
||||
import { DB } from '../database/database.module.js';
|
||||
import { AdminGuard } from './admin.guard.js';
|
||||
import type {
|
||||
BanUserDto,
|
||||
CreateUserDto,
|
||||
UpdateUserRoleDto,
|
||||
UserDto,
|
||||
UserListDto,
|
||||
} from './admin.dto.js';
|
||||
|
||||
type UserRow = typeof usersTable.$inferSelect;
|
||||
|
||||
function toUserDto(u: UserRow): UserDto {
|
||||
return {
|
||||
id: u.id,
|
||||
name: u.name,
|
||||
email: u.email,
|
||||
role: u.role,
|
||||
banned: u.banned ?? false,
|
||||
banReason: u.banReason ?? null,
|
||||
createdAt: u.createdAt.toISOString(),
|
||||
updatedAt: u.updatedAt.toISOString(),
|
||||
};
|
||||
}
|
||||
|
||||
async function requireUpdated(
|
||||
db: Db,
|
||||
id: string,
|
||||
update: Partial<Omit<UserRow, 'id' | 'createdAt'>>,
|
||||
): Promise<UserDto> {
|
||||
const [updated] = await db
|
||||
.update(usersTable)
|
||||
.set({ ...update, updatedAt: new Date() })
|
||||
.where(eq(usersTable.id, id))
|
||||
.returning();
|
||||
if (!updated) throw new InternalServerErrorException('Update returned no rows');
|
||||
return toUserDto(updated);
|
||||
}
|
||||
|
||||
@Controller('api/admin/users')
|
||||
@UseGuards(AdminGuard)
|
||||
export class AdminController {
|
||||
constructor(
|
||||
@Inject(DB) private readonly db: Db,
|
||||
@Inject(AUTH) private readonly auth: Auth,
|
||||
) {}
|
||||
|
||||
@Get()
|
||||
async listUsers(): Promise<UserListDto> {
|
||||
const rows = await this.db.select().from(usersTable).orderBy(usersTable.createdAt);
|
||||
const userList: UserDto[] = rows.map(toUserDto);
|
||||
return { users: userList, total: userList.length };
|
||||
}
|
||||
|
||||
@Get(':id')
|
||||
async getUser(@Param('id') id: string): Promise<UserDto> {
|
||||
const [user] = await this.db.select().from(usersTable).where(eq(usersTable.id, id)).limit(1);
|
||||
if (!user) throw new NotFoundException('User not found');
|
||||
return toUserDto(user);
|
||||
}
|
||||
|
||||
@Post()
|
||||
async createUser(@Body() body: CreateUserDto): Promise<UserDto> {
|
||||
// Use auth API to create user so password is properly hashed
|
||||
const authApi = this.auth.api as unknown as {
|
||||
createUser: (opts: {
|
||||
body: { name: string; email: string; password: string; role?: string };
|
||||
}) => Promise<{
|
||||
user: { id: string; name: string; email: string; createdAt: unknown; updatedAt: unknown };
|
||||
}>;
|
||||
};
|
||||
|
||||
const result = await authApi.createUser({
|
||||
body: {
|
||||
name: body.name,
|
||||
email: body.email,
|
||||
password: body.password,
|
||||
role: body.role ?? 'member',
|
||||
},
|
||||
});
|
||||
|
||||
// Re-fetch from DB to get full row with our schema
|
||||
const [user] = await this.db
|
||||
.select()
|
||||
.from(usersTable)
|
||||
.where(eq(usersTable.id, result.user.id))
|
||||
.limit(1);
|
||||
|
||||
if (!user) throw new InternalServerErrorException('User created but not found in DB');
|
||||
return toUserDto(user);
|
||||
}
|
||||
|
||||
@Patch(':id/role')
|
||||
async setRole(@Param('id') id: string, @Body() body: UpdateUserRoleDto): Promise<UserDto> {
|
||||
await this.ensureExists(id);
|
||||
return requireUpdated(this.db, id, { role: body.role });
|
||||
}
|
||||
|
||||
@Post(':id/ban')
|
||||
@HttpCode(HttpStatus.OK)
|
||||
async banUser(@Param('id') id: string, @Body() body: BanUserDto): Promise<UserDto> {
|
||||
await this.ensureExists(id);
|
||||
return requireUpdated(this.db, id, { banned: true, banReason: body.reason ?? null });
|
||||
}
|
||||
|
||||
@Post(':id/unban')
|
||||
@HttpCode(HttpStatus.OK)
|
||||
async unbanUser(@Param('id') id: string): Promise<UserDto> {
|
||||
await this.ensureExists(id);
|
||||
return requireUpdated(this.db, id, { banned: false, banReason: null, banExpires: null });
|
||||
}
|
||||
|
||||
@Delete(':id')
|
||||
@HttpCode(HttpStatus.NO_CONTENT)
|
||||
async deleteUser(@Param('id') id: string): Promise<void> {
|
||||
await this.ensureExists(id);
|
||||
await this.db.delete(usersTable).where(eq(usersTable.id, id));
|
||||
}
|
||||
|
||||
private async ensureExists(id: string): Promise<void> {
|
||||
const [existing] = await this.db
|
||||
.select({ id: usersTable.id })
|
||||
.from(usersTable)
|
||||
.where(eq(usersTable.id, id))
|
||||
.limit(1);
|
||||
if (!existing) throw new NotFoundException('User not found');
|
||||
}
|
||||
}
|
||||
56
apps/gateway/src/admin/admin.dto.ts
Normal file
56
apps/gateway/src/admin/admin.dto.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
export interface UserDto {
|
||||
id: string;
|
||||
name: string;
|
||||
email: string;
|
||||
role: string;
|
||||
banned: boolean;
|
||||
banReason: string | null;
|
||||
createdAt: string;
|
||||
updatedAt: string;
|
||||
}
|
||||
|
||||
export interface UserListDto {
|
||||
users: UserDto[];
|
||||
total: number;
|
||||
}
|
||||
|
||||
export interface CreateUserDto {
|
||||
name: string;
|
||||
email: string;
|
||||
password: string;
|
||||
role?: string;
|
||||
}
|
||||
|
||||
export interface UpdateUserRoleDto {
|
||||
role: string;
|
||||
}
|
||||
|
||||
export interface BanUserDto {
|
||||
reason?: string;
|
||||
}
|
||||
|
||||
export interface HealthStatusDto {
|
||||
status: 'ok' | 'degraded' | 'error';
|
||||
database: ServiceStatusDto;
|
||||
cache: ServiceStatusDto;
|
||||
agentPool: AgentPoolStatusDto;
|
||||
providers: ProviderStatusDto[];
|
||||
checkedAt: string;
|
||||
}
|
||||
|
||||
export interface ServiceStatusDto {
|
||||
status: 'ok' | 'error';
|
||||
latencyMs?: number;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface AgentPoolStatusDto {
|
||||
activeSessions: number;
|
||||
}
|
||||
|
||||
export interface ProviderStatusDto {
|
||||
id: string;
|
||||
name: string;
|
||||
available: boolean;
|
||||
modelCount: number;
|
||||
}
|
||||
64
apps/gateway/src/admin/admin.guard.ts
Normal file
64
apps/gateway/src/admin/admin.guard.ts
Normal file
@@ -0,0 +1,64 @@
|
||||
import {
|
||||
CanActivate,
|
||||
ExecutionContext,
|
||||
ForbiddenException,
|
||||
Inject,
|
||||
Injectable,
|
||||
UnauthorizedException,
|
||||
} from '@nestjs/common';
|
||||
import { fromNodeHeaders } from 'better-auth/node';
|
||||
import type { Auth } from '@mosaic/auth';
|
||||
import type { Db } from '@mosaic/db';
|
||||
import { eq, users as usersTable } from '@mosaic/db';
|
||||
import type { FastifyRequest } from 'fastify';
|
||||
import { AUTH } from '../auth/auth.tokens.js';
|
||||
import { DB } from '../database/database.module.js';
|
||||
|
||||
interface UserWithRole {
|
||||
id: string;
|
||||
role?: string;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class AdminGuard implements CanActivate {
|
||||
constructor(
|
||||
@Inject(AUTH) private readonly auth: Auth,
|
||||
@Inject(DB) private readonly db: Db,
|
||||
) {}
|
||||
|
||||
async canActivate(context: ExecutionContext): Promise<boolean> {
|
||||
const request = context.switchToHttp().getRequest<FastifyRequest>();
|
||||
const headers = fromNodeHeaders(request.raw.headers);
|
||||
|
||||
const result = await this.auth.api.getSession({ headers });
|
||||
|
||||
if (!result) {
|
||||
throw new UnauthorizedException('Invalid or expired session');
|
||||
}
|
||||
|
||||
const user = result.user as UserWithRole;
|
||||
|
||||
// Ensure the role field is populated. better-auth should include additionalFields
|
||||
// in the session, but as a fallback, fetch the role from the database if needed.
|
||||
let userRole = user.role;
|
||||
if (!userRole) {
|
||||
const [dbUser] = await this.db
|
||||
.select({ role: usersTable.role })
|
||||
.from(usersTable)
|
||||
.where(eq(usersTable.id, user.id))
|
||||
.limit(1);
|
||||
userRole = dbUser?.role ?? 'member';
|
||||
// Update the session user object with the fetched role
|
||||
(user as UserWithRole).role = userRole;
|
||||
}
|
||||
|
||||
if (userRole !== 'admin') {
|
||||
throw new ForbiddenException('Admin access required');
|
||||
}
|
||||
|
||||
(request as FastifyRequest & { user: unknown; session: unknown }).user = result.user;
|
||||
(request as FastifyRequest & { user: unknown; session: unknown }).session = result.session;
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
10
apps/gateway/src/admin/admin.module.ts
Normal file
10
apps/gateway/src/admin/admin.module.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { AdminController } from './admin.controller.js';
|
||||
import { AdminHealthController } from './admin-health.controller.js';
|
||||
import { AdminGuard } from './admin.guard.js';
|
||||
|
||||
@Module({
|
||||
controllers: [AdminController, AdminHealthController],
|
||||
providers: [AdminGuard],
|
||||
})
|
||||
export class AdminModule {}
|
||||
143
apps/gateway/src/agent/__tests__/provider.service.test.ts
Normal file
143
apps/gateway/src/agent/__tests__/provider.service.test.ts
Normal file
@@ -0,0 +1,143 @@
|
||||
import { beforeEach, afterEach, describe, expect, it } from 'vitest';
|
||||
import { ProviderService } from '../provider.service.js';
|
||||
|
||||
const ENV_KEYS = [
|
||||
'ANTHROPIC_API_KEY',
|
||||
'OPENAI_API_KEY',
|
||||
'ZAI_API_KEY',
|
||||
'OLLAMA_BASE_URL',
|
||||
'OLLAMA_HOST',
|
||||
'OLLAMA_MODELS',
|
||||
'MOSAIC_CUSTOM_PROVIDERS',
|
||||
] as const;
|
||||
|
||||
type EnvKey = (typeof ENV_KEYS)[number];
|
||||
|
||||
describe('ProviderService', () => {
|
||||
const savedEnv = new Map<EnvKey, string | undefined>();
|
||||
|
||||
beforeEach(() => {
|
||||
for (const key of ENV_KEYS) {
|
||||
savedEnv.set(key, process.env[key]);
|
||||
delete process.env[key];
|
||||
}
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
for (const key of ENV_KEYS) {
|
||||
const value = savedEnv.get(key);
|
||||
if (value === undefined) {
|
||||
delete process.env[key];
|
||||
} else {
|
||||
process.env[key] = value;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it('skips API-key providers when env vars are missing (no models become available)', async () => {
|
||||
const service = new ProviderService();
|
||||
await service.onModuleInit();
|
||||
|
||||
// Pi's built-in registry may include model definitions for all providers, but
|
||||
// without API keys none of them should be available (usable).
|
||||
const availableModels = service.listAvailableModels();
|
||||
const availableProviderIds = new Set(availableModels.map((m) => m.provider));
|
||||
|
||||
expect(availableProviderIds).not.toContain('anthropic');
|
||||
expect(availableProviderIds).not.toContain('openai');
|
||||
expect(availableProviderIds).not.toContain('zai');
|
||||
|
||||
// Providers list may show built-in providers, but they should not be marked available
|
||||
const providers = service.listProviders();
|
||||
for (const p of providers.filter((p) => ['anthropic', 'openai', 'zai'].includes(p.id))) {
|
||||
expect(p.available).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it('registers Anthropic provider with correct models when ANTHROPIC_API_KEY is set', async () => {
|
||||
process.env['ANTHROPIC_API_KEY'] = 'test-anthropic';
|
||||
|
||||
const service = new ProviderService();
|
||||
await service.onModuleInit();
|
||||
|
||||
const providers = service.listProviders();
|
||||
const anthropic = providers.find((p) => p.id === 'anthropic');
|
||||
expect(anthropic).toBeDefined();
|
||||
expect(anthropic!.available).toBe(true);
|
||||
expect(anthropic!.models.map((m) => m.id)).toEqual([
|
||||
'claude-sonnet-4-6',
|
||||
'claude-opus-4-6',
|
||||
'claude-haiku-4-5',
|
||||
]);
|
||||
// contextWindow override from Pi built-in (200000)
|
||||
for (const m of anthropic!.models) {
|
||||
expect(m.contextWindow).toBe(200000);
|
||||
// maxTokens capped at 8192 per task spec
|
||||
expect(m.maxTokens).toBe(8192);
|
||||
}
|
||||
});
|
||||
|
||||
it('registers OpenAI provider with correct models when OPENAI_API_KEY is set', async () => {
|
||||
process.env['OPENAI_API_KEY'] = 'test-openai';
|
||||
|
||||
const service = new ProviderService();
|
||||
await service.onModuleInit();
|
||||
|
||||
const providers = service.listProviders();
|
||||
const openai = providers.find((p) => p.id === 'openai');
|
||||
expect(openai).toBeDefined();
|
||||
expect(openai!.available).toBe(true);
|
||||
expect(openai!.models.map((m) => m.id)).toEqual(['gpt-4o', 'gpt-4o-mini', 'o3-mini']);
|
||||
});
|
||||
|
||||
it('registers Z.ai provider with correct models when ZAI_API_KEY is set', async () => {
|
||||
process.env['ZAI_API_KEY'] = 'test-zai';
|
||||
|
||||
const service = new ProviderService();
|
||||
await service.onModuleInit();
|
||||
|
||||
const providers = service.listProviders();
|
||||
const zai = providers.find((p) => p.id === 'zai');
|
||||
expect(zai).toBeDefined();
|
||||
expect(zai!.available).toBe(true);
|
||||
expect(zai!.models.map((m) => m.id)).toEqual(['glm-4.5', 'glm-4.5-air', 'glm-4.5-flash']);
|
||||
});
|
||||
|
||||
it('registers all three providers when all keys are set', async () => {
|
||||
process.env['ANTHROPIC_API_KEY'] = 'test-anthropic';
|
||||
process.env['OPENAI_API_KEY'] = 'test-openai';
|
||||
process.env['ZAI_API_KEY'] = 'test-zai';
|
||||
|
||||
const service = new ProviderService();
|
||||
await service.onModuleInit();
|
||||
|
||||
const providerIds = service.listProviders().map((p) => p.id);
|
||||
expect(providerIds).toContain('anthropic');
|
||||
expect(providerIds).toContain('openai');
|
||||
expect(providerIds).toContain('zai');
|
||||
});
|
||||
|
||||
it('can find registered Anthropic models by provider+id', async () => {
|
||||
process.env['ANTHROPIC_API_KEY'] = 'test-anthropic';
|
||||
|
||||
const service = new ProviderService();
|
||||
await service.onModuleInit();
|
||||
|
||||
const sonnet = service.findModel('anthropic', 'claude-sonnet-4-6');
|
||||
expect(sonnet).toBeDefined();
|
||||
expect(sonnet!.provider).toBe('anthropic');
|
||||
expect(sonnet!.id).toBe('claude-sonnet-4-6');
|
||||
});
|
||||
|
||||
it('can find registered Z.ai models by provider+id', async () => {
|
||||
process.env['ZAI_API_KEY'] = 'test-zai';
|
||||
|
||||
const service = new ProviderService();
|
||||
await service.onModuleInit();
|
||||
|
||||
const glm = service.findModel('zai', 'glm-4.5');
|
||||
expect(glm).toBeDefined();
|
||||
expect(glm!.provider).toBe('zai');
|
||||
expect(glm!.id).toBe('glm-4.5');
|
||||
});
|
||||
});
|
||||
191
apps/gateway/src/agent/adapters/anthropic.adapter.ts
Normal file
191
apps/gateway/src/agent/adapters/anthropic.adapter.ts
Normal file
@@ -0,0 +1,191 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import Anthropic from '@anthropic-ai/sdk';
|
||||
import type { ModelRegistry } from '@mariozechner/pi-coding-agent';
|
||||
import type {
|
||||
CompletionEvent,
|
||||
CompletionParams,
|
||||
IProviderAdapter,
|
||||
ModelInfo,
|
||||
ProviderHealth,
|
||||
} from '@mosaic/types';
|
||||
|
||||
/**
|
||||
* Anthropic provider adapter.
|
||||
*
|
||||
* Registers Claude models with the Pi ModelRegistry via the Anthropic SDK.
|
||||
* Configuration is driven by environment variables:
|
||||
* ANTHROPIC_API_KEY — Anthropic API key (required)
|
||||
*/
|
||||
export class AnthropicAdapter implements IProviderAdapter {
|
||||
readonly name = 'anthropic';
|
||||
|
||||
private readonly logger = new Logger(AnthropicAdapter.name);
|
||||
private client: Anthropic | null = null;
|
||||
private registeredModels: ModelInfo[] = [];
|
||||
|
||||
constructor(private readonly registry: ModelRegistry) {}
|
||||
|
||||
async register(): Promise<void> {
|
||||
const apiKey = process.env['ANTHROPIC_API_KEY'];
|
||||
if (!apiKey) {
|
||||
this.logger.warn('Skipping Anthropic provider registration: ANTHROPIC_API_KEY not set');
|
||||
return;
|
||||
}
|
||||
|
||||
this.client = new Anthropic({ apiKey });
|
||||
|
||||
const models: ModelInfo[] = [
|
||||
{
|
||||
id: 'claude-opus-4-6',
|
||||
provider: 'anthropic',
|
||||
name: 'Claude Opus 4.6',
|
||||
reasoning: true,
|
||||
contextWindow: 200000,
|
||||
maxTokens: 32000,
|
||||
inputTypes: ['text', 'image'],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
},
|
||||
{
|
||||
id: 'claude-sonnet-4-6',
|
||||
provider: 'anthropic',
|
||||
name: 'Claude Sonnet 4.6',
|
||||
reasoning: true,
|
||||
contextWindow: 200000,
|
||||
maxTokens: 16000,
|
||||
inputTypes: ['text', 'image'],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
},
|
||||
{
|
||||
id: 'claude-haiku-4-5',
|
||||
provider: 'anthropic',
|
||||
name: 'Claude Haiku 4.5',
|
||||
reasoning: false,
|
||||
contextWindow: 200000,
|
||||
maxTokens: 8192,
|
||||
inputTypes: ['text', 'image'],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
},
|
||||
];
|
||||
|
||||
this.registry.registerProvider('anthropic', {
|
||||
apiKey,
|
||||
baseUrl: 'https://api.anthropic.com',
|
||||
api: 'anthropic' as never,
|
||||
models: models.map((m) => ({
|
||||
id: m.id,
|
||||
name: m.name,
|
||||
reasoning: m.reasoning,
|
||||
input: m.inputTypes as ('text' | 'image')[],
|
||||
cost: m.cost,
|
||||
contextWindow: m.contextWindow,
|
||||
maxTokens: m.maxTokens,
|
||||
})),
|
||||
});
|
||||
|
||||
this.registeredModels = models;
|
||||
|
||||
this.logger.log(
|
||||
`Anthropic provider registered with models: ${models.map((m) => m.id).join(', ')}`,
|
||||
);
|
||||
}
|
||||
|
||||
listModels(): ModelInfo[] {
|
||||
return this.registeredModels;
|
||||
}
|
||||
|
||||
async healthCheck(): Promise<ProviderHealth> {
|
||||
const apiKey = process.env['ANTHROPIC_API_KEY'];
|
||||
if (!apiKey) {
|
||||
return {
|
||||
status: 'down',
|
||||
lastChecked: new Date().toISOString(),
|
||||
error: 'ANTHROPIC_API_KEY not configured',
|
||||
};
|
||||
}
|
||||
|
||||
const start = Date.now();
|
||||
|
||||
try {
|
||||
const client = this.client ?? new Anthropic({ apiKey });
|
||||
await client.models.list({ limit: 1 });
|
||||
const latencyMs = Date.now() - start;
|
||||
return { status: 'healthy', latencyMs, lastChecked: new Date().toISOString() };
|
||||
} catch (err) {
|
||||
const latencyMs = Date.now() - start;
|
||||
const error = err instanceof Error ? err.message : String(err);
|
||||
const status = error.includes('401') || error.includes('403') ? 'degraded' : 'down';
|
||||
return { status, latencyMs, lastChecked: new Date().toISOString(), error };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a completion from Anthropic using the messages API.
|
||||
* Maps Anthropic streaming events to the CompletionEvent format.
|
||||
*
|
||||
* Note: Currently reserved for future direct-completion use. The Pi SDK
|
||||
* integration routes completions through ModelRegistry / AgentSession.
|
||||
*/
|
||||
async *createCompletion(params: CompletionParams): AsyncIterable<CompletionEvent> {
|
||||
const apiKey = process.env['ANTHROPIC_API_KEY'];
|
||||
if (!apiKey) {
|
||||
throw new Error('AnthropicAdapter: ANTHROPIC_API_KEY not configured');
|
||||
}
|
||||
|
||||
const client = this.client ?? new Anthropic({ apiKey });
|
||||
|
||||
// Separate system messages from user/assistant messages
|
||||
const systemMessages = params.messages.filter((m) => m.role === 'system');
|
||||
const conversationMessages = params.messages.filter((m) => m.role !== 'system');
|
||||
|
||||
const systemPrompt =
|
||||
systemMessages.length > 0 ? systemMessages.map((m) => m.content).join('\n') : undefined;
|
||||
|
||||
const stream = await client.messages.stream({
|
||||
model: params.model,
|
||||
max_tokens: params.maxTokens ?? 1024,
|
||||
...(systemPrompt !== undefined ? { system: systemPrompt } : {}),
|
||||
messages: conversationMessages.map((m) => ({
|
||||
role: m.role as 'user' | 'assistant',
|
||||
content: m.content,
|
||||
})),
|
||||
...(params.temperature !== undefined ? { temperature: params.temperature } : {}),
|
||||
...(params.tools && params.tools.length > 0
|
||||
? {
|
||||
tools: params.tools.map((t) => ({
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
input_schema: t.parameters as Anthropic.Tool['input_schema'],
|
||||
})),
|
||||
}
|
||||
: {}),
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
if (event.type === 'content_block_delta' && event.delta.type === 'text_delta') {
|
||||
yield { type: 'text_delta', content: event.delta.text };
|
||||
} else if (event.type === 'content_block_delta' && event.delta.type === 'input_json_delta') {
|
||||
yield { type: 'tool_call', name: '', arguments: event.delta.partial_json };
|
||||
} else if (event.type === 'message_delta' && event.usage) {
|
||||
yield {
|
||||
type: 'done',
|
||||
usage: {
|
||||
inputTokens:
|
||||
(event as { usage: { input_tokens?: number; output_tokens: number } }).usage
|
||||
.input_tokens ?? 0,
|
||||
outputTokens: event.usage.output_tokens,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Emit final done event with full usage from the completed message
|
||||
const finalMessage = await stream.finalMessage();
|
||||
yield {
|
||||
type: 'done',
|
||||
usage: {
|
||||
inputTokens: finalMessage.usage.input_tokens,
|
||||
outputTokens: finalMessage.usage.output_tokens,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
4
apps/gateway/src/agent/adapters/index.ts
Normal file
4
apps/gateway/src/agent/adapters/index.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
export { OllamaAdapter } from './ollama.adapter.js';
|
||||
export { AnthropicAdapter } from './anthropic.adapter.js';
|
||||
export { OpenAIAdapter } from './openai.adapter.js';
|
||||
export { OpenRouterAdapter } from './openrouter.adapter.js';
|
||||
197
apps/gateway/src/agent/adapters/ollama.adapter.ts
Normal file
197
apps/gateway/src/agent/adapters/ollama.adapter.ts
Normal file
@@ -0,0 +1,197 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import type { ModelRegistry } from '@mariozechner/pi-coding-agent';
|
||||
import type {
|
||||
CompletionEvent,
|
||||
CompletionParams,
|
||||
IProviderAdapter,
|
||||
ModelInfo,
|
||||
ProviderHealth,
|
||||
} from '@mosaic/types';
|
||||
|
||||
/** Embedding models that Ollama ships with out of the box */
|
||||
const OLLAMA_EMBEDDING_MODELS: ReadonlyArray<{
|
||||
id: string;
|
||||
contextWindow: number;
|
||||
dimensions: number;
|
||||
}> = [
|
||||
{ id: 'nomic-embed-text', contextWindow: 8192, dimensions: 768 },
|
||||
{ id: 'mxbai-embed-large', contextWindow: 512, dimensions: 1024 },
|
||||
];
|
||||
|
||||
interface OllamaEmbeddingResponse {
|
||||
embedding?: number[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Ollama provider adapter.
|
||||
*
|
||||
* Registers local Ollama models with the Pi ModelRegistry via the OpenAI-compatible
|
||||
* completions API. Also exposes embedding models and an `embed()` method for
|
||||
* vector generation (used by EmbeddingService / M3-009).
|
||||
*
|
||||
* Configuration is driven by environment variables:
|
||||
* OLLAMA_BASE_URL or OLLAMA_HOST — base URL of the Ollama instance
|
||||
* OLLAMA_MODELS — comma-separated list of model IDs (default: llama3.2,codellama,mistral)
|
||||
*/
|
||||
export class OllamaAdapter implements IProviderAdapter {
|
||||
readonly name = 'ollama';
|
||||
|
||||
private readonly logger = new Logger(OllamaAdapter.name);
|
||||
private registeredModels: ModelInfo[] = [];
|
||||
|
||||
constructor(private readonly registry: ModelRegistry) {}
|
||||
|
||||
async register(): Promise<void> {
|
||||
const ollamaUrl = process.env['OLLAMA_BASE_URL'] ?? process.env['OLLAMA_HOST'];
|
||||
if (!ollamaUrl) {
|
||||
this.logger.debug('Skipping Ollama provider registration: OLLAMA_BASE_URL not set');
|
||||
return;
|
||||
}
|
||||
|
||||
const modelsEnv = process.env['OLLAMA_MODELS'] ?? 'llama3.2,codellama,mistral';
|
||||
const modelIds = modelsEnv
|
||||
.split(',')
|
||||
.map((id: string) => id.trim())
|
||||
.filter(Boolean);
|
||||
|
||||
this.registry.registerProvider('ollama', {
|
||||
baseUrl: `${ollamaUrl}/v1`,
|
||||
apiKey: 'ollama',
|
||||
api: 'openai-completions' as never,
|
||||
models: modelIds.map((id) => ({
|
||||
id,
|
||||
name: id,
|
||||
reasoning: false,
|
||||
input: ['text'] as ('text' | 'image')[],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 8192,
|
||||
maxTokens: 4096,
|
||||
})),
|
||||
});
|
||||
|
||||
// Chat / completion models
|
||||
const completionModels: ModelInfo[] = modelIds.map((id) => ({
|
||||
id,
|
||||
provider: 'ollama',
|
||||
name: id,
|
||||
reasoning: false,
|
||||
contextWindow: 8192,
|
||||
maxTokens: 4096,
|
||||
inputTypes: ['text'] as ('text' | 'image')[],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
}));
|
||||
|
||||
// Embedding models (tracked in registeredModels but not in Pi registry,
|
||||
// which only handles completion models)
|
||||
const embeddingModels: ModelInfo[] = OLLAMA_EMBEDDING_MODELS.map((em) => ({
|
||||
id: em.id,
|
||||
provider: 'ollama',
|
||||
name: em.id,
|
||||
reasoning: false,
|
||||
contextWindow: em.contextWindow,
|
||||
maxTokens: 0,
|
||||
inputTypes: ['text'] as ('text' | 'image')[],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
}));
|
||||
|
||||
this.registeredModels = [...completionModels, ...embeddingModels];
|
||||
|
||||
this.logger.log(
|
||||
`Ollama provider registered at ${ollamaUrl} with models: ${modelIds.join(', ')} ` +
|
||||
`and embedding models: ${OLLAMA_EMBEDDING_MODELS.map((em) => em.id).join(', ')}`,
|
||||
);
|
||||
}
|
||||
|
||||
listModels(): ModelInfo[] {
|
||||
return this.registeredModels;
|
||||
}
|
||||
|
||||
async healthCheck(): Promise<ProviderHealth> {
|
||||
const ollamaUrl = process.env['OLLAMA_BASE_URL'] ?? process.env['OLLAMA_HOST'];
|
||||
if (!ollamaUrl) {
|
||||
return {
|
||||
status: 'down',
|
||||
lastChecked: new Date().toISOString(),
|
||||
error: 'OLLAMA_BASE_URL not configured',
|
||||
};
|
||||
}
|
||||
|
||||
const checkUrl = `${ollamaUrl}/v1/models`;
|
||||
const start = Date.now();
|
||||
|
||||
try {
|
||||
const res = await fetch(checkUrl, {
|
||||
method: 'GET',
|
||||
headers: { Accept: 'application/json' },
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
const latencyMs = Date.now() - start;
|
||||
|
||||
if (!res.ok) {
|
||||
return {
|
||||
status: 'degraded',
|
||||
latencyMs,
|
||||
lastChecked: new Date().toISOString(),
|
||||
error: `HTTP ${res.status}`,
|
||||
};
|
||||
}
|
||||
|
||||
return { status: 'healthy', latencyMs, lastChecked: new Date().toISOString() };
|
||||
} catch (err) {
|
||||
const latencyMs = Date.now() - start;
|
||||
const error = err instanceof Error ? err.message : String(err);
|
||||
return { status: 'down', latencyMs, lastChecked: new Date().toISOString(), error };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an embedding vector for the given text using Ollama's /api/embeddings endpoint.
|
||||
*
|
||||
* Defaults to 'nomic-embed-text' when no model is specified.
|
||||
* Intended for use by EmbeddingService (M3-009).
|
||||
*
|
||||
* @param text - The input text to embed.
|
||||
* @param model - Optional embedding model ID (default: 'nomic-embed-text').
|
||||
* @returns A float array representing the embedding vector.
|
||||
*/
|
||||
async embed(text: string, model = 'nomic-embed-text'): Promise<number[]> {
|
||||
const ollamaUrl = process.env['OLLAMA_BASE_URL'] ?? process.env['OLLAMA_HOST'];
|
||||
if (!ollamaUrl) {
|
||||
throw new Error('OllamaAdapter: OLLAMA_BASE_URL not configured');
|
||||
}
|
||||
|
||||
const embeddingUrl = `${ollamaUrl}/api/embeddings`;
|
||||
|
||||
const res = await fetch(embeddingUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ model, prompt: text }),
|
||||
signal: AbortSignal.timeout(30000),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
throw new Error(`OllamaAdapter.embed: request failed with HTTP ${res.status}`);
|
||||
}
|
||||
|
||||
const json = (await res.json()) as OllamaEmbeddingResponse;
|
||||
|
||||
if (!Array.isArray(json.embedding)) {
|
||||
throw new Error('OllamaAdapter.embed: unexpected response — missing embedding array');
|
||||
}
|
||||
|
||||
return json.embedding;
|
||||
}
|
||||
|
||||
/**
|
||||
* createCompletion is reserved for future direct-completion use.
|
||||
* The current integration routes completions through Pi SDK's ModelRegistry/AgentSession.
|
||||
*/
|
||||
async *createCompletion(_params: CompletionParams): AsyncIterable<CompletionEvent> {
|
||||
throw new Error(
|
||||
'OllamaAdapter.createCompletion is not yet implemented. ' +
|
||||
'Use Pi SDK AgentSession for completions.',
|
||||
);
|
||||
// Satisfy the AsyncGenerator return type — unreachable but required for TypeScript.
|
||||
yield undefined as never;
|
||||
}
|
||||
}
|
||||
201
apps/gateway/src/agent/adapters/openai.adapter.ts
Normal file
201
apps/gateway/src/agent/adapters/openai.adapter.ts
Normal file
@@ -0,0 +1,201 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import OpenAI from 'openai';
|
||||
import type { ModelRegistry } from '@mariozechner/pi-coding-agent';
|
||||
import type {
|
||||
CompletionEvent,
|
||||
CompletionParams,
|
||||
IProviderAdapter,
|
||||
ModelInfo,
|
||||
ProviderHealth,
|
||||
} from '@mosaic/types';
|
||||
|
||||
/**
|
||||
* OpenAI provider adapter.
|
||||
*
|
||||
* Registers OpenAI models (including Codex gpt-5.4) with the Pi ModelRegistry.
|
||||
* Configuration is driven by environment variables:
|
||||
* OPENAI_API_KEY — OpenAI API key (required; adapter skips registration when absent)
|
||||
*/
|
||||
export class OpenAIAdapter implements IProviderAdapter {
|
||||
readonly name = 'openai';
|
||||
|
||||
private readonly logger = new Logger(OpenAIAdapter.name);
|
||||
private registeredModels: ModelInfo[] = [];
|
||||
private client: OpenAI | null = null;
|
||||
|
||||
/** Model ID used for Codex gpt-5.4 in the Pi registry. */
|
||||
static readonly CODEX_MODEL_ID = 'codex-gpt-5-4';
|
||||
|
||||
constructor(private readonly registry: ModelRegistry) {}
|
||||
|
||||
async register(): Promise<void> {
|
||||
const apiKey = process.env['OPENAI_API_KEY'];
|
||||
if (!apiKey) {
|
||||
this.logger.debug('Skipping OpenAI provider registration: OPENAI_API_KEY not set');
|
||||
return;
|
||||
}
|
||||
|
||||
this.client = new OpenAI({ apiKey });
|
||||
|
||||
const codexModel = {
|
||||
id: OpenAIAdapter.CODEX_MODEL_ID,
|
||||
name: 'Codex gpt-5.4',
|
||||
/** OpenAI-compatible completions API */
|
||||
api: 'openai-completions' as never,
|
||||
reasoning: false,
|
||||
input: ['text', 'image'] as ('text' | 'image')[],
|
||||
cost: { input: 0.003, output: 0.012, cacheRead: 0.0015, cacheWrite: 0 },
|
||||
contextWindow: 128_000,
|
||||
maxTokens: 16_384,
|
||||
};
|
||||
|
||||
this.registry.registerProvider('openai', {
|
||||
apiKey,
|
||||
baseUrl: 'https://api.openai.com/v1',
|
||||
models: [codexModel],
|
||||
});
|
||||
|
||||
this.registeredModels = [
|
||||
{
|
||||
id: OpenAIAdapter.CODEX_MODEL_ID,
|
||||
provider: 'openai',
|
||||
name: 'Codex gpt-5.4',
|
||||
reasoning: false,
|
||||
contextWindow: 128_000,
|
||||
maxTokens: 16_384,
|
||||
inputTypes: ['text', 'image'] as ('text' | 'image')[],
|
||||
cost: { input: 0.003, output: 0.012, cacheRead: 0.0015, cacheWrite: 0 },
|
||||
},
|
||||
];
|
||||
|
||||
this.logger.log(`OpenAI provider registered with model: ${OpenAIAdapter.CODEX_MODEL_ID}`);
|
||||
}
|
||||
|
||||
listModels(): ModelInfo[] {
|
||||
return this.registeredModels;
|
||||
}
|
||||
|
||||
async healthCheck(): Promise<ProviderHealth> {
|
||||
const apiKey = process.env['OPENAI_API_KEY'];
|
||||
if (!apiKey) {
|
||||
return {
|
||||
status: 'down',
|
||||
lastChecked: new Date().toISOString(),
|
||||
error: 'OPENAI_API_KEY not configured',
|
||||
};
|
||||
}
|
||||
|
||||
const start = Date.now();
|
||||
try {
|
||||
// Lightweight call — list models to verify key validity
|
||||
const res = await fetch('https://api.openai.com/v1/models', {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
const latencyMs = Date.now() - start;
|
||||
|
||||
if (!res.ok) {
|
||||
return {
|
||||
status: 'degraded',
|
||||
latencyMs,
|
||||
lastChecked: new Date().toISOString(),
|
||||
error: `HTTP ${res.status}`,
|
||||
};
|
||||
}
|
||||
|
||||
return { status: 'healthy', latencyMs, lastChecked: new Date().toISOString() };
|
||||
} catch (err) {
|
||||
const latencyMs = Date.now() - start;
|
||||
const error = err instanceof Error ? err.message : String(err);
|
||||
return { status: 'down', latencyMs, lastChecked: new Date().toISOString(), error };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a completion from OpenAI using the chat completions API.
|
||||
*
|
||||
* Maps OpenAI streaming chunks to the Mosaic CompletionEvent format.
|
||||
*/
|
||||
async *createCompletion(params: CompletionParams): AsyncIterable<CompletionEvent> {
|
||||
if (!this.client) {
|
||||
throw new Error(
|
||||
'OpenAIAdapter: client not initialized. ' +
|
||||
'Ensure OPENAI_API_KEY is set and register() was called.',
|
||||
);
|
||||
}
|
||||
|
||||
const stream = await this.client.chat.completions.create({
|
||||
model: params.model,
|
||||
messages: params.messages.map((m) => ({
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
})),
|
||||
...(params.temperature !== undefined && { temperature: params.temperature }),
|
||||
...(params.maxTokens !== undefined && { max_tokens: params.maxTokens }),
|
||||
...(params.tools &&
|
||||
params.tools.length > 0 && {
|
||||
tools: params.tools.map((t) => ({
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.parameters,
|
||||
},
|
||||
})),
|
||||
}),
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
});
|
||||
|
||||
let inputTokens = 0;
|
||||
let outputTokens = 0;
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const choice = chunk.choices[0];
|
||||
|
||||
// Accumulate usage when present (final chunk with stream_options.include_usage)
|
||||
if (chunk.usage) {
|
||||
inputTokens = chunk.usage.prompt_tokens;
|
||||
outputTokens = chunk.usage.completion_tokens;
|
||||
}
|
||||
|
||||
if (!choice) continue;
|
||||
|
||||
const delta = choice.delta;
|
||||
|
||||
// Text content delta
|
||||
if (delta.content) {
|
||||
yield { type: 'text_delta', content: delta.content };
|
||||
}
|
||||
|
||||
// Tool call delta — emit when arguments are complete
|
||||
if (delta.tool_calls) {
|
||||
for (const toolCallDelta of delta.tool_calls) {
|
||||
if (toolCallDelta.function?.name && toolCallDelta.function.arguments !== undefined) {
|
||||
yield {
|
||||
type: 'tool_call',
|
||||
name: toolCallDelta.function.name,
|
||||
arguments: toolCallDelta.function.arguments,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream finished
|
||||
if (choice.finish_reason === 'stop' || choice.finish_reason === 'tool_calls') {
|
||||
yield {
|
||||
type: 'done',
|
||||
usage: { inputTokens, outputTokens },
|
||||
};
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback done event when stream ends without explicit finish_reason
|
||||
yield { type: 'done', usage: { inputTokens, outputTokens } };
|
||||
}
|
||||
}
|
||||
212
apps/gateway/src/agent/adapters/openrouter.adapter.ts
Normal file
212
apps/gateway/src/agent/adapters/openrouter.adapter.ts
Normal file
@@ -0,0 +1,212 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import OpenAI from 'openai';
|
||||
import type {
|
||||
CompletionEvent,
|
||||
CompletionParams,
|
||||
IProviderAdapter,
|
||||
ModelInfo,
|
||||
ProviderHealth,
|
||||
} from '@mosaic/types';
|
||||
|
||||
const OPENROUTER_BASE_URL = 'https://openrouter.ai/api/v1';
|
||||
|
||||
interface OpenRouterModel {
|
||||
id: string;
|
||||
name?: string;
|
||||
context_length?: number;
|
||||
top_provider?: {
|
||||
max_completion_tokens?: number;
|
||||
};
|
||||
pricing?: {
|
||||
prompt?: string | number;
|
||||
completion?: string | number;
|
||||
};
|
||||
architecture?: {
|
||||
input_modalities?: string[];
|
||||
};
|
||||
}
|
||||
|
||||
interface OpenRouterModelsResponse {
|
||||
data?: OpenRouterModel[];
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenRouter provider adapter.
|
||||
*
|
||||
* Routes completions through OpenRouter's OpenAI-compatible API.
|
||||
* Configuration is driven by the OPENROUTER_API_KEY environment variable.
|
||||
*/
|
||||
export class OpenRouterAdapter implements IProviderAdapter {
|
||||
readonly name = 'openrouter';
|
||||
|
||||
private readonly logger = new Logger(OpenRouterAdapter.name);
|
||||
private client: OpenAI | null = null;
|
||||
private registeredModels: ModelInfo[] = [];
|
||||
|
||||
async register(): Promise<void> {
|
||||
const apiKey = process.env['OPENROUTER_API_KEY'];
|
||||
if (!apiKey) {
|
||||
this.logger.debug('Skipping OpenRouter provider registration: OPENROUTER_API_KEY not set');
|
||||
return;
|
||||
}
|
||||
|
||||
this.client = new OpenAI({
|
||||
apiKey,
|
||||
baseURL: OPENROUTER_BASE_URL,
|
||||
defaultHeaders: {
|
||||
'HTTP-Referer': 'https://mosaic.ai',
|
||||
'X-Title': 'Mosaic',
|
||||
},
|
||||
});
|
||||
|
||||
try {
|
||||
this.registeredModels = await this.fetchModels(apiKey);
|
||||
this.logger.log(`OpenRouter provider registered with ${this.registeredModels.length} models`);
|
||||
} catch (err) {
|
||||
this.logger.warn(
|
||||
`OpenRouter model discovery failed: ${err instanceof Error ? err.message : String(err)}. Registering with empty model list.`,
|
||||
);
|
||||
this.registeredModels = [];
|
||||
}
|
||||
}
|
||||
|
||||
listModels(): ModelInfo[] {
|
||||
return this.registeredModels;
|
||||
}
|
||||
|
||||
async healthCheck(): Promise<ProviderHealth> {
|
||||
const apiKey = process.env['OPENROUTER_API_KEY'];
|
||||
if (!apiKey) {
|
||||
return {
|
||||
status: 'down',
|
||||
lastChecked: new Date().toISOString(),
|
||||
error: 'OPENROUTER_API_KEY not configured',
|
||||
};
|
||||
}
|
||||
|
||||
const start = Date.now();
|
||||
try {
|
||||
const res = await fetch(`${OPENROUTER_BASE_URL}/models`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
Accept: 'application/json',
|
||||
},
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
const latencyMs = Date.now() - start;
|
||||
|
||||
if (!res.ok) {
|
||||
return {
|
||||
status: 'degraded',
|
||||
latencyMs,
|
||||
lastChecked: new Date().toISOString(),
|
||||
error: `HTTP ${res.status}`,
|
||||
};
|
||||
}
|
||||
|
||||
return { status: 'healthy', latencyMs, lastChecked: new Date().toISOString() };
|
||||
} catch (err) {
|
||||
const latencyMs = Date.now() - start;
|
||||
const error = err instanceof Error ? err.message : String(err);
|
||||
return { status: 'down', latencyMs, lastChecked: new Date().toISOString(), error };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a completion through OpenRouter's OpenAI-compatible API.
|
||||
*/
|
||||
async *createCompletion(params: CompletionParams): AsyncIterable<CompletionEvent> {
|
||||
if (!this.client) {
|
||||
throw new Error('OpenRouterAdapter is not initialized. Ensure OPENROUTER_API_KEY is set.');
|
||||
}
|
||||
|
||||
const stream = await this.client.chat.completions.create({
|
||||
model: params.model,
|
||||
messages: params.messages.map((m) => ({ role: m.role, content: m.content })),
|
||||
temperature: params.temperature,
|
||||
max_tokens: params.maxTokens,
|
||||
stream: true,
|
||||
});
|
||||
|
||||
let inputTokens = 0;
|
||||
let outputTokens = 0;
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const choice = chunk.choices[0];
|
||||
if (!choice) continue;
|
||||
|
||||
const delta = choice.delta;
|
||||
|
||||
if (delta.content) {
|
||||
yield { type: 'text_delta', content: delta.content };
|
||||
}
|
||||
|
||||
if (choice.finish_reason === 'stop') {
|
||||
const usage = (chunk as { usage?: { prompt_tokens?: number; completion_tokens?: number } })
|
||||
.usage;
|
||||
if (usage) {
|
||||
inputTokens = usage.prompt_tokens ?? 0;
|
||||
outputTokens = usage.completion_tokens ?? 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
yield {
|
||||
type: 'done',
|
||||
usage: { inputTokens, outputTokens },
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Private helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
private async fetchModels(apiKey: string): Promise<ModelInfo[]> {
|
||||
const res = await fetch(`${OPENROUTER_BASE_URL}/models`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
Accept: 'application/json',
|
||||
},
|
||||
signal: AbortSignal.timeout(10000),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
throw new Error(`OpenRouter models endpoint returned HTTP ${res.status}`);
|
||||
}
|
||||
|
||||
const json = (await res.json()) as OpenRouterModelsResponse;
|
||||
const data = json.data ?? [];
|
||||
|
||||
return data.map((model): ModelInfo => {
|
||||
const inputPrice = model.pricing?.prompt
|
||||
? parseFloat(String(model.pricing.prompt)) * 1000
|
||||
: 0;
|
||||
const outputPrice = model.pricing?.completion
|
||||
? parseFloat(String(model.pricing.completion)) * 1000
|
||||
: 0;
|
||||
|
||||
const inputModalities = model.architecture?.input_modalities ?? ['text'];
|
||||
const inputTypes = inputModalities.includes('image')
|
||||
? (['text', 'image'] as const)
|
||||
: (['text'] as const);
|
||||
|
||||
return {
|
||||
id: model.id,
|
||||
provider: 'openrouter',
|
||||
name: model.name ?? model.id,
|
||||
reasoning: false,
|
||||
contextWindow: model.context_length ?? 4096,
|
||||
maxTokens: model.top_provider?.max_completion_tokens ?? 4096,
|
||||
inputTypes: [...inputTypes],
|
||||
cost: {
|
||||
input: inputPrice,
|
||||
output: outputPrice,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
97
apps/gateway/src/agent/agent-config.dto.ts
Normal file
97
apps/gateway/src/agent/agent-config.dto.ts
Normal file
@@ -0,0 +1,97 @@
|
||||
import {
|
||||
IsArray,
|
||||
IsBoolean,
|
||||
IsIn,
|
||||
IsObject,
|
||||
IsOptional,
|
||||
IsString,
|
||||
IsUUID,
|
||||
MaxLength,
|
||||
} from 'class-validator';
|
||||
|
||||
const agentStatuses = ['idle', 'active', 'error', 'offline'] as const;
|
||||
|
||||
export class CreateAgentConfigDto {
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
name!: string;
|
||||
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
provider!: string;
|
||||
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
model!: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsIn(agentStatuses)
|
||||
status?: 'idle' | 'active' | 'error' | 'offline';
|
||||
|
||||
@IsOptional()
|
||||
@IsUUID()
|
||||
projectId?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(50_000)
|
||||
systemPrompt?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsArray()
|
||||
allowedTools?: string[];
|
||||
|
||||
@IsOptional()
|
||||
@IsArray()
|
||||
skills?: string[];
|
||||
|
||||
@IsOptional()
|
||||
@IsBoolean()
|
||||
isSystem?: boolean;
|
||||
|
||||
@IsOptional()
|
||||
@IsObject()
|
||||
config?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export class UpdateAgentConfigDto {
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
name?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
provider?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
model?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsIn(agentStatuses)
|
||||
status?: 'idle' | 'active' | 'error' | 'offline';
|
||||
|
||||
@IsOptional()
|
||||
@IsUUID()
|
||||
projectId?: string | null;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(50_000)
|
||||
systemPrompt?: string | null;
|
||||
|
||||
@IsOptional()
|
||||
@IsArray()
|
||||
allowedTools?: string[] | null;
|
||||
|
||||
@IsOptional()
|
||||
@IsArray()
|
||||
skills?: string[] | null;
|
||||
|
||||
@IsOptional()
|
||||
@IsObject()
|
||||
config?: Record<string, unknown> | null;
|
||||
}
|
||||
89
apps/gateway/src/agent/agent-configs.controller.ts
Normal file
89
apps/gateway/src/agent/agent-configs.controller.ts
Normal file
@@ -0,0 +1,89 @@
|
||||
import {
|
||||
Body,
|
||||
Controller,
|
||||
Delete,
|
||||
ForbiddenException,
|
||||
Get,
|
||||
HttpCode,
|
||||
HttpStatus,
|
||||
Inject,
|
||||
NotFoundException,
|
||||
Param,
|
||||
Patch,
|
||||
Post,
|
||||
UseGuards,
|
||||
} from '@nestjs/common';
|
||||
import type { Brain } from '@mosaic/brain';
|
||||
import { BRAIN } from '../brain/brain.tokens.js';
|
||||
import { AuthGuard } from '../auth/auth.guard.js';
|
||||
import { CurrentUser } from '../auth/current-user.decorator.js';
|
||||
import { CreateAgentConfigDto, UpdateAgentConfigDto } from './agent-config.dto.js';
|
||||
|
||||
@Controller('api/agents')
|
||||
@UseGuards(AuthGuard)
|
||||
export class AgentConfigsController {
|
||||
constructor(@Inject(BRAIN) private readonly brain: Brain) {}
|
||||
|
||||
@Get()
|
||||
async list(@CurrentUser() user: { id: string; role?: string }) {
|
||||
return this.brain.agents.findAccessible(user.id);
|
||||
}
|
||||
|
||||
@Get(':id')
|
||||
async findOne(@Param('id') id: string, @CurrentUser() user: { id: string }) {
|
||||
const agent = await this.brain.agents.findById(id);
|
||||
if (!agent) throw new NotFoundException('Agent not found');
|
||||
if (!agent.isSystem && agent.ownerId !== user.id) {
|
||||
throw new ForbiddenException('Agent does not belong to the current user');
|
||||
}
|
||||
return agent;
|
||||
}
|
||||
|
||||
@Post()
|
||||
async create(@Body() dto: CreateAgentConfigDto, @CurrentUser() user: { id: string }) {
|
||||
return this.brain.agents.create({
|
||||
...dto,
|
||||
ownerId: user.id,
|
||||
isSystem: false,
|
||||
});
|
||||
}
|
||||
|
||||
@Patch(':id')
|
||||
async update(
|
||||
@Param('id') id: string,
|
||||
@Body() dto: UpdateAgentConfigDto,
|
||||
@CurrentUser() user: { id: string; role?: string },
|
||||
) {
|
||||
const agent = await this.brain.agents.findById(id);
|
||||
if (!agent) throw new NotFoundException('Agent not found');
|
||||
if (agent.isSystem && user.role !== 'admin') {
|
||||
throw new ForbiddenException('Only admins can update system agents');
|
||||
}
|
||||
if (!agent.isSystem && agent.ownerId !== user.id) {
|
||||
throw new ForbiddenException('Agent does not belong to the current user');
|
||||
}
|
||||
|
||||
// Pass ownerId for user agents so the repo WHERE clause enforces ownership.
|
||||
// For system agents (admin path) pass undefined so the WHERE matches only on id.
|
||||
const ownerId = agent.isSystem ? undefined : user.id;
|
||||
const updated = await this.brain.agents.update(id, dto, ownerId);
|
||||
if (!updated) throw new NotFoundException('Agent not found');
|
||||
return updated;
|
||||
}
|
||||
|
||||
@Delete(':id')
|
||||
@HttpCode(HttpStatus.NO_CONTENT)
|
||||
async remove(@Param('id') id: string, @CurrentUser() user: { id: string; role?: string }) {
|
||||
const agent = await this.brain.agents.findById(id);
|
||||
if (!agent) throw new NotFoundException('Agent not found');
|
||||
if (agent.isSystem) {
|
||||
throw new ForbiddenException('Cannot delete system agents');
|
||||
}
|
||||
if (agent.ownerId !== user.id) {
|
||||
throw new ForbiddenException('Agent does not belong to the current user');
|
||||
}
|
||||
// Pass ownerId so the repo WHERE clause enforces ownership at the DB level.
|
||||
const deleted = await this.brain.agents.remove(id, user.id);
|
||||
if (!deleted) throw new NotFoundException('Agent not found');
|
||||
}
|
||||
}
|
||||
@@ -2,15 +2,20 @@ import { Global, Module } from '@nestjs/common';
|
||||
import { AgentService } from './agent.service.js';
|
||||
import { ProviderService } from './provider.service.js';
|
||||
import { RoutingService } from './routing.service.js';
|
||||
import { SkillLoaderService } from './skill-loader.service.js';
|
||||
import { ProvidersController } from './providers.controller.js';
|
||||
import { SessionsController } from './sessions.controller.js';
|
||||
import { AgentConfigsController } from './agent-configs.controller.js';
|
||||
import { CoordModule } from '../coord/coord.module.js';
|
||||
import { McpClientModule } from '../mcp-client/mcp-client.module.js';
|
||||
import { SkillsModule } from '../skills/skills.module.js';
|
||||
import { GCModule } from '../gc/gc.module.js';
|
||||
|
||||
@Global()
|
||||
@Module({
|
||||
imports: [CoordModule],
|
||||
providers: [ProviderService, RoutingService, AgentService],
|
||||
controllers: [ProvidersController, SessionsController],
|
||||
exports: [AgentService, ProviderService, RoutingService],
|
||||
imports: [CoordModule, McpClientModule, SkillsModule, GCModule],
|
||||
providers: [ProviderService, RoutingService, SkillLoaderService, AgentService],
|
||||
controllers: [ProvidersController, SessionsController, AgentConfigsController],
|
||||
exports: [AgentService, ProviderService, RoutingService, SkillLoaderService],
|
||||
})
|
||||
export class AgentModule {}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Inject, Injectable, Logger, type OnModuleDestroy } from '@nestjs/common';
|
||||
import { Inject, Injectable, Logger, Optional, type OnModuleDestroy } from '@nestjs/common';
|
||||
import {
|
||||
createAgentSession,
|
||||
DefaultResourceLoader,
|
||||
SessionManager,
|
||||
type AgentSession as PiAgentSession,
|
||||
type AgentSessionEvent,
|
||||
@@ -13,14 +14,65 @@ import { MEMORY } from '../memory/memory.tokens.js';
|
||||
import { EmbeddingService } from '../memory/embedding.service.js';
|
||||
import { CoordService } from '../coord/coord.service.js';
|
||||
import { ProviderService } from './provider.service.js';
|
||||
import { McpClientService } from '../mcp-client/mcp-client.service.js';
|
||||
import { SkillLoaderService } from './skill-loader.service.js';
|
||||
import { createBrainTools } from './tools/brain-tools.js';
|
||||
import { createCoordTools } from './tools/coord-tools.js';
|
||||
import { createMemoryTools } from './tools/memory-tools.js';
|
||||
import { createFileTools } from './tools/file-tools.js';
|
||||
import { createGitTools } from './tools/git-tools.js';
|
||||
import { createShellTools } from './tools/shell-tools.js';
|
||||
import { createWebTools } from './tools/web-tools.js';
|
||||
import type { SessionInfoDto } from './session.dto.js';
|
||||
import { SystemOverrideService } from '../preferences/system-override.service.js';
|
||||
import { PreferencesService } from '../preferences/preferences.service.js';
|
||||
import { SessionGCService } from '../gc/session-gc.service.js';
|
||||
|
||||
/** A single message from DB conversation history, used for context injection. */
|
||||
export interface ConversationHistoryMessage {
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
createdAt: Date;
|
||||
}
|
||||
|
||||
export interface AgentSessionOptions {
|
||||
provider?: string;
|
||||
modelId?: string;
|
||||
/**
|
||||
* Sandbox working directory for the session.
|
||||
* File, git, and shell tools will be restricted to this directory.
|
||||
* Falls back to AGENT_FILE_SANDBOX_DIR env var or process.cwd().
|
||||
*/
|
||||
sandboxDir?: string;
|
||||
/**
|
||||
* Platform-level system prompt for this session.
|
||||
* Merged with skill prompt additions (platform prompt first, then skills).
|
||||
* Falls back to AGENT_SYSTEM_PROMPT env var when omitted.
|
||||
*/
|
||||
systemPrompt?: string;
|
||||
/**
|
||||
* Explicit allowlist of tool names available in this session.
|
||||
* When set, only listed tools are registered with the agent.
|
||||
* When omitted for non-admin users, falls back to AGENT_USER_TOOLS env var.
|
||||
* Admins (isAdmin=true) always receive the full tool set unless explicitly restricted.
|
||||
*/
|
||||
allowedTools?: string[];
|
||||
/** Whether the requesting user has admin privileges. Controls default tool access. */
|
||||
isAdmin?: boolean;
|
||||
/**
|
||||
* DB agent config ID. When provided, loads agent config from DB and merges
|
||||
* provider, model, systemPrompt, and allowedTools. Explicit call-site options
|
||||
* take precedence over config values.
|
||||
*/
|
||||
agentConfigId?: string;
|
||||
/** ID of the user who owns this session. Used for preferences and system override lookups. */
|
||||
userId?: string;
|
||||
/**
|
||||
* Prior conversation messages to inject as context when resuming a session.
|
||||
* These messages are formatted and prepended to the system prompt so the
|
||||
* agent is aware of what was discussed in previous sessions.
|
||||
*/
|
||||
conversationHistory?: ConversationHistoryMessage[];
|
||||
}
|
||||
|
||||
export interface AgentSession {
|
||||
@@ -33,6 +85,14 @@ export interface AgentSession {
|
||||
createdAt: number;
|
||||
promptCount: number;
|
||||
channels: Set<string>;
|
||||
/** System prompt additions injected from enabled prompt-type skills. */
|
||||
skillPromptAdditions: string[];
|
||||
/** Resolved sandbox directory for this session. */
|
||||
sandboxDir: string;
|
||||
/** Tool names available in this session, or null when all tools are available. */
|
||||
allowedTools: string[] | null;
|
||||
/** User ID that owns this session, used for preference lookups. */
|
||||
userId?: string;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
@@ -41,21 +101,69 @@ export class AgentService implements OnModuleDestroy {
|
||||
private readonly sessions = new Map<string, AgentSession>();
|
||||
private readonly creating = new Map<string, Promise<AgentSession>>();
|
||||
|
||||
private readonly customTools: ToolDefinition[];
|
||||
|
||||
constructor(
|
||||
@Inject(ProviderService) private readonly providerService: ProviderService,
|
||||
@Inject(BRAIN) private readonly brain: Brain,
|
||||
@Inject(MEMORY) private readonly memory: Memory,
|
||||
@Inject(EmbeddingService) private readonly embeddingService: EmbeddingService,
|
||||
@Inject(CoordService) private readonly coordService: CoordService,
|
||||
) {
|
||||
this.customTools = [
|
||||
...createBrainTools(brain),
|
||||
...createCoordTools(coordService),
|
||||
...createMemoryTools(memory, embeddingService.available ? embeddingService : null),
|
||||
@Inject(McpClientService) private readonly mcpClientService: McpClientService,
|
||||
@Inject(SkillLoaderService) private readonly skillLoaderService: SkillLoaderService,
|
||||
@Optional()
|
||||
@Inject(SystemOverrideService)
|
||||
private readonly systemOverride: SystemOverrideService | null,
|
||||
@Optional()
|
||||
@Inject(PreferencesService)
|
||||
private readonly preferencesService: PreferencesService | null,
|
||||
@Inject(SessionGCService) private readonly gc: SessionGCService,
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Build the full set of custom tools scoped to the given sandbox directory and session user.
|
||||
* Brain/coord/memory/web tools are stateless with respect to cwd; file/git/shell
|
||||
* tools receive the resolved sandboxDir so they operate within the sandbox.
|
||||
* Memory tools are bound to sessionUserId so the LLM cannot access another user's data.
|
||||
*/
|
||||
private buildToolsForSandbox(
|
||||
sandboxDir: string,
|
||||
sessionUserId: string | undefined,
|
||||
): ToolDefinition[] {
|
||||
return [
|
||||
...createBrainTools(this.brain),
|
||||
...createCoordTools(this.coordService),
|
||||
...createMemoryTools(
|
||||
this.memory,
|
||||
this.embeddingService.available ? this.embeddingService : null,
|
||||
sessionUserId,
|
||||
),
|
||||
...createFileTools(sandboxDir),
|
||||
...createGitTools(sandboxDir),
|
||||
...createShellTools(sandboxDir),
|
||||
...createWebTools(),
|
||||
];
|
||||
this.logger.log(`Registered ${this.customTools.length} custom tools`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve the tool allowlist for a session.
|
||||
* - Admin users: all tools unless an explicit allowedTools list is passed.
|
||||
* - Regular users: use allowedTools if provided, otherwise parse AGENT_USER_TOOLS env var.
|
||||
* Returns null when all tools should be available.
|
||||
*/
|
||||
private resolveAllowedTools(isAdmin: boolean, allowedTools?: string[]): string[] | null {
|
||||
if (allowedTools !== undefined) {
|
||||
return allowedTools.length === 0 ? [] : allowedTools;
|
||||
}
|
||||
if (isAdmin) {
|
||||
return null; // admins get everything
|
||||
}
|
||||
const envTools = process.env['AGENT_USER_TOOLS'];
|
||||
if (!envTools) {
|
||||
return null; // no restriction configured
|
||||
}
|
||||
return envTools
|
||||
.split(',')
|
||||
.map((t) => t.trim())
|
||||
.filter((t) => t.length > 0);
|
||||
}
|
||||
|
||||
async createSession(sessionId: string, options?: AgentSessionOptions): Promise<AgentSession> {
|
||||
@@ -76,22 +184,116 @@ export class AgentService implements OnModuleDestroy {
|
||||
sessionId: string,
|
||||
options?: AgentSessionOptions,
|
||||
): Promise<AgentSession> {
|
||||
const model = this.resolveModel(options);
|
||||
// Merge DB agent config when agentConfigId is provided
|
||||
let mergedOptions = options;
|
||||
if (options?.agentConfigId) {
|
||||
const agentConfig = await this.brain.agents.findById(options.agentConfigId);
|
||||
if (agentConfig) {
|
||||
mergedOptions = {
|
||||
provider: options.provider ?? agentConfig.provider,
|
||||
modelId: options.modelId ?? agentConfig.model,
|
||||
systemPrompt: options.systemPrompt ?? agentConfig.systemPrompt ?? undefined,
|
||||
allowedTools: options.allowedTools ?? agentConfig.allowedTools ?? undefined,
|
||||
sandboxDir: options.sandboxDir,
|
||||
isAdmin: options.isAdmin,
|
||||
agentConfigId: options.agentConfigId,
|
||||
};
|
||||
this.logger.log(
|
||||
`Merged agent config "${agentConfig.name}" (${agentConfig.id}) into session ${sessionId}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const model = this.resolveModel(mergedOptions);
|
||||
const providerName = model?.provider ?? 'default';
|
||||
const modelId = model?.id ?? 'default';
|
||||
|
||||
this.logger.log(
|
||||
`Creating agent session: ${sessionId} (provider=${providerName}, model=${modelId})`,
|
||||
// Resolve sandbox directory: option > env var > process.cwd()
|
||||
const sandboxDir =
|
||||
mergedOptions?.sandboxDir ?? process.env['AGENT_FILE_SANDBOX_DIR'] ?? process.cwd();
|
||||
|
||||
// Resolve allowed tool set
|
||||
const allowedTools = this.resolveAllowedTools(
|
||||
mergedOptions?.isAdmin ?? false,
|
||||
mergedOptions?.allowedTools,
|
||||
);
|
||||
|
||||
this.logger.log(
|
||||
`Creating agent session: ${sessionId} (provider=${providerName}, model=${modelId}, sandbox=${sandboxDir}, tools=${allowedTools === null ? 'all' : allowedTools.join(',') || 'none'})`,
|
||||
);
|
||||
|
||||
// Load skill tools from the catalog
|
||||
const { metaTools: skillMetaTools, promptAdditions } =
|
||||
await this.skillLoaderService.loadForSession();
|
||||
if (skillMetaTools.length > 0) {
|
||||
this.logger.log(`Attaching ${skillMetaTools.length} skill tool(s) to session ${sessionId}`);
|
||||
}
|
||||
if (promptAdditions.length > 0) {
|
||||
this.logger.log(
|
||||
`Injecting ${promptAdditions.length} skill prompt addition(s) into session ${sessionId}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Build per-session tools scoped to the sandbox directory and authenticated user
|
||||
const sandboxTools = this.buildToolsForSandbox(sandboxDir, mergedOptions?.userId);
|
||||
|
||||
// Combine static tools with dynamically discovered MCP client tools and skill tools
|
||||
const mcpTools = this.mcpClientService.getToolDefinitions();
|
||||
let allCustomTools = [...sandboxTools, ...skillMetaTools, ...mcpTools];
|
||||
if (mcpTools.length > 0) {
|
||||
this.logger.log(`Attaching ${mcpTools.length} MCP client tool(s) to session ${sessionId}`);
|
||||
}
|
||||
|
||||
// Filter tools by allowlist when a restriction is in effect
|
||||
if (allowedTools !== null) {
|
||||
const allowedSet = new Set(allowedTools);
|
||||
const before = allCustomTools.length;
|
||||
allCustomTools = allCustomTools.filter((t) => allowedSet.has(t.name));
|
||||
this.logger.log(
|
||||
`Tool restriction applied: ${allCustomTools.length}/${before} tools allowed for session ${sessionId}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Build system prompt: platform prompt + skill additions appended
|
||||
const platformPrompt =
|
||||
mergedOptions?.systemPrompt ?? process.env['AGENT_SYSTEM_PROMPT'] ?? undefined;
|
||||
|
||||
// Format conversation history for context injection (M1-004 / M1-005)
|
||||
const historyPromptSection = mergedOptions?.conversationHistory?.length
|
||||
? this.buildHistoryPromptSection(
|
||||
mergedOptions.conversationHistory,
|
||||
model?.contextWindow ?? 8192,
|
||||
sessionId,
|
||||
)
|
||||
: undefined;
|
||||
|
||||
const appendParts: string[] = [];
|
||||
if (promptAdditions.length > 0) appendParts.push(promptAdditions.join('\n\n'));
|
||||
if (historyPromptSection) appendParts.push(historyPromptSection);
|
||||
const appendSystemPrompt = appendParts.length > 0 ? appendParts.join('\n\n') : undefined;
|
||||
|
||||
// Construct a resource loader that injects the configured system prompt
|
||||
const resourceLoader = new DefaultResourceLoader({
|
||||
cwd: sandboxDir,
|
||||
noExtensions: true,
|
||||
noSkills: true,
|
||||
noPromptTemplates: true,
|
||||
noThemes: true,
|
||||
systemPrompt: platformPrompt,
|
||||
appendSystemPrompt: appendSystemPrompt,
|
||||
});
|
||||
await resourceLoader.reload();
|
||||
|
||||
let piSession: PiAgentSession;
|
||||
try {
|
||||
const result = await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
modelRegistry: this.providerService.getRegistry(),
|
||||
model: model ?? undefined,
|
||||
cwd: sandboxDir,
|
||||
tools: [],
|
||||
customTools: this.customTools,
|
||||
customTools: allCustomTools,
|
||||
resourceLoader,
|
||||
});
|
||||
piSession = result.session;
|
||||
} catch (err) {
|
||||
@@ -124,6 +326,10 @@ export class AgentService implements OnModuleDestroy {
|
||||
createdAt: Date.now(),
|
||||
promptCount: 0,
|
||||
channels: new Set(),
|
||||
skillPromptAdditions: promptAdditions,
|
||||
sandboxDir,
|
||||
allowedTools,
|
||||
userId: mergedOptions?.userId,
|
||||
};
|
||||
|
||||
this.sessions.set(sessionId, session);
|
||||
@@ -132,6 +338,92 @@ export class AgentService implements OnModuleDestroy {
|
||||
return session;
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate token count for a string using a rough 4-chars-per-token heuristic.
|
||||
*/
|
||||
private estimateTokens(text: string): number {
|
||||
return Math.ceil(text.length / 4);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a conversation history section for injection into the system prompt.
|
||||
* Implements M1-004 (history loading) and M1-005 (context window management).
|
||||
*
|
||||
* - Formats messages as a readable conversation transcript.
|
||||
* - If the full history exceeds 80% of the model's context window, older messages
|
||||
* are summarized and only the most recent messages are kept verbatim.
|
||||
* - Summarization is a simple extractive approach (no LLM required).
|
||||
*/
|
||||
private buildHistoryPromptSection(
|
||||
history: ConversationHistoryMessage[],
|
||||
contextWindow: number,
|
||||
sessionId: string,
|
||||
): string {
|
||||
const TOKEN_BUDGET = Math.floor(contextWindow * 0.8);
|
||||
const HISTORY_HEADER = '## Conversation History (resumed session)\n\n';
|
||||
|
||||
const formatMessage = (msg: ConversationHistoryMessage): string => {
|
||||
const roleLabel =
|
||||
msg.role === 'user' ? 'User' : msg.role === 'assistant' ? 'Assistant' : 'System';
|
||||
return `**${roleLabel}:** ${msg.content}`;
|
||||
};
|
||||
|
||||
const formatted = history.map((msg) => formatMessage(msg));
|
||||
const fullHistory = formatted.join('\n\n');
|
||||
const fullTokens = this.estimateTokens(HISTORY_HEADER + fullHistory);
|
||||
|
||||
if (fullTokens <= TOKEN_BUDGET) {
|
||||
this.logger.debug(
|
||||
`Session ${sessionId}: injecting full history (${history.length} msgs, ~${fullTokens} tokens)`,
|
||||
);
|
||||
return HISTORY_HEADER + fullHistory;
|
||||
}
|
||||
|
||||
// History exceeds budget — summarize oldest messages, keep recent verbatim
|
||||
this.logger.log(
|
||||
`Session ${sessionId}: history (~${fullTokens} tokens) exceeds ${TOKEN_BUDGET} token budget; summarizing oldest messages`,
|
||||
);
|
||||
|
||||
// Reserve 20% of the budget for the summary prefix, rest for verbatim messages
|
||||
const SUMMARY_RESERVE = Math.floor(TOKEN_BUDGET * 0.2);
|
||||
const verbatimBudget = TOKEN_BUDGET - SUMMARY_RESERVE;
|
||||
|
||||
let verbatimTokens = 0;
|
||||
let verbatimCutIndex = history.length;
|
||||
for (let i = history.length - 1; i >= 0; i--) {
|
||||
const t = this.estimateTokens(formatted[i]!);
|
||||
if (verbatimTokens + t > verbatimBudget) break;
|
||||
verbatimTokens += t;
|
||||
verbatimCutIndex = i;
|
||||
}
|
||||
|
||||
const summarizedMessages = history.slice(0, verbatimCutIndex);
|
||||
const verbatimMessages = history.slice(verbatimCutIndex);
|
||||
|
||||
let summaryText = '';
|
||||
if (summarizedMessages.length > 0) {
|
||||
const topics = summarizedMessages
|
||||
.filter((m) => m.role === 'user')
|
||||
.map((m) => m.content.slice(0, 120).replace(/\n/g, ' '))
|
||||
.join('; ');
|
||||
summaryText =
|
||||
`**Previous conversation summary** (${summarizedMessages.length} messages omitted for brevity):\n` +
|
||||
`Topics discussed: ${topics || '(no user messages in summarized portion)'}`;
|
||||
}
|
||||
|
||||
const verbatimSection = verbatimMessages.map((m) => formatMessage(m)).join('\n\n');
|
||||
|
||||
const parts: string[] = [HISTORY_HEADER];
|
||||
if (summaryText) parts.push(summaryText);
|
||||
if (verbatimSection) parts.push(verbatimSection);
|
||||
|
||||
const result = parts.join('\n\n');
|
||||
this.logger.log(
|
||||
`Session ${sessionId}: summarized ${summarizedMessages.length} messages, kept ${verbatimMessages.length} verbatim (~${this.estimateTokens(result)} tokens)`,
|
||||
);
|
||||
return result;
|
||||
}
|
||||
|
||||
private resolveModel(options?: AgentSessionOptions) {
|
||||
if (!options?.provider && !options?.modelId) {
|
||||
return this.providerService.getDefaultModel() ?? null;
|
||||
@@ -207,8 +499,20 @@ export class AgentService implements OnModuleDestroy {
|
||||
throw new Error(`No agent session found: ${sessionId}`);
|
||||
}
|
||||
session.promptCount += 1;
|
||||
|
||||
// Prepend session-scoped system override if present (renew TTL on each turn)
|
||||
let effectiveMessage = message;
|
||||
if (this.systemOverride) {
|
||||
const override = await this.systemOverride.get(sessionId);
|
||||
if (override) {
|
||||
effectiveMessage = `[System Override]\n${override}\n\n${message}`;
|
||||
await this.systemOverride.renew(sessionId);
|
||||
this.logger.debug(`Applied system override for session ${sessionId}`);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
await session.piSession.prompt(message);
|
||||
await session.piSession.prompt(effectiveMessage);
|
||||
} catch (err) {
|
||||
this.logger.error(
|
||||
`Prompt failed for session=${sessionId}, messageLength=${message.length}`,
|
||||
@@ -244,6 +548,14 @@ export class AgentService implements OnModuleDestroy {
|
||||
session.listeners.clear();
|
||||
session.channels.clear();
|
||||
this.sessions.delete(sessionId);
|
||||
|
||||
// Run GC cleanup for this session (fire and forget, errors are logged)
|
||||
this.gc.collect(sessionId).catch((err: unknown) => {
|
||||
this.logger.error(
|
||||
`GC collect failed for session ${sessionId}`,
|
||||
err instanceof Error ? err.stack : String(err),
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
async onModuleDestroy(): Promise<void> {
|
||||
|
||||
204
apps/gateway/src/agent/model-capabilities.ts
Normal file
204
apps/gateway/src/agent/model-capabilities.ts
Normal file
@@ -0,0 +1,204 @@
|
||||
import type { ModelCapability } from '@mosaic/types';
|
||||
|
||||
/**
|
||||
* Comprehensive capability matrix for all target models.
|
||||
* Cost fields are optional and will be filled in when real pricing data is available.
|
||||
*/
|
||||
export const MODEL_CAPABILITIES: ModelCapability[] = [
|
||||
{
|
||||
id: 'claude-opus-4-6',
|
||||
provider: 'anthropic',
|
||||
displayName: 'Claude Opus 4.6',
|
||||
tier: 'premium',
|
||||
contextWindow: 200000,
|
||||
maxOutputTokens: 32000,
|
||||
capabilities: {
|
||||
tools: true,
|
||||
vision: true,
|
||||
streaming: true,
|
||||
reasoning: true,
|
||||
embedding: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'claude-sonnet-4-6',
|
||||
provider: 'anthropic',
|
||||
displayName: 'Claude Sonnet 4.6',
|
||||
tier: 'standard',
|
||||
contextWindow: 200000,
|
||||
maxOutputTokens: 16000,
|
||||
capabilities: {
|
||||
tools: true,
|
||||
vision: true,
|
||||
streaming: true,
|
||||
reasoning: true,
|
||||
embedding: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'claude-haiku-4-5',
|
||||
provider: 'anthropic',
|
||||
displayName: 'Claude Haiku 4.5',
|
||||
tier: 'cheap',
|
||||
contextWindow: 200000,
|
||||
maxOutputTokens: 8192,
|
||||
capabilities: {
|
||||
tools: true,
|
||||
vision: true,
|
||||
streaming: true,
|
||||
reasoning: false,
|
||||
embedding: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'codex-gpt-5.4',
|
||||
provider: 'openai',
|
||||
displayName: 'Codex gpt-5.4',
|
||||
tier: 'premium',
|
||||
contextWindow: 128000,
|
||||
maxOutputTokens: 16384,
|
||||
capabilities: {
|
||||
tools: true,
|
||||
vision: true,
|
||||
streaming: true,
|
||||
reasoning: true,
|
||||
embedding: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'glm-5',
|
||||
provider: 'zai',
|
||||
displayName: 'GLM-5',
|
||||
tier: 'standard',
|
||||
contextWindow: 128000,
|
||||
maxOutputTokens: 8192,
|
||||
capabilities: {
|
||||
tools: true,
|
||||
vision: false,
|
||||
streaming: true,
|
||||
reasoning: false,
|
||||
embedding: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'llama3.2',
|
||||
provider: 'ollama',
|
||||
displayName: 'llama3.2',
|
||||
tier: 'local',
|
||||
contextWindow: 128000,
|
||||
maxOutputTokens: 8192,
|
||||
capabilities: {
|
||||
tools: true,
|
||||
vision: false,
|
||||
streaming: true,
|
||||
reasoning: false,
|
||||
embedding: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'codellama',
|
||||
provider: 'ollama',
|
||||
displayName: 'codellama',
|
||||
tier: 'local',
|
||||
contextWindow: 16000,
|
||||
maxOutputTokens: 4096,
|
||||
capabilities: {
|
||||
tools: true,
|
||||
vision: false,
|
||||
streaming: true,
|
||||
reasoning: false,
|
||||
embedding: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'mistral',
|
||||
provider: 'ollama',
|
||||
displayName: 'mistral',
|
||||
tier: 'local',
|
||||
contextWindow: 32000,
|
||||
maxOutputTokens: 8192,
|
||||
capabilities: {
|
||||
tools: true,
|
||||
vision: false,
|
||||
streaming: true,
|
||||
reasoning: false,
|
||||
embedding: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'nomic-embed-text',
|
||||
provider: 'ollama',
|
||||
displayName: 'nomic-embed-text',
|
||||
tier: 'local',
|
||||
contextWindow: 8192,
|
||||
maxOutputTokens: 0,
|
||||
capabilities: {
|
||||
tools: false,
|
||||
vision: false,
|
||||
streaming: false,
|
||||
reasoning: false,
|
||||
embedding: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'mxbai-embed-large',
|
||||
provider: 'ollama',
|
||||
displayName: 'mxbai-embed-large',
|
||||
tier: 'local',
|
||||
contextWindow: 8192,
|
||||
maxOutputTokens: 0,
|
||||
capabilities: {
|
||||
tools: false,
|
||||
vision: false,
|
||||
streaming: false,
|
||||
reasoning: false,
|
||||
embedding: true,
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
/**
|
||||
* Look up a model by its ID.
|
||||
* Returns undefined if the model is not found.
|
||||
*/
|
||||
export function getModelCapability(modelId: string): ModelCapability | undefined {
|
||||
return MODEL_CAPABILITIES.find((m) => m.id === modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Find models matching a partial capability filter.
|
||||
* All provided filter keys must match for a model to be included.
|
||||
*/
|
||||
export function findModelsByCapability(
|
||||
filter: Partial<Pick<ModelCapability, 'tier' | 'provider'>> & {
|
||||
capabilities?: Partial<ModelCapability['capabilities']>;
|
||||
},
|
||||
): ModelCapability[] {
|
||||
return MODEL_CAPABILITIES.filter((model) => {
|
||||
if (filter.tier !== undefined && model.tier !== filter.tier) return false;
|
||||
if (filter.provider !== undefined && model.provider !== filter.provider) return false;
|
||||
if (filter.capabilities) {
|
||||
for (const [key, value] of Object.entries(filter.capabilities) as [
|
||||
keyof ModelCapability['capabilities'],
|
||||
boolean,
|
||||
][]) {
|
||||
if (model.capabilities[key] !== value) return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models for a specific provider.
|
||||
*/
|
||||
export function getModelsByProvider(provider: string): ModelCapability[] {
|
||||
return MODEL_CAPABILITIES.filter((m) => m.provider === provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the full list of all known models.
|
||||
*/
|
||||
export function getAllModels(): ModelCapability[] {
|
||||
return MODEL_CAPABILITIES;
|
||||
}
|
||||
17
apps/gateway/src/agent/provider.dto.ts
Normal file
17
apps/gateway/src/agent/provider.dto.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
export interface TestConnectionDto {
|
||||
/** Provider identifier to test (e.g. 'ollama', custom provider id) */
|
||||
providerId: string;
|
||||
/** Optional base URL override for ad-hoc testing */
|
||||
baseUrl?: string;
|
||||
}
|
||||
|
||||
export interface TestConnectionResultDto {
|
||||
providerId: string;
|
||||
reachable: boolean;
|
||||
/** Round-trip latency in milliseconds (present when reachable) */
|
||||
latencyMs?: number;
|
||||
/** Human-readable error when unreachable */
|
||||
error?: string;
|
||||
/** Model ids discovered at the remote endpoint (present when reachable) */
|
||||
discoveredModels?: string[];
|
||||
}
|
||||
@@ -1,24 +1,212 @@
|
||||
import { Injectable, Logger, type OnModuleInit } from '@nestjs/common';
|
||||
import { Injectable, Logger, type OnModuleDestroy, type OnModuleInit } from '@nestjs/common';
|
||||
import { ModelRegistry, AuthStorage } from '@mariozechner/pi-coding-agent';
|
||||
import type { Model, Api } from '@mariozechner/pi-ai';
|
||||
import type { ModelInfo, ProviderInfo, CustomProviderConfig } from '@mosaic/types';
|
||||
import { getModel, type Model, type Api } from '@mariozechner/pi-ai';
|
||||
import type {
|
||||
CustomProviderConfig,
|
||||
IProviderAdapter,
|
||||
ModelInfo,
|
||||
ProviderHealth,
|
||||
ProviderInfo,
|
||||
} from '@mosaic/types';
|
||||
import {
|
||||
AnthropicAdapter,
|
||||
OllamaAdapter,
|
||||
OpenAIAdapter,
|
||||
OpenRouterAdapter,
|
||||
} from './adapters/index.js';
|
||||
import type { TestConnectionResultDto } from './provider.dto.js';
|
||||
|
||||
/** Default health check interval in seconds */
|
||||
const DEFAULT_HEALTH_INTERVAL_SECS = 60;
|
||||
|
||||
/** DI injection token for the provider adapter array. */
|
||||
export const PROVIDER_ADAPTERS = Symbol('PROVIDER_ADAPTERS');
|
||||
|
||||
@Injectable()
|
||||
export class ProviderService implements OnModuleInit {
|
||||
export class ProviderService implements OnModuleInit, OnModuleDestroy {
|
||||
private readonly logger = new Logger(ProviderService.name);
|
||||
private registry!: ModelRegistry;
|
||||
|
||||
/**
|
||||
* Adapters registered with this service.
|
||||
* Built-in adapters (Ollama) are always present; additional adapters can be
|
||||
* supplied via the PROVIDER_ADAPTERS injection token in the future.
|
||||
*/
|
||||
private adapters: IProviderAdapter[] = [];
|
||||
|
||||
/**
|
||||
* Cached health status per provider, updated by the health check scheduler.
|
||||
*/
|
||||
private healthCache: Map<string, ProviderHealth & { modelCount: number }> = new Map();
|
||||
|
||||
/** Timer handle for the periodic health check scheduler */
|
||||
private healthCheckTimer: ReturnType<typeof setInterval> | null = null;
|
||||
|
||||
async onModuleInit(): Promise<void> {
|
||||
const authStorage = AuthStorage.create();
|
||||
const authStorage = AuthStorage.inMemory();
|
||||
this.registry = new ModelRegistry(authStorage);
|
||||
|
||||
this.registerOllamaProvider();
|
||||
// Build the default set of adapters that rely on the registry
|
||||
this.adapters = [
|
||||
new OllamaAdapter(this.registry),
|
||||
new AnthropicAdapter(this.registry),
|
||||
new OpenAIAdapter(this.registry),
|
||||
new OpenRouterAdapter(),
|
||||
];
|
||||
|
||||
// Run all adapter registrations first (Ollama, Anthropic, and any future adapters)
|
||||
await this.registerAll();
|
||||
|
||||
// Register API-key providers directly (Z.ai, custom)
|
||||
// OpenAI now has a dedicated adapter (M3-003).
|
||||
this.registerZaiProvider();
|
||||
this.registerCustomProviders();
|
||||
|
||||
const available = this.registry.getAvailable();
|
||||
this.logger.log(`Providers initialized: ${available.length} models available`);
|
||||
|
||||
// Kick off the health check scheduler
|
||||
this.startHealthCheckScheduler();
|
||||
}
|
||||
|
||||
onModuleDestroy(): void {
|
||||
if (this.healthCheckTimer !== null) {
|
||||
clearInterval(this.healthCheckTimer);
|
||||
this.healthCheckTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Health check scheduler
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Start periodic health checks on all adapters.
|
||||
* Interval is configurable via PROVIDER_HEALTH_INTERVAL env (seconds, default 60).
|
||||
*/
|
||||
private startHealthCheckScheduler(): void {
|
||||
const intervalSecs =
|
||||
parseInt(process.env['PROVIDER_HEALTH_INTERVAL'] ?? '', 10) || DEFAULT_HEALTH_INTERVAL_SECS;
|
||||
const intervalMs = intervalSecs * 1000;
|
||||
|
||||
// Run an initial check immediately (non-blocking)
|
||||
void this.runScheduledHealthChecks();
|
||||
|
||||
this.healthCheckTimer = setInterval(() => {
|
||||
void this.runScheduledHealthChecks();
|
||||
}, intervalMs);
|
||||
|
||||
this.logger.log(`Provider health check scheduler started (interval: ${intervalSecs}s)`);
|
||||
}
|
||||
|
||||
private async runScheduledHealthChecks(): Promise<void> {
|
||||
for (const adapter of this.adapters) {
|
||||
try {
|
||||
const health = await adapter.healthCheck();
|
||||
const modelCount = adapter.listModels().length;
|
||||
this.healthCache.set(adapter.name, { ...health, modelCount });
|
||||
this.logger.debug(
|
||||
`Health check [${adapter.name}]: ${health.status} (${health.latencyMs ?? 'n/a'}ms)`,
|
||||
);
|
||||
} catch (err) {
|
||||
const modelCount = adapter.listModels().length;
|
||||
this.healthCache.set(adapter.name, {
|
||||
status: 'down',
|
||||
lastChecked: new Date().toISOString(),
|
||||
error: err instanceof Error ? err.message : String(err),
|
||||
modelCount,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the cached health status for all adapters.
|
||||
* Format: array of { name, status, latencyMs, lastChecked, modelCount }
|
||||
*/
|
||||
getProvidersHealth(): Array<{
|
||||
name: string;
|
||||
status: string;
|
||||
latencyMs?: number;
|
||||
lastChecked: string;
|
||||
modelCount: number;
|
||||
error?: string;
|
||||
}> {
|
||||
return this.adapters.map((adapter) => {
|
||||
const cached = this.healthCache.get(adapter.name);
|
||||
if (cached) {
|
||||
return {
|
||||
name: adapter.name,
|
||||
status: cached.status,
|
||||
latencyMs: cached.latencyMs,
|
||||
lastChecked: cached.lastChecked,
|
||||
modelCount: cached.modelCount,
|
||||
error: cached.error,
|
||||
};
|
||||
}
|
||||
// Not yet checked — return a pending placeholder
|
||||
return {
|
||||
name: adapter.name,
|
||||
status: 'unknown',
|
||||
lastChecked: new Date().toISOString(),
|
||||
modelCount: adapter.listModels().length,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Adapter-pattern API
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Call register() on each adapter in order.
|
||||
* Errors from individual adapters are logged and do not abort the others.
|
||||
*/
|
||||
async registerAll(): Promise<void> {
|
||||
for (const adapter of this.adapters) {
|
||||
try {
|
||||
await adapter.register();
|
||||
} catch (err) {
|
||||
this.logger.error(
|
||||
`Adapter "${adapter.name}" registration failed`,
|
||||
err instanceof Error ? err.stack : String(err),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the adapter registered under the given provider name, or undefined.
|
||||
*/
|
||||
getAdapter(providerName: string): IProviderAdapter | undefined {
|
||||
return this.adapters.find((a) => a.name === providerName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Run healthCheck() on all adapters and return results keyed by provider name.
|
||||
*/
|
||||
async healthCheckAll(): Promise<Record<string, ProviderHealth>> {
|
||||
const results: Record<string, ProviderHealth> = {};
|
||||
await Promise.all(
|
||||
this.adapters.map(async (adapter) => {
|
||||
try {
|
||||
results[adapter.name] = await adapter.healthCheck();
|
||||
} catch (err) {
|
||||
results[adapter.name] = {
|
||||
status: 'down',
|
||||
lastChecked: new Date().toISOString(),
|
||||
error: err instanceof Error ? err.message : String(err),
|
||||
};
|
||||
}
|
||||
}),
|
||||
);
|
||||
return results;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Legacy / Pi-SDK-facing API (preserved for AgentService and RoutingService)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
getRegistry(): ModelRegistry {
|
||||
return this.registry;
|
||||
}
|
||||
@@ -64,6 +252,75 @@ export class ProviderService implements OnModuleInit {
|
||||
return this.registry.getAvailable().map((m) => this.toModelInfo(m));
|
||||
}
|
||||
|
||||
async testConnection(providerId: string, baseUrl?: string): Promise<TestConnectionResultDto> {
|
||||
// Delegate to the adapter when one exists and no URL override is given
|
||||
const adapter = this.getAdapter(providerId);
|
||||
if (adapter && !baseUrl) {
|
||||
const health = await adapter.healthCheck();
|
||||
return {
|
||||
providerId,
|
||||
reachable: health.status !== 'down',
|
||||
latencyMs: health.latencyMs,
|
||||
error: health.error,
|
||||
};
|
||||
}
|
||||
|
||||
// Resolve baseUrl: explicit override > registered provider > ollama env
|
||||
let resolvedUrl = baseUrl;
|
||||
|
||||
if (!resolvedUrl) {
|
||||
const allModels = this.registry.getAll();
|
||||
const providerModels = allModels.filter((m) => m.provider === providerId);
|
||||
if (providerModels.length === 0) {
|
||||
return { providerId, reachable: false, error: `Provider '${providerId}' not found` };
|
||||
}
|
||||
// For Ollama, derive the base URL from environment
|
||||
if (providerId === 'ollama') {
|
||||
const ollamaUrl = process.env['OLLAMA_BASE_URL'] ?? process.env['OLLAMA_HOST'];
|
||||
if (!ollamaUrl) {
|
||||
return { providerId, reachable: false, error: 'OLLAMA_BASE_URL not configured' };
|
||||
}
|
||||
resolvedUrl = `${ollamaUrl}/v1/models`;
|
||||
} else {
|
||||
// For other providers, we can only do a basic check
|
||||
return { providerId, reachable: true, discoveredModels: providerModels.map((m) => m.id) };
|
||||
}
|
||||
} else {
|
||||
resolvedUrl = resolvedUrl.replace(/\/?$/, '') + '/models';
|
||||
}
|
||||
|
||||
const start = Date.now();
|
||||
try {
|
||||
const res = await fetch(resolvedUrl, {
|
||||
method: 'GET',
|
||||
headers: { Accept: 'application/json' },
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
|
||||
const latencyMs = Date.now() - start;
|
||||
|
||||
if (!res.ok) {
|
||||
return { providerId, reachable: false, latencyMs, error: `HTTP ${res.status}` };
|
||||
}
|
||||
|
||||
let discoveredModels: string[] | undefined;
|
||||
try {
|
||||
const json = (await res.json()) as { models?: Array<{ id?: string; name?: string }> };
|
||||
if (Array.isArray(json.models)) {
|
||||
discoveredModels = json.models.map((m) => m.id ?? m.name ?? '').filter(Boolean);
|
||||
}
|
||||
} catch {
|
||||
// ignore parse errors — endpoint was reachable
|
||||
}
|
||||
|
||||
return { providerId, reachable: true, latencyMs, discoveredModels };
|
||||
} catch (err) {
|
||||
const latencyMs = Date.now() - start;
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
return { providerId, reachable: false, latencyMs, error: message };
|
||||
}
|
||||
}
|
||||
|
||||
registerCustomProvider(config: CustomProviderConfig): void {
|
||||
this.registry.registerProvider(config.id, {
|
||||
baseUrl: config.baseUrl,
|
||||
@@ -82,32 +339,29 @@ export class ProviderService implements OnModuleInit {
|
||||
this.logger.log(`Registered custom provider: ${config.id} (${config.models.length} models)`);
|
||||
}
|
||||
|
||||
private registerOllamaProvider(): void {
|
||||
const ollamaUrl = process.env['OLLAMA_BASE_URL'] ?? process.env['OLLAMA_HOST'];
|
||||
if (!ollamaUrl) return;
|
||||
// ---------------------------------------------------------------------------
|
||||
// Private helpers — direct registry registration for providers without adapters yet
|
||||
// (Z.ai will move to an adapter in M3-005)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const modelsEnv = process.env['OLLAMA_MODELS'] ?? 'llama3.2,codellama,mistral';
|
||||
const modelIds = modelsEnv
|
||||
.split(',')
|
||||
.map((m) => m.trim())
|
||||
.filter(Boolean);
|
||||
private registerZaiProvider(): void {
|
||||
const apiKey = process.env['ZAI_API_KEY'];
|
||||
if (!apiKey) {
|
||||
this.logger.debug('Skipping Z.ai provider registration: ZAI_API_KEY not set');
|
||||
return;
|
||||
}
|
||||
|
||||
this.registerCustomProvider({
|
||||
id: 'ollama',
|
||||
name: 'Ollama',
|
||||
baseUrl: `${ollamaUrl}/v1`,
|
||||
models: modelIds.map((id) => ({
|
||||
id,
|
||||
name: id,
|
||||
reasoning: false,
|
||||
contextWindow: 8192,
|
||||
maxTokens: 4096,
|
||||
})),
|
||||
const models = ['glm-4.5', 'glm-4.5-air', 'glm-4.5-flash'].map((id) =>
|
||||
this.cloneBuiltInModel('zai', id),
|
||||
);
|
||||
|
||||
this.registry.registerProvider('zai', {
|
||||
apiKey,
|
||||
baseUrl: 'https://open.bigmodel.cn/api/paas/v4',
|
||||
models,
|
||||
});
|
||||
|
||||
this.logger.log(
|
||||
`Ollama provider registered at ${ollamaUrl} with models: ${modelIds.join(', ')}`,
|
||||
);
|
||||
this.logger.log('Z.ai provider registered with 3 models');
|
||||
}
|
||||
|
||||
private registerCustomProviders(): void {
|
||||
@@ -124,6 +378,19 @@ export class ProviderService implements OnModuleInit {
|
||||
}
|
||||
}
|
||||
|
||||
private cloneBuiltInModel(
|
||||
provider: string,
|
||||
modelId: string,
|
||||
overrides: Partial<Model<Api>> = {},
|
||||
): Model<Api> {
|
||||
const model = getModel(provider as never, modelId as never) as Model<Api> | undefined;
|
||||
if (!model) {
|
||||
throw new Error(`Built-in model not found: ${provider}:${modelId}`);
|
||||
}
|
||||
|
||||
return { ...model, ...overrides };
|
||||
}
|
||||
|
||||
private toModelInfo(model: Model<Api>): ModelInfo {
|
||||
return {
|
||||
id: model.id,
|
||||
|
||||
@@ -3,6 +3,7 @@ import type { RoutingCriteria } from '@mosaic/types';
|
||||
import { AuthGuard } from '../auth/auth.guard.js';
|
||||
import { ProviderService } from './provider.service.js';
|
||||
import { RoutingService } from './routing.service.js';
|
||||
import type { TestConnectionDto, TestConnectionResultDto } from './provider.dto.js';
|
||||
|
||||
@Controller('api/providers')
|
||||
@UseGuards(AuthGuard)
|
||||
@@ -22,6 +23,16 @@ export class ProvidersController {
|
||||
return this.providerService.listAvailableModels();
|
||||
}
|
||||
|
||||
@Get('health')
|
||||
health() {
|
||||
return { providers: this.providerService.getProvidersHealth() };
|
||||
}
|
||||
|
||||
@Post('test')
|
||||
testConnection(@Body() body: TestConnectionDto): Promise<TestConnectionResultDto> {
|
||||
return this.providerService.testConnection(body.providerId, body.baseUrl);
|
||||
}
|
||||
|
||||
@Post('route')
|
||||
route(@Body() criteria: RoutingCriteria) {
|
||||
return this.routingService.route(criteria);
|
||||
|
||||
@@ -145,8 +145,11 @@ export class RoutingService {
|
||||
|
||||
private classifyTier(model: ModelInfo): CostTier {
|
||||
const cost = model.cost.input;
|
||||
if (cost <= COST_TIER_THRESHOLDS.cheap.maxInput) return 'cheap';
|
||||
if (cost <= COST_TIER_THRESHOLDS.standard.maxInput) return 'standard';
|
||||
const cheapThreshold = COST_TIER_THRESHOLDS['cheap'];
|
||||
const standardThreshold = COST_TIER_THRESHOLDS['standard'];
|
||||
|
||||
if (cost <= cheapThreshold.maxInput) return 'cheap';
|
||||
if (cost <= standardThreshold.maxInput) return 'standard';
|
||||
return 'premium';
|
||||
}
|
||||
|
||||
|
||||
@@ -12,3 +12,33 @@ export interface SessionListDto {
|
||||
sessions: SessionInfoDto[];
|
||||
total: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Options accepted when creating an agent session.
|
||||
* All fields are optional; omitting them falls back to env-var or process defaults.
|
||||
*/
|
||||
export interface CreateSessionOptionsDto {
|
||||
/** Provider name (e.g. "anthropic", "openai"). */
|
||||
provider?: string;
|
||||
/** Model ID to use for this session. */
|
||||
modelId?: string;
|
||||
/**
|
||||
* Sandbox working directory for the session.
|
||||
* File, git, and shell tools will be restricted to this directory.
|
||||
* Defaults to AGENT_FILE_SANDBOX_DIR env var or process.cwd().
|
||||
*/
|
||||
sandboxDir?: string;
|
||||
/**
|
||||
* Platform-level system prompt for this session.
|
||||
* Merged with skill prompt additions (platform prompt first, then skills).
|
||||
* Falls back to AGENT_SYSTEM_PROMPT env var when omitted.
|
||||
*/
|
||||
systemPrompt?: string;
|
||||
/**
|
||||
* Explicit allowlist of tool names available in this session.
|
||||
* When provided, only listed tools are registered with the agent.
|
||||
* Admins receive all tools; regular users fall back to AGENT_USER_TOOLS
|
||||
* env var (comma-separated) when this field is not supplied.
|
||||
*/
|
||||
allowedTools?: string[];
|
||||
}
|
||||
|
||||
59
apps/gateway/src/agent/skill-loader.service.ts
Normal file
59
apps/gateway/src/agent/skill-loader.service.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
import { Inject, Injectable, Logger } from '@nestjs/common';
|
||||
import type { ToolDefinition } from '@mariozechner/pi-coding-agent';
|
||||
import { SkillsService } from '../skills/skills.service.js';
|
||||
import { createSkillTools } from './tools/skill-tools.js';
|
||||
|
||||
export interface LoadedSkills {
|
||||
/** Meta-tools: skill_list + skill_invoke */
|
||||
metaTools: ToolDefinition[];
|
||||
/**
|
||||
* System prompt additions from enabled prompt-type skills.
|
||||
* Callers may prepend these to the session system prompt.
|
||||
*/
|
||||
promptAdditions: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* SkillLoaderService is responsible for:
|
||||
* 1. Providing the skill meta-tools (skill_list, skill_invoke) to agent sessions.
|
||||
* 2. Collecting system-prompt additions from enabled prompt-type skills.
|
||||
*/
|
||||
@Injectable()
|
||||
export class SkillLoaderService {
|
||||
private readonly logger = new Logger(SkillLoaderService.name);
|
||||
|
||||
constructor(@Inject(SkillsService) private readonly skillsService: SkillsService) {}
|
||||
|
||||
/**
|
||||
* Load enabled skills and return tools + prompt additions for a new session.
|
||||
*/
|
||||
async loadForSession(): Promise<LoadedSkills> {
|
||||
const metaTools = createSkillTools(this.skillsService);
|
||||
|
||||
let promptAdditions: string[] = [];
|
||||
try {
|
||||
const enabledSkills = await this.skillsService.findEnabled();
|
||||
promptAdditions = enabledSkills.flatMap((skill) => {
|
||||
const config = (skill.config ?? {}) as Record<string, unknown>;
|
||||
const skillType = (config['type'] as string | undefined) ?? 'prompt';
|
||||
if (skillType === 'prompt') {
|
||||
const addition = (config['prompt'] as string | undefined) ?? skill.description;
|
||||
return addition ? [addition] : [];
|
||||
}
|
||||
return [];
|
||||
});
|
||||
|
||||
this.logger.log(
|
||||
`Loaded ${enabledSkills.length} enabled skill(s), ` +
|
||||
`${promptAdditions.length} prompt addition(s)`,
|
||||
);
|
||||
} catch (err) {
|
||||
// Non-fatal: log and continue without prompt additions
|
||||
this.logger.warn(
|
||||
`Failed to load skill prompt additions: ${err instanceof Error ? err.message : String(err)}`,
|
||||
);
|
||||
}
|
||||
|
||||
return { metaTools, promptAdditions };
|
||||
}
|
||||
}
|
||||
194
apps/gateway/src/agent/tools/file-tools.ts
Normal file
194
apps/gateway/src/agent/tools/file-tools.ts
Normal file
@@ -0,0 +1,194 @@
|
||||
import { Type } from '@sinclair/typebox';
|
||||
import type { ToolDefinition } from '@mariozechner/pi-coding-agent';
|
||||
import { readFile, writeFile, readdir, stat } from 'node:fs/promises';
|
||||
import { guardPath, guardPathUnsafe, SandboxEscapeError } from './path-guard.js';
|
||||
|
||||
const MAX_READ_BYTES = 512 * 1024; // 512 KB read limit
|
||||
const MAX_WRITE_BYTES = 1024 * 1024; // 1 MB write limit
|
||||
|
||||
export function createFileTools(baseDir: string): ToolDefinition[] {
|
||||
const readFileTool: ToolDefinition = {
|
||||
name: 'fs_read_file',
|
||||
label: 'Read File',
|
||||
description:
|
||||
'Read the contents of a file. Path is resolved relative to the sandbox base directory.',
|
||||
parameters: Type.Object({
|
||||
path: Type.String({
|
||||
description: 'File path (relative to sandbox base or absolute within it)',
|
||||
}),
|
||||
encoding: Type.Optional(
|
||||
Type.String({ description: 'Encoding: utf8 (default), base64, hex' }),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { path, encoding } = params as { path: string; encoding?: string };
|
||||
let safePath: string;
|
||||
try {
|
||||
safePath = guardPath(path, baseDir);
|
||||
} catch (err) {
|
||||
if (err instanceof SandboxEscapeError) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${err.message}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${String(err)}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const info = await stat(safePath);
|
||||
if (!info.isFile()) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: path is not a file: ${path}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
if (info.size > MAX_READ_BYTES) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: `Error: file too large (${info.size} bytes, limit ${MAX_READ_BYTES} bytes)`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
const enc = (encoding ?? 'utf8') as BufferEncoding;
|
||||
const content = await readFile(safePath, { encoding: enc });
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: String(content) }],
|
||||
details: undefined,
|
||||
};
|
||||
} catch (err) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error reading file: ${String(err)}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
const writeFileTool: ToolDefinition = {
|
||||
name: 'fs_write_file',
|
||||
label: 'Write File',
|
||||
description:
|
||||
'Write content to a file. Path is resolved relative to the sandbox base directory. Overwrites existing file.',
|
||||
parameters: Type.Object({
|
||||
path: Type.String({
|
||||
description: 'File path (relative to sandbox base or absolute within it)',
|
||||
}),
|
||||
content: Type.String({ description: 'Content to write' }),
|
||||
encoding: Type.Optional(Type.String({ description: 'Encoding: utf8 (default), base64' })),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { path, content, encoding } = params as {
|
||||
path: string;
|
||||
content: string;
|
||||
encoding?: string;
|
||||
};
|
||||
let safePath: string;
|
||||
try {
|
||||
safePath = guardPathUnsafe(path, baseDir);
|
||||
} catch (err) {
|
||||
if (err instanceof SandboxEscapeError) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${err.message}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${String(err)}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
if (Buffer.byteLength(content, 'utf8') > MAX_WRITE_BYTES) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: `Error: content too large (limit ${MAX_WRITE_BYTES} bytes)`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const enc = (encoding ?? 'utf8') as BufferEncoding;
|
||||
await writeFile(safePath, content, { encoding: enc });
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `File written successfully: ${path}` }],
|
||||
details: undefined,
|
||||
};
|
||||
} catch (err) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error writing file: ${String(err)}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
const listDirectoryTool: ToolDefinition = {
|
||||
name: 'fs_list_directory',
|
||||
label: 'List Directory',
|
||||
description: 'List files and directories at a given path within the sandbox base directory.',
|
||||
parameters: Type.Object({
|
||||
path: Type.Optional(
|
||||
Type.String({
|
||||
description: 'Directory path (relative to sandbox base). Defaults to base directory.',
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { path } = params as { path?: string };
|
||||
const target = path ?? '.';
|
||||
let safePath: string;
|
||||
try {
|
||||
safePath = guardPath(target, baseDir);
|
||||
} catch (err) {
|
||||
if (err instanceof SandboxEscapeError) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${err.message}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${String(err)}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const info = await stat(safePath);
|
||||
if (!info.isDirectory()) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: path is not a directory: ${target}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
const entries = await readdir(safePath, { withFileTypes: true });
|
||||
const items = entries.map((e) => ({
|
||||
name: e.name,
|
||||
type: e.isDirectory() ? 'directory' : e.isSymbolicLink() ? 'symlink' : 'file',
|
||||
}));
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: JSON.stringify(items, null, 2) }],
|
||||
details: undefined,
|
||||
};
|
||||
} catch (err) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error listing directory: ${String(err)}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
return [readFileTool, writeFileTool, listDirectoryTool];
|
||||
}
|
||||
212
apps/gateway/src/agent/tools/git-tools.ts
Normal file
212
apps/gateway/src/agent/tools/git-tools.ts
Normal file
@@ -0,0 +1,212 @@
|
||||
import { Type } from '@sinclair/typebox';
|
||||
import type { ToolDefinition } from '@mariozechner/pi-coding-agent';
|
||||
import { exec } from 'node:child_process';
|
||||
import { promisify } from 'node:util';
|
||||
import { guardPath, guardPathUnsafe, SandboxEscapeError } from './path-guard.js';
|
||||
|
||||
const execAsync = promisify(exec);
|
||||
|
||||
const GIT_TIMEOUT_MS = 15_000;
|
||||
const MAX_OUTPUT_BYTES = 100 * 1024; // 100 KB
|
||||
|
||||
async function runGit(
|
||||
args: string[],
|
||||
cwd?: string,
|
||||
): Promise<{ stdout: string; stderr: string; error?: string }> {
|
||||
// Only allow specific safe read-only git subcommands
|
||||
const allowedSubcommands = ['status', 'log', 'diff', 'show', 'branch', 'tag', 'ls-files'];
|
||||
const subcommand = args[0];
|
||||
if (!subcommand || !allowedSubcommands.includes(subcommand)) {
|
||||
return {
|
||||
stdout: '',
|
||||
stderr: '',
|
||||
error: `Blocked: git subcommand "${subcommand}" is not allowed. Permitted: ${allowedSubcommands.join(', ')}`,
|
||||
};
|
||||
}
|
||||
|
||||
const cmd = `git ${args.map((a) => JSON.stringify(a)).join(' ')}`;
|
||||
try {
|
||||
const { stdout, stderr } = await execAsync(cmd, {
|
||||
cwd,
|
||||
timeout: GIT_TIMEOUT_MS,
|
||||
maxBuffer: MAX_OUTPUT_BYTES,
|
||||
});
|
||||
return { stdout, stderr };
|
||||
} catch (err: unknown) {
|
||||
const e = err as { stdout?: string; stderr?: string; message?: string };
|
||||
return {
|
||||
stdout: e.stdout ?? '',
|
||||
stderr: e.stderr ?? '',
|
||||
error: e.message ?? String(err),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export function createGitTools(sandboxDir?: string): ToolDefinition[] {
|
||||
const defaultCwd = sandboxDir ?? process.cwd();
|
||||
|
||||
const gitStatus: ToolDefinition = {
|
||||
name: 'git_status',
|
||||
label: 'Git Status',
|
||||
description: 'Show the working tree status (staged, unstaged, untracked files).',
|
||||
parameters: Type.Object({
|
||||
cwd: Type.Optional(
|
||||
Type.String({
|
||||
description: 'Repository working directory (relative to sandbox or absolute within it).',
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { cwd } = params as { cwd?: string };
|
||||
let safeCwd: string;
|
||||
try {
|
||||
safeCwd = guardPath(cwd ?? '.', defaultCwd);
|
||||
} catch (err) {
|
||||
if (err instanceof SandboxEscapeError) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${err.message}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${String(err)}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
const result = await runGit(['status', '--short', '--branch'], safeCwd);
|
||||
const text = result.error
|
||||
? `Error: ${result.error}\n${result.stderr}`
|
||||
: result.stdout || '(no output)';
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: text }],
|
||||
details: undefined,
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const gitLog: ToolDefinition = {
|
||||
name: 'git_log',
|
||||
label: 'Git Log',
|
||||
description: 'Show recent commit history.',
|
||||
parameters: Type.Object({
|
||||
limit: Type.Optional(Type.Number({ description: 'Number of commits to show (default 20)' })),
|
||||
oneline: Type.Optional(
|
||||
Type.Boolean({ description: 'Compact one-line format (default true)' }),
|
||||
),
|
||||
cwd: Type.Optional(
|
||||
Type.String({
|
||||
description: 'Repository working directory (relative to sandbox or absolute within it).',
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { limit, oneline, cwd } = params as {
|
||||
limit?: number;
|
||||
oneline?: boolean;
|
||||
cwd?: string;
|
||||
};
|
||||
let safeCwd: string;
|
||||
try {
|
||||
safeCwd = guardPath(cwd ?? '.', defaultCwd);
|
||||
} catch (err) {
|
||||
if (err instanceof SandboxEscapeError) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${err.message}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${String(err)}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
const args = ['log', `--max-count=${limit ?? 20}`];
|
||||
if (oneline !== false) args.push('--oneline');
|
||||
const result = await runGit(args, safeCwd);
|
||||
const text = result.error
|
||||
? `Error: ${result.error}\n${result.stderr}`
|
||||
: result.stdout || '(no commits)';
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: text }],
|
||||
details: undefined,
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const gitDiff: ToolDefinition = {
|
||||
name: 'git_diff',
|
||||
label: 'Git Diff',
|
||||
description: 'Show changes between commits, working tree, or staged changes.',
|
||||
parameters: Type.Object({
|
||||
staged: Type.Optional(
|
||||
Type.Boolean({ description: 'Show staged (cached) changes instead of unstaged' }),
|
||||
),
|
||||
ref: Type.Optional(
|
||||
Type.String({ description: 'Compare against this ref (commit SHA, branch, or tag)' }),
|
||||
),
|
||||
path: Type.Optional(
|
||||
Type.String({ description: 'Limit diff to a specific file or directory' }),
|
||||
),
|
||||
cwd: Type.Optional(
|
||||
Type.String({
|
||||
description: 'Repository working directory (relative to sandbox or absolute within it).',
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { staged, ref, path, cwd } = params as {
|
||||
staged?: boolean;
|
||||
ref?: string;
|
||||
path?: string;
|
||||
cwd?: string;
|
||||
};
|
||||
let safeCwd: string;
|
||||
try {
|
||||
safeCwd = guardPath(cwd ?? '.', defaultCwd);
|
||||
} catch (err) {
|
||||
if (err instanceof SandboxEscapeError) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${err.message}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${String(err)}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
let safePath: string | undefined;
|
||||
if (path !== undefined) {
|
||||
try {
|
||||
safePath = guardPathUnsafe(path, defaultCwd);
|
||||
} catch (err) {
|
||||
if (err instanceof SandboxEscapeError) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${err.message}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${String(err)}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
const args = ['diff'];
|
||||
if (staged) args.push('--cached');
|
||||
if (ref) args.push(ref);
|
||||
args.push('--');
|
||||
if (safePath !== undefined) args.push(safePath);
|
||||
const result = await runGit(args, safeCwd);
|
||||
const text = result.error
|
||||
? `Error: ${result.error}\n${result.stderr}`
|
||||
: result.stdout || '(no diff)';
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: text }],
|
||||
details: undefined,
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
return [gitStatus, gitLog, gitDiff];
|
||||
}
|
||||
@@ -1,2 +1,7 @@
|
||||
export { createBrainTools } from './brain-tools.js';
|
||||
export { createCoordTools } from './coord-tools.js';
|
||||
export { createFileTools } from './file-tools.js';
|
||||
export { createGitTools } from './git-tools.js';
|
||||
export { createShellTools } from './shell-tools.js';
|
||||
export { createWebTools } from './web-tools.js';
|
||||
export { createSkillTools } from './skill-tools.js';
|
||||
|
||||
@@ -3,23 +3,45 @@ import type { ToolDefinition } from '@mariozechner/pi-coding-agent';
|
||||
import type { Memory } from '@mosaic/memory';
|
||||
import type { EmbeddingProvider } from '@mosaic/memory';
|
||||
|
||||
/**
|
||||
* Create memory tools bound to the session's authenticated userId.
|
||||
*
|
||||
* SECURITY: userId is resolved from the authenticated session at tool-creation
|
||||
* time and is never accepted as a user-supplied or LLM-supplied parameter.
|
||||
* This prevents cross-user data access via parameter injection.
|
||||
*/
|
||||
export function createMemoryTools(
|
||||
memory: Memory,
|
||||
embeddingProvider: EmbeddingProvider | null,
|
||||
/** Authenticated user ID from the session. All memory operations are scoped to this user. */
|
||||
sessionUserId: string | undefined,
|
||||
): ToolDefinition[] {
|
||||
/** Return an error result when no session user is bound. */
|
||||
function noUserError() {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: 'Memory tools unavailable — no authenticated user bound to this session',
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
const searchMemory: ToolDefinition = {
|
||||
name: 'memory_search',
|
||||
label: 'Search Memory',
|
||||
description:
|
||||
'Search across stored insights and knowledge using natural language. Returns semantically similar results.',
|
||||
parameters: Type.Object({
|
||||
userId: Type.String({ description: 'User ID to search memory for' }),
|
||||
query: Type.String({ description: 'Natural language search query' }),
|
||||
limit: Type.Optional(Type.Number({ description: 'Max results (default 5)' })),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { userId, query, limit } = params as {
|
||||
userId: string;
|
||||
if (!sessionUserId) return noUserError();
|
||||
|
||||
const { query, limit } = params as {
|
||||
query: string;
|
||||
limit?: number;
|
||||
};
|
||||
@@ -37,7 +59,7 @@ export function createMemoryTools(
|
||||
}
|
||||
|
||||
const embedding = await embeddingProvider.embed(query);
|
||||
const results = await memory.insights.searchByEmbedding(userId, embedding, limit ?? 5);
|
||||
const results = await memory.insights.searchByEmbedding(sessionUserId, embedding, limit ?? 5);
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: JSON.stringify(results, null, 2) }],
|
||||
details: undefined,
|
||||
@@ -48,9 +70,8 @@ export function createMemoryTools(
|
||||
const getPreferences: ToolDefinition = {
|
||||
name: 'memory_get_preferences',
|
||||
label: 'Get User Preferences',
|
||||
description: 'Retrieve stored preferences for a user.',
|
||||
description: 'Retrieve stored preferences for the current session user.',
|
||||
parameters: Type.Object({
|
||||
userId: Type.String({ description: 'User ID' }),
|
||||
category: Type.Optional(
|
||||
Type.String({
|
||||
description: 'Filter by category: communication, coding, workflow, appearance, general',
|
||||
@@ -58,11 +79,13 @@ export function createMemoryTools(
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { userId, category } = params as { userId: string; category?: string };
|
||||
if (!sessionUserId) return noUserError();
|
||||
|
||||
const { category } = params as { category?: string };
|
||||
type Cat = 'communication' | 'coding' | 'workflow' | 'appearance' | 'general';
|
||||
const prefs = category
|
||||
? await memory.preferences.findByUserAndCategory(userId, category as Cat)
|
||||
: await memory.preferences.findByUser(userId);
|
||||
? await memory.preferences.findByUserAndCategory(sessionUserId, category as Cat)
|
||||
: await memory.preferences.findByUser(sessionUserId);
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: JSON.stringify(prefs, null, 2) }],
|
||||
details: undefined,
|
||||
@@ -76,7 +99,6 @@ export function createMemoryTools(
|
||||
description:
|
||||
'Store a learned user preference (e.g., "prefers tables over paragraphs", "timezone: America/Chicago").',
|
||||
parameters: Type.Object({
|
||||
userId: Type.String({ description: 'User ID' }),
|
||||
key: Type.String({ description: 'Preference key' }),
|
||||
value: Type.String({ description: 'Preference value (JSON string)' }),
|
||||
category: Type.Optional(
|
||||
@@ -86,8 +108,9 @@ export function createMemoryTools(
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { userId, key, value, category } = params as {
|
||||
userId: string;
|
||||
if (!sessionUserId) return noUserError();
|
||||
|
||||
const { key, value, category } = params as {
|
||||
key: string;
|
||||
value: string;
|
||||
category?: string;
|
||||
@@ -100,7 +123,7 @@ export function createMemoryTools(
|
||||
parsedValue = value;
|
||||
}
|
||||
const pref = await memory.preferences.upsert({
|
||||
userId,
|
||||
userId: sessionUserId,
|
||||
key,
|
||||
value: parsedValue,
|
||||
category: (category as Cat) ?? 'general',
|
||||
@@ -119,7 +142,6 @@ export function createMemoryTools(
|
||||
description:
|
||||
'Store a learned insight, decision, or knowledge extracted from the current interaction.',
|
||||
parameters: Type.Object({
|
||||
userId: Type.String({ description: 'User ID' }),
|
||||
content: Type.String({ description: 'The insight or knowledge to store' }),
|
||||
category: Type.Optional(
|
||||
Type.String({
|
||||
@@ -128,8 +150,9 @@ export function createMemoryTools(
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { userId, content, category } = params as {
|
||||
userId: string;
|
||||
if (!sessionUserId) return noUserError();
|
||||
|
||||
const { content, category } = params as {
|
||||
content: string;
|
||||
category?: string;
|
||||
};
|
||||
@@ -141,7 +164,7 @@ export function createMemoryTools(
|
||||
}
|
||||
|
||||
const insight = await memory.insights.create({
|
||||
userId,
|
||||
userId: sessionUserId,
|
||||
content,
|
||||
embedding,
|
||||
source: 'agent',
|
||||
|
||||
104
apps/gateway/src/agent/tools/path-guard.test.ts
Normal file
104
apps/gateway/src/agent/tools/path-guard.test.ts
Normal file
@@ -0,0 +1,104 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { guardPath, guardPathUnsafe, SandboxEscapeError } from './path-guard.js';
|
||||
import path from 'node:path';
|
||||
import os from 'node:os';
|
||||
import fs from 'node:fs';
|
||||
|
||||
describe('guardPathUnsafe', () => {
|
||||
const sandbox = '/tmp/test-sandbox';
|
||||
|
||||
it('allows paths inside sandbox', () => {
|
||||
const result = guardPathUnsafe('foo/bar.txt', sandbox);
|
||||
expect(result).toBe(path.resolve(sandbox, 'foo/bar.txt'));
|
||||
});
|
||||
|
||||
it('allows sandbox root itself', () => {
|
||||
const result = guardPathUnsafe('.', sandbox);
|
||||
expect(result).toBe(path.resolve(sandbox));
|
||||
});
|
||||
|
||||
it('rejects path traversal with ../', () => {
|
||||
expect(() => guardPathUnsafe('../escape.txt', sandbox)).toThrow(SandboxEscapeError);
|
||||
});
|
||||
|
||||
it('rejects absolute path outside sandbox', () => {
|
||||
expect(() => guardPathUnsafe('/etc/passwd', sandbox)).toThrow(SandboxEscapeError);
|
||||
});
|
||||
|
||||
it('rejects deeply nested traversal', () => {
|
||||
expect(() => guardPathUnsafe('a/b/../../../../../../etc/passwd', sandbox)).toThrow(
|
||||
SandboxEscapeError,
|
||||
);
|
||||
});
|
||||
|
||||
it('rejects path that starts with sandbox name but is sibling', () => {
|
||||
expect(() => guardPathUnsafe('/tmp/test-sandbox-evil/file.txt', sandbox)).toThrow(
|
||||
SandboxEscapeError,
|
||||
);
|
||||
});
|
||||
|
||||
it('returns the resolved absolute path for nested paths', () => {
|
||||
const result = guardPathUnsafe('deep/nested/file.ts', sandbox);
|
||||
expect(result).toBe('/tmp/test-sandbox/deep/nested/file.ts');
|
||||
});
|
||||
|
||||
it('SandboxEscapeError includes the user path and sandbox in message', () => {
|
||||
let caught: unknown;
|
||||
try {
|
||||
guardPathUnsafe('../escape.txt', sandbox);
|
||||
} catch (err) {
|
||||
caught = err;
|
||||
}
|
||||
expect(caught).toBeInstanceOf(SandboxEscapeError);
|
||||
const e = caught as SandboxEscapeError;
|
||||
expect(e.userPath).toBe('../escape.txt');
|
||||
expect(e.sandboxDir).toBe(sandbox);
|
||||
expect(e.message).toContain('Path escape attempt blocked');
|
||||
});
|
||||
});
|
||||
|
||||
describe('guardPath', () => {
|
||||
let tmpDir: string;
|
||||
|
||||
it('allows an existing path inside a real temp sandbox', () => {
|
||||
tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), 'path-guard-test-'));
|
||||
try {
|
||||
const subdir = path.join(tmpDir, 'subdir');
|
||||
fs.mkdirSync(subdir);
|
||||
const result = guardPath('subdir', tmpDir);
|
||||
expect(result).toBe(subdir);
|
||||
} finally {
|
||||
fs.rmSync(tmpDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
it('allows sandbox root itself', () => {
|
||||
tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), 'path-guard-test-'));
|
||||
try {
|
||||
const result = guardPath('.', tmpDir);
|
||||
// realpathSync resolves the tmpdir symlinks (macOS /var -> /private/var)
|
||||
const realTmp = fs.realpathSync.native(tmpDir);
|
||||
expect(result).toBe(realTmp);
|
||||
} finally {
|
||||
fs.rmSync(tmpDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
it('rejects path traversal with ../ on existing sandbox', () => {
|
||||
tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), 'path-guard-test-'));
|
||||
try {
|
||||
expect(() => guardPath('../escape', tmpDir)).toThrow(SandboxEscapeError);
|
||||
} finally {
|
||||
fs.rmSync(tmpDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
it('rejects absolute path outside sandbox', () => {
|
||||
tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), 'path-guard-test-'));
|
||||
try {
|
||||
expect(() => guardPath('/etc/passwd', tmpDir)).toThrow(SandboxEscapeError);
|
||||
} finally {
|
||||
fs.rmSync(tmpDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
});
|
||||
58
apps/gateway/src/agent/tools/path-guard.ts
Normal file
58
apps/gateway/src/agent/tools/path-guard.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import path from 'node:path';
|
||||
import fs from 'node:fs';
|
||||
|
||||
/**
|
||||
* Resolves a user-provided path and verifies it is inside the allowed sandbox directory.
|
||||
* Throws SandboxEscapeError if the resolved path is outside the sandbox.
|
||||
*
|
||||
* Uses realpathSync to resolve symlinks in the sandbox root. The user-supplied path
|
||||
* is checked for containment AFTER lexical resolution but BEFORE resolving any symlinks
|
||||
* within the user path — so symlink escape attempts are caught too.
|
||||
*
|
||||
* @param userPath - The path provided by the agent (may be relative or absolute)
|
||||
* @param sandboxDir - The allowed root directory (already validated on session creation)
|
||||
* @returns The resolved absolute path, guaranteed to be within sandboxDir
|
||||
*/
|
||||
export function guardPath(userPath: string, sandboxDir: string): string {
|
||||
const resolved = path.resolve(sandboxDir, userPath);
|
||||
const sandboxResolved = fs.realpathSync.native(sandboxDir);
|
||||
|
||||
// Normalize both paths to resolve any symlinks in the sandbox root itself.
|
||||
// For the user path, we check containment BEFORE resolving symlinks in the path
|
||||
// (so we catch symlink escape attempts too — the resolved path must still be under sandbox)
|
||||
if (!resolved.startsWith(sandboxResolved + path.sep) && resolved !== sandboxResolved) {
|
||||
throw new SandboxEscapeError(userPath, sandboxDir, resolved);
|
||||
}
|
||||
|
||||
return resolved;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates a path without resolving symlinks in the user-provided portion.
|
||||
* Use for paths that may not exist yet (creates, writes).
|
||||
*
|
||||
* Performs a lexical containment check only using path.resolve.
|
||||
*/
|
||||
export function guardPathUnsafe(userPath: string, sandboxDir: string): string {
|
||||
const resolved = path.resolve(sandboxDir, userPath);
|
||||
const sandboxAbs = path.resolve(sandboxDir);
|
||||
|
||||
if (!resolved.startsWith(sandboxAbs + path.sep) && resolved !== sandboxAbs) {
|
||||
throw new SandboxEscapeError(userPath, sandboxDir, resolved);
|
||||
}
|
||||
|
||||
return resolved;
|
||||
}
|
||||
|
||||
export class SandboxEscapeError extends Error {
|
||||
constructor(
|
||||
public readonly userPath: string,
|
||||
public readonly sandboxDir: string,
|
||||
public readonly resolvedPath: string,
|
||||
) {
|
||||
super(
|
||||
`Path escape attempt blocked: "${userPath}" resolves to "${resolvedPath}" which is outside sandbox "${sandboxDir}"`,
|
||||
);
|
||||
this.name = 'SandboxEscapeError';
|
||||
}
|
||||
}
|
||||
218
apps/gateway/src/agent/tools/shell-tools.ts
Normal file
218
apps/gateway/src/agent/tools/shell-tools.ts
Normal file
@@ -0,0 +1,218 @@
|
||||
import { Type } from '@sinclair/typebox';
|
||||
import type { ToolDefinition } from '@mariozechner/pi-coding-agent';
|
||||
import { spawn } from 'node:child_process';
|
||||
import { guardPath, SandboxEscapeError } from './path-guard.js';
|
||||
|
||||
const DEFAULT_TIMEOUT_MS = 30_000;
|
||||
const MAX_OUTPUT_BYTES = 100 * 1024; // 100 KB
|
||||
|
||||
/**
|
||||
* Commands that are outright blocked for safety.
|
||||
* This is a denylist; the agent should be instructed to use
|
||||
* the least-privilege command necessary.
|
||||
*/
|
||||
const BLOCKED_COMMANDS = new Set([
|
||||
'rm',
|
||||
'rmdir',
|
||||
'mkfs',
|
||||
'dd',
|
||||
'format',
|
||||
'fdisk',
|
||||
'parted',
|
||||
'shred',
|
||||
'wipefs',
|
||||
'sudo',
|
||||
'su',
|
||||
'chown',
|
||||
'chmod',
|
||||
'passwd',
|
||||
'useradd',
|
||||
'userdel',
|
||||
'groupadd',
|
||||
'shutdown',
|
||||
'reboot',
|
||||
'halt',
|
||||
'poweroff',
|
||||
'kill',
|
||||
'killall',
|
||||
'pkill',
|
||||
'curl',
|
||||
'wget',
|
||||
'nc',
|
||||
'netcat',
|
||||
'ncat',
|
||||
'ssh',
|
||||
'scp',
|
||||
'sftp',
|
||||
'rsync',
|
||||
'iptables',
|
||||
'ip6tables',
|
||||
'nft',
|
||||
'ufw',
|
||||
'firewall-cmd',
|
||||
'docker',
|
||||
'podman',
|
||||
'kubectl',
|
||||
'helm',
|
||||
'terraform',
|
||||
'ansible',
|
||||
'crontab',
|
||||
'at',
|
||||
'batch',
|
||||
]);
|
||||
|
||||
function extractBaseCommand(command: string): string {
|
||||
// Extract the first word (the binary name), stripping path
|
||||
const trimmed = command.trim();
|
||||
const firstToken = trimmed.split(/\s+/)[0] ?? '';
|
||||
return firstToken.split('/').pop() ?? firstToken;
|
||||
}
|
||||
|
||||
function runCommand(
|
||||
command: string,
|
||||
options: { timeoutMs: number; cwd?: string },
|
||||
): Promise<{ stdout: string; stderr: string; exitCode: number | null; timedOut: boolean }> {
|
||||
return new Promise((resolve) => {
|
||||
const child = spawn('sh', ['-c', command], {
|
||||
cwd: options.cwd,
|
||||
stdio: ['ignore', 'pipe', 'pipe'],
|
||||
detached: false,
|
||||
});
|
||||
|
||||
let stdout = '';
|
||||
let stderr = '';
|
||||
let timedOut = false;
|
||||
let totalBytes = 0;
|
||||
let truncated = false;
|
||||
|
||||
child.stdout?.on('data', (chunk: Buffer) => {
|
||||
if (truncated) return;
|
||||
totalBytes += chunk.length;
|
||||
if (totalBytes > MAX_OUTPUT_BYTES) {
|
||||
stdout += chunk.subarray(0, MAX_OUTPUT_BYTES - (totalBytes - chunk.length)).toString();
|
||||
stdout += '\n[output truncated at 100 KB limit]';
|
||||
truncated = true;
|
||||
child.kill('SIGTERM');
|
||||
} else {
|
||||
stdout += chunk.toString();
|
||||
}
|
||||
});
|
||||
|
||||
child.stderr?.on('data', (chunk: Buffer) => {
|
||||
if (stderr.length < MAX_OUTPUT_BYTES) {
|
||||
stderr += chunk.toString();
|
||||
}
|
||||
});
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
timedOut = true;
|
||||
child.kill('SIGTERM');
|
||||
setTimeout(() => {
|
||||
try {
|
||||
child.kill('SIGKILL');
|
||||
} catch {
|
||||
// already exited
|
||||
}
|
||||
}, 2000);
|
||||
}, options.timeoutMs);
|
||||
|
||||
child.on('close', (exitCode) => {
|
||||
clearTimeout(timer);
|
||||
resolve({ stdout, stderr, exitCode, timedOut });
|
||||
});
|
||||
|
||||
child.on('error', (err) => {
|
||||
clearTimeout(timer);
|
||||
resolve({ stdout, stderr: stderr + String(err), exitCode: null, timedOut: false });
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
export function createShellTools(sandboxDir?: string): ToolDefinition[] {
|
||||
const defaultCwd = sandboxDir ?? process.cwd();
|
||||
|
||||
const shellExec: ToolDefinition = {
|
||||
name: 'shell_exec',
|
||||
label: 'Shell Execute',
|
||||
description:
|
||||
'Execute a shell command with timeout and output limits. Dangerous commands (rm, sudo, docker, etc.) are blocked. Working directory is restricted to the session sandbox.',
|
||||
parameters: Type.Object({
|
||||
command: Type.String({ description: 'Shell command to execute' }),
|
||||
cwd: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
'Working directory for the command (relative to sandbox or absolute within it).',
|
||||
}),
|
||||
),
|
||||
timeout: Type.Optional(
|
||||
Type.Number({ description: 'Timeout in milliseconds (default 30000, max 60000)' }),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { command, cwd, timeout } = params as {
|
||||
command: string;
|
||||
cwd?: string;
|
||||
timeout?: number;
|
||||
};
|
||||
|
||||
const base = extractBaseCommand(command);
|
||||
if (BLOCKED_COMMANDS.has(base)) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: `Error: command "${base}" is blocked for safety reasons.`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
const timeoutMs = Math.min(timeout ?? DEFAULT_TIMEOUT_MS, 60_000);
|
||||
let safeCwd: string;
|
||||
try {
|
||||
safeCwd = guardPath(cwd ?? '.', defaultCwd);
|
||||
} catch (err) {
|
||||
if (err instanceof SandboxEscapeError) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${err.message}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${String(err)}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
const result = await runCommand(command, {
|
||||
timeoutMs,
|
||||
cwd: safeCwd,
|
||||
});
|
||||
|
||||
if (result.timedOut) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: `Command timed out after ${timeoutMs}ms.\nPartial stdout:\n${result.stdout}\nPartial stderr:\n${result.stderr}`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
const parts: string[] = [];
|
||||
if (result.stdout) parts.push(`stdout:\n${result.stdout}`);
|
||||
if (result.stderr) parts.push(`stderr:\n${result.stderr}`);
|
||||
parts.push(`exit code: ${result.exitCode ?? 'null'}`);
|
||||
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: parts.join('\n') }],
|
||||
details: undefined,
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
return [shellExec];
|
||||
}
|
||||
180
apps/gateway/src/agent/tools/skill-tools.ts
Normal file
180
apps/gateway/src/agent/tools/skill-tools.ts
Normal file
@@ -0,0 +1,180 @@
|
||||
import { Type } from '@sinclair/typebox';
|
||||
import type { ToolDefinition } from '@mariozechner/pi-coding-agent';
|
||||
import type { SkillsService } from '../../skills/skills.service.js';
|
||||
|
||||
/**
|
||||
* Creates meta-tools that allow agents to list and invoke skills from the catalog.
|
||||
*
|
||||
* skill_list — list all enabled skills
|
||||
* skill_invoke — execute a skill by name with parameters
|
||||
*/
|
||||
export function createSkillTools(skillsService: SkillsService): ToolDefinition[] {
|
||||
const skillList: ToolDefinition = {
|
||||
name: 'skill_list',
|
||||
label: 'List Skills',
|
||||
description:
|
||||
'List all enabled skills available in the catalog. Returns name, description, type, and config for each skill.',
|
||||
parameters: Type.Object({}),
|
||||
async execute() {
|
||||
const skills = await skillsService.findEnabled();
|
||||
const summary = skills.map((s) => ({
|
||||
name: s.name,
|
||||
description: s.description,
|
||||
version: s.version,
|
||||
source: s.source,
|
||||
config: s.config,
|
||||
}));
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text:
|
||||
summary.length > 0
|
||||
? JSON.stringify(summary, null, 2)
|
||||
: 'No enabled skills found in catalog.',
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
const skillInvoke: ToolDefinition = {
|
||||
name: 'skill_invoke',
|
||||
label: 'Invoke Skill',
|
||||
description:
|
||||
'Invoke a skill from the catalog by name. For prompt skills, returns the prompt addition. ' +
|
||||
'For tool skills, executes the embedded logic. For workflow skills, returns the workflow steps.',
|
||||
parameters: Type.Object({
|
||||
name: Type.String({ description: 'Skill name to invoke' }),
|
||||
params: Type.Optional(
|
||||
Type.Record(Type.String(), Type.Unknown(), {
|
||||
description: 'Parameters to pass to the skill (if applicable)',
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, rawParams) {
|
||||
const { name, params } = rawParams as {
|
||||
name: string;
|
||||
params?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
const skill = await skillsService.findByName(name);
|
||||
if (!skill) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Skill not found: ${name}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
if (!skill.enabled) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Skill is disabled: ${name}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
const config = (skill.config ?? {}) as Record<string, unknown>;
|
||||
const skillType = (config['type'] as string | undefined) ?? 'prompt';
|
||||
|
||||
switch (skillType) {
|
||||
case 'prompt': {
|
||||
const promptAddition =
|
||||
(config['prompt'] as string | undefined) ?? skill.description ?? '';
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: promptAddition
|
||||
? `[Skill: ${name}] ${promptAddition}`
|
||||
: `[Skill: ${name}] No prompt content defined.`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
case 'tool': {
|
||||
const toolLogic = config['logic'] as string | undefined;
|
||||
if (!toolLogic) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: `[Skill: ${name}] Tool skill has no logic defined.`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
// Inline tool skill execution: the logic field holds a JS expression or template
|
||||
// For safety, treat it as a template that can reference params
|
||||
const result = renderTemplate(toolLogic, { params: params ?? {}, skill });
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `[Skill: ${name}]\n${result}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
case 'workflow': {
|
||||
const steps = config['steps'] as unknown[] | undefined;
|
||||
if (!steps || steps.length === 0) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: `[Skill: ${name}] Workflow has no steps defined.`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: `[Skill: ${name}] Workflow steps:\n${JSON.stringify(steps, null, 2)}`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
default: {
|
||||
// Unknown type — return full config so the agent can decide what to do
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: `[Skill: ${name}] (type: ${skillType})\n${JSON.stringify(config, null, 2)}`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
return [skillList, skillInvoke];
|
||||
}
|
||||
|
||||
/**
|
||||
* Minimal template renderer — replaces {{key}} with values from the context.
|
||||
* Used for tool skill logic templates.
|
||||
*/
|
||||
function renderTemplate(template: string, context: Record<string, unknown>): string {
|
||||
return template.replace(/\{\{(\w+(?:\.\w+)*)\}\}/g, (_match, path: string) => {
|
||||
const parts = path.split('.');
|
||||
let value: unknown = context;
|
||||
for (const part of parts) {
|
||||
if (value != null && typeof value === 'object') {
|
||||
value = (value as Record<string, unknown>)[part];
|
||||
} else {
|
||||
value = undefined;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return value !== undefined && value !== null ? String(value) : '';
|
||||
});
|
||||
}
|
||||
225
apps/gateway/src/agent/tools/web-tools.ts
Normal file
225
apps/gateway/src/agent/tools/web-tools.ts
Normal file
@@ -0,0 +1,225 @@
|
||||
import { Type } from '@sinclair/typebox';
|
||||
import type { ToolDefinition } from '@mariozechner/pi-coding-agent';
|
||||
|
||||
const DEFAULT_TIMEOUT_MS = 15_000;
|
||||
const MAX_RESPONSE_BYTES = 512 * 1024; // 512 KB
|
||||
|
||||
/**
|
||||
* Blocked URL patterns (private IP ranges, localhost, link-local).
|
||||
*/
|
||||
const BLOCKED_HOSTNAMES = [
|
||||
/^localhost$/i,
|
||||
/^127\./,
|
||||
/^10\./,
|
||||
/^172\.(1[6-9]|2\d|3[01])\./,
|
||||
/^192\.168\./,
|
||||
/^::1$/,
|
||||
/^fc[0-9a-f][0-9a-f]:/i,
|
||||
/^fe80:/i,
|
||||
/^0\.0\.0\.0$/,
|
||||
/^169\.254\./,
|
||||
];
|
||||
|
||||
function isBlockedUrl(urlString: string): string | null {
|
||||
let parsed: URL;
|
||||
try {
|
||||
parsed = new URL(urlString);
|
||||
} catch {
|
||||
return `Invalid URL: ${urlString}`;
|
||||
}
|
||||
if (parsed.protocol !== 'http:' && parsed.protocol !== 'https:') {
|
||||
return `Unsupported protocol: ${parsed.protocol}. Only http and https are allowed.`;
|
||||
}
|
||||
const hostname = parsed.hostname;
|
||||
for (const pattern of BLOCKED_HOSTNAMES) {
|
||||
if (pattern.test(hostname)) {
|
||||
return `Blocked: requests to "${hostname}" are not allowed (private/local addresses).`;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
async function fetchWithLimit(
|
||||
url: string,
|
||||
options: RequestInit,
|
||||
timeoutMs: number,
|
||||
): Promise<{ text: string; status: number; contentType: string }> {
|
||||
const controller = new AbortController();
|
||||
const timer = setTimeout(() => controller.abort(), timeoutMs);
|
||||
|
||||
try {
|
||||
const response = await fetch(url, { ...options, signal: controller.signal });
|
||||
const contentType = response.headers.get('content-type') ?? '';
|
||||
|
||||
// Stream response and enforce size limit
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
return { text: '', status: response.status, contentType };
|
||||
}
|
||||
|
||||
const chunks: Uint8Array[] = [];
|
||||
let totalBytes = 0;
|
||||
let truncated = false;
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
totalBytes += value.length;
|
||||
if (totalBytes > MAX_RESPONSE_BYTES) {
|
||||
const remaining = MAX_RESPONSE_BYTES - (totalBytes - value.length);
|
||||
chunks.push(value.subarray(0, remaining));
|
||||
truncated = true;
|
||||
reader.cancel();
|
||||
break;
|
||||
}
|
||||
chunks.push(value);
|
||||
}
|
||||
|
||||
const combined = new Uint8Array(chunks.reduce((acc, c) => acc + c.length, 0));
|
||||
let offset = 0;
|
||||
for (const chunk of chunks) {
|
||||
combined.set(chunk, offset);
|
||||
offset += chunk.length;
|
||||
}
|
||||
|
||||
let text = new TextDecoder().decode(combined);
|
||||
if (truncated) {
|
||||
text += '\n[response truncated at 512 KB limit]';
|
||||
}
|
||||
|
||||
return { text, status: response.status, contentType };
|
||||
} finally {
|
||||
clearTimeout(timer);
|
||||
}
|
||||
}
|
||||
|
||||
export function createWebTools(): ToolDefinition[] {
|
||||
const webGet: ToolDefinition = {
|
||||
name: 'web_get',
|
||||
label: 'HTTP GET',
|
||||
description:
|
||||
'Perform an HTTP GET request and return the response body. Private/local addresses are blocked.',
|
||||
parameters: Type.Object({
|
||||
url: Type.String({ description: 'URL to fetch (http/https only)' }),
|
||||
headers: Type.Optional(
|
||||
Type.Record(Type.String(), Type.String(), {
|
||||
description: 'Optional request headers as key-value pairs',
|
||||
}),
|
||||
),
|
||||
timeout: Type.Optional(
|
||||
Type.Number({ description: 'Timeout in milliseconds (default 15000, max 30000)' }),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { url, headers, timeout } = params as {
|
||||
url: string;
|
||||
headers?: Record<string, string>;
|
||||
timeout?: number;
|
||||
};
|
||||
|
||||
const blocked = isBlockedUrl(url);
|
||||
if (blocked) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${blocked}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
const timeoutMs = Math.min(timeout ?? DEFAULT_TIMEOUT_MS, 30_000);
|
||||
|
||||
try {
|
||||
const result = await fetchWithLimit(
|
||||
url,
|
||||
{ method: 'GET', headers: headers ?? {} },
|
||||
timeoutMs,
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: `HTTP ${result.status} (${result.contentType})\n\n${result.text}`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
} catch (err) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error fetching URL: ${msg}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
const webPost: ToolDefinition = {
|
||||
name: 'web_post',
|
||||
label: 'HTTP POST',
|
||||
description:
|
||||
'Perform an HTTP POST request with a JSON or text body. Private/local addresses are blocked.',
|
||||
parameters: Type.Object({
|
||||
url: Type.String({ description: 'URL to POST to (http/https only)' }),
|
||||
body: Type.String({ description: 'Request body (JSON string or plain text)' }),
|
||||
contentType: Type.Optional(
|
||||
Type.String({ description: 'Content-Type header (default: application/json)' }),
|
||||
),
|
||||
headers: Type.Optional(
|
||||
Type.Record(Type.String(), Type.String(), {
|
||||
description: 'Optional additional request headers',
|
||||
}),
|
||||
),
|
||||
timeout: Type.Optional(
|
||||
Type.Number({ description: 'Timeout in milliseconds (default 15000, max 30000)' }),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params) {
|
||||
const { url, body, contentType, headers, timeout } = params as {
|
||||
url: string;
|
||||
body: string;
|
||||
contentType?: string;
|
||||
headers?: Record<string, string>;
|
||||
timeout?: number;
|
||||
};
|
||||
|
||||
const blocked = isBlockedUrl(url);
|
||||
if (blocked) {
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error: ${blocked}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
const timeoutMs = Math.min(timeout ?? DEFAULT_TIMEOUT_MS, 30_000);
|
||||
const ct = contentType ?? 'application/json';
|
||||
|
||||
try {
|
||||
const result = await fetchWithLimit(
|
||||
url,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': ct, ...(headers ?? {}) },
|
||||
body,
|
||||
},
|
||||
timeoutMs,
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: `HTTP ${result.status} (${result.contentType})\n\n${result.text}`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
} catch (err) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: `Error posting to URL: ${msg}` }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
return [webGet, webPost];
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { APP_GUARD } from '@nestjs/core';
|
||||
import { HealthController } from './health/health.controller.js';
|
||||
import { DatabaseModule } from './database/database.module.js';
|
||||
import { AuthModule } from './auth/auth.module.js';
|
||||
@@ -13,9 +14,19 @@ import { CoordModule } from './coord/coord.module.js';
|
||||
import { MemoryModule } from './memory/memory.module.js';
|
||||
import { LogModule } from './log/log.module.js';
|
||||
import { SkillsModule } from './skills/skills.module.js';
|
||||
import { PluginModule } from './plugin/plugin.module.js';
|
||||
import { McpModule } from './mcp/mcp.module.js';
|
||||
import { AdminModule } from './admin/admin.module.js';
|
||||
import { CommandsModule } from './commands/commands.module.js';
|
||||
import { PreferencesModule } from './preferences/preferences.module.js';
|
||||
import { GCModule } from './gc/gc.module.js';
|
||||
import { ReloadModule } from './reload/reload.module.js';
|
||||
import { WorkspaceModule } from './workspace/workspace.module.js';
|
||||
import { ThrottlerGuard, ThrottlerModule } from '@nestjs/throttler';
|
||||
|
||||
@Module({
|
||||
imports: [
|
||||
ThrottlerModule.forRoot([{ name: 'default', ttl: 60_000, limit: 60 }]),
|
||||
DatabaseModule,
|
||||
AuthModule,
|
||||
BrainModule,
|
||||
@@ -29,7 +40,21 @@ import { SkillsModule } from './skills/skills.module.js';
|
||||
MemoryModule,
|
||||
LogModule,
|
||||
SkillsModule,
|
||||
PluginModule,
|
||||
McpModule,
|
||||
AdminModule,
|
||||
PreferencesModule,
|
||||
CommandsModule,
|
||||
GCModule,
|
||||
ReloadModule,
|
||||
WorkspaceModule,
|
||||
],
|
||||
controllers: [HealthController],
|
||||
providers: [
|
||||
{
|
||||
provide: APP_GUARD,
|
||||
useClass: ThrottlerGuard,
|
||||
},
|
||||
],
|
||||
})
|
||||
export class AppModule {}
|
||||
|
||||
@@ -7,16 +7,17 @@ import { AUTH } from './auth.tokens.js';
|
||||
export function mountAuthHandler(app: NestFastifyApplication): void {
|
||||
const auth = app.get<Auth>(AUTH);
|
||||
const nodeHandler = toNodeHandler(auth);
|
||||
const corsOrigin = process.env['GATEWAY_CORS_ORIGIN'] ?? 'http://localhost:3000';
|
||||
|
||||
const fastify = app.getHttpAdapter().getInstance();
|
||||
|
||||
// Use Fastify's addHook to intercept auth requests at the raw HTTP level,
|
||||
// before Fastify's body parser runs. This avoids conflicts with NestJS's
|
||||
// custom content-type parser.
|
||||
// BetterAuth is mounted at the raw HTTP level via Fastify's onRequest hook,
|
||||
// bypassing NestJS middleware (including CORS). We must set CORS headers
|
||||
// manually on the raw response before handing off to BetterAuth.
|
||||
fastify.addHook(
|
||||
'onRequest',
|
||||
(
|
||||
req: { raw: IncomingMessage; url: string },
|
||||
req: { raw: IncomingMessage; url: string; method: string },
|
||||
reply: { raw: ServerResponse; hijack: () => void },
|
||||
done: () => void,
|
||||
) => {
|
||||
@@ -25,6 +26,27 @@ export function mountAuthHandler(app: NestFastifyApplication): void {
|
||||
return;
|
||||
}
|
||||
|
||||
const origin = req.raw.headers.origin;
|
||||
const allowed = corsOrigin.split(',').map((o) => o.trim());
|
||||
|
||||
if (origin && allowed.includes(origin)) {
|
||||
reply.raw.setHeader('Access-Control-Allow-Origin', origin);
|
||||
reply.raw.setHeader('Access-Control-Allow-Credentials', 'true');
|
||||
reply.raw.setHeader(
|
||||
'Access-Control-Allow-Methods',
|
||||
'GET, POST, PUT, PATCH, DELETE, OPTIONS',
|
||||
);
|
||||
reply.raw.setHeader('Access-Control-Allow-Headers', 'Content-Type, Authorization, Cookie');
|
||||
}
|
||||
|
||||
// Handle preflight
|
||||
if (req.method === 'OPTIONS') {
|
||||
reply.hijack();
|
||||
reply.raw.writeHead(204);
|
||||
reply.raw.end();
|
||||
return;
|
||||
}
|
||||
|
||||
reply.hijack();
|
||||
nodeHandler(req.raw as IncomingMessage, reply.raw as ServerResponse)
|
||||
.then(() => {
|
||||
|
||||
@@ -3,9 +3,11 @@ import { createAuth, type Auth } from '@mosaic/auth';
|
||||
import type { Db } from '@mosaic/db';
|
||||
import { DB } from '../database/database.module.js';
|
||||
import { AUTH } from './auth.tokens.js';
|
||||
import { SsoController } from './sso.controller.js';
|
||||
|
||||
@Global()
|
||||
@Module({
|
||||
controllers: [SsoController],
|
||||
providers: [
|
||||
{
|
||||
provide: AUTH,
|
||||
|
||||
11
apps/gateway/src/auth/resource-ownership.ts
Normal file
11
apps/gateway/src/auth/resource-ownership.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import { ForbiddenException } from '@nestjs/common';
|
||||
|
||||
export function assertOwner(
|
||||
ownerId: string | null | undefined,
|
||||
userId: string,
|
||||
resourceName: string,
|
||||
): void {
|
||||
if (!ownerId || ownerId !== userId) {
|
||||
throw new ForbiddenException(`${resourceName} does not belong to the current user`);
|
||||
}
|
||||
}
|
||||
40
apps/gateway/src/auth/sso.controller.spec.ts
Normal file
40
apps/gateway/src/auth/sso.controller.spec.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
import { SsoController } from './sso.controller.js';
|
||||
|
||||
describe('SsoController', () => {
|
||||
afterEach(() => {
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
it('lists configured OIDC providers', () => {
|
||||
vi.stubEnv('WORKOS_CLIENT_ID', 'workos-client');
|
||||
vi.stubEnv('WORKOS_CLIENT_SECRET', 'workos-secret');
|
||||
vi.stubEnv('WORKOS_ISSUER', 'https://auth.workos.com/sso/client_123');
|
||||
|
||||
const controller = new SsoController();
|
||||
const providers = controller.list();
|
||||
|
||||
expect(providers.find((provider) => provider.id === 'workos')).toMatchObject({
|
||||
configured: true,
|
||||
loginMode: 'oidc',
|
||||
callbackPath: '/api/auth/oauth2/callback/workos',
|
||||
teamSync: { enabled: true, claim: 'organization_id' },
|
||||
});
|
||||
});
|
||||
|
||||
it('prefers SAML fallback for Keycloak when only the SAML login URL is configured', () => {
|
||||
vi.stubEnv('KEYCLOAK_SAML_LOGIN_URL', 'https://sso.example.com/realms/mosaic/protocol/saml');
|
||||
|
||||
const controller = new SsoController();
|
||||
const providers = controller.list();
|
||||
|
||||
expect(providers.find((provider) => provider.id === 'keycloak')).toMatchObject({
|
||||
configured: true,
|
||||
loginMode: 'saml',
|
||||
samlFallback: {
|
||||
configured: true,
|
||||
loginUrl: 'https://sso.example.com/realms/mosaic/protocol/saml',
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
10
apps/gateway/src/auth/sso.controller.ts
Normal file
10
apps/gateway/src/auth/sso.controller.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import { Controller, Get } from '@nestjs/common';
|
||||
import { buildSsoDiscovery, type SsoProviderDiscovery } from '@mosaic/auth';
|
||||
|
||||
@Controller('api/sso/providers')
|
||||
export class SsoController {
|
||||
@Get()
|
||||
list(): SsoProviderDiscovery[] {
|
||||
return buildSsoDiscovery();
|
||||
}
|
||||
}
|
||||
80
apps/gateway/src/chat/__tests__/chat-security.test.ts
Normal file
80
apps/gateway/src/chat/__tests__/chat-security.test.ts
Normal file
@@ -0,0 +1,80 @@
|
||||
import { readFileSync } from 'node:fs';
|
||||
import { resolve } from 'node:path';
|
||||
import { validateSync } from 'class-validator';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import { SendMessageDto } from '../../conversations/conversations.dto.js';
|
||||
import { ChatRequestDto } from '../chat.dto.js';
|
||||
import { validateSocketSession } from '../chat.gateway-auth.js';
|
||||
|
||||
describe('Chat controller source hardening', () => {
|
||||
it('applies AuthGuard and reads the current user', () => {
|
||||
const source = readFileSync(resolve('src/chat/chat.controller.ts'), 'utf8');
|
||||
|
||||
expect(source).toContain('@UseGuards(AuthGuard)');
|
||||
expect(source).toContain('@CurrentUser() user: { id: string }');
|
||||
});
|
||||
});
|
||||
|
||||
describe('WebSocket session authentication', () => {
|
||||
it('returns null when the handshake does not resolve to a session', async () => {
|
||||
const result = await validateSocketSession(
|
||||
{},
|
||||
{
|
||||
api: {
|
||||
getSession: vi.fn().mockResolvedValue(null),
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('returns the resolved session when Better Auth accepts the headers', async () => {
|
||||
const session = { user: { id: 'user-1' }, session: { id: 'session-1' } };
|
||||
|
||||
const result = await validateSocketSession(
|
||||
{ cookie: 'session=abc' },
|
||||
{
|
||||
api: {
|
||||
getSession: vi.fn().mockResolvedValue(session),
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
expect(result).toEqual(session);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Chat DTO validation', () => {
|
||||
it('rejects unsupported message roles', () => {
|
||||
const dto = Object.assign(new SendMessageDto(), {
|
||||
content: 'hello',
|
||||
role: 'moderator',
|
||||
});
|
||||
|
||||
const errors = validateSync(dto);
|
||||
|
||||
expect(errors.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('rejects oversized conversation message content above 10000 characters', () => {
|
||||
const dto = Object.assign(new SendMessageDto(), {
|
||||
content: 'x'.repeat(10_001),
|
||||
role: 'user',
|
||||
});
|
||||
|
||||
const errors = validateSync(dto);
|
||||
|
||||
expect(errors.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('rejects oversized chat content above 10000 characters', () => {
|
||||
const dto = Object.assign(new ChatRequestDto(), {
|
||||
content: 'x'.repeat(10_001),
|
||||
});
|
||||
|
||||
const errors = validateSync(dto);
|
||||
|
||||
expect(errors.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
@@ -1,12 +1,20 @@
|
||||
import { Controller, Post, Body, Logger, HttpException, HttpStatus, Inject } from '@nestjs/common';
|
||||
import {
|
||||
Controller,
|
||||
Post,
|
||||
Body,
|
||||
Logger,
|
||||
HttpException,
|
||||
HttpStatus,
|
||||
Inject,
|
||||
UseGuards,
|
||||
} from '@nestjs/common';
|
||||
import type { AgentSessionEvent } from '@mariozechner/pi-coding-agent';
|
||||
import { Throttle } from '@nestjs/throttler';
|
||||
import { AgentService } from '../agent/agent.service.js';
|
||||
import { AuthGuard } from '../auth/auth.guard.js';
|
||||
import { CurrentUser } from '../auth/current-user.decorator.js';
|
||||
import { v4 as uuid } from 'uuid';
|
||||
|
||||
interface ChatRequest {
|
||||
conversationId?: string;
|
||||
content: string;
|
||||
}
|
||||
import { ChatRequestDto } from './chat.dto.js';
|
||||
|
||||
interface ChatResponse {
|
||||
conversationId: string;
|
||||
@@ -14,13 +22,18 @@ interface ChatResponse {
|
||||
}
|
||||
|
||||
@Controller('api/chat')
|
||||
@UseGuards(AuthGuard)
|
||||
export class ChatController {
|
||||
private readonly logger = new Logger(ChatController.name);
|
||||
|
||||
constructor(@Inject(AgentService) private readonly agentService: AgentService) {}
|
||||
|
||||
@Post()
|
||||
async chat(@Body() body: ChatRequest): Promise<ChatResponse> {
|
||||
@Throttle({ default: { limit: 10, ttl: 60_000 } })
|
||||
async chat(
|
||||
@Body() body: ChatRequestDto,
|
||||
@CurrentUser() user: { id: string },
|
||||
): Promise<ChatResponse> {
|
||||
const conversationId = body.conversationId ?? uuid();
|
||||
|
||||
try {
|
||||
@@ -36,6 +49,8 @@ export class ChatController {
|
||||
throw new HttpException('Agent session unavailable', HttpStatus.SERVICE_UNAVAILABLE);
|
||||
}
|
||||
|
||||
this.logger.debug(`Handling chat request for user=${user.id}, conversation=${conversationId}`);
|
||||
|
||||
let responseText = '';
|
||||
|
||||
const done = new Promise<void>((resolve, reject) => {
|
||||
|
||||
35
apps/gateway/src/chat/chat.dto.ts
Normal file
35
apps/gateway/src/chat/chat.dto.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import { IsOptional, IsString, IsUUID, MaxLength } from 'class-validator';
|
||||
|
||||
export class ChatRequestDto {
|
||||
@IsOptional()
|
||||
@IsUUID()
|
||||
conversationId?: string;
|
||||
|
||||
@IsString()
|
||||
@MaxLength(10_000)
|
||||
content!: string;
|
||||
}
|
||||
|
||||
export class ChatSocketMessageDto {
|
||||
@IsOptional()
|
||||
@IsUUID()
|
||||
conversationId?: string;
|
||||
|
||||
@IsString()
|
||||
@MaxLength(10_000)
|
||||
content!: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
provider?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
modelId?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsUUID()
|
||||
agentId?: string;
|
||||
}
|
||||
30
apps/gateway/src/chat/chat.gateway-auth.ts
Normal file
30
apps/gateway/src/chat/chat.gateway-auth.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
import type { IncomingHttpHeaders } from 'node:http';
|
||||
import { fromNodeHeaders } from 'better-auth/node';
|
||||
|
||||
export interface SocketSessionResult {
|
||||
session: unknown;
|
||||
user: { id: string };
|
||||
}
|
||||
|
||||
export interface SessionAuth {
|
||||
api: {
|
||||
getSession(context: { headers: Headers }): Promise<SocketSessionResult | null>;
|
||||
};
|
||||
}
|
||||
|
||||
export async function validateSocketSession(
|
||||
headers: IncomingHttpHeaders,
|
||||
auth: SessionAuth,
|
||||
): Promise<SocketSessionResult | null> {
|
||||
const sessionHeaders = fromNodeHeaders(headers);
|
||||
const result = await auth.api.getSession({ headers: sessionHeaders });
|
||||
|
||||
if (!result) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
session: result.session,
|
||||
user: { id: result.user.id },
|
||||
};
|
||||
}
|
||||
@@ -11,18 +11,34 @@ import {
|
||||
} from '@nestjs/websockets';
|
||||
import { Server, Socket } from 'socket.io';
|
||||
import type { AgentSessionEvent } from '@mariozechner/pi-coding-agent';
|
||||
import { AgentService } from '../agent/agent.service.js';
|
||||
import type { Auth } from '@mosaic/auth';
|
||||
import type { Brain } from '@mosaic/brain';
|
||||
import type { SetThinkingPayload, SlashCommandPayload, SystemReloadPayload } from '@mosaic/types';
|
||||
import { AgentService, type ConversationHistoryMessage } from '../agent/agent.service.js';
|
||||
import { AUTH } from '../auth/auth.tokens.js';
|
||||
import { BRAIN } from '../brain/brain.tokens.js';
|
||||
import { CommandRegistryService } from '../commands/command-registry.service.js';
|
||||
import { CommandExecutorService } from '../commands/command-executor.service.js';
|
||||
import { v4 as uuid } from 'uuid';
|
||||
import { ChatSocketMessageDto } from './chat.dto.js';
|
||||
import { validateSocketSession } from './chat.gateway-auth.js';
|
||||
|
||||
interface ChatMessage {
|
||||
conversationId?: string;
|
||||
content: string;
|
||||
provider?: string;
|
||||
modelId?: string;
|
||||
/** Per-client state tracking streaming accumulation for persistence. */
|
||||
interface ClientSession {
|
||||
conversationId: string;
|
||||
cleanup: () => void;
|
||||
/** Accumulated assistant response text for the current turn. */
|
||||
assistantText: string;
|
||||
/** Tool calls observed during the current turn. */
|
||||
toolCalls: Array<{ toolCallId: string; toolName: string; args: unknown; isError: boolean }>;
|
||||
/** Tool calls in-flight (started but not ended yet). */
|
||||
pendingToolCalls: Map<string, { toolName: string; args: unknown }>;
|
||||
}
|
||||
|
||||
@WebSocketGateway({
|
||||
cors: { origin: '*' },
|
||||
cors: {
|
||||
origin: process.env['GATEWAY_CORS_ORIGIN'] ?? 'http://localhost:3000',
|
||||
},
|
||||
namespace: '/chat',
|
||||
})
|
||||
export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect {
|
||||
@@ -30,19 +46,34 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
server!: Server;
|
||||
|
||||
private readonly logger = new Logger(ChatGateway.name);
|
||||
private readonly clientSessions = new Map<
|
||||
string,
|
||||
{ conversationId: string; cleanup: () => void }
|
||||
>();
|
||||
private readonly clientSessions = new Map<string, ClientSession>();
|
||||
|
||||
constructor(@Inject(AgentService) private readonly agentService: AgentService) {}
|
||||
constructor(
|
||||
@Inject(AgentService) private readonly agentService: AgentService,
|
||||
@Inject(AUTH) private readonly auth: Auth,
|
||||
@Inject(BRAIN) private readonly brain: Brain,
|
||||
@Inject(CommandRegistryService) private readonly commandRegistry: CommandRegistryService,
|
||||
@Inject(CommandExecutorService) private readonly commandExecutor: CommandExecutorService,
|
||||
) {}
|
||||
|
||||
afterInit(): void {
|
||||
this.logger.log('Chat WebSocket gateway initialized');
|
||||
}
|
||||
|
||||
handleConnection(client: Socket): void {
|
||||
async handleConnection(client: Socket): Promise<void> {
|
||||
const session = await validateSocketSession(client.handshake.headers, this.auth);
|
||||
if (!session) {
|
||||
this.logger.warn(`Rejected unauthenticated WebSocket client: ${client.id}`);
|
||||
client.disconnect();
|
||||
return;
|
||||
}
|
||||
|
||||
client.data.user = session.user;
|
||||
client.data.session = session.session;
|
||||
this.logger.log(`Client connected: ${client.id}`);
|
||||
|
||||
// Broadcast command manifest to the newly connected client
|
||||
client.emit('commands:manifest', { manifest: this.commandRegistry.getManifest() });
|
||||
}
|
||||
|
||||
handleDisconnect(client: Socket): void {
|
||||
@@ -58,9 +89,10 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
@SubscribeMessage('message')
|
||||
async handleMessage(
|
||||
@ConnectedSocket() client: Socket,
|
||||
@MessageBody() data: ChatMessage,
|
||||
@MessageBody() data: ChatSocketMessageDto,
|
||||
): Promise<void> {
|
||||
const conversationId = data.conversationId ?? uuid();
|
||||
const userId = (client.data.user as { id: string } | undefined)?.id;
|
||||
|
||||
this.logger.log(`Message from ${client.id} in conversation ${conversationId}`);
|
||||
|
||||
@@ -68,10 +100,22 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
try {
|
||||
let agentSession = this.agentService.getSession(conversationId);
|
||||
if (!agentSession) {
|
||||
// When resuming an existing conversation, load prior messages to inject as context (M1-004)
|
||||
const conversationHistory = await this.loadConversationHistory(conversationId, userId);
|
||||
|
||||
agentSession = await this.agentService.createSession(conversationId, {
|
||||
provider: data.provider,
|
||||
modelId: data.modelId,
|
||||
agentConfigId: data.agentId,
|
||||
userId,
|
||||
conversationHistory: conversationHistory.length > 0 ? conversationHistory : undefined,
|
||||
});
|
||||
|
||||
if (conversationHistory.length > 0) {
|
||||
this.logger.log(
|
||||
`Loaded ${conversationHistory.length} prior messages for conversation=${conversationId}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
this.logger.error(
|
||||
@@ -85,6 +129,33 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
return;
|
||||
}
|
||||
|
||||
// Ensure conversation record exists in the DB before persisting messages
|
||||
if (userId) {
|
||||
await this.ensureConversation(conversationId, userId);
|
||||
}
|
||||
|
||||
// Persist the user message
|
||||
if (userId) {
|
||||
try {
|
||||
await this.brain.conversations.addMessage(
|
||||
{
|
||||
conversationId,
|
||||
role: 'user',
|
||||
content: data.content,
|
||||
metadata: {
|
||||
timestamp: new Date().toISOString(),
|
||||
},
|
||||
},
|
||||
userId,
|
||||
);
|
||||
} catch (err) {
|
||||
this.logger.error(
|
||||
`Failed to persist user message for conversation=${conversationId}`,
|
||||
err instanceof Error ? err.stack : String(err),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Always clean up previous listener to prevent leak
|
||||
const existing = this.clientSessions.get(client.id);
|
||||
if (existing) {
|
||||
@@ -96,11 +167,32 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
this.relayEvent(client, conversationId, event);
|
||||
});
|
||||
|
||||
this.clientSessions.set(client.id, { conversationId, cleanup });
|
||||
this.clientSessions.set(client.id, {
|
||||
conversationId,
|
||||
cleanup,
|
||||
assistantText: '',
|
||||
toolCalls: [],
|
||||
pendingToolCalls: new Map(),
|
||||
});
|
||||
|
||||
// Track channel connection
|
||||
this.agentService.addChannel(conversationId, `websocket:${client.id}`);
|
||||
|
||||
// Send session info so the client knows the model/provider
|
||||
{
|
||||
const agentSession = this.agentService.getSession(conversationId);
|
||||
if (agentSession) {
|
||||
const piSession = agentSession.piSession;
|
||||
client.emit('session:info', {
|
||||
conversationId,
|
||||
provider: agentSession.provider,
|
||||
modelId: agentSession.modelId,
|
||||
thinkingLevel: piSession.thinkingLevel,
|
||||
availableThinkingLevels: piSession.getAvailableThinkingLevels(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Send acknowledgment
|
||||
client.emit('message:ack', { conversationId, messageId: uuid() });
|
||||
|
||||
@@ -119,6 +211,109 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
}
|
||||
}
|
||||
|
||||
@SubscribeMessage('set:thinking')
|
||||
handleSetThinking(
|
||||
@ConnectedSocket() client: Socket,
|
||||
@MessageBody() data: SetThinkingPayload,
|
||||
): void {
|
||||
const session = this.agentService.getSession(data.conversationId);
|
||||
if (!session) {
|
||||
client.emit('error', {
|
||||
conversationId: data.conversationId,
|
||||
error: 'No active session for this conversation.',
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const validLevels = session.piSession.getAvailableThinkingLevels();
|
||||
if (!validLevels.includes(data.level as never)) {
|
||||
client.emit('error', {
|
||||
conversationId: data.conversationId,
|
||||
error: `Invalid thinking level "${data.level}". Available: ${validLevels.join(', ')}`,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
session.piSession.setThinkingLevel(data.level as never);
|
||||
this.logger.log(
|
||||
`Thinking level set to "${data.level}" for conversation ${data.conversationId}`,
|
||||
);
|
||||
|
||||
client.emit('session:info', {
|
||||
conversationId: data.conversationId,
|
||||
provider: session.provider,
|
||||
modelId: session.modelId,
|
||||
thinkingLevel: session.piSession.thinkingLevel,
|
||||
availableThinkingLevels: session.piSession.getAvailableThinkingLevels(),
|
||||
});
|
||||
}
|
||||
|
||||
@SubscribeMessage('command:execute')
|
||||
async handleCommandExecute(
|
||||
@ConnectedSocket() client: Socket,
|
||||
@MessageBody() payload: SlashCommandPayload,
|
||||
): Promise<void> {
|
||||
const userId = (client.data.user as { id: string } | undefined)?.id ?? 'unknown';
|
||||
const result = await this.commandExecutor.execute(payload, userId);
|
||||
client.emit('command:result', result);
|
||||
}
|
||||
|
||||
broadcastReload(payload: SystemReloadPayload): void {
|
||||
this.server.emit('system:reload', payload);
|
||||
this.logger.log('Broadcasted system:reload to all connected clients');
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure a conversation record exists in the DB.
|
||||
* Creates it if absent — safe to call concurrently since a duplicate insert
|
||||
* would fail on the PK constraint and be caught here.
|
||||
*/
|
||||
private async ensureConversation(conversationId: string, userId: string): Promise<void> {
|
||||
try {
|
||||
const existing = await this.brain.conversations.findById(conversationId, userId);
|
||||
if (!existing) {
|
||||
await this.brain.conversations.create({
|
||||
id: conversationId,
|
||||
userId,
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
this.logger.error(
|
||||
`Failed to ensure conversation record for conversation=${conversationId}`,
|
||||
err instanceof Error ? err.stack : String(err),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load prior conversation messages from DB for context injection on session resume (M1-004).
|
||||
* Returns an empty array when no history exists, the conversation is not owned by the user,
|
||||
* or userId is not provided.
|
||||
*/
|
||||
private async loadConversationHistory(
|
||||
conversationId: string,
|
||||
userId: string | undefined,
|
||||
): Promise<ConversationHistoryMessage[]> {
|
||||
if (!userId) return [];
|
||||
|
||||
try {
|
||||
const messages = await this.brain.conversations.findMessages(conversationId, userId);
|
||||
if (messages.length === 0) return [];
|
||||
|
||||
return messages.map((msg) => ({
|
||||
role: msg.role as 'user' | 'assistant' | 'system',
|
||||
content: msg.content,
|
||||
createdAt: msg.createdAt,
|
||||
}));
|
||||
} catch (err) {
|
||||
this.logger.error(
|
||||
`Failed to load conversation history for conversation=${conversationId}`,
|
||||
err instanceof Error ? err.stack : String(err),
|
||||
);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
private relayEvent(client: Socket, conversationId: string, event: AgentSessionEvent): void {
|
||||
if (!client.connected) {
|
||||
this.logger.warn(
|
||||
@@ -128,17 +323,98 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
}
|
||||
|
||||
switch (event.type) {
|
||||
case 'agent_start':
|
||||
case 'agent_start': {
|
||||
// Reset accumulation buffers for the new turn
|
||||
const cs = this.clientSessions.get(client.id);
|
||||
if (cs) {
|
||||
cs.assistantText = '';
|
||||
cs.toolCalls = [];
|
||||
cs.pendingToolCalls.clear();
|
||||
}
|
||||
client.emit('agent:start', { conversationId });
|
||||
break;
|
||||
}
|
||||
|
||||
case 'agent_end':
|
||||
client.emit('agent:end', { conversationId });
|
||||
case 'agent_end': {
|
||||
// Gather usage stats from the Pi session
|
||||
const agentSession = this.agentService.getSession(conversationId);
|
||||
const piSession = agentSession?.piSession;
|
||||
const stats = piSession?.getSessionStats();
|
||||
const contextUsage = piSession?.getContextUsage();
|
||||
|
||||
const usagePayload = stats
|
||||
? {
|
||||
provider: agentSession?.provider ?? 'unknown',
|
||||
modelId: agentSession?.modelId ?? 'unknown',
|
||||
thinkingLevel: piSession?.thinkingLevel ?? 'off',
|
||||
tokens: stats.tokens,
|
||||
cost: stats.cost,
|
||||
context: {
|
||||
percent: contextUsage?.percent ?? null,
|
||||
window: contextUsage?.contextWindow ?? 0,
|
||||
},
|
||||
}
|
||||
: undefined;
|
||||
|
||||
client.emit('agent:end', {
|
||||
conversationId,
|
||||
usage: usagePayload,
|
||||
});
|
||||
|
||||
// Persist the assistant message with metadata
|
||||
const cs = this.clientSessions.get(client.id);
|
||||
const userId = (client.data.user as { id: string } | undefined)?.id;
|
||||
if (cs && userId && cs.assistantText.trim().length > 0) {
|
||||
const metadata: Record<string, unknown> = {
|
||||
timestamp: new Date().toISOString(),
|
||||
model: agentSession?.modelId ?? 'unknown',
|
||||
provider: agentSession?.provider ?? 'unknown',
|
||||
toolCalls: cs.toolCalls,
|
||||
};
|
||||
|
||||
if (stats?.tokens) {
|
||||
metadata['tokenUsage'] = {
|
||||
input: stats.tokens.input,
|
||||
output: stats.tokens.output,
|
||||
cacheRead: stats.tokens.cacheRead,
|
||||
cacheWrite: stats.tokens.cacheWrite,
|
||||
total: stats.tokens.total,
|
||||
};
|
||||
}
|
||||
|
||||
this.brain.conversations
|
||||
.addMessage(
|
||||
{
|
||||
conversationId,
|
||||
role: 'assistant',
|
||||
content: cs.assistantText,
|
||||
metadata,
|
||||
},
|
||||
userId,
|
||||
)
|
||||
.catch((err: unknown) => {
|
||||
this.logger.error(
|
||||
`Failed to persist assistant message for conversation=${conversationId}`,
|
||||
err instanceof Error ? err.stack : String(err),
|
||||
);
|
||||
});
|
||||
|
||||
// Reset accumulation
|
||||
cs.assistantText = '';
|
||||
cs.toolCalls = [];
|
||||
cs.pendingToolCalls.clear();
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 'message_update': {
|
||||
const assistantEvent = event.assistantMessageEvent;
|
||||
if (assistantEvent.type === 'text_delta') {
|
||||
// Accumulate assistant text for persistence
|
||||
const cs = this.clientSessions.get(client.id);
|
||||
if (cs) {
|
||||
cs.assistantText += assistantEvent.delta;
|
||||
}
|
||||
client.emit('agent:text', {
|
||||
conversationId,
|
||||
text: assistantEvent.delta,
|
||||
@@ -152,15 +428,36 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
break;
|
||||
}
|
||||
|
||||
case 'tool_execution_start':
|
||||
case 'tool_execution_start': {
|
||||
// Track pending tool call for later recording
|
||||
const cs = this.clientSessions.get(client.id);
|
||||
if (cs) {
|
||||
cs.pendingToolCalls.set(event.toolCallId, {
|
||||
toolName: event.toolName,
|
||||
args: event.args,
|
||||
});
|
||||
}
|
||||
client.emit('agent:tool:start', {
|
||||
conversationId,
|
||||
toolCallId: event.toolCallId,
|
||||
toolName: event.toolName,
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
case 'tool_execution_end':
|
||||
case 'tool_execution_end': {
|
||||
// Finalise tool call record
|
||||
const cs = this.clientSessions.get(client.id);
|
||||
if (cs) {
|
||||
const pending = cs.pendingToolCalls.get(event.toolCallId);
|
||||
cs.toolCalls.push({
|
||||
toolCallId: event.toolCallId,
|
||||
toolName: event.toolName,
|
||||
args: pending?.args ?? null,
|
||||
isError: event.isError,
|
||||
});
|
||||
cs.pendingToolCalls.delete(event.toolCallId);
|
||||
}
|
||||
client.emit('agent:tool:end', {
|
||||
conversationId,
|
||||
toolCallId: event.toolCallId,
|
||||
@@ -168,6 +465,7 @@ export class ChatGateway implements OnGatewayInit, OnGatewayConnection, OnGatewa
|
||||
isError: event.isError,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { forwardRef, Module } from '@nestjs/common';
|
||||
import { CommandsModule } from '../commands/commands.module.js';
|
||||
import { ChatGateway } from './chat.gateway.js';
|
||||
import { ChatController } from './chat.controller.js';
|
||||
|
||||
@Module({
|
||||
imports: [forwardRef(() => CommandsModule)],
|
||||
controllers: [ChatController],
|
||||
providers: [ChatGateway],
|
||||
exports: [ChatGateway],
|
||||
})
|
||||
export class ChatModule {}
|
||||
|
||||
213
apps/gateway/src/commands/command-executor-p8012.spec.ts
Normal file
213
apps/gateway/src/commands/command-executor-p8012.spec.ts
Normal file
@@ -0,0 +1,213 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { CommandExecutorService } from './command-executor.service.js';
|
||||
import type { SlashCommandPayload } from '@mosaic/types';
|
||||
|
||||
// Minimal mock implementations
|
||||
const mockRegistry = {
|
||||
getManifest: vi.fn(() => ({
|
||||
version: 1,
|
||||
commands: [
|
||||
{ name: 'provider', aliases: [], scope: 'agent', execution: 'hybrid', available: true },
|
||||
{ name: 'mission', aliases: [], scope: 'agent', execution: 'socket', available: true },
|
||||
{ name: 'agent', aliases: ['a'], scope: 'agent', execution: 'socket', available: true },
|
||||
{ name: 'prdy', aliases: [], scope: 'agent', execution: 'socket', available: true },
|
||||
{ name: 'tools', aliases: [], scope: 'agent', execution: 'socket', available: true },
|
||||
],
|
||||
skills: [],
|
||||
})),
|
||||
};
|
||||
|
||||
const mockAgentService = {
|
||||
getSession: vi.fn(() => undefined),
|
||||
};
|
||||
|
||||
const mockSystemOverride = {
|
||||
set: vi.fn(),
|
||||
get: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
renew: vi.fn(),
|
||||
};
|
||||
|
||||
const mockSessionGC = {
|
||||
sweepOrphans: vi.fn(() => ({ orphanedSessions: 0, totalCleaned: [], duration: 0 })),
|
||||
};
|
||||
|
||||
const mockRedis = {
|
||||
set: vi.fn().mockResolvedValue('OK'),
|
||||
get: vi.fn(),
|
||||
del: vi.fn(),
|
||||
};
|
||||
|
||||
function buildService(): CommandExecutorService {
|
||||
return new CommandExecutorService(
|
||||
mockRegistry as never,
|
||||
mockAgentService as never,
|
||||
mockSystemOverride as never,
|
||||
mockSessionGC as never,
|
||||
mockRedis as never,
|
||||
null,
|
||||
null,
|
||||
);
|
||||
}
|
||||
|
||||
describe('CommandExecutorService — P8-012 commands', () => {
|
||||
let service: CommandExecutorService;
|
||||
const userId = 'user-123';
|
||||
const conversationId = 'conv-456';
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
service = buildService();
|
||||
});
|
||||
|
||||
// /provider login — missing provider name
|
||||
it('/provider login with no provider name returns usage error', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'provider', args: 'login', conversationId };
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(false);
|
||||
expect(result.message).toContain('Usage: /provider login');
|
||||
expect(result.command).toBe('provider');
|
||||
});
|
||||
|
||||
// /provider login anthropic — success with URL containing poll token
|
||||
it('/provider login <name> returns success with URL and poll token', async () => {
|
||||
const payload: SlashCommandPayload = {
|
||||
command: 'provider',
|
||||
args: 'login anthropic',
|
||||
conversationId,
|
||||
};
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.command).toBe('provider');
|
||||
expect(result.message).toContain('anthropic');
|
||||
expect(result.message).toContain('http');
|
||||
// data should contain loginUrl and pollToken
|
||||
expect(result.data).toBeDefined();
|
||||
const data = result.data as Record<string, unknown>;
|
||||
expect(typeof data['loginUrl']).toBe('string');
|
||||
expect(typeof data['pollToken']).toBe('string');
|
||||
expect(data['loginUrl'] as string).toContain('anthropic');
|
||||
expect(data['loginUrl'] as string).toContain(data['pollToken'] as string);
|
||||
// Verify Valkey was called
|
||||
expect(mockRedis.set).toHaveBeenCalledOnce();
|
||||
const [key, value, , ttl] = mockRedis.set.mock.calls[0] as [string, string, string, number];
|
||||
expect(key).toContain('mosaic:auth:poll:');
|
||||
const stored = JSON.parse(value) as { status: string; provider: string; userId: string };
|
||||
expect(stored.status).toBe('pending');
|
||||
expect(stored.provider).toBe('anthropic');
|
||||
expect(stored.userId).toBe(userId);
|
||||
expect(ttl).toBe(300);
|
||||
});
|
||||
|
||||
// /provider with no args — returns usage
|
||||
it('/provider with no args returns usage message', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'provider', conversationId };
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.message).toContain('Usage: /provider');
|
||||
});
|
||||
|
||||
// /provider list
|
||||
it('/provider list returns success', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'provider', args: 'list', conversationId };
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.command).toBe('provider');
|
||||
});
|
||||
|
||||
// /provider logout with no name — usage error
|
||||
it('/provider logout with no name returns error', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'provider', args: 'logout', conversationId };
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(false);
|
||||
expect(result.message).toContain('Usage: /provider logout');
|
||||
});
|
||||
|
||||
// /provider unknown subcommand
|
||||
it('/provider unknown subcommand returns error', async () => {
|
||||
const payload: SlashCommandPayload = {
|
||||
command: 'provider',
|
||||
args: 'unknown',
|
||||
conversationId,
|
||||
};
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(false);
|
||||
expect(result.message).toContain('Unknown subcommand');
|
||||
});
|
||||
|
||||
// /mission status
|
||||
it('/mission status returns stub message', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'mission', args: 'status', conversationId };
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.command).toBe('mission');
|
||||
expect(result.message).toContain('Mission status');
|
||||
});
|
||||
|
||||
// /mission with no args
|
||||
it('/mission with no args returns status stub', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'mission', conversationId };
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.message).toContain('Mission status');
|
||||
});
|
||||
|
||||
// /mission set <id>
|
||||
it('/mission set <id> returns confirmation', async () => {
|
||||
const payload: SlashCommandPayload = {
|
||||
command: 'mission',
|
||||
args: 'set my-mission-123',
|
||||
conversationId,
|
||||
};
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.message).toContain('my-mission-123');
|
||||
});
|
||||
|
||||
// /agent list
|
||||
it('/agent list returns stub message', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'agent', args: 'list', conversationId };
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.command).toBe('agent');
|
||||
expect(result.message).toContain('agent');
|
||||
});
|
||||
|
||||
// /agent with no args
|
||||
it('/agent with no args returns usage', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'agent', conversationId };
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.message).toContain('Usage: /agent');
|
||||
});
|
||||
|
||||
// /agent <id> — switch
|
||||
it('/agent <id> returns switch confirmation', async () => {
|
||||
const payload: SlashCommandPayload = {
|
||||
command: 'agent',
|
||||
args: 'my-agent-id',
|
||||
conversationId,
|
||||
};
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.message).toContain('my-agent-id');
|
||||
});
|
||||
|
||||
// /prdy
|
||||
it('/prdy returns PRD wizard message', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'prdy', conversationId };
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.command).toBe('prdy');
|
||||
expect(result.message).toContain('mosaic prdy');
|
||||
});
|
||||
|
||||
// /tools
|
||||
it('/tools returns tools stub message', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'tools', conversationId };
|
||||
const result = await service.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.command).toBe('tools');
|
||||
expect(result.message).toContain('tools');
|
||||
});
|
||||
});
|
||||
373
apps/gateway/src/commands/command-executor.service.ts
Normal file
373
apps/gateway/src/commands/command-executor.service.ts
Normal file
@@ -0,0 +1,373 @@
|
||||
import { forwardRef, Inject, Injectable, Logger, Optional } from '@nestjs/common';
|
||||
import type { QueueHandle } from '@mosaic/queue';
|
||||
import type { SlashCommandPayload, SlashCommandResultPayload } from '@mosaic/types';
|
||||
import { AgentService } from '../agent/agent.service.js';
|
||||
import { ChatGateway } from '../chat/chat.gateway.js';
|
||||
import { SessionGCService } from '../gc/session-gc.service.js';
|
||||
import { SystemOverrideService } from '../preferences/system-override.service.js';
|
||||
import { ReloadService } from '../reload/reload.service.js';
|
||||
import { COMMANDS_REDIS } from './commands.tokens.js';
|
||||
import { CommandRegistryService } from './command-registry.service.js';
|
||||
|
||||
@Injectable()
|
||||
export class CommandExecutorService {
|
||||
private readonly logger = new Logger(CommandExecutorService.name);
|
||||
|
||||
constructor(
|
||||
@Inject(CommandRegistryService) private readonly registry: CommandRegistryService,
|
||||
@Inject(AgentService) private readonly agentService: AgentService,
|
||||
@Inject(SystemOverrideService) private readonly systemOverride: SystemOverrideService,
|
||||
@Inject(SessionGCService) private readonly sessionGC: SessionGCService,
|
||||
@Inject(COMMANDS_REDIS) private readonly redis: QueueHandle['redis'],
|
||||
@Optional()
|
||||
@Inject(forwardRef(() => ReloadService))
|
||||
private readonly reloadService: ReloadService | null,
|
||||
@Optional()
|
||||
@Inject(forwardRef(() => ChatGateway))
|
||||
private readonly chatGateway: ChatGateway | null,
|
||||
) {}
|
||||
|
||||
async execute(payload: SlashCommandPayload, userId: string): Promise<SlashCommandResultPayload> {
|
||||
const { command, args, conversationId } = payload;
|
||||
|
||||
const def = this.registry.getManifest().commands.find((c) => c.name === command);
|
||||
if (!def) {
|
||||
return {
|
||||
command,
|
||||
conversationId,
|
||||
success: false,
|
||||
message: `Unknown command: /${command}`,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
switch (command) {
|
||||
case 'model':
|
||||
return await this.handleModel(args ?? null, conversationId);
|
||||
case 'thinking':
|
||||
return await this.handleThinking(args ?? null, conversationId);
|
||||
case 'system':
|
||||
return await this.handleSystem(args ?? null, conversationId);
|
||||
case 'new':
|
||||
return {
|
||||
command,
|
||||
conversationId,
|
||||
success: true,
|
||||
message: 'Start a new conversation by selecting New Conversation.',
|
||||
};
|
||||
case 'clear':
|
||||
return {
|
||||
command,
|
||||
conversationId,
|
||||
success: true,
|
||||
message: 'Conversation display cleared.',
|
||||
};
|
||||
case 'compact':
|
||||
return {
|
||||
command,
|
||||
conversationId,
|
||||
success: true,
|
||||
message: 'Context compaction requested.',
|
||||
};
|
||||
case 'retry':
|
||||
return {
|
||||
command,
|
||||
conversationId,
|
||||
success: true,
|
||||
message: 'Retry last message requested.',
|
||||
};
|
||||
case 'gc': {
|
||||
// Admin-only: system-wide GC sweep across all sessions
|
||||
const result = await this.sessionGC.sweepOrphans();
|
||||
return {
|
||||
command: 'gc',
|
||||
success: true,
|
||||
message: `GC sweep complete: ${result.orphanedSessions} orphaned sessions cleaned in ${result.duration}ms.`,
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
case 'agent':
|
||||
return await this.handleAgent(args ?? null, conversationId);
|
||||
case 'provider':
|
||||
return await this.handleProvider(args ?? null, userId, conversationId);
|
||||
case 'mission':
|
||||
return await this.handleMission(args ?? null, conversationId, userId);
|
||||
case 'prdy':
|
||||
return {
|
||||
command: 'prdy',
|
||||
success: true,
|
||||
message:
|
||||
'PRD wizard: run `mosaic prdy` in your project workspace to create or update a PRD.',
|
||||
conversationId,
|
||||
};
|
||||
case 'tools':
|
||||
return await this.handleTools(conversationId, userId);
|
||||
case 'reload': {
|
||||
if (!this.reloadService) {
|
||||
return {
|
||||
command: 'reload',
|
||||
conversationId,
|
||||
success: false,
|
||||
message: 'ReloadService is not available.',
|
||||
};
|
||||
}
|
||||
const reloadResult = await this.reloadService.reload('command');
|
||||
this.chatGateway?.broadcastReload(reloadResult);
|
||||
return {
|
||||
command: 'reload',
|
||||
success: true,
|
||||
message: reloadResult.message,
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
default:
|
||||
return {
|
||||
command,
|
||||
conversationId,
|
||||
success: false,
|
||||
message: `Command /${command} is not yet implemented.`,
|
||||
};
|
||||
}
|
||||
} catch (err) {
|
||||
this.logger.error(`Command /${command} failed: ${err}`);
|
||||
return { command, conversationId, success: false, message: String(err) };
|
||||
}
|
||||
}
|
||||
|
||||
private async handleModel(
|
||||
args: string | null,
|
||||
conversationId: string,
|
||||
): Promise<SlashCommandResultPayload> {
|
||||
if (!args) {
|
||||
return {
|
||||
command: 'model',
|
||||
conversationId,
|
||||
success: true,
|
||||
message: 'Usage: /model <model-name>',
|
||||
};
|
||||
}
|
||||
// Update agent session model if session is active
|
||||
// For now, acknowledge the request — full wiring done in P8-012
|
||||
const session = this.agentService.getSession(conversationId);
|
||||
if (!session) {
|
||||
return {
|
||||
command: 'model',
|
||||
conversationId,
|
||||
success: true,
|
||||
message: `Model switch to "${args}" requested. No active session for this conversation.`,
|
||||
};
|
||||
}
|
||||
return {
|
||||
command: 'model',
|
||||
conversationId,
|
||||
success: true,
|
||||
message: `Model switch to "${args}" requested.`,
|
||||
};
|
||||
}
|
||||
|
||||
private async handleThinking(
|
||||
args: string | null,
|
||||
conversationId: string,
|
||||
): Promise<SlashCommandResultPayload> {
|
||||
const level = args?.toLowerCase();
|
||||
if (!level || !['none', 'low', 'medium', 'high', 'auto'].includes(level)) {
|
||||
return {
|
||||
command: 'thinking',
|
||||
conversationId,
|
||||
success: true,
|
||||
message: 'Usage: /thinking <none|low|medium|high|auto>',
|
||||
};
|
||||
}
|
||||
return {
|
||||
command: 'thinking',
|
||||
conversationId,
|
||||
success: true,
|
||||
message: `Thinking level set to "${level}".`,
|
||||
};
|
||||
}
|
||||
|
||||
private async handleSystem(
|
||||
args: string | null,
|
||||
conversationId: string,
|
||||
): Promise<SlashCommandResultPayload> {
|
||||
if (!args || args.trim().length === 0) {
|
||||
// Clear the override when called with no args
|
||||
await this.systemOverride.clear(conversationId);
|
||||
return {
|
||||
command: 'system',
|
||||
conversationId,
|
||||
success: true,
|
||||
message: 'Session system prompt override cleared.',
|
||||
};
|
||||
}
|
||||
|
||||
await this.systemOverride.set(conversationId, args.trim());
|
||||
return {
|
||||
command: 'system',
|
||||
conversationId,
|
||||
success: true,
|
||||
message: `Session system prompt override set (expires in 5 minutes of inactivity).`,
|
||||
};
|
||||
}
|
||||
|
||||
private async handleAgent(
|
||||
args: string | null,
|
||||
conversationId: string,
|
||||
): Promise<SlashCommandResultPayload> {
|
||||
if (!args) {
|
||||
return {
|
||||
command: 'agent',
|
||||
success: true,
|
||||
message: 'Usage: /agent <agent-id> to switch, or /agent list to see available agents.',
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
|
||||
if (args === 'list') {
|
||||
return {
|
||||
command: 'agent',
|
||||
success: true,
|
||||
message: 'Agent listing: use the web dashboard for full agent management.',
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
|
||||
// Switch agent — stub for now (full implementation in P8-015)
|
||||
return {
|
||||
command: 'agent',
|
||||
success: true,
|
||||
message: `Agent switch to "${args}" requested. Restart conversation to apply.`,
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
|
||||
private async handleProvider(
|
||||
args: string | null,
|
||||
userId: string,
|
||||
conversationId: string,
|
||||
): Promise<SlashCommandResultPayload> {
|
||||
if (!args) {
|
||||
return {
|
||||
command: 'provider',
|
||||
success: true,
|
||||
message: 'Usage: /provider list | /provider login <name> | /provider logout <name>',
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
|
||||
const spaceIdx = args.indexOf(' ');
|
||||
const subcommand = spaceIdx >= 0 ? args.slice(0, spaceIdx) : args;
|
||||
const providerName = spaceIdx >= 0 ? args.slice(spaceIdx + 1).trim() : '';
|
||||
|
||||
switch (subcommand) {
|
||||
case 'list':
|
||||
return {
|
||||
command: 'provider',
|
||||
success: true,
|
||||
message: 'Use the web dashboard to manage providers.',
|
||||
conversationId,
|
||||
};
|
||||
|
||||
case 'login': {
|
||||
if (!providerName) {
|
||||
return {
|
||||
command: 'provider',
|
||||
success: false,
|
||||
message: 'Usage: /provider login <provider-name>',
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
const pollToken = crypto.randomUUID();
|
||||
const key = `mosaic:auth:poll:${pollToken}`;
|
||||
// Store pending state in Valkey (TTL 5 minutes)
|
||||
await this.redis.set(
|
||||
key,
|
||||
JSON.stringify({ status: 'pending', provider: providerName, userId }),
|
||||
'EX',
|
||||
300,
|
||||
);
|
||||
// In production this would construct an OAuth URL
|
||||
const loginUrl = `${process.env['MOSAIC_BASE_URL'] ?? 'http://localhost:3000'}/auth/provider/${providerName}?token=${pollToken}`;
|
||||
return {
|
||||
command: 'provider',
|
||||
success: true,
|
||||
message: `Open this URL to authenticate with ${providerName}:\n${loginUrl}`,
|
||||
conversationId,
|
||||
data: { loginUrl, pollToken, provider: providerName },
|
||||
};
|
||||
}
|
||||
|
||||
case 'logout': {
|
||||
if (!providerName) {
|
||||
return {
|
||||
command: 'provider',
|
||||
success: false,
|
||||
message: 'Usage: /provider logout <provider-name>',
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
return {
|
||||
command: 'provider',
|
||||
success: true,
|
||||
message: `Logout from ${providerName}: use the web dashboard to revoke provider tokens.`,
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
|
||||
default:
|
||||
return {
|
||||
command: 'provider',
|
||||
success: false,
|
||||
message: `Unknown subcommand: ${subcommand}. Use list, login, or logout.`,
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private async handleMission(
|
||||
args: string | null,
|
||||
conversationId: string,
|
||||
_userId: string,
|
||||
): Promise<SlashCommandResultPayload> {
|
||||
if (!args || args === 'status') {
|
||||
// TODO: fetch active mission from DB when MissionsService is available
|
||||
return {
|
||||
command: 'mission',
|
||||
success: true,
|
||||
message: 'Mission status: use the web dashboard for full mission management.',
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
|
||||
if (args.startsWith('set ')) {
|
||||
const missionId = args.slice(4).trim();
|
||||
return {
|
||||
command: 'mission',
|
||||
success: true,
|
||||
message: `Mission set to ${missionId}. Session context updated.`,
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
command: 'mission',
|
||||
success: true,
|
||||
message: 'Usage: /mission [status|set <id>|list|tasks]',
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
|
||||
private async handleTools(
|
||||
conversationId: string,
|
||||
_userId: string,
|
||||
): Promise<SlashCommandResultPayload> {
|
||||
// TODO: fetch tool list from active agent session
|
||||
return {
|
||||
command: 'tools',
|
||||
success: true,
|
||||
message:
|
||||
'Available tools depend on the active agent configuration. Use the web dashboard to configure tool access.',
|
||||
conversationId,
|
||||
};
|
||||
}
|
||||
}
|
||||
53
apps/gateway/src/commands/command-registry.service.spec.ts
Normal file
53
apps/gateway/src/commands/command-registry.service.spec.ts
Normal file
@@ -0,0 +1,53 @@
|
||||
import { describe, it, expect, beforeEach } from 'vitest';
|
||||
import { CommandRegistryService } from './command-registry.service.js';
|
||||
import type { CommandDef } from '@mosaic/types';
|
||||
|
||||
const mockCmd: CommandDef = {
|
||||
name: 'test',
|
||||
description: 'Test command',
|
||||
aliases: ['t'],
|
||||
scope: 'core',
|
||||
execution: 'local',
|
||||
available: true,
|
||||
};
|
||||
|
||||
describe('CommandRegistryService', () => {
|
||||
let service: CommandRegistryService;
|
||||
|
||||
beforeEach(() => {
|
||||
service = new CommandRegistryService();
|
||||
});
|
||||
|
||||
it('starts with empty manifest', () => {
|
||||
expect(service.getManifest().commands).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('registers a command', () => {
|
||||
service.registerCommand(mockCmd);
|
||||
expect(service.getManifest().commands).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('updates existing command by name', () => {
|
||||
service.registerCommand(mockCmd);
|
||||
service.registerCommand({ ...mockCmd, description: 'Updated' });
|
||||
expect(service.getManifest().commands).toHaveLength(1);
|
||||
expect(service.getManifest().commands[0]?.description).toBe('Updated');
|
||||
});
|
||||
|
||||
it('onModuleInit registers core commands', () => {
|
||||
service.onModuleInit();
|
||||
const manifest = service.getManifest();
|
||||
expect(manifest.commands.length).toBeGreaterThan(5);
|
||||
expect(manifest.commands.some((c) => c.name === 'model')).toBe(true);
|
||||
expect(manifest.commands.some((c) => c.name === 'help')).toBe(true);
|
||||
});
|
||||
|
||||
it('manifest includes skills array', () => {
|
||||
const manifest = service.getManifest();
|
||||
expect(Array.isArray(manifest.skills)).toBe(true);
|
||||
});
|
||||
|
||||
it('manifest version is 1', () => {
|
||||
expect(service.getManifest().version).toBe(1);
|
||||
});
|
||||
});
|
||||
273
apps/gateway/src/commands/command-registry.service.ts
Normal file
273
apps/gateway/src/commands/command-registry.service.ts
Normal file
@@ -0,0 +1,273 @@
|
||||
import { Injectable, type OnModuleInit } from '@nestjs/common';
|
||||
import type { CommandDef, CommandManifest } from '@mosaic/types';
|
||||
|
||||
@Injectable()
|
||||
export class CommandRegistryService implements OnModuleInit {
|
||||
private readonly commands: CommandDef[] = [];
|
||||
|
||||
registerCommand(def: CommandDef): void {
|
||||
const existing = this.commands.findIndex((c) => c.name === def.name);
|
||||
if (existing >= 0) {
|
||||
this.commands[existing] = def;
|
||||
} else {
|
||||
this.commands.push(def);
|
||||
}
|
||||
}
|
||||
|
||||
registerCommands(defs: CommandDef[]): void {
|
||||
for (const def of defs) {
|
||||
this.registerCommand(def);
|
||||
}
|
||||
}
|
||||
|
||||
getManifest(): CommandManifest {
|
||||
return {
|
||||
version: 1,
|
||||
commands: [...this.commands],
|
||||
skills: [],
|
||||
};
|
||||
}
|
||||
|
||||
onModuleInit(): void {
|
||||
this.registerCommands([
|
||||
{
|
||||
name: 'model',
|
||||
description: 'Switch the active model',
|
||||
aliases: ['m'],
|
||||
args: [
|
||||
{
|
||||
name: 'model-name',
|
||||
type: 'string',
|
||||
optional: false,
|
||||
description: 'Model name to switch to',
|
||||
},
|
||||
],
|
||||
scope: 'core',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'thinking',
|
||||
description: 'Set thinking level (none/low/medium/high/auto)',
|
||||
aliases: ['t'],
|
||||
args: [
|
||||
{
|
||||
name: 'level',
|
||||
type: 'enum',
|
||||
optional: false,
|
||||
values: ['none', 'low', 'medium', 'high', 'auto'],
|
||||
description: 'Thinking level',
|
||||
},
|
||||
],
|
||||
scope: 'core',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'new',
|
||||
description: 'Start a new conversation',
|
||||
aliases: ['n'],
|
||||
scope: 'core',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'clear',
|
||||
description: 'Clear conversation context and GC session artifacts',
|
||||
aliases: [],
|
||||
scope: 'core',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'compact',
|
||||
description: 'Request context compaction',
|
||||
aliases: [],
|
||||
scope: 'core',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'retry',
|
||||
description: 'Retry the last message',
|
||||
aliases: [],
|
||||
scope: 'core',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'rename',
|
||||
description: 'Rename current conversation',
|
||||
aliases: [],
|
||||
args: [
|
||||
{ name: 'name', type: 'string', optional: false, description: 'New conversation name' },
|
||||
],
|
||||
scope: 'core',
|
||||
execution: 'rest',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'history',
|
||||
description: 'Show conversation history',
|
||||
aliases: [],
|
||||
args: [
|
||||
{
|
||||
name: 'limit',
|
||||
type: 'string',
|
||||
optional: true,
|
||||
description: 'Number of messages to show',
|
||||
},
|
||||
],
|
||||
scope: 'core',
|
||||
execution: 'rest',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'export',
|
||||
description: 'Export conversation to markdown or JSON',
|
||||
aliases: [],
|
||||
args: [
|
||||
{
|
||||
name: 'format',
|
||||
type: 'enum',
|
||||
optional: true,
|
||||
values: ['md', 'json'],
|
||||
description: 'Export format',
|
||||
},
|
||||
],
|
||||
scope: 'core',
|
||||
execution: 'rest',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'preferences',
|
||||
description: 'View or set user preferences',
|
||||
aliases: ['pref'],
|
||||
args: [
|
||||
{
|
||||
name: 'action',
|
||||
type: 'enum',
|
||||
optional: true,
|
||||
values: ['show', 'set', 'reset'],
|
||||
description: 'Action to perform',
|
||||
},
|
||||
],
|
||||
scope: 'core',
|
||||
execution: 'rest',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'system',
|
||||
description: 'Set session-scoped system prompt override',
|
||||
aliases: [],
|
||||
args: [
|
||||
{
|
||||
name: 'override',
|
||||
type: 'string',
|
||||
optional: false,
|
||||
description: 'System prompt text to inject for this session',
|
||||
},
|
||||
],
|
||||
scope: 'core',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'status',
|
||||
description: 'Show session and connection status',
|
||||
aliases: ['s'],
|
||||
scope: 'core',
|
||||
execution: 'hybrid',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'help',
|
||||
description: 'Show available commands',
|
||||
aliases: ['h'],
|
||||
scope: 'core',
|
||||
execution: 'local',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'gc',
|
||||
description: 'Trigger garbage collection sweep (admin only — system-wide)',
|
||||
aliases: [],
|
||||
scope: 'admin',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'agent',
|
||||
description: 'Switch or list available agents',
|
||||
aliases: ['a'],
|
||||
args: [
|
||||
{
|
||||
name: 'args',
|
||||
type: 'string',
|
||||
optional: true,
|
||||
description: 'list or <agent-id>',
|
||||
},
|
||||
],
|
||||
scope: 'agent',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'provider',
|
||||
description: 'Manage LLM providers (list/login/logout)',
|
||||
aliases: [],
|
||||
args: [
|
||||
{
|
||||
name: 'args',
|
||||
type: 'string',
|
||||
optional: true,
|
||||
description: 'list | login <name> | logout <name>',
|
||||
},
|
||||
],
|
||||
scope: 'agent',
|
||||
execution: 'hybrid',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'mission',
|
||||
description: 'View or set active mission',
|
||||
aliases: [],
|
||||
args: [
|
||||
{
|
||||
name: 'args',
|
||||
type: 'string',
|
||||
optional: true,
|
||||
description: 'status | set <id> | list | tasks',
|
||||
},
|
||||
],
|
||||
scope: 'agent',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'prdy',
|
||||
description: 'Launch PRD wizard',
|
||||
aliases: [],
|
||||
scope: 'agent',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'tools',
|
||||
description: 'List available agent tools',
|
||||
aliases: [],
|
||||
scope: 'agent',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
{
|
||||
name: 'reload',
|
||||
description: 'Soft-reload gateway plugins and command manifest (admin)',
|
||||
aliases: [],
|
||||
scope: 'admin',
|
||||
execution: 'socket',
|
||||
available: true,
|
||||
},
|
||||
]);
|
||||
}
|
||||
}
|
||||
253
apps/gateway/src/commands/commands.integration.spec.ts
Normal file
253
apps/gateway/src/commands/commands.integration.spec.ts
Normal file
@@ -0,0 +1,253 @@
|
||||
/**
|
||||
* Integration tests for the gateway command system (P8-019)
|
||||
*
|
||||
* Covers:
|
||||
* - CommandRegistryService.getManifest() returns 12+ core commands
|
||||
* - All core commands have correct execution types
|
||||
* - Alias resolution works for all defined aliases
|
||||
* - CommandExecutorService routes known/unknown commands correctly
|
||||
* - /gc handler calls SessionGCService.sweepOrphans
|
||||
* - /system handler calls SystemOverrideService.set
|
||||
* - Unknown command returns descriptive error
|
||||
*/
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { CommandRegistryService } from './command-registry.service.js';
|
||||
import { CommandExecutorService } from './command-executor.service.js';
|
||||
import type { SlashCommandPayload } from '@mosaic/types';
|
||||
|
||||
// ─── Mocks ───────────────────────────────────────────────────────────────────
|
||||
|
||||
const mockAgentService = {
|
||||
getSession: vi.fn(() => undefined),
|
||||
};
|
||||
|
||||
const mockSystemOverride = {
|
||||
set: vi.fn().mockResolvedValue(undefined),
|
||||
get: vi.fn().mockResolvedValue(null),
|
||||
clear: vi.fn().mockResolvedValue(undefined),
|
||||
renew: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
const mockSessionGC = {
|
||||
sweepOrphans: vi.fn().mockResolvedValue({ orphanedSessions: 3, totalCleaned: [], duration: 12 }),
|
||||
};
|
||||
|
||||
const mockRedis = {
|
||||
set: vi.fn().mockResolvedValue('OK'),
|
||||
get: vi.fn().mockResolvedValue(null),
|
||||
del: vi.fn().mockResolvedValue(0),
|
||||
keys: vi.fn().mockResolvedValue([]),
|
||||
};
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
function buildRegistry(): CommandRegistryService {
|
||||
const svc = new CommandRegistryService();
|
||||
svc.onModuleInit(); // seed core commands
|
||||
return svc;
|
||||
}
|
||||
|
||||
function buildExecutor(registry: CommandRegistryService): CommandExecutorService {
|
||||
return new CommandExecutorService(
|
||||
registry as never,
|
||||
mockAgentService as never,
|
||||
mockSystemOverride as never,
|
||||
mockSessionGC as never,
|
||||
mockRedis as never,
|
||||
null, // reloadService (optional)
|
||||
null, // chatGateway (optional)
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Registry Tests ───────────────────────────────────────────────────────────
|
||||
|
||||
describe('CommandRegistryService — integration', () => {
|
||||
let registry: CommandRegistryService;
|
||||
|
||||
beforeEach(() => {
|
||||
registry = buildRegistry();
|
||||
});
|
||||
|
||||
it('getManifest() returns 12 or more core commands after onModuleInit', () => {
|
||||
const manifest = registry.getManifest();
|
||||
expect(manifest.commands.length).toBeGreaterThanOrEqual(12);
|
||||
});
|
||||
|
||||
it('manifest version is 1', () => {
|
||||
expect(registry.getManifest().version).toBe(1);
|
||||
});
|
||||
|
||||
it('manifest.skills is an array', () => {
|
||||
expect(Array.isArray(registry.getManifest().skills)).toBe(true);
|
||||
});
|
||||
|
||||
it('all commands have required fields: name, description, execution, scope, available', () => {
|
||||
for (const cmd of registry.getManifest().commands) {
|
||||
expect(typeof cmd.name).toBe('string');
|
||||
expect(typeof cmd.description).toBe('string');
|
||||
expect(['local', 'socket', 'rest', 'hybrid']).toContain(cmd.execution);
|
||||
expect(['core', 'agent', 'admin']).toContain(cmd.scope);
|
||||
expect(typeof cmd.available).toBe('boolean');
|
||||
}
|
||||
});
|
||||
|
||||
// Execution type verification for core commands
|
||||
const expectedExecutionTypes: Record<string, string> = {
|
||||
model: 'socket',
|
||||
thinking: 'socket',
|
||||
new: 'socket',
|
||||
clear: 'socket',
|
||||
compact: 'socket',
|
||||
retry: 'socket',
|
||||
rename: 'rest',
|
||||
history: 'rest',
|
||||
export: 'rest',
|
||||
preferences: 'rest',
|
||||
system: 'socket',
|
||||
help: 'local',
|
||||
gc: 'socket',
|
||||
agent: 'socket',
|
||||
provider: 'hybrid',
|
||||
mission: 'socket',
|
||||
prdy: 'socket',
|
||||
tools: 'socket',
|
||||
reload: 'socket',
|
||||
};
|
||||
|
||||
for (const [name, expectedExecution] of Object.entries(expectedExecutionTypes)) {
|
||||
it(`command "${name}" has execution type "${expectedExecution}"`, () => {
|
||||
const cmd = registry.getManifest().commands.find((c) => c.name === name);
|
||||
expect(cmd, `command "${name}" not found`).toBeDefined();
|
||||
expect(cmd!.execution).toBe(expectedExecution);
|
||||
});
|
||||
}
|
||||
|
||||
// Alias resolution checks
|
||||
const expectedAliases: Array<[string, string]> = [
|
||||
['m', 'model'],
|
||||
['t', 'thinking'],
|
||||
['n', 'new'],
|
||||
['a', 'agent'],
|
||||
['s', 'status'],
|
||||
['h', 'help'],
|
||||
['pref', 'preferences'],
|
||||
];
|
||||
|
||||
for (const [alias, commandName] of expectedAliases) {
|
||||
it(`alias "/${alias}" resolves to command "${commandName}" via aliases array`, () => {
|
||||
const cmd = registry
|
||||
.getManifest()
|
||||
.commands.find((c) => c.name === commandName || c.aliases?.includes(alias));
|
||||
expect(cmd, `command with alias "${alias}" not found`).toBeDefined();
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// ─── Executor Tests ───────────────────────────────────────────────────────────
|
||||
|
||||
describe('CommandExecutorService — integration', () => {
|
||||
let registry: CommandRegistryService;
|
||||
let executor: CommandExecutorService;
|
||||
const userId = 'user-integ-001';
|
||||
const conversationId = 'conv-integ-001';
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
registry = buildRegistry();
|
||||
executor = buildExecutor(registry);
|
||||
});
|
||||
|
||||
// Unknown command returns error
|
||||
it('unknown command returns success:false with descriptive message', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'nonexistent', conversationId };
|
||||
const result = await executor.execute(payload, userId);
|
||||
expect(result.success).toBe(false);
|
||||
expect(result.message).toContain('nonexistent');
|
||||
expect(result.command).toBe('nonexistent');
|
||||
});
|
||||
|
||||
// /gc handler calls SessionGCService.sweepOrphans (admin-only, no userId arg)
|
||||
it('/gc calls SessionGCService.sweepOrphans without arguments', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'gc', conversationId };
|
||||
const result = await executor.execute(payload, userId);
|
||||
expect(mockSessionGC.sweepOrphans).toHaveBeenCalledWith();
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.message).toContain('GC sweep complete');
|
||||
expect(result.message).toContain('3 orphaned sessions');
|
||||
});
|
||||
|
||||
// /system with args calls SystemOverrideService.set
|
||||
it('/system with text calls SystemOverrideService.set', async () => {
|
||||
const override = 'You are a helpful assistant.';
|
||||
const payload: SlashCommandPayload = { command: 'system', args: override, conversationId };
|
||||
const result = await executor.execute(payload, userId);
|
||||
expect(mockSystemOverride.set).toHaveBeenCalledWith(conversationId, override);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.message).toContain('override set');
|
||||
});
|
||||
|
||||
// /system with no args clears the override
|
||||
it('/system with no args calls SystemOverrideService.clear', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'system', conversationId };
|
||||
const result = await executor.execute(payload, userId);
|
||||
expect(mockSystemOverride.clear).toHaveBeenCalledWith(conversationId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.message).toContain('cleared');
|
||||
});
|
||||
|
||||
// /model with model name returns success
|
||||
it('/model with a model name returns success', async () => {
|
||||
const payload: SlashCommandPayload = {
|
||||
command: 'model',
|
||||
args: 'claude-3-opus',
|
||||
conversationId,
|
||||
};
|
||||
const result = await executor.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.command).toBe('model');
|
||||
expect(result.message).toContain('claude-3-opus');
|
||||
});
|
||||
|
||||
// /thinking with valid level returns success
|
||||
it('/thinking with valid level returns success', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'thinking', args: 'high', conversationId };
|
||||
const result = await executor.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.message).toContain('high');
|
||||
});
|
||||
|
||||
// /thinking with invalid level returns usage message
|
||||
it('/thinking with invalid level returns usage message', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'thinking', args: 'invalid', conversationId };
|
||||
const result = await executor.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.message).toContain('Usage:');
|
||||
});
|
||||
|
||||
// /new command returns success
|
||||
it('/new returns success', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'new', conversationId };
|
||||
const result = await executor.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.command).toBe('new');
|
||||
});
|
||||
|
||||
// /reload without reloadService returns failure
|
||||
it('/reload without ReloadService returns failure', async () => {
|
||||
const payload: SlashCommandPayload = { command: 'reload', conversationId };
|
||||
const result = await executor.execute(payload, userId);
|
||||
expect(result.success).toBe(false);
|
||||
expect(result.message).toContain('ReloadService');
|
||||
});
|
||||
|
||||
// Commands not yet fully implemented return a fallback response
|
||||
const stubCommands = ['clear', 'compact', 'retry'];
|
||||
for (const cmd of stubCommands) {
|
||||
it(`/${cmd} returns success (stub)`, async () => {
|
||||
const payload: SlashCommandPayload = { command: cmd, conversationId };
|
||||
const result = await executor.execute(payload, userId);
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.command).toBe(cmd);
|
||||
});
|
||||
}
|
||||
});
|
||||
37
apps/gateway/src/commands/commands.module.ts
Normal file
37
apps/gateway/src/commands/commands.module.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
import { forwardRef, Inject, Module, type OnApplicationShutdown } from '@nestjs/common';
|
||||
import { createQueue, type QueueHandle } from '@mosaic/queue';
|
||||
import { ChatModule } from '../chat/chat.module.js';
|
||||
import { GCModule } from '../gc/gc.module.js';
|
||||
import { ReloadModule } from '../reload/reload.module.js';
|
||||
import { CommandExecutorService } from './command-executor.service.js';
|
||||
import { CommandRegistryService } from './command-registry.service.js';
|
||||
import { COMMANDS_REDIS } from './commands.tokens.js';
|
||||
|
||||
const COMMANDS_QUEUE_HANDLE = 'COMMANDS_QUEUE_HANDLE';
|
||||
|
||||
@Module({
|
||||
imports: [GCModule, forwardRef(() => ReloadModule), forwardRef(() => ChatModule)],
|
||||
providers: [
|
||||
{
|
||||
provide: COMMANDS_QUEUE_HANDLE,
|
||||
useFactory: (): QueueHandle => {
|
||||
return createQueue();
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: COMMANDS_REDIS,
|
||||
useFactory: (handle: QueueHandle) => handle.redis,
|
||||
inject: [COMMANDS_QUEUE_HANDLE],
|
||||
},
|
||||
CommandRegistryService,
|
||||
CommandExecutorService,
|
||||
],
|
||||
exports: [CommandRegistryService, CommandExecutorService],
|
||||
})
|
||||
export class CommandsModule implements OnApplicationShutdown {
|
||||
constructor(@Inject(COMMANDS_QUEUE_HANDLE) private readonly handle: QueueHandle) {}
|
||||
|
||||
async onApplicationShutdown(): Promise<void> {
|
||||
await this.handle.close().catch(() => {});
|
||||
}
|
||||
}
|
||||
1
apps/gateway/src/commands/commands.tokens.ts
Normal file
1
apps/gateway/src/commands/commands.tokens.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const COMMANDS_REDIS = 'COMMANDS_REDIS';
|
||||
@@ -1,7 +1,9 @@
|
||||
import {
|
||||
BadRequestException,
|
||||
Body,
|
||||
Controller,
|
||||
Delete,
|
||||
ForbiddenException,
|
||||
Get,
|
||||
HttpCode,
|
||||
HttpStatus,
|
||||
@@ -10,16 +12,18 @@ import {
|
||||
Param,
|
||||
Patch,
|
||||
Post,
|
||||
Query,
|
||||
UseGuards,
|
||||
} from '@nestjs/common';
|
||||
import type { Brain } from '@mosaic/brain';
|
||||
import { BRAIN } from '../brain/brain.tokens.js';
|
||||
import { AuthGuard } from '../auth/auth.guard.js';
|
||||
import { CurrentUser } from '../auth/current-user.decorator.js';
|
||||
import type {
|
||||
import {
|
||||
CreateConversationDto,
|
||||
UpdateConversationDto,
|
||||
SendMessageDto,
|
||||
SearchMessagesDto,
|
||||
} from './conversations.dto.js';
|
||||
|
||||
@Controller('api/conversations')
|
||||
@@ -32,9 +36,19 @@ export class ConversationsController {
|
||||
return this.brain.conversations.findAll(user.id);
|
||||
}
|
||||
|
||||
@Get('search')
|
||||
async search(@Query() dto: SearchMessagesDto, @CurrentUser() user: { id: string }) {
|
||||
if (!dto.q || dto.q.trim().length === 0) {
|
||||
throw new BadRequestException('Query parameter "q" is required and must not be empty');
|
||||
}
|
||||
const limit = dto.limit ?? 20;
|
||||
const offset = dto.offset ?? 0;
|
||||
return this.brain.conversations.searchMessages(user.id, dto.q.trim(), limit, offset);
|
||||
}
|
||||
|
||||
@Get(':id')
|
||||
async findOne(@Param('id') id: string) {
|
||||
const conversation = await this.brain.conversations.findById(id);
|
||||
async findOne(@Param('id') id: string, @CurrentUser() user: { id: string }) {
|
||||
const conversation = await this.brain.conversations.findById(id, user.id);
|
||||
if (!conversation) throw new NotFoundException('Conversation not found');
|
||||
return conversation;
|
||||
}
|
||||
@@ -49,35 +63,47 @@ export class ConversationsController {
|
||||
}
|
||||
|
||||
@Patch(':id')
|
||||
async update(@Param('id') id: string, @Body() dto: UpdateConversationDto) {
|
||||
const conversation = await this.brain.conversations.update(id, dto);
|
||||
async update(
|
||||
@Param('id') id: string,
|
||||
@Body() dto: UpdateConversationDto,
|
||||
@CurrentUser() user: { id: string },
|
||||
) {
|
||||
const conversation = await this.brain.conversations.update(id, user.id, dto);
|
||||
if (!conversation) throw new NotFoundException('Conversation not found');
|
||||
return conversation;
|
||||
}
|
||||
|
||||
@Delete(':id')
|
||||
@HttpCode(HttpStatus.NO_CONTENT)
|
||||
async remove(@Param('id') id: string) {
|
||||
const deleted = await this.brain.conversations.remove(id);
|
||||
async remove(@Param('id') id: string, @CurrentUser() user: { id: string }) {
|
||||
const deleted = await this.brain.conversations.remove(id, user.id);
|
||||
if (!deleted) throw new NotFoundException('Conversation not found');
|
||||
}
|
||||
|
||||
@Get(':id/messages')
|
||||
async listMessages(@Param('id') id: string) {
|
||||
const conversation = await this.brain.conversations.findById(id);
|
||||
async listMessages(@Param('id') id: string, @CurrentUser() user: { id: string }) {
|
||||
// Verify ownership explicitly to return a clear 404 rather than an empty list.
|
||||
const conversation = await this.brain.conversations.findById(id, user.id);
|
||||
if (!conversation) throw new NotFoundException('Conversation not found');
|
||||
return this.brain.conversations.findMessages(id);
|
||||
return this.brain.conversations.findMessages(id, user.id);
|
||||
}
|
||||
|
||||
@Post(':id/messages')
|
||||
async addMessage(@Param('id') id: string, @Body() dto: SendMessageDto) {
|
||||
const conversation = await this.brain.conversations.findById(id);
|
||||
if (!conversation) throw new NotFoundException('Conversation not found');
|
||||
return this.brain.conversations.addMessage({
|
||||
conversationId: id,
|
||||
role: dto.role,
|
||||
content: dto.content,
|
||||
metadata: dto.metadata,
|
||||
});
|
||||
async addMessage(
|
||||
@Param('id') id: string,
|
||||
@Body() dto: SendMessageDto,
|
||||
@CurrentUser() user: { id: string },
|
||||
) {
|
||||
const message = await this.brain.conversations.addMessage(
|
||||
{
|
||||
conversationId: id,
|
||||
role: dto.role,
|
||||
content: dto.content,
|
||||
metadata: dto.metadata,
|
||||
},
|
||||
user.id,
|
||||
);
|
||||
if (!message) throw new ForbiddenException('Conversation not found or access denied');
|
||||
return message;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,71 @@
|
||||
export interface CreateConversationDto {
|
||||
import {
|
||||
IsBoolean,
|
||||
IsIn,
|
||||
IsInt,
|
||||
IsObject,
|
||||
IsOptional,
|
||||
IsString,
|
||||
IsUUID,
|
||||
Max,
|
||||
MaxLength,
|
||||
Min,
|
||||
} from 'class-validator';
|
||||
import { Type } from 'class-transformer';
|
||||
|
||||
export class SearchMessagesDto {
|
||||
@IsString()
|
||||
@MaxLength(500)
|
||||
q!: string;
|
||||
|
||||
@IsOptional()
|
||||
@Type(() => Number)
|
||||
@IsInt()
|
||||
@Min(1)
|
||||
@Max(100)
|
||||
limit?: number = 20;
|
||||
|
||||
@IsOptional()
|
||||
@Type(() => Number)
|
||||
@IsInt()
|
||||
@Min(0)
|
||||
offset?: number = 0;
|
||||
}
|
||||
|
||||
export class CreateConversationDto {
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
title?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsUUID()
|
||||
projectId?: string;
|
||||
}
|
||||
|
||||
export interface UpdateConversationDto {
|
||||
export class UpdateConversationDto {
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
title?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsUUID()
|
||||
projectId?: string | null;
|
||||
|
||||
@IsOptional()
|
||||
@IsBoolean()
|
||||
archived?: boolean;
|
||||
}
|
||||
|
||||
export interface SendMessageDto {
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
export class SendMessageDto {
|
||||
@IsIn(['user', 'assistant', 'system'])
|
||||
role!: 'user' | 'assistant' | 'system';
|
||||
|
||||
@IsString()
|
||||
@MaxLength(10_000)
|
||||
content!: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsObject()
|
||||
metadata?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
@@ -44,6 +44,10 @@ function resolveAndValidatePath(raw: string | undefined): string {
|
||||
return resolved;
|
||||
}
|
||||
|
||||
/**
|
||||
* File-based coord endpoints for agent tool consumption.
|
||||
* DB-backed mission CRUD has moved to MissionsController at /api/missions.
|
||||
*/
|
||||
@Controller('api/coord')
|
||||
@UseGuards(AuthGuard)
|
||||
export class CoordController {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
// ── File-based coord DTOs (legacy file-system backed) ──
|
||||
|
||||
export interface CoordMissionStatusDto {
|
||||
mission: {
|
||||
id: string;
|
||||
@@ -47,3 +49,42 @@ export interface CoordTaskDetailDto {
|
||||
startedAt: string;
|
||||
};
|
||||
}
|
||||
|
||||
// ── DB-backed coord DTOs ──
|
||||
|
||||
export interface CreateDbMissionDto {
|
||||
name: string;
|
||||
description?: string;
|
||||
projectId?: string;
|
||||
phase?: string;
|
||||
milestones?: Record<string, unknown>[];
|
||||
config?: Record<string, unknown>;
|
||||
status?: 'planning' | 'active' | 'paused' | 'completed' | 'failed';
|
||||
}
|
||||
|
||||
export interface UpdateDbMissionDto {
|
||||
name?: string;
|
||||
description?: string;
|
||||
projectId?: string;
|
||||
phase?: string;
|
||||
milestones?: Record<string, unknown>[];
|
||||
config?: Record<string, unknown>;
|
||||
status?: 'planning' | 'active' | 'paused' | 'completed' | 'failed';
|
||||
}
|
||||
|
||||
export interface CreateMissionTaskDto {
|
||||
missionId: string;
|
||||
taskId?: string;
|
||||
status?: 'not-started' | 'in-progress' | 'blocked' | 'done' | 'cancelled';
|
||||
description?: string;
|
||||
notes?: string;
|
||||
pr?: string;
|
||||
}
|
||||
|
||||
export interface UpdateMissionTaskDto {
|
||||
taskId?: string;
|
||||
status?: 'not-started' | 'in-progress' | 'blocked' | 'done' | 'cancelled';
|
||||
description?: string;
|
||||
notes?: string;
|
||||
pr?: string;
|
||||
}
|
||||
|
||||
@@ -12,6 +12,10 @@ import {
|
||||
import { promises as fs } from 'node:fs';
|
||||
import path from 'node:path';
|
||||
|
||||
/**
|
||||
* File-based coord operations for agent tool consumption.
|
||||
* DB-backed mission CRUD is handled directly by MissionsController via Brain repos.
|
||||
*/
|
||||
@Injectable()
|
||||
export class CoordService {
|
||||
private readonly logger = new Logger(CoordService.name);
|
||||
|
||||
31
apps/gateway/src/gc/gc.module.ts
Normal file
31
apps/gateway/src/gc/gc.module.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
import { Module, type OnApplicationShutdown, Inject } from '@nestjs/common';
|
||||
import { createQueue, type QueueHandle } from '@mosaic/queue';
|
||||
import { SessionGCService } from './session-gc.service.js';
|
||||
import { REDIS } from './gc.tokens.js';
|
||||
|
||||
const GC_QUEUE_HANDLE = 'GC_QUEUE_HANDLE';
|
||||
|
||||
@Module({
|
||||
providers: [
|
||||
{
|
||||
provide: GC_QUEUE_HANDLE,
|
||||
useFactory: (): QueueHandle => {
|
||||
return createQueue();
|
||||
},
|
||||
},
|
||||
{
|
||||
provide: REDIS,
|
||||
useFactory: (handle: QueueHandle) => handle.redis,
|
||||
inject: [GC_QUEUE_HANDLE],
|
||||
},
|
||||
SessionGCService,
|
||||
],
|
||||
exports: [SessionGCService],
|
||||
})
|
||||
export class GCModule implements OnApplicationShutdown {
|
||||
constructor(@Inject(GC_QUEUE_HANDLE) private readonly handle: QueueHandle) {}
|
||||
|
||||
async onApplicationShutdown(): Promise<void> {
|
||||
await this.handle.close().catch(() => {});
|
||||
}
|
||||
}
|
||||
1
apps/gateway/src/gc/gc.tokens.ts
Normal file
1
apps/gateway/src/gc/gc.tokens.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const REDIS = 'REDIS';
|
||||
112
apps/gateway/src/gc/session-gc.service.spec.ts
Normal file
112
apps/gateway/src/gc/session-gc.service.spec.ts
Normal file
@@ -0,0 +1,112 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { Logger } from '@nestjs/common';
|
||||
import type { QueueHandle } from '@mosaic/queue';
|
||||
import type { LogService } from '@mosaic/log';
|
||||
import { SessionGCService } from './session-gc.service.js';
|
||||
|
||||
type MockRedis = {
|
||||
scan: ReturnType<typeof vi.fn>;
|
||||
del: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
|
||||
describe('SessionGCService', () => {
|
||||
let service: SessionGCService;
|
||||
let mockRedis: MockRedis;
|
||||
let mockLogService: { logs: { promoteToWarm: ReturnType<typeof vi.fn> } };
|
||||
|
||||
/**
|
||||
* Helper: build a scan mock that returns all provided keys in a single
|
||||
* cursor iteration (cursor '0' in → ['0', keys] out).
|
||||
*/
|
||||
function makeScanMock(keys: string[]): ReturnType<typeof vi.fn> {
|
||||
return vi.fn().mockResolvedValue(['0', keys]);
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
mockRedis = {
|
||||
scan: makeScanMock([]),
|
||||
del: vi.fn().mockResolvedValue(0),
|
||||
};
|
||||
|
||||
mockLogService = {
|
||||
logs: {
|
||||
promoteToWarm: vi.fn().mockResolvedValue(0),
|
||||
},
|
||||
};
|
||||
|
||||
// Suppress logger output in tests
|
||||
vi.spyOn(Logger.prototype, 'log').mockImplementation(() => {});
|
||||
|
||||
service = new SessionGCService(
|
||||
mockRedis as unknown as QueueHandle['redis'],
|
||||
mockLogService as unknown as LogService,
|
||||
);
|
||||
});
|
||||
|
||||
it('collect() deletes Valkey keys for session', async () => {
|
||||
mockRedis.scan = makeScanMock(['mosaic:session:abc:system', 'mosaic:session:abc:foo']);
|
||||
const result = await service.collect('abc');
|
||||
expect(mockRedis.del).toHaveBeenCalledWith(
|
||||
'mosaic:session:abc:system',
|
||||
'mosaic:session:abc:foo',
|
||||
);
|
||||
expect(result.cleaned.valkeyKeys).toBe(2);
|
||||
});
|
||||
|
||||
it('collect() with no keys returns empty cleaned valkeyKeys', async () => {
|
||||
mockRedis.scan = makeScanMock([]);
|
||||
const result = await service.collect('abc');
|
||||
expect(result.cleaned.valkeyKeys).toBeUndefined();
|
||||
});
|
||||
|
||||
it('collect() returns sessionId in result', async () => {
|
||||
const result = await service.collect('test-session-id');
|
||||
expect(result.sessionId).toBe('test-session-id');
|
||||
});
|
||||
|
||||
it('fullCollect() deletes all session keys', async () => {
|
||||
mockRedis.scan = makeScanMock(['mosaic:session:abc:system', 'mosaic:session:xyz:foo']);
|
||||
const result = await service.fullCollect();
|
||||
expect(mockRedis.del).toHaveBeenCalled();
|
||||
expect(result.valkeyKeys).toBe(2);
|
||||
});
|
||||
|
||||
it('fullCollect() with no keys returns 0 valkeyKeys', async () => {
|
||||
mockRedis.scan = makeScanMock([]);
|
||||
const result = await service.fullCollect();
|
||||
expect(result.valkeyKeys).toBe(0);
|
||||
expect(mockRedis.del).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('fullCollect() returns duration', async () => {
|
||||
const result = await service.fullCollect();
|
||||
expect(result.duration).toBeGreaterThanOrEqual(0);
|
||||
});
|
||||
|
||||
it('sweepOrphans() extracts unique session IDs and collects them', async () => {
|
||||
// First scan call returns the global session list; subsequent calls return
|
||||
// per-session keys during collect().
|
||||
mockRedis.scan = vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce([
|
||||
'0',
|
||||
['mosaic:session:abc:system', 'mosaic:session:abc:messages', 'mosaic:session:xyz:system'],
|
||||
])
|
||||
// collect('abc') scan
|
||||
.mockResolvedValueOnce(['0', ['mosaic:session:abc:system', 'mosaic:session:abc:messages']])
|
||||
// collect('xyz') scan
|
||||
.mockResolvedValueOnce(['0', ['mosaic:session:xyz:system']]);
|
||||
mockRedis.del.mockResolvedValue(1);
|
||||
|
||||
const result = await service.sweepOrphans();
|
||||
expect(result.orphanedSessions).toBeGreaterThanOrEqual(0);
|
||||
expect(result.duration).toBeGreaterThanOrEqual(0);
|
||||
});
|
||||
|
||||
it('sweepOrphans() returns empty when no session keys', async () => {
|
||||
mockRedis.scan = makeScanMock([]);
|
||||
const result = await service.sweepOrphans();
|
||||
expect(result.orphanedSessions).toBe(0);
|
||||
expect(result.totalCleaned).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
164
apps/gateway/src/gc/session-gc.service.ts
Normal file
164
apps/gateway/src/gc/session-gc.service.ts
Normal file
@@ -0,0 +1,164 @@
|
||||
import { Inject, Injectable, Logger, type OnModuleInit } from '@nestjs/common';
|
||||
import type { QueueHandle } from '@mosaic/queue';
|
||||
import type { LogService } from '@mosaic/log';
|
||||
import { LOG_SERVICE } from '../log/log.tokens.js';
|
||||
import { REDIS } from './gc.tokens.js';
|
||||
|
||||
export interface GCResult {
|
||||
sessionId: string;
|
||||
cleaned: {
|
||||
valkeyKeys?: number;
|
||||
logsDemoted?: number;
|
||||
tempFilesRemoved?: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface GCSweepResult {
|
||||
orphanedSessions: number;
|
||||
totalCleaned: GCResult[];
|
||||
duration: number;
|
||||
}
|
||||
|
||||
export interface FullGCResult {
|
||||
valkeyKeys: number;
|
||||
logsDemoted: number;
|
||||
jobsPurged: number;
|
||||
tempFilesRemoved: number;
|
||||
duration: number;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class SessionGCService implements OnModuleInit {
|
||||
private readonly logger = new Logger(SessionGCService.name);
|
||||
|
||||
constructor(
|
||||
@Inject(REDIS) private readonly redis: QueueHandle['redis'],
|
||||
@Inject(LOG_SERVICE) private readonly logService: LogService,
|
||||
) {}
|
||||
|
||||
onModuleInit(): void {
|
||||
// Fire-and-forget: run full GC asynchronously so it does not block the
|
||||
// NestJS bootstrap chain. Cold-start GC typically takes 100–500 ms
|
||||
// depending on Valkey key count; deferring it removes that latency from
|
||||
// the TTFB of the first HTTP request.
|
||||
this.fullCollect()
|
||||
.then((result) => {
|
||||
this.logger.log(
|
||||
`Full GC complete: ${result.valkeyKeys} Valkey keys, ` +
|
||||
`${result.logsDemoted} logs demoted, ` +
|
||||
`${result.jobsPurged} jobs purged, ` +
|
||||
`${result.tempFilesRemoved} temp dirs removed ` +
|
||||
`(${result.duration}ms)`,
|
||||
);
|
||||
})
|
||||
.catch((err: unknown) => {
|
||||
this.logger.error('Cold-start GC failed', err instanceof Error ? err.stack : String(err));
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Scan Valkey for all keys matching a pattern using SCAN (non-blocking).
|
||||
* KEYS is avoided because it blocks the Valkey event loop for the full scan
|
||||
* duration, which can cause latency spikes under production key volumes.
|
||||
*/
|
||||
private async scanKeys(pattern: string): Promise<string[]> {
|
||||
const collected: string[] = [];
|
||||
let cursor = '0';
|
||||
do {
|
||||
const [nextCursor, keys] = await this.redis.scan(cursor, 'MATCH', pattern, 'COUNT', 100);
|
||||
cursor = nextCursor;
|
||||
collected.push(...keys);
|
||||
} while (cursor !== '0');
|
||||
return collected;
|
||||
}
|
||||
|
||||
/**
|
||||
* Immediate cleanup for a single session (call from destroySession).
|
||||
*/
|
||||
async collect(sessionId: string): Promise<GCResult> {
|
||||
const result: GCResult = { sessionId, cleaned: {} };
|
||||
|
||||
// 1. Valkey: delete all session-scoped keys
|
||||
const pattern = `mosaic:session:${sessionId}:*`;
|
||||
const valkeyKeys = await this.scanKeys(pattern);
|
||||
if (valkeyKeys.length > 0) {
|
||||
await this.redis.del(...valkeyKeys);
|
||||
result.cleaned.valkeyKeys = valkeyKeys.length;
|
||||
}
|
||||
|
||||
// 2. PG: demote hot-tier agent_logs for this session to warm
|
||||
const cutoff = new Date(); // demote all hot logs for this session
|
||||
const logsDemoted = await this.logService.logs.promoteToWarm(cutoff);
|
||||
if (logsDemoted > 0) {
|
||||
result.cleaned.logsDemoted = logsDemoted;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sweep GC — find orphaned artifacts from dead sessions.
|
||||
* System-wide operation: only call from admin-authorized paths or internal
|
||||
* scheduled jobs. Individual session cleanup is handled by collect().
|
||||
*/
|
||||
async sweepOrphans(): Promise<GCSweepResult> {
|
||||
const start = Date.now();
|
||||
const cleaned: GCResult[] = [];
|
||||
|
||||
// 1. Find all session-scoped Valkey keys (non-blocking SCAN)
|
||||
const allSessionKeys = await this.scanKeys('mosaic:session:*');
|
||||
|
||||
// Extract unique session IDs from keys
|
||||
const sessionIds = new Set<string>();
|
||||
for (const key of allSessionKeys) {
|
||||
const match = key.match(/^mosaic:session:([^:]+):/);
|
||||
if (match) sessionIds.add(match[1]!);
|
||||
}
|
||||
|
||||
// 2. For each session ID, collect stale keys
|
||||
for (const sessionId of sessionIds) {
|
||||
const gcResult = await this.collect(sessionId);
|
||||
if (Object.keys(gcResult.cleaned).length > 0) {
|
||||
cleaned.push(gcResult);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
orphanedSessions: cleaned.length,
|
||||
totalCleaned: cleaned,
|
||||
duration: Date.now() - start,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Full GC — aggressive collection for cold start.
|
||||
* Assumes no sessions survived the restart.
|
||||
*/
|
||||
async fullCollect(): Promise<FullGCResult> {
|
||||
const start = Date.now();
|
||||
|
||||
// 1. Valkey: delete ALL session-scoped keys (non-blocking SCAN)
|
||||
const sessionKeys = await this.scanKeys('mosaic:session:*');
|
||||
if (sessionKeys.length > 0) {
|
||||
await this.redis.del(...sessionKeys);
|
||||
}
|
||||
|
||||
// 2. NOTE: channel keys are NOT collected on cold start
|
||||
// (discord/telegram plugins may reconnect and resume)
|
||||
|
||||
// 3. PG: demote stale hot-tier logs older than 24h to warm
|
||||
const hotCutoff = new Date(Date.now() - 24 * 60 * 60 * 1000);
|
||||
const logsDemoted = await this.logService.logs.promoteToWarm(hotCutoff);
|
||||
|
||||
// 4. No summarization job purge API available yet
|
||||
const jobsPurged = 0;
|
||||
|
||||
return {
|
||||
valkeyKeys: sessionKeys.length,
|
||||
logsDemoted,
|
||||
jobsPurged,
|
||||
tempFilesRemoved: 0,
|
||||
duration: Date.now() - start,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,28 @@
|
||||
import { Injectable, Logger, type OnModuleInit, type OnModuleDestroy } from '@nestjs/common';
|
||||
import {
|
||||
Inject,
|
||||
Injectable,
|
||||
Logger,
|
||||
type OnModuleInit,
|
||||
type OnModuleDestroy,
|
||||
} from '@nestjs/common';
|
||||
import cron from 'node-cron';
|
||||
import { SummarizationService } from './summarization.service.js';
|
||||
import { SessionGCService } from '../gc/session-gc.service.js';
|
||||
|
||||
@Injectable()
|
||||
export class CronService implements OnModuleInit, OnModuleDestroy {
|
||||
private readonly logger = new Logger(CronService.name);
|
||||
private readonly tasks: cron.ScheduledTask[] = [];
|
||||
|
||||
constructor(private readonly summarization: SummarizationService) {}
|
||||
constructor(
|
||||
@Inject(SummarizationService) private readonly summarization: SummarizationService,
|
||||
@Inject(SessionGCService) private readonly sessionGC: SessionGCService,
|
||||
) {}
|
||||
|
||||
onModuleInit(): void {
|
||||
const summarizationSchedule = process.env['SUMMARIZATION_CRON'] ?? '0 */6 * * *'; // every 6 hours
|
||||
const tierManagementSchedule = process.env['TIER_MANAGEMENT_CRON'] ?? '0 3 * * *'; // daily at 3am
|
||||
const gcSchedule = process.env['SESSION_GC_CRON'] ?? '0 4 * * *'; // daily at 4am
|
||||
|
||||
this.tasks.push(
|
||||
cron.schedule(summarizationSchedule, () => {
|
||||
@@ -29,8 +40,16 @@ export class CronService implements OnModuleInit, OnModuleDestroy {
|
||||
}),
|
||||
);
|
||||
|
||||
this.tasks.push(
|
||||
cron.schedule(gcSchedule, () => {
|
||||
this.sessionGC.sweepOrphans().catch((err) => {
|
||||
this.logger.error(`Session GC sweep failed: ${err}`);
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
this.logger.log(
|
||||
`Cron scheduled: summarization="${summarizationSchedule}", tier="${tierManagementSchedule}"`,
|
||||
`Cron scheduled: summarization="${summarizationSchedule}", tier="${tierManagementSchedule}", gc="${gcSchedule}"`,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -6,9 +6,11 @@ import { LOG_SERVICE } from './log.tokens.js';
|
||||
import { LogController } from './log.controller.js';
|
||||
import { SummarizationService } from './summarization.service.js';
|
||||
import { CronService } from './cron.service.js';
|
||||
import { GCModule } from '../gc/gc.module.js';
|
||||
|
||||
@Global()
|
||||
@Module({
|
||||
imports: [GCModule],
|
||||
providers: [
|
||||
{
|
||||
provide: LOG_SERVICE,
|
||||
|
||||
@@ -29,7 +29,7 @@ export class SummarizationService {
|
||||
constructor(
|
||||
@Inject(LOG_SERVICE) private readonly logService: LogService,
|
||||
@Inject(MEMORY) private readonly memory: Memory,
|
||||
private readonly embeddings: EmbeddingService,
|
||||
@Inject(EmbeddingService) private readonly embeddings: EmbeddingService,
|
||||
@Inject(DB) private readonly db: Db,
|
||||
) {
|
||||
this.apiKey = process.env['OPENAI_API_KEY'];
|
||||
@@ -137,7 +137,7 @@ export class SummarizationService {
|
||||
|
||||
const promoted = await this.logService.logs.promoteToCold(warmCutoff);
|
||||
const purged = await this.logService.logs.purge(coldCutoff);
|
||||
const decayed = await this.memory.insights.decayOldInsights(decayCutoff);
|
||||
const decayed = await this.memory.insights.decayAllInsights(decayCutoff);
|
||||
|
||||
this.logger.log(
|
||||
`Tier management: ${promoted} logs→cold, ${purged} purged, ${decayed} insights decayed`,
|
||||
|
||||
@@ -1,19 +1,58 @@
|
||||
import { config } from 'dotenv';
|
||||
import { resolve } from 'node:path';
|
||||
|
||||
// Load .env from monorepo root (cwd is apps/gateway when run via pnpm filter)
|
||||
config({ path: resolve(process.cwd(), '../../.env') });
|
||||
config(); // Also load apps/gateway/.env if present (overrides)
|
||||
|
||||
import './tracing.js';
|
||||
import 'reflect-metadata';
|
||||
import { NestFactory } from '@nestjs/core';
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { Logger, ValidationPipe } from '@nestjs/common';
|
||||
import { FastifyAdapter, type NestFastifyApplication } from '@nestjs/platform-fastify';
|
||||
import helmet from '@fastify/helmet';
|
||||
import { listSsoStartupWarnings } from '@mosaic/auth';
|
||||
import { AppModule } from './app.module.js';
|
||||
import { mountAuthHandler } from './auth/auth.controller.js';
|
||||
import { mountMcpHandler } from './mcp/mcp.controller.js';
|
||||
import { McpService } from './mcp/mcp.service.js';
|
||||
|
||||
async function bootstrap(): Promise<void> {
|
||||
const logger = new Logger('Bootstrap');
|
||||
const app = await NestFactory.create<NestFastifyApplication>(AppModule, new FastifyAdapter());
|
||||
|
||||
if (!process.env['BETTER_AUTH_SECRET']) {
|
||||
throw new Error('BETTER_AUTH_SECRET is required');
|
||||
}
|
||||
|
||||
for (const warning of listSsoStartupWarnings()) {
|
||||
logger.warn(warning);
|
||||
}
|
||||
|
||||
const app = await NestFactory.create<NestFastifyApplication>(
|
||||
AppModule,
|
||||
new FastifyAdapter({ bodyLimit: 1_048_576 }),
|
||||
);
|
||||
|
||||
app.enableCors({
|
||||
origin: process.env['GATEWAY_CORS_ORIGIN'] ?? 'http://localhost:3000',
|
||||
credentials: true,
|
||||
methods: ['GET', 'HEAD', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'],
|
||||
});
|
||||
|
||||
await app.register(helmet as never, { contentSecurityPolicy: false });
|
||||
app.useGlobalPipes(
|
||||
new ValidationPipe({
|
||||
whitelist: true,
|
||||
forbidNonWhitelisted: true,
|
||||
transform: true,
|
||||
}),
|
||||
);
|
||||
|
||||
mountAuthHandler(app);
|
||||
mountMcpHandler(app, app.get(McpService));
|
||||
|
||||
const port = process.env['GATEWAY_PORT'] ?? 4000;
|
||||
await app.listen(port as number, '0.0.0.0');
|
||||
const port = Number(process.env['GATEWAY_PORT'] ?? 4000);
|
||||
await app.listen(port, '0.0.0.0');
|
||||
logger.log(`Gateway listening on port ${port}`);
|
||||
}
|
||||
|
||||
|
||||
33
apps/gateway/src/mcp-client/mcp-client.dto.ts
Normal file
33
apps/gateway/src/mcp-client/mcp-client.dto.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
/**
|
||||
* DTOs for MCP client configuration and tool discovery.
|
||||
*/
|
||||
|
||||
export interface McpServerConfigDto {
|
||||
/** Unique name identifying this MCP server */
|
||||
name: string;
|
||||
/** URL of the MCP server (streamable HTTP or SSE endpoint) */
|
||||
url: string;
|
||||
/** Optional HTTP headers to send with requests (e.g., Authorization) */
|
||||
headers?: Record<string, string>;
|
||||
}
|
||||
|
||||
export interface McpToolDto {
|
||||
/** Namespaced tool name: "<serverName>__<toolName>" */
|
||||
name: string;
|
||||
/** Human-readable description of the tool */
|
||||
description: string;
|
||||
/** JSON Schema for tool input parameters */
|
||||
inputSchema: Record<string, unknown>;
|
||||
/** MCP server this tool belongs to */
|
||||
serverName: string;
|
||||
/** Original tool name on the remote server */
|
||||
remoteName: string;
|
||||
}
|
||||
|
||||
export interface McpServerStatusDto {
|
||||
name: string;
|
||||
url: string;
|
||||
connected: boolean;
|
||||
toolCount: number;
|
||||
error?: string;
|
||||
}
|
||||
8
apps/gateway/src/mcp-client/mcp-client.module.ts
Normal file
8
apps/gateway/src/mcp-client/mcp-client.module.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { McpClientService } from './mcp-client.service.js';
|
||||
|
||||
@Module({
|
||||
providers: [McpClientService],
|
||||
exports: [McpClientService],
|
||||
})
|
||||
export class McpClientModule {}
|
||||
331
apps/gateway/src/mcp-client/mcp-client.service.ts
Normal file
331
apps/gateway/src/mcp-client/mcp-client.service.ts
Normal file
@@ -0,0 +1,331 @@
|
||||
import { Injectable, Logger, OnModuleInit, OnModuleDestroy } from '@nestjs/common';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { Type } from '@sinclair/typebox';
|
||||
import type { ToolDefinition } from '@mariozechner/pi-coding-agent';
|
||||
import type { McpServerConfigDto, McpToolDto, McpServerStatusDto } from './mcp-client.dto.js';
|
||||
|
||||
interface ConnectedServer {
|
||||
config: McpServerConfigDto;
|
||||
client: Client;
|
||||
tools: McpToolDto[];
|
||||
connected: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* McpClientService connects to external MCP servers, discovers their tools,
|
||||
* and bridges them into Pi SDK ToolDefinition format for agent sessions.
|
||||
*
|
||||
* Configuration is read from the MCP_SERVERS environment variable:
|
||||
* MCP_SERVERS='[{"name":"my-server","url":"http://localhost:3001/mcp","headers":{"Authorization":"Bearer token"}}]'
|
||||
*/
|
||||
@Injectable()
|
||||
export class McpClientService implements OnModuleInit, OnModuleDestroy {
|
||||
private readonly logger = new Logger(McpClientService.name);
|
||||
private readonly servers = new Map<string, ConnectedServer>();
|
||||
|
||||
async onModuleInit(): Promise<void> {
|
||||
const configs = this.loadConfigs();
|
||||
if (configs.length === 0) {
|
||||
this.logger.log('No external MCP servers configured (MCP_SERVERS not set)');
|
||||
return;
|
||||
}
|
||||
|
||||
this.logger.log(`Connecting to ${configs.length} external MCP server(s)`);
|
||||
await Promise.allSettled(configs.map((cfg) => this.connectServer(cfg)));
|
||||
}
|
||||
|
||||
async onModuleDestroy(): Promise<void> {
|
||||
this.logger.log(`Disconnecting from ${this.servers.size} MCP server(s)`);
|
||||
const disconnects = Array.from(this.servers.values()).map((s) => this.disconnectServer(s));
|
||||
await Promise.allSettled(disconnects);
|
||||
this.servers.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns all bridged Pi SDK ToolDefinitions from all connected MCP servers.
|
||||
*/
|
||||
getToolDefinitions(): ToolDefinition[] {
|
||||
const tools: ToolDefinition[] = [];
|
||||
for (const server of this.servers.values()) {
|
||||
if (!server.connected) continue;
|
||||
for (const mcpTool of server.tools) {
|
||||
tools.push(this.bridgeTool(server.client, mcpTool));
|
||||
}
|
||||
}
|
||||
return tools;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns status information for all configured MCP servers.
|
||||
*/
|
||||
getServerStatuses(): McpServerStatusDto[] {
|
||||
return Array.from(this.servers.values()).map((s) => ({
|
||||
name: s.config.name,
|
||||
url: s.config.url,
|
||||
connected: s.connected,
|
||||
toolCount: s.tools.length,
|
||||
error: s.error,
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to reconnect a server that has been disconnected.
|
||||
*/
|
||||
async reconnectServer(serverName: string): Promise<void> {
|
||||
const existing = this.servers.get(serverName);
|
||||
if (!existing) {
|
||||
throw new Error(`MCP server not found: ${serverName}`);
|
||||
}
|
||||
if (existing.connected) return;
|
||||
|
||||
this.logger.log(`Reconnecting to MCP server: ${serverName}`);
|
||||
await this.connectServer(existing.config);
|
||||
}
|
||||
|
||||
// ─── Private helpers ──────────────────────────────────────────────────────
|
||||
|
||||
private loadConfigs(): McpServerConfigDto[] {
|
||||
const raw = process.env['MCP_SERVERS'];
|
||||
if (!raw) return [];
|
||||
|
||||
try {
|
||||
const parsed: unknown = JSON.parse(raw);
|
||||
if (!Array.isArray(parsed)) {
|
||||
this.logger.warn('MCP_SERVERS must be a JSON array — ignoring');
|
||||
return [];
|
||||
}
|
||||
|
||||
const configs: McpServerConfigDto[] = [];
|
||||
for (const item of parsed) {
|
||||
if (
|
||||
typeof item === 'object' &&
|
||||
item !== null &&
|
||||
'name' in item &&
|
||||
typeof (item as Record<string, unknown>)['name'] === 'string' &&
|
||||
'url' in item &&
|
||||
typeof (item as Record<string, unknown>)['url'] === 'string'
|
||||
) {
|
||||
const cfg = item as McpServerConfigDto;
|
||||
configs.push({
|
||||
name: cfg.name,
|
||||
url: cfg.url,
|
||||
headers: cfg.headers,
|
||||
});
|
||||
} else {
|
||||
this.logger.warn(`Skipping invalid MCP server config entry: ${JSON.stringify(item)}`);
|
||||
}
|
||||
}
|
||||
|
||||
return configs;
|
||||
} catch (err) {
|
||||
this.logger.error(
|
||||
`Failed to parse MCP_SERVERS: ${err instanceof Error ? err.message : String(err)}`,
|
||||
);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
private async connectServer(config: McpServerConfigDto): Promise<void> {
|
||||
const serverEntry: ConnectedServer = {
|
||||
config,
|
||||
client: new Client({ name: 'mosaic-gateway', version: '1.0.0' }),
|
||||
tools: [],
|
||||
connected: false,
|
||||
};
|
||||
|
||||
// Preserve existing entry if reconnecting
|
||||
this.servers.set(config.name, serverEntry);
|
||||
|
||||
try {
|
||||
const url = new URL(config.url);
|
||||
const headers = config.headers ?? {};
|
||||
|
||||
// Attempt StreamableHTTP first, fall back to SSE
|
||||
let connected = false;
|
||||
|
||||
try {
|
||||
const transport = new StreamableHTTPClientTransport(url, { requestInit: { headers } });
|
||||
await serverEntry.client.connect(transport);
|
||||
connected = true;
|
||||
this.logger.log(`Connected to MCP server "${config.name}" via StreamableHTTP`);
|
||||
} catch (streamErr) {
|
||||
this.logger.warn(
|
||||
`StreamableHTTP failed for "${config.name}", trying SSE: ${streamErr instanceof Error ? streamErr.message : String(streamErr)}`,
|
||||
);
|
||||
|
||||
// Reset client for SSE attempt
|
||||
serverEntry.client = new Client({ name: 'mosaic-gateway', version: '1.0.0' });
|
||||
|
||||
try {
|
||||
const transport = new SSEClientTransport(url, { requestInit: { headers } });
|
||||
await serverEntry.client.connect(transport);
|
||||
connected = true;
|
||||
this.logger.log(`Connected to MCP server "${config.name}" via SSE`);
|
||||
} catch (sseErr) {
|
||||
throw new Error(
|
||||
`Both transports failed for "${config.name}": SSE error: ${sseErr instanceof Error ? sseErr.message : String(sseErr)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (!connected) return;
|
||||
|
||||
// Discover tools
|
||||
const toolsResult = await serverEntry.client.listTools();
|
||||
serverEntry.tools = toolsResult.tools.map((t) => ({
|
||||
name: `${config.name}__${t.name}`,
|
||||
description: t.description ?? `Tool ${t.name} from MCP server ${config.name}`,
|
||||
inputSchema: (t.inputSchema as Record<string, unknown>) ?? {},
|
||||
serverName: config.name,
|
||||
remoteName: t.name,
|
||||
}));
|
||||
|
||||
serverEntry.connected = true;
|
||||
this.logger.log(
|
||||
`Discovered ${serverEntry.tools.length} tool(s) from MCP server "${config.name}"`,
|
||||
);
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
serverEntry.error = message;
|
||||
serverEntry.connected = false;
|
||||
this.logger.error(`Failed to connect to MCP server "${config.name}": ${message}`);
|
||||
}
|
||||
}
|
||||
|
||||
private async disconnectServer(server: ConnectedServer): Promise<void> {
|
||||
try {
|
||||
await server.client.close();
|
||||
} catch (err) {
|
||||
this.logger.warn(
|
||||
`Error closing MCP client for "${server.config.name}": ${err instanceof Error ? err.message : String(err)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Bridges a single McpToolDto into a Pi SDK ToolDefinition.
|
||||
* The MCP inputSchema is converted to a TypeBox schema representation.
|
||||
*/
|
||||
private bridgeTool(client: Client, mcpTool: McpToolDto): ToolDefinition {
|
||||
const schema = this.inputSchemaToTypeBox(mcpTool.inputSchema);
|
||||
|
||||
return {
|
||||
name: mcpTool.name,
|
||||
label: mcpTool.remoteName,
|
||||
description: mcpTool.description,
|
||||
parameters: schema,
|
||||
execute: async (_toolCallId: string, params: unknown) => {
|
||||
try {
|
||||
const result = await client.callTool({
|
||||
name: mcpTool.remoteName,
|
||||
arguments: (params as Record<string, unknown>) ?? {},
|
||||
});
|
||||
|
||||
// MCP callTool returns { content: [...], isError?: boolean }
|
||||
const content = Array.isArray(result.content) ? result.content : [];
|
||||
const textParts = content
|
||||
.filter((c): c is { type: 'text'; text: string } => c.type === 'text')
|
||||
.map((c) => c.text)
|
||||
.join('\n');
|
||||
|
||||
if (result.isError) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: `MCP tool error from "${mcpTool.serverName}/${mcpTool.remoteName}": ${textParts || 'Unknown error'}`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
content:
|
||||
content.length > 0
|
||||
? (content as { type: 'text'; text: string }[])
|
||||
: [{ type: 'text' as const, text: '' }],
|
||||
details: undefined,
|
||||
};
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
this.logger.error(
|
||||
`MCP tool call failed: ${mcpTool.serverName}/${mcpTool.remoteName}: ${message}`,
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: `Failed to call MCP tool "${mcpTool.name}": ${message}`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a JSON Schema object to a TypeBox-compatible schema.
|
||||
* For simplicity, maps the inputSchema properties to TypeBox Type.Object.
|
||||
* Unknown/complex schemas fall back to Type.Object with Type.Unknown values.
|
||||
*/
|
||||
private inputSchemaToTypeBox(
|
||||
inputSchema: Record<string, unknown>,
|
||||
): ReturnType<typeof Type.Object> {
|
||||
const properties = inputSchema['properties'];
|
||||
|
||||
if (!properties || typeof properties !== 'object') {
|
||||
return Type.Object({});
|
||||
}
|
||||
|
||||
const required: string[] = Array.isArray(inputSchema['required'])
|
||||
? (inputSchema['required'] as string[])
|
||||
: [];
|
||||
|
||||
const tbProps: Record<string, ReturnType<typeof Type.String>> = {};
|
||||
|
||||
for (const [key, schemaDef] of Object.entries(properties as Record<string, unknown>)) {
|
||||
const def = schemaDef as Record<string, unknown>;
|
||||
const desc = typeof def['description'] === 'string' ? def['description'] : undefined;
|
||||
const isOptional = !required.includes(key);
|
||||
const base = this.jsonSchemaToTypeBox(def);
|
||||
tbProps[key] = isOptional
|
||||
? (Type.Optional(base) as unknown as ReturnType<typeof Type.String>)
|
||||
: (base as unknown as ReturnType<typeof Type.String>);
|
||||
if (desc && tbProps[key]) {
|
||||
// Attach description via metadata
|
||||
(tbProps[key] as Record<string, unknown>)['description'] = desc;
|
||||
}
|
||||
}
|
||||
|
||||
return Type.Object(tbProps as Parameters<typeof Type.Object>[0]);
|
||||
}
|
||||
|
||||
private jsonSchemaToTypeBox(
|
||||
def: Record<string, unknown>,
|
||||
):
|
||||
| ReturnType<typeof Type.String>
|
||||
| ReturnType<typeof Type.Number>
|
||||
| ReturnType<typeof Type.Boolean>
|
||||
| ReturnType<typeof Type.Unknown> {
|
||||
const type = def['type'];
|
||||
const desc = typeof def['description'] === 'string' ? { description: def['description'] } : {};
|
||||
|
||||
switch (type) {
|
||||
case 'string':
|
||||
return Type.String(desc);
|
||||
case 'number':
|
||||
case 'integer':
|
||||
return Type.Number(desc);
|
||||
case 'boolean':
|
||||
return Type.Boolean(desc);
|
||||
default:
|
||||
return Type.Unknown(desc);
|
||||
}
|
||||
}
|
||||
}
|
||||
1
apps/gateway/src/mcp-client/mcp-client.tokens.ts
Normal file
1
apps/gateway/src/mcp-client/mcp-client.tokens.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const MCP_CLIENT_SERVICE = 'MCP_CLIENT_SERVICE';
|
||||
142
apps/gateway/src/mcp/mcp.controller.ts
Normal file
142
apps/gateway/src/mcp/mcp.controller.ts
Normal file
@@ -0,0 +1,142 @@
|
||||
import type { IncomingMessage, ServerResponse } from 'node:http';
|
||||
import { Logger } from '@nestjs/common';
|
||||
import { fromNodeHeaders } from 'better-auth/node';
|
||||
import type { Auth } from '@mosaic/auth';
|
||||
import type { NestFastifyApplication } from '@nestjs/platform-fastify';
|
||||
import type { McpService } from './mcp.service.js';
|
||||
import { AUTH } from '../auth/auth.tokens.js';
|
||||
|
||||
/**
|
||||
* Mounts the MCP streamable HTTP transport endpoint at /mcp on the Fastify instance.
|
||||
*
|
||||
* This follows the same low-level Fastify hook pattern used by the auth controller,
|
||||
* bypassing NestJS routing to directly delegate to the MCP SDK transport handlers.
|
||||
*
|
||||
* Endpoint: POST /mcp (and GET /mcp for SSE stream reconnect)
|
||||
* Auth: Requires a valid BetterAuth session (cookie or Authorization header).
|
||||
* Session: Stateful — each initialized client gets a session ID via Mcp-Session-Id header.
|
||||
*/
|
||||
export function mountMcpHandler(app: NestFastifyApplication, mcpService: McpService): void {
|
||||
const auth = app.get<Auth>(AUTH);
|
||||
const logger = new Logger('McpController');
|
||||
const fastify = app.getHttpAdapter().getInstance();
|
||||
|
||||
fastify.addHook(
|
||||
'onRequest',
|
||||
(
|
||||
req: { raw: IncomingMessage; url: string; method: string },
|
||||
reply: { raw: ServerResponse; hijack: () => void },
|
||||
done: () => void,
|
||||
) => {
|
||||
if (!req.url.startsWith('/mcp')) {
|
||||
done();
|
||||
return;
|
||||
}
|
||||
|
||||
reply.hijack();
|
||||
|
||||
handleMcpRequest(req, reply, auth, mcpService, logger).catch((err: unknown) => {
|
||||
logger.error(
|
||||
`MCP request handler error: ${err instanceof Error ? err.message : String(err)}`,
|
||||
);
|
||||
if (!reply.raw.headersSent) {
|
||||
reply.raw.writeHead(500, { 'Content-Type': 'application/json' });
|
||||
}
|
||||
if (!reply.raw.writableEnded) {
|
||||
reply.raw.end(JSON.stringify({ error: 'Internal server error' }));
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
async function handleMcpRequest(
|
||||
req: { raw: IncomingMessage; url: string; method: string },
|
||||
reply: { raw: ServerResponse; hijack: () => void },
|
||||
auth: Auth,
|
||||
mcpService: McpService,
|
||||
logger: Logger,
|
||||
): Promise<void> {
|
||||
// ─── Authentication ─────────────────────────────────────────────────────
|
||||
const headers = fromNodeHeaders(req.raw.headers);
|
||||
const result = await auth.api.getSession({ headers });
|
||||
|
||||
if (!result) {
|
||||
reply.raw.writeHead(401, { 'Content-Type': 'application/json' });
|
||||
reply.raw.end(JSON.stringify({ error: 'Unauthorized: valid session required' }));
|
||||
return;
|
||||
}
|
||||
|
||||
const userId = result.user.id;
|
||||
|
||||
// ─── Session routing ─────────────────────────────────────────────────────
|
||||
const sessionId = req.raw.headers['mcp-session-id'];
|
||||
|
||||
if (typeof sessionId === 'string' && sessionId.length > 0) {
|
||||
// Existing session request
|
||||
const transport = mcpService.getSession(sessionId);
|
||||
if (!transport) {
|
||||
logger.warn(`MCP session not found: ${sessionId}`);
|
||||
reply.raw.writeHead(404, { 'Content-Type': 'application/json' });
|
||||
reply.raw.end(JSON.stringify({ error: 'Session not found' }));
|
||||
return;
|
||||
}
|
||||
|
||||
await transport.handleRequest(req.raw, reply.raw);
|
||||
return;
|
||||
}
|
||||
|
||||
// ─── Initialize new session ───────────────────────────────────────────────
|
||||
// Only POST requests can initialize a new session (must be initialize message)
|
||||
if (req.method !== 'POST') {
|
||||
reply.raw.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
reply.raw.end(
|
||||
JSON.stringify({
|
||||
error: 'New session must be established via POST with initialize message',
|
||||
}),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse body to verify this is an initialize request before creating a session
|
||||
let body: unknown;
|
||||
try {
|
||||
body = await readRequestBody(req.raw);
|
||||
} catch (err) {
|
||||
logger.warn(
|
||||
`Failed to parse MCP request body: ${err instanceof Error ? err.message : String(err)}`,
|
||||
);
|
||||
reply.raw.writeHead(400, { 'Content-Type': 'application/json' });
|
||||
reply.raw.end(JSON.stringify({ error: 'Invalid request body' }));
|
||||
return;
|
||||
}
|
||||
|
||||
// Create new session and handle this initializing request
|
||||
const { transport } = mcpService.createSession(userId);
|
||||
logger.log(`New MCP session created for user ${userId}`);
|
||||
|
||||
await transport.handleRequest(req.raw, reply.raw, body);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads and parses the JSON body from a Node.js IncomingMessage.
|
||||
*/
|
||||
function readRequestBody(req: IncomingMessage): Promise<unknown> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const chunks: Buffer[] = [];
|
||||
req.on('data', (chunk: Buffer) => chunks.push(chunk));
|
||||
req.on('end', () => {
|
||||
const raw = Buffer.concat(chunks).toString('utf8');
|
||||
if (!raw) {
|
||||
resolve(undefined);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
resolve(JSON.parse(raw));
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
}
|
||||
});
|
||||
req.on('error', reject);
|
||||
});
|
||||
}
|
||||
19
apps/gateway/src/mcp/mcp.dto.ts
Normal file
19
apps/gateway/src/mcp/mcp.dto.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
/**
|
||||
* MCP (Model Context Protocol) DTOs
|
||||
*
|
||||
* Defines the data transfer objects for the MCP streamable HTTP transport.
|
||||
* See: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
|
||||
*/
|
||||
|
||||
export interface McpToolDescriptor {
|
||||
name: string;
|
||||
description: string;
|
||||
inputSchema: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface McpServerInfo {
|
||||
name: string;
|
||||
version: string;
|
||||
protocolVersion: string;
|
||||
tools: McpToolDescriptor[];
|
||||
}
|
||||
10
apps/gateway/src/mcp/mcp.module.ts
Normal file
10
apps/gateway/src/mcp/mcp.module.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { McpService } from './mcp.service.js';
|
||||
import { CoordModule } from '../coord/coord.module.js';
|
||||
|
||||
@Module({
|
||||
imports: [CoordModule],
|
||||
providers: [McpService],
|
||||
exports: [McpService],
|
||||
})
|
||||
export class McpModule {}
|
||||
429
apps/gateway/src/mcp/mcp.service.ts
Normal file
429
apps/gateway/src/mcp/mcp.service.ts
Normal file
@@ -0,0 +1,429 @@
|
||||
import { Injectable, Logger, Inject, OnModuleDestroy } from '@nestjs/common';
|
||||
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||
import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js';
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import { z } from 'zod';
|
||||
import type { Brain } from '@mosaic/brain';
|
||||
import type { Memory } from '@mosaic/memory';
|
||||
import { BRAIN } from '../brain/brain.tokens.js';
|
||||
import { MEMORY } from '../memory/memory.tokens.js';
|
||||
import { EmbeddingService } from '../memory/embedding.service.js';
|
||||
import { CoordService } from '../coord/coord.service.js';
|
||||
|
||||
interface SessionEntry {
|
||||
server: McpServer;
|
||||
transport: StreamableHTTPServerTransport;
|
||||
createdAt: Date;
|
||||
userId: string;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class McpService implements OnModuleDestroy {
|
||||
private readonly logger = new Logger(McpService.name);
|
||||
private readonly sessions = new Map<string, SessionEntry>();
|
||||
|
||||
constructor(
|
||||
@Inject(BRAIN) private readonly brain: Brain,
|
||||
@Inject(MEMORY) private readonly memory: Memory,
|
||||
@Inject(EmbeddingService) private readonly embeddings: EmbeddingService,
|
||||
@Inject(CoordService) private readonly coordService: CoordService,
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Creates a new MCP session with its own server + transport pair.
|
||||
* Returns the transport for use by the controller.
|
||||
*/
|
||||
createSession(userId: string): { sessionId: string; transport: StreamableHTTPServerTransport } {
|
||||
const sessionId = randomUUID();
|
||||
|
||||
const transport = new StreamableHTTPServerTransport({
|
||||
sessionIdGenerator: () => sessionId,
|
||||
onsessioninitialized: (id) => {
|
||||
this.logger.log(`MCP session initialized: ${id} for user ${userId}`);
|
||||
},
|
||||
});
|
||||
|
||||
const server = new McpServer(
|
||||
{ name: 'mosaic-gateway', version: '1.0.0' },
|
||||
{ capabilities: { tools: {} } },
|
||||
);
|
||||
|
||||
this.registerTools(server, userId);
|
||||
|
||||
transport.onclose = () => {
|
||||
this.logger.log(`MCP session closed: ${sessionId}`);
|
||||
this.sessions.delete(sessionId);
|
||||
};
|
||||
|
||||
server.connect(transport).catch((err: unknown) => {
|
||||
this.logger.error(
|
||||
`MCP server connect error for session ${sessionId}: ${err instanceof Error ? err.message : String(err)}`,
|
||||
);
|
||||
});
|
||||
|
||||
this.sessions.set(sessionId, { server, transport, createdAt: new Date(), userId });
|
||||
return { sessionId, transport };
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the transport for an existing session, or null if not found.
|
||||
*/
|
||||
getSession(sessionId: string): StreamableHTTPServerTransport | null {
|
||||
return this.sessions.get(sessionId)?.transport ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers all platform tools on the given McpServer instance.
|
||||
*/
|
||||
private registerTools(server: McpServer, _userId: string): void {
|
||||
// ─── Brain: Project tools ────────────────────────────────────────────
|
||||
|
||||
server.registerTool(
|
||||
'brain_list_projects',
|
||||
{
|
||||
description: 'List all projects in the brain.',
|
||||
inputSchema: z.object({}),
|
||||
},
|
||||
async () => {
|
||||
const projects = await this.brain.projects.findAll();
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: JSON.stringify(projects, null, 2) }],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
server.registerTool(
|
||||
'brain_get_project',
|
||||
{
|
||||
description: 'Get a project by ID.',
|
||||
inputSchema: z.object({
|
||||
id: z.string().describe('Project ID (UUID)'),
|
||||
}),
|
||||
},
|
||||
async ({ id }) => {
|
||||
const project = await this.brain.projects.findById(id);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: project ? JSON.stringify(project, null, 2) : `Project not found: ${id}`,
|
||||
},
|
||||
],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
// ─── Brain: Task tools ───────────────────────────────────────────────
|
||||
|
||||
server.registerTool(
|
||||
'brain_list_tasks',
|
||||
{
|
||||
description: 'List tasks, optionally filtered by project, mission, or status.',
|
||||
inputSchema: z.object({
|
||||
projectId: z.string().optional().describe('Filter by project ID'),
|
||||
missionId: z.string().optional().describe('Filter by mission ID'),
|
||||
status: z.string().optional().describe('Filter by status'),
|
||||
}),
|
||||
},
|
||||
async ({ projectId, missionId, status }) => {
|
||||
type TaskStatus = 'not-started' | 'in-progress' | 'blocked' | 'done' | 'cancelled';
|
||||
let tasks;
|
||||
if (projectId) tasks = await this.brain.tasks.findByProject(projectId);
|
||||
else if (missionId) tasks = await this.brain.tasks.findByMission(missionId);
|
||||
else if (status) tasks = await this.brain.tasks.findByStatus(status as TaskStatus);
|
||||
else tasks = await this.brain.tasks.findAll();
|
||||
return { content: [{ type: 'text' as const, text: JSON.stringify(tasks, null, 2) }] };
|
||||
},
|
||||
);
|
||||
|
||||
server.registerTool(
|
||||
'brain_create_task',
|
||||
{
|
||||
description: 'Create a new task in the brain.',
|
||||
inputSchema: z.object({
|
||||
title: z.string().describe('Task title'),
|
||||
description: z.string().optional().describe('Task description'),
|
||||
projectId: z.string().optional().describe('Project ID'),
|
||||
missionId: z.string().optional().describe('Mission ID'),
|
||||
priority: z.string().optional().describe('Priority: low, medium, high, critical'),
|
||||
}),
|
||||
},
|
||||
async (params) => {
|
||||
type Priority = 'low' | 'medium' | 'high' | 'critical';
|
||||
const task = await this.brain.tasks.create({
|
||||
...params,
|
||||
priority: params.priority as Priority | undefined,
|
||||
});
|
||||
return { content: [{ type: 'text' as const, text: JSON.stringify(task, null, 2) }] };
|
||||
},
|
||||
);
|
||||
|
||||
server.registerTool(
|
||||
'brain_update_task',
|
||||
{
|
||||
description: 'Update an existing task.',
|
||||
inputSchema: z.object({
|
||||
id: z.string().describe('Task ID'),
|
||||
title: z.string().optional(),
|
||||
description: z.string().optional(),
|
||||
status: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe('not-started, in-progress, blocked, done, cancelled'),
|
||||
priority: z.string().optional(),
|
||||
}),
|
||||
},
|
||||
async ({ id, ...updates }) => {
|
||||
type TaskStatus = 'not-started' | 'in-progress' | 'blocked' | 'done' | 'cancelled';
|
||||
type Priority = 'low' | 'medium' | 'high' | 'critical';
|
||||
const task = await this.brain.tasks.update(id, {
|
||||
...updates,
|
||||
status: updates.status as TaskStatus | undefined,
|
||||
priority: updates.priority as Priority | undefined,
|
||||
});
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: task ? JSON.stringify(task, null, 2) : `Task not found: ${id}`,
|
||||
},
|
||||
],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
// ─── Brain: Mission tools ────────────────────────────────────────────
|
||||
|
||||
server.registerTool(
|
||||
'brain_list_missions',
|
||||
{
|
||||
description: 'List all missions, optionally filtered by project.',
|
||||
inputSchema: z.object({
|
||||
projectId: z.string().optional().describe('Filter by project ID'),
|
||||
}),
|
||||
},
|
||||
async ({ projectId }) => {
|
||||
const missions = projectId
|
||||
? await this.brain.missions.findByProject(projectId)
|
||||
: await this.brain.missions.findAll();
|
||||
return { content: [{ type: 'text' as const, text: JSON.stringify(missions, null, 2) }] };
|
||||
},
|
||||
);
|
||||
|
||||
server.registerTool(
|
||||
'brain_list_conversations',
|
||||
{
|
||||
description: 'List conversations for a user.',
|
||||
inputSchema: z.object({
|
||||
userId: z.string().describe('User ID'),
|
||||
}),
|
||||
},
|
||||
async ({ userId }) => {
|
||||
const conversations = await this.brain.conversations.findAll(userId);
|
||||
return {
|
||||
content: [{ type: 'text' as const, text: JSON.stringify(conversations, null, 2) }],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
// ─── Memory tools ────────────────────────────────────────────────────
|
||||
|
||||
server.registerTool(
|
||||
'memory_search',
|
||||
{
|
||||
description:
|
||||
'Search across stored insights and knowledge using natural language. Returns semantically similar results.',
|
||||
inputSchema: z.object({
|
||||
userId: z.string().describe('User ID to search memory for'),
|
||||
query: z.string().describe('Natural language search query'),
|
||||
limit: z.number().optional().describe('Max results (default 5)'),
|
||||
}),
|
||||
},
|
||||
async ({ userId, query, limit }) => {
|
||||
if (!this.embeddings.available) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: 'Semantic search unavailable — no embedding provider configured',
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
const embedding = await this.embeddings.embed(query);
|
||||
const results = await this.memory.insights.searchByEmbedding(userId, embedding, limit ?? 5);
|
||||
return { content: [{ type: 'text' as const, text: JSON.stringify(results, null, 2) }] };
|
||||
},
|
||||
);
|
||||
|
||||
server.registerTool(
|
||||
'memory_get_preferences',
|
||||
{
|
||||
description: 'Retrieve stored preferences for a user.',
|
||||
inputSchema: z.object({
|
||||
userId: z.string().describe('User ID'),
|
||||
category: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe('Filter by category: communication, coding, workflow, appearance, general'),
|
||||
}),
|
||||
},
|
||||
async ({ userId, category }) => {
|
||||
type Cat = 'communication' | 'coding' | 'workflow' | 'appearance' | 'general';
|
||||
const prefs = category
|
||||
? await this.memory.preferences.findByUserAndCategory(userId, category as Cat)
|
||||
: await this.memory.preferences.findByUser(userId);
|
||||
return { content: [{ type: 'text' as const, text: JSON.stringify(prefs, null, 2) }] };
|
||||
},
|
||||
);
|
||||
|
||||
server.registerTool(
|
||||
'memory_save_preference',
|
||||
{
|
||||
description:
|
||||
'Store a learned user preference (e.g., "prefers tables over paragraphs", "timezone: America/Chicago").',
|
||||
inputSchema: z.object({
|
||||
userId: z.string().describe('User ID'),
|
||||
key: z.string().describe('Preference key'),
|
||||
value: z.string().describe('Preference value (JSON string)'),
|
||||
category: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe('Category: communication, coding, workflow, appearance, general'),
|
||||
}),
|
||||
},
|
||||
async ({ userId, key, value, category }) => {
|
||||
type Cat = 'communication' | 'coding' | 'workflow' | 'appearance' | 'general';
|
||||
let parsedValue: unknown;
|
||||
try {
|
||||
parsedValue = JSON.parse(value);
|
||||
} catch {
|
||||
parsedValue = value;
|
||||
}
|
||||
const pref = await this.memory.preferences.upsert({
|
||||
userId,
|
||||
key,
|
||||
value: parsedValue,
|
||||
category: (category as Cat) ?? 'general',
|
||||
source: 'agent',
|
||||
});
|
||||
return { content: [{ type: 'text' as const, text: JSON.stringify(pref, null, 2) }] };
|
||||
},
|
||||
);
|
||||
|
||||
server.registerTool(
|
||||
'memory_save_insight',
|
||||
{
|
||||
description:
|
||||
'Store a learned insight, decision, or knowledge extracted from the current interaction.',
|
||||
inputSchema: z.object({
|
||||
userId: z.string().describe('User ID'),
|
||||
content: z.string().describe('The insight or knowledge to store'),
|
||||
category: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe('Category: decision, learning, preference, fact, pattern, general'),
|
||||
}),
|
||||
},
|
||||
async ({ userId, content, category }) => {
|
||||
type Cat = 'decision' | 'learning' | 'preference' | 'fact' | 'pattern' | 'general';
|
||||
const embedding = this.embeddings.available ? await this.embeddings.embed(content) : null;
|
||||
const insight = await this.memory.insights.create({
|
||||
userId,
|
||||
content,
|
||||
embedding,
|
||||
source: 'agent',
|
||||
category: (category as Cat) ?? 'learning',
|
||||
});
|
||||
return { content: [{ type: 'text' as const, text: JSON.stringify(insight, null, 2) }] };
|
||||
},
|
||||
);
|
||||
|
||||
// ─── Coord tools ─────────────────────────────────────────────────────
|
||||
|
||||
server.registerTool(
|
||||
'coord_mission_status',
|
||||
{
|
||||
description:
|
||||
'Get the current orchestration mission status including milestones, tasks, and active session.',
|
||||
inputSchema: z.object({
|
||||
projectPath: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe('Project path. Defaults to gateway working directory.'),
|
||||
}),
|
||||
},
|
||||
async ({ projectPath }) => {
|
||||
const resolvedPath = projectPath ?? process.cwd();
|
||||
const status = await this.coordService.getMissionStatus(resolvedPath);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: status ? JSON.stringify(status, null, 2) : 'No active coord mission found.',
|
||||
},
|
||||
],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
server.registerTool(
|
||||
'coord_list_tasks',
|
||||
{
|
||||
description: 'List all tasks from the orchestration TASKS.md file.',
|
||||
inputSchema: z.object({
|
||||
projectPath: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe('Project path. Defaults to gateway working directory.'),
|
||||
}),
|
||||
},
|
||||
async ({ projectPath }) => {
|
||||
const resolvedPath = projectPath ?? process.cwd();
|
||||
const tasks = await this.coordService.listTasks(resolvedPath);
|
||||
return { content: [{ type: 'text' as const, text: JSON.stringify(tasks, null, 2) }] };
|
||||
},
|
||||
);
|
||||
|
||||
server.registerTool(
|
||||
'coord_task_detail',
|
||||
{
|
||||
description: 'Get detailed status for a specific orchestration task.',
|
||||
inputSchema: z.object({
|
||||
taskId: z.string().describe('Task ID (e.g. P2-005)'),
|
||||
projectPath: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe('Project path. Defaults to gateway working directory.'),
|
||||
}),
|
||||
},
|
||||
async ({ taskId, projectPath }) => {
|
||||
const resolvedPath = projectPath ?? process.cwd();
|
||||
const detail = await this.coordService.getTaskStatus(resolvedPath, taskId);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text: detail
|
||||
? JSON.stringify(detail, null, 2)
|
||||
: `Task ${taskId} not found in coord mission.`,
|
||||
},
|
||||
],
|
||||
};
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
async onModuleDestroy(): Promise<void> {
|
||||
this.logger.log(`Closing ${this.sessions.size} MCP sessions on shutdown`);
|
||||
const closePromises = Array.from(this.sessions.values()).map(({ transport }) =>
|
||||
transport.close().catch((err: unknown) => {
|
||||
this.logger.warn(
|
||||
`Error closing MCP transport: ${err instanceof Error ? err.message : String(err)}`,
|
||||
);
|
||||
}),
|
||||
);
|
||||
await Promise.all(closePromises);
|
||||
this.sessions.clear();
|
||||
}
|
||||
}
|
||||
1
apps/gateway/src/mcp/mcp.tokens.ts
Normal file
1
apps/gateway/src/mcp/mcp.tokens.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const MCP_SERVICE = 'MCP_SERVICE';
|
||||
@@ -1,36 +1,122 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import type { EmbeddingProvider } from '@mosaic/memory';
|
||||
|
||||
const DEFAULT_MODEL = 'text-embedding-3-small';
|
||||
const DEFAULT_DIMENSIONS = 1536;
|
||||
// ---------------------------------------------------------------------------
|
||||
// Environment-driven configuration
|
||||
//
|
||||
// EMBEDDING_PROVIDER — 'ollama' (default) | 'openai'
|
||||
// EMBEDDING_MODEL — model id, defaults differ per provider
|
||||
// EMBEDDING_DIMENSIONS — integer, defaults differ per provider
|
||||
// OLLAMA_BASE_URL — base URL for Ollama (used when provider=ollama)
|
||||
// EMBEDDING_API_URL — full base URL for OpenAI-compatible API
|
||||
// OPENAI_API_KEY — required for OpenAI provider
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface EmbeddingResponse {
|
||||
const OLLAMA_DEFAULT_MODEL = 'nomic-embed-text';
|
||||
const OLLAMA_DEFAULT_DIMENSIONS = 768;
|
||||
|
||||
const OPENAI_DEFAULT_MODEL = 'text-embedding-3-small';
|
||||
const OPENAI_DEFAULT_DIMENSIONS = 1536;
|
||||
|
||||
/** Known dimension mismatch: warn if pgvector column likely has wrong size */
|
||||
const PGVECTOR_SCHEMA_DIMENSIONS = 1536;
|
||||
|
||||
type EmbeddingBackend = 'ollama' | 'openai';
|
||||
|
||||
interface OllamaEmbeddingResponse {
|
||||
embedding: number[];
|
||||
}
|
||||
|
||||
interface OpenAIEmbeddingResponse {
|
||||
data: Array<{ embedding: number[]; index: number }>;
|
||||
model: string;
|
||||
usage: { prompt_tokens: number; total_tokens: number };
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates embeddings via the OpenAI-compatible embeddings API.
|
||||
* Supports OpenAI, Azure OpenAI, and any provider with a compatible endpoint.
|
||||
* Provider-agnostic embedding service.
|
||||
*
|
||||
* Defaults to Ollama's native embedding API using nomic-embed-text (768 dims).
|
||||
* Falls back to the OpenAI-compatible API when EMBEDDING_PROVIDER=openai or
|
||||
* when OPENAI_API_KEY is set and EMBEDDING_PROVIDER is not explicitly set to ollama.
|
||||
*
|
||||
* Dimension mismatch detection: if the configured dimensions differ from the
|
||||
* pgvector schema (1536), a warning is logged with re-embedding instructions.
|
||||
*/
|
||||
@Injectable()
|
||||
export class EmbeddingService implements EmbeddingProvider {
|
||||
private readonly logger = new Logger(EmbeddingService.name);
|
||||
private readonly apiKey: string | undefined;
|
||||
private readonly baseUrl: string;
|
||||
private readonly backend: EmbeddingBackend;
|
||||
private readonly model: string;
|
||||
readonly dimensions: number;
|
||||
|
||||
readonly dimensions = DEFAULT_DIMENSIONS;
|
||||
// Ollama-specific
|
||||
private readonly ollamaBaseUrl: string | undefined;
|
||||
|
||||
// OpenAI-compatible
|
||||
private readonly openaiApiKey: string | undefined;
|
||||
private readonly openaiBaseUrl: string;
|
||||
|
||||
constructor() {
|
||||
this.apiKey = process.env['OPENAI_API_KEY'];
|
||||
this.baseUrl = process.env['EMBEDDING_API_URL'] ?? 'https://api.openai.com/v1';
|
||||
this.model = process.env['EMBEDDING_MODEL'] ?? DEFAULT_MODEL;
|
||||
// Determine backend
|
||||
const providerEnv = process.env['EMBEDDING_PROVIDER'];
|
||||
const openaiKey = process.env['OPENAI_API_KEY'];
|
||||
const ollamaUrl = process.env['OLLAMA_BASE_URL'] ?? process.env['OLLAMA_HOST'];
|
||||
|
||||
if (providerEnv === 'openai') {
|
||||
this.backend = 'openai';
|
||||
} else if (providerEnv === 'ollama') {
|
||||
this.backend = 'ollama';
|
||||
} else if (process.env['EMBEDDING_API_URL']) {
|
||||
// Legacy: explicit API URL configured → use openai-compat path
|
||||
this.backend = 'openai';
|
||||
} else if (ollamaUrl) {
|
||||
// Ollama available and no explicit override → prefer Ollama
|
||||
this.backend = 'ollama';
|
||||
} else if (openaiKey) {
|
||||
// OpenAI key present → use OpenAI
|
||||
this.backend = 'openai';
|
||||
} else {
|
||||
// Nothing configured — default to ollama (will return zeros when unavailable)
|
||||
this.backend = 'ollama';
|
||||
}
|
||||
|
||||
// Set model and dimension defaults based on backend
|
||||
if (this.backend === 'ollama') {
|
||||
this.model = process.env['EMBEDDING_MODEL'] ?? OLLAMA_DEFAULT_MODEL;
|
||||
this.dimensions =
|
||||
parseInt(process.env['EMBEDDING_DIMENSIONS'] ?? '', 10) || OLLAMA_DEFAULT_DIMENSIONS;
|
||||
this.ollamaBaseUrl = ollamaUrl;
|
||||
this.openaiApiKey = undefined;
|
||||
this.openaiBaseUrl = '';
|
||||
} else {
|
||||
this.model = process.env['EMBEDDING_MODEL'] ?? OPENAI_DEFAULT_MODEL;
|
||||
this.dimensions =
|
||||
parseInt(process.env['EMBEDDING_DIMENSIONS'] ?? '', 10) || OPENAI_DEFAULT_DIMENSIONS;
|
||||
this.ollamaBaseUrl = undefined;
|
||||
this.openaiApiKey = openaiKey;
|
||||
this.openaiBaseUrl = process.env['EMBEDDING_API_URL'] ?? 'https://api.openai.com/v1';
|
||||
}
|
||||
|
||||
// Warn on dimension mismatch with the current schema
|
||||
if (this.dimensions !== PGVECTOR_SCHEMA_DIMENSIONS) {
|
||||
this.logger.warn(
|
||||
`Embedding dimensions (${this.dimensions}) differ from pgvector schema (${PGVECTOR_SCHEMA_DIMENSIONS}). ` +
|
||||
`If insights already contain ${PGVECTOR_SCHEMA_DIMENSIONS}-dim vectors, similarity search will fail. ` +
|
||||
`To fix: truncate the insights table and re-embed, or run a migration to ALTER COLUMN embedding TYPE vector(${this.dimensions}).`,
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.log(
|
||||
`EmbeddingService initialized: backend=${this.backend}, model=${this.model}, dimensions=${this.dimensions}`,
|
||||
);
|
||||
}
|
||||
|
||||
get available(): boolean {
|
||||
return !!this.apiKey;
|
||||
if (this.backend === 'ollama') {
|
||||
return !!this.ollamaBaseUrl;
|
||||
}
|
||||
return !!this.openaiApiKey;
|
||||
}
|
||||
|
||||
async embed(text: string): Promise<number[]> {
|
||||
@@ -39,16 +125,60 @@ export class EmbeddingService implements EmbeddingProvider {
|
||||
}
|
||||
|
||||
async embedBatch(texts: string[]): Promise<number[][]> {
|
||||
if (!this.apiKey) {
|
||||
this.logger.warn('No OPENAI_API_KEY configured — returning zero vectors');
|
||||
if (!this.available) {
|
||||
const reason =
|
||||
this.backend === 'ollama'
|
||||
? 'OLLAMA_BASE_URL not configured'
|
||||
: 'No OPENAI_API_KEY configured';
|
||||
this.logger.warn(`${reason} — returning zero vectors`);
|
||||
return texts.map(() => new Array<number>(this.dimensions).fill(0));
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.baseUrl}/embeddings`, {
|
||||
if (this.backend === 'ollama') {
|
||||
return this.embedBatchOllama(texts);
|
||||
}
|
||||
return this.embedBatchOpenAI(texts);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Ollama backend
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
private async embedBatchOllama(texts: string[]): Promise<number[][]> {
|
||||
const baseUrl = this.ollamaBaseUrl!;
|
||||
const results: number[][] = [];
|
||||
|
||||
// Ollama's /api/embeddings endpoint processes one text at a time
|
||||
for (const text of texts) {
|
||||
const response = await fetch(`${baseUrl}/api/embeddings`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ model: this.model, prompt: text }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const body = await response.text();
|
||||
this.logger.error(`Ollama embedding API error: ${response.status} ${body}`);
|
||||
throw new Error(`Ollama embedding API returned ${response.status}`);
|
||||
}
|
||||
|
||||
const json = (await response.json()) as OllamaEmbeddingResponse;
|
||||
results.push(json.embedding);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OpenAI-compatible backend
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
private async embedBatchOpenAI(texts: string[]): Promise<number[][]> {
|
||||
const response = await fetch(`${this.openaiBaseUrl}/embeddings`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
Authorization: `Bearer ${this.openaiApiKey}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: this.model,
|
||||
@@ -63,7 +193,7 @@ export class EmbeddingService implements EmbeddingProvider {
|
||||
throw new Error(`Embedding API returned ${response.status}`);
|
||||
}
|
||||
|
||||
const json = (await response.json()) as EmbeddingResponse;
|
||||
const json = (await response.json()) as OpenAIEmbeddingResponse;
|
||||
return json.data.sort((a, b) => a.index - b.index).map((d) => d.embedding);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
import type { Memory } from '@mosaic/memory';
|
||||
import { MEMORY } from './memory.tokens.js';
|
||||
import { AuthGuard } from '../auth/auth.guard.js';
|
||||
import { CurrentUser } from '../auth/current-user.decorator.js';
|
||||
import { EmbeddingService } from './embedding.service.js';
|
||||
import type { UpsertPreferenceDto, CreateInsightDto, SearchMemoryDto } from './memory.dto.js';
|
||||
|
||||
@@ -23,33 +24,33 @@ import type { UpsertPreferenceDto, CreateInsightDto, SearchMemoryDto } from './m
|
||||
export class MemoryController {
|
||||
constructor(
|
||||
@Inject(MEMORY) private readonly memory: Memory,
|
||||
private readonly embeddings: EmbeddingService,
|
||||
@Inject(EmbeddingService) private readonly embeddings: EmbeddingService,
|
||||
) {}
|
||||
|
||||
// ─── Preferences ────────────────────────────────────────────────────
|
||||
|
||||
@Get('preferences')
|
||||
async listPreferences(@Query('userId') userId: string, @Query('category') category?: string) {
|
||||
async listPreferences(@CurrentUser() user: { id: string }, @Query('category') category?: string) {
|
||||
if (category) {
|
||||
return this.memory.preferences.findByUserAndCategory(
|
||||
userId,
|
||||
user.id,
|
||||
category as Parameters<typeof this.memory.preferences.findByUserAndCategory>[1],
|
||||
);
|
||||
}
|
||||
return this.memory.preferences.findByUser(userId);
|
||||
return this.memory.preferences.findByUser(user.id);
|
||||
}
|
||||
|
||||
@Get('preferences/:key')
|
||||
async getPreference(@Query('userId') userId: string, @Param('key') key: string) {
|
||||
const pref = await this.memory.preferences.findByUserAndKey(userId, key);
|
||||
async getPreference(@CurrentUser() user: { id: string }, @Param('key') key: string) {
|
||||
const pref = await this.memory.preferences.findByUserAndKey(user.id, key);
|
||||
if (!pref) throw new NotFoundException('Preference not found');
|
||||
return pref;
|
||||
}
|
||||
|
||||
@Post('preferences')
|
||||
async upsertPreference(@Query('userId') userId: string, @Body() dto: UpsertPreferenceDto) {
|
||||
async upsertPreference(@CurrentUser() user: { id: string }, @Body() dto: UpsertPreferenceDto) {
|
||||
return this.memory.preferences.upsert({
|
||||
userId,
|
||||
userId: user.id,
|
||||
key: dto.key,
|
||||
value: dto.value,
|
||||
category: dto.category,
|
||||
@@ -59,33 +60,33 @@ export class MemoryController {
|
||||
|
||||
@Delete('preferences/:key')
|
||||
@HttpCode(HttpStatus.NO_CONTENT)
|
||||
async removePreference(@Query('userId') userId: string, @Param('key') key: string) {
|
||||
const deleted = await this.memory.preferences.remove(userId, key);
|
||||
async removePreference(@CurrentUser() user: { id: string }, @Param('key') key: string) {
|
||||
const deleted = await this.memory.preferences.remove(user.id, key);
|
||||
if (!deleted) throw new NotFoundException('Preference not found');
|
||||
}
|
||||
|
||||
// ─── Insights ───────────────────────────────────────────────────────
|
||||
|
||||
@Get('insights')
|
||||
async listInsights(@Query('userId') userId: string, @Query('limit') limit?: string) {
|
||||
return this.memory.insights.findByUser(userId, limit ? Number(limit) : undefined);
|
||||
async listInsights(@CurrentUser() user: { id: string }, @Query('limit') limit?: string) {
|
||||
return this.memory.insights.findByUser(user.id, limit ? Number(limit) : undefined);
|
||||
}
|
||||
|
||||
@Get('insights/:id')
|
||||
async getInsight(@Param('id') id: string) {
|
||||
const insight = await this.memory.insights.findById(id);
|
||||
async getInsight(@CurrentUser() user: { id: string }, @Param('id') id: string) {
|
||||
const insight = await this.memory.insights.findById(id, user.id);
|
||||
if (!insight) throw new NotFoundException('Insight not found');
|
||||
return insight;
|
||||
}
|
||||
|
||||
@Post('insights')
|
||||
async createInsight(@Query('userId') userId: string, @Body() dto: CreateInsightDto) {
|
||||
async createInsight(@CurrentUser() user: { id: string }, @Body() dto: CreateInsightDto) {
|
||||
const embedding = this.embeddings.available
|
||||
? await this.embeddings.embed(dto.content)
|
||||
: undefined;
|
||||
|
||||
return this.memory.insights.create({
|
||||
userId,
|
||||
userId: user.id,
|
||||
content: dto.content,
|
||||
source: dto.source,
|
||||
category: dto.category,
|
||||
@@ -96,15 +97,15 @@ export class MemoryController {
|
||||
|
||||
@Delete('insights/:id')
|
||||
@HttpCode(HttpStatus.NO_CONTENT)
|
||||
async removeInsight(@Param('id') id: string) {
|
||||
const deleted = await this.memory.insights.remove(id);
|
||||
async removeInsight(@CurrentUser() user: { id: string }, @Param('id') id: string) {
|
||||
const deleted = await this.memory.insights.remove(id, user.id);
|
||||
if (!deleted) throw new NotFoundException('Insight not found');
|
||||
}
|
||||
|
||||
// ─── Search ─────────────────────────────────────────────────────────
|
||||
|
||||
@Post('search')
|
||||
async searchMemory(@Query('userId') userId: string, @Body() dto: SearchMemoryDto) {
|
||||
async searchMemory(@CurrentUser() user: { id: string }, @Body() dto: SearchMemoryDto) {
|
||||
if (!this.embeddings.available) {
|
||||
return {
|
||||
query: dto.query,
|
||||
@@ -115,7 +116,7 @@ export class MemoryController {
|
||||
|
||||
const queryEmbedding = await this.embeddings.embed(dto.query);
|
||||
const results = await this.memory.insights.searchByEmbedding(
|
||||
userId,
|
||||
user.id,
|
||||
queryEmbedding,
|
||||
dto.limit ?? 10,
|
||||
dto.maxDistance ?? 0.8,
|
||||
|
||||
@@ -15,37 +15,55 @@ import {
|
||||
import type { Brain } from '@mosaic/brain';
|
||||
import { BRAIN } from '../brain/brain.tokens.js';
|
||||
import { AuthGuard } from '../auth/auth.guard.js';
|
||||
import type { CreateMissionDto, UpdateMissionDto } from './missions.dto.js';
|
||||
import { CurrentUser } from '../auth/current-user.decorator.js';
|
||||
import {
|
||||
CreateMissionDto,
|
||||
UpdateMissionDto,
|
||||
CreateMissionTaskDto,
|
||||
UpdateMissionTaskDto,
|
||||
} from './missions.dto.js';
|
||||
|
||||
@Controller('api/missions')
|
||||
@UseGuards(AuthGuard)
|
||||
export class MissionsController {
|
||||
constructor(@Inject(BRAIN) private readonly brain: Brain) {}
|
||||
|
||||
// ── Missions CRUD (user-scoped) ──
|
||||
|
||||
@Get()
|
||||
async list() {
|
||||
return this.brain.missions.findAll();
|
||||
async list(@CurrentUser() user: { id: string }) {
|
||||
return this.brain.missions.findAllByUser(user.id);
|
||||
}
|
||||
|
||||
@Get(':id')
|
||||
async findOne(@Param('id') id: string) {
|
||||
const mission = await this.brain.missions.findById(id);
|
||||
async findOne(@Param('id') id: string, @CurrentUser() user: { id: string }) {
|
||||
const mission = await this.brain.missions.findByIdAndUser(id, user.id);
|
||||
if (!mission) throw new NotFoundException('Mission not found');
|
||||
return mission;
|
||||
}
|
||||
|
||||
@Post()
|
||||
async create(@Body() dto: CreateMissionDto) {
|
||||
async create(@Body() dto: CreateMissionDto, @CurrentUser() user: { id: string }) {
|
||||
return this.brain.missions.create({
|
||||
name: dto.name,
|
||||
description: dto.description,
|
||||
projectId: dto.projectId,
|
||||
userId: user.id,
|
||||
phase: dto.phase,
|
||||
milestones: dto.milestones,
|
||||
config: dto.config,
|
||||
status: dto.status,
|
||||
});
|
||||
}
|
||||
|
||||
@Patch(':id')
|
||||
async update(@Param('id') id: string, @Body() dto: UpdateMissionDto) {
|
||||
async update(
|
||||
@Param('id') id: string,
|
||||
@Body() dto: UpdateMissionDto,
|
||||
@CurrentUser() user: { id: string },
|
||||
) {
|
||||
const existing = await this.brain.missions.findByIdAndUser(id, user.id);
|
||||
if (!existing) throw new NotFoundException('Mission not found');
|
||||
const mission = await this.brain.missions.update(id, dto);
|
||||
if (!mission) throw new NotFoundException('Mission not found');
|
||||
return mission;
|
||||
@@ -53,8 +71,82 @@ export class MissionsController {
|
||||
|
||||
@Delete(':id')
|
||||
@HttpCode(HttpStatus.NO_CONTENT)
|
||||
async remove(@Param('id') id: string) {
|
||||
async remove(@Param('id') id: string, @CurrentUser() user: { id: string }) {
|
||||
const existing = await this.brain.missions.findByIdAndUser(id, user.id);
|
||||
if (!existing) throw new NotFoundException('Mission not found');
|
||||
const deleted = await this.brain.missions.remove(id);
|
||||
if (!deleted) throw new NotFoundException('Mission not found');
|
||||
}
|
||||
|
||||
// ── Mission Tasks sub-routes ──
|
||||
|
||||
@Get(':missionId/tasks')
|
||||
async listTasks(@Param('missionId') missionId: string, @CurrentUser() user: { id: string }) {
|
||||
const mission = await this.brain.missions.findByIdAndUser(missionId, user.id);
|
||||
if (!mission) throw new NotFoundException('Mission not found');
|
||||
return this.brain.missionTasks.findByMissionAndUser(missionId, user.id);
|
||||
}
|
||||
|
||||
@Get(':missionId/tasks/:taskId')
|
||||
async getTask(
|
||||
@Param('missionId') missionId: string,
|
||||
@Param('taskId') taskId: string,
|
||||
@CurrentUser() user: { id: string },
|
||||
) {
|
||||
const mission = await this.brain.missions.findByIdAndUser(missionId, user.id);
|
||||
if (!mission) throw new NotFoundException('Mission not found');
|
||||
const task = await this.brain.missionTasks.findByIdAndUser(taskId, user.id);
|
||||
if (!task) throw new NotFoundException('Mission task not found');
|
||||
return task;
|
||||
}
|
||||
|
||||
@Post(':missionId/tasks')
|
||||
async createTask(
|
||||
@Param('missionId') missionId: string,
|
||||
@Body() dto: CreateMissionTaskDto,
|
||||
@CurrentUser() user: { id: string },
|
||||
) {
|
||||
const mission = await this.brain.missions.findByIdAndUser(missionId, user.id);
|
||||
if (!mission) throw new NotFoundException('Mission not found');
|
||||
return this.brain.missionTasks.create({
|
||||
missionId,
|
||||
taskId: dto.taskId,
|
||||
userId: user.id,
|
||||
status: dto.status,
|
||||
description: dto.description,
|
||||
notes: dto.notes,
|
||||
pr: dto.pr,
|
||||
});
|
||||
}
|
||||
|
||||
@Patch(':missionId/tasks/:taskId')
|
||||
async updateTask(
|
||||
@Param('missionId') missionId: string,
|
||||
@Param('taskId') taskId: string,
|
||||
@Body() dto: UpdateMissionTaskDto,
|
||||
@CurrentUser() user: { id: string },
|
||||
) {
|
||||
const mission = await this.brain.missions.findByIdAndUser(missionId, user.id);
|
||||
if (!mission) throw new NotFoundException('Mission not found');
|
||||
const existing = await this.brain.missionTasks.findByIdAndUser(taskId, user.id);
|
||||
if (!existing) throw new NotFoundException('Mission task not found');
|
||||
const updated = await this.brain.missionTasks.update(taskId, dto);
|
||||
if (!updated) throw new NotFoundException('Mission task not found');
|
||||
return updated;
|
||||
}
|
||||
|
||||
@Delete(':missionId/tasks/:taskId')
|
||||
@HttpCode(HttpStatus.NO_CONTENT)
|
||||
async removeTask(
|
||||
@Param('missionId') missionId: string,
|
||||
@Param('taskId') taskId: string,
|
||||
@CurrentUser() user: { id: string },
|
||||
) {
|
||||
const mission = await this.brain.missions.findByIdAndUser(missionId, user.id);
|
||||
if (!mission) throw new NotFoundException('Mission not found');
|
||||
const existing = await this.brain.missionTasks.findByIdAndUser(taskId, user.id);
|
||||
if (!existing) throw new NotFoundException('Mission task not found');
|
||||
const deleted = await this.brain.missionTasks.remove(taskId);
|
||||
if (!deleted) throw new NotFoundException('Mission task not found');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,123 @@
|
||||
export interface CreateMissionDto {
|
||||
name: string;
|
||||
import { IsArray, IsIn, IsObject, IsOptional, IsString, IsUUID, MaxLength } from 'class-validator';
|
||||
|
||||
const missionStatuses = ['planning', 'active', 'paused', 'completed', 'failed'] as const;
|
||||
const taskStatuses = ['not-started', 'in-progress', 'blocked', 'done', 'cancelled'] as const;
|
||||
|
||||
export class CreateMissionDto {
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
name!: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(10_000)
|
||||
description?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsUUID()
|
||||
projectId?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsIn(missionStatuses)
|
||||
status?: 'planning' | 'active' | 'paused' | 'completed' | 'failed';
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
phase?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsArray()
|
||||
milestones?: Record<string, unknown>[];
|
||||
|
||||
@IsOptional()
|
||||
@IsObject()
|
||||
config?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface UpdateMissionDto {
|
||||
export class UpdateMissionDto {
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
name?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(10_000)
|
||||
description?: string | null;
|
||||
|
||||
@IsOptional()
|
||||
@IsUUID()
|
||||
projectId?: string | null;
|
||||
|
||||
@IsOptional()
|
||||
@IsIn(missionStatuses)
|
||||
status?: 'planning' | 'active' | 'paused' | 'completed' | 'failed';
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
phase?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsArray()
|
||||
milestones?: Record<string, unknown>[];
|
||||
|
||||
@IsOptional()
|
||||
@IsObject()
|
||||
config?: Record<string, unknown>;
|
||||
|
||||
@IsOptional()
|
||||
@IsObject()
|
||||
metadata?: Record<string, unknown> | null;
|
||||
}
|
||||
|
||||
export class CreateMissionTaskDto {
|
||||
@IsOptional()
|
||||
@IsUUID()
|
||||
taskId?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsIn(taskStatuses)
|
||||
status?: 'not-started' | 'in-progress' | 'blocked' | 'done' | 'cancelled';
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(10_000)
|
||||
description?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(10_000)
|
||||
notes?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
pr?: string;
|
||||
}
|
||||
|
||||
export class UpdateMissionTaskDto {
|
||||
@IsOptional()
|
||||
@IsUUID()
|
||||
taskId?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsIn(taskStatuses)
|
||||
status?: 'not-started' | 'in-progress' | 'blocked' | 'done' | 'cancelled';
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(10_000)
|
||||
description?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(10_000)
|
||||
notes?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
pr?: string;
|
||||
}
|
||||
|
||||
11
apps/gateway/src/plugin/plugin.interface.ts
Normal file
11
apps/gateway/src/plugin/plugin.interface.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
export interface IChannelPlugin {
|
||||
readonly name: string;
|
||||
start(): Promise<void>;
|
||||
stop(): Promise<void>;
|
||||
/** Called when a new project is bootstrapped. Return channelId if a channel was created. */
|
||||
onProjectCreated?(project: {
|
||||
id: string;
|
||||
name: string;
|
||||
description?: string;
|
||||
}): Promise<{ channelId: string } | null>;
|
||||
}
|
||||
117
apps/gateway/src/plugin/plugin.module.ts
Normal file
117
apps/gateway/src/plugin/plugin.module.ts
Normal file
@@ -0,0 +1,117 @@
|
||||
import {
|
||||
Global,
|
||||
Inject,
|
||||
Logger,
|
||||
Module,
|
||||
type OnModuleDestroy,
|
||||
type OnModuleInit,
|
||||
} from '@nestjs/common';
|
||||
import { DiscordPlugin } from '@mosaic/discord-plugin';
|
||||
import { TelegramPlugin } from '@mosaic/telegram-plugin';
|
||||
import { PluginService } from './plugin.service.js';
|
||||
import type { IChannelPlugin } from './plugin.interface.js';
|
||||
import { PLUGIN_REGISTRY } from './plugin.tokens.js';
|
||||
|
||||
class DiscordChannelPluginAdapter implements IChannelPlugin {
|
||||
readonly name = 'discord';
|
||||
|
||||
constructor(private readonly plugin: DiscordPlugin) {}
|
||||
|
||||
async start(): Promise<void> {
|
||||
await this.plugin.start();
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
await this.plugin.stop();
|
||||
}
|
||||
|
||||
async onProjectCreated(project: {
|
||||
id: string;
|
||||
name: string;
|
||||
description?: string;
|
||||
}): Promise<{ channelId: string } | null> {
|
||||
return this.plugin.createProjectChannel(project);
|
||||
}
|
||||
}
|
||||
|
||||
class TelegramChannelPluginAdapter implements IChannelPlugin {
|
||||
readonly name = 'telegram';
|
||||
|
||||
constructor(private readonly plugin: TelegramPlugin) {}
|
||||
|
||||
async start(): Promise<void> {
|
||||
await this.plugin.start();
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
await this.plugin.stop();
|
||||
}
|
||||
}
|
||||
|
||||
const DEFAULT_GATEWAY_URL = 'http://localhost:4000';
|
||||
|
||||
function createPluginRegistry(): IChannelPlugin[] {
|
||||
const plugins: IChannelPlugin[] = [];
|
||||
const discordToken = process.env['DISCORD_BOT_TOKEN'];
|
||||
const discordGuildId = process.env['DISCORD_GUILD_ID'];
|
||||
const discordGatewayUrl = process.env['DISCORD_GATEWAY_URL'] ?? DEFAULT_GATEWAY_URL;
|
||||
|
||||
if (discordToken) {
|
||||
plugins.push(
|
||||
new DiscordChannelPluginAdapter(
|
||||
new DiscordPlugin({
|
||||
token: discordToken,
|
||||
guildId: discordGuildId,
|
||||
gatewayUrl: discordGatewayUrl,
|
||||
}),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
const telegramToken = process.env['TELEGRAM_BOT_TOKEN'];
|
||||
const telegramGatewayUrl = process.env['TELEGRAM_GATEWAY_URL'] ?? DEFAULT_GATEWAY_URL;
|
||||
|
||||
if (telegramToken) {
|
||||
plugins.push(
|
||||
new TelegramChannelPluginAdapter(
|
||||
new TelegramPlugin({
|
||||
token: telegramToken,
|
||||
gatewayUrl: telegramGatewayUrl,
|
||||
}),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
return plugins;
|
||||
}
|
||||
|
||||
@Global()
|
||||
@Module({
|
||||
providers: [
|
||||
{
|
||||
provide: PLUGIN_REGISTRY,
|
||||
useFactory: (): IChannelPlugin[] => createPluginRegistry(),
|
||||
},
|
||||
PluginService,
|
||||
],
|
||||
exports: [PluginService, PLUGIN_REGISTRY],
|
||||
})
|
||||
export class PluginModule implements OnModuleInit, OnModuleDestroy {
|
||||
private readonly logger = new Logger(PluginModule.name);
|
||||
|
||||
constructor(@Inject(PLUGIN_REGISTRY) private readonly plugins: IChannelPlugin[]) {}
|
||||
|
||||
async onModuleInit(): Promise<void> {
|
||||
for (const plugin of this.plugins) {
|
||||
this.logger.log(`Starting plugin: ${plugin.name}`);
|
||||
await plugin.start();
|
||||
}
|
||||
}
|
||||
|
||||
async onModuleDestroy(): Promise<void> {
|
||||
for (const plugin of [...this.plugins].reverse()) {
|
||||
this.logger.log(`Stopping plugin: ${plugin.name}`);
|
||||
await plugin.stop();
|
||||
}
|
||||
}
|
||||
}
|
||||
16
apps/gateway/src/plugin/plugin.service.ts
Normal file
16
apps/gateway/src/plugin/plugin.service.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { Inject, Injectable } from '@nestjs/common';
|
||||
import { PLUGIN_REGISTRY } from './plugin.tokens.js';
|
||||
import type { IChannelPlugin } from './plugin.interface.js';
|
||||
|
||||
@Injectable()
|
||||
export class PluginService {
|
||||
constructor(@Inject(PLUGIN_REGISTRY) private readonly plugins: IChannelPlugin[]) {}
|
||||
|
||||
getPlugins(): IChannelPlugin[] {
|
||||
return this.plugins;
|
||||
}
|
||||
|
||||
getPlugin(name: string): IChannelPlugin | undefined {
|
||||
return this.plugins.find((plugin: IChannelPlugin) => plugin.name === name);
|
||||
}
|
||||
}
|
||||
1
apps/gateway/src/plugin/plugin.tokens.ts
Normal file
1
apps/gateway/src/plugin/plugin.tokens.ts
Normal file
@@ -0,0 +1 @@
|
||||
export const PLUGIN_REGISTRY = Symbol('PLUGIN_REGISTRY');
|
||||
44
apps/gateway/src/preferences/preferences.controller.ts
Normal file
44
apps/gateway/src/preferences/preferences.controller.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
import {
|
||||
Body,
|
||||
Controller,
|
||||
Delete,
|
||||
Get,
|
||||
HttpCode,
|
||||
HttpStatus,
|
||||
Inject,
|
||||
Param,
|
||||
Post,
|
||||
UseGuards,
|
||||
} from '@nestjs/common';
|
||||
import { PreferencesService } from './preferences.service.js';
|
||||
import { AuthGuard } from '../auth/auth.guard.js';
|
||||
import { CurrentUser } from '../auth/current-user.decorator.js';
|
||||
|
||||
@Controller('api/preferences')
|
||||
@UseGuards(AuthGuard)
|
||||
export class PreferencesController {
|
||||
constructor(@Inject(PreferencesService) private readonly preferences: PreferencesService) {}
|
||||
|
||||
@Get()
|
||||
async show(@CurrentUser() user: { id: string }): Promise<Record<string, unknown>> {
|
||||
return this.preferences.getEffective(user.id);
|
||||
}
|
||||
|
||||
@Post()
|
||||
@HttpCode(HttpStatus.OK)
|
||||
async set(
|
||||
@CurrentUser() user: { id: string },
|
||||
@Body() body: { key: string; value: unknown },
|
||||
): Promise<{ success: boolean; message: string }> {
|
||||
return this.preferences.set(user.id, body.key, body.value);
|
||||
}
|
||||
|
||||
@Delete(':key')
|
||||
@HttpCode(HttpStatus.OK)
|
||||
async reset(
|
||||
@CurrentUser() user: { id: string },
|
||||
@Param('key') key: string,
|
||||
): Promise<{ success: boolean; message: string }> {
|
||||
return this.preferences.reset(user.id, key);
|
||||
}
|
||||
}
|
||||
12
apps/gateway/src/preferences/preferences.module.ts
Normal file
12
apps/gateway/src/preferences/preferences.module.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import { Global, Module } from '@nestjs/common';
|
||||
import { PreferencesService } from './preferences.service.js';
|
||||
import { PreferencesController } from './preferences.controller.js';
|
||||
import { SystemOverrideService } from './system-override.service.js';
|
||||
|
||||
@Global()
|
||||
@Module({
|
||||
controllers: [PreferencesController],
|
||||
providers: [PreferencesService, SystemOverrideService],
|
||||
exports: [PreferencesService, SystemOverrideService],
|
||||
})
|
||||
export class PreferencesModule {}
|
||||
152
apps/gateway/src/preferences/preferences.service.spec.ts
Normal file
152
apps/gateway/src/preferences/preferences.service.spec.ts
Normal file
@@ -0,0 +1,152 @@
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { PreferencesService, PLATFORM_DEFAULTS, IMMUTABLE_KEYS } from './preferences.service.js';
|
||||
import type { Db } from '@mosaic/db';
|
||||
|
||||
/**
|
||||
* Build a mock Drizzle DB where the select chain supports:
|
||||
* db.select().from().where() → resolves to `listRows`
|
||||
* db.insert().values().onConflictDoUpdate() → resolves to []
|
||||
*/
|
||||
function makeMockDb(listRows: Array<{ key: string; value: unknown }> = []): Db {
|
||||
const chainWithLimit = {
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
then: (resolve: (v: typeof listRows) => unknown) => Promise.resolve(listRows).then(resolve),
|
||||
};
|
||||
const selectFrom = {
|
||||
from: vi.fn().mockReturnThis(),
|
||||
where: vi.fn().mockReturnValue(chainWithLimit),
|
||||
};
|
||||
const deleteResult = {
|
||||
where: vi.fn().mockResolvedValue([]),
|
||||
};
|
||||
// Single-round-trip upsert chain: insert().values().onConflictDoUpdate()
|
||||
const insertResult = {
|
||||
values: vi.fn().mockReturnThis(),
|
||||
onConflictDoUpdate: vi.fn().mockResolvedValue([]),
|
||||
};
|
||||
|
||||
return {
|
||||
select: vi.fn().mockReturnValue(selectFrom),
|
||||
delete: vi.fn().mockReturnValue(deleteResult),
|
||||
insert: vi.fn().mockReturnValue(insertResult),
|
||||
} as unknown as Db;
|
||||
}
|
||||
|
||||
describe('PreferencesService', () => {
|
||||
describe('getEffective', () => {
|
||||
it('returns platform defaults when user has no overrides', async () => {
|
||||
const db = makeMockDb([]);
|
||||
const service = new PreferencesService(db);
|
||||
const result = await service.getEffective('user-1');
|
||||
|
||||
expect(result['agent.thinkingLevel']).toBe('auto');
|
||||
expect(result['agent.streamingEnabled']).toBe(true);
|
||||
expect(result['session.autoCompactEnabled']).toBe(true);
|
||||
expect(result['session.autoCompactThreshold']).toBe(0.8);
|
||||
});
|
||||
|
||||
it('applies user overrides for mutable keys', async () => {
|
||||
const db = makeMockDb([
|
||||
{ key: 'agent.thinkingLevel', value: 'high' },
|
||||
{ key: 'response.language', value: 'es' },
|
||||
]);
|
||||
|
||||
const service = new PreferencesService(db);
|
||||
const result = await service.getEffective('user-1');
|
||||
|
||||
expect(result['agent.thinkingLevel']).toBe('high');
|
||||
expect(result['response.language']).toBe('es');
|
||||
});
|
||||
|
||||
it('ignores user overrides for immutable keys — enforcement always wins', async () => {
|
||||
const db = makeMockDb([
|
||||
{ key: 'limits.maxThinkingLevel', value: 'high' },
|
||||
{ key: 'limits.rateLimit', value: 9999 },
|
||||
]);
|
||||
|
||||
const service = new PreferencesService(db);
|
||||
const result = await service.getEffective('user-1');
|
||||
|
||||
// Should still be null (platform default), not the user-supplied values
|
||||
expect(result['limits.maxThinkingLevel']).toBeNull();
|
||||
expect(result['limits.rateLimit']).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('set', () => {
|
||||
it('returns error when attempting to override an immutable key', async () => {
|
||||
const db = makeMockDb();
|
||||
const service = new PreferencesService(db);
|
||||
|
||||
const result = await service.set('user-1', 'limits.maxThinkingLevel', 'high');
|
||||
expect(result.success).toBe(false);
|
||||
expect(result.message).toContain('platform enforcement');
|
||||
});
|
||||
|
||||
it('returns error when attempting to override limits.rateLimit', async () => {
|
||||
const db = makeMockDb();
|
||||
const service = new PreferencesService(db);
|
||||
|
||||
const result = await service.set('user-1', 'limits.rateLimit', 100);
|
||||
expect(result.success).toBe(false);
|
||||
expect(result.message).toContain('platform enforcement');
|
||||
});
|
||||
|
||||
it('upserts a mutable preference and returns success', async () => {
|
||||
// Single-round-trip INSERT … ON CONFLICT DO UPDATE path.
|
||||
const db = makeMockDb([]);
|
||||
const service = new PreferencesService(db);
|
||||
const result = await service.set('user-1', 'agent.thinkingLevel', 'high');
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.message).toContain('"agent.thinkingLevel"');
|
||||
});
|
||||
});
|
||||
|
||||
describe('reset', () => {
|
||||
it('returns error when attempting to reset an immutable key', async () => {
|
||||
const db = makeMockDb();
|
||||
const service = new PreferencesService(db);
|
||||
|
||||
const result = await service.reset('user-1', 'limits.rateLimit');
|
||||
expect(result.success).toBe(false);
|
||||
expect(result.message).toContain('platform enforcement');
|
||||
});
|
||||
|
||||
it('deletes user override and returns default value in message', async () => {
|
||||
const db = makeMockDb();
|
||||
const service = new PreferencesService(db);
|
||||
const result = await service.reset('user-1', 'agent.thinkingLevel');
|
||||
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.message).toContain('"auto"'); // platform default for agent.thinkingLevel
|
||||
});
|
||||
});
|
||||
|
||||
describe('IMMUTABLE_KEYS', () => {
|
||||
it('contains only the enforcement keys', () => {
|
||||
expect(IMMUTABLE_KEYS.has('limits.maxThinkingLevel')).toBe(true);
|
||||
expect(IMMUTABLE_KEYS.has('limits.rateLimit')).toBe(true);
|
||||
expect(IMMUTABLE_KEYS.has('agent.thinkingLevel')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('PLATFORM_DEFAULTS', () => {
|
||||
it('has all expected keys', () => {
|
||||
const expectedKeys = [
|
||||
'agent.defaultModel',
|
||||
'agent.thinkingLevel',
|
||||
'agent.streamingEnabled',
|
||||
'response.language',
|
||||
'response.codeAnnotations',
|
||||
'safety.confirmDestructiveTools',
|
||||
'session.autoCompactThreshold',
|
||||
'session.autoCompactEnabled',
|
||||
'limits.maxThinkingLevel',
|
||||
'limits.rateLimit',
|
||||
];
|
||||
for (const key of expectedKeys) {
|
||||
expect(Object.prototype.hasOwnProperty.call(PLATFORM_DEFAULTS, key)).toBe(true);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
118
apps/gateway/src/preferences/preferences.service.ts
Normal file
118
apps/gateway/src/preferences/preferences.service.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
import { Inject, Injectable, Logger } from '@nestjs/common';
|
||||
import { eq, and, sql, type Db, preferences as preferencesTable } from '@mosaic/db';
|
||||
import { DB } from '../database/database.module.js';
|
||||
|
||||
export const PLATFORM_DEFAULTS: Record<string, unknown> = {
|
||||
'agent.defaultModel': null,
|
||||
'agent.thinkingLevel': 'auto',
|
||||
'agent.streamingEnabled': true,
|
||||
'response.language': 'auto',
|
||||
'response.codeAnnotations': true,
|
||||
'safety.confirmDestructiveTools': true,
|
||||
'session.autoCompactThreshold': 0.8,
|
||||
'session.autoCompactEnabled': true,
|
||||
'limits.maxThinkingLevel': null,
|
||||
'limits.rateLimit': null,
|
||||
};
|
||||
|
||||
export const IMMUTABLE_KEYS = new Set<string>(['limits.maxThinkingLevel', 'limits.rateLimit']);
|
||||
|
||||
@Injectable()
|
||||
export class PreferencesService {
|
||||
private readonly logger = new Logger(PreferencesService.name);
|
||||
|
||||
constructor(@Inject(DB) private readonly db: Db) {}
|
||||
|
||||
/**
|
||||
* Returns the effective preference set for a user:
|
||||
* Platform defaults → user overrides (mutable keys only) → enforcements re-applied last
|
||||
*/
|
||||
async getEffective(userId: string): Promise<Record<string, unknown>> {
|
||||
const userPrefs = await this.getUserPrefs(userId);
|
||||
const result: Record<string, unknown> = { ...PLATFORM_DEFAULTS };
|
||||
|
||||
for (const [key, value] of Object.entries(userPrefs)) {
|
||||
if (!IMMUTABLE_KEYS.has(key)) {
|
||||
result[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
// Re-apply immutable keys (enforcements always win)
|
||||
for (const key of IMMUTABLE_KEYS) {
|
||||
result[key] = PLATFORM_DEFAULTS[key];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async set(
|
||||
userId: string,
|
||||
key: string,
|
||||
value: unknown,
|
||||
): Promise<{ success: boolean; message: string }> {
|
||||
if (IMMUTABLE_KEYS.has(key)) {
|
||||
return {
|
||||
success: false,
|
||||
message: `Cannot override "${key}" — this is a platform enforcement. Contact your admin.`,
|
||||
};
|
||||
}
|
||||
|
||||
await this.upsertPref(userId, key, value);
|
||||
return { success: true, message: `Preference "${key}" set to ${JSON.stringify(value)}.` };
|
||||
}
|
||||
|
||||
async reset(userId: string, key: string): Promise<{ success: boolean; message: string }> {
|
||||
if (IMMUTABLE_KEYS.has(key)) {
|
||||
return { success: false, message: `Cannot reset "${key}" — it is a platform enforcement.` };
|
||||
}
|
||||
|
||||
await this.deletePref(userId, key);
|
||||
const defaultVal = PLATFORM_DEFAULTS[key];
|
||||
return {
|
||||
success: true,
|
||||
message: `Preference "${key}" reset to default: ${JSON.stringify(defaultVal)}.`,
|
||||
};
|
||||
}
|
||||
|
||||
private async getUserPrefs(userId: string): Promise<Record<string, unknown>> {
|
||||
const rows = await this.db
|
||||
.select({ key: preferencesTable.key, value: preferencesTable.value })
|
||||
.from(preferencesTable)
|
||||
.where(eq(preferencesTable.userId, userId));
|
||||
|
||||
const result: Record<string, unknown> = {};
|
||||
for (const row of rows) {
|
||||
result[row.key] = row.value;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private async upsertPref(userId: string, key: string, value: unknown): Promise<void> {
|
||||
// Single-round-trip upsert using INSERT … ON CONFLICT DO UPDATE.
|
||||
// Previously this was two queries (SELECT + INSERT/UPDATE), which doubled
|
||||
// the DB round-trips and introduced a TOCTOU window under concurrent writes.
|
||||
await this.db
|
||||
.insert(preferencesTable)
|
||||
.values({
|
||||
userId,
|
||||
key,
|
||||
value: value as never,
|
||||
mutable: true,
|
||||
})
|
||||
.onConflictDoUpdate({
|
||||
target: [preferencesTable.userId, preferencesTable.key],
|
||||
set: {
|
||||
value: sql`excluded.value`,
|
||||
updatedAt: sql`now()`,
|
||||
},
|
||||
});
|
||||
this.logger.debug(`Upserted preference "${key}" for user ${userId}`);
|
||||
}
|
||||
|
||||
private async deletePref(userId: string, key: string): Promise<void> {
|
||||
await this.db
|
||||
.delete(preferencesTable)
|
||||
.where(and(eq(preferencesTable.userId, userId), eq(preferencesTable.key, key)));
|
||||
this.logger.debug(`Deleted preference "${key}" for user ${userId}`);
|
||||
}
|
||||
}
|
||||
131
apps/gateway/src/preferences/system-override.service.ts
Normal file
131
apps/gateway/src/preferences/system-override.service.ts
Normal file
@@ -0,0 +1,131 @@
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { createQueue, type QueueHandle } from '@mosaic/queue';
|
||||
|
||||
const SESSION_SYSTEM_KEY = (sessionId: string) => `mosaic:session:${sessionId}:system`;
|
||||
const SESSION_SYSTEM_FRAGMENTS_KEY = (sessionId: string) =>
|
||||
`mosaic:session:${sessionId}:system:fragments`;
|
||||
const SYSTEM_OVERRIDE_TTL_SECONDS = 604800; // 7 days
|
||||
|
||||
interface OverrideFragment {
|
||||
text: string;
|
||||
addedAt: number;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class SystemOverrideService {
|
||||
private readonly logger = new Logger(SystemOverrideService.name);
|
||||
private readonly handle: QueueHandle;
|
||||
|
||||
constructor() {
|
||||
this.handle = createQueue();
|
||||
}
|
||||
|
||||
async set(sessionId: string, override: string): Promise<void> {
|
||||
// Load existing fragments
|
||||
const existing = await this.handle.redis.get(SESSION_SYSTEM_FRAGMENTS_KEY(sessionId));
|
||||
const fragments: OverrideFragment[] = existing
|
||||
? (JSON.parse(existing) as OverrideFragment[])
|
||||
: [];
|
||||
|
||||
// Append new fragment
|
||||
fragments.push({ text: override, addedAt: Date.now() });
|
||||
|
||||
// Condense fragments into one coherent override
|
||||
const texts = fragments.map((f) => f.text);
|
||||
const condensed = await this.condenseOverrides(texts);
|
||||
|
||||
// Store both: fragments array and condensed result
|
||||
const pipeline = this.handle.redis.pipeline();
|
||||
pipeline.setex(
|
||||
SESSION_SYSTEM_FRAGMENTS_KEY(sessionId),
|
||||
SYSTEM_OVERRIDE_TTL_SECONDS,
|
||||
JSON.stringify(fragments),
|
||||
);
|
||||
pipeline.setex(SESSION_SYSTEM_KEY(sessionId), SYSTEM_OVERRIDE_TTL_SECONDS, condensed);
|
||||
await pipeline.exec();
|
||||
|
||||
this.logger.debug(
|
||||
`Set system override for session ${sessionId} (${fragments.length} fragment(s), TTL=${SYSTEM_OVERRIDE_TTL_SECONDS}s)`,
|
||||
);
|
||||
}
|
||||
|
||||
async get(sessionId: string): Promise<string | null> {
|
||||
return this.handle.redis.get(SESSION_SYSTEM_KEY(sessionId));
|
||||
}
|
||||
|
||||
async renew(sessionId: string): Promise<void> {
|
||||
const pipeline = this.handle.redis.pipeline();
|
||||
pipeline.expire(SESSION_SYSTEM_KEY(sessionId), SYSTEM_OVERRIDE_TTL_SECONDS);
|
||||
pipeline.expire(SESSION_SYSTEM_FRAGMENTS_KEY(sessionId), SYSTEM_OVERRIDE_TTL_SECONDS);
|
||||
await pipeline.exec();
|
||||
}
|
||||
|
||||
async clear(sessionId: string): Promise<void> {
|
||||
await this.handle.redis.del(
|
||||
SESSION_SYSTEM_KEY(sessionId),
|
||||
SESSION_SYSTEM_FRAGMENTS_KEY(sessionId),
|
||||
);
|
||||
this.logger.debug(`Cleared system override for session ${sessionId}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge an array of override fragments into one coherent string.
|
||||
* If only one fragment exists, returns it as-is.
|
||||
* For multiple fragments, calls Haiku to produce a merged instruction.
|
||||
* Falls back to newline concatenation if the LLM call fails.
|
||||
*/
|
||||
async condenseOverrides(fragments: string[]): Promise<string> {
|
||||
if (fragments.length === 0) return '';
|
||||
if (fragments.length === 1) return fragments[0]!;
|
||||
|
||||
const numbered = fragments.map((f, i) => `${i + 1}. ${f}`).join('\n');
|
||||
const prompt =
|
||||
`Merge these system prompt instructions into one coherent paragraph. ` +
|
||||
`If instructions conflict, favor the most recently added (last in the list). ` +
|
||||
`Be concise — output only the merged instruction, nothing else.\n\n` +
|
||||
`Instructions (oldest first):\n${numbered}`;
|
||||
|
||||
const apiKey = process.env['ANTHROPIC_API_KEY'];
|
||||
if (!apiKey) {
|
||||
this.logger.warn('ANTHROPIC_API_KEY not set — falling back to newline concatenation');
|
||||
return fragments.join('\n');
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch('https://api.anthropic.com/v1/messages', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': apiKey,
|
||||
'anthropic-version': '2023-06-01',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: 'claude-haiku-4-5-20251001',
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: 'user', content: prompt }],
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`Anthropic API error ${response.status}: ${errorText}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
content: Array<{ type: string; text: string }>;
|
||||
};
|
||||
|
||||
const textBlock = data.content.find((c) => c.type === 'text');
|
||||
if (!textBlock) {
|
||||
throw new Error('No text block in Anthropic response');
|
||||
}
|
||||
|
||||
return textBlock.text.trim();
|
||||
} catch (err) {
|
||||
this.logger.error(
|
||||
`Condensation LLM call failed — falling back to newline concatenation: ${String(err)}`,
|
||||
);
|
||||
return fragments.join('\n');
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ import {
|
||||
Body,
|
||||
Controller,
|
||||
Delete,
|
||||
ForbiddenException,
|
||||
Get,
|
||||
HttpCode,
|
||||
HttpStatus,
|
||||
@@ -16,23 +17,25 @@ import type { Brain } from '@mosaic/brain';
|
||||
import { BRAIN } from '../brain/brain.tokens.js';
|
||||
import { AuthGuard } from '../auth/auth.guard.js';
|
||||
import { CurrentUser } from '../auth/current-user.decorator.js';
|
||||
import type { CreateProjectDto, UpdateProjectDto } from './projects.dto.js';
|
||||
import { TeamsService } from '../workspace/teams.service.js';
|
||||
import { CreateProjectDto, UpdateProjectDto } from './projects.dto.js';
|
||||
|
||||
@Controller('api/projects')
|
||||
@UseGuards(AuthGuard)
|
||||
export class ProjectsController {
|
||||
constructor(@Inject(BRAIN) private readonly brain: Brain) {}
|
||||
constructor(
|
||||
@Inject(BRAIN) private readonly brain: Brain,
|
||||
private readonly teamsService: TeamsService,
|
||||
) {}
|
||||
|
||||
@Get()
|
||||
async list() {
|
||||
return this.brain.projects.findAll();
|
||||
async list(@CurrentUser() user: { id: string }) {
|
||||
return this.brain.projects.findAllForUser(user.id);
|
||||
}
|
||||
|
||||
@Get(':id')
|
||||
async findOne(@Param('id') id: string) {
|
||||
const project = await this.brain.projects.findById(id);
|
||||
if (!project) throw new NotFoundException('Project not found');
|
||||
return project;
|
||||
async findOne(@Param('id') id: string, @CurrentUser() user: { id: string }) {
|
||||
return this.getAccessibleProject(id, user.id);
|
||||
}
|
||||
|
||||
@Post()
|
||||
@@ -46,7 +49,12 @@ export class ProjectsController {
|
||||
}
|
||||
|
||||
@Patch(':id')
|
||||
async update(@Param('id') id: string, @Body() dto: UpdateProjectDto) {
|
||||
async update(
|
||||
@Param('id') id: string,
|
||||
@Body() dto: UpdateProjectDto,
|
||||
@CurrentUser() user: { id: string },
|
||||
) {
|
||||
await this.getAccessibleProject(id, user.id);
|
||||
const project = await this.brain.projects.update(id, dto);
|
||||
if (!project) throw new NotFoundException('Project not found');
|
||||
return project;
|
||||
@@ -54,8 +62,22 @@ export class ProjectsController {
|
||||
|
||||
@Delete(':id')
|
||||
@HttpCode(HttpStatus.NO_CONTENT)
|
||||
async remove(@Param('id') id: string) {
|
||||
async remove(@Param('id') id: string, @CurrentUser() user: { id: string }) {
|
||||
await this.getAccessibleProject(id, user.id);
|
||||
const deleted = await this.brain.projects.remove(id);
|
||||
if (!deleted) throw new NotFoundException('Project not found');
|
||||
}
|
||||
|
||||
/**
|
||||
* Verify the requesting user can access the project — either as the direct
|
||||
* owner or as a member of the owning team. Throws NotFoundException when the
|
||||
* project does not exist and ForbiddenException when the user lacks access.
|
||||
*/
|
||||
private async getAccessibleProject(id: string, userId: string) {
|
||||
const project = await this.brain.projects.findById(id);
|
||||
if (!project) throw new NotFoundException('Project not found');
|
||||
const canAccess = await this.teamsService.canAccessProject(userId, id);
|
||||
if (!canAccess) throw new ForbiddenException('Project does not belong to the current user');
|
||||
return project;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,38 @@
|
||||
export interface CreateProjectDto {
|
||||
name: string;
|
||||
import { IsIn, IsObject, IsOptional, IsString, MaxLength } from 'class-validator';
|
||||
|
||||
const projectStatuses = ['active', 'paused', 'completed', 'archived'] as const;
|
||||
|
||||
export class CreateProjectDto {
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
name!: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(10_000)
|
||||
description?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsIn(projectStatuses)
|
||||
status?: 'active' | 'paused' | 'completed' | 'archived';
|
||||
}
|
||||
|
||||
export interface UpdateProjectDto {
|
||||
export class UpdateProjectDto {
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(255)
|
||||
name?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(10_000)
|
||||
description?: string | null;
|
||||
|
||||
@IsOptional()
|
||||
@IsIn(projectStatuses)
|
||||
status?: 'active' | 'paused' | 'completed' | 'archived';
|
||||
|
||||
@IsOptional()
|
||||
@IsObject()
|
||||
metadata?: Record<string, unknown> | null;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import { Module } from '@nestjs/common';
|
||||
import { ProjectsController } from './projects.controller.js';
|
||||
import { WorkspaceModule } from '../workspace/workspace.module.js';
|
||||
|
||||
@Module({
|
||||
imports: [WorkspaceModule],
|
||||
controllers: [ProjectsController],
|
||||
})
|
||||
export class ProjectsModule {}
|
||||
|
||||
20
apps/gateway/src/reload/mosaic-plugin.interface.ts
Normal file
20
apps/gateway/src/reload/mosaic-plugin.interface.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
export interface MosaicPlugin {
|
||||
/** Called when the plugin is loaded/reloaded */
|
||||
onLoad(): Promise<void>;
|
||||
|
||||
/** Called before the plugin is unloaded during reload */
|
||||
onUnload(): Promise<void>;
|
||||
|
||||
/** Plugin identifier for registry */
|
||||
readonly pluginName: string;
|
||||
}
|
||||
|
||||
export function isMosaicPlugin(obj: unknown): obj is MosaicPlugin {
|
||||
return (
|
||||
typeof obj === 'object' &&
|
||||
obj !== null &&
|
||||
typeof (obj as MosaicPlugin).onLoad === 'function' &&
|
||||
typeof (obj as MosaicPlugin).onUnload === 'function' &&
|
||||
typeof (obj as MosaicPlugin).pluginName === 'string'
|
||||
);
|
||||
}
|
||||
22
apps/gateway/src/reload/reload.controller.ts
Normal file
22
apps/gateway/src/reload/reload.controller.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { Controller, HttpCode, HttpStatus, Inject, Post, UseGuards } from '@nestjs/common';
|
||||
import type { SystemReloadPayload } from '@mosaic/types';
|
||||
import { AdminGuard } from '../admin/admin.guard.js';
|
||||
import { ChatGateway } from '../chat/chat.gateway.js';
|
||||
import { ReloadService } from './reload.service.js';
|
||||
|
||||
@Controller('api/admin')
|
||||
@UseGuards(AdminGuard)
|
||||
export class ReloadController {
|
||||
constructor(
|
||||
@Inject(ReloadService) private readonly reloadService: ReloadService,
|
||||
@Inject(ChatGateway) private readonly chatGateway: ChatGateway,
|
||||
) {}
|
||||
|
||||
@Post('reload')
|
||||
@HttpCode(HttpStatus.OK)
|
||||
async triggerReload(): Promise<SystemReloadPayload> {
|
||||
const result = await this.reloadService.reload('rest');
|
||||
this.chatGateway.broadcastReload(result);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user