Merge pull request 'chore: upgrade Node.js runtime to v24 across codebase' (#419) from fix/auth-frontend-remediation into main
Some checks failed
Some checks failed
Reviewed-on: #419
This commit was merged in pull request #419.
This commit is contained in:
257
.env.example
257
.env.example
@@ -19,13 +19,18 @@ NEXT_PUBLIC_API_URL=http://localhost:3001
|
||||
# ======================
|
||||
# PostgreSQL Database
|
||||
# ======================
|
||||
# Bundled PostgreSQL (when database profile enabled)
|
||||
# SECURITY: Change POSTGRES_PASSWORD to a strong random password in production
|
||||
DATABASE_URL=postgresql://mosaic:REPLACE_WITH_SECURE_PASSWORD@localhost:5432/mosaic
|
||||
DATABASE_URL=postgresql://mosaic:REPLACE_WITH_SECURE_PASSWORD@postgres:5432/mosaic
|
||||
POSTGRES_USER=mosaic
|
||||
POSTGRES_PASSWORD=REPLACE_WITH_SECURE_PASSWORD
|
||||
POSTGRES_DB=mosaic
|
||||
POSTGRES_PORT=5432
|
||||
|
||||
# External PostgreSQL (managed service)
|
||||
# Disable 'database' profile and point DATABASE_URL to your external instance
|
||||
# Example: DATABASE_URL=postgresql://user:pass@rds.amazonaws.com:5432/mosaic
|
||||
|
||||
# PostgreSQL Performance Tuning (Optional)
|
||||
POSTGRES_SHARED_BUFFERS=256MB
|
||||
POSTGRES_EFFECTIVE_CACHE_SIZE=1GB
|
||||
@@ -34,12 +39,18 @@ POSTGRES_MAX_CONNECTIONS=100
|
||||
# ======================
|
||||
# Valkey Cache (Redis-compatible)
|
||||
# ======================
|
||||
VALKEY_URL=redis://localhost:6379
|
||||
VALKEY_HOST=localhost
|
||||
# Bundled Valkey (when cache profile enabled)
|
||||
VALKEY_URL=redis://valkey:6379
|
||||
VALKEY_HOST=valkey
|
||||
VALKEY_PORT=6379
|
||||
# VALKEY_PASSWORD= # Optional: Password for Valkey authentication
|
||||
VALKEY_MAXMEMORY=256mb
|
||||
|
||||
# External Redis/Valkey (managed service)
|
||||
# Disable 'cache' profile and point VALKEY_URL to your external instance
|
||||
# Example: VALKEY_URL=redis://elasticache.amazonaws.com:6379
|
||||
# Example with auth: VALKEY_URL=redis://:password@redis.example.com:6379
|
||||
|
||||
# Knowledge Module Cache Configuration
|
||||
# Set KNOWLEDGE_CACHE_ENABLED=false to disable caching (useful for development)
|
||||
KNOWLEDGE_CACHE_ENABLED=true
|
||||
@@ -49,7 +60,12 @@ KNOWLEDGE_CACHE_TTL=300
|
||||
# ======================
|
||||
# Authentication (Authentik OIDC)
|
||||
# ======================
|
||||
# Authentik Server URLs
|
||||
# Set to 'true' to enable OIDC authentication with Authentik
|
||||
# When enabled, OIDC_ISSUER, OIDC_CLIENT_ID, OIDC_CLIENT_SECRET, and OIDC_REDIRECT_URI are required
|
||||
OIDC_ENABLED=false
|
||||
|
||||
# Authentik Server URLs (required when OIDC_ENABLED=true)
|
||||
# OIDC_ISSUER must end with a trailing slash (/)
|
||||
OIDC_ISSUER=https://auth.example.com/application/o/mosaic-stack/
|
||||
OIDC_CLIENT_ID=your-client-id-here
|
||||
OIDC_CLIENT_SECRET=your-client-secret-here
|
||||
@@ -77,6 +93,14 @@ AUTHENTIK_COOKIE_DOMAIN=.localhost
|
||||
AUTHENTIK_PORT_HTTP=9000
|
||||
AUTHENTIK_PORT_HTTPS=9443
|
||||
|
||||
# ======================
|
||||
# CSRF Protection
|
||||
# ======================
|
||||
# CRITICAL: Generate a random secret for CSRF token signing
|
||||
# Required in production; auto-generated in development (not persistent across restarts)
|
||||
# Command to generate: node -e "console.log(require('crypto').randomBytes(32).toString('hex'))"
|
||||
CSRF_SECRET=REPLACE_WITH_64_CHAR_HEX_STRING
|
||||
|
||||
# ======================
|
||||
# JWT Configuration
|
||||
# ======================
|
||||
@@ -85,6 +109,59 @@ AUTHENTIK_PORT_HTTPS=9443
|
||||
JWT_SECRET=REPLACE_WITH_RANDOM_SECRET_MINIMUM_32_CHARS
|
||||
JWT_EXPIRATION=24h
|
||||
|
||||
# ======================
|
||||
# BetterAuth Configuration
|
||||
# ======================
|
||||
# CRITICAL: Generate a random secret key with at least 32 characters
|
||||
# This is used by BetterAuth for session management and CSRF protection
|
||||
# Example: openssl rand -base64 32
|
||||
BETTER_AUTH_SECRET=REPLACE_WITH_RANDOM_SECRET_MINIMUM_32_CHARS
|
||||
|
||||
# Trusted Origins (comma-separated list of additional trusted origins for CORS and auth)
|
||||
# These are added to NEXT_PUBLIC_APP_URL and NEXT_PUBLIC_API_URL automatically
|
||||
TRUSTED_ORIGINS=
|
||||
|
||||
# Cookie Domain (for cross-subdomain session sharing)
|
||||
# Leave empty for single-domain setups. Set to ".example.com" for cross-subdomain.
|
||||
COOKIE_DOMAIN=
|
||||
|
||||
# ======================
|
||||
# Encryption (Credential Security)
|
||||
# ======================
|
||||
# CRITICAL: Generate a random 32-byte (256-bit) encryption key
|
||||
# This key is used for AES-256-GCM encryption of OAuth tokens and sensitive data
|
||||
# Command to generate: openssl rand -hex 32
|
||||
# SECURITY: Never commit this key to version control
|
||||
# SECURITY: Use different keys for development, staging, and production
|
||||
# SECURITY: Store production keys in a secure secrets manager (see docs/design/credential-security.md)
|
||||
ENCRYPTION_KEY=REPLACE_WITH_64_CHAR_HEX_STRING_GENERATE_WITH_OPENSSL_RAND_HEX_32
|
||||
|
||||
# ======================
|
||||
# OpenBao Secrets Management
|
||||
# ======================
|
||||
# OpenBao provides Transit encryption for sensitive credentials
|
||||
# Enable with: COMPOSE_PROFILES=openbao or COMPOSE_PROFILES=full
|
||||
# Auto-initialized on first run via openbao-init sidecar
|
||||
|
||||
# Bundled OpenBao (when openbao profile enabled)
|
||||
OPENBAO_ADDR=http://openbao:8200
|
||||
OPENBAO_PORT=8200
|
||||
|
||||
# External OpenBao/Vault (managed service)
|
||||
# Disable 'openbao' profile and set OPENBAO_ADDR to your external instance
|
||||
# Example: OPENBAO_ADDR=https://vault.example.com:8200
|
||||
# Example: OPENBAO_ADDR=https://vault.hashicorp.com:8200
|
||||
|
||||
# AppRole Authentication (Optional)
|
||||
# If not set, credentials are read from /openbao/init/approle-credentials volume
|
||||
# Required when using external OpenBao
|
||||
# OPENBAO_ROLE_ID=your-role-id-here
|
||||
# OPENBAO_SECRET_ID=your-secret-id-here
|
||||
|
||||
# Fallback Mode
|
||||
# When OpenBao is unavailable, API automatically falls back to AES-256-GCM
|
||||
# encryption using ENCRYPTION_KEY. This provides graceful degradation.
|
||||
|
||||
# ======================
|
||||
# Ollama (Optional AI Service)
|
||||
# ======================
|
||||
@@ -120,15 +197,38 @@ SEMANTIC_SEARCH_SIMILARITY_THRESHOLD=0.5
|
||||
# ======================
|
||||
NODE_ENV=development
|
||||
|
||||
# ======================
|
||||
# Docker Image Configuration
|
||||
# ======================
|
||||
# Docker image tag for pulling pre-built images from git.mosaicstack.dev registry
|
||||
# Used by docker-compose.yml (pulls images) and docker-swarm.yml
|
||||
# For local builds, use docker-compose.build.yml instead
|
||||
# Options:
|
||||
# - dev: Pull development images from registry (default, built from develop branch)
|
||||
# - latest: Pull latest stable images from registry (built from main branch)
|
||||
# - <commit-sha>: Use specific commit SHA tag (e.g., 658ec077)
|
||||
# - <version>: Use specific version tag (e.g., v1.0.0)
|
||||
IMAGE_TAG=dev
|
||||
|
||||
# ======================
|
||||
# Docker Compose Profiles
|
||||
# ======================
|
||||
# Uncomment to enable optional services:
|
||||
# COMPOSE_PROFILES=authentik,ollama # Enable both Authentik and Ollama
|
||||
# COMPOSE_PROFILES=full # Enable all optional services
|
||||
# COMPOSE_PROFILES=authentik # Enable only Authentik
|
||||
# COMPOSE_PROFILES=ollama # Enable only Ollama
|
||||
# COMPOSE_PROFILES=traefik-bundled # Enable bundled Traefik reverse proxy
|
||||
# Enable optional services via profiles. Combine multiple profiles with commas.
|
||||
#
|
||||
# Available profiles:
|
||||
# - database: PostgreSQL database (disable to use external database)
|
||||
# - cache: Valkey cache (disable to use external Redis)
|
||||
# - openbao: OpenBao secrets management (disable to use external vault or fallback encryption)
|
||||
# - authentik: Authentik OIDC authentication (disable to use external auth provider)
|
||||
# - ollama: Ollama AI/LLM service (disable to use external LLM service)
|
||||
# - traefik-bundled: Bundled Traefik reverse proxy (disable to use external proxy)
|
||||
# - full: Enable all optional services (turnkey deployment)
|
||||
#
|
||||
# Examples:
|
||||
# COMPOSE_PROFILES=full # Everything bundled (development)
|
||||
# COMPOSE_PROFILES=database,cache,openbao # Core services only
|
||||
# COMPOSE_PROFILES= # All external services (production)
|
||||
COMPOSE_PROFILES=full
|
||||
|
||||
# ======================
|
||||
# Traefik Reverse Proxy
|
||||
@@ -224,6 +324,143 @@ RATE_LIMIT_STORAGE=redis
|
||||
# multi-tenant isolation. Each Discord bot instance should be configured for
|
||||
# a single workspace.
|
||||
|
||||
# ======================
|
||||
# Matrix Bridge (Optional)
|
||||
# ======================
|
||||
# Matrix bot integration for chat-based control via Matrix protocol
|
||||
# Requires a Matrix account with an access token for the bot user
|
||||
# MATRIX_HOMESERVER_URL=https://matrix.example.com
|
||||
# MATRIX_ACCESS_TOKEN=
|
||||
# MATRIX_BOT_USER_ID=@mosaic-bot:example.com
|
||||
# MATRIX_CONTROL_ROOM_ID=!roomid:example.com
|
||||
# MATRIX_WORKSPACE_ID=your-workspace-uuid
|
||||
#
|
||||
# SECURITY: MATRIX_WORKSPACE_ID must be a valid workspace UUID from your database.
|
||||
# All Matrix commands will execute within this workspace context for proper
|
||||
# multi-tenant isolation. Each Matrix bot instance should be configured for
|
||||
# a single workspace.
|
||||
|
||||
# ======================
|
||||
# Orchestrator Configuration
|
||||
# ======================
|
||||
# API Key for orchestrator agent management endpoints
|
||||
# CRITICAL: Generate a random API key with at least 32 characters
|
||||
# Example: openssl rand -base64 32
|
||||
# Required for all /agents/* endpoints (spawn, kill, kill-all, status)
|
||||
# Health endpoints (/health/*) remain unauthenticated
|
||||
ORCHESTRATOR_API_KEY=REPLACE_WITH_RANDOM_API_KEY_MINIMUM_32_CHARS
|
||||
|
||||
# ======================
|
||||
# AI Provider Configuration
|
||||
# ======================
|
||||
# Choose the AI provider for orchestrator agents
|
||||
# Options: ollama, claude, openai
|
||||
# Default: ollama (no API key required)
|
||||
AI_PROVIDER=ollama
|
||||
|
||||
# Ollama Configuration (when AI_PROVIDER=ollama)
|
||||
# For local Ollama: http://localhost:11434
|
||||
# For remote Ollama: http://your-ollama-server:11434
|
||||
OLLAMA_MODEL=llama3.1:latest
|
||||
|
||||
# Claude API Configuration (when AI_PROVIDER=claude)
|
||||
# OPTIONAL: Only required if AI_PROVIDER=claude
|
||||
# Get your API key from: https://console.anthropic.com/
|
||||
# Note: Claude Max subscription users should use AI_PROVIDER=ollama instead
|
||||
# CLAUDE_API_KEY=sk-ant-...
|
||||
|
||||
# OpenAI API Configuration (when AI_PROVIDER=openai)
|
||||
# OPTIONAL: Only required if AI_PROVIDER=openai
|
||||
# Get your API key from: https://platform.openai.com/api-keys
|
||||
# OPENAI_API_KEY=sk-...
|
||||
|
||||
# ======================
|
||||
# Speech Services (STT / TTS)
|
||||
# ======================
|
||||
# Speech-to-Text (STT) - Whisper via Speaches
|
||||
# Set STT_ENABLED=true to enable speech-to-text transcription
|
||||
# STT_BASE_URL is required when STT_ENABLED=true
|
||||
STT_ENABLED=true
|
||||
STT_BASE_URL=http://speaches:8000/v1
|
||||
STT_MODEL=Systran/faster-whisper-large-v3-turbo
|
||||
STT_LANGUAGE=en
|
||||
|
||||
# Text-to-Speech (TTS) - Default Engine (Kokoro)
|
||||
# Set TTS_ENABLED=true to enable text-to-speech synthesis
|
||||
# TTS_DEFAULT_URL is required when TTS_ENABLED=true
|
||||
TTS_ENABLED=true
|
||||
TTS_DEFAULT_URL=http://kokoro-tts:8880/v1
|
||||
TTS_DEFAULT_VOICE=af_heart
|
||||
TTS_DEFAULT_FORMAT=mp3
|
||||
|
||||
# Text-to-Speech (TTS) - Premium Engine (Chatterbox) - Optional
|
||||
# Higher quality voice cloning engine, disabled by default
|
||||
# TTS_PREMIUM_URL is required when TTS_PREMIUM_ENABLED=true
|
||||
TTS_PREMIUM_ENABLED=false
|
||||
TTS_PREMIUM_URL=http://chatterbox-tts:8881/v1
|
||||
|
||||
# Text-to-Speech (TTS) - Fallback Engine (Piper/OpenedAI) - Optional
|
||||
# Lightweight fallback engine, disabled by default
|
||||
# TTS_FALLBACK_URL is required when TTS_FALLBACK_ENABLED=true
|
||||
TTS_FALLBACK_ENABLED=false
|
||||
TTS_FALLBACK_URL=http://openedai-speech:8000/v1
|
||||
|
||||
# Speech Service Limits
|
||||
# Maximum upload file size in bytes (default: 25MB)
|
||||
SPEECH_MAX_UPLOAD_SIZE=25000000
|
||||
# Maximum audio duration in seconds (default: 600 = 10 minutes)
|
||||
SPEECH_MAX_DURATION_SECONDS=600
|
||||
# Maximum text length for TTS in characters (default: 4096)
|
||||
SPEECH_MAX_TEXT_LENGTH=4096
|
||||
|
||||
# ======================
|
||||
# Mosaic Telemetry (Task Completion Tracking & Predictions)
|
||||
# ======================
|
||||
# Telemetry tracks task completion patterns to provide time estimates and predictions.
|
||||
# Data is sent to the Mosaic Telemetry API (a separate service).
|
||||
|
||||
# Master switch: set to false to completely disable telemetry (no HTTP calls will be made)
|
||||
MOSAIC_TELEMETRY_ENABLED=true
|
||||
|
||||
# URL of the telemetry API server
|
||||
# For Docker Compose (internal): http://telemetry-api:8000
|
||||
# For production/swarm: https://tel-api.mosaicstack.dev
|
||||
MOSAIC_TELEMETRY_SERVER_URL=http://telemetry-api:8000
|
||||
|
||||
# API key for authenticating with the telemetry server
|
||||
# Generate with: openssl rand -hex 32
|
||||
MOSAIC_TELEMETRY_API_KEY=your-64-char-hex-api-key-here
|
||||
|
||||
# Unique identifier for this Mosaic Stack instance
|
||||
# Generate with: uuidgen or python -c "import uuid; print(uuid.uuid4())"
|
||||
MOSAIC_TELEMETRY_INSTANCE_ID=your-instance-uuid-here
|
||||
|
||||
# Dry run mode: set to true to log telemetry events to console instead of sending HTTP requests
|
||||
# Useful for development and debugging telemetry payloads
|
||||
MOSAIC_TELEMETRY_DRY_RUN=false
|
||||
|
||||
# ======================
|
||||
# Matrix Dev Environment (docker-compose.matrix.yml overlay)
|
||||
# ======================
|
||||
# These variables configure the local Matrix dev environment.
|
||||
# Only used when running: docker compose -f docker/docker-compose.yml -f docker/docker-compose.matrix.yml up
|
||||
#
|
||||
# Synapse homeserver
|
||||
# SYNAPSE_CLIENT_PORT=8008
|
||||
# SYNAPSE_FEDERATION_PORT=8448
|
||||
# SYNAPSE_POSTGRES_DB=synapse
|
||||
# SYNAPSE_POSTGRES_USER=synapse
|
||||
# SYNAPSE_POSTGRES_PASSWORD=synapse_dev_password
|
||||
#
|
||||
# Element Web client
|
||||
# ELEMENT_PORT=8501
|
||||
#
|
||||
# Matrix bridge connection (set after running docker/matrix/scripts/setup-bot.sh)
|
||||
# MATRIX_HOMESERVER_URL=http://localhost:8008
|
||||
# MATRIX_ACCESS_TOKEN=<obtained from setup-bot.sh>
|
||||
# MATRIX_BOT_USER_ID=@mosaic-bot:localhost
|
||||
# MATRIX_SERVER_NAME=localhost
|
||||
|
||||
# ======================
|
||||
# Logging & Debugging
|
||||
# ======================
|
||||
|
||||
161
.env.swarm.example
Normal file
161
.env.swarm.example
Normal file
@@ -0,0 +1,161 @@
|
||||
# ==============================================
|
||||
# Mosaic Stack - Docker Swarm Configuration
|
||||
# ==============================================
|
||||
# Copy this file to .env for Docker Swarm deployment
|
||||
|
||||
# ======================
|
||||
# Application Ports (Internal)
|
||||
# ======================
|
||||
API_PORT=3001
|
||||
API_HOST=0.0.0.0
|
||||
WEB_PORT=3000
|
||||
|
||||
# ======================
|
||||
# Domain Configuration (Traefik)
|
||||
# ======================
|
||||
# These domains must be configured in your DNS or /etc/hosts
|
||||
MOSAIC_API_DOMAIN=api.mosaicstack.dev
|
||||
MOSAIC_WEB_DOMAIN=mosaic.mosaicstack.dev
|
||||
MOSAIC_AUTH_DOMAIN=auth.mosaicstack.dev
|
||||
|
||||
# ======================
|
||||
# Web Configuration
|
||||
# ======================
|
||||
# Use the Traefik domain for the API URL
|
||||
NEXT_PUBLIC_APP_URL=http://mosaic.mosaicstack.dev
|
||||
NEXT_PUBLIC_API_URL=http://api.mosaicstack.dev
|
||||
|
||||
# ======================
|
||||
# PostgreSQL Database
|
||||
# ======================
|
||||
DATABASE_URL=postgresql://mosaic:REPLACE_WITH_SECURE_PASSWORD@postgres:5432/mosaic
|
||||
POSTGRES_USER=mosaic
|
||||
POSTGRES_PASSWORD=REPLACE_WITH_SECURE_PASSWORD
|
||||
POSTGRES_DB=mosaic
|
||||
POSTGRES_PORT=5432
|
||||
|
||||
# PostgreSQL Performance Tuning
|
||||
POSTGRES_SHARED_BUFFERS=256MB
|
||||
POSTGRES_EFFECTIVE_CACHE_SIZE=1GB
|
||||
POSTGRES_MAX_CONNECTIONS=100
|
||||
|
||||
# ======================
|
||||
# Valkey Cache
|
||||
# ======================
|
||||
VALKEY_URL=redis://valkey:6379
|
||||
VALKEY_HOST=valkey
|
||||
VALKEY_PORT=6379
|
||||
VALKEY_MAXMEMORY=256mb
|
||||
|
||||
# Knowledge Module Cache Configuration
|
||||
KNOWLEDGE_CACHE_ENABLED=true
|
||||
KNOWLEDGE_CACHE_TTL=300
|
||||
|
||||
# ======================
|
||||
# Authentication (Authentik OIDC)
|
||||
# ======================
|
||||
# NOTE: Authentik services are COMMENTED OUT in docker-compose.swarm.yml by default
|
||||
# Uncomment those services if you want to run Authentik internally
|
||||
# Otherwise, use external Authentik by configuring OIDC_* variables below
|
||||
|
||||
# External Authentik Configuration (default)
|
||||
OIDC_ENABLED=true
|
||||
OIDC_ISSUER=https://auth.example.com/application/o/mosaic-stack/
|
||||
OIDC_CLIENT_ID=your-client-id-here
|
||||
OIDC_CLIENT_SECRET=your-client-secret-here
|
||||
OIDC_REDIRECT_URI=https://api.mosaicstack.dev/auth/callback/authentik
|
||||
|
||||
# Internal Authentik Configuration (only needed if uncommenting Authentik services)
|
||||
# Authentik PostgreSQL Database
|
||||
AUTHENTIK_POSTGRES_USER=authentik
|
||||
AUTHENTIK_POSTGRES_PASSWORD=REPLACE_WITH_SECURE_PASSWORD
|
||||
AUTHENTIK_POSTGRES_DB=authentik
|
||||
|
||||
# Authentik Server Configuration
|
||||
AUTHENTIK_SECRET_KEY=REPLACE_WITH_RANDOM_SECRET_MINIMUM_50_CHARS
|
||||
AUTHENTIK_ERROR_REPORTING=false
|
||||
AUTHENTIK_BOOTSTRAP_PASSWORD=REPLACE_WITH_SECURE_PASSWORD
|
||||
AUTHENTIK_BOOTSTRAP_EMAIL=admin@mosaicstack.dev
|
||||
AUTHENTIK_COOKIE_DOMAIN=.mosaicstack.dev
|
||||
|
||||
# ======================
|
||||
# JWT Configuration
|
||||
# ======================
|
||||
JWT_SECRET=REPLACE_WITH_RANDOM_SECRET_MINIMUM_32_CHARS
|
||||
JWT_EXPIRATION=24h
|
||||
|
||||
# ======================
|
||||
# Encryption (Credential Security)
|
||||
# ======================
|
||||
# Generate with: openssl rand -hex 32
|
||||
ENCRYPTION_KEY=REPLACE_WITH_64_CHAR_HEX_STRING_GENERATE_WITH_OPENSSL_RAND_HEX_32
|
||||
|
||||
# ======================
|
||||
# OpenBao Secrets Management
|
||||
# ======================
|
||||
OPENBAO_ADDR=http://openbao:8200
|
||||
OPENBAO_PORT=8200
|
||||
# For development only - remove in production
|
||||
OPENBAO_DEV_ROOT_TOKEN_ID=root
|
||||
|
||||
# ======================
|
||||
# Ollama (Optional AI Service)
|
||||
# ======================
|
||||
OLLAMA_ENDPOINT=http://ollama:11434
|
||||
OLLAMA_PORT=11434
|
||||
OLLAMA_EMBEDDING_MODEL=mxbai-embed-large
|
||||
|
||||
# Semantic Search Configuration
|
||||
SEMANTIC_SEARCH_SIMILARITY_THRESHOLD=0.5
|
||||
|
||||
# ======================
|
||||
# OpenAI API (Optional)
|
||||
# ======================
|
||||
# OPENAI_API_KEY=sk-...
|
||||
|
||||
# ======================
|
||||
# Application Environment
|
||||
# ======================
|
||||
NODE_ENV=production
|
||||
|
||||
# ======================
|
||||
# Gitea Integration (Coordinator)
|
||||
# ======================
|
||||
GITEA_URL=https://git.mosaicstack.dev
|
||||
GITEA_BOT_USERNAME=mosaic
|
||||
GITEA_BOT_TOKEN=REPLACE_WITH_COORDINATOR_BOT_API_TOKEN
|
||||
GITEA_BOT_PASSWORD=REPLACE_WITH_COORDINATOR_BOT_PASSWORD
|
||||
GITEA_REPO_OWNER=mosaic
|
||||
GITEA_REPO_NAME=stack
|
||||
GITEA_WEBHOOK_SECRET=REPLACE_WITH_RANDOM_WEBHOOK_SECRET
|
||||
COORDINATOR_API_KEY=REPLACE_WITH_RANDOM_API_KEY_MINIMUM_32_CHARS
|
||||
|
||||
# ======================
|
||||
# Coordinator Service
|
||||
# ======================
|
||||
ANTHROPIC_API_KEY=REPLACE_WITH_ANTHROPIC_API_KEY
|
||||
COORDINATOR_POLL_INTERVAL=5.0
|
||||
COORDINATOR_MAX_CONCURRENT_AGENTS=10
|
||||
COORDINATOR_ENABLED=true
|
||||
|
||||
# ======================
|
||||
# Rate Limiting
|
||||
# ======================
|
||||
RATE_LIMIT_TTL=60
|
||||
RATE_LIMIT_GLOBAL_LIMIT=100
|
||||
RATE_LIMIT_WEBHOOK_LIMIT=60
|
||||
RATE_LIMIT_COORDINATOR_LIMIT=100
|
||||
RATE_LIMIT_HEALTH_LIMIT=300
|
||||
RATE_LIMIT_STORAGE=redis
|
||||
|
||||
# ======================
|
||||
# Orchestrator Configuration
|
||||
# ======================
|
||||
ORCHESTRATOR_API_KEY=REPLACE_WITH_RANDOM_API_KEY_MINIMUM_32_CHARS
|
||||
CLAUDE_API_KEY=REPLACE_WITH_CLAUDE_API_KEY
|
||||
|
||||
# ======================
|
||||
# Logging & Debugging
|
||||
# ======================
|
||||
LOG_LEVEL=info
|
||||
DEBUG=false
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -30,10 +30,12 @@ Thumbs.db
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
.env.test
|
||||
.env.development.local
|
||||
.env.test.local
|
||||
.env.production.local
|
||||
.env.bak.*
|
||||
*.bak
|
||||
|
||||
# Credentials (never commit)
|
||||
.admin-credentials
|
||||
@@ -54,3 +56,6 @@ yarn-error.log*
|
||||
|
||||
# Husky
|
||||
.husky/_
|
||||
|
||||
# Orchestrator reports (generated by QA automation, cleaned up after processing)
|
||||
docs/reports/qa-automation/
|
||||
|
||||
1
.npmrc
Normal file
1
.npmrc
Normal file
@@ -0,0 +1 @@
|
||||
@mosaicstack:registry=https://git.mosaicstack.dev/api/packages/mosaic/npm/
|
||||
33
.trivyignore
Normal file
33
.trivyignore
Normal file
@@ -0,0 +1,33 @@
|
||||
# Trivy CVE Suppressions — Upstream Dependencies
|
||||
# Reviewed: 2026-02-13 | Milestone: M11-CIPipeline
|
||||
#
|
||||
# MITIGATED:
|
||||
# - Go stdlib CVEs (6): gosu rebuilt from source with Go 1.26
|
||||
# - npm bundled CVEs (5): npm removed from production Node.js images
|
||||
# - Node.js 20 → 24 LTS migration (#367): base images updated
|
||||
#
|
||||
# REMAINING: OpenBao (5 CVEs) + Next.js bundled tar (3 CVEs)
|
||||
# Re-evaluate when upgrading openbao image beyond 2.5.0 or Next.js beyond 16.1.6.
|
||||
|
||||
# === OpenBao false positives ===
|
||||
# Trivy reads Go module pseudo-version (v0.0.0-20260204...) from bin/bao
|
||||
# and reports CVEs fixed in openbao 2.0.3–2.4.4. We run openbao:2.5.0.
|
||||
CVE-2024-8185 # HIGH: DoS via Raft join (fixed in 2.0.3)
|
||||
CVE-2024-9180 # HIGH: privilege escalation (fixed in 2.0.3)
|
||||
CVE-2025-59043 # HIGH: DoS via malicious JSON (fixed in 2.4.1)
|
||||
CVE-2025-64761 # HIGH: identity group root escalation (fixed in 2.4.4)
|
||||
|
||||
# === Next.js bundled tar CVEs (upstream — waiting on Next.js release) ===
|
||||
# Next.js 16.1.6 bundles tar@7.5.2 in next/dist/compiled/tar/ (pre-compiled).
|
||||
# This is NOT a pnpm dependency — it's embedded in the Next.js package itself.
|
||||
# Affects web image only (orchestrator and API are clean).
|
||||
# npm was also removed from all production images, eliminating the npm-bundled copy.
|
||||
# To resolve: upgrade Next.js when a release bundles tar >= 7.5.7.
|
||||
CVE-2026-23745 # HIGH: tar arbitrary file overwrite via unsanitized linkpaths (fixed in 7.5.3)
|
||||
CVE-2026-23950 # HIGH: tar arbitrary file overwrite via Unicode path collision (fixed in 7.5.4)
|
||||
CVE-2026-24842 # HIGH: tar arbitrary file creation via hardlink path traversal (needs tar >= 7.5.7)
|
||||
|
||||
# === OpenBao Go stdlib (waiting on upstream rebuild) ===
|
||||
# OpenBao 2.5.0 compiled with Go 1.25.6, fix needs Go >= 1.25.7.
|
||||
# Cannot build OpenBao from source (large project). Waiting for upstream release.
|
||||
CVE-2025-68121 # CRITICAL: crypto/tls session resumption
|
||||
185
.woodpecker.yml
185
.woodpecker.yml
@@ -1,185 +0,0 @@
|
||||
# Woodpecker CI Quality Enforcement Pipeline - Monorepo
|
||||
when:
|
||||
- event: [push, pull_request, manual]
|
||||
|
||||
variables:
|
||||
- &node_image "node:20-alpine"
|
||||
- &install_deps |
|
||||
corepack enable
|
||||
pnpm install --frozen-lockfile
|
||||
- &use_deps |
|
||||
corepack enable
|
||||
# Kaniko base command setup
|
||||
- &kaniko_setup |
|
||||
mkdir -p /kaniko/.docker
|
||||
echo "{\"auths\":{\"reg.mosaicstack.dev\":{\"username\":\"$HARBOR_USER\",\"password\":\"$HARBOR_PASS\"}}}" > /kaniko/.docker/config.json
|
||||
|
||||
steps:
|
||||
install:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *install_deps
|
||||
|
||||
security-audit:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm audit --audit-level=high
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
lint:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm lint || true # Non-blocking while fixing legacy code
|
||||
depends_on:
|
||||
- install
|
||||
when:
|
||||
- evaluate: 'CI_PIPELINE_EVENT != "pull_request" || CI_COMMIT_BRANCH != "main"'
|
||||
|
||||
prisma-generate:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/api" prisma:generate
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
typecheck:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm typecheck
|
||||
depends_on:
|
||||
- prisma-generate
|
||||
|
||||
test:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm test || true # Non-blocking while fixing legacy tests
|
||||
depends_on:
|
||||
- prisma-generate
|
||||
|
||||
build:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
NODE_ENV: "production"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm build
|
||||
depends_on:
|
||||
- typecheck # Only block on critical checks
|
||||
- security-audit
|
||||
- prisma-generate
|
||||
|
||||
# ======================
|
||||
# Docker Build & Push (main/develop only)
|
||||
# ======================
|
||||
# Requires secrets: harbor_username, harbor_password
|
||||
#
|
||||
# Tagging Strategy:
|
||||
# - Always: commit SHA (e.g., 658ec077)
|
||||
# - main branch: 'latest'
|
||||
# - develop branch: 'dev'
|
||||
# - git tags: version tag (e.g., v1.0.0)
|
||||
|
||||
# Build and push API image using Kaniko
|
||||
docker-build-api:
|
||||
image: gcr.io/kaniko-project/executor:debug
|
||||
environment:
|
||||
HARBOR_USER:
|
||||
from_secret: harbor_username
|
||||
HARBOR_PASS:
|
||||
from_secret: harbor_password
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
CI_COMMIT_SHA: ${CI_COMMIT_SHA}
|
||||
commands:
|
||||
- *kaniko_setup
|
||||
- |
|
||||
DESTINATIONS="--destination reg.mosaicstack.dev/mosaic/api:${CI_COMMIT_SHA:0:8}"
|
||||
if [ "$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
DESTINATIONS="$DESTINATIONS --destination reg.mosaicstack.dev/mosaic/api:latest"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "develop" ]; then
|
||||
DESTINATIONS="$DESTINATIONS --destination reg.mosaicstack.dev/mosaic/api:dev"
|
||||
fi
|
||||
if [ -n "$CI_COMMIT_TAG" ]; then
|
||||
DESTINATIONS="$DESTINATIONS --destination reg.mosaicstack.dev/mosaic/api:$CI_COMMIT_TAG"
|
||||
fi
|
||||
/kaniko/executor --context . --dockerfile apps/api/Dockerfile $DESTINATIONS
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- build
|
||||
|
||||
# Build and push Web image using Kaniko
|
||||
docker-build-web:
|
||||
image: gcr.io/kaniko-project/executor:debug
|
||||
environment:
|
||||
HARBOR_USER:
|
||||
from_secret: harbor_username
|
||||
HARBOR_PASS:
|
||||
from_secret: harbor_password
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
CI_COMMIT_SHA: ${CI_COMMIT_SHA}
|
||||
commands:
|
||||
- *kaniko_setup
|
||||
- |
|
||||
DESTINATIONS="--destination reg.mosaicstack.dev/mosaic/web:${CI_COMMIT_SHA:0:8}"
|
||||
if [ "$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
DESTINATIONS="$DESTINATIONS --destination reg.mosaicstack.dev/mosaic/web:latest"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "develop" ]; then
|
||||
DESTINATIONS="$DESTINATIONS --destination reg.mosaicstack.dev/mosaic/web:dev"
|
||||
fi
|
||||
if [ -n "$CI_COMMIT_TAG" ]; then
|
||||
DESTINATIONS="$DESTINATIONS --destination reg.mosaicstack.dev/mosaic/web:$CI_COMMIT_TAG"
|
||||
fi
|
||||
/kaniko/executor --context . --dockerfile apps/web/Dockerfile --build-arg NEXT_PUBLIC_API_URL=https://api.mosaicstack.dev $DESTINATIONS
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- build
|
||||
|
||||
# Build and push Postgres image using Kaniko
|
||||
docker-build-postgres:
|
||||
image: gcr.io/kaniko-project/executor:debug
|
||||
environment:
|
||||
HARBOR_USER:
|
||||
from_secret: harbor_username
|
||||
HARBOR_PASS:
|
||||
from_secret: harbor_password
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
CI_COMMIT_SHA: ${CI_COMMIT_SHA}
|
||||
commands:
|
||||
- *kaniko_setup
|
||||
- |
|
||||
DESTINATIONS="--destination reg.mosaicstack.dev/mosaic/postgres:${CI_COMMIT_SHA:0:8}"
|
||||
if [ "$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
DESTINATIONS="$DESTINATIONS --destination reg.mosaicstack.dev/mosaic/postgres:latest"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "develop" ]; then
|
||||
DESTINATIONS="$DESTINATIONS --destination reg.mosaicstack.dev/mosaic/postgres:dev"
|
||||
fi
|
||||
if [ -n "$CI_COMMIT_TAG" ]; then
|
||||
DESTINATIONS="$DESTINATIONS --destination reg.mosaicstack.dev/mosaic/postgres:$CI_COMMIT_TAG"
|
||||
fi
|
||||
/kaniko/executor --context docker/postgres --dockerfile docker/postgres/Dockerfile $DESTINATIONS
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- build
|
||||
142
.woodpecker/README.md
Normal file
142
.woodpecker/README.md
Normal file
@@ -0,0 +1,142 @@
|
||||
# Woodpecker CI Configuration for Mosaic Stack
|
||||
|
||||
## Pipeline Architecture
|
||||
|
||||
Split per-package pipelines with path filtering. Only affected packages rebuild on push.
|
||||
|
||||
```
|
||||
.woodpecker/
|
||||
├── api.yml # @mosaic/api (NestJS)
|
||||
├── web.yml # @mosaic/web (Next.js)
|
||||
├── orchestrator.yml # @mosaic/orchestrator (NestJS)
|
||||
├── coordinator.yml # mosaic-coordinator (Python/FastAPI)
|
||||
├── infra.yml # postgres + openbao Docker images
|
||||
├── codex-review.yml # AI code/security review (PRs only)
|
||||
├── README.md
|
||||
└── schemas/
|
||||
├── code-review-schema.json
|
||||
└── security-review-schema.json
|
||||
```
|
||||
|
||||
## Path Filtering
|
||||
|
||||
| Pipeline | Triggers On |
|
||||
| ------------------ | --------------------------------------------------- |
|
||||
| `api.yml` | `apps/api/**`, `packages/**`, root configs |
|
||||
| `web.yml` | `apps/web/**`, `packages/**`, root configs |
|
||||
| `orchestrator.yml` | `apps/orchestrator/**`, `packages/**`, root configs |
|
||||
| `coordinator.yml` | `apps/coordinator/**` |
|
||||
| `infra.yml` | `docker/**` |
|
||||
| `codex-review.yml` | All PRs (no path filter) |
|
||||
|
||||
**Root configs** = `pnpm-lock.yaml`, `pnpm-workspace.yaml`, `turbo.json`, `package.json`
|
||||
|
||||
## Security Chain
|
||||
|
||||
Every pipeline follows the full security chain required by the CI/CD guide:
|
||||
|
||||
```
|
||||
source scanning (lint + pnpm audit / bandit + pip-audit)
|
||||
-> docker build (Kaniko)
|
||||
-> container scanning (Trivy: HIGH,CRITICAL)
|
||||
-> package linking (Gitea registry)
|
||||
```
|
||||
|
||||
Docker builds gate on ALL quality + security steps passing.
|
||||
|
||||
## Pipeline Dependency Graphs
|
||||
|
||||
### Node.js Apps (api, web, orchestrator)
|
||||
|
||||
```
|
||||
install -> [security-audit, lint, prisma-generate*]
|
||||
prisma-generate* -> [typecheck, prisma-migrate*]
|
||||
prisma-migrate* -> test
|
||||
[all quality gates] -> build -> docker-build -> trivy -> link
|
||||
```
|
||||
|
||||
_\*prisma steps: api.yml only_
|
||||
|
||||
### Coordinator (Python)
|
||||
|
||||
```
|
||||
install -> [ruff-check, mypy, security-bandit, security-pip-audit, test]
|
||||
[all quality gates] -> docker-build -> trivy -> link
|
||||
```
|
||||
|
||||
### Infrastructure
|
||||
|
||||
```
|
||||
[docker-build-postgres, docker-build-openbao]
|
||||
-> [trivy-postgres, trivy-openbao]
|
||||
-> link
|
||||
```
|
||||
|
||||
## Docker Images
|
||||
|
||||
| Image | Registry Path | Context |
|
||||
| ------------------ | ----------------------------------------------- | ------------------- |
|
||||
| stack-api | `git.mosaicstack.dev/mosaic/stack-api` | `.` (monorepo root) |
|
||||
| stack-web | `git.mosaicstack.dev/mosaic/stack-web` | `.` (monorepo root) |
|
||||
| stack-orchestrator | `git.mosaicstack.dev/mosaic/stack-orchestrator` | `.` (monorepo root) |
|
||||
| stack-coordinator | `git.mosaicstack.dev/mosaic/stack-coordinator` | `apps/coordinator` |
|
||||
| stack-postgres | `git.mosaicstack.dev/mosaic/stack-postgres` | `docker/postgres` |
|
||||
| stack-openbao | `git.mosaicstack.dev/mosaic/stack-openbao` | `docker/openbao` |
|
||||
|
||||
## Image Tagging
|
||||
|
||||
| Condition | Tag | Purpose |
|
||||
| ---------------- | -------------------------- | -------------------------- |
|
||||
| Always | `${CI_COMMIT_SHA:0:8}` | Immutable commit reference |
|
||||
| `main` branch | `latest` | Current production release |
|
||||
| `develop` branch | `dev` | Current development build |
|
||||
| Git tag | tag value (e.g., `v1.0.0`) | Semantic version release |
|
||||
|
||||
## Required Secrets
|
||||
|
||||
Configure in Woodpecker UI (Settings > Secrets):
|
||||
|
||||
| Secret | Scope | Purpose |
|
||||
| ---------------- | ----------------- | ------------------------------------------- |
|
||||
| `gitea_username` | push, manual, tag | Gitea registry auth |
|
||||
| `gitea_token` | push, manual, tag | Gitea registry auth (`package:write` scope) |
|
||||
| `codex_api_key` | pull_request | Codex AI reviews |
|
||||
|
||||
## Codex AI Review Pipeline
|
||||
|
||||
The `codex-review.yml` pipeline runs independently on all PRs:
|
||||
|
||||
- **Code review**: Correctness, code quality, testing, performance
|
||||
- **Security review**: OWASP Top 10, hardcoded secrets, injection flaws
|
||||
|
||||
Fails on blockers or critical/high severity security findings.
|
||||
|
||||
### Local Testing
|
||||
|
||||
```bash
|
||||
~/.claude/scripts/codex/codex-code-review.sh --uncommitted
|
||||
~/.claude/scripts/codex/codex-security-review.sh --uncommitted
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "unauthorized: authentication required"
|
||||
|
||||
- Verify `gitea_username` and `gitea_token` secrets in Woodpecker
|
||||
- Verify token has `package:write` scope
|
||||
|
||||
### Trivy scan fails with HIGH/CRITICAL
|
||||
|
||||
- Check if the vulnerability is in the base image (not our code)
|
||||
- Add to `.trivyignore` if it's a known, accepted risk
|
||||
- Use `--ignore-unfixed` (already set) to skip unfixable CVEs
|
||||
|
||||
### Package linking returns 404
|
||||
|
||||
- Normal for recently pushed packages — retry logic handles this
|
||||
- If persistent: verify package name matches exactly (case-sensitive)
|
||||
|
||||
### Pipeline runs Docker builds on pull requests
|
||||
|
||||
- Docker build steps have `when: branch: [main, develop]` guards
|
||||
- PRs only run quality gates, not Docker builds
|
||||
235
.woodpecker/api.yml
Normal file
235
.woodpecker/api.yml
Normal file
@@ -0,0 +1,235 @@
|
||||
# API Pipeline - Mosaic Stack
|
||||
# Quality gates, build, and Docker publish for @mosaic/api
|
||||
#
|
||||
# Triggers on: apps/api/**, packages/**, root configs
|
||||
# Security chain: source audit + Trivy container scan
|
||||
|
||||
when:
|
||||
- event: [push, pull_request, manual]
|
||||
path:
|
||||
include:
|
||||
- "apps/api/**"
|
||||
- "packages/**"
|
||||
- "pnpm-lock.yaml"
|
||||
- "pnpm-workspace.yaml"
|
||||
- "turbo.json"
|
||||
- "package.json"
|
||||
- ".woodpecker/api.yml"
|
||||
|
||||
variables:
|
||||
- &node_image "node:24-alpine"
|
||||
- &install_deps |
|
||||
corepack enable
|
||||
pnpm install --frozen-lockfile
|
||||
- &use_deps |
|
||||
corepack enable
|
||||
- &kaniko_setup |
|
||||
mkdir -p /kaniko/.docker
|
||||
echo "{\"auths\":{\"git.mosaicstack.dev\":{\"username\":\"$GITEA_USER\",\"password\":\"$GITEA_TOKEN\"}}}" > /kaniko/.docker/config.json
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17.7-alpine3.22
|
||||
environment:
|
||||
POSTGRES_DB: test_db
|
||||
POSTGRES_USER: test_user
|
||||
POSTGRES_PASSWORD: test_password
|
||||
|
||||
steps:
|
||||
# === Quality Gates ===
|
||||
|
||||
install:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *install_deps
|
||||
|
||||
security-audit:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm audit --audit-level=high
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
lint:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/api" lint
|
||||
depends_on:
|
||||
- prisma-generate
|
||||
- build-shared
|
||||
|
||||
prisma-generate:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/api" prisma:generate
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
build-shared:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/shared" build
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
typecheck:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/api" typecheck
|
||||
depends_on:
|
||||
- prisma-generate
|
||||
- build-shared
|
||||
|
||||
prisma-migrate:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
DATABASE_URL: "postgresql://test_user:test_password@postgres:5432/test_db?schema=public"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/api" prisma migrate deploy
|
||||
depends_on:
|
||||
- prisma-generate
|
||||
|
||||
test:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
DATABASE_URL: "postgresql://test_user:test_password@postgres:5432/test_db?schema=public"
|
||||
ENCRYPTION_KEY: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/api" exec vitest run --exclude 'src/auth/auth-rls.integration.spec.ts' --exclude 'src/credentials/user-credential.model.spec.ts' --exclude 'src/job-events/job-events.performance.spec.ts' --exclude 'src/knowledge/services/fulltext-search.spec.ts'
|
||||
depends_on:
|
||||
- prisma-migrate
|
||||
|
||||
# === Build ===
|
||||
|
||||
build:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
NODE_ENV: "production"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm turbo build --filter=@mosaic/api
|
||||
depends_on:
|
||||
- lint
|
||||
- typecheck
|
||||
- test
|
||||
- security-audit
|
||||
|
||||
# === Docker Build & Push ===
|
||||
|
||||
docker-build-api:
|
||||
image: gcr.io/kaniko-project/executor:debug
|
||||
environment:
|
||||
GITEA_USER:
|
||||
from_secret: gitea_username
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
commands:
|
||||
- *kaniko_setup
|
||||
- |
|
||||
DESTINATIONS=""
|
||||
if [ -n "$CI_COMMIT_TAG" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-api:$CI_COMMIT_TAG"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-api:latest"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "develop" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-api:dev"
|
||||
fi
|
||||
/kaniko/executor --context . --dockerfile apps/api/Dockerfile $DESTINATIONS
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- build
|
||||
|
||||
# === Container Security Scan ===
|
||||
|
||||
security-trivy-api:
|
||||
image: aquasec/trivy:latest
|
||||
environment:
|
||||
GITEA_USER:
|
||||
from_secret: gitea_username
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
commands:
|
||||
- |
|
||||
if [ -n "$$CI_COMMIT_TAG" ]; then
|
||||
SCAN_TAG="$$CI_COMMIT_TAG"
|
||||
elif [ "$$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
SCAN_TAG="latest"
|
||||
else
|
||||
SCAN_TAG="dev"
|
||||
fi
|
||||
mkdir -p ~/.docker
|
||||
echo "{\"auths\":{\"git.mosaicstack.dev\":{\"username\":\"$$GITEA_USER\",\"password\":\"$$GITEA_TOKEN\"}}}" > ~/.docker/config.json
|
||||
trivy image --exit-code 1 --severity HIGH,CRITICAL --ignore-unfixed \
|
||||
--ignorefile .trivyignore \
|
||||
git.mosaicstack.dev/mosaic/stack-api:$$SCAN_TAG
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- docker-build-api
|
||||
|
||||
# === Package Linking ===
|
||||
|
||||
link-packages:
|
||||
image: alpine:3
|
||||
environment:
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
commands:
|
||||
- apk add --no-cache curl
|
||||
- sleep 10
|
||||
- |
|
||||
set -e
|
||||
link_package() {
|
||||
PKG="$$1"
|
||||
echo "Linking $$PKG..."
|
||||
for attempt in 1 2 3; do
|
||||
STATUS=$$(curl -s -o /tmp/link-response.txt -w "%{http_code}" -X POST \
|
||||
-H "Authorization: token $$GITEA_TOKEN" \
|
||||
"https://git.mosaicstack.dev/api/v1/packages/mosaic/container/$$PKG/-/link/stack")
|
||||
if [ "$$STATUS" = "201" ] || [ "$$STATUS" = "204" ]; then
|
||||
echo " Linked $$PKG"
|
||||
return 0
|
||||
elif [ "$$STATUS" = "400" ]; then
|
||||
echo " $$PKG already linked"
|
||||
return 0
|
||||
elif [ "$$STATUS" = "404" ] && [ $$attempt -lt 3 ]; then
|
||||
echo " $$PKG not found yet, retrying in 5s (attempt $$attempt/3)..."
|
||||
sleep 5
|
||||
else
|
||||
echo " FAILED: $$PKG status $$STATUS"
|
||||
cat /tmp/link-response.txt
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
link_package "stack-api"
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- security-trivy-api
|
||||
90
.woodpecker/codex-review.yml
Normal file
90
.woodpecker/codex-review.yml
Normal file
@@ -0,0 +1,90 @@
|
||||
# Codex AI Review Pipeline for Woodpecker CI
|
||||
# Drop this into your repo's .woodpecker/ directory to enable automated
|
||||
# code and security reviews on every pull request.
|
||||
#
|
||||
# Required secrets:
|
||||
# - codex_api_key: OpenAI API key or Codex-compatible key
|
||||
#
|
||||
# Optional secrets:
|
||||
# - gitea_token: Gitea API token for posting PR comments (if not using tea CLI auth)
|
||||
|
||||
when:
|
||||
event: pull_request
|
||||
|
||||
variables:
|
||||
- &node_image "node:24-slim"
|
||||
- &install_codex "npm i -g @openai/codex"
|
||||
|
||||
steps:
|
||||
# --- Code Quality Review ---
|
||||
code-review:
|
||||
image: *node_image
|
||||
environment:
|
||||
CODEX_API_KEY:
|
||||
from_secret: codex_api_key
|
||||
commands:
|
||||
- *install_codex
|
||||
- apt-get update -qq && apt-get install -y -qq jq git > /dev/null 2>&1
|
||||
|
||||
# Generate the diff
|
||||
- git fetch origin ${CI_COMMIT_TARGET_BRANCH:-main}
|
||||
- DIFF=$(git diff origin/${CI_COMMIT_TARGET_BRANCH:-main}...HEAD)
|
||||
|
||||
# Run code review with structured output
|
||||
- |
|
||||
codex exec \
|
||||
--sandbox read-only \
|
||||
--output-schema .woodpecker/schemas/code-review-schema.json \
|
||||
-o /tmp/code-review.json \
|
||||
"You are an expert code reviewer. Review the following code changes for correctness, code quality, testing, performance, and documentation issues. Only flag actionable, important issues. Categorize as blocker/should-fix/suggestion. If code looks good, say so.
|
||||
|
||||
Changes:
|
||||
$DIFF"
|
||||
|
||||
# Output summary
|
||||
- echo "=== Code Review Results ==="
|
||||
- jq '.' /tmp/code-review.json
|
||||
- |
|
||||
BLOCKERS=$(jq '.stats.blockers // 0' /tmp/code-review.json)
|
||||
if [ "$BLOCKERS" -gt 0 ]; then
|
||||
echo "FAIL: $BLOCKERS blocker(s) found"
|
||||
exit 1
|
||||
fi
|
||||
echo "PASS: No blockers found"
|
||||
|
||||
# --- Security Review ---
|
||||
security-review:
|
||||
image: *node_image
|
||||
environment:
|
||||
CODEX_API_KEY:
|
||||
from_secret: codex_api_key
|
||||
commands:
|
||||
- *install_codex
|
||||
- apt-get update -qq && apt-get install -y -qq jq git > /dev/null 2>&1
|
||||
|
||||
# Generate the diff
|
||||
- git fetch origin ${CI_COMMIT_TARGET_BRANCH:-main}
|
||||
- DIFF=$(git diff origin/${CI_COMMIT_TARGET_BRANCH:-main}...HEAD)
|
||||
|
||||
# Run security review with structured output
|
||||
- |
|
||||
codex exec \
|
||||
--sandbox read-only \
|
||||
--output-schema .woodpecker/schemas/security-review-schema.json \
|
||||
-o /tmp/security-review.json \
|
||||
"You are an expert application security engineer. Review the following code changes for security vulnerabilities including OWASP Top 10, hardcoded secrets, injection flaws, auth/authz gaps, XSS, CSRF, SSRF, path traversal, and supply chain risks. Include CWE IDs and remediation steps. Only flag real security issues, not code quality.
|
||||
|
||||
Changes:
|
||||
$DIFF"
|
||||
|
||||
# Output summary
|
||||
- echo "=== Security Review Results ==="
|
||||
- jq '.' /tmp/security-review.json
|
||||
- |
|
||||
CRITICAL=$(jq '.stats.critical // 0' /tmp/security-review.json)
|
||||
HIGH=$(jq '.stats.high // 0' /tmp/security-review.json)
|
||||
if [ "$CRITICAL" -gt 0 ] || [ "$HIGH" -gt 0 ]; then
|
||||
echo "FAIL: $CRITICAL critical, $HIGH high severity finding(s)"
|
||||
exit 1
|
||||
fi
|
||||
echo "PASS: No critical or high severity findings"
|
||||
180
.woodpecker/coordinator.yml
Normal file
180
.woodpecker/coordinator.yml
Normal file
@@ -0,0 +1,180 @@
|
||||
# Coordinator Pipeline - Mosaic Stack
|
||||
# Quality gates, build, and Docker publish for mosaic-coordinator (Python)
|
||||
#
|
||||
# Triggers on: apps/coordinator/**
|
||||
# Security chain: bandit + pip-audit + Trivy container scan
|
||||
|
||||
when:
|
||||
- event: [push, pull_request, manual]
|
||||
path:
|
||||
include:
|
||||
- "apps/coordinator/**"
|
||||
- ".woodpecker/coordinator.yml"
|
||||
|
||||
variables:
|
||||
- &python_image "python:3.11-slim"
|
||||
- &activate_venv |
|
||||
cd apps/coordinator
|
||||
. venv/bin/activate
|
||||
- &kaniko_setup |
|
||||
mkdir -p /kaniko/.docker
|
||||
echo "{\"auths\":{\"git.mosaicstack.dev\":{\"username\":\"$GITEA_USER\",\"password\":\"$GITEA_TOKEN\"}}}" > /kaniko/.docker/config.json
|
||||
|
||||
steps:
|
||||
# === Quality Gates ===
|
||||
|
||||
install:
|
||||
image: *python_image
|
||||
commands:
|
||||
- cd apps/coordinator
|
||||
- python -m venv venv
|
||||
- . venv/bin/activate
|
||||
- pip install --no-cache-dir --upgrade "pip>=25.3"
|
||||
- pip install --no-cache-dir --extra-index-url https://git.mosaicstack.dev/api/packages/mosaic/pypi/simple/ -e ".[dev]"
|
||||
- pip install --no-cache-dir bandit pip-audit
|
||||
|
||||
ruff-check:
|
||||
image: *python_image
|
||||
commands:
|
||||
- *activate_venv
|
||||
- ruff check src/ tests/
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
mypy:
|
||||
image: *python_image
|
||||
commands:
|
||||
- *activate_venv
|
||||
- mypy src/
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
security-bandit:
|
||||
image: *python_image
|
||||
commands:
|
||||
- *activate_venv
|
||||
- bandit -r src/ -c bandit.yaml -f screen
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
security-pip-audit:
|
||||
image: *python_image
|
||||
commands:
|
||||
- *activate_venv
|
||||
- pip-audit
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
test:
|
||||
image: *python_image
|
||||
commands:
|
||||
- *activate_venv
|
||||
- pytest
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
# === Docker Build & Push ===
|
||||
|
||||
docker-build-coordinator:
|
||||
image: gcr.io/kaniko-project/executor:debug
|
||||
environment:
|
||||
GITEA_USER:
|
||||
from_secret: gitea_username
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
commands:
|
||||
- *kaniko_setup
|
||||
- |
|
||||
DESTINATIONS=""
|
||||
if [ -n "$CI_COMMIT_TAG" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-coordinator:$CI_COMMIT_TAG"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-coordinator:latest"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "develop" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-coordinator:dev"
|
||||
fi
|
||||
/kaniko/executor --context apps/coordinator --dockerfile apps/coordinator/Dockerfile $DESTINATIONS
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- ruff-check
|
||||
- mypy
|
||||
- security-bandit
|
||||
- security-pip-audit
|
||||
- test
|
||||
|
||||
# === Container Security Scan ===
|
||||
|
||||
security-trivy-coordinator:
|
||||
image: aquasec/trivy:latest
|
||||
environment:
|
||||
GITEA_USER:
|
||||
from_secret: gitea_username
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
commands:
|
||||
- |
|
||||
if [ -n "$$CI_COMMIT_TAG" ]; then
|
||||
SCAN_TAG="$$CI_COMMIT_TAG"
|
||||
elif [ "$$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
SCAN_TAG="latest"
|
||||
else
|
||||
SCAN_TAG="dev"
|
||||
fi
|
||||
mkdir -p ~/.docker
|
||||
echo "{\"auths\":{\"git.mosaicstack.dev\":{\"username\":\"$$GITEA_USER\",\"password\":\"$$GITEA_TOKEN\"}}}" > ~/.docker/config.json
|
||||
trivy image --exit-code 1 --severity HIGH,CRITICAL --ignore-unfixed \
|
||||
--ignorefile .trivyignore \
|
||||
git.mosaicstack.dev/mosaic/stack-coordinator:$$SCAN_TAG
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- docker-build-coordinator
|
||||
|
||||
# === Package Linking ===
|
||||
|
||||
link-packages:
|
||||
image: alpine:3
|
||||
environment:
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
commands:
|
||||
- apk add --no-cache curl
|
||||
- sleep 10
|
||||
- |
|
||||
set -e
|
||||
link_package() {
|
||||
PKG="$$1"
|
||||
echo "Linking $$PKG..."
|
||||
for attempt in 1 2 3; do
|
||||
STATUS=$$(curl -s -o /tmp/link-response.txt -w "%{http_code}" -X POST \
|
||||
-H "Authorization: token $$GITEA_TOKEN" \
|
||||
"https://git.mosaicstack.dev/api/v1/packages/mosaic/container/$$PKG/-/link/stack")
|
||||
if [ "$$STATUS" = "201" ] || [ "$$STATUS" = "204" ]; then
|
||||
echo " Linked $$PKG"
|
||||
return 0
|
||||
elif [ "$$STATUS" = "400" ]; then
|
||||
echo " $$PKG already linked"
|
||||
return 0
|
||||
elif [ "$$STATUS" = "404" ] && [ $$attempt -lt 3 ]; then
|
||||
echo " $$PKG not found yet, retrying in 5s (attempt $$attempt/3)..."
|
||||
sleep 5
|
||||
else
|
||||
echo " FAILED: $$PKG status $$STATUS"
|
||||
cat /tmp/link-response.txt
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
link_package "stack-coordinator"
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- security-trivy-coordinator
|
||||
174
.woodpecker/infra.yml
Normal file
174
.woodpecker/infra.yml
Normal file
@@ -0,0 +1,174 @@
|
||||
# Infrastructure Pipeline - Mosaic Stack
|
||||
# Docker build, Trivy scan, and publish for postgres + openbao images
|
||||
#
|
||||
# Triggers on: docker/**
|
||||
# No quality gates — infrastructure images (base image + config only)
|
||||
|
||||
when:
|
||||
- event: [push, manual, tag]
|
||||
path:
|
||||
include:
|
||||
- "docker/**"
|
||||
- ".woodpecker/infra.yml"
|
||||
|
||||
variables:
|
||||
- &kaniko_setup |
|
||||
mkdir -p /kaniko/.docker
|
||||
echo "{\"auths\":{\"git.mosaicstack.dev\":{\"username\":\"$GITEA_USER\",\"password\":\"$GITEA_TOKEN\"}}}" > /kaniko/.docker/config.json
|
||||
|
||||
steps:
|
||||
# === Docker Build & Push ===
|
||||
|
||||
docker-build-postgres:
|
||||
image: gcr.io/kaniko-project/executor:debug
|
||||
environment:
|
||||
GITEA_USER:
|
||||
from_secret: gitea_username
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
commands:
|
||||
- *kaniko_setup
|
||||
- |
|
||||
DESTINATIONS=""
|
||||
if [ -n "$CI_COMMIT_TAG" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-postgres:$CI_COMMIT_TAG"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-postgres:latest"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "develop" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-postgres:dev"
|
||||
fi
|
||||
/kaniko/executor --context docker/postgres --dockerfile docker/postgres/Dockerfile $DESTINATIONS
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
|
||||
docker-build-openbao:
|
||||
image: gcr.io/kaniko-project/executor:debug
|
||||
environment:
|
||||
GITEA_USER:
|
||||
from_secret: gitea_username
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
commands:
|
||||
- *kaniko_setup
|
||||
- |
|
||||
DESTINATIONS=""
|
||||
if [ -n "$CI_COMMIT_TAG" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-openbao:$CI_COMMIT_TAG"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-openbao:latest"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "develop" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-openbao:dev"
|
||||
fi
|
||||
/kaniko/executor --context docker/openbao --dockerfile docker/openbao/Dockerfile $DESTINATIONS
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
|
||||
# === Container Security Scans ===
|
||||
|
||||
security-trivy-postgres:
|
||||
image: aquasec/trivy:latest
|
||||
environment:
|
||||
GITEA_USER:
|
||||
from_secret: gitea_username
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
commands:
|
||||
- |
|
||||
if [ -n "$$CI_COMMIT_TAG" ]; then
|
||||
SCAN_TAG="$$CI_COMMIT_TAG"
|
||||
elif [ "$$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
SCAN_TAG="latest"
|
||||
else
|
||||
SCAN_TAG="dev"
|
||||
fi
|
||||
mkdir -p ~/.docker
|
||||
echo "{\"auths\":{\"git.mosaicstack.dev\":{\"username\":\"$$GITEA_USER\",\"password\":\"$$GITEA_TOKEN\"}}}" > ~/.docker/config.json
|
||||
trivy image --exit-code 1 --severity HIGH,CRITICAL --ignore-unfixed \
|
||||
--ignorefile .trivyignore \
|
||||
git.mosaicstack.dev/mosaic/stack-postgres:$$SCAN_TAG
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- docker-build-postgres
|
||||
|
||||
security-trivy-openbao:
|
||||
image: aquasec/trivy:latest
|
||||
environment:
|
||||
GITEA_USER:
|
||||
from_secret: gitea_username
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
commands:
|
||||
- |
|
||||
if [ -n "$$CI_COMMIT_TAG" ]; then
|
||||
SCAN_TAG="$$CI_COMMIT_TAG"
|
||||
elif [ "$$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
SCAN_TAG="latest"
|
||||
else
|
||||
SCAN_TAG="dev"
|
||||
fi
|
||||
mkdir -p ~/.docker
|
||||
echo "{\"auths\":{\"git.mosaicstack.dev\":{\"username\":\"$$GITEA_USER\",\"password\":\"$$GITEA_TOKEN\"}}}" > ~/.docker/config.json
|
||||
trivy image --exit-code 1 --severity HIGH,CRITICAL --ignore-unfixed \
|
||||
--ignorefile .trivyignore \
|
||||
git.mosaicstack.dev/mosaic/stack-openbao:$$SCAN_TAG
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- docker-build-openbao
|
||||
|
||||
# === Package Linking ===
|
||||
|
||||
link-packages:
|
||||
image: alpine:3
|
||||
environment:
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
commands:
|
||||
- apk add --no-cache curl
|
||||
- sleep 10
|
||||
- |
|
||||
set -e
|
||||
link_package() {
|
||||
PKG="$$1"
|
||||
echo "Linking $$PKG..."
|
||||
for attempt in 1 2 3; do
|
||||
STATUS=$$(curl -s -o /tmp/link-response.txt -w "%{http_code}" -X POST \
|
||||
-H "Authorization: token $$GITEA_TOKEN" \
|
||||
"https://git.mosaicstack.dev/api/v1/packages/mosaic/container/$$PKG/-/link/stack")
|
||||
if [ "$$STATUS" = "201" ] || [ "$$STATUS" = "204" ]; then
|
||||
echo " Linked $$PKG"
|
||||
return 0
|
||||
elif [ "$$STATUS" = "400" ]; then
|
||||
echo " $$PKG already linked"
|
||||
return 0
|
||||
elif [ "$$STATUS" = "404" ] && [ $$attempt -lt 3 ]; then
|
||||
echo " $$PKG not found yet, retrying in 5s (attempt $$attempt/3)..."
|
||||
sleep 5
|
||||
else
|
||||
echo " FAILED: $$PKG status $$STATUS"
|
||||
cat /tmp/link-response.txt
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
link_package "stack-postgres"
|
||||
link_package "stack-openbao"
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- security-trivy-postgres
|
||||
- security-trivy-openbao
|
||||
192
.woodpecker/orchestrator.yml
Normal file
192
.woodpecker/orchestrator.yml
Normal file
@@ -0,0 +1,192 @@
|
||||
# Orchestrator Pipeline - Mosaic Stack
|
||||
# Quality gates, build, and Docker publish for @mosaic/orchestrator
|
||||
#
|
||||
# Triggers on: apps/orchestrator/**, packages/**, root configs
|
||||
# Security chain: source audit + Trivy container scan
|
||||
|
||||
when:
|
||||
- event: [push, pull_request, manual]
|
||||
path:
|
||||
include:
|
||||
- "apps/orchestrator/**"
|
||||
- "packages/**"
|
||||
- "pnpm-lock.yaml"
|
||||
- "pnpm-workspace.yaml"
|
||||
- "turbo.json"
|
||||
- "package.json"
|
||||
- ".woodpecker/orchestrator.yml"
|
||||
|
||||
variables:
|
||||
- &node_image "node:24-alpine"
|
||||
- &install_deps |
|
||||
corepack enable
|
||||
pnpm install --frozen-lockfile
|
||||
- &use_deps |
|
||||
corepack enable
|
||||
- &kaniko_setup |
|
||||
mkdir -p /kaniko/.docker
|
||||
echo "{\"auths\":{\"git.mosaicstack.dev\":{\"username\":\"$GITEA_USER\",\"password\":\"$GITEA_TOKEN\"}}}" > /kaniko/.docker/config.json
|
||||
|
||||
steps:
|
||||
# === Quality Gates ===
|
||||
|
||||
install:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *install_deps
|
||||
|
||||
security-audit:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm audit --audit-level=high
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
lint:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/orchestrator" lint
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
typecheck:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/orchestrator" typecheck
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
test:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/orchestrator" test
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
# === Build ===
|
||||
|
||||
build:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
NODE_ENV: "production"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm turbo build --filter=@mosaic/orchestrator
|
||||
depends_on:
|
||||
- lint
|
||||
- typecheck
|
||||
- test
|
||||
- security-audit
|
||||
|
||||
# === Docker Build & Push ===
|
||||
|
||||
docker-build-orchestrator:
|
||||
image: gcr.io/kaniko-project/executor:debug
|
||||
environment:
|
||||
GITEA_USER:
|
||||
from_secret: gitea_username
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
commands:
|
||||
- *kaniko_setup
|
||||
- |
|
||||
DESTINATIONS=""
|
||||
if [ -n "$CI_COMMIT_TAG" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-orchestrator:$CI_COMMIT_TAG"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-orchestrator:latest"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "develop" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-orchestrator:dev"
|
||||
fi
|
||||
/kaniko/executor --context . --dockerfile apps/orchestrator/Dockerfile $DESTINATIONS
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- build
|
||||
|
||||
# === Container Security Scan ===
|
||||
|
||||
security-trivy-orchestrator:
|
||||
image: aquasec/trivy:latest
|
||||
environment:
|
||||
GITEA_USER:
|
||||
from_secret: gitea_username
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
commands:
|
||||
- |
|
||||
if [ -n "$$CI_COMMIT_TAG" ]; then
|
||||
SCAN_TAG="$$CI_COMMIT_TAG"
|
||||
elif [ "$$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
SCAN_TAG="latest"
|
||||
else
|
||||
SCAN_TAG="dev"
|
||||
fi
|
||||
mkdir -p ~/.docker
|
||||
echo "{\"auths\":{\"git.mosaicstack.dev\":{\"username\":\"$$GITEA_USER\",\"password\":\"$$GITEA_TOKEN\"}}}" > ~/.docker/config.json
|
||||
trivy image --exit-code 1 --severity HIGH,CRITICAL --ignore-unfixed \
|
||||
--ignorefile .trivyignore \
|
||||
git.mosaicstack.dev/mosaic/stack-orchestrator:$$SCAN_TAG
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- docker-build-orchestrator
|
||||
|
||||
# === Package Linking ===
|
||||
|
||||
link-packages:
|
||||
image: alpine:3
|
||||
environment:
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
commands:
|
||||
- apk add --no-cache curl
|
||||
- sleep 10
|
||||
- |
|
||||
set -e
|
||||
link_package() {
|
||||
PKG="$$1"
|
||||
echo "Linking $$PKG..."
|
||||
for attempt in 1 2 3; do
|
||||
STATUS=$$(curl -s -o /tmp/link-response.txt -w "%{http_code}" -X POST \
|
||||
-H "Authorization: token $$GITEA_TOKEN" \
|
||||
"https://git.mosaicstack.dev/api/v1/packages/mosaic/container/$$PKG/-/link/stack")
|
||||
if [ "$$STATUS" = "201" ] || [ "$$STATUS" = "204" ]; then
|
||||
echo " Linked $$PKG"
|
||||
return 0
|
||||
elif [ "$$STATUS" = "400" ]; then
|
||||
echo " $$PKG already linked"
|
||||
return 0
|
||||
elif [ "$$STATUS" = "404" ] && [ $$attempt -lt 3 ]; then
|
||||
echo " $$PKG not found yet, retrying in 5s (attempt $$attempt/3)..."
|
||||
sleep 5
|
||||
else
|
||||
echo " FAILED: $$PKG status $$STATUS"
|
||||
cat /tmp/link-response.txt
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
link_package "stack-orchestrator"
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- security-trivy-orchestrator
|
||||
92
.woodpecker/schemas/code-review-schema.json
Normal file
92
.woodpecker/schemas/code-review-schema.json
Normal file
@@ -0,0 +1,92 @@
|
||||
{
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "Brief overall assessment of the code changes"
|
||||
},
|
||||
"verdict": {
|
||||
"type": "string",
|
||||
"enum": ["approve", "request-changes", "comment"],
|
||||
"description": "Overall review verdict"
|
||||
},
|
||||
"confidence": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
"description": "Confidence score for the review (0-1)"
|
||||
},
|
||||
"findings": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"severity": {
|
||||
"type": "string",
|
||||
"enum": ["blocker", "should-fix", "suggestion"],
|
||||
"description": "Finding severity: blocker (must fix), should-fix (important), suggestion (optional)"
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Short title describing the issue"
|
||||
},
|
||||
"file": {
|
||||
"type": "string",
|
||||
"description": "File path where the issue was found"
|
||||
},
|
||||
"line_start": {
|
||||
"type": "integer",
|
||||
"description": "Starting line number"
|
||||
},
|
||||
"line_end": {
|
||||
"type": "integer",
|
||||
"description": "Ending line number"
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Detailed explanation of the issue"
|
||||
},
|
||||
"suggestion": {
|
||||
"type": "string",
|
||||
"description": "Suggested fix or improvement"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"severity",
|
||||
"title",
|
||||
"file",
|
||||
"line_start",
|
||||
"line_end",
|
||||
"description",
|
||||
"suggestion"
|
||||
]
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"files_reviewed": {
|
||||
"type": "integer",
|
||||
"description": "Number of files reviewed"
|
||||
},
|
||||
"blockers": {
|
||||
"type": "integer",
|
||||
"description": "Count of blocker findings"
|
||||
},
|
||||
"should_fix": {
|
||||
"type": "integer",
|
||||
"description": "Count of should-fix findings"
|
||||
},
|
||||
"suggestions": {
|
||||
"type": "integer",
|
||||
"description": "Count of suggestion findings"
|
||||
}
|
||||
},
|
||||
"required": ["files_reviewed", "blockers", "should_fix", "suggestions"]
|
||||
}
|
||||
},
|
||||
"required": ["summary", "verdict", "confidence", "findings", "stats"]
|
||||
}
|
||||
106
.woodpecker/schemas/security-review-schema.json
Normal file
106
.woodpecker/schemas/security-review-schema.json
Normal file
@@ -0,0 +1,106 @@
|
||||
{
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "Brief overall security assessment of the code changes"
|
||||
},
|
||||
"risk_level": {
|
||||
"type": "string",
|
||||
"enum": ["critical", "high", "medium", "low", "none"],
|
||||
"description": "Overall security risk level"
|
||||
},
|
||||
"confidence": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1,
|
||||
"description": "Confidence score for the review (0-1)"
|
||||
},
|
||||
"findings": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"severity": {
|
||||
"type": "string",
|
||||
"enum": ["critical", "high", "medium", "low"],
|
||||
"description": "Vulnerability severity level"
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Short title describing the vulnerability"
|
||||
},
|
||||
"file": {
|
||||
"type": "string",
|
||||
"description": "File path where the vulnerability was found"
|
||||
},
|
||||
"line_start": {
|
||||
"type": "integer",
|
||||
"description": "Starting line number"
|
||||
},
|
||||
"line_end": {
|
||||
"type": "integer",
|
||||
"description": "Ending line number"
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Detailed explanation of the vulnerability"
|
||||
},
|
||||
"cwe_id": {
|
||||
"type": "string",
|
||||
"description": "CWE identifier if applicable (e.g., CWE-79)"
|
||||
},
|
||||
"owasp_category": {
|
||||
"type": "string",
|
||||
"description": "OWASP Top 10 category if applicable (e.g., A03:2021-Injection)"
|
||||
},
|
||||
"remediation": {
|
||||
"type": "string",
|
||||
"description": "Specific remediation steps to fix the vulnerability"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"severity",
|
||||
"title",
|
||||
"file",
|
||||
"line_start",
|
||||
"line_end",
|
||||
"description",
|
||||
"cwe_id",
|
||||
"owasp_category",
|
||||
"remediation"
|
||||
]
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
"files_reviewed": {
|
||||
"type": "integer",
|
||||
"description": "Number of files reviewed"
|
||||
},
|
||||
"critical": {
|
||||
"type": "integer",
|
||||
"description": "Count of critical findings"
|
||||
},
|
||||
"high": {
|
||||
"type": "integer",
|
||||
"description": "Count of high findings"
|
||||
},
|
||||
"medium": {
|
||||
"type": "integer",
|
||||
"description": "Count of medium findings"
|
||||
},
|
||||
"low": {
|
||||
"type": "integer",
|
||||
"description": "Count of low findings"
|
||||
}
|
||||
},
|
||||
"required": ["files_reviewed", "critical", "high", "medium", "low"]
|
||||
}
|
||||
},
|
||||
"required": ["summary", "risk_level", "confidence", "findings", "stats"]
|
||||
}
|
||||
203
.woodpecker/web.yml
Normal file
203
.woodpecker/web.yml
Normal file
@@ -0,0 +1,203 @@
|
||||
# Web Pipeline - Mosaic Stack
|
||||
# Quality gates, build, and Docker publish for @mosaic/web
|
||||
#
|
||||
# Triggers on: apps/web/**, packages/**, root configs
|
||||
# Security chain: source audit + Trivy container scan
|
||||
|
||||
when:
|
||||
- event: [push, pull_request, manual]
|
||||
path:
|
||||
include:
|
||||
- "apps/web/**"
|
||||
- "packages/**"
|
||||
- "pnpm-lock.yaml"
|
||||
- "pnpm-workspace.yaml"
|
||||
- "turbo.json"
|
||||
- "package.json"
|
||||
- ".woodpecker/web.yml"
|
||||
|
||||
variables:
|
||||
- &node_image "node:24-alpine"
|
||||
- &install_deps |
|
||||
corepack enable
|
||||
pnpm install --frozen-lockfile
|
||||
- &use_deps |
|
||||
corepack enable
|
||||
- &kaniko_setup |
|
||||
mkdir -p /kaniko/.docker
|
||||
echo "{\"auths\":{\"git.mosaicstack.dev\":{\"username\":\"$GITEA_USER\",\"password\":\"$GITEA_TOKEN\"}}}" > /kaniko/.docker/config.json
|
||||
|
||||
steps:
|
||||
# === Quality Gates ===
|
||||
|
||||
install:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *install_deps
|
||||
|
||||
security-audit:
|
||||
image: *node_image
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm audit --audit-level=high
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
build-shared:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/shared" build
|
||||
- pnpm --filter "@mosaic/ui" build
|
||||
depends_on:
|
||||
- install
|
||||
|
||||
lint:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/web" lint
|
||||
depends_on:
|
||||
- build-shared
|
||||
|
||||
typecheck:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/web" typecheck
|
||||
depends_on:
|
||||
- build-shared
|
||||
|
||||
test:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm --filter "@mosaic/web" test
|
||||
depends_on:
|
||||
- build-shared
|
||||
|
||||
# === Build ===
|
||||
|
||||
build:
|
||||
image: *node_image
|
||||
environment:
|
||||
SKIP_ENV_VALIDATION: "true"
|
||||
NODE_ENV: "production"
|
||||
commands:
|
||||
- *use_deps
|
||||
- pnpm turbo build --filter=@mosaic/web
|
||||
depends_on:
|
||||
- lint
|
||||
- typecheck
|
||||
- test
|
||||
- security-audit
|
||||
|
||||
# === Docker Build & Push ===
|
||||
|
||||
docker-build-web:
|
||||
image: gcr.io/kaniko-project/executor:debug
|
||||
environment:
|
||||
GITEA_USER:
|
||||
from_secret: gitea_username
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
commands:
|
||||
- *kaniko_setup
|
||||
- |
|
||||
DESTINATIONS=""
|
||||
if [ -n "$CI_COMMIT_TAG" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-web:$CI_COMMIT_TAG"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-web:latest"
|
||||
elif [ "$CI_COMMIT_BRANCH" = "develop" ]; then
|
||||
DESTINATIONS="--destination git.mosaicstack.dev/mosaic/stack-web:dev"
|
||||
fi
|
||||
/kaniko/executor --context . --dockerfile apps/web/Dockerfile --build-arg NEXT_PUBLIC_API_URL=https://api.mosaicstack.dev $DESTINATIONS
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- build
|
||||
|
||||
# === Container Security Scan ===
|
||||
|
||||
security-trivy-web:
|
||||
image: aquasec/trivy:latest
|
||||
environment:
|
||||
GITEA_USER:
|
||||
from_secret: gitea_username
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
CI_COMMIT_BRANCH: ${CI_COMMIT_BRANCH}
|
||||
CI_COMMIT_TAG: ${CI_COMMIT_TAG}
|
||||
commands:
|
||||
- |
|
||||
if [ -n "$$CI_COMMIT_TAG" ]; then
|
||||
SCAN_TAG="$$CI_COMMIT_TAG"
|
||||
elif [ "$$CI_COMMIT_BRANCH" = "main" ]; then
|
||||
SCAN_TAG="latest"
|
||||
else
|
||||
SCAN_TAG="dev"
|
||||
fi
|
||||
mkdir -p ~/.docker
|
||||
echo "{\"auths\":{\"git.mosaicstack.dev\":{\"username\":\"$$GITEA_USER\",\"password\":\"$$GITEA_TOKEN\"}}}" > ~/.docker/config.json
|
||||
trivy image --exit-code 1 --severity HIGH,CRITICAL --ignore-unfixed \
|
||||
--ignorefile .trivyignore \
|
||||
git.mosaicstack.dev/mosaic/stack-web:$$SCAN_TAG
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- docker-build-web
|
||||
|
||||
# === Package Linking ===
|
||||
|
||||
link-packages:
|
||||
image: alpine:3
|
||||
environment:
|
||||
GITEA_TOKEN:
|
||||
from_secret: gitea_token
|
||||
commands:
|
||||
- apk add --no-cache curl
|
||||
- sleep 10
|
||||
- |
|
||||
set -e
|
||||
link_package() {
|
||||
PKG="$$1"
|
||||
echo "Linking $$PKG..."
|
||||
for attempt in 1 2 3; do
|
||||
STATUS=$$(curl -s -o /tmp/link-response.txt -w "%{http_code}" -X POST \
|
||||
-H "Authorization: token $$GITEA_TOKEN" \
|
||||
"https://git.mosaicstack.dev/api/v1/packages/mosaic/container/$$PKG/-/link/stack")
|
||||
if [ "$$STATUS" = "201" ] || [ "$$STATUS" = "204" ]; then
|
||||
echo " Linked $$PKG"
|
||||
return 0
|
||||
elif [ "$$STATUS" = "400" ]; then
|
||||
echo " $$PKG already linked"
|
||||
return 0
|
||||
elif [ "$$STATUS" = "404" ] && [ $$attempt -lt 3 ]; then
|
||||
echo " $$PKG not found yet, retrying in 5s (attempt $$attempt/3)..."
|
||||
sleep 5
|
||||
else
|
||||
echo " FAILED: $$PKG status $$STATUS"
|
||||
cat /tmp/link-response.txt
|
||||
return 1
|
||||
fi
|
||||
done
|
||||
}
|
||||
link_package "stack-web"
|
||||
when:
|
||||
- branch: [main, develop]
|
||||
event: [push, manual, tag]
|
||||
depends_on:
|
||||
- security-trivy-web
|
||||
118
AGENTS.md
118
AGENTS.md
@@ -1,101 +1,37 @@
|
||||
# AGENTS.md — Mosaic Stack
|
||||
# Mosaic Stack — Agent Guidelines
|
||||
|
||||
Guidelines for AI agents working on this codebase.
|
||||
> **Any AI model, coding assistant, or framework working in this codebase MUST read and follow `CLAUDE.md` in the project root.**
|
||||
|
||||
## Quick Start
|
||||
`CLAUDE.md` is the authoritative source for:
|
||||
|
||||
1. Read `CLAUDE.md` for project-specific patterns
|
||||
2. Check this file for workflow and context management
|
||||
3. Use `TOOLS.md` patterns (if present) before fumbling with CLIs
|
||||
- Technology stack and versions
|
||||
- TypeScript strict mode requirements
|
||||
- ESLint Quality Rails (error-level enforcement)
|
||||
- Prettier formatting rules
|
||||
- Testing requirements (85% coverage, TDD)
|
||||
- API conventions and database patterns
|
||||
- Commit format and branch strategy
|
||||
- PDA-friendly design principles
|
||||
|
||||
## Context Management
|
||||
## Quick Rules (Read CLAUDE.md for Details)
|
||||
|
||||
Context = tokens = cost. Be smart.
|
||||
- **No `any` types** — use `unknown`, generics, or proper types
|
||||
- **Explicit return types** on all functions
|
||||
- **Type-only imports** — `import type { Foo }` for types
|
||||
- **Double quotes**, semicolons, 2-space indent, 100 char width
|
||||
- **`??` not `||`** for defaults, **`?.`** not `&&` chains
|
||||
- **All promises** must be awaited or returned
|
||||
- **85% test coverage** minimum, tests before implementation
|
||||
|
||||
| Strategy | When |
|
||||
| ----------------------------- | -------------------------------------------------------------- |
|
||||
| **Spawn sub-agents** | Isolated coding tasks, research, anything that can report back |
|
||||
| **Batch operations** | Group related API calls, don't do one-at-a-time |
|
||||
| **Check existing patterns** | Before writing new code, see how similar features were built |
|
||||
| **Minimize re-reading** | Don't re-read files you just wrote |
|
||||
| **Summarize before clearing** | Extract learnings to memory before context reset |
|
||||
## Updating Conventions
|
||||
|
||||
## Workflow (Non-Negotiable)
|
||||
If you discover new patterns, gotchas, or conventions while working in this codebase, **update `CLAUDE.md`** — not this file. This file exists solely to redirect agents that look for `AGENTS.md` to the canonical source.
|
||||
|
||||
### Code Changes
|
||||
## Per-App Context
|
||||
|
||||
```
|
||||
1. Branch → git checkout -b feature/XX-description
|
||||
2. Code → TDD: write test (RED), implement (GREEN), refactor
|
||||
3. Test → pnpm test (must pass)
|
||||
4. Push → git push origin feature/XX-description
|
||||
5. PR → Create PR to develop (not main)
|
||||
6. Review → Wait for approval or self-merge if authorized
|
||||
7. Close → Close related issues via API
|
||||
```
|
||||
Each app directory has its own `AGENTS.md` for app-specific patterns:
|
||||
|
||||
**Never merge directly to develop without a PR.**
|
||||
|
||||
### Issue Management
|
||||
|
||||
```bash
|
||||
# Get Gitea token
|
||||
TOKEN="$(jq -r '.gitea.mosaicstack.token' ~/src/jarvis-brain/credentials.json)"
|
||||
|
||||
# Create issue
|
||||
curl -s -H "Authorization: token $TOKEN" -H "Content-Type: application/json" \
|
||||
"https://git.mosaicstack.dev/api/v1/repos/mosaic/stack/issues" \
|
||||
-d '{"title":"Title","body":"Description","milestone":54}'
|
||||
|
||||
# Close issue (REQUIRED after merge)
|
||||
curl -s -X PATCH -H "Authorization: token $TOKEN" -H "Content-Type: application/json" \
|
||||
"https://git.mosaicstack.dev/api/v1/repos/mosaic/stack/issues/XX" \
|
||||
-d '{"state":"closed"}'
|
||||
|
||||
# Create PR (tea CLI works for this)
|
||||
tea pulls create --repo mosaic/stack --base develop --head feature/XX-name \
|
||||
--title "feat(#XX): Title" --description "Description"
|
||||
```
|
||||
|
||||
### Commit Messages
|
||||
|
||||
```
|
||||
<type>(#issue): Brief description
|
||||
|
||||
Detailed explanation if needed.
|
||||
|
||||
Closes #XX, #YY
|
||||
```
|
||||
|
||||
Types: `feat`, `fix`, `docs`, `test`, `refactor`, `chore`
|
||||
|
||||
## TDD Requirements
|
||||
|
||||
**All code must follow TDD. This is non-negotiable.**
|
||||
|
||||
1. **RED** — Write failing test first
|
||||
2. **GREEN** — Minimal code to pass
|
||||
3. **REFACTOR** — Clean up while tests stay green
|
||||
|
||||
Minimum 85% coverage for new code.
|
||||
|
||||
## Token-Saving Tips
|
||||
|
||||
- **Sub-agents die after task** — their context doesn't pollute main session
|
||||
- **API over CLI** when CLI needs TTY or confirmation prompts
|
||||
- **One commit** with all issue numbers, not separate commits per issue
|
||||
- **Don't re-read** files you just wrote
|
||||
- **Batch similar operations** — create all issues at once, close all at once
|
||||
|
||||
## Key Files
|
||||
|
||||
| File | Purpose |
|
||||
| ------------------------------- | ----------------------------------------- |
|
||||
| `CLAUDE.md` | Project overview, tech stack, conventions |
|
||||
| `CONTRIBUTING.md` | Human contributor guide |
|
||||
| `apps/api/prisma/schema.prisma` | Database schema |
|
||||
| `docs/` | Architecture and setup docs |
|
||||
|
||||
---
|
||||
|
||||
_Model-agnostic. Works for Claude, MiniMax, GPT, Llama, etc._
|
||||
- `apps/api/AGENTS.md`
|
||||
- `apps/web/AGENTS.md`
|
||||
- `apps/coordinator/AGENTS.md`
|
||||
- `apps/orchestrator/AGENTS.md`
|
||||
|
||||
35
CLAUDE.md
35
CLAUDE.md
@@ -1,6 +1,19 @@
|
||||
**Multi-tenant personal assistant platform with PostgreSQL backend, Authentik SSO, and MoltBot
|
||||
integration.**
|
||||
|
||||
## Conditional Documentation Loading
|
||||
|
||||
| When working on... | Load this guide |
|
||||
| ---------------------------------------- | ------------------------------------------------------------------- |
|
||||
| Orchestrating autonomous task completion | `docs/claude/orchestrator.md` |
|
||||
| Security remediation (review findings) | `docs/reports/codebase-review-2026-02-05/01-security-review.md` |
|
||||
| Code quality fixes | `docs/reports/codebase-review-2026-02-05/02-code-quality-review.md` |
|
||||
| Test coverage gaps | `docs/reports/codebase-review-2026-02-05/03-qa-test-coverage.md` |
|
||||
|
||||
## Platform Templates
|
||||
|
||||
Bootstrap templates are at `docs/templates/`. See `docs/templates/README.md` for usage.
|
||||
|
||||
## Project Overview
|
||||
|
||||
Mosaic Stack is a standalone platform that provides:
|
||||
@@ -462,3 +475,25 @@ Related Repositories
|
||||
---
|
||||
|
||||
Mosaic Stack v0.0.x — Building the future of personal assistants.
|
||||
|
||||
## Campsite Rule (MANDATORY)
|
||||
|
||||
If you modify a line containing a policy violation, you MUST either:
|
||||
|
||||
1. **Fix the violation properly** in the same change, OR
|
||||
2. **Flag it as a deferred item** with documented rationale
|
||||
|
||||
**"It was already there" is NEVER an acceptable justification** for perpetuating a violation in code you touched. Touching it makes it yours.
|
||||
|
||||
Examples of violations you must fix when you touch the line:
|
||||
|
||||
- `as unknown as Type` double assertions — use type guards instead
|
||||
- `any` types — narrow to `unknown` with validation or define a proper interface
|
||||
- Missing error handling — add it if you're modifying the surrounding code
|
||||
- Suppressed linting rules (`// eslint-disable`) — fix the underlying issue
|
||||
|
||||
If the proper fix is too large for the current scope, you MUST:
|
||||
|
||||
- Create a TODO comment with issue reference: `// TODO(#123): Replace double assertion with type guard`
|
||||
- Document the deferral in your PR/commit description
|
||||
- Never silently carry the violation forward
|
||||
|
||||
36
Makefile
36
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: help install dev build test docker-up docker-down docker-logs docker-ps docker-build docker-restart docker-test clean
|
||||
.PHONY: help install dev build test docker-up docker-down docker-logs docker-ps docker-build docker-restart docker-test speech-up speech-down speech-logs clean matrix-up matrix-down matrix-logs matrix-setup-bot
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@@ -24,6 +24,17 @@ help:
|
||||
@echo " make docker-test Run Docker smoke test"
|
||||
@echo " make docker-test-traefik Run Traefik integration tests"
|
||||
@echo ""
|
||||
@echo "Speech Services:"
|
||||
@echo " make speech-up Start speech services (STT + TTS)"
|
||||
@echo " make speech-down Stop speech services"
|
||||
@echo " make speech-logs View speech service logs"
|
||||
@echo ""
|
||||
@echo "Matrix Dev Environment:"
|
||||
@echo " make matrix-up Start Matrix services (Synapse + Element)"
|
||||
@echo " make matrix-down Stop Matrix services"
|
||||
@echo " make matrix-logs View Matrix service logs"
|
||||
@echo " make matrix-setup-bot Create bot account and get access token"
|
||||
@echo ""
|
||||
@echo "Database:"
|
||||
@echo " make db-migrate Run database migrations"
|
||||
@echo " make db-seed Seed development data"
|
||||
@@ -85,6 +96,29 @@ docker-test:
|
||||
docker-test-traefik:
|
||||
./tests/integration/docker/traefik.test.sh all
|
||||
|
||||
# Speech services
|
||||
speech-up:
|
||||
docker compose -f docker-compose.yml -f docker-compose.speech.yml up -d speaches kokoro-tts
|
||||
|
||||
speech-down:
|
||||
docker compose -f docker-compose.yml -f docker-compose.speech.yml down --remove-orphans
|
||||
|
||||
speech-logs:
|
||||
docker compose -f docker-compose.yml -f docker-compose.speech.yml logs -f speaches kokoro-tts
|
||||
|
||||
# Matrix Dev Environment
|
||||
matrix-up:
|
||||
docker compose -f docker/docker-compose.yml -f docker/docker-compose.matrix.yml up -d
|
||||
|
||||
matrix-down:
|
||||
docker compose -f docker/docker-compose.yml -f docker/docker-compose.matrix.yml down
|
||||
|
||||
matrix-logs:
|
||||
docker compose -f docker/docker-compose.yml -f docker/docker-compose.matrix.yml logs -f synapse element-web
|
||||
|
||||
matrix-setup-bot:
|
||||
docker/matrix/scripts/setup-bot.sh
|
||||
|
||||
# Database operations
|
||||
db-migrate:
|
||||
cd apps/api && pnpm prisma:migrate
|
||||
|
||||
304
README.md
304
README.md
@@ -19,29 +19,82 @@ Mosaic Stack is a modern, PDA-friendly platform designed to help users manage th
|
||||
|
||||
## Technology Stack
|
||||
|
||||
| Layer | Technology |
|
||||
| -------------- | -------------------------------------------- |
|
||||
| **Frontend** | Next.js 16 + React + TailwindCSS + Shadcn/ui |
|
||||
| **Backend** | NestJS + Prisma ORM |
|
||||
| **Database** | PostgreSQL 17 + pgvector |
|
||||
| **Cache** | Valkey (Redis-compatible) |
|
||||
| **Auth** | Authentik (OIDC) via BetterAuth |
|
||||
| **AI** | Ollama (local or remote) |
|
||||
| **Messaging** | MoltBot (stock + plugins) |
|
||||
| **Real-time** | WebSockets (Socket.io) |
|
||||
| **Monorepo** | pnpm workspaces + TurboRepo |
|
||||
| **Testing** | Vitest + Playwright |
|
||||
| **Deployment** | Docker + docker-compose |
|
||||
| Layer | Technology |
|
||||
| -------------- | ---------------------------------------------- |
|
||||
| **Frontend** | Next.js 16 + React + TailwindCSS + Shadcn/ui |
|
||||
| **Backend** | NestJS + Prisma ORM |
|
||||
| **Database** | PostgreSQL 17 + pgvector |
|
||||
| **Cache** | Valkey (Redis-compatible) |
|
||||
| **Auth** | Authentik (OIDC) via BetterAuth |
|
||||
| **AI** | Ollama (local or remote) |
|
||||
| **Messaging** | MoltBot (stock + plugins) |
|
||||
| **Real-time** | WebSockets (Socket.io) |
|
||||
| **Speech** | Speaches (STT) + Kokoro/Chatterbox/Piper (TTS) |
|
||||
| **Monorepo** | pnpm workspaces + TurboRepo |
|
||||
| **Testing** | Vitest + Playwright |
|
||||
| **Deployment** | Docker + docker-compose |
|
||||
|
||||
## Quick Start
|
||||
|
||||
### One-Line Install (Recommended)
|
||||
|
||||
The fastest way to get Mosaic Stack running on macOS or Linux:
|
||||
|
||||
```bash
|
||||
curl -fsSL https://get.mosaicstack.dev | bash
|
||||
```
|
||||
|
||||
This installer:
|
||||
|
||||
- ✅ Detects your platform (macOS, Debian/Ubuntu, Arch, Fedora)
|
||||
- ✅ Installs all required dependencies (Docker, Node.js, etc.)
|
||||
- ✅ Generates secure secrets automatically
|
||||
- ✅ Configures the environment for you
|
||||
- ✅ Starts all services with Docker Compose
|
||||
- ✅ Validates the installation with health checks
|
||||
|
||||
**Installer Options:**
|
||||
|
||||
```bash
|
||||
# Non-interactive Docker deployment
|
||||
curl -fsSL https://get.mosaicstack.dev | bash -s -- --non-interactive --mode docker
|
||||
|
||||
# Preview installation without making changes
|
||||
curl -fsSL https://get.mosaicstack.dev | bash -s -- --dry-run
|
||||
|
||||
# With SSO and local Ollama
|
||||
curl -fsSL https://get.mosaicstack.dev | bash -s -- \
|
||||
--mode docker \
|
||||
--enable-sso --bundled-authentik \
|
||||
--ollama-mode local
|
||||
|
||||
# Skip dependency installation (if already installed)
|
||||
curl -fsSL https://get.mosaicstack.dev | bash -s -- --skip-deps
|
||||
```
|
||||
|
||||
**After Installation:**
|
||||
|
||||
```bash
|
||||
# Check system health
|
||||
./scripts/commands/doctor.sh
|
||||
|
||||
# View service logs
|
||||
docker compose logs -f
|
||||
|
||||
# Stop services
|
||||
docker compose down
|
||||
```
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Node.js 20+ and pnpm 9+
|
||||
- PostgreSQL 17+ (or use Docker)
|
||||
- Docker & Docker Compose (optional, for turnkey deployment)
|
||||
If you prefer manual installation, you'll need:
|
||||
|
||||
### Installation
|
||||
- **Docker mode:** Docker 24+ and Docker Compose
|
||||
- **Native mode:** Node.js 24+, pnpm 10+, PostgreSQL 17+
|
||||
|
||||
The installer handles these automatically.
|
||||
|
||||
### Manual Installation
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
@@ -70,10 +123,12 @@ pnpm prisma:seed
|
||||
pnpm dev
|
||||
```
|
||||
|
||||
### Docker Deployment (Turnkey)
|
||||
### Docker Deployment
|
||||
|
||||
**Recommended for quick setup and production deployments.**
|
||||
|
||||
#### Development (Turnkey - All Services Bundled)
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://git.mosaicstack.dev/mosaic/stack mosaic-stack
|
||||
@@ -81,26 +136,63 @@ cd mosaic-stack
|
||||
|
||||
# Copy and configure environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your settings
|
||||
# Set COMPOSE_PROFILES=full in .env
|
||||
|
||||
# Start core services (PostgreSQL, Valkey, API, Web)
|
||||
# Start all services (PostgreSQL, Valkey, OpenBao, Authentik, Ollama, API, Web)
|
||||
docker compose up -d
|
||||
|
||||
# Or start with optional services
|
||||
docker compose --profile full up -d # Includes Authentik and Ollama
|
||||
|
||||
# View logs
|
||||
docker compose logs -f
|
||||
|
||||
# Check service status
|
||||
docker compose ps
|
||||
|
||||
# Access services
|
||||
# Web: http://localhost:3000
|
||||
# API: http://localhost:3001
|
||||
# Auth: http://localhost:9000 (if Authentik enabled)
|
||||
# Auth: http://localhost:9000
|
||||
```
|
||||
|
||||
# Stop services
|
||||
#### Production (External Managed Services)
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://git.mosaicstack.dev/mosaic/stack mosaic-stack
|
||||
cd mosaic-stack
|
||||
|
||||
# Copy environment template and example
|
||||
cp .env.example .env
|
||||
cp docker/docker-compose.example.external.yml docker-compose.override.yml
|
||||
|
||||
# Edit .env with external service URLs:
|
||||
# - DATABASE_URL=postgresql://... (RDS, Cloud SQL, etc.)
|
||||
# - VALKEY_URL=redis://... (ElastiCache, Memorystore, etc.)
|
||||
# - OPENBAO_ADDR=https://... (HashiCorp Vault, etc.)
|
||||
# - OIDC_ISSUER=https://... (Auth0, Okta, etc.)
|
||||
# - Set COMPOSE_PROFILES= (empty)
|
||||
|
||||
# Start API and Web only
|
||||
docker compose up -d
|
||||
|
||||
# View logs
|
||||
docker compose logs -f
|
||||
```
|
||||
|
||||
#### Hybrid (Mix of Bundled and External)
|
||||
|
||||
```bash
|
||||
# Use bundled database/cache, external auth/secrets
|
||||
cp docker/docker-compose.example.hybrid.yml docker-compose.override.yml
|
||||
|
||||
# Edit .env:
|
||||
# - COMPOSE_PROFILES=database,cache,ollama
|
||||
# - OPENBAO_ADDR=https://... (external vault)
|
||||
# - OIDC_ISSUER=https://... (external auth)
|
||||
|
||||
# Start mixed deployment
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
**Stop services:**
|
||||
|
||||
```bash
|
||||
docker compose down
|
||||
```
|
||||
|
||||
@@ -110,11 +202,88 @@ docker compose down
|
||||
- Valkey (Redis-compatible cache)
|
||||
- Mosaic API (NestJS)
|
||||
- Mosaic Web (Next.js)
|
||||
- Mosaic Orchestrator (Agent lifecycle management)
|
||||
- Mosaic Coordinator (Task assignment & monitoring)
|
||||
- Authentik OIDC (optional, use `--profile authentik`)
|
||||
- Ollama AI (optional, use `--profile ollama`)
|
||||
|
||||
See [Docker Deployment Guide](docs/1-getting-started/4-docker-deployment/) for complete documentation.
|
||||
|
||||
### Docker Swarm Deployment (Production)
|
||||
|
||||
**Recommended for production deployments with high availability and auto-scaling.**
|
||||
|
||||
Deploy to a Docker Swarm cluster with integrated Traefik reverse proxy:
|
||||
|
||||
```bash
|
||||
# 1. Initialize swarm (if not already done)
|
||||
docker swarm init --advertise-addr <your-ip>
|
||||
|
||||
# 2. Create Traefik network
|
||||
docker network create --driver=overlay traefik-public
|
||||
|
||||
# 3. Configure environment for swarm
|
||||
cp .env.swarm.example .env
|
||||
nano .env # Configure domains, passwords, API keys
|
||||
|
||||
# 4. CRITICAL: Deploy OpenBao standalone FIRST
|
||||
# OpenBao cannot run in swarm mode - deploy as standalone container
|
||||
docker compose -f docker-compose.openbao.yml up -d
|
||||
sleep 30 # Wait for auto-initialization
|
||||
|
||||
# 5. Deploy swarm stack
|
||||
IMAGE_TAG=dev ./scripts/deploy-swarm.sh mosaic
|
||||
|
||||
# 6. Check deployment status
|
||||
docker stack services mosaic
|
||||
docker stack ps mosaic
|
||||
|
||||
# Access services via Traefik
|
||||
# Web: http://mosaic.mosaicstack.dev
|
||||
# API: http://api.mosaicstack.dev
|
||||
# Auth: http://auth.mosaicstack.dev (if using bundled Authentik)
|
||||
```
|
||||
|
||||
**Key features:**
|
||||
|
||||
- Automatic Traefik integration for routing
|
||||
- Overlay networking for multi-host deployments
|
||||
- Built-in health checks and rolling updates
|
||||
- Horizontal scaling for web and API services
|
||||
- Zero-downtime deployments
|
||||
- Service orchestration across multiple nodes
|
||||
|
||||
**Important Notes:**
|
||||
|
||||
- **OpenBao Requirement:** OpenBao MUST be deployed as standalone container (not in swarm). Use `docker-compose.openbao.yml` or external Vault.
|
||||
- Swarm does NOT support docker-compose profiles
|
||||
- To use external services (PostgreSQL, Authentik, etc.), manually comment them out in `docker-compose.swarm.yml`
|
||||
|
||||
See [Docker Swarm Deployment Guide](docs/SWARM-DEPLOYMENT.md) and [Quick Reference](docs/SWARM-QUICKREF.md) for complete documentation.
|
||||
|
||||
### Portainer Deployment
|
||||
|
||||
**Recommended for GUI-based stack management.**
|
||||
|
||||
Portainer provides a web UI for managing Docker containers and stacks. Use the Portainer-optimized compose file:
|
||||
|
||||
**File:** `docker-compose.portainer.yml`
|
||||
|
||||
**Key differences from standard compose:**
|
||||
|
||||
- No `env_file` directive (define variables in Portainer UI)
|
||||
- Port exposed on all interfaces (Portainer limitation)
|
||||
- Optimized for Portainer's stack parser
|
||||
|
||||
**Quick Steps:**
|
||||
|
||||
1. Create `mosaic_internal` overlay network in Portainer
|
||||
2. Deploy `mosaic-openbao` stack with `docker-compose.portainer.yml`
|
||||
3. Deploy `mosaic` swarm stack with `docker-compose.swarm.yml`
|
||||
4. Configure environment variables in Portainer UI
|
||||
|
||||
See [Portainer Deployment Guide](docs/PORTAINER-DEPLOYMENT.md) for detailed instructions.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
@@ -124,13 +293,29 @@ mosaic-stack/
|
||||
│ │ ├── src/
|
||||
│ │ │ ├── auth/ # BetterAuth + Authentik OIDC
|
||||
│ │ │ ├── prisma/ # Database service
|
||||
│ │ │ ├── coordinator-integration/ # Coordinator API client
|
||||
│ │ │ └── app.module.ts # Main application module
|
||||
│ │ ├── prisma/
|
||||
│ │ │ └── schema.prisma # Database schema
|
||||
│ │ └── Dockerfile
|
||||
│ └── web/ # Next.js 16 frontend (planned)
|
||||
│ ├── app/
|
||||
│ ├── components/
|
||||
│ ├── web/ # Next.js 16 frontend
|
||||
│ │ ├── app/
|
||||
│ │ ├── components/
|
||||
│ │ │ └── widgets/ # HUD widgets (agent status, etc.)
|
||||
│ │ └── Dockerfile
|
||||
│ ├── orchestrator/ # Agent lifecycle & spawning (NestJS)
|
||||
│ │ ├── src/
|
||||
│ │ │ ├── spawner/ # Agent spawning service
|
||||
│ │ │ ├── queue/ # Valkey-backed task queue
|
||||
│ │ │ ├── monitor/ # Health monitoring
|
||||
│ │ │ ├── git/ # Git worktree management
|
||||
│ │ │ └── killswitch/ # Emergency agent termination
|
||||
│ │ └── Dockerfile
|
||||
│ └── coordinator/ # Task assignment & monitoring (FastAPI)
|
||||
│ ├── src/
|
||||
│ │ ├── webhook.py # Gitea webhook receiver
|
||||
│ │ ├── parser.py # Issue metadata parser
|
||||
│ │ └── security.py # HMAC signature verification
|
||||
│ └── Dockerfile
|
||||
├── packages/
|
||||
│ ├── shared/ # Shared types & utilities
|
||||
@@ -159,23 +344,59 @@ mosaic-stack/
|
||||
└── pnpm-workspace.yaml # Workspace configuration
|
||||
```
|
||||
|
||||
## Agent Orchestration Layer (v0.0.6)
|
||||
|
||||
Mosaic Stack includes a sophisticated agent orchestration system for autonomous task execution:
|
||||
|
||||
- **Orchestrator Service** (NestJS) - Manages agent lifecycle, spawning, and health monitoring
|
||||
- **Coordinator Service** (FastAPI) - Receives Gitea webhooks, assigns tasks to agents
|
||||
- **Task Queue** - Valkey-backed queue for distributed task management
|
||||
- **Git Worktrees** - Isolated workspaces for parallel agent execution
|
||||
- **Killswitch** - Emergency stop mechanism for runaway agents
|
||||
- **Agent Dashboard** - Real-time monitoring UI with status widgets
|
||||
|
||||
See [Agent Orchestration Design](docs/design/agent-orchestration.md) for architecture details.
|
||||
|
||||
## Speech Services
|
||||
|
||||
Mosaic Stack includes integrated speech-to-text (STT) and text-to-speech (TTS) capabilities through a modular provider architecture. Each component is optional and independently configurable.
|
||||
|
||||
- **Speech-to-Text** - Transcribe audio files and real-time audio streams using Whisper (via Speaches)
|
||||
- **Text-to-Speech** - Synthesize speech with 54+ voices across 8 languages (via Kokoro, CPU-based)
|
||||
- **Premium Voice Cloning** - Clone voices from audio samples with emotion control (via Chatterbox, GPU)
|
||||
- **Fallback TTS** - Ultra-lightweight CPU fallback for low-resource environments (via Piper/OpenedAI Speech)
|
||||
- **WebSocket Streaming** - Real-time streaming transcription via Socket.IO `/speech` namespace
|
||||
- **Automatic Fallback** - TTS tier system with graceful degradation (premium -> default -> fallback)
|
||||
|
||||
**Quick Start:**
|
||||
|
||||
```bash
|
||||
# Start speech services alongside core stack
|
||||
make speech-up
|
||||
|
||||
# Or with Docker Compose directly
|
||||
docker compose -f docker-compose.yml -f docker-compose.speech.yml up -d
|
||||
```
|
||||
|
||||
See [Speech Services Documentation](docs/SPEECH.md) for architecture details, API reference, provider configuration, and deployment options.
|
||||
|
||||
## Current Implementation Status
|
||||
|
||||
### ✅ Completed (v0.0.1)
|
||||
### ✅ Completed (v0.0.1-0.0.6)
|
||||
|
||||
- **Issue #1:** Project scaffold and monorepo setup
|
||||
- **Issue #2:** PostgreSQL 17 + pgvector database schema
|
||||
- **Issue #3:** Prisma ORM integration with tests and seed data
|
||||
- **Issue #4:** Authentik OIDC authentication with BetterAuth
|
||||
- **M1-Foundation:** Project scaffold, PostgreSQL 17 + pgvector, Prisma ORM
|
||||
- **M2-MultiTenant:** Workspace isolation with RLS, team management
|
||||
- **M3-Features:** Knowledge management, tasks, calendar, authentication
|
||||
- **M4-MoltBot:** Bot integration architecture (in progress)
|
||||
- **M6-AgentOrchestration:** Orchestrator service, coordinator, agent dashboard ✅
|
||||
|
||||
**Test Coverage:** 26/26 tests passing (100%)
|
||||
**Test Coverage:** 2168+ tests passing
|
||||
|
||||
### 🚧 In Progress (v0.0.x)
|
||||
|
||||
- **Issue #5:** Multi-tenant workspace isolation (planned)
|
||||
- **Issue #6:** Frontend authentication UI ✅ **COMPLETED**
|
||||
- **Issue #7:** Activity logging system (planned)
|
||||
- **Issue #8:** Docker compose setup ✅ **COMPLETED**
|
||||
- Agent orchestration E2E testing
|
||||
- Usage budget management
|
||||
- Performance optimization
|
||||
|
||||
### 📋 Planned Features (v0.1.0 MVP)
|
||||
|
||||
@@ -561,6 +782,7 @@ Complete documentation is organized in a Bookstack-compatible structure in the `
|
||||
- **[Overview](docs/3-architecture/1-overview/)** — System design and components
|
||||
- **[Authentication](docs/3-architecture/2-authentication/)** — BetterAuth and OIDC integration
|
||||
- **[Design Principles](docs/3-architecture/3-design-principles/1-pda-friendly.md)** — PDA-friendly patterns (non-negotiable)
|
||||
- **[Telemetry](docs/telemetry.md)** — AI task completion tracking, predictions, and SDK reference
|
||||
|
||||
### 🔌 API Reference
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
# Database
|
||||
DATABASE_URL=postgresql://user:password@localhost:5432/database
|
||||
|
||||
# System Administration
|
||||
# Comma-separated list of user IDs that have system administrator privileges
|
||||
# These users can perform system-level operations across all workspaces
|
||||
# Note: Workspace ownership does NOT grant system admin access
|
||||
# SYSTEM_ADMIN_IDS=uuid1,uuid2,uuid3
|
||||
|
||||
# Federation Instance Identity
|
||||
# Display name for this Mosaic instance
|
||||
INSTANCE_NAME=Mosaic Instance
|
||||
@@ -11,3 +17,24 @@ INSTANCE_URL=http://localhost:3000
|
||||
# CRITICAL: Generate a secure random key for production!
|
||||
# Generate with: node -e "console.log(require('crypto').randomBytes(32).toString('hex'))"
|
||||
ENCRYPTION_KEY=0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef
|
||||
|
||||
# CSRF Protection (Required in production)
|
||||
# Secret key for HMAC binding CSRF tokens to user sessions
|
||||
# Generate with: node -e "console.log(require('crypto').randomBytes(32).toString('hex'))"
|
||||
# In development, a random key is generated if not set
|
||||
CSRF_SECRET=fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210
|
||||
|
||||
# OpenTelemetry Configuration
|
||||
# Enable/disable OpenTelemetry tracing (default: true)
|
||||
OTEL_ENABLED=true
|
||||
# Service name for telemetry (default: mosaic-api)
|
||||
OTEL_SERVICE_NAME=mosaic-api
|
||||
# OTLP exporter endpoint (default: http://localhost:4318/v1/traces)
|
||||
OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318/v1/traces
|
||||
# Alternative: Jaeger endpoint (legacy)
|
||||
# OTEL_EXPORTER_JAEGER_ENDPOINT=http://localhost:4318/v1/traces
|
||||
# Deployment environment (default: development, or uses NODE_ENV)
|
||||
# OTEL_DEPLOYMENT_ENVIRONMENT=production
|
||||
# Trace sampling ratio: 0.0 (none) to 1.0 (all) - default: 1.0
|
||||
# Use lower values in high-traffic production environments
|
||||
# OTEL_TRACES_SAMPLER_ARG=1.0
|
||||
|
||||
9
apps/api/.env.test.example
Normal file
9
apps/api/.env.test.example
Normal file
@@ -0,0 +1,9 @@
|
||||
# WARNING: These are example test credentials for local integration testing.
|
||||
# Copy this file to .env.test and customize the values for your local environment.
|
||||
# NEVER use these credentials in any shared environment or commit .env.test to git.
|
||||
|
||||
DATABASE_URL="postgresql://test:test@localhost:5432/test"
|
||||
ENCRYPTION_KEY="test-encryption-key-32-characters"
|
||||
JWT_SECRET="test-jwt-secret"
|
||||
INSTANCE_NAME="Test Instance"
|
||||
INSTANCE_URL="https://test.example.com"
|
||||
25
apps/api/AGENTS.md
Normal file
25
apps/api/AGENTS.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# api — Agent Context
|
||||
|
||||
> Part of the apps layer.
|
||||
|
||||
## Patterns
|
||||
|
||||
- **Config validation pattern**: Config files use exported validation functions + typed getter functions (not class-validator). See `auth.config.ts`, `federation.config.ts`, `speech/speech.config.ts`. Pattern: export `isXEnabled()`, `validateXConfig()`, and `getXConfig()` functions.
|
||||
- **Config registerAs**: `speech.config.ts` also exports a `registerAs("speech", ...)` factory for NestJS ConfigModule namespaced injection. Use `ConfigModule.forFeature(speechConfig)` in module imports and access via `this.config.get<string>('speech.stt.baseUrl')`.
|
||||
- **Conditional config validation**: When a service has an enabled flag (e.g., `STT_ENABLED`), URL/connection vars are only required when enabled. Validation throws with a helpful message suggesting how to disable.
|
||||
- **Boolean env parsing**: Use `value === "true" || value === "1"` pattern. No default-true -- all services default to disabled when env var is unset.
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **Prisma client must be generated** before `tsc --noEmit` will pass. Run `pnpm prisma:generate` first. Pre-existing type errors from Prisma are expected in worktrees without generated client.
|
||||
- **Pre-commit hooks**: lint-staged runs on staged files. If other packages' files are staged, their lint must pass too. Only stage files you intend to commit.
|
||||
- **vitest runs all test files**: Even when targeting a specific test file, vitest loads all spec files. Many will fail if Prisma client isn't generated -- this is expected. Check only your target file's pass/fail status.
|
||||
|
||||
## Key Files
|
||||
|
||||
| File | Purpose |
|
||||
| ------------------------------------- | ---------------------------------------------------------------------- |
|
||||
| `src/speech/speech.config.ts` | Speech services env var validation and typed config (STT, TTS, limits) |
|
||||
| `src/speech/speech.config.spec.ts` | Unit tests for speech config validation (51 tests) |
|
||||
| `src/auth/auth.config.ts` | Auth/OIDC config validation (reference pattern) |
|
||||
| `src/federation/federation.config.ts` | Federation config validation (reference pattern) |
|
||||
@@ -2,7 +2,9 @@
|
||||
# Enable BuildKit features for cache mounts
|
||||
|
||||
# Base image for all stages
|
||||
FROM node:20-alpine AS base
|
||||
# Uses Debian slim (glibc) instead of Alpine (musl) because native Node.js addons
|
||||
# (matrix-sdk-crypto-nodejs, Prisma engines) require glibc-compatible binaries.
|
||||
FROM node:24-slim AS base
|
||||
|
||||
# Install pnpm globally
|
||||
RUN corepack enable && corepack prepare pnpm@10.27.0 --activate
|
||||
@@ -46,36 +48,24 @@ COPY --from=deps /app/packages/shared/node_modules ./packages/shared/node_module
|
||||
COPY --from=deps /app/packages/config/node_modules ./packages/config/node_modules
|
||||
COPY --from=deps /app/apps/api/node_modules ./apps/api/node_modules
|
||||
|
||||
# Debug: Show what we have before building
|
||||
RUN echo "=== Pre-build directory structure ===" && \
|
||||
echo "--- packages/config/typescript ---" && ls -la packages/config/typescript/ && \
|
||||
echo "--- packages/shared (top level) ---" && ls -la packages/shared/ && \
|
||||
echo "--- packages/shared/src ---" && ls -la packages/shared/src/ && \
|
||||
echo "--- apps/api (top level) ---" && ls -la apps/api/ && \
|
||||
echo "--- apps/api/src (exists?) ---" && ls apps/api/src/*.ts | head -5 && \
|
||||
echo "--- node_modules/@mosaic (symlinks?) ---" && ls -la node_modules/@mosaic/ 2>/dev/null || echo "No @mosaic in node_modules"
|
||||
|
||||
# Build the API app and its dependencies using TurboRepo
|
||||
# This ensures @mosaic/shared is built first, then prisma:generate, then the API
|
||||
# Disable turbo cache temporarily to ensure fresh build and see full output
|
||||
RUN pnpm turbo build --filter=@mosaic/api --force --verbosity=2
|
||||
|
||||
# Debug: Show what was built
|
||||
RUN echo "=== Post-build directory structure ===" && \
|
||||
echo "--- packages/shared/dist ---" && ls -la packages/shared/dist/ 2>/dev/null || echo "NO dist in shared" && \
|
||||
echo "--- apps/api/dist ---" && ls -la apps/api/dist/ 2>/dev/null || echo "NO dist in api" && \
|
||||
echo "--- apps/api/dist contents (if exists) ---" && find apps/api/dist -type f 2>/dev/null | head -10 || echo "Cannot find dist files"
|
||||
# --force disables turbo cache to ensure fresh build from source
|
||||
RUN pnpm turbo build --filter=@mosaic/api --force
|
||||
|
||||
# ======================
|
||||
# Production stage
|
||||
# ======================
|
||||
FROM node:20-alpine AS production
|
||||
FROM node:24-slim AS production
|
||||
|
||||
# Remove npm (unused in production — we use pnpm) to reduce attack surface
|
||||
RUN rm -rf /usr/local/lib/node_modules/npm /usr/local/bin/npm /usr/local/bin/npx
|
||||
|
||||
# Install dumb-init for proper signal handling
|
||||
RUN apk add --no-cache dumb-init
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends dumb-init \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1001 -S nodejs && adduser -S nestjs -u 1001
|
||||
RUN groupadd -g 1001 nodejs && useradd -m -u 1001 -g nodejs nestjs
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -93,6 +83,9 @@ COPY --from=builder --chown=nestjs:nodejs /app/apps/api/package.json ./apps/api/
|
||||
# Copy app's node_modules which contains symlinks to root node_modules
|
||||
COPY --from=builder --chown=nestjs:nodejs /app/apps/api/node_modules ./apps/api/node_modules
|
||||
|
||||
# Copy entrypoint script (runs migrations before starting app)
|
||||
COPY --from=builder --chown=nestjs:nodejs /app/apps/api/docker-entrypoint.sh ./apps/api/
|
||||
|
||||
# Set working directory to API app
|
||||
WORKDIR /app/apps/api
|
||||
|
||||
@@ -109,5 +102,5 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
# Use dumb-init to handle signals properly
|
||||
ENTRYPOINT ["dumb-init", "--"]
|
||||
|
||||
# Start the application
|
||||
CMD ["node", "dist/main.js"]
|
||||
# Run migrations then start the application
|
||||
CMD ["sh", "docker-entrypoint.sh"]
|
||||
|
||||
8
apps/api/docker-entrypoint.sh
Executable file
8
apps/api/docker-entrypoint.sh
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/bin/sh
|
||||
set -e
|
||||
|
||||
echo "Running database migrations..."
|
||||
./node_modules/.bin/prisma migrate deploy --schema ./prisma/schema.prisma
|
||||
|
||||
echo "Starting application..."
|
||||
exec node dist/main.js
|
||||
@@ -21,11 +21,13 @@
|
||||
"prisma:migrate:prod": "prisma migrate deploy",
|
||||
"prisma:studio": "prisma studio",
|
||||
"prisma:seed": "prisma db seed",
|
||||
"prisma:reset": "prisma migrate reset"
|
||||
"prisma:reset": "prisma migrate reset",
|
||||
"migrate:encrypt-llm-keys": "tsx scripts/encrypt-llm-keys.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@anthropic-ai/sdk": "^0.72.1",
|
||||
"@mosaic/shared": "workspace:*",
|
||||
"@mosaicstack/telemetry-client": "^0.1.1",
|
||||
"@nestjs/axios": "^4.0.1",
|
||||
"@nestjs/bullmq": "^11.0.4",
|
||||
"@nestjs/common": "^11.1.12",
|
||||
@@ -42,17 +44,19 @@
|
||||
"@opentelemetry/instrumentation-nestjs-core": "^0.44.0",
|
||||
"@opentelemetry/resources": "^1.30.1",
|
||||
"@opentelemetry/sdk-node": "^0.56.0",
|
||||
"@opentelemetry/sdk-trace-base": "^2.5.0",
|
||||
"@opentelemetry/semantic-conventions": "^1.28.0",
|
||||
"@prisma/client": "^6.19.2",
|
||||
"@types/marked": "^6.0.0",
|
||||
"@types/multer": "^2.0.0",
|
||||
"adm-zip": "^0.5.16",
|
||||
"archiver": "^7.0.1",
|
||||
"axios": "^1.13.4",
|
||||
"axios": "^1.13.5",
|
||||
"better-auth": "^1.4.17",
|
||||
"bullmq": "^5.67.2",
|
||||
"class-transformer": "^0.5.1",
|
||||
"class-validator": "^0.14.3",
|
||||
"cookie-parser": "^1.4.7",
|
||||
"discord.js": "^14.25.1",
|
||||
"gray-matter": "^4.0.3",
|
||||
"highlight.js": "^11.11.1",
|
||||
@@ -61,6 +65,7 @@
|
||||
"marked": "^17.0.1",
|
||||
"marked-gfm-heading-id": "^4.1.3",
|
||||
"marked-highlight": "^2.2.3",
|
||||
"matrix-bot-sdk": "^0.8.0",
|
||||
"ollama": "^0.6.3",
|
||||
"openai": "^6.17.0",
|
||||
"reflect-metadata": "^0.2.2",
|
||||
@@ -75,15 +80,18 @@
|
||||
"@nestjs/cli": "^11.0.6",
|
||||
"@nestjs/schematics": "^11.0.1",
|
||||
"@nestjs/testing": "^11.1.12",
|
||||
"@opentelemetry/context-async-hooks": "^2.5.0",
|
||||
"@swc/core": "^1.10.18",
|
||||
"@types/adm-zip": "^0.5.7",
|
||||
"@types/archiver": "^7.0.0",
|
||||
"@types/cookie-parser": "^1.4.10",
|
||||
"@types/express": "^5.0.1",
|
||||
"@types/highlight.js": "^10.1.0",
|
||||
"@types/node": "^22.13.4",
|
||||
"@types/sanitize-html": "^2.16.0",
|
||||
"@types/supertest": "^6.0.3",
|
||||
"@vitest/coverage-v8": "^4.0.18",
|
||||
"dotenv": "^17.2.4",
|
||||
"express": "^5.2.1",
|
||||
"prisma": "^6.19.2",
|
||||
"supertest": "^7.2.2",
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "FederationConnectionStatus" AS ENUM ('PENDING', 'ACTIVE', 'SUSPENDED', 'DISCONNECTED');
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "FederationMessageType" AS ENUM ('QUERY', 'COMMAND', 'EVENT');
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "FederationMessageStatus" AS ENUM ('PENDING', 'DELIVERED', 'FAILED', 'TIMEOUT');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "federation_connections" (
|
||||
"id" UUID NOT NULL,
|
||||
"workspace_id" UUID NOT NULL,
|
||||
"remote_instance_id" TEXT NOT NULL,
|
||||
"remote_url" TEXT NOT NULL,
|
||||
"remote_public_key" TEXT NOT NULL,
|
||||
"remote_capabilities" JSONB NOT NULL DEFAULT '{}',
|
||||
"status" "FederationConnectionStatus" NOT NULL DEFAULT 'PENDING',
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ NOT NULL,
|
||||
"connected_at" TIMESTAMPTZ,
|
||||
"disconnected_at" TIMESTAMPTZ,
|
||||
|
||||
CONSTRAINT "federation_connections_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "federated_identities" (
|
||||
"id" UUID NOT NULL,
|
||||
"local_user_id" UUID NOT NULL,
|
||||
"remote_user_id" TEXT NOT NULL,
|
||||
"remote_instance_id" TEXT NOT NULL,
|
||||
"oidc_subject" TEXT NOT NULL,
|
||||
"email" TEXT NOT NULL,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ NOT NULL,
|
||||
|
||||
CONSTRAINT "federated_identities_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "federation_messages" (
|
||||
"id" UUID NOT NULL,
|
||||
"workspace_id" UUID NOT NULL,
|
||||
"connection_id" UUID NOT NULL,
|
||||
"message_type" "FederationMessageType" NOT NULL,
|
||||
"message_id" TEXT NOT NULL,
|
||||
"correlation_id" TEXT,
|
||||
"query" TEXT,
|
||||
"command_type" TEXT,
|
||||
"event_type" TEXT,
|
||||
"payload" JSONB DEFAULT '{}',
|
||||
"response" JSONB DEFAULT '{}',
|
||||
"status" "FederationMessageStatus" NOT NULL DEFAULT 'PENDING',
|
||||
"error" TEXT,
|
||||
"signature" TEXT NOT NULL,
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ NOT NULL,
|
||||
"delivered_at" TIMESTAMPTZ,
|
||||
|
||||
CONSTRAINT "federation_messages_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "federation_connections_workspace_id_remote_instance_id_key" ON "federation_connections"("workspace_id", "remote_instance_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "federation_connections_workspace_id_idx" ON "federation_connections"("workspace_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "federation_connections_workspace_id_status_idx" ON "federation_connections"("workspace_id", "status");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "federation_connections_remote_instance_id_idx" ON "federation_connections"("remote_instance_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "federated_identities_local_user_id_remote_instance_id_key" ON "federated_identities"("local_user_id", "remote_instance_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "federated_identities_local_user_id_idx" ON "federated_identities"("local_user_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "federated_identities_remote_instance_id_idx" ON "federated_identities"("remote_instance_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "federated_identities_oidc_subject_idx" ON "federated_identities"("oidc_subject");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "federation_messages_message_id_key" ON "federation_messages"("message_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "federation_messages_workspace_id_idx" ON "federation_messages"("workspace_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "federation_messages_connection_id_idx" ON "federation_messages"("connection_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "federation_messages_message_id_idx" ON "federation_messages"("message_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "federation_messages_correlation_id_idx" ON "federation_messages"("correlation_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "federation_messages_event_type_idx" ON "federation_messages"("event_type");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "federation_connections" ADD CONSTRAINT "federation_connections_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "federated_identities" ADD CONSTRAINT "federated_identities_local_user_id_fkey" FOREIGN KEY ("local_user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "federation_messages" ADD CONSTRAINT "federation_messages_connection_id_fkey" FOREIGN KEY ("connection_id") REFERENCES "federation_connections"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "federation_messages" ADD CONSTRAINT "federation_messages_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
@@ -1,9 +1,3 @@
|
||||
-- Add eventType column to federation_messages table
|
||||
ALTER TABLE "federation_messages" ADD COLUMN "event_type" TEXT;
|
||||
|
||||
-- Add index for eventType
|
||||
CREATE INDEX "federation_messages_event_type_idx" ON "federation_messages"("event_type");
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "federation_event_subscriptions" (
|
||||
"id" UUID NOT NULL,
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
-- Rollback: SQL Injection Hardening for is_workspace_admin() Helper Function
|
||||
-- This reverts the function to its previous implementation
|
||||
|
||||
-- =============================================================================
|
||||
-- REVERT is_workspace_admin() to original implementation
|
||||
-- =============================================================================
|
||||
|
||||
CREATE OR REPLACE FUNCTION is_workspace_admin(workspace_uuid UUID, user_uuid UUID)
|
||||
RETURNS BOOLEAN AS $$
|
||||
BEGIN
|
||||
RETURN EXISTS (
|
||||
SELECT 1 FROM workspace_members
|
||||
WHERE workspace_id = workspace_uuid
|
||||
AND user_id = user_uuid
|
||||
AND role IN ('OWNER', 'ADMIN')
|
||||
);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql STABLE SECURITY DEFINER;
|
||||
@@ -0,0 +1,58 @@
|
||||
-- Security Fix: SQL Injection Hardening for is_workspace_admin() Helper Function
|
||||
-- This migration adds explicit UUID validation to prevent SQL injection attacks
|
||||
--
|
||||
-- Related: #355 Code Review - Security CRIT-3
|
||||
-- Original issue: Migration 20260129221004_add_rls_policies
|
||||
|
||||
-- =============================================================================
|
||||
-- SECURITY FIX: Add explicit UUID validation to is_workspace_admin()
|
||||
-- =============================================================================
|
||||
-- The is_workspace_admin() function previously accepted UUID parameters without
|
||||
-- explicit type casting/validation. Although PostgreSQL's parameter binding provides
|
||||
-- some protection, explicit UUID type validation is a security best practice.
|
||||
--
|
||||
-- This fix adds explicit UUID validation using PostgreSQL's uuid type checking
|
||||
-- to ensure that non-UUID values cannot bypass the function's intent.
|
||||
|
||||
CREATE OR REPLACE FUNCTION is_workspace_admin(workspace_uuid UUID, user_uuid UUID)
|
||||
RETURNS BOOLEAN AS $$
|
||||
DECLARE
|
||||
-- Validate input parameters are valid UUIDs
|
||||
v_workspace_id UUID;
|
||||
v_user_id UUID;
|
||||
BEGIN
|
||||
-- Explicitly validate workspace_uuid parameter
|
||||
IF workspace_uuid IS NULL THEN
|
||||
RETURN FALSE;
|
||||
END IF;
|
||||
v_workspace_id := workspace_uuid::UUID;
|
||||
|
||||
-- Explicitly validate user_uuid parameter
|
||||
IF user_uuid IS NULL THEN
|
||||
RETURN FALSE;
|
||||
END IF;
|
||||
v_user_id := user_uuid::UUID;
|
||||
|
||||
-- Query with validated parameters
|
||||
RETURN EXISTS (
|
||||
SELECT 1 FROM workspace_members
|
||||
WHERE workspace_id = v_workspace_id
|
||||
AND user_id = v_user_id
|
||||
AND role IN ('OWNER', 'ADMIN')
|
||||
);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql STABLE SECURITY DEFINER;
|
||||
|
||||
-- =============================================================================
|
||||
-- NOTES
|
||||
-- =============================================================================
|
||||
-- This is a hardening fix that adds defense-in-depth to the is_workspace_admin()
|
||||
-- helper function. While PostgreSQL's parameterized queries already provide
|
||||
-- protection against SQL injection, explicit UUID type validation ensures:
|
||||
--
|
||||
-- 1. Parameters are explicitly cast to UUID type
|
||||
-- 2. NULL values are handled defensively
|
||||
-- 3. The function's intent is clear and secure
|
||||
-- 4. Compliance with security best practices
|
||||
--
|
||||
-- This change is backward compatible and does not affect existing functionality.
|
||||
@@ -0,0 +1,91 @@
|
||||
-- Row-Level Security (RLS) for Auth Tables
|
||||
-- This migration adds FORCE ROW LEVEL SECURITY and policies to accounts and sessions tables
|
||||
-- to ensure users can only access their own authentication data.
|
||||
--
|
||||
-- Related: #350 - Add RLS policies to auth tables with FORCE enforcement
|
||||
-- Design: docs/design/credential-security.md (Phase 1a)
|
||||
|
||||
-- =============================================================================
|
||||
-- ENABLE FORCE RLS ON AUTH TABLES
|
||||
-- =============================================================================
|
||||
-- FORCE means the table owner (mosaic) is also subject to RLS policies.
|
||||
-- This prevents Prisma (connecting as owner) from bypassing policies.
|
||||
|
||||
ALTER TABLE accounts ENABLE ROW LEVEL SECURITY;
|
||||
ALTER TABLE accounts FORCE ROW LEVEL SECURITY;
|
||||
|
||||
ALTER TABLE sessions ENABLE ROW LEVEL SECURITY;
|
||||
ALTER TABLE sessions FORCE ROW LEVEL SECURITY;
|
||||
|
||||
-- =============================================================================
|
||||
-- ACCOUNTS TABLE POLICIES
|
||||
-- =============================================================================
|
||||
|
||||
-- Owner bypass policy: Allow access to all rows ONLY when no RLS context is set
|
||||
-- This is required for:
|
||||
-- 1. Prisma migrations that run without RLS context
|
||||
-- 2. BetterAuth internal operations during authentication flow (when no user context)
|
||||
-- 3. Database maintenance operations
|
||||
-- When RLS context IS set (current_user_id() returns non-NULL), this policy does not apply
|
||||
--
|
||||
-- NOTE: If connecting as a PostgreSQL superuser (like the default 'mosaic' role),
|
||||
-- RLS policies are bypassed entirely. For full RLS enforcement, the application
|
||||
-- should connect as a non-superuser role. See docs/design/credential-security.md
|
||||
CREATE POLICY accounts_owner_bypass ON accounts
|
||||
FOR ALL
|
||||
USING (current_user_id() IS NULL);
|
||||
|
||||
-- User access policy: Users can only access their own accounts
|
||||
-- Uses current_user_id() helper from migration 20260129221004_add_rls_policies
|
||||
-- This policy applies to all operations: SELECT, INSERT, UPDATE, DELETE
|
||||
CREATE POLICY accounts_user_access ON accounts
|
||||
FOR ALL
|
||||
USING (user_id = current_user_id());
|
||||
|
||||
-- =============================================================================
|
||||
-- SESSIONS TABLE POLICIES
|
||||
-- =============================================================================
|
||||
|
||||
-- Owner bypass policy: Allow access to all rows ONLY when no RLS context is set
|
||||
-- See note on accounts_owner_bypass policy about superuser limitations
|
||||
CREATE POLICY sessions_owner_bypass ON sessions
|
||||
FOR ALL
|
||||
USING (current_user_id() IS NULL);
|
||||
|
||||
-- User access policy: Users can only access their own sessions
|
||||
CREATE POLICY sessions_user_access ON sessions
|
||||
FOR ALL
|
||||
USING (user_id = current_user_id());
|
||||
|
||||
-- =============================================================================
|
||||
-- VERIFICATION TABLE ANALYSIS
|
||||
-- =============================================================================
|
||||
-- The verifications table does NOT need RLS policies because:
|
||||
-- 1. It stores ephemeral verification tokens (email verification, password reset)
|
||||
-- 2. It has no user_id column - only identifier (email) and value (token)
|
||||
-- 3. Tokens are short-lived and accessed by token value, not user context
|
||||
-- 4. BetterAuth manages access control through token validation, not RLS
|
||||
-- 5. No cross-user data leakage risk since tokens are random and expire
|
||||
--
|
||||
-- Therefore, we intentionally do NOT add RLS to verifications table.
|
||||
|
||||
-- =============================================================================
|
||||
-- IMPORTANT: SUPERUSER LIMITATION
|
||||
-- =============================================================================
|
||||
-- PostgreSQL superusers (including the default 'mosaic' role) ALWAYS bypass
|
||||
-- Row-Level Security policies, even with FORCE ROW LEVEL SECURITY enabled.
|
||||
-- This is a fundamental PostgreSQL security design.
|
||||
--
|
||||
-- For production deployments with full RLS enforcement, create a dedicated
|
||||
-- non-superuser application role:
|
||||
--
|
||||
-- CREATE ROLE mosaic_app WITH LOGIN PASSWORD 'secure-password';
|
||||
-- GRANT CONNECT ON DATABASE mosaic TO mosaic_app;
|
||||
-- GRANT USAGE ON SCHEMA public TO mosaic_app;
|
||||
-- GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO mosaic_app;
|
||||
-- GRANT USAGE, SELECT ON ALL SEQUENCES IN SCHEMA public TO mosaic_app;
|
||||
--
|
||||
-- Then update DATABASE_URL to connect as mosaic_app instead of mosaic.
|
||||
-- The RLS policies will then be properly enforced for application queries.
|
||||
--
|
||||
-- See: https://www.postgresql.org/docs/current/ddl-rowsecurity.html
|
||||
@@ -0,0 +1,76 @@
|
||||
-- Rollback: User Credentials Storage with RLS Policies
|
||||
-- This migration reverses all changes from migration.sql
|
||||
--
|
||||
-- Related: #355 - Create UserCredential Prisma model with RLS policies
|
||||
|
||||
-- =============================================================================
|
||||
-- DROP TRIGGERS AND FUNCTIONS
|
||||
-- =============================================================================
|
||||
|
||||
DROP TRIGGER IF EXISTS user_credentials_updated_at ON user_credentials;
|
||||
DROP FUNCTION IF EXISTS update_user_credentials_updated_at();
|
||||
|
||||
-- =============================================================================
|
||||
-- DISABLE RLS
|
||||
-- =============================================================================
|
||||
|
||||
ALTER TABLE user_credentials DISABLE ROW LEVEL SECURITY;
|
||||
|
||||
-- =============================================================================
|
||||
-- DROP RLS POLICIES
|
||||
-- =============================================================================
|
||||
|
||||
DROP POLICY IF EXISTS user_credentials_owner_bypass ON user_credentials;
|
||||
DROP POLICY IF EXISTS user_credentials_user_access ON user_credentials;
|
||||
DROP POLICY IF EXISTS user_credentials_workspace_access ON user_credentials;
|
||||
|
||||
-- =============================================================================
|
||||
-- DROP INDEXES
|
||||
-- =============================================================================
|
||||
|
||||
DROP INDEX IF EXISTS "user_credentials_user_id_workspace_id_provider_name_key";
|
||||
DROP INDEX IF EXISTS "user_credentials_scope_is_active_idx";
|
||||
DROP INDEX IF EXISTS "user_credentials_workspace_id_scope_idx";
|
||||
DROP INDEX IF EXISTS "user_credentials_user_id_scope_idx";
|
||||
DROP INDEX IF EXISTS "user_credentials_workspace_id_idx";
|
||||
DROP INDEX IF EXISTS "user_credentials_user_id_idx";
|
||||
|
||||
-- =============================================================================
|
||||
-- DROP FOREIGN KEY CONSTRAINTS
|
||||
-- =============================================================================
|
||||
|
||||
ALTER TABLE "user_credentials" DROP CONSTRAINT IF EXISTS "user_credentials_workspace_id_fkey";
|
||||
ALTER TABLE "user_credentials" DROP CONSTRAINT IF EXISTS "user_credentials_user_id_fkey";
|
||||
|
||||
-- =============================================================================
|
||||
-- DROP TABLE
|
||||
-- =============================================================================
|
||||
|
||||
DROP TABLE IF EXISTS "user_credentials";
|
||||
|
||||
-- =============================================================================
|
||||
-- DROP ENUMS
|
||||
-- =============================================================================
|
||||
-- NOTE: ENUM values cannot be easily removed from an existing enum type in PostgreSQL.
|
||||
-- To fully reverse this migration, you would need to:
|
||||
--
|
||||
-- 1. Remove the 'CREDENTIAL' value from EntityType enum (if not used elsewhere):
|
||||
-- ALTER TYPE "EntityType" RENAME TO "EntityType_old";
|
||||
-- CREATE TYPE "EntityType" AS ENUM (...all values except CREDENTIAL...);
|
||||
-- -- Then rebuild all dependent objects
|
||||
--
|
||||
-- 2. Remove credential-related actions from ActivityAction enum (if not used elsewhere):
|
||||
-- ALTER TYPE "ActivityAction" RENAME TO "ActivityAction_old";
|
||||
-- CREATE TYPE "ActivityAction" AS ENUM (...all values except CREDENTIAL_*...);
|
||||
-- -- Then rebuild all dependent objects
|
||||
--
|
||||
-- 3. Drop the CredentialType and CredentialScope enums:
|
||||
-- DROP TYPE IF EXISTS "CredentialType";
|
||||
-- DROP TYPE IF EXISTS "CredentialScope";
|
||||
--
|
||||
-- Due to the complexity and risk of breaking existing data/code that references
|
||||
-- these enum values, this migration does NOT automatically remove them.
|
||||
-- If you need to clean up the enums, manually execute the steps above.
|
||||
--
|
||||
-- For development environments, you can safely drop and recreate the enums manually
|
||||
-- using the SQL statements above.
|
||||
@@ -0,0 +1,184 @@
|
||||
-- User Credentials Storage with RLS Policies
|
||||
-- This migration adds the user_credentials table for secure storage of user API keys,
|
||||
-- OAuth tokens, and other credentials with encryption and RLS enforcement.
|
||||
--
|
||||
-- Related: #355 - Create UserCredential Prisma model with RLS policies
|
||||
-- Design: docs/design/credential-security.md (Phase 3a)
|
||||
|
||||
-- =============================================================================
|
||||
-- CREATE ENUMS
|
||||
-- =============================================================================
|
||||
|
||||
-- CredentialType enum: Types of credentials that can be stored
|
||||
CREATE TYPE "CredentialType" AS ENUM ('API_KEY', 'OAUTH_TOKEN', 'ACCESS_TOKEN', 'SECRET', 'PASSWORD', 'CUSTOM');
|
||||
|
||||
-- CredentialScope enum: Access scope for credentials
|
||||
CREATE TYPE "CredentialScope" AS ENUM ('USER', 'WORKSPACE', 'SYSTEM');
|
||||
|
||||
-- =============================================================================
|
||||
-- EXTEND EXISTING ENUMS
|
||||
-- =============================================================================
|
||||
|
||||
-- Add CREDENTIAL to EntityType for activity logging
|
||||
ALTER TYPE "EntityType" ADD VALUE 'CREDENTIAL';
|
||||
|
||||
-- Add credential-related actions to ActivityAction
|
||||
ALTER TYPE "ActivityAction" ADD VALUE 'CREDENTIAL_CREATED';
|
||||
ALTER TYPE "ActivityAction" ADD VALUE 'CREDENTIAL_ACCESSED';
|
||||
ALTER TYPE "ActivityAction" ADD VALUE 'CREDENTIAL_ROTATED';
|
||||
ALTER TYPE "ActivityAction" ADD VALUE 'CREDENTIAL_REVOKED';
|
||||
|
||||
-- =============================================================================
|
||||
-- CREATE USER_CREDENTIALS TABLE
|
||||
-- =============================================================================
|
||||
|
||||
CREATE TABLE "user_credentials" (
|
||||
"id" UUID NOT NULL DEFAULT uuid_generate_v4(),
|
||||
"user_id" UUID NOT NULL,
|
||||
"workspace_id" UUID,
|
||||
|
||||
-- Identity
|
||||
"name" VARCHAR(255) NOT NULL,
|
||||
"provider" VARCHAR(100) NOT NULL,
|
||||
"type" "CredentialType" NOT NULL,
|
||||
"scope" "CredentialScope" NOT NULL DEFAULT 'USER',
|
||||
|
||||
-- Encrypted storage
|
||||
"encrypted_value" TEXT NOT NULL,
|
||||
"masked_value" VARCHAR(20),
|
||||
|
||||
-- Metadata
|
||||
"description" TEXT,
|
||||
"expires_at" TIMESTAMPTZ,
|
||||
"last_used_at" TIMESTAMPTZ,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
-- Status
|
||||
"is_active" BOOLEAN NOT NULL DEFAULT true,
|
||||
"rotated_at" TIMESTAMPTZ,
|
||||
|
||||
-- Audit
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
|
||||
CONSTRAINT "user_credentials_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- =============================================================================
|
||||
-- CREATE FOREIGN KEY CONSTRAINTS
|
||||
-- =============================================================================
|
||||
|
||||
ALTER TABLE "user_credentials" ADD CONSTRAINT "user_credentials_user_id_fkey"
|
||||
FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
ALTER TABLE "user_credentials" ADD CONSTRAINT "user_credentials_workspace_id_fkey"
|
||||
FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- =============================================================================
|
||||
-- CREATE INDEXES
|
||||
-- =============================================================================
|
||||
|
||||
-- Index for user lookups
|
||||
CREATE INDEX "user_credentials_user_id_idx" ON "user_credentials"("user_id");
|
||||
|
||||
-- Index for workspace lookups
|
||||
CREATE INDEX "user_credentials_workspace_id_idx" ON "user_credentials"("workspace_id");
|
||||
|
||||
-- Index for user + scope queries
|
||||
CREATE INDEX "user_credentials_user_id_scope_idx" ON "user_credentials"("user_id", "scope");
|
||||
|
||||
-- Index for workspace + scope queries
|
||||
CREATE INDEX "user_credentials_workspace_id_scope_idx" ON "user_credentials"("workspace_id", "scope");
|
||||
|
||||
-- Index for scope + active status queries
|
||||
CREATE INDEX "user_credentials_scope_is_active_idx" ON "user_credentials"("scope", "is_active");
|
||||
|
||||
-- =============================================================================
|
||||
-- CREATE UNIQUE CONSTRAINT
|
||||
-- =============================================================================
|
||||
|
||||
-- Prevent duplicate credentials per user/workspace/provider/name
|
||||
CREATE UNIQUE INDEX "user_credentials_user_id_workspace_id_provider_name_key"
|
||||
ON "user_credentials"("user_id", "workspace_id", "provider", "name");
|
||||
|
||||
-- =============================================================================
|
||||
-- ENABLE FORCE ROW LEVEL SECURITY
|
||||
-- =============================================================================
|
||||
-- FORCE means the table owner (mosaic) is also subject to RLS policies.
|
||||
-- This prevents Prisma (connecting as owner) from bypassing policies.
|
||||
|
||||
ALTER TABLE user_credentials ENABLE ROW LEVEL SECURITY;
|
||||
ALTER TABLE user_credentials FORCE ROW LEVEL SECURITY;
|
||||
|
||||
-- =============================================================================
|
||||
-- RLS POLICIES
|
||||
-- =============================================================================
|
||||
|
||||
-- Owner bypass policy: Allow access to all rows ONLY when no RLS context is set
|
||||
-- This is required for:
|
||||
-- 1. Prisma migrations that run without RLS context
|
||||
-- 2. Database maintenance operations
|
||||
-- When RLS context IS set (current_user_id() returns non-NULL), this policy does not apply
|
||||
--
|
||||
-- NOTE: If connecting as a PostgreSQL superuser (like the default 'mosaic' role),
|
||||
-- RLS policies are bypassed entirely. For full RLS enforcement, the application
|
||||
-- should connect as a non-superuser role. See docs/design/credential-security.md
|
||||
CREATE POLICY user_credentials_owner_bypass ON user_credentials
|
||||
FOR ALL
|
||||
USING (current_user_id() IS NULL);
|
||||
|
||||
-- User access policy: USER-scoped credentials visible only to owner
|
||||
-- Uses current_user_id() helper from migration 20260129221004_add_rls_policies
|
||||
CREATE POLICY user_credentials_user_access ON user_credentials
|
||||
FOR ALL
|
||||
USING (
|
||||
scope = 'USER' AND user_id = current_user_id()
|
||||
);
|
||||
|
||||
-- Workspace admin access policy: WORKSPACE-scoped credentials visible to workspace admins
|
||||
-- Uses is_workspace_admin() helper from migration 20260129221004_add_rls_policies
|
||||
CREATE POLICY user_credentials_workspace_access ON user_credentials
|
||||
FOR ALL
|
||||
USING (
|
||||
scope = 'WORKSPACE'
|
||||
AND workspace_id IS NOT NULL
|
||||
AND is_workspace_admin(workspace_id, current_user_id())
|
||||
);
|
||||
|
||||
-- SYSTEM-scoped credentials are only accessible via owner bypass policy
|
||||
-- (when current_user_id() IS NULL, which happens for admin operations)
|
||||
|
||||
-- =============================================================================
|
||||
-- AUDIT TRIGGER
|
||||
-- =============================================================================
|
||||
|
||||
-- Update updated_at timestamp on row changes
|
||||
CREATE OR REPLACE FUNCTION update_user_credentials_updated_at()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = NOW();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER user_credentials_updated_at
|
||||
BEFORE UPDATE ON user_credentials
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_user_credentials_updated_at();
|
||||
|
||||
-- =============================================================================
|
||||
-- NOTES
|
||||
-- =============================================================================
|
||||
-- This migration creates the foundation for secure credential storage.
|
||||
-- The encrypted_value column stores ciphertext in one of two formats:
|
||||
--
|
||||
-- 1. OpenBao Transit format (preferred): vault:v1:base64data
|
||||
-- 2. AES-256-GCM fallback format: iv:authTag:encrypted
|
||||
--
|
||||
-- The VaultService (issue #353) handles encryption/decryption with automatic
|
||||
-- fallback to CryptoService when OpenBao is unavailable.
|
||||
--
|
||||
-- RLS enforcement ensures:
|
||||
-- - USER scope: Only the credential owner can access
|
||||
-- - WORKSPACE scope: Only workspace admins can access
|
||||
-- - SYSTEM scope: Only accessible via admin/migration bypass
|
||||
@@ -0,0 +1,37 @@
|
||||
-- Encrypt existing plaintext Account tokens
|
||||
-- This migration adds an encryption_version column and marks existing records for encryption
|
||||
-- The actual encryption happens via Prisma middleware on first read/write
|
||||
|
||||
-- Add encryption_version column to track encryption state
|
||||
-- NULL = not encrypted (legacy plaintext)
|
||||
-- 'aes' = AES-256-GCM encrypted
|
||||
-- 'vault' = OpenBao Transit encrypted (Phase 2)
|
||||
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS encryption_version VARCHAR(20);
|
||||
|
||||
-- Create index for efficient queries filtering by encryption status
|
||||
-- This index is also declared in Prisma schema (@@index([encryptionVersion]))
|
||||
-- Using CREATE INDEX IF NOT EXISTS for idempotency
|
||||
CREATE INDEX IF NOT EXISTS "accounts_encryption_version_idx" ON accounts(encryption_version);
|
||||
|
||||
-- Verify index was created successfully by running:
|
||||
-- SELECT indexname, indexdef FROM pg_indexes WHERE tablename = 'accounts' AND indexname = 'accounts_encryption_version_idx';
|
||||
|
||||
-- Update statistics for query planner
|
||||
ANALYZE accounts;
|
||||
|
||||
-- Migration Note:
|
||||
-- This migration does NOT encrypt data in-place to avoid downtime and data corruption risks.
|
||||
-- Instead, the Prisma middleware (account-encryption.middleware.ts) handles encryption:
|
||||
--
|
||||
-- 1. On READ: Detects format (plaintext vs encrypted) and decrypts if needed
|
||||
-- 2. On WRITE: Encrypts tokens and sets encryption_version = 'aes'
|
||||
-- 3. Backward compatible: Plaintext tokens (encryption_version = NULL) are passed through unchanged
|
||||
--
|
||||
-- To actively encrypt existing tokens, run the companion script:
|
||||
-- node scripts/encrypt-account-tokens.js
|
||||
--
|
||||
-- This approach ensures:
|
||||
-- - Zero downtime migration
|
||||
-- - No risk of corrupting tokens during bulk encryption
|
||||
-- - Progressive encryption as tokens are accessed/refreshed
|
||||
-- - Easy rollback (middleware is idempotent)
|
||||
@@ -0,0 +1,26 @@
|
||||
-- Encrypt LLM Provider API Keys Migration
|
||||
--
|
||||
-- This migration enables transparent encryption/decryption of LLM provider API keys
|
||||
-- stored in the llm_provider_instances.config JSON field.
|
||||
--
|
||||
-- IMPORTANT: This is a data migration with no schema changes.
|
||||
--
|
||||
-- Strategy:
|
||||
-- 1. Prisma middleware (llm-encryption.middleware.ts) handles encryption/decryption
|
||||
-- 2. Middleware auto-detects encryption format:
|
||||
-- - vault:v1:... = OpenBao Transit encrypted
|
||||
-- - Otherwise = Legacy plaintext (backward compatible)
|
||||
-- 3. New API keys are always encrypted on write
|
||||
-- 4. Existing plaintext keys work until re-saved (lazy migration)
|
||||
--
|
||||
-- To actively encrypt all existing API keys NOW:
|
||||
-- pnpm --filter @mosaic/api migrate:encrypt-llm-keys
|
||||
--
|
||||
-- This approach ensures:
|
||||
-- - Zero downtime migration
|
||||
-- - No schema changes required
|
||||
-- - Backward compatible with plaintext keys
|
||||
-- - Progressive encryption as keys are accessed/updated
|
||||
-- - Easy rollback (middleware is idempotent)
|
||||
--
|
||||
-- Note: No SQL changes needed. This file exists for migration tracking only.
|
||||
@@ -0,0 +1,197 @@
|
||||
-- RecreateEnum: FormalityLevel was dropped in 20260129235248_add_link_storage_fields
|
||||
CREATE TYPE "FormalityLevel" AS ENUM ('VERY_CASUAL', 'CASUAL', 'NEUTRAL', 'FORMAL', 'VERY_FORMAL');
|
||||
|
||||
-- RecreateTable: personalities was dropped in 20260129235248_add_link_storage_fields
|
||||
-- Recreated with current schema (display_name, system_prompt, temperature, etc.)
|
||||
CREATE TABLE "personalities" (
|
||||
"id" UUID NOT NULL,
|
||||
"workspace_id" UUID NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"display_name" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"system_prompt" TEXT NOT NULL,
|
||||
"temperature" DOUBLE PRECISION,
|
||||
"max_tokens" INTEGER,
|
||||
"llm_provider_instance_id" UUID,
|
||||
"is_default" BOOLEAN NOT NULL DEFAULT false,
|
||||
"is_enabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ NOT NULL,
|
||||
|
||||
CONSTRAINT "personalities_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex: personalities
|
||||
CREATE UNIQUE INDEX "personalities_id_workspace_id_key" ON "personalities"("id", "workspace_id");
|
||||
CREATE UNIQUE INDEX "personalities_workspace_id_name_key" ON "personalities"("workspace_id", "name");
|
||||
CREATE INDEX "personalities_workspace_id_idx" ON "personalities"("workspace_id");
|
||||
CREATE INDEX "personalities_workspace_id_is_default_idx" ON "personalities"("workspace_id", "is_default");
|
||||
CREATE INDEX "personalities_workspace_id_is_enabled_idx" ON "personalities"("workspace_id", "is_enabled");
|
||||
CREATE INDEX "personalities_llm_provider_instance_id_idx" ON "personalities"("llm_provider_instance_id");
|
||||
|
||||
-- AddForeignKey: personalities
|
||||
ALTER TABLE "personalities" ADD CONSTRAINT "personalities_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
ALTER TABLE "personalities" ADD CONSTRAINT "personalities_llm_provider_instance_id_fkey" FOREIGN KEY ("llm_provider_instance_id") REFERENCES "llm_provider_instances"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "cron_schedules" (
|
||||
"id" UUID NOT NULL,
|
||||
"workspace_id" UUID NOT NULL,
|
||||
"expression" TEXT NOT NULL,
|
||||
"command" TEXT NOT NULL,
|
||||
"enabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"last_run" TIMESTAMPTZ,
|
||||
"next_run" TIMESTAMPTZ,
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ NOT NULL,
|
||||
|
||||
CONSTRAINT "cron_schedules_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "workspace_llm_settings" (
|
||||
"id" UUID NOT NULL,
|
||||
"workspace_id" UUID NOT NULL,
|
||||
"default_llm_provider_id" UUID,
|
||||
"default_personality_id" UUID,
|
||||
"settings" JSONB DEFAULT '{}',
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ NOT NULL,
|
||||
|
||||
CONSTRAINT "workspace_llm_settings_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "quality_gates" (
|
||||
"id" UUID NOT NULL,
|
||||
"workspace_id" UUID NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"description" TEXT,
|
||||
"type" TEXT NOT NULL,
|
||||
"command" TEXT,
|
||||
"expected_output" TEXT,
|
||||
"is_regex" BOOLEAN NOT NULL DEFAULT false,
|
||||
"required" BOOLEAN NOT NULL DEFAULT true,
|
||||
"order" INTEGER NOT NULL DEFAULT 0,
|
||||
"is_enabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ NOT NULL,
|
||||
|
||||
CONSTRAINT "quality_gates_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "task_rejections" (
|
||||
"id" UUID NOT NULL,
|
||||
"task_id" TEXT NOT NULL,
|
||||
"workspace_id" TEXT NOT NULL,
|
||||
"agent_id" TEXT NOT NULL,
|
||||
"attempt_count" INTEGER NOT NULL,
|
||||
"failures" JSONB NOT NULL,
|
||||
"original_task" TEXT NOT NULL,
|
||||
"started_at" TIMESTAMPTZ NOT NULL,
|
||||
"rejected_at" TIMESTAMPTZ NOT NULL,
|
||||
"escalated" BOOLEAN NOT NULL DEFAULT false,
|
||||
"manual_review" BOOLEAN NOT NULL DEFAULT false,
|
||||
"resolved_at" TIMESTAMPTZ,
|
||||
"resolution" TEXT,
|
||||
|
||||
CONSTRAINT "task_rejections_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "token_budgets" (
|
||||
"id" UUID NOT NULL,
|
||||
"task_id" UUID NOT NULL,
|
||||
"workspace_id" UUID NOT NULL,
|
||||
"agent_id" TEXT NOT NULL,
|
||||
"allocated_tokens" INTEGER NOT NULL,
|
||||
"estimated_complexity" TEXT NOT NULL,
|
||||
"input_tokens_used" INTEGER NOT NULL DEFAULT 0,
|
||||
"output_tokens_used" INTEGER NOT NULL DEFAULT 0,
|
||||
"total_tokens_used" INTEGER NOT NULL DEFAULT 0,
|
||||
"estimated_cost" DECIMAL(10,6),
|
||||
"started_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"last_updated_at" TIMESTAMPTZ NOT NULL,
|
||||
"completed_at" TIMESTAMPTZ,
|
||||
"budget_utilization" DOUBLE PRECISION,
|
||||
"suspicious_pattern" BOOLEAN NOT NULL DEFAULT false,
|
||||
"suspicious_reason" TEXT,
|
||||
|
||||
CONSTRAINT "token_budgets_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "llm_usage_logs" (
|
||||
"id" UUID NOT NULL,
|
||||
"workspace_id" UUID NOT NULL,
|
||||
"user_id" UUID NOT NULL,
|
||||
"provider" VARCHAR(50) NOT NULL,
|
||||
"model" VARCHAR(100) NOT NULL,
|
||||
"provider_instance_id" UUID,
|
||||
"prompt_tokens" INTEGER NOT NULL DEFAULT 0,
|
||||
"completion_tokens" INTEGER NOT NULL DEFAULT 0,
|
||||
"total_tokens" INTEGER NOT NULL DEFAULT 0,
|
||||
"cost_cents" DOUBLE PRECISION,
|
||||
"task_type" VARCHAR(50),
|
||||
"conversation_id" UUID,
|
||||
"duration_ms" INTEGER,
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "llm_usage_logs_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex: cron_schedules
|
||||
CREATE INDEX "cron_schedules_workspace_id_idx" ON "cron_schedules"("workspace_id");
|
||||
CREATE INDEX "cron_schedules_workspace_id_enabled_idx" ON "cron_schedules"("workspace_id", "enabled");
|
||||
CREATE INDEX "cron_schedules_next_run_idx" ON "cron_schedules"("next_run");
|
||||
|
||||
-- CreateIndex: workspace_llm_settings
|
||||
CREATE UNIQUE INDEX "workspace_llm_settings_workspace_id_key" ON "workspace_llm_settings"("workspace_id");
|
||||
CREATE INDEX "workspace_llm_settings_workspace_id_idx" ON "workspace_llm_settings"("workspace_id");
|
||||
CREATE INDEX "workspace_llm_settings_default_llm_provider_id_idx" ON "workspace_llm_settings"("default_llm_provider_id");
|
||||
CREATE INDEX "workspace_llm_settings_default_personality_id_idx" ON "workspace_llm_settings"("default_personality_id");
|
||||
|
||||
-- CreateIndex: quality_gates
|
||||
CREATE UNIQUE INDEX "quality_gates_workspace_id_name_key" ON "quality_gates"("workspace_id", "name");
|
||||
CREATE INDEX "quality_gates_workspace_id_idx" ON "quality_gates"("workspace_id");
|
||||
CREATE INDEX "quality_gates_workspace_id_is_enabled_idx" ON "quality_gates"("workspace_id", "is_enabled");
|
||||
|
||||
-- CreateIndex: task_rejections
|
||||
CREATE INDEX "task_rejections_task_id_idx" ON "task_rejections"("task_id");
|
||||
CREATE INDEX "task_rejections_workspace_id_idx" ON "task_rejections"("workspace_id");
|
||||
CREATE INDEX "task_rejections_agent_id_idx" ON "task_rejections"("agent_id");
|
||||
CREATE INDEX "task_rejections_escalated_idx" ON "task_rejections"("escalated");
|
||||
CREATE INDEX "task_rejections_manual_review_idx" ON "task_rejections"("manual_review");
|
||||
|
||||
-- CreateIndex: token_budgets
|
||||
CREATE UNIQUE INDEX "token_budgets_task_id_key" ON "token_budgets"("task_id");
|
||||
CREATE INDEX "token_budgets_task_id_idx" ON "token_budgets"("task_id");
|
||||
CREATE INDEX "token_budgets_workspace_id_idx" ON "token_budgets"("workspace_id");
|
||||
CREATE INDEX "token_budgets_suspicious_pattern_idx" ON "token_budgets"("suspicious_pattern");
|
||||
|
||||
-- CreateIndex: llm_usage_logs
|
||||
CREATE INDEX "llm_usage_logs_workspace_id_idx" ON "llm_usage_logs"("workspace_id");
|
||||
CREATE INDEX "llm_usage_logs_workspace_id_created_at_idx" ON "llm_usage_logs"("workspace_id", "created_at");
|
||||
CREATE INDEX "llm_usage_logs_user_id_idx" ON "llm_usage_logs"("user_id");
|
||||
CREATE INDEX "llm_usage_logs_provider_idx" ON "llm_usage_logs"("provider");
|
||||
CREATE INDEX "llm_usage_logs_model_idx" ON "llm_usage_logs"("model");
|
||||
CREATE INDEX "llm_usage_logs_provider_instance_id_idx" ON "llm_usage_logs"("provider_instance_id");
|
||||
CREATE INDEX "llm_usage_logs_task_type_idx" ON "llm_usage_logs"("task_type");
|
||||
CREATE INDEX "llm_usage_logs_conversation_id_idx" ON "llm_usage_logs"("conversation_id");
|
||||
|
||||
-- AddForeignKey: cron_schedules
|
||||
ALTER TABLE "cron_schedules" ADD CONSTRAINT "cron_schedules_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey: workspace_llm_settings
|
||||
ALTER TABLE "workspace_llm_settings" ADD CONSTRAINT "workspace_llm_settings_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
ALTER TABLE "workspace_llm_settings" ADD CONSTRAINT "workspace_llm_settings_default_llm_provider_id_fkey" FOREIGN KEY ("default_llm_provider_id") REFERENCES "llm_provider_instances"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
ALTER TABLE "workspace_llm_settings" ADD CONSTRAINT "workspace_llm_settings_default_personality_id_fkey" FOREIGN KEY ("default_personality_id") REFERENCES "personalities"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey: quality_gates
|
||||
ALTER TABLE "quality_gates" ADD CONSTRAINT "quality_gates_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey: llm_usage_logs
|
||||
ALTER TABLE "llm_usage_logs" ADD CONSTRAINT "llm_usage_logs_workspace_id_fkey" FOREIGN KEY ("workspace_id") REFERENCES "workspaces"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
ALTER TABLE "llm_usage_logs" ADD CONSTRAINT "llm_usage_logs_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "users"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
ALTER TABLE "llm_usage_logs" ADD CONSTRAINT "llm_usage_logs_provider_instance_id_fkey" FOREIGN KEY ("provider_instance_id") REFERENCES "llm_provider_instances"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "workspaces" ADD COLUMN "matrix_room_id" TEXT;
|
||||
@@ -0,0 +1,49 @@
|
||||
-- Fix schema drift: tables, indexes, and constraints defined in schema.prisma
|
||||
-- but never created (or dropped and never recreated) by prior migrations.
|
||||
|
||||
-- ============================================
|
||||
-- CreateTable: instances (Federation module)
|
||||
-- Never created in any prior migration
|
||||
-- ============================================
|
||||
CREATE TABLE "instances" (
|
||||
"id" UUID NOT NULL,
|
||||
"instance_id" TEXT NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"url" TEXT NOT NULL,
|
||||
"public_key" TEXT NOT NULL,
|
||||
"private_key" TEXT NOT NULL,
|
||||
"capabilities" JSONB NOT NULL DEFAULT '{}',
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMPTZ NOT NULL,
|
||||
|
||||
CONSTRAINT "instances_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX "instances_instance_id_key" ON "instances"("instance_id");
|
||||
|
||||
-- ============================================
|
||||
-- Recreate dropped unique index on knowledge_links
|
||||
-- Created in 20260129220645_add_knowledge_module, dropped in
|
||||
-- 20260129235248_add_link_storage_fields, never recreated.
|
||||
-- ============================================
|
||||
CREATE UNIQUE INDEX "knowledge_links_source_id_target_id_key" ON "knowledge_links"("source_id", "target_id");
|
||||
|
||||
-- ============================================
|
||||
-- Missing @@unique([id, workspaceId]) composite indexes
|
||||
-- Defined in schema.prisma but never created in migrations.
|
||||
-- (agent_tasks and runner_jobs already have these.)
|
||||
-- ============================================
|
||||
CREATE UNIQUE INDEX "tasks_id_workspace_id_key" ON "tasks"("id", "workspace_id");
|
||||
CREATE UNIQUE INDEX "events_id_workspace_id_key" ON "events"("id", "workspace_id");
|
||||
CREATE UNIQUE INDEX "projects_id_workspace_id_key" ON "projects"("id", "workspace_id");
|
||||
CREATE UNIQUE INDEX "activity_logs_id_workspace_id_key" ON "activity_logs"("id", "workspace_id");
|
||||
CREATE UNIQUE INDEX "domains_id_workspace_id_key" ON "domains"("id", "workspace_id");
|
||||
CREATE UNIQUE INDEX "ideas_id_workspace_id_key" ON "ideas"("id", "workspace_id");
|
||||
CREATE UNIQUE INDEX "user_layouts_id_workspace_id_key" ON "user_layouts"("id", "workspace_id");
|
||||
|
||||
-- ============================================
|
||||
-- Missing index on agent_tasks.agent_type
|
||||
-- Defined as @@index([agentType]) in schema.prisma
|
||||
-- ============================================
|
||||
CREATE INDEX "agent_tasks_agent_type_idx" ON "agent_tasks"("agent_type");
|
||||
@@ -62,6 +62,10 @@ enum ActivityAction {
|
||||
LOGOUT
|
||||
PASSWORD_RESET
|
||||
EMAIL_VERIFIED
|
||||
CREDENTIAL_CREATED
|
||||
CREDENTIAL_ACCESSED
|
||||
CREDENTIAL_ROTATED
|
||||
CREDENTIAL_REVOKED
|
||||
}
|
||||
|
||||
enum EntityType {
|
||||
@@ -72,6 +76,7 @@ enum EntityType {
|
||||
USER
|
||||
IDEA
|
||||
DOMAIN
|
||||
CREDENTIAL
|
||||
}
|
||||
|
||||
enum IdeaStatus {
|
||||
@@ -186,6 +191,21 @@ enum FederationMessageStatus {
|
||||
TIMEOUT
|
||||
}
|
||||
|
||||
enum CredentialType {
|
||||
API_KEY
|
||||
OAUTH_TOKEN
|
||||
ACCESS_TOKEN
|
||||
SECRET
|
||||
PASSWORD
|
||||
CUSTOM
|
||||
}
|
||||
|
||||
enum CredentialScope {
|
||||
USER
|
||||
WORKSPACE
|
||||
SYSTEM
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// MODELS
|
||||
// ============================================
|
||||
@@ -221,6 +241,8 @@ model User {
|
||||
knowledgeEntryVersions KnowledgeEntryVersion[] @relation("EntryVersionAuthor")
|
||||
llmProviders LlmProviderInstance[] @relation("UserLlmProviders")
|
||||
federatedIdentities FederatedIdentity[]
|
||||
llmUsageLogs LlmUsageLog[] @relation("UserLlmUsageLogs")
|
||||
userCredentials UserCredential[] @relation("UserCredentials")
|
||||
|
||||
@@map("users")
|
||||
}
|
||||
@@ -239,39 +261,42 @@ model UserPreference {
|
||||
}
|
||||
|
||||
model Workspace {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
name String
|
||||
ownerId String @map("owner_id") @db.Uuid
|
||||
settings Json @default("{}")
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz
|
||||
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
name String
|
||||
ownerId String @map("owner_id") @db.Uuid
|
||||
settings Json @default("{}")
|
||||
matrixRoomId String? @map("matrix_room_id")
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz
|
||||
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz
|
||||
|
||||
// Relations
|
||||
owner User @relation("WorkspaceOwner", fields: [ownerId], references: [id], onDelete: Cascade)
|
||||
members WorkspaceMember[]
|
||||
teams Team[]
|
||||
tasks Task[]
|
||||
events Event[]
|
||||
projects Project[]
|
||||
activityLogs ActivityLog[]
|
||||
memoryEmbeddings MemoryEmbedding[]
|
||||
domains Domain[]
|
||||
ideas Idea[]
|
||||
relationships Relationship[]
|
||||
agents Agent[]
|
||||
agentSessions AgentSession[]
|
||||
agentTasks AgentTask[]
|
||||
userLayouts UserLayout[]
|
||||
knowledgeEntries KnowledgeEntry[]
|
||||
knowledgeTags KnowledgeTag[]
|
||||
cronSchedules CronSchedule[]
|
||||
personalities Personality[]
|
||||
llmSettings WorkspaceLlmSettings?
|
||||
qualityGates QualityGate[]
|
||||
runnerJobs RunnerJob[]
|
||||
federationConnections FederationConnection[]
|
||||
federationMessages FederationMessage[]
|
||||
federationEventSubscriptions FederationEventSubscription[]
|
||||
owner User @relation("WorkspaceOwner", fields: [ownerId], references: [id], onDelete: Cascade)
|
||||
members WorkspaceMember[]
|
||||
teams Team[]
|
||||
tasks Task[]
|
||||
events Event[]
|
||||
projects Project[]
|
||||
activityLogs ActivityLog[]
|
||||
memoryEmbeddings MemoryEmbedding[]
|
||||
domains Domain[]
|
||||
ideas Idea[]
|
||||
relationships Relationship[]
|
||||
agents Agent[]
|
||||
agentSessions AgentSession[]
|
||||
agentTasks AgentTask[]
|
||||
userLayouts UserLayout[]
|
||||
knowledgeEntries KnowledgeEntry[]
|
||||
knowledgeTags KnowledgeTag[]
|
||||
cronSchedules CronSchedule[]
|
||||
personalities Personality[]
|
||||
llmSettings WorkspaceLlmSettings?
|
||||
qualityGates QualityGate[]
|
||||
runnerJobs RunnerJob[]
|
||||
federationConnections FederationConnection[]
|
||||
federationMessages FederationMessage[]
|
||||
federationEventSubscriptions FederationEventSubscription[]
|
||||
llmUsageLogs LlmUsageLog[]
|
||||
userCredentials UserCredential[]
|
||||
|
||||
@@index([ownerId])
|
||||
@@map("workspaces")
|
||||
@@ -781,6 +806,7 @@ model Account {
|
||||
refreshTokenExpiresAt DateTime? @map("refresh_token_expires_at") @db.Timestamptz
|
||||
scope String?
|
||||
password String?
|
||||
encryptionVersion String? @map("encryption_version") @db.VarChar(20)
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz
|
||||
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz
|
||||
|
||||
@@ -789,6 +815,7 @@ model Account {
|
||||
|
||||
@@unique([providerId, accountId])
|
||||
@@index([userId])
|
||||
@@index([encryptionVersion])
|
||||
@@map("accounts")
|
||||
}
|
||||
|
||||
@@ -804,6 +831,52 @@ model Verification {
|
||||
@@map("verifications")
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// USER CREDENTIALS MODULE
|
||||
// ============================================
|
||||
|
||||
model UserCredential {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
userId String @map("user_id") @db.Uuid
|
||||
workspaceId String? @map("workspace_id") @db.Uuid
|
||||
|
||||
// Identity
|
||||
name String
|
||||
provider String // "github", "openai", "custom"
|
||||
type CredentialType
|
||||
scope CredentialScope @default(USER)
|
||||
|
||||
// Encrypted storage
|
||||
encryptedValue String @map("encrypted_value") @db.Text
|
||||
maskedValue String? @map("masked_value") @db.VarChar(20)
|
||||
|
||||
// Metadata
|
||||
description String? @db.Text
|
||||
expiresAt DateTime? @map("expires_at") @db.Timestamptz
|
||||
lastUsedAt DateTime? @map("last_used_at") @db.Timestamptz
|
||||
metadata Json @default("{}")
|
||||
|
||||
// Status
|
||||
isActive Boolean @default(true) @map("is_active")
|
||||
rotatedAt DateTime? @map("rotated_at") @db.Timestamptz
|
||||
|
||||
// Audit
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz
|
||||
updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz
|
||||
|
||||
// Relations
|
||||
user User @relation("UserCredentials", fields: [userId], references: [id], onDelete: Cascade)
|
||||
workspace Workspace? @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@unique([userId, workspaceId, provider, name])
|
||||
@@index([userId])
|
||||
@@index([workspaceId])
|
||||
@@index([userId, scope])
|
||||
@@index([workspaceId, scope])
|
||||
@@index([scope, isActive])
|
||||
@@map("user_credentials")
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// KNOWLEDGE MODULE
|
||||
// ============================================
|
||||
@@ -1036,6 +1109,7 @@ model LlmProviderInstance {
|
||||
user User? @relation("UserLlmProviders", fields: [userId], references: [id], onDelete: Cascade)
|
||||
personalities Personality[] @relation("PersonalityLlmProvider")
|
||||
workspaceLlmSettings WorkspaceLlmSettings[] @relation("WorkspaceLlmProvider")
|
||||
llmUsageLogs LlmUsageLog[] @relation("LlmUsageLogs")
|
||||
|
||||
@@index([userId])
|
||||
@@index([providerType])
|
||||
@@ -1288,8 +1362,8 @@ model FederationConnection {
|
||||
disconnectedAt DateTime? @map("disconnected_at") @db.Timestamptz
|
||||
|
||||
// Relations
|
||||
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
||||
messages FederationMessage[]
|
||||
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
||||
messages FederationMessage[]
|
||||
eventSubscriptions FederationEventSubscription[]
|
||||
|
||||
@@unique([workspaceId, remoteInstanceId])
|
||||
@@ -1383,3 +1457,53 @@ model FederationEventSubscription {
|
||||
@@index([workspaceId, isActive])
|
||||
@@map("federation_event_subscriptions")
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// LLM USAGE TRACKING MODULE
|
||||
// ============================================
|
||||
|
||||
model LlmUsageLog {
|
||||
id String @id @default(uuid()) @db.Uuid
|
||||
workspaceId String @map("workspace_id") @db.Uuid
|
||||
userId String @map("user_id") @db.Uuid
|
||||
|
||||
// LLM provider and model info
|
||||
provider String @db.VarChar(50)
|
||||
model String @db.VarChar(100)
|
||||
providerInstanceId String? @map("provider_instance_id") @db.Uuid
|
||||
|
||||
// Token usage
|
||||
promptTokens Int @default(0) @map("prompt_tokens")
|
||||
completionTokens Int @default(0) @map("completion_tokens")
|
||||
totalTokens Int @default(0) @map("total_tokens")
|
||||
|
||||
// Optional cost (in cents for precision)
|
||||
costCents Float? @map("cost_cents")
|
||||
|
||||
// Task type for routing analytics
|
||||
taskType String? @map("task_type") @db.VarChar(50)
|
||||
|
||||
// Optional reference to conversation/session
|
||||
conversationId String? @map("conversation_id") @db.Uuid
|
||||
|
||||
// Duration in milliseconds
|
||||
durationMs Int? @map("duration_ms")
|
||||
|
||||
// Timestamp
|
||||
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz
|
||||
|
||||
// Relations
|
||||
workspace Workspace @relation(fields: [workspaceId], references: [id], onDelete: Cascade)
|
||||
user User @relation("UserLlmUsageLogs", fields: [userId], references: [id], onDelete: Cascade)
|
||||
llmProviderInstance LlmProviderInstance? @relation("LlmUsageLogs", fields: [providerInstanceId], references: [id], onDelete: SetNull)
|
||||
|
||||
@@index([workspaceId])
|
||||
@@index([workspaceId, createdAt])
|
||||
@@index([userId])
|
||||
@@index([provider])
|
||||
@@index([model])
|
||||
@@index([providerInstanceId])
|
||||
@@index([taskType])
|
||||
@@index([conversationId])
|
||||
@@map("llm_usage_logs")
|
||||
}
|
||||
|
||||
166
apps/api/scripts/encrypt-llm-keys.ts
Normal file
166
apps/api/scripts/encrypt-llm-keys.ts
Normal file
@@ -0,0 +1,166 @@
|
||||
/**
|
||||
* Data Migration: Encrypt LLM Provider API Keys
|
||||
*
|
||||
* Encrypts all plaintext API keys in llm_provider_instances.config using OpenBao Transit.
|
||||
* This script processes records in batches and runs in a transaction for safety.
|
||||
*
|
||||
* Usage:
|
||||
* pnpm --filter @mosaic/api migrate:encrypt-llm-keys
|
||||
*
|
||||
* Environment Variables:
|
||||
* DATABASE_URL - PostgreSQL connection string
|
||||
* OPENBAO_ADDR - OpenBao server address (default: http://openbao:8200)
|
||||
* APPROLE_CREDENTIALS_PATH - Path to AppRole credentials file
|
||||
*/
|
||||
|
||||
import { PrismaClient } from "@prisma/client";
|
||||
import { VaultService } from "../src/vault/vault.service";
|
||||
import { TransitKey } from "../src/vault/vault.constants";
|
||||
import { Logger } from "@nestjs/common";
|
||||
import { ConfigService } from "@nestjs/config";
|
||||
|
||||
interface LlmProviderConfig {
|
||||
apiKey?: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
interface LlmProviderInstance {
|
||||
id: string;
|
||||
config: LlmProviderConfig;
|
||||
providerType: string;
|
||||
displayName: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a value is already encrypted
|
||||
*/
|
||||
function isEncrypted(value: string): boolean {
|
||||
if (!value || typeof value !== "string") {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Vault format: vault:v1:...
|
||||
if (value.startsWith("vault:v1:")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// AES format: iv:authTag:encrypted (3 colon-separated hex parts)
|
||||
const parts = value.split(":");
|
||||
if (parts.length === 3 && parts.every((part) => /^[0-9a-f]+$/i.test(part))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Main migration function
|
||||
*/
|
||||
async function main(): Promise<void> {
|
||||
const logger = new Logger("EncryptLlmKeys");
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
try {
|
||||
logger.log("Starting LLM API key encryption migration...");
|
||||
|
||||
// Initialize VaultService
|
||||
const configService = new ConfigService();
|
||||
const vaultService = new VaultService(configService);
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-call
|
||||
await vaultService.onModuleInit();
|
||||
|
||||
logger.log("VaultService initialized successfully");
|
||||
|
||||
// Fetch all LLM provider instances
|
||||
const instances = await prisma.llmProviderInstance.findMany({
|
||||
select: {
|
||||
id: true,
|
||||
config: true,
|
||||
providerType: true,
|
||||
displayName: true,
|
||||
},
|
||||
});
|
||||
|
||||
logger.log(`Found ${String(instances.length)} LLM provider instances`);
|
||||
|
||||
let encryptedCount = 0;
|
||||
let skippedCount = 0;
|
||||
let errorCount = 0;
|
||||
|
||||
// Process each instance
|
||||
for (const instance of instances as LlmProviderInstance[]) {
|
||||
try {
|
||||
const config = instance.config;
|
||||
|
||||
// Skip if no apiKey field
|
||||
if (!config.apiKey || typeof config.apiKey !== "string") {
|
||||
logger.debug(`Skipping ${instance.displayName} (${instance.id}): No API key`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip if already encrypted
|
||||
if (isEncrypted(config.apiKey)) {
|
||||
logger.debug(`Skipping ${instance.displayName} (${instance.id}): Already encrypted`);
|
||||
skippedCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Encrypt the API key
|
||||
logger.log(`Encrypting ${instance.displayName} (${instance.providerType})...`);
|
||||
|
||||
const encryptedApiKey = await vaultService.encrypt(config.apiKey, TransitKey.LLM_CONFIG);
|
||||
|
||||
// Update the instance with encrypted key
|
||||
await prisma.llmProviderInstance.update({
|
||||
where: { id: instance.id },
|
||||
data: {
|
||||
config: {
|
||||
...config,
|
||||
apiKey: encryptedApiKey,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
encryptedCount++;
|
||||
logger.log(`✓ Encrypted ${instance.displayName} (${instance.id})`);
|
||||
} catch (error: unknown) {
|
||||
errorCount++;
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
logger.error(`✗ Failed to encrypt ${instance.displayName} (${instance.id}): ${errorMsg}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Summary
|
||||
logger.log("\n=== Migration Summary ===");
|
||||
logger.log(`Total instances: ${String(instances.length)}`);
|
||||
logger.log(`Encrypted: ${String(encryptedCount)}`);
|
||||
logger.log(`Skipped: ${String(skippedCount)}`);
|
||||
logger.log(`Errors: ${String(errorCount)}`);
|
||||
|
||||
if (errorCount > 0) {
|
||||
logger.warn("\n⚠️ Some API keys failed to encrypt. Please review the errors above.");
|
||||
process.exit(1);
|
||||
} else if (encryptedCount === 0) {
|
||||
logger.log("\n✓ All API keys are already encrypted or no keys found.");
|
||||
} else {
|
||||
logger.log("\n✓ Migration completed successfully!");
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
logger.error(`Migration failed: ${errorMsg}`);
|
||||
throw error;
|
||||
} finally {
|
||||
await prisma.$disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
// Run migration
|
||||
main()
|
||||
.then(() => {
|
||||
process.exit(0);
|
||||
})
|
||||
.catch((error: unknown) => {
|
||||
console.error(error);
|
||||
process.exit(1);
|
||||
});
|
||||
@@ -802,7 +802,7 @@ describe("ActivityService", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle database errors gracefully when logging activity", async () => {
|
||||
it("should handle database errors gracefully when logging activity (fire-and-forget)", async () => {
|
||||
const input: CreateActivityLogInput = {
|
||||
workspaceId: "workspace-123",
|
||||
userId: "user-123",
|
||||
@@ -814,7 +814,9 @@ describe("ActivityService", () => {
|
||||
const dbError = new Error("Database connection failed");
|
||||
mockPrismaService.activityLog.create.mockRejectedValue(dbError);
|
||||
|
||||
await expect(service.logActivity(input)).rejects.toThrow("Database connection failed");
|
||||
// Activity logging is fire-and-forget - returns null on error instead of throwing
|
||||
const result = await service.logActivity(input);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should handle extremely large details objects", async () => {
|
||||
@@ -1132,7 +1134,7 @@ describe("ActivityService", () => {
|
||||
});
|
||||
|
||||
describe("database error handling", () => {
|
||||
it("should handle database connection failures in logActivity", async () => {
|
||||
it("should handle database connection failures in logActivity (fire-and-forget)", async () => {
|
||||
const createInput: CreateActivityLogInput = {
|
||||
workspaceId: "workspace-123",
|
||||
userId: "user-123",
|
||||
@@ -1144,7 +1146,9 @@ describe("ActivityService", () => {
|
||||
const dbError = new Error("Connection refused");
|
||||
mockPrismaService.activityLog.create.mockRejectedValue(dbError);
|
||||
|
||||
await expect(service.logActivity(createInput)).rejects.toThrow("Connection refused");
|
||||
// Activity logging is fire-and-forget - returns null on error instead of throwing
|
||||
const result = await service.logActivity(createInput);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should handle Prisma timeout errors in findAll", async () => {
|
||||
|
||||
@@ -18,16 +18,25 @@ export class ActivityService {
|
||||
constructor(private readonly prisma: PrismaService) {}
|
||||
|
||||
/**
|
||||
* Create a new activity log entry
|
||||
* Create a new activity log entry (fire-and-forget)
|
||||
*
|
||||
* Activity logging failures are logged but never propagate to callers.
|
||||
* This ensures activity logging never breaks primary operations.
|
||||
*
|
||||
* @returns The created ActivityLog or null if logging failed
|
||||
*/
|
||||
async logActivity(input: CreateActivityLogInput): Promise<ActivityLog> {
|
||||
async logActivity(input: CreateActivityLogInput): Promise<ActivityLog | null> {
|
||||
try {
|
||||
return await this.prisma.activityLog.create({
|
||||
data: input as unknown as Prisma.ActivityLogCreateInput,
|
||||
});
|
||||
} catch (error) {
|
||||
this.logger.error("Failed to log activity", error);
|
||||
throw error;
|
||||
// Log the error but don't propagate - activity logging is fire-and-forget
|
||||
this.logger.error(
|
||||
`Failed to log activity: action=${input.action} entityType=${input.entityType} entityId=${input.entityId}`,
|
||||
error instanceof Error ? error.stack : String(error)
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,7 +176,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
taskId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -186,7 +195,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
taskId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -205,7 +214,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
taskId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -224,7 +233,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
taskId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -243,7 +252,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
taskId: string,
|
||||
assigneeId: string
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -262,7 +271,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
eventId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -281,7 +290,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
eventId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -300,7 +309,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
eventId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -319,7 +328,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
projectId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -338,7 +347,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
projectId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -357,7 +366,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
projectId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -375,7 +384,7 @@ export class ActivityService {
|
||||
workspaceId: string,
|
||||
userId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -393,7 +402,7 @@ export class ActivityService {
|
||||
workspaceId: string,
|
||||
userId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -412,7 +421,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
memberId: string,
|
||||
role: string
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -430,7 +439,7 @@ export class ActivityService {
|
||||
workspaceId: string,
|
||||
userId: string,
|
||||
memberId: string
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -448,7 +457,7 @@ export class ActivityService {
|
||||
workspaceId: string,
|
||||
userId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -467,7 +476,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
domainId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -486,7 +495,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
domainId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -505,7 +514,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
domainId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -524,7 +533,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
ideaId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -543,7 +552,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
ideaId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
@@ -562,7 +571,7 @@ export class ActivityService {
|
||||
userId: string,
|
||||
ideaId: string,
|
||||
details?: Prisma.JsonValue
|
||||
): Promise<ActivityLog> {
|
||||
): Promise<ActivityLog | null> {
|
||||
return this.logActivity({
|
||||
workspaceId,
|
||||
userId,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { Controller, Get } from "@nestjs/common";
|
||||
import { SkipThrottle } from "@nestjs/throttler";
|
||||
import { AppService } from "./app.service";
|
||||
import { PrismaService } from "./prisma/prisma.service";
|
||||
import type { ApiResponse, HealthStatus } from "@mosaic/shared";
|
||||
@@ -17,6 +18,7 @@ export class AppController {
|
||||
}
|
||||
|
||||
@Get("health")
|
||||
@SkipThrottle()
|
||||
async getHealth(): Promise<ApiResponse<HealthStatus>> {
|
||||
const dbHealthy = await this.prisma.isHealthy();
|
||||
const dbInfo = await this.prisma.getConnectionInfo();
|
||||
|
||||
@@ -3,8 +3,11 @@ import { APP_INTERCEPTOR, APP_GUARD } from "@nestjs/core";
|
||||
import { ThrottlerModule } from "@nestjs/throttler";
|
||||
import { BullModule } from "@nestjs/bullmq";
|
||||
import { ThrottlerValkeyStorageService, ThrottlerApiKeyGuard } from "./common/throttler";
|
||||
import { CsrfGuard } from "./common/guards/csrf.guard";
|
||||
import { CsrfService } from "./common/services/csrf.service";
|
||||
import { AppController } from "./app.controller";
|
||||
import { AppService } from "./app.service";
|
||||
import { CsrfController } from "./common/controllers/csrf.controller";
|
||||
import { PrismaModule } from "./prisma/prisma.module";
|
||||
import { DatabaseModule } from "./database/database.module";
|
||||
import { AuthModule } from "./auth/auth.module";
|
||||
@@ -20,6 +23,7 @@ import { KnowledgeModule } from "./knowledge/knowledge.module";
|
||||
import { UsersModule } from "./users/users.module";
|
||||
import { WebSocketModule } from "./websocket/websocket.module";
|
||||
import { LlmModule } from "./llm/llm.module";
|
||||
import { LlmUsageModule } from "./llm-usage/llm-usage.module";
|
||||
import { BrainModule } from "./brain/brain.module";
|
||||
import { CronModule } from "./cron/cron.module";
|
||||
import { AgentTasksModule } from "./agent-tasks/agent-tasks.module";
|
||||
@@ -32,6 +36,10 @@ import { JobEventsModule } from "./job-events/job-events.module";
|
||||
import { JobStepsModule } from "./job-steps/job-steps.module";
|
||||
import { CoordinatorIntegrationModule } from "./coordinator-integration/coordinator-integration.module";
|
||||
import { FederationModule } from "./federation/federation.module";
|
||||
import { CredentialsModule } from "./credentials/credentials.module";
|
||||
import { MosaicTelemetryModule } from "./mosaic-telemetry";
|
||||
import { SpeechModule } from "./speech/speech.module";
|
||||
import { RlsContextInterceptor } from "./common/interceptors/rls-context.interceptor";
|
||||
|
||||
@Module({
|
||||
imports: [
|
||||
@@ -54,10 +62,13 @@ import { FederationModule } from "./federation/federation.module";
|
||||
}),
|
||||
// BullMQ job queue configuration
|
||||
BullModule.forRoot({
|
||||
connection: {
|
||||
host: process.env.VALKEY_HOST ?? "localhost",
|
||||
port: parseInt(process.env.VALKEY_PORT ?? "6379", 10),
|
||||
},
|
||||
connection: (() => {
|
||||
const url = new URL(process.env.VALKEY_URL ?? "redis://localhost:6379");
|
||||
return {
|
||||
host: url.hostname,
|
||||
port: parseInt(url.port || "6379", 10),
|
||||
};
|
||||
})(),
|
||||
}),
|
||||
TelemetryModule,
|
||||
PrismaModule,
|
||||
@@ -78,6 +89,7 @@ import { FederationModule } from "./federation/federation.module";
|
||||
UsersModule,
|
||||
WebSocketModule,
|
||||
LlmModule,
|
||||
LlmUsageModule,
|
||||
BrainModule,
|
||||
CronModule,
|
||||
AgentTasksModule,
|
||||
@@ -86,18 +98,30 @@ import { FederationModule } from "./federation/federation.module";
|
||||
JobStepsModule,
|
||||
CoordinatorIntegrationModule,
|
||||
FederationModule,
|
||||
CredentialsModule,
|
||||
MosaicTelemetryModule,
|
||||
SpeechModule,
|
||||
],
|
||||
controllers: [AppController],
|
||||
controllers: [AppController, CsrfController],
|
||||
providers: [
|
||||
AppService,
|
||||
CsrfService,
|
||||
{
|
||||
provide: APP_INTERCEPTOR,
|
||||
useClass: TelemetryInterceptor,
|
||||
},
|
||||
{
|
||||
provide: APP_INTERCEPTOR,
|
||||
useClass: RlsContextInterceptor,
|
||||
},
|
||||
{
|
||||
provide: APP_GUARD,
|
||||
useClass: ThrottlerApiKeyGuard,
|
||||
},
|
||||
{
|
||||
provide: APP_GUARD,
|
||||
useClass: CsrfGuard,
|
||||
},
|
||||
],
|
||||
})
|
||||
export class AppModule {}
|
||||
|
||||
677
apps/api/src/auth/auth-rls.integration.spec.ts
Normal file
677
apps/api/src/auth/auth-rls.integration.spec.ts
Normal file
@@ -0,0 +1,677 @@
|
||||
/**
|
||||
* Auth Tables RLS Integration Tests
|
||||
*
|
||||
* Tests that RLS policies on accounts and sessions tables correctly
|
||||
* enforce user-scoped access and prevent cross-user data leakage.
|
||||
*
|
||||
* Related: #350 - Add RLS policies to auth tables with FORCE enforcement
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeAll, afterAll } from "vitest";
|
||||
import { PrismaClient, Prisma } from "@prisma/client";
|
||||
import { randomUUID as uuid } from "crypto";
|
||||
import { runWithRlsClient, getRlsClient } from "../prisma/rls-context.provider";
|
||||
|
||||
describe.skipIf(!process.env.DATABASE_URL)(
|
||||
"Auth Tables RLS Policies (requires DATABASE_URL)",
|
||||
() => {
|
||||
let prisma: PrismaClient;
|
||||
const testData: {
|
||||
users: string[];
|
||||
accounts: string[];
|
||||
sessions: string[];
|
||||
} = {
|
||||
users: [],
|
||||
accounts: [],
|
||||
sessions: [],
|
||||
};
|
||||
|
||||
beforeAll(async () => {
|
||||
// Skip setup if DATABASE_URL is not available
|
||||
if (!process.env.DATABASE_URL) {
|
||||
return;
|
||||
}
|
||||
|
||||
prisma = new PrismaClient();
|
||||
await prisma.$connect();
|
||||
|
||||
// RLS policies are bypassed for superusers
|
||||
const [{ rolsuper }] = await prisma.$queryRaw<[{ rolsuper: boolean }]>`
|
||||
SELECT rolsuper FROM pg_roles WHERE rolname = current_user
|
||||
`;
|
||||
if (rolsuper) {
|
||||
throw new Error(
|
||||
"Auth RLS integration tests require a non-superuser database role. " +
|
||||
"See migration 20260207_add_auth_rls_policies for setup instructions."
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Skip cleanup if DATABASE_URL is not available or prisma not initialized
|
||||
if (!process.env.DATABASE_URL || !prisma) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Clean up test data
|
||||
if (testData.sessions.length > 0) {
|
||||
await prisma.session.deleteMany({
|
||||
where: { id: { in: testData.sessions } },
|
||||
});
|
||||
}
|
||||
|
||||
if (testData.accounts.length > 0) {
|
||||
await prisma.account.deleteMany({
|
||||
where: { id: { in: testData.accounts } },
|
||||
});
|
||||
}
|
||||
|
||||
if (testData.users.length > 0) {
|
||||
await prisma.user.deleteMany({
|
||||
where: { id: { in: testData.users } },
|
||||
});
|
||||
}
|
||||
|
||||
await prisma.$disconnect();
|
||||
} catch (error) {
|
||||
console.error(
|
||||
"Test cleanup failed:",
|
||||
error instanceof Error ? error.message : String(error)
|
||||
);
|
||||
// Re-throw to make test failure visible
|
||||
throw new Error(
|
||||
"Test cleanup failed. Database may contain orphaned test data. " +
|
||||
`Error: ${error instanceof Error ? error.message : String(error)}`
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
async function createTestUser(email: string): Promise<string> {
|
||||
const userId = uuid();
|
||||
await prisma.user.create({
|
||||
data: {
|
||||
id: userId,
|
||||
email,
|
||||
name: `Test User ${email}`,
|
||||
authProviderId: `auth-${userId}`,
|
||||
},
|
||||
});
|
||||
testData.users.push(userId);
|
||||
return userId;
|
||||
}
|
||||
|
||||
async function createTestAccount(userId: string, token: string): Promise<string> {
|
||||
const accountId = uuid();
|
||||
await prisma.account.create({
|
||||
data: {
|
||||
id: accountId,
|
||||
userId,
|
||||
accountId: `account-${accountId}`,
|
||||
providerId: "test-provider",
|
||||
accessToken: token,
|
||||
},
|
||||
});
|
||||
testData.accounts.push(accountId);
|
||||
return accountId;
|
||||
}
|
||||
|
||||
async function createTestSession(userId: string): Promise<string> {
|
||||
const sessionId = uuid();
|
||||
await prisma.session.create({
|
||||
data: {
|
||||
id: sessionId,
|
||||
userId,
|
||||
token: `session-${sessionId}-${Date.now()}`,
|
||||
expiresAt: new Date(Date.now() + 86400000),
|
||||
},
|
||||
});
|
||||
testData.sessions.push(sessionId);
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
describe("Account table RLS", () => {
|
||||
it("should allow user to read their own accounts when RLS context is set", async () => {
|
||||
const user1Id = await createTestUser("account-read-own@test.com");
|
||||
const account1Id = await createTestAccount(user1Id, "user1-token");
|
||||
|
||||
// Use runWithRlsClient to set RLS context
|
||||
const result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const accounts = await client.account.findMany({
|
||||
where: { userId: user1Id },
|
||||
});
|
||||
|
||||
return accounts;
|
||||
});
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe(account1Id);
|
||||
expect(result[0].accessToken).toBe("user1-token");
|
||||
});
|
||||
|
||||
it("should prevent user from reading other users accounts", async () => {
|
||||
const user1Id = await createTestUser("account-read-self@test.com");
|
||||
const user2Id = await createTestUser("account-read-other@test.com");
|
||||
await createTestAccount(user1Id, "user1-token");
|
||||
await createTestAccount(user2Id, "user2-token");
|
||||
|
||||
// Set RLS context for user1, try to read user2's accounts
|
||||
const result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const accounts = await client.account.findMany({
|
||||
where: { userId: user2Id },
|
||||
});
|
||||
|
||||
return accounts;
|
||||
});
|
||||
});
|
||||
|
||||
// Should return empty array due to RLS policy
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("should prevent direct access by ID to other users accounts", async () => {
|
||||
const user1Id = await createTestUser("account-id-self@test.com");
|
||||
const user2Id = await createTestUser("account-id-other@test.com");
|
||||
await createTestAccount(user1Id, "user1-token");
|
||||
const account2Id = await createTestAccount(user2Id, "user2-token");
|
||||
|
||||
// Set RLS context for user1, try to read user2's account by ID
|
||||
const result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const account = await client.account.findUnique({
|
||||
where: { id: account2Id },
|
||||
});
|
||||
|
||||
return account;
|
||||
});
|
||||
});
|
||||
|
||||
// Should return null due to RLS policy
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should allow user to create their own accounts", async () => {
|
||||
const user1Id = await createTestUser("account-create-own@test.com");
|
||||
|
||||
// Set RLS context for user1, create their own account
|
||||
const result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const newAccount = await client.account.create({
|
||||
data: {
|
||||
id: uuid(),
|
||||
userId: user1Id,
|
||||
accountId: "new-account",
|
||||
providerId: "test-provider",
|
||||
accessToken: "new-token",
|
||||
},
|
||||
});
|
||||
|
||||
testData.accounts.push(newAccount.id);
|
||||
return newAccount;
|
||||
});
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.userId).toBe(user1Id);
|
||||
expect(result.accessToken).toBe("new-token");
|
||||
});
|
||||
|
||||
it("should prevent user from creating accounts for other users", async () => {
|
||||
const user1Id = await createTestUser("account-create-self@test.com");
|
||||
const user2Id = await createTestUser("account-create-other@test.com");
|
||||
|
||||
// Set RLS context for user1, try to create an account for user2
|
||||
await expect(
|
||||
prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const newAccount = await client.account.create({
|
||||
data: {
|
||||
id: uuid(),
|
||||
userId: user2Id, // Trying to create for user2 while logged in as user1
|
||||
accountId: "hacked-account",
|
||||
providerId: "test-provider",
|
||||
accessToken: "hacked-token",
|
||||
},
|
||||
});
|
||||
|
||||
testData.accounts.push(newAccount.id);
|
||||
return newAccount;
|
||||
});
|
||||
})
|
||||
).rejects.toThrow();
|
||||
});
|
||||
|
||||
it("should allow user to update their own accounts", async () => {
|
||||
const user1Id = await createTestUser("account-update-own@test.com");
|
||||
const account1Id = await createTestAccount(user1Id, "original-token");
|
||||
|
||||
const result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const updated = await client.account.update({
|
||||
where: { id: account1Id },
|
||||
data: { accessToken: "updated-token" },
|
||||
});
|
||||
|
||||
return updated;
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.accessToken).toBe("updated-token");
|
||||
});
|
||||
|
||||
it("should prevent user from updating other users accounts", async () => {
|
||||
const user1Id = await createTestUser("account-update-self@test.com");
|
||||
const user2Id = await createTestUser("account-update-other@test.com");
|
||||
await createTestAccount(user1Id, "user1-token");
|
||||
const account2Id = await createTestAccount(user2Id, "user2-token");
|
||||
|
||||
// Set RLS context for user1, try to update user2's account
|
||||
await expect(
|
||||
prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
await client.account.update({
|
||||
where: { id: account2Id },
|
||||
data: { accessToken: "hacked-token" },
|
||||
});
|
||||
});
|
||||
})
|
||||
).rejects.toThrow();
|
||||
});
|
||||
|
||||
it("should prevent user from deleting other users accounts", async () => {
|
||||
const user1Id = await createTestUser("account-delete-self@test.com");
|
||||
const user2Id = await createTestUser("account-delete-other@test.com");
|
||||
await createTestAccount(user1Id, "user1-token");
|
||||
const account2Id = await createTestAccount(user2Id, "user2-token");
|
||||
|
||||
// Set RLS context for user1, try to delete user2's account
|
||||
await expect(
|
||||
prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
await client.account.delete({
|
||||
where: { id: account2Id },
|
||||
});
|
||||
});
|
||||
})
|
||||
).rejects.toThrow();
|
||||
|
||||
// Verify the record still exists and wasn't deleted
|
||||
const stillExists = await prisma.account.findUnique({ where: { id: account2Id } });
|
||||
expect(stillExists).not.toBeNull();
|
||||
expect(stillExists?.userId).toBe(user2Id);
|
||||
});
|
||||
|
||||
it("should allow user to delete their own accounts", async () => {
|
||||
const user1Id = await createTestUser("account-delete-own@test.com");
|
||||
const account1Id = await createTestAccount(user1Id, "user1-token");
|
||||
|
||||
// Set RLS context for user1, delete their own account
|
||||
await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
await client.account.delete({
|
||||
where: { id: account1Id },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Verify the record was actually deleted
|
||||
const deleted = await prisma.account.findUnique({ where: { id: account1Id } });
|
||||
expect(deleted).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Session table RLS", () => {
|
||||
it("should allow user to read their own sessions when RLS context is set", async () => {
|
||||
const user1Id = await createTestUser("session-read-own@test.com");
|
||||
const session1Id = await createTestSession(user1Id);
|
||||
|
||||
const result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const sessions = await client.session.findMany({
|
||||
where: { userId: user1Id },
|
||||
});
|
||||
|
||||
return sessions;
|
||||
});
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe(session1Id);
|
||||
});
|
||||
|
||||
it("should prevent user from reading other users sessions", async () => {
|
||||
const user1Id = await createTestUser("session-read-self@test.com");
|
||||
const user2Id = await createTestUser("session-read-other@test.com");
|
||||
await createTestSession(user1Id);
|
||||
await createTestSession(user2Id);
|
||||
|
||||
const result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const sessions = await client.session.findMany({
|
||||
where: { userId: user2Id },
|
||||
});
|
||||
|
||||
return sessions;
|
||||
});
|
||||
});
|
||||
|
||||
// Should return empty array due to RLS policy
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("should prevent direct access by ID to other users sessions", async () => {
|
||||
const user1Id = await createTestUser("session-id-self@test.com");
|
||||
const user2Id = await createTestUser("session-id-other@test.com");
|
||||
await createTestSession(user1Id);
|
||||
const session2Id = await createTestSession(user2Id);
|
||||
|
||||
const result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const session = await client.session.findUnique({
|
||||
where: { id: session2Id },
|
||||
});
|
||||
|
||||
return session;
|
||||
});
|
||||
});
|
||||
|
||||
// Should return null due to RLS policy
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should allow user to create their own sessions", async () => {
|
||||
const user1Id = await createTestUser("session-create-own@test.com");
|
||||
|
||||
// Set RLS context for user1, create their own session
|
||||
const result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const newSession = await client.session.create({
|
||||
data: {
|
||||
id: uuid(),
|
||||
userId: user1Id,
|
||||
token: `new-session-${Date.now()}`,
|
||||
expiresAt: new Date(Date.now() + 86400000),
|
||||
},
|
||||
});
|
||||
|
||||
testData.sessions.push(newSession.id);
|
||||
return newSession;
|
||||
});
|
||||
});
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.userId).toBe(user1Id);
|
||||
expect(result.token).toContain("new-session");
|
||||
});
|
||||
|
||||
it("should prevent user from creating sessions for other users", async () => {
|
||||
const user1Id = await createTestUser("session-create-self@test.com");
|
||||
const user2Id = await createTestUser("session-create-other@test.com");
|
||||
|
||||
// Set RLS context for user1, try to create a session for user2
|
||||
await expect(
|
||||
prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const newSession = await client.session.create({
|
||||
data: {
|
||||
id: uuid(),
|
||||
userId: user2Id, // Trying to create for user2 while logged in as user1
|
||||
token: `hacked-session-${Date.now()}`,
|
||||
expiresAt: new Date(Date.now() + 86400000),
|
||||
},
|
||||
});
|
||||
|
||||
testData.sessions.push(newSession.id);
|
||||
return newSession;
|
||||
});
|
||||
})
|
||||
).rejects.toThrow();
|
||||
});
|
||||
|
||||
it("should allow user to update their own sessions", async () => {
|
||||
const user1Id = await createTestUser("session-update-own@test.com");
|
||||
const session1Id = await createTestSession(user1Id);
|
||||
|
||||
const result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const updated = await client.session.update({
|
||||
where: { id: session1Id },
|
||||
data: { ipAddress: "192.168.1.1" },
|
||||
});
|
||||
|
||||
return updated;
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.ipAddress).toBe("192.168.1.1");
|
||||
});
|
||||
|
||||
it("should prevent user from updating other users sessions", async () => {
|
||||
const user1Id = await createTestUser("session-update-self@test.com");
|
||||
const user2Id = await createTestUser("session-update-other@test.com");
|
||||
await createTestSession(user1Id);
|
||||
const session2Id = await createTestSession(user2Id);
|
||||
|
||||
await expect(
|
||||
prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
await client.session.update({
|
||||
where: { id: session2Id },
|
||||
data: { ipAddress: "10.0.0.1" },
|
||||
});
|
||||
});
|
||||
})
|
||||
).rejects.toThrow();
|
||||
});
|
||||
|
||||
it("should prevent user from deleting other users sessions", async () => {
|
||||
const user1Id = await createTestUser("session-delete-self@test.com");
|
||||
const user2Id = await createTestUser("session-delete-other@test.com");
|
||||
await createTestSession(user1Id);
|
||||
const session2Id = await createTestSession(user2Id);
|
||||
|
||||
await expect(
|
||||
prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
await client.session.delete({
|
||||
where: { id: session2Id },
|
||||
});
|
||||
});
|
||||
})
|
||||
).rejects.toThrow();
|
||||
|
||||
// Verify the record still exists and wasn't deleted
|
||||
const stillExists = await prisma.session.findUnique({ where: { id: session2Id } });
|
||||
expect(stillExists).not.toBeNull();
|
||||
expect(stillExists?.userId).toBe(user2Id);
|
||||
});
|
||||
|
||||
it("should allow user to delete their own sessions", async () => {
|
||||
const user1Id = await createTestUser("session-delete-own@test.com");
|
||||
const session1Id = await createTestSession(user1Id);
|
||||
|
||||
// Set RLS context for user1, delete their own session
|
||||
await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
await client.session.delete({
|
||||
where: { id: session1Id },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Verify the record was actually deleted
|
||||
const deleted = await prisma.session.findUnique({ where: { id: session1Id } });
|
||||
expect(deleted).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Owner bypass policy", () => {
|
||||
it("should allow table owner to access all records without RLS context", async () => {
|
||||
const user1Id = await createTestUser("owner-bypass-1@test.com");
|
||||
const user2Id = await createTestUser("owner-bypass-2@test.com");
|
||||
const account1Id = await createTestAccount(user1Id, "token1");
|
||||
const account2Id = await createTestAccount(user2Id, "token2");
|
||||
|
||||
// Don't set RLS context - rely on owner bypass policy
|
||||
const accounts = await prisma.account.findMany({
|
||||
where: {
|
||||
id: { in: [account1Id, account2Id] },
|
||||
},
|
||||
});
|
||||
|
||||
// Owner should see both accounts
|
||||
expect(accounts).toHaveLength(2);
|
||||
});
|
||||
|
||||
it("should allow migrations to work without RLS context", async () => {
|
||||
const userId = await createTestUser("migration-test@test.com");
|
||||
|
||||
// This simulates a migration or BetterAuth internal operation
|
||||
// that doesn't have RLS context set
|
||||
const newAccount = await prisma.account.create({
|
||||
data: {
|
||||
id: uuid(),
|
||||
userId,
|
||||
accountId: "migration-test-account",
|
||||
providerId: "test-migration",
|
||||
},
|
||||
});
|
||||
|
||||
expect(newAccount.id).toBeDefined();
|
||||
|
||||
// Clean up
|
||||
await prisma.account.delete({
|
||||
where: { id: newAccount.id },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("RLS context isolation", () => {
|
||||
it("should enforce RLS when context is set, even for table owner", async () => {
|
||||
const user1Id = await createTestUser("rls-enforce-1@test.com");
|
||||
const user2Id = await createTestUser("rls-enforce-2@test.com");
|
||||
const account1Id = await createTestAccount(user1Id, "token1");
|
||||
const account2Id = await createTestAccount(user2Id, "token2");
|
||||
|
||||
// With RLS context set for user1, they should only see their own account
|
||||
const result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
const accounts = await client.account.findMany({
|
||||
where: {
|
||||
id: { in: [account1Id, account2Id] },
|
||||
},
|
||||
});
|
||||
|
||||
return accounts;
|
||||
});
|
||||
});
|
||||
|
||||
// Should only see user1's account, not user2's
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe(account1Id);
|
||||
});
|
||||
|
||||
it("should allow different users to see only their own data in separate contexts", async () => {
|
||||
const user1Id = await createTestUser("context-user1@test.com");
|
||||
const user2Id = await createTestUser("context-user2@test.com");
|
||||
const session1Id = await createTestSession(user1Id);
|
||||
const session2Id = await createTestSession(user2Id);
|
||||
|
||||
// User1 context - query for both sessions, but RLS should only return user1's
|
||||
const user1Result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user1Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
return client.session.findMany({
|
||||
where: {
|
||||
id: { in: [session1Id, session2Id] },
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// User2 context - query for both sessions, but RLS should only return user2's
|
||||
const user2Result = await prisma.$transaction(async (tx) => {
|
||||
await tx.$executeRaw`SELECT set_config('app.current_user_id', ${user2Id}::text, true)`;
|
||||
|
||||
return runWithRlsClient(tx, async () => {
|
||||
const client = getRlsClient()!;
|
||||
return client.session.findMany({
|
||||
where: {
|
||||
id: { in: [session1Id, session2Id] },
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Each user should only see their own session
|
||||
expect(user1Result).toHaveLength(1);
|
||||
expect(user1Result[0].id).toBe(session1Id);
|
||||
|
||||
expect(user2Result).toHaveLength(1);
|
||||
expect(user2Result[0].id).toBe(session2Id);
|
||||
});
|
||||
});
|
||||
}
|
||||
);
|
||||
627
apps/api/src/auth/auth.config.spec.ts
Normal file
627
apps/api/src/auth/auth.config.spec.ts
Normal file
@@ -0,0 +1,627 @@
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
|
||||
import type { PrismaClient } from "@prisma/client";
|
||||
|
||||
// Mock better-auth modules to inspect genericOAuth plugin configuration
|
||||
const mockGenericOAuth = vi.fn().mockReturnValue({ id: "generic-oauth" });
|
||||
const mockBetterAuth = vi.fn().mockReturnValue({ handler: vi.fn() });
|
||||
const mockPrismaAdapter = vi.fn().mockReturnValue({});
|
||||
|
||||
vi.mock("better-auth/plugins", () => ({
|
||||
genericOAuth: (...args: unknown[]) => mockGenericOAuth(...args),
|
||||
}));
|
||||
|
||||
vi.mock("better-auth", () => ({
|
||||
betterAuth: (...args: unknown[]) => mockBetterAuth(...args),
|
||||
}));
|
||||
|
||||
vi.mock("better-auth/adapters/prisma", () => ({
|
||||
prismaAdapter: (...args: unknown[]) => mockPrismaAdapter(...args),
|
||||
}));
|
||||
|
||||
import { isOidcEnabled, validateOidcConfig, createAuth, getTrustedOrigins } from "./auth.config";
|
||||
|
||||
describe("auth.config", () => {
|
||||
// Store original env vars to restore after each test
|
||||
const originalEnv = { ...process.env };
|
||||
|
||||
beforeEach(() => {
|
||||
// Clear relevant env vars before each test
|
||||
delete process.env.OIDC_ENABLED;
|
||||
delete process.env.OIDC_ISSUER;
|
||||
delete process.env.OIDC_CLIENT_ID;
|
||||
delete process.env.OIDC_CLIENT_SECRET;
|
||||
delete process.env.OIDC_REDIRECT_URI;
|
||||
delete process.env.NODE_ENV;
|
||||
delete process.env.NEXT_PUBLIC_APP_URL;
|
||||
delete process.env.NEXT_PUBLIC_API_URL;
|
||||
delete process.env.TRUSTED_ORIGINS;
|
||||
delete process.env.COOKIE_DOMAIN;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Restore original env vars
|
||||
process.env = { ...originalEnv };
|
||||
});
|
||||
|
||||
describe("isOidcEnabled", () => {
|
||||
it("should return false when OIDC_ENABLED is not set", () => {
|
||||
expect(isOidcEnabled()).toBe(false);
|
||||
});
|
||||
|
||||
it("should return false when OIDC_ENABLED is 'false'", () => {
|
||||
process.env.OIDC_ENABLED = "false";
|
||||
expect(isOidcEnabled()).toBe(false);
|
||||
});
|
||||
|
||||
it("should return false when OIDC_ENABLED is '0'", () => {
|
||||
process.env.OIDC_ENABLED = "0";
|
||||
expect(isOidcEnabled()).toBe(false);
|
||||
});
|
||||
|
||||
it("should return false when OIDC_ENABLED is empty string", () => {
|
||||
process.env.OIDC_ENABLED = "";
|
||||
expect(isOidcEnabled()).toBe(false);
|
||||
});
|
||||
|
||||
it("should return true when OIDC_ENABLED is 'true'", () => {
|
||||
process.env.OIDC_ENABLED = "true";
|
||||
expect(isOidcEnabled()).toBe(true);
|
||||
});
|
||||
|
||||
it("should return true when OIDC_ENABLED is '1'", () => {
|
||||
process.env.OIDC_ENABLED = "1";
|
||||
expect(isOidcEnabled()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("validateOidcConfig", () => {
|
||||
describe("when OIDC is disabled", () => {
|
||||
it("should not throw when OIDC_ENABLED is not set", () => {
|
||||
expect(() => validateOidcConfig()).not.toThrow();
|
||||
});
|
||||
|
||||
it("should not throw when OIDC_ENABLED is false even if vars are missing", () => {
|
||||
process.env.OIDC_ENABLED = "false";
|
||||
// Intentionally not setting any OIDC vars
|
||||
expect(() => validateOidcConfig()).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("when OIDC is enabled", () => {
|
||||
beforeEach(() => {
|
||||
process.env.OIDC_ENABLED = "true";
|
||||
});
|
||||
|
||||
it("should throw when OIDC_ISSUER is missing", () => {
|
||||
process.env.OIDC_CLIENT_ID = "test-client-id";
|
||||
process.env.OIDC_CLIENT_SECRET = "test-client-secret";
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/auth/callback/authentik";
|
||||
|
||||
expect(() => validateOidcConfig()).toThrow("OIDC_ISSUER");
|
||||
expect(() => validateOidcConfig()).toThrow("OIDC authentication is enabled");
|
||||
});
|
||||
|
||||
it("should throw when OIDC_CLIENT_ID is missing", () => {
|
||||
process.env.OIDC_ISSUER = "https://auth.example.com/";
|
||||
process.env.OIDC_CLIENT_SECRET = "test-client-secret";
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/auth/callback/authentik";
|
||||
|
||||
expect(() => validateOidcConfig()).toThrow("OIDC_CLIENT_ID");
|
||||
});
|
||||
|
||||
it("should throw when OIDC_CLIENT_SECRET is missing", () => {
|
||||
process.env.OIDC_ISSUER = "https://auth.example.com/";
|
||||
process.env.OIDC_CLIENT_ID = "test-client-id";
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/auth/callback/authentik";
|
||||
|
||||
expect(() => validateOidcConfig()).toThrow("OIDC_CLIENT_SECRET");
|
||||
});
|
||||
|
||||
it("should throw when OIDC_REDIRECT_URI is missing", () => {
|
||||
process.env.OIDC_ISSUER = "https://auth.example.com/";
|
||||
process.env.OIDC_CLIENT_ID = "test-client-id";
|
||||
process.env.OIDC_CLIENT_SECRET = "test-client-secret";
|
||||
|
||||
expect(() => validateOidcConfig()).toThrow("OIDC_REDIRECT_URI");
|
||||
});
|
||||
|
||||
it("should throw when all required vars are missing", () => {
|
||||
expect(() => validateOidcConfig()).toThrow(
|
||||
"OIDC_ISSUER, OIDC_CLIENT_ID, OIDC_CLIENT_SECRET, OIDC_REDIRECT_URI"
|
||||
);
|
||||
});
|
||||
|
||||
it("should throw when vars are empty strings", () => {
|
||||
process.env.OIDC_ISSUER = "";
|
||||
process.env.OIDC_CLIENT_ID = "";
|
||||
process.env.OIDC_CLIENT_SECRET = "";
|
||||
process.env.OIDC_REDIRECT_URI = "";
|
||||
|
||||
expect(() => validateOidcConfig()).toThrow(
|
||||
"OIDC_ISSUER, OIDC_CLIENT_ID, OIDC_CLIENT_SECRET, OIDC_REDIRECT_URI"
|
||||
);
|
||||
});
|
||||
|
||||
it("should throw when vars are whitespace only", () => {
|
||||
process.env.OIDC_ISSUER = " ";
|
||||
process.env.OIDC_CLIENT_ID = "test-client-id";
|
||||
process.env.OIDC_CLIENT_SECRET = "test-client-secret";
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/auth/callback/authentik";
|
||||
|
||||
expect(() => validateOidcConfig()).toThrow("OIDC_ISSUER");
|
||||
});
|
||||
|
||||
it("should throw when OIDC_ISSUER does not end with trailing slash", () => {
|
||||
process.env.OIDC_ISSUER = "https://auth.example.com/application/o/mosaic";
|
||||
process.env.OIDC_CLIENT_ID = "test-client-id";
|
||||
process.env.OIDC_CLIENT_SECRET = "test-client-secret";
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/auth/callback/authentik";
|
||||
|
||||
expect(() => validateOidcConfig()).toThrow("OIDC_ISSUER must end with a trailing slash");
|
||||
expect(() => validateOidcConfig()).toThrow("https://auth.example.com/application/o/mosaic");
|
||||
});
|
||||
|
||||
it("should not throw with valid complete configuration", () => {
|
||||
process.env.OIDC_ISSUER = "https://auth.example.com/application/o/mosaic-stack/";
|
||||
process.env.OIDC_CLIENT_ID = "test-client-id";
|
||||
process.env.OIDC_CLIENT_SECRET = "test-client-secret";
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/auth/callback/authentik";
|
||||
|
||||
expect(() => validateOidcConfig()).not.toThrow();
|
||||
});
|
||||
|
||||
it("should suggest disabling OIDC in error message", () => {
|
||||
expect(() => validateOidcConfig()).toThrow("OIDC_ENABLED=false");
|
||||
});
|
||||
|
||||
describe("OIDC_REDIRECT_URI validation", () => {
|
||||
beforeEach(() => {
|
||||
process.env.OIDC_ISSUER = "https://auth.example.com/application/o/mosaic-stack/";
|
||||
process.env.OIDC_CLIENT_ID = "test-client-id";
|
||||
process.env.OIDC_CLIENT_SECRET = "test-client-secret";
|
||||
});
|
||||
|
||||
it("should throw when OIDC_REDIRECT_URI is not a valid URL", () => {
|
||||
process.env.OIDC_REDIRECT_URI = "not-a-url";
|
||||
|
||||
expect(() => validateOidcConfig()).toThrow("OIDC_REDIRECT_URI must be a valid URL");
|
||||
expect(() => validateOidcConfig()).toThrow("not-a-url");
|
||||
expect(() => validateOidcConfig()).toThrow("Parse error:");
|
||||
});
|
||||
|
||||
it("should throw when OIDC_REDIRECT_URI path does not start with /auth/callback", () => {
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/oauth/callback";
|
||||
|
||||
expect(() => validateOidcConfig()).toThrow(
|
||||
'OIDC_REDIRECT_URI path must start with "/auth/callback"'
|
||||
);
|
||||
expect(() => validateOidcConfig()).toThrow("/oauth/callback");
|
||||
});
|
||||
|
||||
it("should accept a valid OIDC_REDIRECT_URI with /auth/callback path", () => {
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/auth/callback/authentik";
|
||||
|
||||
expect(() => validateOidcConfig()).not.toThrow();
|
||||
});
|
||||
|
||||
it("should accept OIDC_REDIRECT_URI with exactly /auth/callback path", () => {
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/auth/callback";
|
||||
|
||||
expect(() => validateOidcConfig()).not.toThrow();
|
||||
});
|
||||
|
||||
it("should warn but not throw when using localhost in production", () => {
|
||||
process.env.NODE_ENV = "production";
|
||||
process.env.OIDC_REDIRECT_URI = "http://localhost:3000/auth/callback/authentik";
|
||||
|
||||
const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
|
||||
expect(() => validateOidcConfig()).not.toThrow();
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("OIDC_REDIRECT_URI uses localhost")
|
||||
);
|
||||
|
||||
warnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("should warn but not throw when using 127.0.0.1 in production", () => {
|
||||
process.env.NODE_ENV = "production";
|
||||
process.env.OIDC_REDIRECT_URI = "http://127.0.0.1:3000/auth/callback/authentik";
|
||||
|
||||
const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
|
||||
expect(() => validateOidcConfig()).not.toThrow();
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("OIDC_REDIRECT_URI uses localhost")
|
||||
);
|
||||
|
||||
warnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("should not warn about localhost when not in production", () => {
|
||||
process.env.NODE_ENV = "development";
|
||||
process.env.OIDC_REDIRECT_URI = "http://localhost:3000/auth/callback/authentik";
|
||||
|
||||
const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
|
||||
expect(() => validateOidcConfig()).not.toThrow();
|
||||
expect(warnSpy).not.toHaveBeenCalled();
|
||||
|
||||
warnSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("createAuth - genericOAuth PKCE configuration", () => {
|
||||
beforeEach(() => {
|
||||
mockGenericOAuth.mockClear();
|
||||
mockBetterAuth.mockClear();
|
||||
mockPrismaAdapter.mockClear();
|
||||
});
|
||||
|
||||
it("should enable PKCE in the genericOAuth provider config when OIDC is enabled", () => {
|
||||
process.env.OIDC_ENABLED = "true";
|
||||
process.env.OIDC_ISSUER = "https://auth.example.com/application/o/mosaic-stack/";
|
||||
process.env.OIDC_CLIENT_ID = "test-client-id";
|
||||
process.env.OIDC_CLIENT_SECRET = "test-client-secret";
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/auth/callback/authentik";
|
||||
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
createAuth(mockPrisma);
|
||||
|
||||
expect(mockGenericOAuth).toHaveBeenCalledOnce();
|
||||
const callArgs = mockGenericOAuth.mock.calls[0][0] as {
|
||||
config: Array<{ pkce?: boolean }>;
|
||||
};
|
||||
expect(callArgs.config[0].pkce).toBe(true);
|
||||
});
|
||||
|
||||
it("should not call genericOAuth when OIDC is disabled", () => {
|
||||
process.env.OIDC_ENABLED = "false";
|
||||
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
createAuth(mockPrisma);
|
||||
|
||||
expect(mockGenericOAuth).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should throw if OIDC_CLIENT_ID is missing when OIDC is enabled", () => {
|
||||
process.env.OIDC_ENABLED = "true";
|
||||
process.env.OIDC_ISSUER = "https://auth.example.com/application/o/mosaic-stack/";
|
||||
process.env.OIDC_CLIENT_SECRET = "test-client-secret";
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/auth/callback/authentik";
|
||||
// OIDC_CLIENT_ID deliberately not set
|
||||
|
||||
// validateOidcConfig will throw first, so we need to bypass it
|
||||
// by setting the var then deleting it after validation
|
||||
// Instead, test via the validation path which is fine — but let's
|
||||
// verify the plugin-level guard by using a direct approach:
|
||||
// Set env to pass validateOidcConfig, then delete OIDC_CLIENT_ID
|
||||
// The validateOidcConfig will catch this first, which is correct behavior
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
expect(() => createAuth(mockPrisma)).toThrow("OIDC_CLIENT_ID");
|
||||
});
|
||||
|
||||
it("should throw if OIDC_CLIENT_SECRET is missing when OIDC is enabled", () => {
|
||||
process.env.OIDC_ENABLED = "true";
|
||||
process.env.OIDC_ISSUER = "https://auth.example.com/application/o/mosaic-stack/";
|
||||
process.env.OIDC_CLIENT_ID = "test-client-id";
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/auth/callback/authentik";
|
||||
// OIDC_CLIENT_SECRET deliberately not set
|
||||
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
expect(() => createAuth(mockPrisma)).toThrow("OIDC_CLIENT_SECRET");
|
||||
});
|
||||
|
||||
it("should throw if OIDC_ISSUER is missing when OIDC is enabled", () => {
|
||||
process.env.OIDC_ENABLED = "true";
|
||||
process.env.OIDC_CLIENT_ID = "test-client-id";
|
||||
process.env.OIDC_CLIENT_SECRET = "test-client-secret";
|
||||
process.env.OIDC_REDIRECT_URI = "https://app.example.com/auth/callback/authentik";
|
||||
// OIDC_ISSUER deliberately not set
|
||||
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
expect(() => createAuth(mockPrisma)).toThrow("OIDC_ISSUER");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getTrustedOrigins", () => {
|
||||
it("should return localhost URLs when NODE_ENV is not production", () => {
|
||||
process.env.NODE_ENV = "development";
|
||||
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
expect(origins).toContain("http://localhost:3000");
|
||||
expect(origins).toContain("http://localhost:3001");
|
||||
});
|
||||
|
||||
it("should return localhost URLs when NODE_ENV is not set", () => {
|
||||
// NODE_ENV is deleted in beforeEach, so it's undefined here
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
expect(origins).toContain("http://localhost:3000");
|
||||
expect(origins).toContain("http://localhost:3001");
|
||||
});
|
||||
|
||||
it("should exclude localhost URLs in production", () => {
|
||||
process.env.NODE_ENV = "production";
|
||||
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
expect(origins).not.toContain("http://localhost:3000");
|
||||
expect(origins).not.toContain("http://localhost:3001");
|
||||
});
|
||||
|
||||
it("should parse TRUSTED_ORIGINS comma-separated values", () => {
|
||||
process.env.TRUSTED_ORIGINS =
|
||||
"https://app.mosaicstack.dev,https://api.mosaicstack.dev";
|
||||
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
expect(origins).toContain("https://app.mosaicstack.dev");
|
||||
expect(origins).toContain("https://api.mosaicstack.dev");
|
||||
});
|
||||
|
||||
it("should trim whitespace from TRUSTED_ORIGINS entries", () => {
|
||||
process.env.TRUSTED_ORIGINS =
|
||||
" https://app.mosaicstack.dev , https://api.mosaicstack.dev ";
|
||||
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
expect(origins).toContain("https://app.mosaicstack.dev");
|
||||
expect(origins).toContain("https://api.mosaicstack.dev");
|
||||
});
|
||||
|
||||
it("should filter out empty strings from TRUSTED_ORIGINS", () => {
|
||||
process.env.TRUSTED_ORIGINS = "https://app.mosaicstack.dev,,, ,";
|
||||
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
expect(origins).toContain("https://app.mosaicstack.dev");
|
||||
// No empty strings in the result
|
||||
origins.forEach((o) => expect(o).not.toBe(""));
|
||||
});
|
||||
|
||||
it("should include NEXT_PUBLIC_APP_URL", () => {
|
||||
process.env.NEXT_PUBLIC_APP_URL = "https://my-app.example.com";
|
||||
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
expect(origins).toContain("https://my-app.example.com");
|
||||
});
|
||||
|
||||
it("should include NEXT_PUBLIC_API_URL", () => {
|
||||
process.env.NEXT_PUBLIC_API_URL = "https://my-api.example.com";
|
||||
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
expect(origins).toContain("https://my-api.example.com");
|
||||
});
|
||||
|
||||
it("should deduplicate origins", () => {
|
||||
process.env.NEXT_PUBLIC_APP_URL = "http://localhost:3000";
|
||||
process.env.TRUSTED_ORIGINS = "http://localhost:3000,http://localhost:3001";
|
||||
// NODE_ENV not set, so localhost fallbacks are also added
|
||||
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
const countLocalhost3000 = origins.filter((o) => o === "http://localhost:3000").length;
|
||||
const countLocalhost3001 = origins.filter((o) => o === "http://localhost:3001").length;
|
||||
expect(countLocalhost3000).toBe(1);
|
||||
expect(countLocalhost3001).toBe(1);
|
||||
});
|
||||
|
||||
it("should handle all env vars missing gracefully", () => {
|
||||
// All env vars deleted in beforeEach; NODE_ENV is also deleted (not production)
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
// Should still return localhost fallbacks since not in production
|
||||
expect(origins).toContain("http://localhost:3000");
|
||||
expect(origins).toContain("http://localhost:3001");
|
||||
expect(origins).toHaveLength(2);
|
||||
});
|
||||
|
||||
it("should return empty array when all env vars missing in production", () => {
|
||||
process.env.NODE_ENV = "production";
|
||||
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
expect(origins).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("should combine all sources correctly", () => {
|
||||
process.env.NEXT_PUBLIC_APP_URL = "https://app.mosaicstack.dev";
|
||||
process.env.NEXT_PUBLIC_API_URL = "https://api.mosaicstack.dev";
|
||||
process.env.TRUSTED_ORIGINS = "https://extra.example.com";
|
||||
process.env.NODE_ENV = "development";
|
||||
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
expect(origins).toContain("https://app.mosaicstack.dev");
|
||||
expect(origins).toContain("https://api.mosaicstack.dev");
|
||||
expect(origins).toContain("https://extra.example.com");
|
||||
expect(origins).toContain("http://localhost:3000");
|
||||
expect(origins).toContain("http://localhost:3001");
|
||||
expect(origins).toHaveLength(5);
|
||||
});
|
||||
|
||||
it("should reject invalid URLs in TRUSTED_ORIGINS with a warning including error details", () => {
|
||||
process.env.TRUSTED_ORIGINS = "not-a-url,https://valid.example.com";
|
||||
process.env.NODE_ENV = "production";
|
||||
|
||||
const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
expect(origins).toContain("https://valid.example.com");
|
||||
expect(origins).not.toContain("not-a-url");
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Ignoring invalid URL in TRUSTED_ORIGINS: "not-a-url"')
|
||||
);
|
||||
// Verify that error detail is included in the warning
|
||||
const warnCall = warnSpy.mock.calls.find(
|
||||
(call) => typeof call[0] === "string" && call[0].includes("not-a-url")
|
||||
);
|
||||
expect(warnCall).toBeDefined();
|
||||
expect(warnCall![0]).toMatch(/\(.*\)$/);
|
||||
|
||||
warnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("should reject non-HTTP origins in TRUSTED_ORIGINS with a warning", () => {
|
||||
process.env.TRUSTED_ORIGINS = "ftp://files.example.com,https://valid.example.com";
|
||||
process.env.NODE_ENV = "production";
|
||||
|
||||
const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
|
||||
|
||||
const origins = getTrustedOrigins();
|
||||
|
||||
expect(origins).toContain("https://valid.example.com");
|
||||
expect(origins).not.toContain("ftp://files.example.com");
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("Ignoring non-HTTP origin in TRUSTED_ORIGINS")
|
||||
);
|
||||
|
||||
warnSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe("createAuth - session and cookie configuration", () => {
|
||||
beforeEach(() => {
|
||||
mockGenericOAuth.mockClear();
|
||||
mockBetterAuth.mockClear();
|
||||
mockPrismaAdapter.mockClear();
|
||||
});
|
||||
|
||||
it("should configure session expiresIn to 7 days (604800 seconds)", () => {
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
createAuth(mockPrisma);
|
||||
|
||||
expect(mockBetterAuth).toHaveBeenCalledOnce();
|
||||
const config = mockBetterAuth.mock.calls[0][0] as {
|
||||
session: { expiresIn: number; updateAge: number };
|
||||
};
|
||||
expect(config.session.expiresIn).toBe(604800);
|
||||
});
|
||||
|
||||
it("should configure session updateAge to 2 hours (7200 seconds)", () => {
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
createAuth(mockPrisma);
|
||||
|
||||
expect(mockBetterAuth).toHaveBeenCalledOnce();
|
||||
const config = mockBetterAuth.mock.calls[0][0] as {
|
||||
session: { expiresIn: number; updateAge: number };
|
||||
};
|
||||
expect(config.session.updateAge).toBe(7200);
|
||||
});
|
||||
|
||||
it("should set httpOnly cookie attribute to true", () => {
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
createAuth(mockPrisma);
|
||||
|
||||
expect(mockBetterAuth).toHaveBeenCalledOnce();
|
||||
const config = mockBetterAuth.mock.calls[0][0] as {
|
||||
advanced: {
|
||||
defaultCookieAttributes: {
|
||||
httpOnly: boolean;
|
||||
secure: boolean;
|
||||
sameSite: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
expect(config.advanced.defaultCookieAttributes.httpOnly).toBe(true);
|
||||
});
|
||||
|
||||
it("should set sameSite cookie attribute to lax", () => {
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
createAuth(mockPrisma);
|
||||
|
||||
expect(mockBetterAuth).toHaveBeenCalledOnce();
|
||||
const config = mockBetterAuth.mock.calls[0][0] as {
|
||||
advanced: {
|
||||
defaultCookieAttributes: {
|
||||
httpOnly: boolean;
|
||||
secure: boolean;
|
||||
sameSite: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
expect(config.advanced.defaultCookieAttributes.sameSite).toBe("lax");
|
||||
});
|
||||
|
||||
it("should set secure cookie attribute to true in production", () => {
|
||||
process.env.NODE_ENV = "production";
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
createAuth(mockPrisma);
|
||||
|
||||
expect(mockBetterAuth).toHaveBeenCalledOnce();
|
||||
const config = mockBetterAuth.mock.calls[0][0] as {
|
||||
advanced: {
|
||||
defaultCookieAttributes: {
|
||||
httpOnly: boolean;
|
||||
secure: boolean;
|
||||
sameSite: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
expect(config.advanced.defaultCookieAttributes.secure).toBe(true);
|
||||
});
|
||||
|
||||
it("should set secure cookie attribute to false in non-production", () => {
|
||||
process.env.NODE_ENV = "development";
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
createAuth(mockPrisma);
|
||||
|
||||
expect(mockBetterAuth).toHaveBeenCalledOnce();
|
||||
const config = mockBetterAuth.mock.calls[0][0] as {
|
||||
advanced: {
|
||||
defaultCookieAttributes: {
|
||||
httpOnly: boolean;
|
||||
secure: boolean;
|
||||
sameSite: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
expect(config.advanced.defaultCookieAttributes.secure).toBe(false);
|
||||
});
|
||||
|
||||
it("should set cookie domain when COOKIE_DOMAIN env var is present", () => {
|
||||
process.env.COOKIE_DOMAIN = ".mosaicstack.dev";
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
createAuth(mockPrisma);
|
||||
|
||||
expect(mockBetterAuth).toHaveBeenCalledOnce();
|
||||
const config = mockBetterAuth.mock.calls[0][0] as {
|
||||
advanced: {
|
||||
defaultCookieAttributes: {
|
||||
httpOnly: boolean;
|
||||
secure: boolean;
|
||||
sameSite: string;
|
||||
domain?: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
expect(config.advanced.defaultCookieAttributes.domain).toBe(".mosaicstack.dev");
|
||||
});
|
||||
|
||||
it("should not set cookie domain when COOKIE_DOMAIN env var is absent", () => {
|
||||
delete process.env.COOKIE_DOMAIN;
|
||||
const mockPrisma = {} as PrismaClient;
|
||||
createAuth(mockPrisma);
|
||||
|
||||
expect(mockBetterAuth).toHaveBeenCalledOnce();
|
||||
const config = mockBetterAuth.mock.calls[0][0] as {
|
||||
advanced: {
|
||||
defaultCookieAttributes: {
|
||||
httpOnly: boolean;
|
||||
secure: boolean;
|
||||
sameSite: string;
|
||||
domain?: string;
|
||||
};
|
||||
};
|
||||
};
|
||||
expect(config.advanced.defaultCookieAttributes.domain).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -3,37 +3,227 @@ import { prismaAdapter } from "better-auth/adapters/prisma";
|
||||
import { genericOAuth } from "better-auth/plugins";
|
||||
import type { PrismaClient } from "@prisma/client";
|
||||
|
||||
/**
|
||||
* Required OIDC environment variables when OIDC is enabled
|
||||
*/
|
||||
const REQUIRED_OIDC_ENV_VARS = [
|
||||
"OIDC_ISSUER",
|
||||
"OIDC_CLIENT_ID",
|
||||
"OIDC_CLIENT_SECRET",
|
||||
"OIDC_REDIRECT_URI",
|
||||
] as const;
|
||||
|
||||
/**
|
||||
* Check if OIDC authentication is enabled via environment variable
|
||||
*/
|
||||
export function isOidcEnabled(): boolean {
|
||||
const enabled = process.env.OIDC_ENABLED;
|
||||
return enabled === "true" || enabled === "1";
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates OIDC configuration at startup.
|
||||
* Throws an error if OIDC is enabled but required environment variables are missing.
|
||||
*
|
||||
* @throws Error if OIDC is enabled but required vars are missing or empty
|
||||
*/
|
||||
export function validateOidcConfig(): void {
|
||||
if (!isOidcEnabled()) {
|
||||
// OIDC is disabled, no validation needed
|
||||
return;
|
||||
}
|
||||
|
||||
const missingVars: string[] = [];
|
||||
|
||||
for (const envVar of REQUIRED_OIDC_ENV_VARS) {
|
||||
const value = process.env[envVar];
|
||||
if (!value || value.trim() === "") {
|
||||
missingVars.push(envVar);
|
||||
}
|
||||
}
|
||||
|
||||
if (missingVars.length > 0) {
|
||||
throw new Error(
|
||||
`OIDC authentication is enabled (OIDC_ENABLED=true) but required environment variables are missing or empty: ${missingVars.join(", ")}. ` +
|
||||
`Either set these variables or disable OIDC by setting OIDC_ENABLED=false.`
|
||||
);
|
||||
}
|
||||
|
||||
// Additional validation: OIDC_ISSUER should end with a trailing slash for proper discovery URL
|
||||
const issuer = process.env.OIDC_ISSUER;
|
||||
if (issuer && !issuer.endsWith("/")) {
|
||||
throw new Error(
|
||||
`OIDC_ISSUER must end with a trailing slash (/). Current value: "${issuer}". ` +
|
||||
`The discovery URL is constructed by appending ".well-known/openid-configuration" to the issuer.`
|
||||
);
|
||||
}
|
||||
|
||||
// Additional validation: OIDC_REDIRECT_URI must be a valid URL with /auth/callback path
|
||||
validateRedirectUri();
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the OIDC_REDIRECT_URI environment variable.
|
||||
* - Must be a parseable URL
|
||||
* - Path must start with /auth/callback
|
||||
* - Warns (but does not throw) if using localhost in production
|
||||
*
|
||||
* @throws Error if URL is invalid or path does not start with /auth/callback
|
||||
*/
|
||||
function validateRedirectUri(): void {
|
||||
const redirectUri = process.env.OIDC_REDIRECT_URI;
|
||||
if (!redirectUri || redirectUri.trim() === "") {
|
||||
// Already caught by REQUIRED_OIDC_ENV_VARS check above
|
||||
return;
|
||||
}
|
||||
|
||||
let parsed: URL;
|
||||
try {
|
||||
parsed = new URL(redirectUri);
|
||||
} catch (urlError: unknown) {
|
||||
const detail = urlError instanceof Error ? urlError.message : String(urlError);
|
||||
throw new Error(
|
||||
`OIDC_REDIRECT_URI must be a valid URL. Current value: "${redirectUri}". ` +
|
||||
`Parse error: ${detail}. ` +
|
||||
`Example: "https://app.example.com/auth/callback/authentik".`
|
||||
);
|
||||
}
|
||||
|
||||
if (!parsed.pathname.startsWith("/auth/callback")) {
|
||||
throw new Error(
|
||||
`OIDC_REDIRECT_URI path must start with "/auth/callback". Current path: "${parsed.pathname}". ` +
|
||||
`Example: "https://app.example.com/auth/callback/authentik".`
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
process.env.NODE_ENV === "production" &&
|
||||
(parsed.hostname === "localhost" || parsed.hostname === "127.0.0.1")
|
||||
) {
|
||||
console.warn(
|
||||
`[AUTH WARNING] OIDC_REDIRECT_URI uses localhost ("${redirectUri}") in production. ` +
|
||||
`This is likely a misconfiguration. Use a public domain for production deployments.`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get OIDC plugins configuration.
|
||||
* Returns empty array if OIDC is disabled, otherwise returns configured OAuth plugin.
|
||||
*/
|
||||
function getOidcPlugins(): ReturnType<typeof genericOAuth>[] {
|
||||
if (!isOidcEnabled()) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const clientId = process.env.OIDC_CLIENT_ID;
|
||||
const clientSecret = process.env.OIDC_CLIENT_SECRET;
|
||||
const issuer = process.env.OIDC_ISSUER;
|
||||
|
||||
if (!clientId) {
|
||||
throw new Error("OIDC_CLIENT_ID is required when OIDC is enabled but was not set.");
|
||||
}
|
||||
if (!clientSecret) {
|
||||
throw new Error("OIDC_CLIENT_SECRET is required when OIDC is enabled but was not set.");
|
||||
}
|
||||
if (!issuer) {
|
||||
throw new Error("OIDC_ISSUER is required when OIDC is enabled but was not set.");
|
||||
}
|
||||
|
||||
return [
|
||||
genericOAuth({
|
||||
config: [
|
||||
{
|
||||
providerId: "authentik",
|
||||
clientId,
|
||||
clientSecret,
|
||||
discoveryUrl: `${issuer}.well-known/openid-configuration`,
|
||||
pkce: true,
|
||||
scopes: ["openid", "profile", "email"],
|
||||
},
|
||||
],
|
||||
}),
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the list of trusted origins from environment variables.
|
||||
*
|
||||
* Sources (in order):
|
||||
* - NEXT_PUBLIC_APP_URL — primary frontend URL
|
||||
* - NEXT_PUBLIC_API_URL — API's own origin
|
||||
* - TRUSTED_ORIGINS — comma-separated additional origins
|
||||
* - localhost fallbacks — only when NODE_ENV !== "production"
|
||||
*
|
||||
* The returned list is deduplicated and empty strings are filtered out.
|
||||
*/
|
||||
export function getTrustedOrigins(): string[] {
|
||||
const origins: string[] = [];
|
||||
|
||||
// Environment-driven origins
|
||||
if (process.env.NEXT_PUBLIC_APP_URL) {
|
||||
origins.push(process.env.NEXT_PUBLIC_APP_URL);
|
||||
}
|
||||
|
||||
if (process.env.NEXT_PUBLIC_API_URL) {
|
||||
origins.push(process.env.NEXT_PUBLIC_API_URL);
|
||||
}
|
||||
|
||||
// Comma-separated additional origins (validated)
|
||||
if (process.env.TRUSTED_ORIGINS) {
|
||||
const rawOrigins = process.env.TRUSTED_ORIGINS.split(",")
|
||||
.map((o) => o.trim())
|
||||
.filter((o) => o !== "");
|
||||
for (const origin of rawOrigins) {
|
||||
try {
|
||||
const parsed = new URL(origin);
|
||||
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
|
||||
console.warn(`[AUTH] Ignoring non-HTTP origin in TRUSTED_ORIGINS: "${origin}"`);
|
||||
continue;
|
||||
}
|
||||
origins.push(origin);
|
||||
} catch (urlError: unknown) {
|
||||
const detail = urlError instanceof Error ? urlError.message : String(urlError);
|
||||
console.warn(`[AUTH] Ignoring invalid URL in TRUSTED_ORIGINS: "${origin}" (${detail})`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Localhost fallbacks for development only
|
||||
if (process.env.NODE_ENV !== "production") {
|
||||
origins.push("http://localhost:3000", "http://localhost:3001");
|
||||
}
|
||||
|
||||
// Deduplicate and filter empty strings
|
||||
return [...new Set(origins)].filter((o) => o !== "");
|
||||
}
|
||||
|
||||
export function createAuth(prisma: PrismaClient) {
|
||||
// Validate OIDC configuration at startup - fail fast if misconfigured
|
||||
validateOidcConfig();
|
||||
|
||||
return betterAuth({
|
||||
basePath: "/auth",
|
||||
database: prismaAdapter(prisma, {
|
||||
provider: "postgresql",
|
||||
}),
|
||||
emailAndPassword: {
|
||||
enabled: true, // Enable for now, can be disabled later
|
||||
enabled: true,
|
||||
},
|
||||
plugins: [
|
||||
genericOAuth({
|
||||
config: [
|
||||
{
|
||||
providerId: "authentik",
|
||||
clientId: process.env.OIDC_CLIENT_ID ?? "",
|
||||
clientSecret: process.env.OIDC_CLIENT_SECRET ?? "",
|
||||
discoveryUrl: `${process.env.OIDC_ISSUER ?? ""}.well-known/openid-configuration`,
|
||||
scopes: ["openid", "profile", "email"],
|
||||
},
|
||||
],
|
||||
}),
|
||||
],
|
||||
plugins: [...getOidcPlugins()],
|
||||
session: {
|
||||
expiresIn: 60 * 60 * 24, // 24 hours
|
||||
updateAge: 60 * 60 * 24, // 24 hours
|
||||
expiresIn: 60 * 60 * 24 * 7, // 7 days absolute max
|
||||
updateAge: 60 * 60 * 2, // 2 hours — minimum session age before BetterAuth refreshes the expiry on next request
|
||||
},
|
||||
trustedOrigins: [
|
||||
process.env.NEXT_PUBLIC_APP_URL ?? "http://localhost:3000",
|
||||
"http://localhost:3001", // API origin (dev)
|
||||
"https://app.mosaicstack.dev", // Production web
|
||||
"https://api.mosaicstack.dev", // Production API
|
||||
],
|
||||
advanced: {
|
||||
defaultCookieAttributes: {
|
||||
httpOnly: true,
|
||||
secure: process.env.NODE_ENV === "production",
|
||||
sameSite: "lax" as const,
|
||||
...(process.env.COOKIE_DOMAIN ? { domain: process.env.COOKIE_DOMAIN } : {}),
|
||||
},
|
||||
},
|
||||
trustedOrigins: getTrustedOrigins(),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,15 +1,41 @@
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
|
||||
// Mock better-auth modules before importing AuthService (pulled in by AuthController)
|
||||
vi.mock("better-auth/node", () => ({
|
||||
toNodeHandler: vi.fn().mockReturnValue(vi.fn()),
|
||||
}));
|
||||
|
||||
vi.mock("better-auth", () => ({
|
||||
betterAuth: vi.fn().mockReturnValue({
|
||||
handler: vi.fn(),
|
||||
api: { getSession: vi.fn() },
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("better-auth/adapters/prisma", () => ({
|
||||
prismaAdapter: vi.fn().mockReturnValue({}),
|
||||
}));
|
||||
|
||||
vi.mock("better-auth/plugins", () => ({
|
||||
genericOAuth: vi.fn().mockReturnValue({ id: "generic-oauth" }),
|
||||
}));
|
||||
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import type { AuthUser } from "@mosaic/shared";
|
||||
import { HttpException, HttpStatus, UnauthorizedException } from "@nestjs/common";
|
||||
import type { AuthUser, AuthSession } from "@mosaic/shared";
|
||||
import type { Request as ExpressRequest, Response as ExpressResponse } from "express";
|
||||
import { AuthController } from "./auth.controller";
|
||||
import { AuthService } from "./auth.service";
|
||||
|
||||
describe("AuthController", () => {
|
||||
let controller: AuthController;
|
||||
let authService: AuthService;
|
||||
|
||||
const mockNodeHandler = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
const mockAuthService = {
|
||||
getAuth: vi.fn(),
|
||||
getNodeHandler: vi.fn().mockReturnValue(mockNodeHandler),
|
||||
getAuthConfig: vi.fn(),
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
@@ -24,30 +50,317 @@ describe("AuthController", () => {
|
||||
}).compile();
|
||||
|
||||
controller = module.get<AuthController>(AuthController);
|
||||
authService = module.get<AuthService>(AuthService);
|
||||
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Restore mock implementations after clearAllMocks
|
||||
mockAuthService.getNodeHandler.mockReturnValue(mockNodeHandler);
|
||||
mockNodeHandler.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
describe("handleAuth", () => {
|
||||
it("should call BetterAuth handler", async () => {
|
||||
const mockHandler = vi.fn().mockResolvedValue({ status: 200 });
|
||||
mockAuthService.getAuth.mockReturnValue({ handler: mockHandler });
|
||||
|
||||
it("should delegate to BetterAuth node handler with Express req/res", async () => {
|
||||
const mockRequest = {
|
||||
method: "GET",
|
||||
url: "/auth/session",
|
||||
headers: {},
|
||||
ip: "127.0.0.1",
|
||||
socket: { remoteAddress: "127.0.0.1" },
|
||||
} as unknown as ExpressRequest;
|
||||
|
||||
const mockResponse = {
|
||||
headersSent: false,
|
||||
} as unknown as ExpressResponse;
|
||||
|
||||
await controller.handleAuth(mockRequest, mockResponse);
|
||||
|
||||
expect(mockAuthService.getNodeHandler).toHaveBeenCalled();
|
||||
expect(mockNodeHandler).toHaveBeenCalledWith(mockRequest, mockResponse);
|
||||
});
|
||||
|
||||
it("should throw HttpException with 500 when handler throws before headers sent", async () => {
|
||||
const handlerError = new Error("BetterAuth internal failure");
|
||||
mockNodeHandler.mockRejectedValueOnce(handlerError);
|
||||
|
||||
const mockRequest = {
|
||||
method: "POST",
|
||||
url: "/auth/sign-in",
|
||||
headers: {},
|
||||
ip: "192.168.1.10",
|
||||
socket: { remoteAddress: "192.168.1.10" },
|
||||
} as unknown as ExpressRequest;
|
||||
|
||||
const mockResponse = {
|
||||
headersSent: false,
|
||||
} as unknown as ExpressResponse;
|
||||
|
||||
try {
|
||||
await controller.handleAuth(mockRequest, mockResponse);
|
||||
// Should not reach here
|
||||
expect.unreachable("Expected HttpException to be thrown");
|
||||
} catch (err) {
|
||||
expect(err).toBeInstanceOf(HttpException);
|
||||
expect((err as HttpException).getStatus()).toBe(HttpStatus.INTERNAL_SERVER_ERROR);
|
||||
expect((err as HttpException).getResponse()).toBe(
|
||||
"Unable to complete authentication. Please try again in a moment.",
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("should log warning and not throw when handler throws after headers sent", async () => {
|
||||
const handlerError = new Error("Stream interrupted");
|
||||
mockNodeHandler.mockRejectedValueOnce(handlerError);
|
||||
|
||||
const mockRequest = {
|
||||
method: "POST",
|
||||
url: "/auth/sign-up",
|
||||
headers: {},
|
||||
ip: "10.0.0.5",
|
||||
socket: { remoteAddress: "10.0.0.5" },
|
||||
} as unknown as ExpressRequest;
|
||||
|
||||
const mockResponse = {
|
||||
headersSent: true,
|
||||
} as unknown as ExpressResponse;
|
||||
|
||||
// Should not throw when headers already sent
|
||||
await expect(controller.handleAuth(mockRequest, mockResponse)).resolves.toBeUndefined();
|
||||
});
|
||||
|
||||
it("should handle non-Error thrown values", async () => {
|
||||
mockNodeHandler.mockRejectedValueOnce("string error");
|
||||
|
||||
const mockRequest = {
|
||||
method: "GET",
|
||||
url: "/auth/callback",
|
||||
headers: {},
|
||||
ip: "127.0.0.1",
|
||||
socket: { remoteAddress: "127.0.0.1" },
|
||||
} as unknown as ExpressRequest;
|
||||
|
||||
const mockResponse = {
|
||||
headersSent: false,
|
||||
} as unknown as ExpressResponse;
|
||||
|
||||
await expect(controller.handleAuth(mockRequest, mockResponse)).rejects.toThrow(
|
||||
HttpException,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getConfig", () => {
|
||||
it("should return auth config from service", async () => {
|
||||
const mockConfig = {
|
||||
providers: [
|
||||
{ id: "email", name: "Email", type: "credentials" as const },
|
||||
{ id: "authentik", name: "Authentik", type: "oauth" as const },
|
||||
],
|
||||
};
|
||||
mockAuthService.getAuthConfig.mockResolvedValue(mockConfig);
|
||||
|
||||
const result = await controller.getConfig();
|
||||
|
||||
expect(result).toEqual(mockConfig);
|
||||
expect(mockAuthService.getAuthConfig).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should return correct response shape with only email provider", async () => {
|
||||
const mockConfig = {
|
||||
providers: [{ id: "email", name: "Email", type: "credentials" as const }],
|
||||
};
|
||||
mockAuthService.getAuthConfig.mockResolvedValue(mockConfig);
|
||||
|
||||
const result = await controller.getConfig();
|
||||
|
||||
expect(result).toEqual(mockConfig);
|
||||
expect(result.providers).toHaveLength(1);
|
||||
expect(result.providers[0]).toEqual({
|
||||
id: "email",
|
||||
name: "Email",
|
||||
type: "credentials",
|
||||
});
|
||||
});
|
||||
|
||||
it("should never leak secrets in auth config response", async () => {
|
||||
// Set ALL sensitive environment variables with known values
|
||||
const sensitiveEnv: Record<string, string> = {
|
||||
OIDC_CLIENT_SECRET: "test-client-secret",
|
||||
OIDC_CLIENT_ID: "test-client-id",
|
||||
OIDC_ISSUER: "https://auth.test.com/",
|
||||
OIDC_REDIRECT_URI: "https://app.test.com/auth/callback/authentik",
|
||||
BETTER_AUTH_SECRET: "test-better-auth-secret",
|
||||
JWT_SECRET: "test-jwt-secret",
|
||||
CSRF_SECRET: "test-csrf-secret",
|
||||
DATABASE_URL: "postgresql://user:password@localhost/db",
|
||||
OIDC_ENABLED: "true",
|
||||
};
|
||||
|
||||
await controller.handleAuth(mockRequest);
|
||||
const originalEnv: Record<string, string | undefined> = {};
|
||||
for (const [key, value] of Object.entries(sensitiveEnv)) {
|
||||
originalEnv[key] = process.env[key];
|
||||
process.env[key] = value;
|
||||
}
|
||||
|
||||
expect(mockAuthService.getAuth).toHaveBeenCalled();
|
||||
expect(mockHandler).toHaveBeenCalledWith(mockRequest);
|
||||
try {
|
||||
// Mock the service to return a realistic config with both providers
|
||||
const mockConfig = {
|
||||
providers: [
|
||||
{ id: "email", name: "Email", type: "credentials" as const },
|
||||
{ id: "authentik", name: "Authentik", type: "oauth" as const },
|
||||
],
|
||||
};
|
||||
mockAuthService.getAuthConfig.mockResolvedValue(mockConfig);
|
||||
|
||||
const result = await controller.getConfig();
|
||||
const serialized = JSON.stringify(result);
|
||||
|
||||
// Assert no secret values leak into the serialized response
|
||||
const forbiddenPatterns = [
|
||||
"test-client-secret",
|
||||
"test-client-id",
|
||||
"test-better-auth-secret",
|
||||
"test-jwt-secret",
|
||||
"test-csrf-secret",
|
||||
"auth.test.com",
|
||||
"callback",
|
||||
"password",
|
||||
];
|
||||
|
||||
for (const pattern of forbiddenPatterns) {
|
||||
expect(serialized).not.toContain(pattern);
|
||||
}
|
||||
|
||||
// Assert response contains ONLY expected fields
|
||||
expect(result).toHaveProperty("providers");
|
||||
expect(Object.keys(result)).toEqual(["providers"]);
|
||||
expect(Array.isArray(result.providers)).toBe(true);
|
||||
|
||||
for (const provider of result.providers) {
|
||||
const keys = Object.keys(provider);
|
||||
expect(keys).toEqual(expect.arrayContaining(["id", "name", "type"]));
|
||||
expect(keys).toHaveLength(3);
|
||||
}
|
||||
} finally {
|
||||
// Restore original environment
|
||||
for (const [key] of Object.entries(sensitiveEnv)) {
|
||||
if (originalEnv[key] === undefined) {
|
||||
delete process.env[key];
|
||||
} else {
|
||||
process.env[key] = originalEnv[key];
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("getSession", () => {
|
||||
it("should return user and session data", () => {
|
||||
const mockUser: AuthUser = {
|
||||
id: "user-123",
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
workspaceId: "workspace-123",
|
||||
};
|
||||
|
||||
const mockSession = {
|
||||
id: "session-123",
|
||||
token: "session-token",
|
||||
expiresAt: new Date(Date.now() + 86400000),
|
||||
};
|
||||
|
||||
const mockRequest = {
|
||||
user: mockUser,
|
||||
session: mockSession,
|
||||
};
|
||||
|
||||
const result = controller.getSession(mockRequest);
|
||||
|
||||
const expected: AuthSession = {
|
||||
user: mockUser,
|
||||
session: {
|
||||
id: mockSession.id,
|
||||
token: mockSession.token,
|
||||
expiresAt: mockSession.expiresAt,
|
||||
},
|
||||
};
|
||||
|
||||
expect(result).toEqual(expected);
|
||||
});
|
||||
|
||||
it("should throw UnauthorizedException when req.user is undefined", () => {
|
||||
const mockRequest = {
|
||||
session: {
|
||||
id: "session-123",
|
||||
token: "session-token",
|
||||
expiresAt: new Date(Date.now() + 86400000),
|
||||
},
|
||||
};
|
||||
|
||||
expect(() => controller.getSession(mockRequest as never)).toThrow(
|
||||
UnauthorizedException,
|
||||
);
|
||||
expect(() => controller.getSession(mockRequest as never)).toThrow(
|
||||
"Missing authentication context",
|
||||
);
|
||||
});
|
||||
|
||||
it("should throw UnauthorizedException when req.session is undefined", () => {
|
||||
const mockRequest = {
|
||||
user: {
|
||||
id: "user-123",
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
},
|
||||
};
|
||||
|
||||
expect(() => controller.getSession(mockRequest as never)).toThrow(
|
||||
UnauthorizedException,
|
||||
);
|
||||
expect(() => controller.getSession(mockRequest as never)).toThrow(
|
||||
"Missing authentication context",
|
||||
);
|
||||
});
|
||||
|
||||
it("should throw UnauthorizedException when both req.user and req.session are undefined", () => {
|
||||
const mockRequest = {};
|
||||
|
||||
expect(() => controller.getSession(mockRequest as never)).toThrow(
|
||||
UnauthorizedException,
|
||||
);
|
||||
expect(() => controller.getSession(mockRequest as never)).toThrow(
|
||||
"Missing authentication context",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getProfile", () => {
|
||||
it("should return user profile", () => {
|
||||
it("should return complete user profile with workspace fields", () => {
|
||||
const mockUser: AuthUser = {
|
||||
id: "user-123",
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
image: "https://example.com/avatar.jpg",
|
||||
emailVerified: true,
|
||||
workspaceId: "workspace-123",
|
||||
currentWorkspaceId: "workspace-456",
|
||||
workspaceRole: "admin",
|
||||
};
|
||||
|
||||
const result = controller.getProfile(mockUser);
|
||||
|
||||
expect(result).toEqual({
|
||||
id: mockUser.id,
|
||||
email: mockUser.email,
|
||||
name: mockUser.name,
|
||||
image: mockUser.image,
|
||||
emailVerified: mockUser.emailVerified,
|
||||
workspaceId: mockUser.workspaceId,
|
||||
currentWorkspaceId: mockUser.currentWorkspaceId,
|
||||
workspaceRole: mockUser.workspaceRole,
|
||||
});
|
||||
});
|
||||
|
||||
it("should return user profile with optional fields undefined", () => {
|
||||
const mockUser: AuthUser = {
|
||||
id: "user-123",
|
||||
email: "test@example.com",
|
||||
@@ -60,7 +373,107 @@ describe("AuthController", () => {
|
||||
id: mockUser.id,
|
||||
email: mockUser.email,
|
||||
name: mockUser.name,
|
||||
image: undefined,
|
||||
emailVerified: undefined,
|
||||
workspaceId: undefined,
|
||||
currentWorkspaceId: undefined,
|
||||
workspaceRole: undefined,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("getClientIp (via handleAuth)", () => {
|
||||
it("should extract IP from X-Forwarded-For with single IP", async () => {
|
||||
const mockRequest = {
|
||||
method: "GET",
|
||||
url: "/auth/callback",
|
||||
headers: { "x-forwarded-for": "203.0.113.50" },
|
||||
ip: "127.0.0.1",
|
||||
socket: { remoteAddress: "127.0.0.1" },
|
||||
} as unknown as ExpressRequest;
|
||||
|
||||
const mockResponse = {
|
||||
headersSent: false,
|
||||
} as unknown as ExpressResponse;
|
||||
|
||||
// Spy on the logger to verify the extracted IP
|
||||
const debugSpy = vi.spyOn(controller["logger"], "debug");
|
||||
|
||||
await controller.handleAuth(mockRequest, mockResponse);
|
||||
|
||||
expect(debugSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("203.0.113.50"),
|
||||
);
|
||||
});
|
||||
|
||||
it("should extract first IP from X-Forwarded-For with comma-separated IPs", async () => {
|
||||
const mockRequest = {
|
||||
method: "GET",
|
||||
url: "/auth/callback",
|
||||
headers: { "x-forwarded-for": "203.0.113.50, 70.41.3.18" },
|
||||
ip: "127.0.0.1",
|
||||
socket: { remoteAddress: "127.0.0.1" },
|
||||
} as unknown as ExpressRequest;
|
||||
|
||||
const mockResponse = {
|
||||
headersSent: false,
|
||||
} as unknown as ExpressResponse;
|
||||
|
||||
const debugSpy = vi.spyOn(controller["logger"], "debug");
|
||||
|
||||
await controller.handleAuth(mockRequest, mockResponse);
|
||||
|
||||
expect(debugSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("203.0.113.50"),
|
||||
);
|
||||
// Ensure it does NOT contain the second IP in the extracted position
|
||||
expect(debugSpy).toHaveBeenCalledWith(
|
||||
expect.not.stringContaining("70.41.3.18"),
|
||||
);
|
||||
});
|
||||
|
||||
it("should extract first IP from X-Forwarded-For as array", async () => {
|
||||
const mockRequest = {
|
||||
method: "GET",
|
||||
url: "/auth/callback",
|
||||
headers: { "x-forwarded-for": ["203.0.113.50", "70.41.3.18"] },
|
||||
ip: "127.0.0.1",
|
||||
socket: { remoteAddress: "127.0.0.1" },
|
||||
} as unknown as ExpressRequest;
|
||||
|
||||
const mockResponse = {
|
||||
headersSent: false,
|
||||
} as unknown as ExpressResponse;
|
||||
|
||||
const debugSpy = vi.spyOn(controller["logger"], "debug");
|
||||
|
||||
await controller.handleAuth(mockRequest, mockResponse);
|
||||
|
||||
expect(debugSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("203.0.113.50"),
|
||||
);
|
||||
});
|
||||
|
||||
it("should fallback to req.ip when no X-Forwarded-For header", async () => {
|
||||
const mockRequest = {
|
||||
method: "GET",
|
||||
url: "/auth/callback",
|
||||
headers: {},
|
||||
ip: "192.168.1.100",
|
||||
socket: { remoteAddress: "192.168.1.100" },
|
||||
} as unknown as ExpressRequest;
|
||||
|
||||
const mockResponse = {
|
||||
headersSent: false,
|
||||
} as unknown as ExpressResponse;
|
||||
|
||||
const debugSpy = vi.spyOn(controller["logger"], "debug");
|
||||
|
||||
await controller.handleAuth(mockRequest, mockResponse);
|
||||
|
||||
expect(debugSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("192.168.1.100"),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,26 +1,162 @@
|
||||
import { Controller, All, Req, Get, UseGuards } from "@nestjs/common";
|
||||
import type { AuthUser } from "@mosaic/shared";
|
||||
import {
|
||||
Controller,
|
||||
All,
|
||||
Req,
|
||||
Res,
|
||||
Get,
|
||||
Header,
|
||||
UseGuards,
|
||||
Request,
|
||||
Logger,
|
||||
HttpException,
|
||||
HttpStatus,
|
||||
UnauthorizedException,
|
||||
} from "@nestjs/common";
|
||||
import { Throttle } from "@nestjs/throttler";
|
||||
import type { Request as ExpressRequest, Response as ExpressResponse } from "express";
|
||||
import type { AuthUser, AuthSession, AuthConfigResponse } from "@mosaic/shared";
|
||||
import { AuthService } from "./auth.service";
|
||||
import { AuthGuard } from "./guards/auth.guard";
|
||||
import { CurrentUser } from "./decorators/current-user.decorator";
|
||||
import { SkipCsrf } from "../common/decorators/skip-csrf.decorator";
|
||||
import type { AuthenticatedRequest } from "./types/better-auth-request.interface";
|
||||
|
||||
@Controller("auth")
|
||||
export class AuthController {
|
||||
private readonly logger = new Logger(AuthController.name);
|
||||
|
||||
constructor(private readonly authService: AuthService) {}
|
||||
|
||||
/**
|
||||
* Get current session
|
||||
* Returns user and session data for authenticated user
|
||||
*/
|
||||
@Get("session")
|
||||
@UseGuards(AuthGuard)
|
||||
getSession(@Request() req: AuthenticatedRequest): AuthSession {
|
||||
// Defense-in-depth: AuthGuard should guarantee these, but if someone adds
|
||||
// a route with AuthenticatedRequest and forgets @UseGuards(AuthGuard),
|
||||
// TypeScript types won't help at runtime.
|
||||
// eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
|
||||
if (!req.user || !req.session) {
|
||||
throw new UnauthorizedException("Missing authentication context");
|
||||
}
|
||||
|
||||
return {
|
||||
user: req.user,
|
||||
session: {
|
||||
id: req.session.id,
|
||||
token: req.session.token,
|
||||
expiresAt: req.session.expiresAt,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get current user profile
|
||||
* Returns basic user information
|
||||
*/
|
||||
@Get("profile")
|
||||
@UseGuards(AuthGuard)
|
||||
getProfile(@CurrentUser() user: AuthUser) {
|
||||
return {
|
||||
getProfile(@CurrentUser() user: AuthUser): AuthUser {
|
||||
// Return only defined properties to maintain type safety
|
||||
const profile: AuthUser = {
|
||||
id: user.id,
|
||||
email: user.email,
|
||||
name: user.name,
|
||||
};
|
||||
|
||||
if (user.image !== undefined) {
|
||||
profile.image = user.image;
|
||||
}
|
||||
if (user.emailVerified !== undefined) {
|
||||
profile.emailVerified = user.emailVerified;
|
||||
}
|
||||
if (user.workspaceId !== undefined) {
|
||||
profile.workspaceId = user.workspaceId;
|
||||
}
|
||||
if (user.currentWorkspaceId !== undefined) {
|
||||
profile.currentWorkspaceId = user.currentWorkspaceId;
|
||||
}
|
||||
if (user.workspaceRole !== undefined) {
|
||||
profile.workspaceRole = user.workspaceRole;
|
||||
}
|
||||
|
||||
return profile;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get available authentication providers.
|
||||
* Public endpoint (no auth guard) so the frontend can discover login options
|
||||
* before the user is authenticated.
|
||||
*/
|
||||
@Get("config")
|
||||
@Header("Cache-Control", "public, max-age=300")
|
||||
async getConfig(): Promise<AuthConfigResponse> {
|
||||
return this.authService.getAuthConfig();
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle all other auth routes (sign-in, sign-up, sign-out, etc.)
|
||||
* Delegates to BetterAuth
|
||||
*
|
||||
* Rate limit: "strict" tier (10 req/min) - More restrictive than normal routes
|
||||
* to prevent brute-force attacks on auth endpoints
|
||||
*
|
||||
* Security note: This catch-all route bypasses standard guards that other routes have.
|
||||
* Rate limiting and logging are applied to mitigate abuse (SEC-API-10).
|
||||
*/
|
||||
@All("*")
|
||||
async handleAuth(@Req() req: Request) {
|
||||
const auth = this.authService.getAuth();
|
||||
return auth.handler(req);
|
||||
// BetterAuth handles CSRF internally (Fetch Metadata + SameSite=Lax cookies).
|
||||
// @SkipCsrf avoids double-protection conflicts.
|
||||
// See: https://www.better-auth.com/docs/reference/security
|
||||
@SkipCsrf()
|
||||
@Throttle({ strict: { limit: 10, ttl: 60000 } })
|
||||
async handleAuth(@Req() req: ExpressRequest, @Res() res: ExpressResponse): Promise<void> {
|
||||
// Extract client IP for logging
|
||||
const clientIp = this.getClientIp(req);
|
||||
|
||||
// Log auth catch-all hits for monitoring and debugging
|
||||
this.logger.debug(`Auth catch-all: ${req.method} ${req.url} from ${clientIp}`);
|
||||
|
||||
const handler = this.authService.getNodeHandler();
|
||||
|
||||
try {
|
||||
await handler(req, res);
|
||||
} catch (error: unknown) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
const stack = error instanceof Error ? error.stack : undefined;
|
||||
|
||||
this.logger.error(
|
||||
`BetterAuth handler error: ${req.method} ${req.url} from ${clientIp} - ${message}`,
|
||||
stack
|
||||
);
|
||||
|
||||
if (!res.headersSent) {
|
||||
throw new HttpException(
|
||||
"Unable to complete authentication. Please try again in a moment.",
|
||||
HttpStatus.INTERNAL_SERVER_ERROR
|
||||
);
|
||||
}
|
||||
|
||||
this.logger.error(
|
||||
`Headers already sent for failed auth request ${req.method} ${req.url} — client may have received partial response`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract client IP from request, handling proxies
|
||||
*/
|
||||
private getClientIp(req: ExpressRequest): string {
|
||||
// Check X-Forwarded-For header (for reverse proxy setups)
|
||||
const forwardedFor = req.headers["x-forwarded-for"];
|
||||
if (forwardedFor) {
|
||||
const ips = Array.isArray(forwardedFor) ? forwardedFor[0] : forwardedFor;
|
||||
return ips?.split(",")[0]?.trim() ?? "unknown";
|
||||
}
|
||||
|
||||
// Fall back to direct IP
|
||||
return req.ip ?? req.socket.remoteAddress ?? "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
213
apps/api/src/auth/auth.rate-limit.spec.ts
Normal file
213
apps/api/src/auth/auth.rate-limit.spec.ts
Normal file
@@ -0,0 +1,213 @@
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { INestApplication, HttpStatus, Logger } from "@nestjs/common";
|
||||
import request from "supertest";
|
||||
import { AuthController } from "./auth.controller";
|
||||
import { AuthService } from "./auth.service";
|
||||
import { ThrottlerModule } from "@nestjs/throttler";
|
||||
import { APP_GUARD } from "@nestjs/core";
|
||||
import { ThrottlerApiKeyGuard } from "../common/throttler";
|
||||
|
||||
/**
|
||||
* Rate Limiting Tests for Auth Controller Catch-All Route
|
||||
*
|
||||
* These tests verify that rate limiting is properly enforced on the auth
|
||||
* catch-all route to prevent brute-force attacks (SEC-API-10).
|
||||
*
|
||||
* Test Coverage:
|
||||
* - Rate limit enforcement (429 status after 10 requests in 1 minute)
|
||||
* - Retry-After header inclusion
|
||||
* - Logging occurs for auth catch-all hits
|
||||
*/
|
||||
describe("AuthController - Rate Limiting", () => {
|
||||
let app: INestApplication;
|
||||
let loggerSpy: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
const mockNodeHandler = vi.fn(
|
||||
(_req: unknown, res: { statusCode: number; end: (body: string) => void }) => {
|
||||
res.statusCode = 200;
|
||||
res.end(JSON.stringify({}));
|
||||
return Promise.resolve();
|
||||
}
|
||||
);
|
||||
|
||||
const mockAuthService = {
|
||||
getAuth: vi.fn(),
|
||||
getNodeHandler: vi.fn().mockReturnValue(mockNodeHandler),
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
// Spy on Logger.prototype.debug to verify logging
|
||||
loggerSpy = vi.spyOn(Logger.prototype, "debug").mockImplementation(() => {});
|
||||
|
||||
const moduleFixture: TestingModule = await Test.createTestingModule({
|
||||
imports: [
|
||||
ThrottlerModule.forRoot([
|
||||
{
|
||||
ttl: 60000, // 1 minute
|
||||
limit: 10, // Match the "strict" tier limit
|
||||
},
|
||||
]),
|
||||
],
|
||||
controllers: [AuthController],
|
||||
providers: [
|
||||
{ provide: AuthService, useValue: mockAuthService },
|
||||
{
|
||||
provide: APP_GUARD,
|
||||
useClass: ThrottlerApiKeyGuard,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
app = moduleFixture.createNestApplication();
|
||||
await app.init();
|
||||
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await app.close();
|
||||
loggerSpy.mockRestore();
|
||||
});
|
||||
|
||||
describe("Auth Catch-All Route - Rate Limiting", () => {
|
||||
it("should allow requests within rate limit", async () => {
|
||||
// Make 3 requests (within limit of 10)
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const response = await request(app.getHttpServer()).post("/auth/sign-in").send({
|
||||
email: "test@example.com",
|
||||
password: "password",
|
||||
});
|
||||
|
||||
// Should not be rate limited
|
||||
expect(response.status).not.toBe(HttpStatus.TOO_MANY_REQUESTS);
|
||||
}
|
||||
|
||||
expect(mockAuthService.getNodeHandler).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it("should return 429 when rate limit is exceeded", async () => {
|
||||
// Exhaust rate limit (10 requests)
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await request(app.getHttpServer()).post("/auth/sign-in").send({
|
||||
email: "test@example.com",
|
||||
password: "password",
|
||||
});
|
||||
}
|
||||
|
||||
// The 11th request should be rate limited
|
||||
const response = await request(app.getHttpServer()).post("/auth/sign-in").send({
|
||||
email: "test@example.com",
|
||||
password: "password",
|
||||
});
|
||||
|
||||
expect(response.status).toBe(HttpStatus.TOO_MANY_REQUESTS);
|
||||
});
|
||||
|
||||
it("should include Retry-After header in 429 response", async () => {
|
||||
// Exhaust rate limit (10 requests)
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await request(app.getHttpServer()).post("/auth/sign-in").send({
|
||||
email: "test@example.com",
|
||||
password: "password",
|
||||
});
|
||||
}
|
||||
|
||||
// Get rate limited response
|
||||
const response = await request(app.getHttpServer()).post("/auth/sign-in").send({
|
||||
email: "test@example.com",
|
||||
password: "password",
|
||||
});
|
||||
|
||||
expect(response.status).toBe(HttpStatus.TOO_MANY_REQUESTS);
|
||||
expect(response.headers).toHaveProperty("retry-after");
|
||||
expect(parseInt(response.headers["retry-after"])).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it("should rate limit different auth endpoints under the same limit", async () => {
|
||||
// Make 5 sign-in requests
|
||||
for (let i = 0; i < 5; i++) {
|
||||
await request(app.getHttpServer()).post("/auth/sign-in").send({
|
||||
email: "test@example.com",
|
||||
password: "password",
|
||||
});
|
||||
}
|
||||
|
||||
// Make 5 sign-up requests (total now 10)
|
||||
for (let i = 0; i < 5; i++) {
|
||||
await request(app.getHttpServer()).post("/auth/sign-up").send({
|
||||
email: "test@example.com",
|
||||
password: "password",
|
||||
name: "Test User",
|
||||
});
|
||||
}
|
||||
|
||||
// The 11th request (any auth endpoint) should be rate limited
|
||||
const response = await request(app.getHttpServer()).post("/auth/sign-in").send({
|
||||
email: "test@example.com",
|
||||
password: "password",
|
||||
});
|
||||
|
||||
expect(response.status).toBe(HttpStatus.TOO_MANY_REQUESTS);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Auth Catch-All Route - Logging", () => {
|
||||
it("should log auth catch-all hits with request details", async () => {
|
||||
await request(app.getHttpServer()).post("/auth/sign-in").send({
|
||||
email: "test@example.com",
|
||||
password: "password",
|
||||
});
|
||||
|
||||
// Verify logging was called
|
||||
expect(loggerSpy).toHaveBeenCalled();
|
||||
|
||||
// Find the log call that contains our expected message
|
||||
const logCalls = loggerSpy.mock.calls;
|
||||
const authLogCall = logCalls.find(
|
||||
(call) => typeof call[0] === "string" && call[0].includes("Auth catch-all:")
|
||||
);
|
||||
|
||||
expect(authLogCall).toBeDefined();
|
||||
expect(authLogCall?.[0]).toMatch(/Auth catch-all: POST/);
|
||||
});
|
||||
|
||||
it("should log different HTTP methods correctly", async () => {
|
||||
// Test GET request
|
||||
await request(app.getHttpServer()).get("/auth/callback");
|
||||
|
||||
const logCalls = loggerSpy.mock.calls;
|
||||
const getLogCall = logCalls.find(
|
||||
(call) =>
|
||||
typeof call[0] === "string" &&
|
||||
call[0].includes("Auth catch-all:") &&
|
||||
call[0].includes("GET")
|
||||
);
|
||||
|
||||
expect(getLogCall).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Per-IP Rate Limiting", () => {
|
||||
it("should track rate limits per IP independently", async () => {
|
||||
// Note: In a real scenario, different IPs would have different limits
|
||||
// This test verifies the rate limit tracking behavior
|
||||
|
||||
// Exhaust rate limit with requests
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await request(app.getHttpServer()).post("/auth/sign-in").send({
|
||||
email: "test@example.com",
|
||||
password: "password",
|
||||
});
|
||||
}
|
||||
|
||||
// Should be rate limited now
|
||||
const response = await request(app.getHttpServer()).post("/auth/sign-in").send({
|
||||
email: "test@example.com",
|
||||
password: "password",
|
||||
});
|
||||
|
||||
expect(response.status).toBe(HttpStatus.TOO_MANY_REQUESTS);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,5 +1,26 @@
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
|
||||
// Mock better-auth modules before importing AuthService
|
||||
vi.mock("better-auth/node", () => ({
|
||||
toNodeHandler: vi.fn().mockReturnValue(vi.fn()),
|
||||
}));
|
||||
|
||||
vi.mock("better-auth", () => ({
|
||||
betterAuth: vi.fn().mockReturnValue({
|
||||
handler: vi.fn(),
|
||||
api: { getSession: vi.fn() },
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("better-auth/adapters/prisma", () => ({
|
||||
prismaAdapter: vi.fn().mockReturnValue({}),
|
||||
}));
|
||||
|
||||
vi.mock("better-auth/plugins", () => ({
|
||||
genericOAuth: vi.fn().mockReturnValue({ id: "generic-oauth" }),
|
||||
}));
|
||||
|
||||
import { AuthService } from "./auth.service";
|
||||
import { PrismaService } from "../prisma/prisma.service";
|
||||
|
||||
@@ -30,6 +51,12 @@ describe("AuthService", () => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
delete process.env.OIDC_ENABLED;
|
||||
delete process.env.OIDC_ISSUER;
|
||||
});
|
||||
|
||||
describe("getAuth", () => {
|
||||
it("should return BetterAuth instance", () => {
|
||||
const auth = service.getAuth();
|
||||
@@ -62,6 +89,23 @@ describe("AuthService", () => {
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should return null when user is not found", async () => {
|
||||
mockPrismaService.user.findUnique.mockResolvedValue(null);
|
||||
|
||||
const result = await service.getUserById("nonexistent-id");
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(mockPrismaService.user.findUnique).toHaveBeenCalledWith({
|
||||
where: { id: "nonexistent-id" },
|
||||
select: {
|
||||
id: true,
|
||||
email: true,
|
||||
name: true,
|
||||
authProviderId: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("getUserByEmail", () => {
|
||||
@@ -88,6 +132,269 @@ describe("AuthService", () => {
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should return null when user is not found", async () => {
|
||||
mockPrismaService.user.findUnique.mockResolvedValue(null);
|
||||
|
||||
const result = await service.getUserByEmail("unknown@example.com");
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(mockPrismaService.user.findUnique).toHaveBeenCalledWith({
|
||||
where: { email: "unknown@example.com" },
|
||||
select: {
|
||||
id: true,
|
||||
email: true,
|
||||
name: true,
|
||||
authProviderId: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("isOidcProviderReachable", () => {
|
||||
const discoveryUrl = "https://auth.example.com/.well-known/openid-configuration";
|
||||
|
||||
beforeEach(() => {
|
||||
process.env.OIDC_ISSUER = "https://auth.example.com/";
|
||||
// Reset the cache by accessing private fields via bracket notation
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(service as any).lastHealthCheck = 0;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(service as any).lastHealthResult = false;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(service as any).consecutiveHealthFailures = 0;
|
||||
});
|
||||
|
||||
it("should return true when discovery URL returns 200", async () => {
|
||||
const mockFetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
status: 200,
|
||||
});
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
const result = await service.isOidcProviderReachable();
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockFetch).toHaveBeenCalledWith(discoveryUrl, {
|
||||
signal: expect.any(AbortSignal) as AbortSignal,
|
||||
});
|
||||
});
|
||||
|
||||
it("should return false on network error", async () => {
|
||||
const mockFetch = vi.fn().mockRejectedValue(new Error("ECONNREFUSED"));
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
const result = await service.isOidcProviderReachable();
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it("should return false on timeout", async () => {
|
||||
const mockFetch = vi.fn().mockRejectedValue(new DOMException("The operation was aborted"));
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
const result = await service.isOidcProviderReachable();
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it("should return false when discovery URL returns non-200", async () => {
|
||||
const mockFetch = vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 503,
|
||||
});
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
const result = await service.isOidcProviderReachable();
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it("should cache result for 30 seconds", async () => {
|
||||
const mockFetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
status: 200,
|
||||
});
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
// First call - fetches
|
||||
const result1 = await service.isOidcProviderReachable();
|
||||
expect(result1).toBe(true);
|
||||
expect(mockFetch).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Second call within 30s - uses cache
|
||||
const result2 = await service.isOidcProviderReachable();
|
||||
expect(result2).toBe(true);
|
||||
expect(mockFetch).toHaveBeenCalledTimes(1); // Still 1, no new fetch
|
||||
|
||||
// Simulate cache expiry by moving lastHealthCheck back
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(service as any).lastHealthCheck = Date.now() - 31_000;
|
||||
|
||||
// Third call after cache expiry - fetches again
|
||||
const result3 = await service.isOidcProviderReachable();
|
||||
expect(result3).toBe(true);
|
||||
expect(mockFetch).toHaveBeenCalledTimes(2); // Now 2
|
||||
});
|
||||
|
||||
it("should cache false results too", async () => {
|
||||
const mockFetch = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(new Error("ECONNREFUSED"))
|
||||
.mockResolvedValueOnce({ ok: true, status: 200 });
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
// First call - fails
|
||||
const result1 = await service.isOidcProviderReachable();
|
||||
expect(result1).toBe(false);
|
||||
expect(mockFetch).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Second call within 30s - returns cached false
|
||||
const result2 = await service.isOidcProviderReachable();
|
||||
expect(result2).toBe(false);
|
||||
expect(mockFetch).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should escalate to error level after 3 consecutive failures", async () => {
|
||||
const mockFetch = vi.fn().mockRejectedValue(new Error("ECONNREFUSED"));
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
const loggerWarn = vi.spyOn(service["logger"], "warn");
|
||||
const loggerError = vi.spyOn(service["logger"], "error");
|
||||
|
||||
// Failures 1 and 2 should log at warn level
|
||||
await service.isOidcProviderReachable();
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(service as any).lastHealthCheck = 0; // Reset cache
|
||||
await service.isOidcProviderReachable();
|
||||
|
||||
expect(loggerWarn).toHaveBeenCalledTimes(2);
|
||||
expect(loggerError).not.toHaveBeenCalled();
|
||||
|
||||
// Failure 3 should escalate to error level
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(service as any).lastHealthCheck = 0;
|
||||
await service.isOidcProviderReachable();
|
||||
|
||||
expect(loggerError).toHaveBeenCalledTimes(1);
|
||||
expect(loggerError).toHaveBeenCalledWith(
|
||||
expect.stringContaining("OIDC provider unreachable")
|
||||
);
|
||||
});
|
||||
|
||||
it("should escalate to error level after 3 consecutive non-OK responses", async () => {
|
||||
const mockFetch = vi.fn().mockResolvedValue({ ok: false, status: 503 });
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
const loggerWarn = vi.spyOn(service["logger"], "warn");
|
||||
const loggerError = vi.spyOn(service["logger"], "error");
|
||||
|
||||
// Failures 1 and 2 at warn level
|
||||
await service.isOidcProviderReachable();
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(service as any).lastHealthCheck = 0;
|
||||
await service.isOidcProviderReachable();
|
||||
|
||||
expect(loggerWarn).toHaveBeenCalledTimes(2);
|
||||
expect(loggerError).not.toHaveBeenCalled();
|
||||
|
||||
// Failure 3 at error level
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(service as any).lastHealthCheck = 0;
|
||||
await service.isOidcProviderReachable();
|
||||
|
||||
expect(loggerError).toHaveBeenCalledTimes(1);
|
||||
expect(loggerError).toHaveBeenCalledWith(
|
||||
expect.stringContaining("OIDC provider returned non-OK status")
|
||||
);
|
||||
});
|
||||
|
||||
it("should reset failure counter and log recovery on success after failures", async () => {
|
||||
const mockFetch = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(new Error("ECONNREFUSED"))
|
||||
.mockRejectedValueOnce(new Error("ECONNREFUSED"))
|
||||
.mockResolvedValueOnce({ ok: true, status: 200 });
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
const loggerLog = vi.spyOn(service["logger"], "log");
|
||||
|
||||
// Two failures
|
||||
await service.isOidcProviderReachable();
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(service as any).lastHealthCheck = 0;
|
||||
await service.isOidcProviderReachable();
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(service as any).lastHealthCheck = 0;
|
||||
|
||||
// Recovery
|
||||
const result = await service.isOidcProviderReachable();
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(loggerLog).toHaveBeenCalledWith(
|
||||
expect.stringContaining("OIDC provider recovered after 2 consecutive failure(s)")
|
||||
);
|
||||
// Verify counter reset
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((service as any).consecutiveHealthFailures).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getAuthConfig", () => {
|
||||
it("should return only email provider when OIDC is disabled", async () => {
|
||||
delete process.env.OIDC_ENABLED;
|
||||
|
||||
const result = await service.getAuthConfig();
|
||||
|
||||
expect(result).toEqual({
|
||||
providers: [{ id: "email", name: "Email", type: "credentials" }],
|
||||
});
|
||||
});
|
||||
|
||||
it("should return both email and authentik providers when OIDC is enabled and reachable", async () => {
|
||||
process.env.OIDC_ENABLED = "true";
|
||||
process.env.OIDC_ISSUER = "https://auth.example.com/";
|
||||
|
||||
const mockFetch = vi.fn().mockResolvedValue({ ok: true, status: 200 });
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
const result = await service.getAuthConfig();
|
||||
|
||||
expect(result).toEqual({
|
||||
providers: [
|
||||
{ id: "email", name: "Email", type: "credentials" },
|
||||
{ id: "authentik", name: "Authentik", type: "oauth" },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it("should return only email provider when OIDC_ENABLED is false", async () => {
|
||||
process.env.OIDC_ENABLED = "false";
|
||||
|
||||
const result = await service.getAuthConfig();
|
||||
|
||||
expect(result).toEqual({
|
||||
providers: [{ id: "email", name: "Email", type: "credentials" }],
|
||||
});
|
||||
});
|
||||
|
||||
it("should omit authentik when OIDC is enabled but provider is unreachable", async () => {
|
||||
process.env.OIDC_ENABLED = "true";
|
||||
process.env.OIDC_ISSUER = "https://auth.example.com/";
|
||||
|
||||
// Reset cache
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(service as any).lastHealthCheck = 0;
|
||||
|
||||
const mockFetch = vi.fn().mockRejectedValue(new Error("ECONNREFUSED"));
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
const result = await service.getAuthConfig();
|
||||
|
||||
expect(result).toEqual({
|
||||
providers: [{ id: "email", name: "Email", type: "credentials" }],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("verifySession", () => {
|
||||
@@ -128,14 +435,268 @@ describe("AuthService", () => {
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null and log error on verification failure", async () => {
|
||||
it("should return null for 'invalid token' auth error", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue(new Error("Invalid token provided"));
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const result = await service.verifySession("bad-token");
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null for 'expired' auth error", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue(new Error("Token expired"));
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const result = await service.verifySession("expired-token");
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null for 'session not found' auth error", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue(new Error("Session not found"));
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const result = await service.verifySession("missing-session");
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null for 'unauthorized' auth error", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue(new Error("Unauthorized"));
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const result = await service.verifySession("unauth-token");
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null for 'invalid session' auth error", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue(new Error("Invalid session"));
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const result = await service.verifySession("invalid-session");
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null for 'session expired' auth error", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue(new Error("Session expired"));
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const result = await service.verifySession("expired-session");
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null for bare 'unauthorized' (exact match)", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue(new Error("unauthorized"));
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const result = await service.verifySession("unauth-token");
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null for bare 'expired' (exact match)", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue(new Error("expired"));
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const result = await service.verifySession("expired-token");
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should re-throw 'certificate has expired' as infrastructure error (not auth)", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error("certificate has expired"));
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
await expect(service.verifySession("any-token")).rejects.toThrow(
|
||||
"certificate has expired"
|
||||
);
|
||||
});
|
||||
|
||||
it("should re-throw 'Unauthorized: Access denied for user' as infrastructure error (not auth)", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error("Unauthorized: Access denied for user"));
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
await expect(service.verifySession("any-token")).rejects.toThrow(
|
||||
"Unauthorized: Access denied for user"
|
||||
);
|
||||
});
|
||||
|
||||
it("should return null when a non-Error value is thrown", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue("string-error");
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const result = await service.verifySession("any-token");
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null when getSession throws a non-Error value (string)", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue("some error");
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const result = await service.verifySession("any-token");
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null when getSession throws a non-Error value (object)", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue({ code: "ERR_UNKNOWN" });
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const result = await service.verifySession("any-token");
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should re-throw unexpected errors that are not known auth errors", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue(new Error("Verification failed"));
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const result = await service.verifySession("error-token");
|
||||
await expect(service.verifySession("error-token")).rejects.toThrow("Verification failed");
|
||||
});
|
||||
|
||||
it("should re-throw Prisma infrastructure errors", async () => {
|
||||
const auth = service.getAuth();
|
||||
const prismaError = new Error("connect ECONNREFUSED 127.0.0.1:5432");
|
||||
const mockGetSession = vi.fn().mockRejectedValue(prismaError);
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
await expect(service.verifySession("any-token")).rejects.toThrow("ECONNREFUSED");
|
||||
});
|
||||
|
||||
it("should re-throw timeout errors as infrastructure errors", async () => {
|
||||
const auth = service.getAuth();
|
||||
const timeoutError = new Error("Connection timeout after 5000ms");
|
||||
const mockGetSession = vi.fn().mockRejectedValue(timeoutError);
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
await expect(service.verifySession("any-token")).rejects.toThrow("timeout");
|
||||
});
|
||||
|
||||
it("should re-throw errors with Prisma-prefixed constructor name", async () => {
|
||||
const auth = service.getAuth();
|
||||
class PrismaClientKnownRequestError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = "PrismaClientKnownRequestError";
|
||||
}
|
||||
}
|
||||
const prismaError = new PrismaClientKnownRequestError("Database connection lost");
|
||||
const mockGetSession = vi.fn().mockRejectedValue(prismaError);
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
await expect(service.verifySession("any-token")).rejects.toThrow("Database connection lost");
|
||||
});
|
||||
|
||||
it("should redact Bearer tokens from logged error messages", async () => {
|
||||
const auth = service.getAuth();
|
||||
const errorWithToken = new Error(
|
||||
"Request failed: Bearer eyJhbGciOiJIUzI1NiJ9.secret-payload in header"
|
||||
);
|
||||
const mockGetSession = vi.fn().mockRejectedValue(errorWithToken);
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const loggerError = vi.spyOn(service["logger"], "error");
|
||||
|
||||
await expect(service.verifySession("any-token")).rejects.toThrow();
|
||||
|
||||
expect(loggerError).toHaveBeenCalledWith(
|
||||
"Session verification failed due to unexpected error",
|
||||
expect.stringContaining("Bearer [REDACTED]")
|
||||
);
|
||||
expect(loggerError).toHaveBeenCalledWith(
|
||||
"Session verification failed due to unexpected error",
|
||||
expect.not.stringContaining("eyJhbGciOiJIUzI1NiJ9")
|
||||
);
|
||||
});
|
||||
|
||||
it("should redact Bearer tokens from error stack traces", async () => {
|
||||
const auth = service.getAuth();
|
||||
const errorWithToken = new Error("Something went wrong");
|
||||
errorWithToken.stack =
|
||||
"Error: Something went wrong\n at fetch (Bearer abc123-secret-token)\n at verifySession";
|
||||
const mockGetSession = vi.fn().mockRejectedValue(errorWithToken);
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const loggerError = vi.spyOn(service["logger"], "error");
|
||||
|
||||
await expect(service.verifySession("any-token")).rejects.toThrow();
|
||||
|
||||
expect(loggerError).toHaveBeenCalledWith(
|
||||
"Session verification failed due to unexpected error",
|
||||
expect.stringContaining("Bearer [REDACTED]")
|
||||
);
|
||||
expect(loggerError).toHaveBeenCalledWith(
|
||||
"Session verification failed due to unexpected error",
|
||||
expect.not.stringContaining("abc123-secret-token")
|
||||
);
|
||||
});
|
||||
|
||||
it("should warn when a non-Error string value is thrown", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue("string-error");
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const loggerWarn = vi.spyOn(service["logger"], "warn");
|
||||
|
||||
const result = await service.verifySession("any-token");
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(loggerWarn).toHaveBeenCalledWith(
|
||||
"Session verification received non-Error thrown value",
|
||||
"string-error"
|
||||
);
|
||||
});
|
||||
|
||||
it("should warn with JSON when a non-Error object is thrown", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue({ code: "ERR_UNKNOWN" });
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const loggerWarn = vi.spyOn(service["logger"], "warn");
|
||||
|
||||
const result = await service.verifySession("any-token");
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(loggerWarn).toHaveBeenCalledWith(
|
||||
"Session verification received non-Error thrown value",
|
||||
JSON.stringify({ code: "ERR_UNKNOWN" })
|
||||
);
|
||||
});
|
||||
|
||||
it("should not warn for expected auth errors (Error instances)", async () => {
|
||||
const auth = service.getAuth();
|
||||
const mockGetSession = vi.fn().mockRejectedValue(new Error("Invalid token provided"));
|
||||
auth.api = { getSession: mockGetSession } as any;
|
||||
|
||||
const loggerWarn = vi.spyOn(service["logger"], "warn");
|
||||
|
||||
const result = await service.verifySession("bad-token");
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(loggerWarn).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,17 +1,45 @@
|
||||
import { Injectable, Logger } from "@nestjs/common";
|
||||
import type { PrismaClient } from "@prisma/client";
|
||||
import type { IncomingMessage, ServerResponse } from "http";
|
||||
import { toNodeHandler } from "better-auth/node";
|
||||
import type { AuthConfigResponse, AuthProviderConfig } from "@mosaic/shared";
|
||||
import { PrismaService } from "../prisma/prisma.service";
|
||||
import { createAuth, type Auth } from "./auth.config";
|
||||
import { createAuth, isOidcEnabled, type Auth } from "./auth.config";
|
||||
|
||||
/** Duration in milliseconds to cache the OIDC health check result */
|
||||
const OIDC_HEALTH_CACHE_TTL_MS = 30_000;
|
||||
|
||||
/** Timeout in milliseconds for the OIDC discovery URL fetch */
|
||||
const OIDC_HEALTH_TIMEOUT_MS = 2_000;
|
||||
|
||||
/** Number of consecutive health-check failures before escalating to error level */
|
||||
const HEALTH_ESCALATION_THRESHOLD = 3;
|
||||
|
||||
/** Verified session shape returned by BetterAuth's getSession */
|
||||
interface VerifiedSession {
|
||||
user: Record<string, unknown>;
|
||||
session: Record<string, unknown>;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class AuthService {
|
||||
private readonly logger = new Logger(AuthService.name);
|
||||
private readonly auth: Auth;
|
||||
private readonly nodeHandler: (req: IncomingMessage, res: ServerResponse) => Promise<void>;
|
||||
|
||||
/** Timestamp of the last OIDC health check */
|
||||
private lastHealthCheck = 0;
|
||||
/** Cached result of the last OIDC health check */
|
||||
private lastHealthResult = false;
|
||||
/** Consecutive OIDC health check failure count for log-level escalation */
|
||||
private consecutiveHealthFailures = 0;
|
||||
|
||||
constructor(private readonly prisma: PrismaService) {
|
||||
// PrismaService extends PrismaClient and is compatible with BetterAuth's adapter
|
||||
// Cast is safe as PrismaService provides all required PrismaClient methods
|
||||
// TODO(#411): BetterAuth returns opaque types — replace when upstream exports typed interfaces
|
||||
this.auth = createAuth(this.prisma as unknown as PrismaClient);
|
||||
this.nodeHandler = toNodeHandler(this.auth);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -21,6 +49,14 @@ export class AuthService {
|
||||
return this.auth;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Node.js-compatible request handler for BetterAuth.
|
||||
* Wraps BetterAuth's Web API handler to work with Express/Node.js req/res.
|
||||
*/
|
||||
getNodeHandler(): (req: IncomingMessage, res: ServerResponse) => Promise<void> {
|
||||
return this.nodeHandler;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get user by ID
|
||||
*/
|
||||
@@ -63,12 +99,12 @@ export class AuthService {
|
||||
|
||||
/**
|
||||
* Verify session token
|
||||
* Returns session data if valid, null if invalid or expired
|
||||
* Returns session data if valid, null if invalid or expired.
|
||||
* Only known-safe auth errors return null; everything else propagates as 500.
|
||||
*/
|
||||
async verifySession(
|
||||
token: string
|
||||
): Promise<{ user: Record<string, unknown>; session: Record<string, unknown> } | null> {
|
||||
async verifySession(token: string): Promise<VerifiedSession | null> {
|
||||
try {
|
||||
// TODO(#411): BetterAuth getSession returns opaque types — replace when upstream exports typed interfaces
|
||||
const session = await this.auth.api.getSession({
|
||||
headers: {
|
||||
authorization: `Bearer ${token}`,
|
||||
@@ -83,12 +119,107 @@ export class AuthService {
|
||||
user: session.user as Record<string, unknown>,
|
||||
session: session.session as Record<string, unknown>,
|
||||
};
|
||||
} catch (error) {
|
||||
this.logger.error(
|
||||
"Session verification failed",
|
||||
error instanceof Error ? error.message : "Unknown error"
|
||||
);
|
||||
} catch (error: unknown) {
|
||||
// Only known-safe auth errors return null
|
||||
if (error instanceof Error) {
|
||||
const msg = error.message.toLowerCase();
|
||||
const isExpectedAuthError =
|
||||
msg.includes("invalid token") ||
|
||||
msg.includes("token expired") ||
|
||||
msg.includes("session expired") ||
|
||||
msg.includes("session not found") ||
|
||||
msg.includes("invalid session") ||
|
||||
msg === "unauthorized" ||
|
||||
msg === "expired";
|
||||
|
||||
if (!isExpectedAuthError) {
|
||||
// Infrastructure or unexpected — propagate as 500
|
||||
const safeMessage = (error.stack ?? error.message).replace(
|
||||
/Bearer\s+\S+/gi,
|
||||
"Bearer [REDACTED]"
|
||||
);
|
||||
this.logger.error("Session verification failed due to unexpected error", safeMessage);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
// Non-Error thrown values — log for observability, treat as auth failure
|
||||
if (!(error instanceof Error)) {
|
||||
const errorDetail = typeof error === "string" ? error : JSON.stringify(error);
|
||||
this.logger.warn("Session verification received non-Error thrown value", errorDetail);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the OIDC provider (Authentik) is reachable by fetching the discovery URL.
|
||||
* Results are cached for 30 seconds to prevent repeated network calls.
|
||||
*
|
||||
* @returns true if the provider responds with an HTTP 2xx status, false otherwise
|
||||
*/
|
||||
async isOidcProviderReachable(): Promise<boolean> {
|
||||
const now = Date.now();
|
||||
|
||||
// Return cached result if still valid
|
||||
if (now - this.lastHealthCheck < OIDC_HEALTH_CACHE_TTL_MS) {
|
||||
this.logger.debug("OIDC health check: returning cached result");
|
||||
return this.lastHealthResult;
|
||||
}
|
||||
|
||||
const discoveryUrl = `${process.env.OIDC_ISSUER ?? ""}.well-known/openid-configuration`;
|
||||
this.logger.debug(`OIDC health check: fetching ${discoveryUrl}`);
|
||||
|
||||
try {
|
||||
const response = await fetch(discoveryUrl, {
|
||||
signal: AbortSignal.timeout(OIDC_HEALTH_TIMEOUT_MS),
|
||||
});
|
||||
|
||||
this.lastHealthCheck = Date.now();
|
||||
this.lastHealthResult = response.ok;
|
||||
|
||||
if (response.ok) {
|
||||
if (this.consecutiveHealthFailures > 0) {
|
||||
this.logger.log(
|
||||
`OIDC provider recovered after ${String(this.consecutiveHealthFailures)} consecutive failure(s)`
|
||||
);
|
||||
}
|
||||
this.consecutiveHealthFailures = 0;
|
||||
} else {
|
||||
this.consecutiveHealthFailures++;
|
||||
const logLevel =
|
||||
this.consecutiveHealthFailures >= HEALTH_ESCALATION_THRESHOLD ? "error" : "warn";
|
||||
this.logger[logLevel](
|
||||
`OIDC provider returned non-OK status: ${String(response.status)} from ${discoveryUrl}`
|
||||
);
|
||||
}
|
||||
|
||||
return this.lastHealthResult;
|
||||
} catch (error: unknown) {
|
||||
this.lastHealthCheck = Date.now();
|
||||
this.lastHealthResult = false;
|
||||
this.consecutiveHealthFailures++;
|
||||
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
const logLevel =
|
||||
this.consecutiveHealthFailures >= HEALTH_ESCALATION_THRESHOLD ? "error" : "warn";
|
||||
this.logger[logLevel](`OIDC provider unreachable at ${discoveryUrl}: ${message}`);
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get authentication configuration for the frontend.
|
||||
* Returns available auth providers so the UI can render login options dynamically.
|
||||
* When OIDC is enabled, performs a health check to verify the provider is reachable.
|
||||
*/
|
||||
async getAuthConfig(): Promise<AuthConfigResponse> {
|
||||
const providers: AuthProviderConfig[] = [{ id: "email", name: "Email", type: "credentials" }];
|
||||
|
||||
if (isOidcEnabled() && (await this.isOidcProviderReachable())) {
|
||||
providers.push({ id: "authentik", name: "Authentik", type: "oauth" });
|
||||
}
|
||||
|
||||
return { providers };
|
||||
}
|
||||
}
|
||||
|
||||
96
apps/api/src/auth/decorators/current-user.decorator.spec.ts
Normal file
96
apps/api/src/auth/decorators/current-user.decorator.spec.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { ExecutionContext, UnauthorizedException } from "@nestjs/common";
|
||||
import { ROUTE_ARGS_METADATA } from "@nestjs/common/constants";
|
||||
import { CurrentUser } from "./current-user.decorator";
|
||||
import type { AuthUser } from "@mosaic/shared";
|
||||
|
||||
/**
|
||||
* Extract the factory function from a NestJS param decorator created with createParamDecorator.
|
||||
* NestJS stores param decorator factories in metadata on a dummy class.
|
||||
*/
|
||||
function getParamDecoratorFactory(): (data: unknown, ctx: ExecutionContext) => AuthUser {
|
||||
class TestController {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
testMethod(@CurrentUser() _user: AuthUser): void {
|
||||
// no-op
|
||||
}
|
||||
}
|
||||
|
||||
const metadata = Reflect.getMetadata(ROUTE_ARGS_METADATA, TestController, "testMethod");
|
||||
|
||||
// The metadata keys are in the format "paramtype:index"
|
||||
const key = Object.keys(metadata)[0];
|
||||
return metadata[key].factory;
|
||||
}
|
||||
|
||||
function createMockExecutionContext(user?: AuthUser): ExecutionContext {
|
||||
const mockRequest = {
|
||||
...(user !== undefined ? { user } : {}),
|
||||
};
|
||||
|
||||
return {
|
||||
switchToHttp: () => ({
|
||||
getRequest: () => mockRequest,
|
||||
}),
|
||||
} as ExecutionContext;
|
||||
}
|
||||
|
||||
describe("CurrentUser decorator", () => {
|
||||
const factory = getParamDecoratorFactory();
|
||||
|
||||
const mockUser: AuthUser = {
|
||||
id: "user-123",
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
};
|
||||
|
||||
it("should return the user when present on the request", () => {
|
||||
const ctx = createMockExecutionContext(mockUser);
|
||||
const result = factory(undefined, ctx);
|
||||
|
||||
expect(result).toEqual(mockUser);
|
||||
});
|
||||
|
||||
it("should return the user with optional fields", () => {
|
||||
const userWithOptionalFields: AuthUser = {
|
||||
...mockUser,
|
||||
image: "https://example.com/avatar.png",
|
||||
workspaceId: "ws-123",
|
||||
workspaceRole: "owner",
|
||||
};
|
||||
|
||||
const ctx = createMockExecutionContext(userWithOptionalFields);
|
||||
const result = factory(undefined, ctx);
|
||||
|
||||
expect(result).toEqual(userWithOptionalFields);
|
||||
expect(result.image).toBe("https://example.com/avatar.png");
|
||||
expect(result.workspaceId).toBe("ws-123");
|
||||
});
|
||||
|
||||
it("should throw UnauthorizedException when user is undefined", () => {
|
||||
const ctx = createMockExecutionContext(undefined);
|
||||
|
||||
expect(() => factory(undefined, ctx)).toThrow(UnauthorizedException);
|
||||
expect(() => factory(undefined, ctx)).toThrow("No authenticated user found on request");
|
||||
});
|
||||
|
||||
it("should throw UnauthorizedException when request has no user property", () => {
|
||||
// Request object without a user property at all
|
||||
const ctx = {
|
||||
switchToHttp: () => ({
|
||||
getRequest: () => ({}),
|
||||
}),
|
||||
} as ExecutionContext;
|
||||
|
||||
expect(() => factory(undefined, ctx)).toThrow(UnauthorizedException);
|
||||
});
|
||||
|
||||
it("should ignore the data parameter", () => {
|
||||
const ctx = createMockExecutionContext(mockUser);
|
||||
|
||||
// The decorator doesn't use the data parameter, but ensure it doesn't break
|
||||
const result = factory("some-data", ctx);
|
||||
|
||||
expect(result).toEqual(mockUser);
|
||||
});
|
||||
});
|
||||
@@ -1,10 +1,16 @@
|
||||
import type { ExecutionContext } from "@nestjs/common";
|
||||
import { createParamDecorator } from "@nestjs/common";
|
||||
import type { AuthenticatedRequest, AuthenticatedUser } from "../../common/types/user.types";
|
||||
import { createParamDecorator, UnauthorizedException } from "@nestjs/common";
|
||||
import type { AuthUser } from "@mosaic/shared";
|
||||
import type { MaybeAuthenticatedRequest } from "../types/better-auth-request.interface";
|
||||
|
||||
export const CurrentUser = createParamDecorator(
|
||||
(_data: unknown, ctx: ExecutionContext): AuthenticatedUser | undefined => {
|
||||
const request = ctx.switchToHttp().getRequest<AuthenticatedRequest>();
|
||||
(_data: unknown, ctx: ExecutionContext): AuthUser => {
|
||||
// Use MaybeAuthenticatedRequest because the decorator doesn't know
|
||||
// whether AuthGuard ran — the null check provides defense-in-depth.
|
||||
const request = ctx.switchToHttp().getRequest<MaybeAuthenticatedRequest>();
|
||||
if (!request.user) {
|
||||
throw new UnauthorizedException("No authenticated user found on request");
|
||||
}
|
||||
return request.user;
|
||||
}
|
||||
);
|
||||
|
||||
170
apps/api/src/auth/guards/admin.guard.spec.ts
Normal file
170
apps/api/src/auth/guards/admin.guard.spec.ts
Normal file
@@ -0,0 +1,170 @@
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
|
||||
import { ExecutionContext, ForbiddenException } from "@nestjs/common";
|
||||
import { AdminGuard } from "./admin.guard";
|
||||
|
||||
describe("AdminGuard", () => {
|
||||
const originalEnv = process.env.SYSTEM_ADMIN_IDS;
|
||||
|
||||
afterEach(() => {
|
||||
// Restore original environment
|
||||
if (originalEnv !== undefined) {
|
||||
process.env.SYSTEM_ADMIN_IDS = originalEnv;
|
||||
} else {
|
||||
delete process.env.SYSTEM_ADMIN_IDS;
|
||||
}
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
const createMockExecutionContext = (user: { id: string } | undefined): ExecutionContext => {
|
||||
const mockRequest = {
|
||||
user,
|
||||
};
|
||||
|
||||
return {
|
||||
switchToHttp: () => ({
|
||||
getRequest: () => mockRequest,
|
||||
}),
|
||||
} as ExecutionContext;
|
||||
};
|
||||
|
||||
describe("constructor", () => {
|
||||
it("should parse system admin IDs from environment variable", () => {
|
||||
process.env.SYSTEM_ADMIN_IDS = "admin-1,admin-2,admin-3";
|
||||
const guard = new AdminGuard();
|
||||
|
||||
expect(guard.isSystemAdmin("admin-1")).toBe(true);
|
||||
expect(guard.isSystemAdmin("admin-2")).toBe(true);
|
||||
expect(guard.isSystemAdmin("admin-3")).toBe(true);
|
||||
});
|
||||
|
||||
it("should handle whitespace in admin IDs", () => {
|
||||
process.env.SYSTEM_ADMIN_IDS = " admin-1 , admin-2 , admin-3 ";
|
||||
const guard = new AdminGuard();
|
||||
|
||||
expect(guard.isSystemAdmin("admin-1")).toBe(true);
|
||||
expect(guard.isSystemAdmin("admin-2")).toBe(true);
|
||||
expect(guard.isSystemAdmin("admin-3")).toBe(true);
|
||||
});
|
||||
|
||||
it("should handle empty environment variable", () => {
|
||||
process.env.SYSTEM_ADMIN_IDS = "";
|
||||
const guard = new AdminGuard();
|
||||
|
||||
expect(guard.isSystemAdmin("any-user")).toBe(false);
|
||||
});
|
||||
|
||||
it("should handle missing environment variable", () => {
|
||||
delete process.env.SYSTEM_ADMIN_IDS;
|
||||
const guard = new AdminGuard();
|
||||
|
||||
expect(guard.isSystemAdmin("any-user")).toBe(false);
|
||||
});
|
||||
|
||||
it("should handle single admin ID", () => {
|
||||
process.env.SYSTEM_ADMIN_IDS = "single-admin";
|
||||
const guard = new AdminGuard();
|
||||
|
||||
expect(guard.isSystemAdmin("single-admin")).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("isSystemAdmin", () => {
|
||||
let guard: AdminGuard;
|
||||
|
||||
beforeEach(() => {
|
||||
process.env.SYSTEM_ADMIN_IDS = "admin-uuid-1,admin-uuid-2";
|
||||
guard = new AdminGuard();
|
||||
});
|
||||
|
||||
it("should return true for configured system admin", () => {
|
||||
expect(guard.isSystemAdmin("admin-uuid-1")).toBe(true);
|
||||
expect(guard.isSystemAdmin("admin-uuid-2")).toBe(true);
|
||||
});
|
||||
|
||||
it("should return false for non-admin user", () => {
|
||||
expect(guard.isSystemAdmin("regular-user-id")).toBe(false);
|
||||
});
|
||||
|
||||
it("should return false for empty string", () => {
|
||||
expect(guard.isSystemAdmin("")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("canActivate", () => {
|
||||
let guard: AdminGuard;
|
||||
|
||||
beforeEach(() => {
|
||||
process.env.SYSTEM_ADMIN_IDS = "admin-uuid-1,admin-uuid-2";
|
||||
guard = new AdminGuard();
|
||||
});
|
||||
|
||||
it("should return true for system admin user", () => {
|
||||
const context = createMockExecutionContext({ id: "admin-uuid-1" });
|
||||
|
||||
const result = guard.canActivate(context);
|
||||
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it("should throw ForbiddenException for non-admin user", () => {
|
||||
const context = createMockExecutionContext({ id: "regular-user-id" });
|
||||
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
expect(() => guard.canActivate(context)).toThrow(
|
||||
"This operation requires system administrator privileges"
|
||||
);
|
||||
});
|
||||
|
||||
it("should throw ForbiddenException when user is not authenticated", () => {
|
||||
const context = createMockExecutionContext(undefined);
|
||||
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
expect(() => guard.canActivate(context)).toThrow("User not authenticated");
|
||||
});
|
||||
|
||||
it("should NOT grant admin access based on workspace ownership", () => {
|
||||
// This test verifies that workspace ownership alone does not grant admin access
|
||||
// The user must be explicitly listed in SYSTEM_ADMIN_IDS
|
||||
const workspaceOwnerButNotSystemAdmin = { id: "workspace-owner-id" };
|
||||
const context = createMockExecutionContext(workspaceOwnerButNotSystemAdmin);
|
||||
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
expect(() => guard.canActivate(context)).toThrow(
|
||||
"This operation requires system administrator privileges"
|
||||
);
|
||||
});
|
||||
|
||||
it("should deny access when no system admins are configured", () => {
|
||||
process.env.SYSTEM_ADMIN_IDS = "";
|
||||
const guardWithNoAdmins = new AdminGuard();
|
||||
|
||||
const context = createMockExecutionContext({ id: "any-user-id" });
|
||||
|
||||
expect(() => guardWithNoAdmins.canActivate(context)).toThrow(ForbiddenException);
|
||||
});
|
||||
});
|
||||
|
||||
describe("security: workspace ownership vs system admin", () => {
|
||||
it("should require explicit system admin configuration, not implicit workspace ownership", () => {
|
||||
// Setup: user is NOT in SYSTEM_ADMIN_IDS
|
||||
process.env.SYSTEM_ADMIN_IDS = "different-admin-id";
|
||||
const guard = new AdminGuard();
|
||||
|
||||
// Even if this user owns workspaces, they should NOT have system admin access
|
||||
// because they are not in SYSTEM_ADMIN_IDS
|
||||
const context = createMockExecutionContext({ id: "workspace-owner-user-id" });
|
||||
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
});
|
||||
|
||||
it("should grant access only to users explicitly listed as system admins", () => {
|
||||
const adminUserId = "explicitly-configured-admin";
|
||||
process.env.SYSTEM_ADMIN_IDS = adminUserId;
|
||||
const guard = new AdminGuard();
|
||||
|
||||
const context = createMockExecutionContext({ id: adminUserId });
|
||||
|
||||
expect(guard.canActivate(context)).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -2,8 +2,14 @@
|
||||
* Admin Guard
|
||||
*
|
||||
* Restricts access to system-level admin operations.
|
||||
* Currently checks if user owns at least one workspace (indicating admin status).
|
||||
* Future: Replace with proper role-based access control (RBAC).
|
||||
* System administrators are configured via the SYSTEM_ADMIN_IDS environment variable.
|
||||
*
|
||||
* Configuration:
|
||||
* SYSTEM_ADMIN_IDS=uuid1,uuid2,uuid3 (comma-separated list of user IDs)
|
||||
*
|
||||
* Note: Workspace ownership does NOT grant system admin access. These are separate concepts:
|
||||
* - Workspace owner: Can manage their workspace and its members
|
||||
* - System admin: Can perform system-level operations across all workspaces
|
||||
*/
|
||||
|
||||
import {
|
||||
@@ -13,16 +19,42 @@ import {
|
||||
ForbiddenException,
|
||||
Logger,
|
||||
} from "@nestjs/common";
|
||||
import { PrismaService } from "../../prisma/prisma.service";
|
||||
import type { AuthenticatedRequest } from "../../common/types/user.types";
|
||||
|
||||
@Injectable()
|
||||
export class AdminGuard implements CanActivate {
|
||||
private readonly logger = new Logger(AdminGuard.name);
|
||||
private readonly systemAdminIds: Set<string>;
|
||||
|
||||
constructor(private readonly prisma: PrismaService) {}
|
||||
constructor() {
|
||||
// Load system admin IDs from environment variable
|
||||
const adminIdsEnv = process.env.SYSTEM_ADMIN_IDS ?? "";
|
||||
this.systemAdminIds = new Set(
|
||||
adminIdsEnv
|
||||
.split(",")
|
||||
.map((id) => id.trim())
|
||||
.filter((id) => id.length > 0)
|
||||
);
|
||||
|
||||
async canActivate(context: ExecutionContext): Promise<boolean> {
|
||||
if (this.systemAdminIds.size === 0) {
|
||||
this.logger.warn(
|
||||
"No system administrators configured. Set SYSTEM_ADMIN_IDS environment variable."
|
||||
);
|
||||
} else {
|
||||
this.logger.log(
|
||||
`System administrators configured: ${String(this.systemAdminIds.size)} user(s)`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a user ID is a system administrator
|
||||
*/
|
||||
isSystemAdmin(userId: string): boolean {
|
||||
return this.systemAdminIds.has(userId);
|
||||
}
|
||||
|
||||
canActivate(context: ExecutionContext): boolean {
|
||||
const request = context.switchToHttp().getRequest<AuthenticatedRequest>();
|
||||
const user = request.user;
|
||||
|
||||
@@ -30,13 +62,7 @@ export class AdminGuard implements CanActivate {
|
||||
throw new ForbiddenException("User not authenticated");
|
||||
}
|
||||
|
||||
// Check if user owns any workspace (admin indicator)
|
||||
// TODO: Replace with proper RBAC system admin role check
|
||||
const ownedWorkspaces = await this.prisma.workspace.count({
|
||||
where: { ownerId: user.id },
|
||||
});
|
||||
|
||||
if (ownedWorkspaces === 0) {
|
||||
if (!this.isSystemAdmin(user.id)) {
|
||||
this.logger.warn(`Non-admin user ${user.id} attempted admin operation`);
|
||||
throw new ForbiddenException("This operation requires system administrator privileges");
|
||||
}
|
||||
|
||||
@@ -1,37 +1,50 @@
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { ExecutionContext, UnauthorizedException } from "@nestjs/common";
|
||||
|
||||
// Mock better-auth modules before importing AuthGuard (which imports AuthService)
|
||||
vi.mock("better-auth/node", () => ({
|
||||
toNodeHandler: vi.fn().mockReturnValue(vi.fn()),
|
||||
}));
|
||||
|
||||
vi.mock("better-auth", () => ({
|
||||
betterAuth: vi.fn().mockReturnValue({
|
||||
handler: vi.fn(),
|
||||
api: { getSession: vi.fn() },
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("better-auth/adapters/prisma", () => ({
|
||||
prismaAdapter: vi.fn().mockReturnValue({}),
|
||||
}));
|
||||
|
||||
vi.mock("better-auth/plugins", () => ({
|
||||
genericOAuth: vi.fn().mockReturnValue({ id: "generic-oauth" }),
|
||||
}));
|
||||
|
||||
import { AuthGuard } from "./auth.guard";
|
||||
import { AuthService } from "../auth.service";
|
||||
import type { AuthService } from "../auth.service";
|
||||
|
||||
describe("AuthGuard", () => {
|
||||
let guard: AuthGuard;
|
||||
let authService: AuthService;
|
||||
|
||||
const mockAuthService = {
|
||||
verifySession: vi.fn(),
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
AuthGuard,
|
||||
{
|
||||
provide: AuthService,
|
||||
useValue: mockAuthService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
guard = module.get<AuthGuard>(AuthGuard);
|
||||
authService = module.get<AuthService>(AuthService);
|
||||
beforeEach(() => {
|
||||
// Directly construct the guard with the mock to avoid NestJS DI issues
|
||||
guard = new AuthGuard(mockAuthService as unknown as AuthService);
|
||||
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
const createMockExecutionContext = (headers: any = {}): ExecutionContext => {
|
||||
const createMockExecutionContext = (
|
||||
headers: Record<string, string> = {},
|
||||
cookies: Record<string, string> = {}
|
||||
): ExecutionContext => {
|
||||
const mockRequest = {
|
||||
headers,
|
||||
cookies,
|
||||
};
|
||||
|
||||
return {
|
||||
@@ -42,57 +55,256 @@ describe("AuthGuard", () => {
|
||||
};
|
||||
|
||||
describe("canActivate", () => {
|
||||
it("should return true for valid session", async () => {
|
||||
const mockSessionData = {
|
||||
user: {
|
||||
id: "user-123",
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
},
|
||||
session: {
|
||||
id: "session-123",
|
||||
},
|
||||
const mockSessionData = {
|
||||
user: {
|
||||
id: "user-123",
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
},
|
||||
session: {
|
||||
id: "session-123",
|
||||
token: "session-token",
|
||||
expiresAt: new Date(Date.now() + 86400000),
|
||||
},
|
||||
};
|
||||
|
||||
describe("Bearer token authentication", () => {
|
||||
it("should return true for valid Bearer token", async () => {
|
||||
mockAuthService.verifySession.mockResolvedValue(mockSessionData);
|
||||
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer valid-token",
|
||||
});
|
||||
|
||||
const result = await guard.canActivate(context);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockAuthService.verifySession).toHaveBeenCalledWith("valid-token");
|
||||
});
|
||||
|
||||
it("should throw UnauthorizedException for invalid Bearer token", async () => {
|
||||
mockAuthService.verifySession.mockResolvedValue(null);
|
||||
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer invalid-token",
|
||||
});
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(UnauthorizedException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow("Invalid or expired session");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Cookie-based authentication", () => {
|
||||
it("should return true for valid session cookie", async () => {
|
||||
mockAuthService.verifySession.mockResolvedValue(mockSessionData);
|
||||
|
||||
const context = createMockExecutionContext(
|
||||
{},
|
||||
{
|
||||
"better-auth.session_token": "cookie-token",
|
||||
}
|
||||
);
|
||||
|
||||
const result = await guard.canActivate(context);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockAuthService.verifySession).toHaveBeenCalledWith("cookie-token");
|
||||
});
|
||||
|
||||
it("should prefer cookie over Bearer token when both present", async () => {
|
||||
mockAuthService.verifySession.mockResolvedValue(mockSessionData);
|
||||
|
||||
const context = createMockExecutionContext(
|
||||
{
|
||||
authorization: "Bearer bearer-token",
|
||||
},
|
||||
{
|
||||
"better-auth.session_token": "cookie-token",
|
||||
}
|
||||
);
|
||||
|
||||
const result = await guard.canActivate(context);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockAuthService.verifySession).toHaveBeenCalledWith("cookie-token");
|
||||
});
|
||||
|
||||
it("should fallback to Bearer token if no cookie", async () => {
|
||||
mockAuthService.verifySession.mockResolvedValue(mockSessionData);
|
||||
|
||||
const context = createMockExecutionContext(
|
||||
{
|
||||
authorization: "Bearer bearer-token",
|
||||
},
|
||||
{}
|
||||
);
|
||||
|
||||
const result = await guard.canActivate(context);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockAuthService.verifySession).toHaveBeenCalledWith("bearer-token");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error handling", () => {
|
||||
it("should throw UnauthorizedException if no token provided", async () => {
|
||||
const context = createMockExecutionContext({}, {});
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(UnauthorizedException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(
|
||||
"No authentication token provided"
|
||||
);
|
||||
});
|
||||
|
||||
it("should propagate non-auth errors as-is (not wrap as 401)", async () => {
|
||||
const infraError = new Error("connect ECONNREFUSED 127.0.0.1:5432");
|
||||
mockAuthService.verifySession.mockRejectedValue(infraError);
|
||||
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer error-token",
|
||||
});
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(infraError);
|
||||
await expect(guard.canActivate(context)).rejects.not.toBeInstanceOf(UnauthorizedException);
|
||||
});
|
||||
|
||||
it("should propagate database errors so GlobalExceptionFilter returns 500", async () => {
|
||||
const dbError = new Error("PrismaClientKnownRequestError: Connection refused");
|
||||
mockAuthService.verifySession.mockRejectedValue(dbError);
|
||||
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer valid-token",
|
||||
});
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(dbError);
|
||||
await expect(guard.canActivate(context)).rejects.not.toBeInstanceOf(UnauthorizedException);
|
||||
});
|
||||
|
||||
it("should propagate timeout errors so GlobalExceptionFilter returns 503", async () => {
|
||||
const timeoutError = new Error("Connection timeout after 5000ms");
|
||||
mockAuthService.verifySession.mockRejectedValue(timeoutError);
|
||||
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer valid-token",
|
||||
});
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(timeoutError);
|
||||
await expect(guard.canActivate(context)).rejects.not.toBeInstanceOf(UnauthorizedException);
|
||||
});
|
||||
});
|
||||
|
||||
describe("user data validation", () => {
|
||||
const mockSession = {
|
||||
id: "session-123",
|
||||
token: "session-token",
|
||||
expiresAt: new Date(Date.now() + 86400000),
|
||||
};
|
||||
|
||||
mockAuthService.verifySession.mockResolvedValue(mockSessionData);
|
||||
it("should throw UnauthorizedException when user is missing id", async () => {
|
||||
mockAuthService.verifySession.mockResolvedValue({
|
||||
user: { email: "a@b.com", name: "Test" },
|
||||
session: mockSession,
|
||||
});
|
||||
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer valid-token",
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer valid-token",
|
||||
});
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(UnauthorizedException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(
|
||||
"Invalid user data in session"
|
||||
);
|
||||
});
|
||||
|
||||
const result = await guard.canActivate(context);
|
||||
it("should throw UnauthorizedException when user is missing email", async () => {
|
||||
mockAuthService.verifySession.mockResolvedValue({
|
||||
user: { id: "1", name: "Test" },
|
||||
session: mockSession,
|
||||
});
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockAuthService.verifySession).toHaveBeenCalledWith("valid-token");
|
||||
});
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer valid-token",
|
||||
});
|
||||
|
||||
it("should throw UnauthorizedException if no token provided", async () => {
|
||||
const context = createMockExecutionContext({});
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(UnauthorizedException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow("No authentication token provided");
|
||||
});
|
||||
|
||||
it("should throw UnauthorizedException if session is invalid", async () => {
|
||||
mockAuthService.verifySession.mockResolvedValue(null);
|
||||
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer invalid-token",
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(UnauthorizedException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(
|
||||
"Invalid user data in session"
|
||||
);
|
||||
});
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(UnauthorizedException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow("Invalid or expired session");
|
||||
});
|
||||
it("should throw UnauthorizedException when user is missing name", async () => {
|
||||
mockAuthService.verifySession.mockResolvedValue({
|
||||
user: { id: "1", email: "a@b.com" },
|
||||
session: mockSession,
|
||||
});
|
||||
|
||||
it("should throw UnauthorizedException if session verification fails", async () => {
|
||||
mockAuthService.verifySession.mockRejectedValue(new Error("Verification failed"));
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer valid-token",
|
||||
});
|
||||
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer error-token",
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(UnauthorizedException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(
|
||||
"Invalid user data in session"
|
||||
);
|
||||
});
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(UnauthorizedException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow("Authentication failed");
|
||||
it("should throw UnauthorizedException when user is a string", async () => {
|
||||
mockAuthService.verifySession.mockResolvedValue({
|
||||
user: "not-an-object",
|
||||
session: mockSession,
|
||||
});
|
||||
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer valid-token",
|
||||
});
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(UnauthorizedException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(
|
||||
"Invalid user data in session"
|
||||
);
|
||||
});
|
||||
|
||||
it("should reject when user is null (typeof null === 'object' causes TypeError on 'in' operator)", async () => {
|
||||
// Note: typeof null === "object" in JS, so the guard's typeof check passes
|
||||
// but "id" in null throws TypeError. The catch block propagates non-auth errors as-is.
|
||||
mockAuthService.verifySession.mockResolvedValue({
|
||||
user: null,
|
||||
session: mockSession,
|
||||
});
|
||||
|
||||
const context = createMockExecutionContext({
|
||||
authorization: "Bearer valid-token",
|
||||
});
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(TypeError);
|
||||
await expect(guard.canActivate(context)).rejects.not.toBeInstanceOf(
|
||||
UnauthorizedException
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("request attachment", () => {
|
||||
it("should attach user and session to request on success", async () => {
|
||||
mockAuthService.verifySession.mockResolvedValue(mockSessionData);
|
||||
|
||||
const mockRequest = {
|
||||
headers: {
|
||||
authorization: "Bearer valid-token",
|
||||
},
|
||||
cookies: {},
|
||||
};
|
||||
|
||||
const context = {
|
||||
switchToHttp: () => ({
|
||||
getRequest: () => mockRequest,
|
||||
}),
|
||||
} as ExecutionContext;
|
||||
|
||||
await guard.canActivate(context);
|
||||
|
||||
expect(mockRequest).toHaveProperty("user", mockSessionData.user);
|
||||
expect(mockRequest).toHaveProperty("session", mockSessionData.session);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
import { Injectable, CanActivate, ExecutionContext, UnauthorizedException } from "@nestjs/common";
|
||||
import { AuthService } from "../auth.service";
|
||||
import type { AuthenticatedRequest } from "../../common/types/user.types";
|
||||
import type { AuthUser } from "@mosaic/shared";
|
||||
import type { MaybeAuthenticatedRequest } from "../types/better-auth-request.interface";
|
||||
|
||||
@Injectable()
|
||||
export class AuthGuard implements CanActivate {
|
||||
constructor(private readonly authService: AuthService) {}
|
||||
|
||||
async canActivate(context: ExecutionContext): Promise<boolean> {
|
||||
const request = context.switchToHttp().getRequest<AuthenticatedRequest>();
|
||||
const token = this.extractTokenFromHeader(request);
|
||||
const request = context.switchToHttp().getRequest<MaybeAuthenticatedRequest>();
|
||||
|
||||
// Try to get token from either cookie (preferred) or Authorization header
|
||||
const token = this.extractToken(request);
|
||||
|
||||
if (!token) {
|
||||
throw new UnauthorizedException("No authentication token provided");
|
||||
@@ -21,25 +24,58 @@ export class AuthGuard implements CanActivate {
|
||||
throw new UnauthorizedException("Invalid or expired session");
|
||||
}
|
||||
|
||||
// Attach user to request (with type assertion for session data structure)
|
||||
const user = sessionData.user as unknown as AuthenticatedRequest["user"];
|
||||
if (!user) {
|
||||
// Attach user and session to request
|
||||
const user = sessionData.user;
|
||||
// Validate user has required fields
|
||||
if (typeof user !== "object" || !("id" in user) || !("email" in user) || !("name" in user)) {
|
||||
throw new UnauthorizedException("Invalid user data in session");
|
||||
}
|
||||
request.user = user;
|
||||
request.user = user as unknown as AuthUser;
|
||||
request.session = sessionData.session;
|
||||
|
||||
return true;
|
||||
} catch (error) {
|
||||
// Re-throw if it's already an UnauthorizedException
|
||||
if (error instanceof UnauthorizedException) {
|
||||
throw error;
|
||||
}
|
||||
throw new UnauthorizedException("Authentication failed");
|
||||
// Infrastructure errors (DB down, connection refused, timeouts) must propagate
|
||||
// as 500/503 via GlobalExceptionFilter — never mask as 401
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
private extractTokenFromHeader(request: AuthenticatedRequest): string | undefined {
|
||||
/**
|
||||
* Extract token from cookie (preferred) or Authorization header
|
||||
*/
|
||||
private extractToken(request: MaybeAuthenticatedRequest): string | undefined {
|
||||
// Try cookie first (BetterAuth default)
|
||||
const cookieToken = this.extractTokenFromCookie(request);
|
||||
if (cookieToken) {
|
||||
return cookieToken;
|
||||
}
|
||||
|
||||
// Fallback to Authorization header for API clients
|
||||
return this.extractTokenFromHeader(request);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract token from cookie (BetterAuth stores session token in better-auth.session_token cookie)
|
||||
*/
|
||||
private extractTokenFromCookie(request: MaybeAuthenticatedRequest): string | undefined {
|
||||
// Express types `cookies` as `any`; cast to a known shape for type safety.
|
||||
const cookies = request.cookies as Record<string, string> | undefined;
|
||||
if (!cookies) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// BetterAuth uses 'better-auth.session_token' as the cookie name by default
|
||||
return cookies["better-auth.session_token"];
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract token from Authorization header (Bearer token)
|
||||
*/
|
||||
private extractTokenFromHeader(request: MaybeAuthenticatedRequest): string | undefined {
|
||||
const authHeader = request.headers.authorization;
|
||||
if (typeof authHeader !== "string") {
|
||||
return undefined;
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
/**
|
||||
* BetterAuth Request Type
|
||||
* Unified request types for authentication context.
|
||||
*
|
||||
* BetterAuth expects a Request object compatible with the Fetch API standard.
|
||||
* This extends the web standard Request interface with additional properties
|
||||
* that may be present in the Express request object at runtime.
|
||||
* Replaces the previously scattered interfaces:
|
||||
* - RequestWithSession (auth.controller.ts)
|
||||
* - AuthRequest (auth.guard.ts)
|
||||
* - BetterAuthRequest (this file, removed)
|
||||
* - RequestWithUser (current-user.decorator.ts)
|
||||
*/
|
||||
|
||||
import type { Request } from "express";
|
||||
import type { AuthUser } from "@mosaic/shared";
|
||||
|
||||
// Re-export AuthUser for use in other modules
|
||||
@@ -22,19 +25,21 @@ export interface RequestSession {
|
||||
}
|
||||
|
||||
/**
|
||||
* Web standard Request interface extended with Express-specific properties
|
||||
* This matches the Fetch API Request specification that BetterAuth expects.
|
||||
* Request that may or may not have auth data (before guard runs).
|
||||
* Used by AuthGuard and other middleware that processes requests
|
||||
* before authentication is confirmed.
|
||||
*/
|
||||
export interface BetterAuthRequest extends Request {
|
||||
// Express route parameters
|
||||
params?: Record<string, string>;
|
||||
|
||||
// Express query string parameters
|
||||
query?: Record<string, string | string[]>;
|
||||
|
||||
// Session data attached by AuthGuard after successful authentication
|
||||
session?: RequestSession;
|
||||
|
||||
// Authenticated user attached by AuthGuard
|
||||
export interface MaybeAuthenticatedRequest extends Request {
|
||||
user?: AuthUser;
|
||||
session?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Request with authenticated user attached by AuthGuard.
|
||||
* After AuthGuard runs, user and session are guaranteed present.
|
||||
* Use this type in controllers/decorators that sit behind AuthGuard.
|
||||
*/
|
||||
export interface AuthenticatedRequest extends Request {
|
||||
user: AuthUser;
|
||||
session: RequestSession;
|
||||
}
|
||||
|
||||
234
apps/api/src/brain/brain-search-validation.spec.ts
Normal file
234
apps/api/src/brain/brain-search-validation.spec.ts
Normal file
@@ -0,0 +1,234 @@
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { validate } from "class-validator";
|
||||
import { plainToInstance } from "class-transformer";
|
||||
import { BadRequestException } from "@nestjs/common";
|
||||
import { BrainSearchDto, BrainQueryDto } from "./dto";
|
||||
import { BrainService } from "./brain.service";
|
||||
import { PrismaService } from "../prisma/prisma.service";
|
||||
|
||||
describe("Brain Search Validation", () => {
|
||||
describe("BrainSearchDto", () => {
|
||||
it("should accept a valid search query", async () => {
|
||||
const dto = plainToInstance(BrainSearchDto, { q: "meeting notes", limit: 10 });
|
||||
const errors = await validate(dto);
|
||||
expect(errors).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("should accept empty query params", async () => {
|
||||
const dto = plainToInstance(BrainSearchDto, {});
|
||||
const errors = await validate(dto);
|
||||
expect(errors).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("should reject search query exceeding 500 characters", async () => {
|
||||
const longQuery = "a".repeat(501);
|
||||
const dto = plainToInstance(BrainSearchDto, { q: longQuery });
|
||||
const errors = await validate(dto);
|
||||
expect(errors.length).toBeGreaterThan(0);
|
||||
const qError = errors.find((e) => e.property === "q");
|
||||
expect(qError).toBeDefined();
|
||||
expect(qError?.constraints?.maxLength).toContain("500");
|
||||
});
|
||||
|
||||
it("should accept search query at exactly 500 characters", async () => {
|
||||
const maxQuery = "a".repeat(500);
|
||||
const dto = plainToInstance(BrainSearchDto, { q: maxQuery });
|
||||
const errors = await validate(dto);
|
||||
expect(errors).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("should reject negative limit", async () => {
|
||||
const dto = plainToInstance(BrainSearchDto, { q: "test", limit: -1 });
|
||||
const errors = await validate(dto);
|
||||
expect(errors.length).toBeGreaterThan(0);
|
||||
const limitError = errors.find((e) => e.property === "limit");
|
||||
expect(limitError).toBeDefined();
|
||||
expect(limitError?.constraints?.min).toContain("1");
|
||||
});
|
||||
|
||||
it("should reject zero limit", async () => {
|
||||
const dto = plainToInstance(BrainSearchDto, { q: "test", limit: 0 });
|
||||
const errors = await validate(dto);
|
||||
expect(errors.length).toBeGreaterThan(0);
|
||||
const limitError = errors.find((e) => e.property === "limit");
|
||||
expect(limitError).toBeDefined();
|
||||
});
|
||||
|
||||
it("should reject limit exceeding 100", async () => {
|
||||
const dto = plainToInstance(BrainSearchDto, { q: "test", limit: 101 });
|
||||
const errors = await validate(dto);
|
||||
expect(errors.length).toBeGreaterThan(0);
|
||||
const limitError = errors.find((e) => e.property === "limit");
|
||||
expect(limitError).toBeDefined();
|
||||
expect(limitError?.constraints?.max).toContain("100");
|
||||
});
|
||||
|
||||
it("should accept limit at boundaries (1 and 100)", async () => {
|
||||
const dto1 = plainToInstance(BrainSearchDto, { limit: 1 });
|
||||
const errors1 = await validate(dto1);
|
||||
expect(errors1).toHaveLength(0);
|
||||
|
||||
const dto100 = plainToInstance(BrainSearchDto, { limit: 100 });
|
||||
const errors100 = await validate(dto100);
|
||||
expect(errors100).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("should reject non-integer limit", async () => {
|
||||
const dto = plainToInstance(BrainSearchDto, { limit: 10.5 });
|
||||
const errors = await validate(dto);
|
||||
expect(errors.length).toBeGreaterThan(0);
|
||||
const limitError = errors.find((e) => e.property === "limit");
|
||||
expect(limitError).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("BrainQueryDto search and query length validation", () => {
|
||||
it("should reject query exceeding 500 characters", async () => {
|
||||
const longQuery = "a".repeat(501);
|
||||
const dto = plainToInstance(BrainQueryDto, {
|
||||
workspaceId: "550e8400-e29b-41d4-a716-446655440000",
|
||||
query: longQuery,
|
||||
});
|
||||
const errors = await validate(dto);
|
||||
expect(errors.length).toBeGreaterThan(0);
|
||||
const queryError = errors.find((e) => e.property === "query");
|
||||
expect(queryError).toBeDefined();
|
||||
expect(queryError?.constraints?.maxLength).toContain("500");
|
||||
});
|
||||
|
||||
it("should reject search exceeding 500 characters", async () => {
|
||||
const longSearch = "b".repeat(501);
|
||||
const dto = plainToInstance(BrainQueryDto, {
|
||||
workspaceId: "550e8400-e29b-41d4-a716-446655440000",
|
||||
search: longSearch,
|
||||
});
|
||||
const errors = await validate(dto);
|
||||
expect(errors.length).toBeGreaterThan(0);
|
||||
const searchError = errors.find((e) => e.property === "search");
|
||||
expect(searchError).toBeDefined();
|
||||
expect(searchError?.constraints?.maxLength).toContain("500");
|
||||
});
|
||||
|
||||
it("should accept query at exactly 500 characters", async () => {
|
||||
const maxQuery = "a".repeat(500);
|
||||
const dto = plainToInstance(BrainQueryDto, {
|
||||
workspaceId: "550e8400-e29b-41d4-a716-446655440000",
|
||||
query: maxQuery,
|
||||
});
|
||||
const errors = await validate(dto);
|
||||
expect(errors).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("should accept search at exactly 500 characters", async () => {
|
||||
const maxSearch = "b".repeat(500);
|
||||
const dto = plainToInstance(BrainQueryDto, {
|
||||
workspaceId: "550e8400-e29b-41d4-a716-446655440000",
|
||||
search: maxSearch,
|
||||
});
|
||||
const errors = await validate(dto);
|
||||
expect(errors).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("BrainService.search defensive validation", () => {
|
||||
let service: BrainService;
|
||||
let prisma: {
|
||||
task: { findMany: ReturnType<typeof vi.fn> };
|
||||
event: { findMany: ReturnType<typeof vi.fn> };
|
||||
project: { findMany: ReturnType<typeof vi.fn> };
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
prisma = {
|
||||
task: { findMany: vi.fn().mockResolvedValue([]) },
|
||||
event: { findMany: vi.fn().mockResolvedValue([]) },
|
||||
project: { findMany: vi.fn().mockResolvedValue([]) },
|
||||
};
|
||||
service = new BrainService(prisma as unknown as PrismaService);
|
||||
});
|
||||
|
||||
it("should throw BadRequestException for search term exceeding 500 characters", async () => {
|
||||
const longTerm = "x".repeat(501);
|
||||
await expect(service.search("workspace-id", longTerm)).rejects.toThrow(BadRequestException);
|
||||
await expect(service.search("workspace-id", longTerm)).rejects.toThrow("500");
|
||||
});
|
||||
|
||||
it("should accept search term at exactly 500 characters", async () => {
|
||||
const maxTerm = "x".repeat(500);
|
||||
await expect(service.search("workspace-id", maxTerm)).resolves.toBeDefined();
|
||||
});
|
||||
|
||||
it("should clamp limit to max 100 when higher value provided", async () => {
|
||||
await service.search("workspace-id", "test", 200);
|
||||
expect(prisma.task.findMany).toHaveBeenCalledWith(expect.objectContaining({ take: 100 }));
|
||||
});
|
||||
|
||||
it("should clamp limit to min 1 when negative value provided", async () => {
|
||||
await service.search("workspace-id", "test", -5);
|
||||
expect(prisma.task.findMany).toHaveBeenCalledWith(expect.objectContaining({ take: 1 }));
|
||||
});
|
||||
|
||||
it("should clamp limit to min 1 when zero provided", async () => {
|
||||
await service.search("workspace-id", "test", 0);
|
||||
expect(prisma.task.findMany).toHaveBeenCalledWith(expect.objectContaining({ take: 1 }));
|
||||
});
|
||||
|
||||
it("should pass through valid limit values unchanged", async () => {
|
||||
await service.search("workspace-id", "test", 50);
|
||||
expect(prisma.task.findMany).toHaveBeenCalledWith(expect.objectContaining({ take: 50 }));
|
||||
});
|
||||
});
|
||||
|
||||
describe("BrainService.query defensive validation", () => {
|
||||
let service: BrainService;
|
||||
let prisma: {
|
||||
task: { findMany: ReturnType<typeof vi.fn> };
|
||||
event: { findMany: ReturnType<typeof vi.fn> };
|
||||
project: { findMany: ReturnType<typeof vi.fn> };
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
prisma = {
|
||||
task: { findMany: vi.fn().mockResolvedValue([]) },
|
||||
event: { findMany: vi.fn().mockResolvedValue([]) },
|
||||
project: { findMany: vi.fn().mockResolvedValue([]) },
|
||||
};
|
||||
service = new BrainService(prisma as unknown as PrismaService);
|
||||
});
|
||||
|
||||
it("should throw BadRequestException for search field exceeding 500 characters", async () => {
|
||||
const longSearch = "y".repeat(501);
|
||||
await expect(
|
||||
service.query({ workspaceId: "workspace-id", search: longSearch })
|
||||
).rejects.toThrow(BadRequestException);
|
||||
});
|
||||
|
||||
it("should throw BadRequestException for query field exceeding 500 characters", async () => {
|
||||
const longQuery = "z".repeat(501);
|
||||
await expect(
|
||||
service.query({ workspaceId: "workspace-id", query: longQuery })
|
||||
).rejects.toThrow(BadRequestException);
|
||||
});
|
||||
|
||||
it("should clamp limit to max 100 in query method", async () => {
|
||||
await service.query({ workspaceId: "workspace-id", limit: 200 });
|
||||
expect(prisma.task.findMany).toHaveBeenCalledWith(expect.objectContaining({ take: 100 }));
|
||||
});
|
||||
|
||||
it("should clamp limit to min 1 in query method when negative", async () => {
|
||||
await service.query({ workspaceId: "workspace-id", limit: -10 });
|
||||
expect(prisma.task.findMany).toHaveBeenCalledWith(expect.objectContaining({ take: 1 }));
|
||||
});
|
||||
|
||||
it("should accept valid query and search within limits", async () => {
|
||||
await expect(
|
||||
service.query({
|
||||
workspaceId: "workspace-id",
|
||||
query: "test query",
|
||||
search: "test search",
|
||||
limit: 50,
|
||||
})
|
||||
).resolves.toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -250,39 +250,33 @@ describe("BrainController", () => {
|
||||
});
|
||||
|
||||
describe("search", () => {
|
||||
it("should call service.search with parameters", async () => {
|
||||
const result = await controller.search("test query", "10", mockWorkspaceId);
|
||||
it("should call service.search with parameters from DTO", async () => {
|
||||
const result = await controller.search({ q: "test query", limit: 10 }, mockWorkspaceId);
|
||||
|
||||
expect(mockService.search).toHaveBeenCalledWith(mockWorkspaceId, "test query", 10);
|
||||
expect(result).toEqual(mockQueryResult);
|
||||
});
|
||||
|
||||
it("should use default limit when not provided", async () => {
|
||||
await controller.search("test", undefined as unknown as string, mockWorkspaceId);
|
||||
it("should use default limit when not provided in DTO", async () => {
|
||||
await controller.search({ q: "test" }, mockWorkspaceId);
|
||||
|
||||
expect(mockService.search).toHaveBeenCalledWith(mockWorkspaceId, "test", 20);
|
||||
});
|
||||
|
||||
it("should cap limit at 100", async () => {
|
||||
await controller.search("test", "500", mockWorkspaceId);
|
||||
it("should handle empty search DTO", async () => {
|
||||
await controller.search({}, mockWorkspaceId);
|
||||
|
||||
expect(mockService.search).toHaveBeenCalledWith(mockWorkspaceId, "test", 100);
|
||||
expect(mockService.search).toHaveBeenCalledWith(mockWorkspaceId, "", 20);
|
||||
});
|
||||
|
||||
it("should handle empty search term", async () => {
|
||||
await controller.search(undefined as unknown as string, "10", mockWorkspaceId);
|
||||
it("should handle undefined q in DTO", async () => {
|
||||
await controller.search({ limit: 10 }, mockWorkspaceId);
|
||||
|
||||
expect(mockService.search).toHaveBeenCalledWith(mockWorkspaceId, "", 10);
|
||||
});
|
||||
|
||||
it("should handle invalid limit", async () => {
|
||||
await controller.search("test", "invalid", mockWorkspaceId);
|
||||
|
||||
expect(mockService.search).toHaveBeenCalledWith(mockWorkspaceId, "test", 20);
|
||||
});
|
||||
|
||||
it("should return search result structure", async () => {
|
||||
const result = await controller.search("test", "10", mockWorkspaceId);
|
||||
const result = await controller.search({ q: "test", limit: 10 }, mockWorkspaceId);
|
||||
|
||||
expect(result).toHaveProperty("tasks");
|
||||
expect(result).toHaveProperty("events");
|
||||
|
||||
@@ -3,6 +3,7 @@ import { BrainService } from "./brain.service";
|
||||
import { IntentClassificationService } from "./intent-classification.service";
|
||||
import {
|
||||
BrainQueryDto,
|
||||
BrainSearchDto,
|
||||
BrainContextDto,
|
||||
ClassifyIntentDto,
|
||||
IntentClassificationResultDto,
|
||||
@@ -67,13 +68,10 @@ export class BrainController {
|
||||
*/
|
||||
@Get("search")
|
||||
@RequirePermission(Permission.WORKSPACE_ANY)
|
||||
async search(
|
||||
@Query("q") searchTerm: string,
|
||||
@Query("limit") limit: string,
|
||||
@Workspace() workspaceId: string
|
||||
) {
|
||||
const parsedLimit = limit ? Math.min(parseInt(limit, 10) || 20, 100) : 20;
|
||||
return this.brainService.search(workspaceId, searchTerm || "", parsedLimit);
|
||||
async search(@Query() searchDto: BrainSearchDto, @Workspace() workspaceId: string) {
|
||||
const searchTerm = searchDto.q ?? "";
|
||||
const limit = searchDto.limit ?? 20;
|
||||
return this.brainService.search(workspaceId, searchTerm, limit);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Injectable } from "@nestjs/common";
|
||||
import { Injectable, BadRequestException } from "@nestjs/common";
|
||||
import { EntityType, TaskStatus, ProjectStatus } from "@prisma/client";
|
||||
import { PrismaService } from "../prisma/prisma.service";
|
||||
import type { BrainQueryDto, BrainContextDto, TaskFilter, EventFilter, ProjectFilter } from "./dto";
|
||||
@@ -80,6 +80,11 @@ export interface BrainContext {
|
||||
}[];
|
||||
}
|
||||
|
||||
/** Maximum allowed length for search query strings */
|
||||
const MAX_SEARCH_LENGTH = 500;
|
||||
/** Maximum allowed limit for search results per entity type */
|
||||
const MAX_SEARCH_LIMIT = 100;
|
||||
|
||||
/**
|
||||
* @description Service for querying and aggregating workspace data for AI/brain operations.
|
||||
* Provides unified access to tasks, events, and projects with filtering and search capabilities.
|
||||
@@ -97,15 +102,28 @@ export class BrainService {
|
||||
*/
|
||||
async query(queryDto: BrainQueryDto): Promise<BrainQueryResult> {
|
||||
const { workspaceId, entities, search, limit = 20 } = queryDto;
|
||||
if (search && search.length > MAX_SEARCH_LENGTH) {
|
||||
throw new BadRequestException(
|
||||
`Search term must not exceed ${String(MAX_SEARCH_LENGTH)} characters`
|
||||
);
|
||||
}
|
||||
if (queryDto.query && queryDto.query.length > MAX_SEARCH_LENGTH) {
|
||||
throw new BadRequestException(
|
||||
`Query must not exceed ${String(MAX_SEARCH_LENGTH)} characters`
|
||||
);
|
||||
}
|
||||
const clampedLimit = Math.max(1, Math.min(limit, MAX_SEARCH_LIMIT));
|
||||
const includeEntities = entities ?? [EntityType.TASK, EntityType.EVENT, EntityType.PROJECT];
|
||||
const includeTasks = includeEntities.includes(EntityType.TASK);
|
||||
const includeEvents = includeEntities.includes(EntityType.EVENT);
|
||||
const includeProjects = includeEntities.includes(EntityType.PROJECT);
|
||||
|
||||
const [tasks, events, projects] = await Promise.all([
|
||||
includeTasks ? this.queryTasks(workspaceId, queryDto.tasks, search, limit) : [],
|
||||
includeEvents ? this.queryEvents(workspaceId, queryDto.events, search, limit) : [],
|
||||
includeProjects ? this.queryProjects(workspaceId, queryDto.projects, search, limit) : [],
|
||||
includeTasks ? this.queryTasks(workspaceId, queryDto.tasks, search, clampedLimit) : [],
|
||||
includeEvents ? this.queryEvents(workspaceId, queryDto.events, search, clampedLimit) : [],
|
||||
includeProjects
|
||||
? this.queryProjects(workspaceId, queryDto.projects, search, clampedLimit)
|
||||
: [],
|
||||
]);
|
||||
|
||||
// Build filters object conditionally for exactOptionalPropertyTypes
|
||||
@@ -259,10 +277,17 @@ export class BrainService {
|
||||
* @throws PrismaClientKnownRequestError if database query fails
|
||||
*/
|
||||
async search(workspaceId: string, searchTerm: string, limit = 20): Promise<BrainQueryResult> {
|
||||
if (searchTerm.length > MAX_SEARCH_LENGTH) {
|
||||
throw new BadRequestException(
|
||||
`Search term must not exceed ${String(MAX_SEARCH_LENGTH)} characters`
|
||||
);
|
||||
}
|
||||
const clampedLimit = Math.max(1, Math.min(limit, MAX_SEARCH_LIMIT));
|
||||
|
||||
const [tasks, events, projects] = await Promise.all([
|
||||
this.queryTasks(workspaceId, undefined, searchTerm, limit),
|
||||
this.queryEvents(workspaceId, undefined, searchTerm, limit),
|
||||
this.queryProjects(workspaceId, undefined, searchTerm, limit),
|
||||
this.queryTasks(workspaceId, undefined, searchTerm, clampedLimit),
|
||||
this.queryEvents(workspaceId, undefined, searchTerm, clampedLimit),
|
||||
this.queryProjects(workspaceId, undefined, searchTerm, clampedLimit),
|
||||
]);
|
||||
|
||||
return {
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
IsInt,
|
||||
Min,
|
||||
Max,
|
||||
MaxLength,
|
||||
IsDateString,
|
||||
IsArray,
|
||||
ValidateNested,
|
||||
@@ -105,6 +106,7 @@ export class BrainQueryDto {
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(500, { message: "query must not exceed 500 characters" })
|
||||
query?: string;
|
||||
|
||||
@IsOptional()
|
||||
@@ -129,6 +131,7 @@ export class BrainQueryDto {
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(500, { message: "search must not exceed 500 characters" })
|
||||
search?: string;
|
||||
|
||||
@IsOptional()
|
||||
@@ -162,3 +165,17 @@ export class BrainContextDto {
|
||||
@Max(30)
|
||||
eventDays?: number;
|
||||
}
|
||||
|
||||
export class BrainSearchDto {
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MaxLength(500, { message: "q must not exceed 500 characters" })
|
||||
q?: string;
|
||||
|
||||
@IsOptional()
|
||||
@Type(() => Number)
|
||||
@IsInt({ message: "limit must be an integer" })
|
||||
@Min(1, { message: "limit must be at least 1" })
|
||||
@Max(100, { message: "limit must not exceed 100" })
|
||||
limit?: number;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
export {
|
||||
BrainQueryDto,
|
||||
BrainSearchDto,
|
||||
TaskFilter,
|
||||
EventFilter,
|
||||
ProjectFilter,
|
||||
|
||||
15
apps/api/src/bridge/bridge.constants.ts
Normal file
15
apps/api/src/bridge/bridge.constants.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
/**
|
||||
* Bridge Module Constants
|
||||
*
|
||||
* Injection tokens for the bridge module.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Injection token for the array of active IChatProvider instances.
|
||||
*
|
||||
* Use this token to inject all configured chat providers:
|
||||
* ```
|
||||
* @Inject(CHAT_PROVIDERS) private readonly chatProviders: IChatProvider[]
|
||||
* ```
|
||||
*/
|
||||
export const CHAT_PROVIDERS = "CHAT_PROVIDERS";
|
||||
@@ -1,10 +1,13 @@
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { BridgeModule } from "./bridge.module";
|
||||
import { DiscordService } from "./discord/discord.service";
|
||||
import { MatrixService } from "./matrix/matrix.service";
|
||||
import { StitcherService } from "../stitcher/stitcher.service";
|
||||
import { PrismaService } from "../prisma/prisma.service";
|
||||
import { BullMqService } from "../bullmq/bullmq.service";
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { CHAT_PROVIDERS } from "./bridge.constants";
|
||||
import type { IChatProvider } from "./interfaces";
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
|
||||
|
||||
// Mock discord.js
|
||||
const mockReadyCallbacks: Array<() => void> = [];
|
||||
@@ -53,19 +56,93 @@ vi.mock("discord.js", () => {
|
||||
};
|
||||
});
|
||||
|
||||
describe("BridgeModule", () => {
|
||||
let module: TestingModule;
|
||||
// Mock matrix-bot-sdk
|
||||
vi.mock("matrix-bot-sdk", () => {
|
||||
return {
|
||||
MatrixClient: class MockMatrixClient {
|
||||
start = vi.fn().mockResolvedValue(undefined);
|
||||
stop = vi.fn();
|
||||
on = vi.fn();
|
||||
sendMessage = vi.fn().mockResolvedValue("$mock-event-id");
|
||||
},
|
||||
SimpleFsStorageProvider: class MockStorage {
|
||||
constructor(_path: string) {
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
AutojoinRoomsMixin: {
|
||||
setupOnClient: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
// Set environment variables
|
||||
process.env.DISCORD_BOT_TOKEN = "test-token";
|
||||
process.env.DISCORD_GUILD_ID = "test-guild-id";
|
||||
process.env.DISCORD_CONTROL_CHANNEL_ID = "test-channel-id";
|
||||
/**
|
||||
* Saved environment variables to restore after each test
|
||||
*/
|
||||
interface SavedEnvVars {
|
||||
DISCORD_BOT_TOKEN?: string;
|
||||
DISCORD_GUILD_ID?: string;
|
||||
DISCORD_CONTROL_CHANNEL_ID?: string;
|
||||
MATRIX_ACCESS_TOKEN?: string;
|
||||
MATRIX_HOMESERVER_URL?: string;
|
||||
MATRIX_BOT_USER_ID?: string;
|
||||
MATRIX_CONTROL_ROOM_ID?: string;
|
||||
MATRIX_WORKSPACE_ID?: string;
|
||||
ENCRYPTION_KEY?: string;
|
||||
}
|
||||
|
||||
describe("BridgeModule", () => {
|
||||
let savedEnv: SavedEnvVars;
|
||||
|
||||
beforeEach(() => {
|
||||
// Save current env vars
|
||||
savedEnv = {
|
||||
DISCORD_BOT_TOKEN: process.env.DISCORD_BOT_TOKEN,
|
||||
DISCORD_GUILD_ID: process.env.DISCORD_GUILD_ID,
|
||||
DISCORD_CONTROL_CHANNEL_ID: process.env.DISCORD_CONTROL_CHANNEL_ID,
|
||||
MATRIX_ACCESS_TOKEN: process.env.MATRIX_ACCESS_TOKEN,
|
||||
MATRIX_HOMESERVER_URL: process.env.MATRIX_HOMESERVER_URL,
|
||||
MATRIX_BOT_USER_ID: process.env.MATRIX_BOT_USER_ID,
|
||||
MATRIX_CONTROL_ROOM_ID: process.env.MATRIX_CONTROL_ROOM_ID,
|
||||
MATRIX_WORKSPACE_ID: process.env.MATRIX_WORKSPACE_ID,
|
||||
ENCRYPTION_KEY: process.env.ENCRYPTION_KEY,
|
||||
};
|
||||
|
||||
// Clear all bridge env vars
|
||||
delete process.env.DISCORD_BOT_TOKEN;
|
||||
delete process.env.DISCORD_GUILD_ID;
|
||||
delete process.env.DISCORD_CONTROL_CHANNEL_ID;
|
||||
delete process.env.MATRIX_ACCESS_TOKEN;
|
||||
delete process.env.MATRIX_HOMESERVER_URL;
|
||||
delete process.env.MATRIX_BOT_USER_ID;
|
||||
delete process.env.MATRIX_CONTROL_ROOM_ID;
|
||||
delete process.env.MATRIX_WORKSPACE_ID;
|
||||
|
||||
// Set encryption key (needed by StitcherService)
|
||||
process.env.ENCRYPTION_KEY = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
|
||||
|
||||
// Clear ready callbacks
|
||||
mockReadyCallbacks.length = 0;
|
||||
|
||||
module = await Test.createTestingModule({
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Restore env vars
|
||||
for (const [key, value] of Object.entries(savedEnv)) {
|
||||
if (value === undefined) {
|
||||
delete process.env[key];
|
||||
} else {
|
||||
process.env[key] = value;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Helper to compile a test module with BridgeModule
|
||||
*/
|
||||
async function compileModule(): Promise<TestingModule> {
|
||||
return Test.createTestingModule({
|
||||
imports: [BridgeModule],
|
||||
})
|
||||
.overrideProvider(PrismaService)
|
||||
@@ -73,24 +150,144 @@ describe("BridgeModule", () => {
|
||||
.overrideProvider(BullMqService)
|
||||
.useValue({})
|
||||
.compile();
|
||||
}
|
||||
|
||||
// Clear all mocks
|
||||
vi.clearAllMocks();
|
||||
/**
|
||||
* Helper to set Discord env vars
|
||||
*/
|
||||
function setDiscordEnv(): void {
|
||||
process.env.DISCORD_BOT_TOKEN = "test-discord-token";
|
||||
process.env.DISCORD_GUILD_ID = "test-guild-id";
|
||||
process.env.DISCORD_CONTROL_CHANNEL_ID = "test-channel-id";
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to set Matrix env vars
|
||||
*/
|
||||
function setMatrixEnv(): void {
|
||||
process.env.MATRIX_ACCESS_TOKEN = "test-matrix-token";
|
||||
process.env.MATRIX_HOMESERVER_URL = "https://matrix.example.com";
|
||||
process.env.MATRIX_BOT_USER_ID = "@bot:example.com";
|
||||
process.env.MATRIX_CONTROL_ROOM_ID = "!room:example.com";
|
||||
process.env.MATRIX_WORKSPACE_ID = "test-workspace-id";
|
||||
}
|
||||
|
||||
describe("with both Discord and Matrix configured", () => {
|
||||
let module: TestingModule;
|
||||
|
||||
beforeEach(async () => {
|
||||
setDiscordEnv();
|
||||
setMatrixEnv();
|
||||
module = await compileModule();
|
||||
});
|
||||
|
||||
it("should compile the module", () => {
|
||||
expect(module).toBeDefined();
|
||||
});
|
||||
|
||||
it("should provide DiscordService", () => {
|
||||
const discordService = module.get<DiscordService>(DiscordService);
|
||||
expect(discordService).toBeDefined();
|
||||
expect(discordService).toBeInstanceOf(DiscordService);
|
||||
});
|
||||
|
||||
it("should provide MatrixService", () => {
|
||||
const matrixService = module.get<MatrixService>(MatrixService);
|
||||
expect(matrixService).toBeDefined();
|
||||
expect(matrixService).toBeInstanceOf(MatrixService);
|
||||
});
|
||||
|
||||
it("should provide CHAT_PROVIDERS with both providers", () => {
|
||||
const chatProviders = module.get<IChatProvider[]>(CHAT_PROVIDERS);
|
||||
expect(chatProviders).toBeDefined();
|
||||
expect(chatProviders).toHaveLength(2);
|
||||
expect(chatProviders[0]).toBeInstanceOf(DiscordService);
|
||||
expect(chatProviders[1]).toBeInstanceOf(MatrixService);
|
||||
});
|
||||
|
||||
it("should provide StitcherService via StitcherModule", () => {
|
||||
const stitcherService = module.get<StitcherService>(StitcherService);
|
||||
expect(stitcherService).toBeDefined();
|
||||
expect(stitcherService).toBeInstanceOf(StitcherService);
|
||||
});
|
||||
});
|
||||
|
||||
it("should be defined", () => {
|
||||
expect(module).toBeDefined();
|
||||
describe("with only Discord configured", () => {
|
||||
let module: TestingModule;
|
||||
|
||||
beforeEach(async () => {
|
||||
setDiscordEnv();
|
||||
module = await compileModule();
|
||||
});
|
||||
|
||||
it("should compile the module", () => {
|
||||
expect(module).toBeDefined();
|
||||
});
|
||||
|
||||
it("should provide DiscordService", () => {
|
||||
const discordService = module.get<DiscordService>(DiscordService);
|
||||
expect(discordService).toBeDefined();
|
||||
expect(discordService).toBeInstanceOf(DiscordService);
|
||||
});
|
||||
|
||||
it("should provide CHAT_PROVIDERS with only Discord", () => {
|
||||
const chatProviders = module.get<IChatProvider[]>(CHAT_PROVIDERS);
|
||||
expect(chatProviders).toBeDefined();
|
||||
expect(chatProviders).toHaveLength(1);
|
||||
expect(chatProviders[0]).toBeInstanceOf(DiscordService);
|
||||
});
|
||||
});
|
||||
|
||||
it("should provide DiscordService", () => {
|
||||
const discordService = module.get<DiscordService>(DiscordService);
|
||||
expect(discordService).toBeDefined();
|
||||
expect(discordService).toBeInstanceOf(DiscordService);
|
||||
describe("with only Matrix configured", () => {
|
||||
let module: TestingModule;
|
||||
|
||||
beforeEach(async () => {
|
||||
setMatrixEnv();
|
||||
module = await compileModule();
|
||||
});
|
||||
|
||||
it("should compile the module", () => {
|
||||
expect(module).toBeDefined();
|
||||
});
|
||||
|
||||
it("should provide MatrixService", () => {
|
||||
const matrixService = module.get<MatrixService>(MatrixService);
|
||||
expect(matrixService).toBeDefined();
|
||||
expect(matrixService).toBeInstanceOf(MatrixService);
|
||||
});
|
||||
|
||||
it("should provide CHAT_PROVIDERS with only Matrix", () => {
|
||||
const chatProviders = module.get<IChatProvider[]>(CHAT_PROVIDERS);
|
||||
expect(chatProviders).toBeDefined();
|
||||
expect(chatProviders).toHaveLength(1);
|
||||
expect(chatProviders[0]).toBeInstanceOf(MatrixService);
|
||||
});
|
||||
});
|
||||
|
||||
it("should provide StitcherService", () => {
|
||||
const stitcherService = module.get<StitcherService>(StitcherService);
|
||||
expect(stitcherService).toBeDefined();
|
||||
expect(stitcherService).toBeInstanceOf(StitcherService);
|
||||
describe("with neither bridge configured", () => {
|
||||
let module: TestingModule;
|
||||
|
||||
beforeEach(async () => {
|
||||
// No env vars set for either bridge
|
||||
module = await compileModule();
|
||||
});
|
||||
|
||||
it("should compile the module without errors", () => {
|
||||
expect(module).toBeDefined();
|
||||
});
|
||||
|
||||
it("should provide CHAT_PROVIDERS as an empty array", () => {
|
||||
const chatProviders = module.get<IChatProvider[]>(CHAT_PROVIDERS);
|
||||
expect(chatProviders).toBeDefined();
|
||||
expect(chatProviders).toHaveLength(0);
|
||||
expect(Array.isArray(chatProviders)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("CHAT_PROVIDERS token", () => {
|
||||
it("should be a string constant", () => {
|
||||
expect(CHAT_PROVIDERS).toBe("CHAT_PROVIDERS");
|
||||
expect(typeof CHAT_PROVIDERS).toBe("string");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,16 +1,81 @@
|
||||
import { Module } from "@nestjs/common";
|
||||
import { Logger, Module } from "@nestjs/common";
|
||||
import { DiscordService } from "./discord/discord.service";
|
||||
import { MatrixService } from "./matrix/matrix.service";
|
||||
import { MatrixRoomService } from "./matrix/matrix-room.service";
|
||||
import { MatrixStreamingService } from "./matrix/matrix-streaming.service";
|
||||
import { CommandParserService } from "./parser/command-parser.service";
|
||||
import { StitcherModule } from "../stitcher/stitcher.module";
|
||||
import { CHAT_PROVIDERS } from "./bridge.constants";
|
||||
import type { IChatProvider } from "./interfaces";
|
||||
|
||||
const logger = new Logger("BridgeModule");
|
||||
|
||||
/**
|
||||
* Bridge Module - Chat platform integrations
|
||||
*
|
||||
* Provides integration with chat platforms (Discord, Slack, Matrix, etc.)
|
||||
* Provides integration with chat platforms (Discord, Matrix, etc.)
|
||||
* for controlling Mosaic Stack via chat commands.
|
||||
*
|
||||
* Both services are always registered as providers, but the CHAT_PROVIDERS
|
||||
* injection token only includes bridges whose environment variables are set:
|
||||
* - Discord: included when DISCORD_BOT_TOKEN is set
|
||||
* - Matrix: included when MATRIX_ACCESS_TOKEN is set
|
||||
*
|
||||
* Both bridges can run simultaneously, and no error occurs if neither is configured.
|
||||
* Consumers should inject CHAT_PROVIDERS for bridge-agnostic access to all active providers.
|
||||
*
|
||||
* CommandParserService provides shared, platform-agnostic command parsing.
|
||||
* MatrixRoomService handles workspace-to-Matrix-room mapping.
|
||||
*/
|
||||
@Module({
|
||||
imports: [StitcherModule],
|
||||
providers: [DiscordService],
|
||||
exports: [DiscordService],
|
||||
providers: [
|
||||
CommandParserService,
|
||||
MatrixRoomService,
|
||||
MatrixStreamingService,
|
||||
DiscordService,
|
||||
MatrixService,
|
||||
{
|
||||
provide: CHAT_PROVIDERS,
|
||||
useFactory: (discord: DiscordService, matrix: MatrixService): IChatProvider[] => {
|
||||
const providers: IChatProvider[] = [];
|
||||
|
||||
if (process.env.DISCORD_BOT_TOKEN) {
|
||||
providers.push(discord);
|
||||
logger.log("Discord bridge enabled (DISCORD_BOT_TOKEN detected)");
|
||||
}
|
||||
|
||||
if (process.env.MATRIX_ACCESS_TOKEN) {
|
||||
const missingVars = [
|
||||
"MATRIX_HOMESERVER_URL",
|
||||
"MATRIX_BOT_USER_ID",
|
||||
"MATRIX_WORKSPACE_ID",
|
||||
].filter((v) => !process.env[v]);
|
||||
if (missingVars.length > 0) {
|
||||
logger.warn(
|
||||
`Matrix bridge enabled but missing: ${missingVars.join(", ")}. connect() will fail.`
|
||||
);
|
||||
}
|
||||
providers.push(matrix);
|
||||
logger.log("Matrix bridge enabled (MATRIX_ACCESS_TOKEN detected)");
|
||||
}
|
||||
|
||||
if (providers.length === 0) {
|
||||
logger.warn("No chat bridges configured. Set DISCORD_BOT_TOKEN or MATRIX_ACCESS_TOKEN.");
|
||||
}
|
||||
|
||||
return providers;
|
||||
},
|
||||
inject: [DiscordService, MatrixService],
|
||||
},
|
||||
],
|
||||
exports: [
|
||||
DiscordService,
|
||||
MatrixService,
|
||||
MatrixRoomService,
|
||||
MatrixStreamingService,
|
||||
CommandParserService,
|
||||
CHAT_PROVIDERS,
|
||||
],
|
||||
})
|
||||
export class BridgeModule {}
|
||||
|
||||
@@ -187,6 +187,7 @@ describe("DiscordService", () => {
|
||||
await service.connect();
|
||||
await service.sendThreadMessage({
|
||||
threadId: "thread-123",
|
||||
channelId: "test-channel-id",
|
||||
content: "Step completed",
|
||||
});
|
||||
|
||||
|
||||
@@ -305,6 +305,7 @@ export class DiscordService implements IChatProvider {
|
||||
// Send confirmation to thread
|
||||
await this.sendThreadMessage({
|
||||
threadId,
|
||||
channelId: message.channelId,
|
||||
content: `Job created: ${result.jobId}\nStatus: ${result.status}\nQueue: ${result.queueName}`,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ export interface ThreadCreateOptions {
|
||||
|
||||
export interface ThreadMessageOptions {
|
||||
threadId: string;
|
||||
channelId: string;
|
||||
content: string;
|
||||
}
|
||||
|
||||
@@ -76,4 +77,17 @@ export interface IChatProvider {
|
||||
* Parse a command from a message
|
||||
*/
|
||||
parseCommand(message: ChatMessage): ChatCommand | null;
|
||||
|
||||
/**
|
||||
* Edit an existing message in a channel.
|
||||
*
|
||||
* Optional method for providers that support message editing
|
||||
* (e.g., Matrix via m.replace, Discord via message.edit).
|
||||
* Used for streaming AI responses with incremental updates.
|
||||
*
|
||||
* @param channelId - The channel/room ID
|
||||
* @param messageId - The original message/event ID to edit
|
||||
* @param content - The updated message content
|
||||
*/
|
||||
editMessage?(channelId: string, messageId: string, content: string): Promise<void>;
|
||||
}
|
||||
|
||||
4
apps/api/src/bridge/matrix/index.ts
Normal file
4
apps/api/src/bridge/matrix/index.ts
Normal file
@@ -0,0 +1,4 @@
|
||||
export { MatrixService } from "./matrix.service";
|
||||
export { MatrixRoomService } from "./matrix-room.service";
|
||||
export { MatrixStreamingService } from "./matrix-streaming.service";
|
||||
export type { StreamResponseOptions } from "./matrix-streaming.service";
|
||||
1065
apps/api/src/bridge/matrix/matrix-bridge.integration.spec.ts
Normal file
1065
apps/api/src/bridge/matrix/matrix-bridge.integration.spec.ts
Normal file
File diff suppressed because it is too large
Load Diff
212
apps/api/src/bridge/matrix/matrix-room.service.spec.ts
Normal file
212
apps/api/src/bridge/matrix/matrix-room.service.spec.ts
Normal file
@@ -0,0 +1,212 @@
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { MatrixRoomService } from "./matrix-room.service";
|
||||
import { MatrixService } from "./matrix.service";
|
||||
import { PrismaService } from "../../prisma/prisma.service";
|
||||
import { vi, describe, it, expect, beforeEach } from "vitest";
|
||||
|
||||
// Mock matrix-bot-sdk to avoid native module import errors
|
||||
vi.mock("matrix-bot-sdk", () => {
|
||||
return {
|
||||
MatrixClient: class MockMatrixClient {},
|
||||
SimpleFsStorageProvider: class MockStorageProvider {
|
||||
constructor(_filename: string) {
|
||||
// No-op for testing
|
||||
}
|
||||
},
|
||||
AutojoinRoomsMixin: {
|
||||
setupOnClient: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
describe("MatrixRoomService", () => {
|
||||
let service: MatrixRoomService;
|
||||
|
||||
const mockCreateRoom = vi.fn().mockResolvedValue("!new-room:example.com");
|
||||
|
||||
const mockMatrixClient = {
|
||||
createRoom: mockCreateRoom,
|
||||
};
|
||||
|
||||
const mockMatrixService = {
|
||||
isConnected: vi.fn().mockReturnValue(true),
|
||||
getClient: vi.fn().mockReturnValue(mockMatrixClient),
|
||||
};
|
||||
|
||||
const mockPrismaService = {
|
||||
workspace: {
|
||||
findUnique: vi.fn(),
|
||||
findFirst: vi.fn(),
|
||||
update: vi.fn(),
|
||||
},
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
process.env.MATRIX_SERVER_NAME = "example.com";
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
MatrixRoomService,
|
||||
{
|
||||
provide: PrismaService,
|
||||
useValue: mockPrismaService,
|
||||
},
|
||||
{
|
||||
provide: MatrixService,
|
||||
useValue: mockMatrixService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<MatrixRoomService>(MatrixRoomService);
|
||||
|
||||
vi.clearAllMocks();
|
||||
// Restore defaults after clearing
|
||||
mockMatrixService.isConnected.mockReturnValue(true);
|
||||
mockCreateRoom.mockResolvedValue("!new-room:example.com");
|
||||
mockPrismaService.workspace.update.mockResolvedValue({});
|
||||
});
|
||||
|
||||
describe("provisionRoom", () => {
|
||||
it("should create a Matrix room and store the mapping", async () => {
|
||||
const roomId = await service.provisionRoom(
|
||||
"workspace-uuid-1",
|
||||
"My Workspace",
|
||||
"my-workspace"
|
||||
);
|
||||
|
||||
expect(roomId).toBe("!new-room:example.com");
|
||||
|
||||
expect(mockCreateRoom).toHaveBeenCalledWith({
|
||||
name: "Mosaic: My Workspace",
|
||||
room_alias_name: "mosaic-my-workspace",
|
||||
topic: "Mosaic workspace: My Workspace",
|
||||
preset: "private_chat",
|
||||
visibility: "private",
|
||||
});
|
||||
|
||||
expect(mockPrismaService.workspace.update).toHaveBeenCalledWith({
|
||||
where: { id: "workspace-uuid-1" },
|
||||
data: { matrixRoomId: "!new-room:example.com" },
|
||||
});
|
||||
});
|
||||
|
||||
it("should return null when Matrix is not configured (no MatrixService)", async () => {
|
||||
// Create a service without MatrixService
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
MatrixRoomService,
|
||||
{
|
||||
provide: PrismaService,
|
||||
useValue: mockPrismaService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
const serviceWithoutMatrix = module.get<MatrixRoomService>(MatrixRoomService);
|
||||
|
||||
const roomId = await serviceWithoutMatrix.provisionRoom(
|
||||
"workspace-uuid-1",
|
||||
"My Workspace",
|
||||
"my-workspace"
|
||||
);
|
||||
|
||||
expect(roomId).toBeNull();
|
||||
expect(mockCreateRoom).not.toHaveBeenCalled();
|
||||
expect(mockPrismaService.workspace.update).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should return null when Matrix is not connected", async () => {
|
||||
mockMatrixService.isConnected.mockReturnValue(false);
|
||||
|
||||
const roomId = await service.provisionRoom(
|
||||
"workspace-uuid-1",
|
||||
"My Workspace",
|
||||
"my-workspace"
|
||||
);
|
||||
|
||||
expect(roomId).toBeNull();
|
||||
expect(mockCreateRoom).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("getRoomForWorkspace", () => {
|
||||
it("should return the room ID for a mapped workspace", async () => {
|
||||
mockPrismaService.workspace.findUnique.mockResolvedValue({
|
||||
matrixRoomId: "!mapped-room:example.com",
|
||||
});
|
||||
|
||||
const roomId = await service.getRoomForWorkspace("workspace-uuid-1");
|
||||
|
||||
expect(roomId).toBe("!mapped-room:example.com");
|
||||
expect(mockPrismaService.workspace.findUnique).toHaveBeenCalledWith({
|
||||
where: { id: "workspace-uuid-1" },
|
||||
select: { matrixRoomId: true },
|
||||
});
|
||||
});
|
||||
|
||||
it("should return null for an unmapped workspace", async () => {
|
||||
mockPrismaService.workspace.findUnique.mockResolvedValue({
|
||||
matrixRoomId: null,
|
||||
});
|
||||
|
||||
const roomId = await service.getRoomForWorkspace("workspace-uuid-2");
|
||||
|
||||
expect(roomId).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null for a non-existent workspace", async () => {
|
||||
mockPrismaService.workspace.findUnique.mockResolvedValue(null);
|
||||
|
||||
const roomId = await service.getRoomForWorkspace("non-existent-uuid");
|
||||
|
||||
expect(roomId).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("getWorkspaceForRoom", () => {
|
||||
it("should return the workspace ID for a mapped room", async () => {
|
||||
mockPrismaService.workspace.findFirst.mockResolvedValue({
|
||||
id: "workspace-uuid-1",
|
||||
});
|
||||
|
||||
const workspaceId = await service.getWorkspaceForRoom("!mapped-room:example.com");
|
||||
|
||||
expect(workspaceId).toBe("workspace-uuid-1");
|
||||
expect(mockPrismaService.workspace.findFirst).toHaveBeenCalledWith({
|
||||
where: { matrixRoomId: "!mapped-room:example.com" },
|
||||
select: { id: true },
|
||||
});
|
||||
});
|
||||
|
||||
it("should return null for an unmapped room", async () => {
|
||||
mockPrismaService.workspace.findFirst.mockResolvedValue(null);
|
||||
|
||||
const workspaceId = await service.getWorkspaceForRoom("!unknown-room:example.com");
|
||||
|
||||
expect(workspaceId).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("linkWorkspaceToRoom", () => {
|
||||
it("should store the room mapping in the workspace", async () => {
|
||||
await service.linkWorkspaceToRoom("workspace-uuid-1", "!existing-room:example.com");
|
||||
|
||||
expect(mockPrismaService.workspace.update).toHaveBeenCalledWith({
|
||||
where: { id: "workspace-uuid-1" },
|
||||
data: { matrixRoomId: "!existing-room:example.com" },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("unlinkWorkspace", () => {
|
||||
it("should remove the room mapping from the workspace", async () => {
|
||||
await service.unlinkWorkspace("workspace-uuid-1");
|
||||
|
||||
expect(mockPrismaService.workspace.update).toHaveBeenCalledWith({
|
||||
where: { id: "workspace-uuid-1" },
|
||||
data: { matrixRoomId: null },
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
154
apps/api/src/bridge/matrix/matrix-room.service.ts
Normal file
154
apps/api/src/bridge/matrix/matrix-room.service.ts
Normal file
@@ -0,0 +1,154 @@
|
||||
import { Injectable, Logger, Optional, Inject } from "@nestjs/common";
|
||||
import { PrismaService } from "../../prisma/prisma.service";
|
||||
import { MatrixService } from "./matrix.service";
|
||||
import type { MatrixClient, RoomCreateOptions } from "matrix-bot-sdk";
|
||||
|
||||
/**
|
||||
* MatrixRoomService - Workspace-to-Matrix-Room mapping and provisioning
|
||||
*
|
||||
* Responsibilities:
|
||||
* - Provision Matrix rooms for Mosaic workspaces
|
||||
* - Map workspaces to Matrix room IDs
|
||||
* - Link/unlink existing rooms to workspaces
|
||||
*
|
||||
* Room provisioning creates a private Matrix room with:
|
||||
* - Name: "Mosaic: {workspace_name}"
|
||||
* - Alias: #mosaic-{workspace_slug}:{server_name}
|
||||
* - Room ID stored in workspace.matrixRoomId
|
||||
*/
|
||||
@Injectable()
|
||||
export class MatrixRoomService {
|
||||
private readonly logger = new Logger(MatrixRoomService.name);
|
||||
|
||||
constructor(
|
||||
private readonly prisma: PrismaService,
|
||||
@Optional() @Inject(MatrixService) private readonly matrixService: MatrixService | null
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Provision a Matrix room for a workspace and store the mapping.
|
||||
*
|
||||
* @param workspaceId - The workspace UUID
|
||||
* @param workspaceName - Human-readable workspace name
|
||||
* @param workspaceSlug - URL-safe workspace identifier for the room alias
|
||||
* @returns The Matrix room ID, or null if Matrix is not configured
|
||||
*/
|
||||
async provisionRoom(
|
||||
workspaceId: string,
|
||||
workspaceName: string,
|
||||
workspaceSlug: string
|
||||
): Promise<string | null> {
|
||||
if (!this.matrixService?.isConnected()) {
|
||||
this.logger.warn("Matrix is not configured or not connected; skipping room provisioning");
|
||||
return null;
|
||||
}
|
||||
|
||||
const client = this.getMatrixClient();
|
||||
if (!client) {
|
||||
this.logger.warn("Matrix client is not available; skipping room provisioning");
|
||||
return null;
|
||||
}
|
||||
|
||||
const roomOptions: RoomCreateOptions = {
|
||||
name: `Mosaic: ${workspaceName}`,
|
||||
room_alias_name: `mosaic-${workspaceSlug}`,
|
||||
topic: `Mosaic workspace: ${workspaceName}`,
|
||||
preset: "private_chat",
|
||||
visibility: "private",
|
||||
};
|
||||
|
||||
this.logger.log(
|
||||
`Provisioning Matrix room for workspace "${workspaceName}" (${workspaceId})...`
|
||||
);
|
||||
|
||||
const roomId = await client.createRoom(roomOptions);
|
||||
|
||||
// Store the room mapping
|
||||
try {
|
||||
await this.prisma.workspace.update({
|
||||
where: { id: workspaceId },
|
||||
data: { matrixRoomId: roomId },
|
||||
});
|
||||
} catch (dbError: unknown) {
|
||||
this.logger.error(
|
||||
`Failed to store room mapping for workspace ${workspaceId}, room ${roomId} may be orphaned: ${dbError instanceof Error ? dbError.message : "unknown"}`
|
||||
);
|
||||
throw dbError;
|
||||
}
|
||||
|
||||
this.logger.log(`Matrix room ${roomId} provisioned and linked to workspace ${workspaceId}`);
|
||||
|
||||
return roomId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Look up the Matrix room ID mapped to a workspace.
|
||||
*
|
||||
* @param workspaceId - The workspace UUID
|
||||
* @returns The Matrix room ID, or null if no room is mapped
|
||||
*/
|
||||
async getRoomForWorkspace(workspaceId: string): Promise<string | null> {
|
||||
const workspace = await this.prisma.workspace.findUnique({
|
||||
where: { id: workspaceId },
|
||||
select: { matrixRoomId: true },
|
||||
});
|
||||
|
||||
if (!workspace) {
|
||||
return null;
|
||||
}
|
||||
return workspace.matrixRoomId ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reverse lookup: find the workspace that owns a given Matrix room.
|
||||
*
|
||||
* @param roomId - The Matrix room ID (e.g. "!abc:example.com")
|
||||
* @returns The workspace ID, or null if the room is not mapped to any workspace
|
||||
*/
|
||||
async getWorkspaceForRoom(roomId: string): Promise<string | null> {
|
||||
const workspace = await this.prisma.workspace.findFirst({
|
||||
where: { matrixRoomId: roomId },
|
||||
select: { id: true },
|
||||
});
|
||||
|
||||
return workspace?.id ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Manually link an existing Matrix room to a workspace.
|
||||
*
|
||||
* @param workspaceId - The workspace UUID
|
||||
* @param roomId - The Matrix room ID to link
|
||||
*/
|
||||
async linkWorkspaceToRoom(workspaceId: string, roomId: string): Promise<void> {
|
||||
await this.prisma.workspace.update({
|
||||
where: { id: workspaceId },
|
||||
data: { matrixRoomId: roomId },
|
||||
});
|
||||
|
||||
this.logger.log(`Linked workspace ${workspaceId} to Matrix room ${roomId}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove the Matrix room mapping from a workspace.
|
||||
*
|
||||
* @param workspaceId - The workspace UUID
|
||||
*/
|
||||
async unlinkWorkspace(workspaceId: string): Promise<void> {
|
||||
await this.prisma.workspace.update({
|
||||
where: { id: workspaceId },
|
||||
data: { matrixRoomId: null },
|
||||
});
|
||||
|
||||
this.logger.log(`Unlinked Matrix room from workspace ${workspaceId}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Access the underlying MatrixClient from the MatrixService
|
||||
* via the public getClient() accessor.
|
||||
*/
|
||||
private getMatrixClient(): MatrixClient | null {
|
||||
if (!this.matrixService) return null;
|
||||
return this.matrixService.getClient();
|
||||
}
|
||||
}
|
||||
408
apps/api/src/bridge/matrix/matrix-streaming.service.spec.ts
Normal file
408
apps/api/src/bridge/matrix/matrix-streaming.service.spec.ts
Normal file
@@ -0,0 +1,408 @@
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { MatrixStreamingService } from "./matrix-streaming.service";
|
||||
import { MatrixService } from "./matrix.service";
|
||||
import { vi, describe, it, expect, beforeEach, afterEach } from "vitest";
|
||||
import type { StreamResponseOptions } from "./matrix-streaming.service";
|
||||
|
||||
// Mock matrix-bot-sdk to prevent native module loading
|
||||
vi.mock("matrix-bot-sdk", () => {
|
||||
return {
|
||||
MatrixClient: class MockMatrixClient {},
|
||||
SimpleFsStorageProvider: class MockStorageProvider {
|
||||
constructor(_filename: string) {
|
||||
// No-op for testing
|
||||
}
|
||||
},
|
||||
AutojoinRoomsMixin: {
|
||||
setupOnClient: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
// Mock MatrixClient
|
||||
const mockClient = {
|
||||
sendMessage: vi.fn().mockResolvedValue("$initial-event-id"),
|
||||
sendEvent: vi.fn().mockResolvedValue("$edit-event-id"),
|
||||
setTyping: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
// Mock MatrixService
|
||||
const mockMatrixService = {
|
||||
isConnected: vi.fn().mockReturnValue(true),
|
||||
getClient: vi.fn().mockReturnValue(mockClient),
|
||||
};
|
||||
|
||||
/**
|
||||
* Helper: create an async iterable from an array of strings with optional delays
|
||||
*/
|
||||
async function* createTokenStream(
|
||||
tokens: string[],
|
||||
delayMs = 0
|
||||
): AsyncGenerator<string, void, undefined> {
|
||||
for (const token of tokens) {
|
||||
if (delayMs > 0) {
|
||||
await new Promise((resolve) => setTimeout(resolve, delayMs));
|
||||
}
|
||||
yield token;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper: create a token stream that throws an error mid-stream
|
||||
*/
|
||||
async function* createErrorStream(
|
||||
tokens: string[],
|
||||
errorAfter: number
|
||||
): AsyncGenerator<string, void, undefined> {
|
||||
let count = 0;
|
||||
for (const token of tokens) {
|
||||
if (count >= errorAfter) {
|
||||
throw new Error("LLM provider connection lost");
|
||||
}
|
||||
yield token;
|
||||
count++;
|
||||
}
|
||||
}
|
||||
|
||||
describe("MatrixStreamingService", () => {
|
||||
let service: MatrixStreamingService;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.useFakeTimers({ shouldAdvanceTime: true });
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
MatrixStreamingService,
|
||||
{
|
||||
provide: MatrixService,
|
||||
useValue: mockMatrixService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<MatrixStreamingService>(MatrixStreamingService);
|
||||
|
||||
// Clear all mocks
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Re-apply default mock returns after clearing
|
||||
mockMatrixService.isConnected.mockReturnValue(true);
|
||||
mockMatrixService.getClient.mockReturnValue(mockClient);
|
||||
mockClient.sendMessage.mockResolvedValue("$initial-event-id");
|
||||
mockClient.sendEvent.mockResolvedValue("$edit-event-id");
|
||||
mockClient.setTyping.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
describe("editMessage", () => {
|
||||
it("should send a m.replace event to edit an existing message", async () => {
|
||||
await service.editMessage("!room:example.com", "$original-event-id", "Updated content");
|
||||
|
||||
expect(mockClient.sendEvent).toHaveBeenCalledWith("!room:example.com", "m.room.message", {
|
||||
"m.new_content": {
|
||||
msgtype: "m.text",
|
||||
body: "Updated content",
|
||||
},
|
||||
"m.relates_to": {
|
||||
rel_type: "m.replace",
|
||||
event_id: "$original-event-id",
|
||||
},
|
||||
// Fallback for clients that don't support edits
|
||||
msgtype: "m.text",
|
||||
body: "* Updated content",
|
||||
});
|
||||
});
|
||||
|
||||
it("should throw error when client is not connected", async () => {
|
||||
mockMatrixService.isConnected.mockReturnValue(false);
|
||||
|
||||
await expect(
|
||||
service.editMessage("!room:example.com", "$event-id", "content")
|
||||
).rejects.toThrow("Matrix client is not connected");
|
||||
});
|
||||
|
||||
it("should throw error when client is null", async () => {
|
||||
mockMatrixService.getClient.mockReturnValue(null);
|
||||
|
||||
await expect(
|
||||
service.editMessage("!room:example.com", "$event-id", "content")
|
||||
).rejects.toThrow("Matrix client is not connected");
|
||||
});
|
||||
});
|
||||
|
||||
describe("setTypingIndicator", () => {
|
||||
it("should call client.setTyping with true and timeout", async () => {
|
||||
await service.setTypingIndicator("!room:example.com", true);
|
||||
|
||||
expect(mockClient.setTyping).toHaveBeenCalledWith("!room:example.com", true, 30000);
|
||||
});
|
||||
|
||||
it("should call client.setTyping with false to clear indicator", async () => {
|
||||
await service.setTypingIndicator("!room:example.com", false);
|
||||
|
||||
expect(mockClient.setTyping).toHaveBeenCalledWith("!room:example.com", false, undefined);
|
||||
});
|
||||
|
||||
it("should throw error when client is not connected", async () => {
|
||||
mockMatrixService.isConnected.mockReturnValue(false);
|
||||
|
||||
await expect(service.setTypingIndicator("!room:example.com", true)).rejects.toThrow(
|
||||
"Matrix client is not connected"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("sendStreamingMessage", () => {
|
||||
it("should send an initial message and return the event ID", async () => {
|
||||
const eventId = await service.sendStreamingMessage("!room:example.com", "Thinking...");
|
||||
|
||||
expect(eventId).toBe("$initial-event-id");
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith("!room:example.com", {
|
||||
msgtype: "m.text",
|
||||
body: "Thinking...",
|
||||
});
|
||||
});
|
||||
|
||||
it("should send a thread message when threadId is provided", async () => {
|
||||
const eventId = await service.sendStreamingMessage(
|
||||
"!room:example.com",
|
||||
"Thinking...",
|
||||
"$thread-root-id"
|
||||
);
|
||||
|
||||
expect(eventId).toBe("$initial-event-id");
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith("!room:example.com", {
|
||||
msgtype: "m.text",
|
||||
body: "Thinking...",
|
||||
"m.relates_to": {
|
||||
rel_type: "m.thread",
|
||||
event_id: "$thread-root-id",
|
||||
is_falling_back: true,
|
||||
"m.in_reply_to": {
|
||||
event_id: "$thread-root-id",
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should throw error when client is not connected", async () => {
|
||||
mockMatrixService.isConnected.mockReturnValue(false);
|
||||
|
||||
await expect(service.sendStreamingMessage("!room:example.com", "Test")).rejects.toThrow(
|
||||
"Matrix client is not connected"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("streamResponse", () => {
|
||||
it("should send initial 'Thinking...' message and start typing indicator", async () => {
|
||||
vi.useRealTimers();
|
||||
|
||||
const tokens = ["Hello", " world"];
|
||||
const stream = createTokenStream(tokens);
|
||||
|
||||
await service.streamResponse("!room:example.com", stream);
|
||||
|
||||
// Should have sent initial message
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith(
|
||||
"!room:example.com",
|
||||
expect.objectContaining({
|
||||
msgtype: "m.text",
|
||||
body: "Thinking...",
|
||||
})
|
||||
);
|
||||
|
||||
// Should have started typing indicator
|
||||
expect(mockClient.setTyping).toHaveBeenCalledWith("!room:example.com", true, 30000);
|
||||
});
|
||||
|
||||
it("should use custom initial message when provided", async () => {
|
||||
vi.useRealTimers();
|
||||
|
||||
const tokens = ["Hi"];
|
||||
const stream = createTokenStream(tokens);
|
||||
|
||||
const options: StreamResponseOptions = { initialMessage: "Processing..." };
|
||||
await service.streamResponse("!room:example.com", stream, options);
|
||||
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith(
|
||||
"!room:example.com",
|
||||
expect.objectContaining({
|
||||
body: "Processing...",
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should edit message with accumulated tokens on completion", async () => {
|
||||
vi.useRealTimers();
|
||||
|
||||
const tokens = ["Hello", " ", "world", "!"];
|
||||
const stream = createTokenStream(tokens);
|
||||
|
||||
await service.streamResponse("!room:example.com", stream);
|
||||
|
||||
// The final edit should contain the full accumulated text
|
||||
const sendEventCalls = mockClient.sendEvent.mock.calls;
|
||||
const lastEditCall = sendEventCalls[sendEventCalls.length - 1];
|
||||
|
||||
expect(lastEditCall).toBeDefined();
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
|
||||
expect(lastEditCall[2]["m.new_content"].body).toBe("Hello world!");
|
||||
});
|
||||
|
||||
it("should clear typing indicator on completion", async () => {
|
||||
vi.useRealTimers();
|
||||
|
||||
const tokens = ["Done"];
|
||||
const stream = createTokenStream(tokens);
|
||||
|
||||
await service.streamResponse("!room:example.com", stream);
|
||||
|
||||
// Last setTyping call should be false
|
||||
const typingCalls = mockClient.setTyping.mock.calls;
|
||||
const lastTypingCall = typingCalls[typingCalls.length - 1];
|
||||
|
||||
expect(lastTypingCall).toEqual(["!room:example.com", false, undefined]);
|
||||
});
|
||||
|
||||
it("should rate-limit edits to at most one every 500ms", async () => {
|
||||
vi.useRealTimers();
|
||||
|
||||
// Send tokens with small delays - all within one 500ms window
|
||||
const tokens = ["a", "b", "c", "d", "e"];
|
||||
const stream = createTokenStream(tokens, 50); // 50ms between tokens = 250ms total
|
||||
|
||||
await service.streamResponse("!room:example.com", stream);
|
||||
|
||||
// With 250ms total streaming time (5 tokens * 50ms), all tokens arrive
|
||||
// within one 500ms window. We expect at most 1 intermediate edit + 1 final edit,
|
||||
// or just the final edit. The key point is that there should NOT be 5 separate edits.
|
||||
const editCalls = mockClient.sendEvent.mock.calls.filter(
|
||||
(call) => call[1] === "m.room.message"
|
||||
);
|
||||
|
||||
// Should have fewer edits than tokens (rate limiting in effect)
|
||||
expect(editCalls.length).toBeLessThanOrEqual(2);
|
||||
// Should have at least the final edit
|
||||
expect(editCalls.length).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
it("should handle errors gracefully and edit message with error notice", async () => {
|
||||
vi.useRealTimers();
|
||||
|
||||
const stream = createErrorStream(["Hello", " ", "world"], 2);
|
||||
|
||||
await service.streamResponse("!room:example.com", stream);
|
||||
|
||||
// Should edit message with error content
|
||||
const sendEventCalls = mockClient.sendEvent.mock.calls;
|
||||
const lastEditCall = sendEventCalls[sendEventCalls.length - 1];
|
||||
|
||||
expect(lastEditCall).toBeDefined();
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
|
||||
const finalBody = lastEditCall[2]["m.new_content"].body as string;
|
||||
expect(finalBody).toContain("error");
|
||||
|
||||
// Should clear typing on error
|
||||
const typingCalls = mockClient.setTyping.mock.calls;
|
||||
const lastTypingCall = typingCalls[typingCalls.length - 1];
|
||||
expect(lastTypingCall).toEqual(["!room:example.com", false, undefined]);
|
||||
});
|
||||
|
||||
it("should include token usage in final message when provided", async () => {
|
||||
vi.useRealTimers();
|
||||
|
||||
const tokens = ["Hello"];
|
||||
const stream = createTokenStream(tokens);
|
||||
|
||||
const options: StreamResponseOptions = {
|
||||
showTokenUsage: true,
|
||||
tokenUsage: { prompt: 10, completion: 5, total: 15 },
|
||||
};
|
||||
|
||||
await service.streamResponse("!room:example.com", stream, options);
|
||||
|
||||
const sendEventCalls = mockClient.sendEvent.mock.calls;
|
||||
const lastEditCall = sendEventCalls[sendEventCalls.length - 1];
|
||||
|
||||
expect(lastEditCall).toBeDefined();
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
|
||||
const finalBody = lastEditCall[2]["m.new_content"].body as string;
|
||||
expect(finalBody).toContain("15");
|
||||
});
|
||||
|
||||
it("should throw error when client is not connected", async () => {
|
||||
mockMatrixService.isConnected.mockReturnValue(false);
|
||||
|
||||
const stream = createTokenStream(["test"]);
|
||||
|
||||
await expect(service.streamResponse("!room:example.com", stream)).rejects.toThrow(
|
||||
"Matrix client is not connected"
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle empty token stream", async () => {
|
||||
vi.useRealTimers();
|
||||
|
||||
const stream = createTokenStream([]);
|
||||
|
||||
await service.streamResponse("!room:example.com", stream);
|
||||
|
||||
// Should still send initial message
|
||||
expect(mockClient.sendMessage).toHaveBeenCalled();
|
||||
|
||||
// Should edit with empty/no-content message
|
||||
const sendEventCalls = mockClient.sendEvent.mock.calls;
|
||||
expect(sendEventCalls.length).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Should clear typing
|
||||
const typingCalls = mockClient.setTyping.mock.calls;
|
||||
const lastTypingCall = typingCalls[typingCalls.length - 1];
|
||||
expect(lastTypingCall).toEqual(["!room:example.com", false, undefined]);
|
||||
});
|
||||
|
||||
it("should support thread context in streamResponse", async () => {
|
||||
vi.useRealTimers();
|
||||
|
||||
const tokens = ["Reply"];
|
||||
const stream = createTokenStream(tokens);
|
||||
|
||||
const options: StreamResponseOptions = { threadId: "$thread-root" };
|
||||
await service.streamResponse("!room:example.com", stream, options);
|
||||
|
||||
// Initial message should include thread relation
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith(
|
||||
"!room:example.com",
|
||||
expect.objectContaining({
|
||||
"m.relates_to": expect.objectContaining({
|
||||
rel_type: "m.thread",
|
||||
event_id: "$thread-root",
|
||||
}),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should perform multiple edits for long-running streams", async () => {
|
||||
vi.useRealTimers();
|
||||
|
||||
// Create tokens with 200ms delays - total ~2000ms, should get multiple edit windows
|
||||
const tokens = Array.from({ length: 10 }, (_, i) => `token${String(i)} `);
|
||||
const stream = createTokenStream(tokens, 200);
|
||||
|
||||
await service.streamResponse("!room:example.com", stream);
|
||||
|
||||
// With 10 tokens at 200ms each = 2000ms total, at 500ms intervals
|
||||
// we expect roughly 3-4 intermediate edits + 1 final = 4-5 total
|
||||
const editCalls = mockClient.sendEvent.mock.calls.filter(
|
||||
(call) => call[1] === "m.room.message"
|
||||
);
|
||||
|
||||
// Should have multiple edits (at least 2) but far fewer than 10
|
||||
expect(editCalls.length).toBeGreaterThanOrEqual(2);
|
||||
expect(editCalls.length).toBeLessThanOrEqual(8);
|
||||
});
|
||||
});
|
||||
});
|
||||
248
apps/api/src/bridge/matrix/matrix-streaming.service.ts
Normal file
248
apps/api/src/bridge/matrix/matrix-streaming.service.ts
Normal file
@@ -0,0 +1,248 @@
|
||||
import { Injectable, Logger } from "@nestjs/common";
|
||||
import type { MatrixClient } from "matrix-bot-sdk";
|
||||
import { MatrixService } from "./matrix.service";
|
||||
|
||||
/**
|
||||
* Options for the streamResponse method
|
||||
*/
|
||||
export interface StreamResponseOptions {
|
||||
/** Custom initial message (defaults to "Thinking...") */
|
||||
initialMessage?: string;
|
||||
/** Thread root event ID for threaded responses */
|
||||
threadId?: string;
|
||||
/** Whether to show token usage in the final message */
|
||||
showTokenUsage?: boolean;
|
||||
/** Token usage stats to display in the final message */
|
||||
tokenUsage?: { prompt: number; completion: number; total: number };
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrix message content for m.room.message events
|
||||
*/
|
||||
interface MatrixMessageContent {
|
||||
msgtype: string;
|
||||
body: string;
|
||||
"m.new_content"?: {
|
||||
msgtype: string;
|
||||
body: string;
|
||||
};
|
||||
"m.relates_to"?: {
|
||||
rel_type: string;
|
||||
event_id: string;
|
||||
is_falling_back?: boolean;
|
||||
"m.in_reply_to"?: {
|
||||
event_id: string;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
/** Minimum interval between message edits (milliseconds) */
|
||||
const EDIT_INTERVAL_MS = 500;
|
||||
|
||||
/** Typing indicator timeout (milliseconds) */
|
||||
const TYPING_TIMEOUT_MS = 30000;
|
||||
|
||||
/**
|
||||
* Matrix Streaming Service
|
||||
*
|
||||
* Provides streaming AI response capabilities for Matrix rooms using
|
||||
* incremental message edits. Tokens from an LLM are buffered and the
|
||||
* response message is edited at rate-limited intervals, providing a
|
||||
* smooth streaming experience without excessive API calls.
|
||||
*
|
||||
* Key features:
|
||||
* - Rate-limited edits (max every 500ms)
|
||||
* - Typing indicator management during generation
|
||||
* - Graceful error handling with user-visible error notices
|
||||
* - Thread support for contextual responses
|
||||
* - LLM-agnostic design via AsyncIterable<string> token stream
|
||||
*/
|
||||
@Injectable()
|
||||
export class MatrixStreamingService {
|
||||
private readonly logger = new Logger(MatrixStreamingService.name);
|
||||
|
||||
constructor(private readonly matrixService: MatrixService) {}
|
||||
|
||||
/**
|
||||
* Edit an existing Matrix message using the m.replace relation.
|
||||
*
|
||||
* Sends a new event that replaces the content of an existing message.
|
||||
* Includes fallback content for clients that don't support edits.
|
||||
*
|
||||
* @param roomId - The Matrix room ID
|
||||
* @param eventId - The original event ID to replace
|
||||
* @param newContent - The updated message text
|
||||
*/
|
||||
async editMessage(roomId: string, eventId: string, newContent: string): Promise<void> {
|
||||
const client = this.getClientOrThrow();
|
||||
|
||||
const editContent: MatrixMessageContent = {
|
||||
"m.new_content": {
|
||||
msgtype: "m.text",
|
||||
body: newContent,
|
||||
},
|
||||
"m.relates_to": {
|
||||
rel_type: "m.replace",
|
||||
event_id: eventId,
|
||||
},
|
||||
// Fallback for clients that don't support edits
|
||||
msgtype: "m.text",
|
||||
body: `* ${newContent}`,
|
||||
};
|
||||
|
||||
await client.sendEvent(roomId, "m.room.message", editContent);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the typing indicator for the bot in a room.
|
||||
*
|
||||
* @param roomId - The Matrix room ID
|
||||
* @param typing - Whether the bot is typing
|
||||
*/
|
||||
async setTypingIndicator(roomId: string, typing: boolean): Promise<void> {
|
||||
const client = this.getClientOrThrow();
|
||||
|
||||
await client.setTyping(roomId, typing, typing ? TYPING_TIMEOUT_MS : undefined);
|
||||
}
|
||||
|
||||
/**
|
||||
* Send an initial message for streaming, optionally in a thread.
|
||||
*
|
||||
* Returns the event ID of the sent message, which can be used for
|
||||
* subsequent edits via editMessage.
|
||||
*
|
||||
* @param roomId - The Matrix room ID
|
||||
* @param content - The initial message content
|
||||
* @param threadId - Optional thread root event ID
|
||||
* @returns The event ID of the sent message
|
||||
*/
|
||||
async sendStreamingMessage(roomId: string, content: string, threadId?: string): Promise<string> {
|
||||
const client = this.getClientOrThrow();
|
||||
|
||||
const messageContent: MatrixMessageContent = {
|
||||
msgtype: "m.text",
|
||||
body: content,
|
||||
};
|
||||
|
||||
if (threadId) {
|
||||
messageContent["m.relates_to"] = {
|
||||
rel_type: "m.thread",
|
||||
event_id: threadId,
|
||||
is_falling_back: true,
|
||||
"m.in_reply_to": {
|
||||
event_id: threadId,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const eventId: string = await client.sendMessage(roomId, messageContent);
|
||||
return eventId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream an AI response to a Matrix room using incremental message edits.
|
||||
*
|
||||
* This is the main streaming method. It:
|
||||
* 1. Sends an initial "Thinking..." message
|
||||
* 2. Starts the typing indicator
|
||||
* 3. Buffers incoming tokens from the async iterable
|
||||
* 4. Edits the message every 500ms with accumulated text
|
||||
* 5. On completion: sends a final clean edit, clears typing
|
||||
* 6. On error: edits message with error notice, clears typing
|
||||
*
|
||||
* @param roomId - The Matrix room ID
|
||||
* @param tokenStream - AsyncIterable that yields string tokens
|
||||
* @param options - Optional configuration for the stream
|
||||
*/
|
||||
async streamResponse(
|
||||
roomId: string,
|
||||
tokenStream: AsyncIterable<string>,
|
||||
options?: StreamResponseOptions
|
||||
): Promise<void> {
|
||||
// Validate connection before starting
|
||||
this.getClientOrThrow();
|
||||
|
||||
const initialMessage = options?.initialMessage ?? "Thinking...";
|
||||
const threadId = options?.threadId;
|
||||
|
||||
// Step 1: Send initial message
|
||||
const eventId = await this.sendStreamingMessage(roomId, initialMessage, threadId);
|
||||
|
||||
// Step 2: Start typing indicator
|
||||
await this.setTypingIndicator(roomId, true);
|
||||
|
||||
// Step 3: Buffer and stream tokens
|
||||
let accumulatedText = "";
|
||||
let lastEditTime = 0;
|
||||
let hasError = false;
|
||||
|
||||
try {
|
||||
for await (const token of tokenStream) {
|
||||
accumulatedText += token;
|
||||
|
||||
const now = Date.now();
|
||||
const elapsed = now - lastEditTime;
|
||||
|
||||
if (elapsed >= EDIT_INTERVAL_MS && accumulatedText.length > 0) {
|
||||
await this.editMessage(roomId, eventId, accumulatedText);
|
||||
lastEditTime = now;
|
||||
}
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
hasError = true;
|
||||
const errorMessage = error instanceof Error ? error.message : "Unknown error occurred";
|
||||
|
||||
this.logger.error(`Stream error in room ${roomId}: ${errorMessage}`);
|
||||
|
||||
// Edit message to show error
|
||||
try {
|
||||
const errorContent = accumulatedText
|
||||
? `${accumulatedText}\n\n[Streaming error: ${errorMessage}]`
|
||||
: `[Streaming error: ${errorMessage}]`;
|
||||
|
||||
await this.editMessage(roomId, eventId, errorContent);
|
||||
} catch (editError: unknown) {
|
||||
this.logger.warn(
|
||||
`Failed to edit error message in ${roomId}: ${editError instanceof Error ? editError.message : "unknown"}`
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
// Step 4: Clear typing indicator
|
||||
try {
|
||||
await this.setTypingIndicator(roomId, false);
|
||||
} catch (typingError: unknown) {
|
||||
this.logger.warn(
|
||||
`Failed to clear typing indicator in ${roomId}: ${typingError instanceof Error ? typingError.message : "unknown"}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Final edit with clean output (if no error)
|
||||
if (!hasError) {
|
||||
let finalContent = accumulatedText || "(No response generated)";
|
||||
|
||||
if (options?.showTokenUsage && options.tokenUsage) {
|
||||
const { prompt, completion, total } = options.tokenUsage;
|
||||
finalContent += `\n\n---\nTokens: ${String(total)} (prompt: ${String(prompt)}, completion: ${String(completion)})`;
|
||||
}
|
||||
|
||||
await this.editMessage(roomId, eventId, finalContent);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the Matrix client from the parent MatrixService, or throw if not connected.
|
||||
*/
|
||||
private getClientOrThrow(): MatrixClient {
|
||||
if (!this.matrixService.isConnected()) {
|
||||
throw new Error("Matrix client is not connected");
|
||||
}
|
||||
|
||||
const client = this.matrixService.getClient();
|
||||
if (!client) {
|
||||
throw new Error("Matrix client is not connected");
|
||||
}
|
||||
|
||||
return client;
|
||||
}
|
||||
}
|
||||
979
apps/api/src/bridge/matrix/matrix.service.spec.ts
Normal file
979
apps/api/src/bridge/matrix/matrix.service.spec.ts
Normal file
@@ -0,0 +1,979 @@
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { MatrixService } from "./matrix.service";
|
||||
import { MatrixRoomService } from "./matrix-room.service";
|
||||
import { StitcherService } from "../../stitcher/stitcher.service";
|
||||
import { CommandParserService } from "../parser/command-parser.service";
|
||||
import { vi, describe, it, expect, beforeEach } from "vitest";
|
||||
import type { ChatMessage } from "../interfaces";
|
||||
|
||||
// Mock matrix-bot-sdk
|
||||
const mockMessageCallbacks: Array<(roomId: string, event: Record<string, unknown>) => void> = [];
|
||||
const mockEventCallbacks: Array<(roomId: string, event: Record<string, unknown>) => void> = [];
|
||||
|
||||
const mockClient = {
|
||||
start: vi.fn().mockResolvedValue(undefined),
|
||||
stop: vi.fn(),
|
||||
on: vi
|
||||
.fn()
|
||||
.mockImplementation(
|
||||
(event: string, callback: (roomId: string, evt: Record<string, unknown>) => void) => {
|
||||
if (event === "room.message") {
|
||||
mockMessageCallbacks.push(callback);
|
||||
}
|
||||
if (event === "room.event") {
|
||||
mockEventCallbacks.push(callback);
|
||||
}
|
||||
}
|
||||
),
|
||||
sendMessage: vi.fn().mockResolvedValue("$event-id-123"),
|
||||
sendEvent: vi.fn().mockResolvedValue("$event-id-456"),
|
||||
};
|
||||
|
||||
vi.mock("matrix-bot-sdk", () => {
|
||||
return {
|
||||
MatrixClient: class MockMatrixClient {
|
||||
start = mockClient.start;
|
||||
stop = mockClient.stop;
|
||||
on = mockClient.on;
|
||||
sendMessage = mockClient.sendMessage;
|
||||
sendEvent = mockClient.sendEvent;
|
||||
},
|
||||
SimpleFsStorageProvider: class MockStorageProvider {
|
||||
constructor(_filename: string) {
|
||||
// No-op for testing
|
||||
}
|
||||
},
|
||||
AutojoinRoomsMixin: {
|
||||
setupOnClient: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
describe("MatrixService", () => {
|
||||
let service: MatrixService;
|
||||
let stitcherService: StitcherService;
|
||||
let commandParser: CommandParserService;
|
||||
let matrixRoomService: MatrixRoomService;
|
||||
|
||||
const mockStitcherService = {
|
||||
dispatchJob: vi.fn().mockResolvedValue({
|
||||
jobId: "test-job-id",
|
||||
queueName: "main",
|
||||
status: "PENDING",
|
||||
}),
|
||||
trackJobEvent: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
const mockMatrixRoomService = {
|
||||
getWorkspaceForRoom: vi.fn().mockResolvedValue(null),
|
||||
getRoomForWorkspace: vi.fn().mockResolvedValue(null),
|
||||
provisionRoom: vi.fn().mockResolvedValue(null),
|
||||
linkWorkspaceToRoom: vi.fn().mockResolvedValue(undefined),
|
||||
unlinkWorkspace: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
// Set environment variables for testing
|
||||
process.env.MATRIX_HOMESERVER_URL = "https://matrix.example.com";
|
||||
process.env.MATRIX_ACCESS_TOKEN = "test-access-token";
|
||||
process.env.MATRIX_BOT_USER_ID = "@mosaic-bot:example.com";
|
||||
process.env.MATRIX_CONTROL_ROOM_ID = "!test-room:example.com";
|
||||
process.env.MATRIX_WORKSPACE_ID = "test-workspace-id";
|
||||
|
||||
// Clear callbacks
|
||||
mockMessageCallbacks.length = 0;
|
||||
mockEventCallbacks.length = 0;
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
MatrixService,
|
||||
CommandParserService,
|
||||
{
|
||||
provide: StitcherService,
|
||||
useValue: mockStitcherService,
|
||||
},
|
||||
{
|
||||
provide: MatrixRoomService,
|
||||
useValue: mockMatrixRoomService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
service = module.get<MatrixService>(MatrixService);
|
||||
stitcherService = module.get<StitcherService>(StitcherService);
|
||||
commandParser = module.get<CommandParserService>(CommandParserService);
|
||||
matrixRoomService = module.get(MatrixRoomService) as MatrixRoomService;
|
||||
|
||||
// Clear all mocks
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Connection Management", () => {
|
||||
it("should connect to Matrix", async () => {
|
||||
await service.connect();
|
||||
|
||||
expect(mockClient.start).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should disconnect from Matrix", async () => {
|
||||
await service.connect();
|
||||
await service.disconnect();
|
||||
|
||||
expect(mockClient.stop).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should check connection status", async () => {
|
||||
expect(service.isConnected()).toBe(false);
|
||||
|
||||
await service.connect();
|
||||
expect(service.isConnected()).toBe(true);
|
||||
|
||||
await service.disconnect();
|
||||
expect(service.isConnected()).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Message Handling", () => {
|
||||
it("should send a message to a room", async () => {
|
||||
await service.connect();
|
||||
await service.sendMessage("!test-room:example.com", "Hello, Matrix!");
|
||||
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith("!test-room:example.com", {
|
||||
msgtype: "m.text",
|
||||
body: "Hello, Matrix!",
|
||||
});
|
||||
});
|
||||
|
||||
it("should throw error if client is not connected", async () => {
|
||||
await expect(service.sendMessage("!room:example.com", "Test")).rejects.toThrow(
|
||||
"Matrix client is not connected"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Thread Management", () => {
|
||||
it("should create a thread by sending an initial message", async () => {
|
||||
await service.connect();
|
||||
const threadId = await service.createThread({
|
||||
channelId: "!test-room:example.com",
|
||||
name: "Job #42",
|
||||
message: "Starting job...",
|
||||
});
|
||||
|
||||
expect(threadId).toBe("$event-id-123");
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith("!test-room:example.com", {
|
||||
msgtype: "m.text",
|
||||
body: "[Job #42] Starting job...",
|
||||
});
|
||||
});
|
||||
|
||||
it("should send a message to a thread with m.thread relation", async () => {
|
||||
await service.connect();
|
||||
await service.sendThreadMessage({
|
||||
threadId: "$root-event-id",
|
||||
channelId: "!test-room:example.com",
|
||||
content: "Step completed",
|
||||
});
|
||||
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith("!test-room:example.com", {
|
||||
msgtype: "m.text",
|
||||
body: "Step completed",
|
||||
"m.relates_to": {
|
||||
rel_type: "m.thread",
|
||||
event_id: "$root-event-id",
|
||||
is_falling_back: true,
|
||||
"m.in_reply_to": {
|
||||
event_id: "$root-event-id",
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should fall back to controlRoomId when channelId is empty", async () => {
|
||||
await service.connect();
|
||||
await service.sendThreadMessage({
|
||||
threadId: "$root-event-id",
|
||||
channelId: "",
|
||||
content: "Fallback message",
|
||||
});
|
||||
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith("!test-room:example.com", {
|
||||
msgtype: "m.text",
|
||||
body: "Fallback message",
|
||||
"m.relates_to": {
|
||||
rel_type: "m.thread",
|
||||
event_id: "$root-event-id",
|
||||
is_falling_back: true,
|
||||
"m.in_reply_to": {
|
||||
event_id: "$root-event-id",
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should throw error when creating thread without connection", async () => {
|
||||
await expect(
|
||||
service.createThread({
|
||||
channelId: "!room:example.com",
|
||||
name: "Test",
|
||||
message: "Test",
|
||||
})
|
||||
).rejects.toThrow("Matrix client is not connected");
|
||||
});
|
||||
|
||||
it("should throw error when sending thread message without connection", async () => {
|
||||
await expect(
|
||||
service.sendThreadMessage({
|
||||
threadId: "$event-id",
|
||||
channelId: "!room:example.com",
|
||||
content: "Test",
|
||||
})
|
||||
).rejects.toThrow("Matrix client is not connected");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Command Parsing with shared CommandParserService", () => {
|
||||
it("should parse @mosaic fix #42 via shared parser", () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-1",
|
||||
channelId: "!room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "@mosaic fix #42",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
const command = service.parseCommand(message);
|
||||
|
||||
expect(command).not.toBeNull();
|
||||
expect(command?.command).toBe("fix");
|
||||
expect(command?.args).toContain("#42");
|
||||
});
|
||||
|
||||
it("should parse !mosaic fix #42 by normalizing to @mosaic for the shared parser", () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-1",
|
||||
channelId: "!room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "!mosaic fix #42",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
const command = service.parseCommand(message);
|
||||
|
||||
expect(command).not.toBeNull();
|
||||
expect(command?.command).toBe("fix");
|
||||
expect(command?.args).toContain("#42");
|
||||
});
|
||||
|
||||
it("should parse @mosaic status command via shared parser", () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-2",
|
||||
channelId: "!room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "@mosaic status job-123",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
const command = service.parseCommand(message);
|
||||
|
||||
expect(command).not.toBeNull();
|
||||
expect(command?.command).toBe("status");
|
||||
expect(command?.args).toContain("job-123");
|
||||
});
|
||||
|
||||
it("should parse @mosaic cancel command via shared parser", () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-3",
|
||||
channelId: "!room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "@mosaic cancel job-456",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
const command = service.parseCommand(message);
|
||||
|
||||
expect(command).not.toBeNull();
|
||||
expect(command?.command).toBe("cancel");
|
||||
});
|
||||
|
||||
it("should parse @mosaic help command via shared parser", () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-6",
|
||||
channelId: "!room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "@mosaic help",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
const command = service.parseCommand(message);
|
||||
|
||||
expect(command).not.toBeNull();
|
||||
expect(command?.command).toBe("help");
|
||||
});
|
||||
|
||||
it("should return null for non-command messages", () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-7",
|
||||
channelId: "!room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "Just a regular message",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
const command = service.parseCommand(message);
|
||||
|
||||
expect(command).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null for messages without @mosaic or !mosaic mention", () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-8",
|
||||
channelId: "!room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "fix 42",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
const command = service.parseCommand(message);
|
||||
|
||||
expect(command).toBeNull();
|
||||
});
|
||||
|
||||
it("should return null for @mosaic mention without a command", () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-11",
|
||||
channelId: "!room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "@mosaic",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
const command = service.parseCommand(message);
|
||||
|
||||
expect(command).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Event-driven message reception", () => {
|
||||
it("should ignore messages from the bot itself", async () => {
|
||||
await service.connect();
|
||||
|
||||
const parseCommandSpy = vi.spyOn(commandParser, "parseCommand");
|
||||
|
||||
// Simulate a message from the bot
|
||||
expect(mockMessageCallbacks.length).toBeGreaterThan(0);
|
||||
const callback = mockMessageCallbacks[0];
|
||||
callback?.("!test-room:example.com", {
|
||||
event_id: "$msg-1",
|
||||
sender: "@mosaic-bot:example.com",
|
||||
origin_server_ts: Date.now(),
|
||||
content: {
|
||||
msgtype: "m.text",
|
||||
body: "@mosaic fix #42",
|
||||
},
|
||||
});
|
||||
|
||||
// Should not attempt to parse
|
||||
expect(parseCommandSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should ignore messages in unmapped rooms", async () => {
|
||||
// MatrixRoomService returns null for unknown rooms
|
||||
mockMatrixRoomService.getWorkspaceForRoom.mockResolvedValue(null);
|
||||
|
||||
await service.connect();
|
||||
|
||||
const callback = mockMessageCallbacks[0];
|
||||
callback?.("!unknown-room:example.com", {
|
||||
event_id: "$msg-1",
|
||||
sender: "@user:example.com",
|
||||
origin_server_ts: Date.now(),
|
||||
content: {
|
||||
msgtype: "m.text",
|
||||
body: "@mosaic fix #42",
|
||||
},
|
||||
});
|
||||
|
||||
// Wait for async processing
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Should not dispatch to stitcher
|
||||
expect(stitcherService.dispatchJob).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should process commands in the control room (fallback workspace)", async () => {
|
||||
// MatrixRoomService returns null, but room matches controlRoomId
|
||||
mockMatrixRoomService.getWorkspaceForRoom.mockResolvedValue(null);
|
||||
|
||||
await service.connect();
|
||||
|
||||
const callback = mockMessageCallbacks[0];
|
||||
callback?.("!test-room:example.com", {
|
||||
event_id: "$msg-1",
|
||||
sender: "@user:example.com",
|
||||
origin_server_ts: Date.now(),
|
||||
content: {
|
||||
msgtype: "m.text",
|
||||
body: "@mosaic help",
|
||||
},
|
||||
});
|
||||
|
||||
// Wait for async processing
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Should send help message
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith(
|
||||
"!test-room:example.com",
|
||||
expect.objectContaining({
|
||||
body: expect.stringContaining("Available commands:"),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should process commands in rooms mapped via MatrixRoomService", async () => {
|
||||
// MatrixRoomService resolves the workspace
|
||||
mockMatrixRoomService.getWorkspaceForRoom.mockResolvedValue("mapped-workspace-id");
|
||||
|
||||
await service.connect();
|
||||
|
||||
const callback = mockMessageCallbacks[0];
|
||||
callback?.("!mapped-room:example.com", {
|
||||
event_id: "$msg-1",
|
||||
sender: "@user:example.com",
|
||||
origin_server_ts: Date.now(),
|
||||
content: {
|
||||
msgtype: "m.text",
|
||||
body: "@mosaic fix #42",
|
||||
},
|
||||
});
|
||||
|
||||
// Wait for async processing
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Should dispatch with the mapped workspace ID
|
||||
expect(stitcherService.dispatchJob).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
workspaceId: "mapped-workspace-id",
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle !mosaic prefix in incoming messages", async () => {
|
||||
mockMatrixRoomService.getWorkspaceForRoom.mockResolvedValue("test-workspace-id");
|
||||
|
||||
await service.connect();
|
||||
|
||||
const callback = mockMessageCallbacks[0];
|
||||
callback?.("!test-room:example.com", {
|
||||
event_id: "$msg-1",
|
||||
sender: "@user:example.com",
|
||||
origin_server_ts: Date.now(),
|
||||
content: {
|
||||
msgtype: "m.text",
|
||||
body: "!mosaic help",
|
||||
},
|
||||
});
|
||||
|
||||
// Wait for async processing
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Should send help message (normalized !mosaic -> @mosaic for parser)
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith(
|
||||
"!test-room:example.com",
|
||||
expect.objectContaining({
|
||||
body: expect.stringContaining("Available commands:"),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should send help text when user tries an unknown command", async () => {
|
||||
mockMatrixRoomService.getWorkspaceForRoom.mockResolvedValue("test-workspace-id");
|
||||
|
||||
await service.connect();
|
||||
|
||||
const callback = mockMessageCallbacks[0];
|
||||
callback?.("!test-room:example.com", {
|
||||
event_id: "$msg-1",
|
||||
sender: "@user:example.com",
|
||||
origin_server_ts: Date.now(),
|
||||
content: {
|
||||
msgtype: "m.text",
|
||||
body: "@mosaic invalidcommand",
|
||||
},
|
||||
});
|
||||
|
||||
// Wait for async processing
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Should send error/help message (CommandParserService returns help text for unknown actions)
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith(
|
||||
"!test-room:example.com",
|
||||
expect.objectContaining({
|
||||
body: expect.stringContaining("Available commands"),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should ignore non-text messages", async () => {
|
||||
mockMatrixRoomService.getWorkspaceForRoom.mockResolvedValue("test-workspace-id");
|
||||
|
||||
await service.connect();
|
||||
|
||||
const callback = mockMessageCallbacks[0];
|
||||
callback?.("!test-room:example.com", {
|
||||
event_id: "$msg-1",
|
||||
sender: "@user:example.com",
|
||||
origin_server_ts: Date.now(),
|
||||
content: {
|
||||
msgtype: "m.image",
|
||||
body: "photo.jpg",
|
||||
},
|
||||
});
|
||||
|
||||
// Wait for async processing
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Should not attempt any message sending
|
||||
expect(mockClient.sendMessage).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Command Execution", () => {
|
||||
it("should forward fix command to stitcher and create a thread", async () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-1",
|
||||
channelId: "!test-room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "@mosaic fix 42",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
await service.connect();
|
||||
await service.handleCommand({
|
||||
command: "fix",
|
||||
args: ["42"],
|
||||
message,
|
||||
});
|
||||
|
||||
expect(stitcherService.dispatchJob).toHaveBeenCalledWith({
|
||||
workspaceId: "test-workspace-id",
|
||||
type: "code-task",
|
||||
priority: 10,
|
||||
metadata: {
|
||||
issueNumber: 42,
|
||||
command: "fix",
|
||||
channelId: "!test-room:example.com",
|
||||
threadId: "$event-id-123",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle fix with #-prefixed issue number", async () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-1",
|
||||
channelId: "!test-room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "@mosaic fix #42",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
await service.connect();
|
||||
await service.handleCommand({
|
||||
command: "fix",
|
||||
args: ["#42"],
|
||||
message,
|
||||
});
|
||||
|
||||
expect(stitcherService.dispatchJob).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
metadata: expect.objectContaining({
|
||||
issueNumber: 42,
|
||||
}),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should respond with help message", async () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-1",
|
||||
channelId: "!test-room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "@mosaic help",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
await service.connect();
|
||||
await service.handleCommand({
|
||||
command: "help",
|
||||
args: [],
|
||||
message,
|
||||
});
|
||||
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith(
|
||||
"!test-room:example.com",
|
||||
expect.objectContaining({
|
||||
body: expect.stringContaining("Available commands:"),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should include retry command in help output", async () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-1",
|
||||
channelId: "!test-room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "@mosaic help",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
await service.connect();
|
||||
await service.handleCommand({
|
||||
command: "help",
|
||||
args: [],
|
||||
message,
|
||||
});
|
||||
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith(
|
||||
"!test-room:example.com",
|
||||
expect.objectContaining({
|
||||
body: expect.stringContaining("retry"),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should send error for fix command without issue number", async () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-1",
|
||||
channelId: "!test-room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "@mosaic fix",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
await service.connect();
|
||||
await service.handleCommand({
|
||||
command: "fix",
|
||||
args: [],
|
||||
message,
|
||||
});
|
||||
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith(
|
||||
"!test-room:example.com",
|
||||
expect.objectContaining({
|
||||
body: expect.stringContaining("Usage:"),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should send error for fix command with non-numeric issue", async () => {
|
||||
const message: ChatMessage = {
|
||||
id: "msg-1",
|
||||
channelId: "!test-room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "@mosaic fix abc",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
await service.connect();
|
||||
await service.handleCommand({
|
||||
command: "fix",
|
||||
args: ["abc"],
|
||||
message,
|
||||
});
|
||||
|
||||
expect(mockClient.sendMessage).toHaveBeenCalledWith(
|
||||
"!test-room:example.com",
|
||||
expect.objectContaining({
|
||||
body: expect.stringContaining("Invalid issue number"),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should dispatch fix command with workspace from MatrixRoomService", async () => {
|
||||
mockMatrixRoomService.getWorkspaceForRoom.mockResolvedValue("dynamic-workspace-id");
|
||||
|
||||
await service.connect();
|
||||
|
||||
const callback = mockMessageCallbacks[0];
|
||||
callback?.("!mapped-room:example.com", {
|
||||
event_id: "$msg-1",
|
||||
sender: "@user:example.com",
|
||||
origin_server_ts: Date.now(),
|
||||
content: {
|
||||
msgtype: "m.text",
|
||||
body: "@mosaic fix #99",
|
||||
},
|
||||
});
|
||||
|
||||
// Wait for async processing
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
expect(stitcherService.dispatchJob).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
workspaceId: "dynamic-workspace-id",
|
||||
metadata: expect.objectContaining({
|
||||
issueNumber: 99,
|
||||
}),
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Configuration", () => {
|
||||
it("should throw error if MATRIX_HOMESERVER_URL is not set", async () => {
|
||||
delete process.env.MATRIX_HOMESERVER_URL;
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
MatrixService,
|
||||
CommandParserService,
|
||||
{
|
||||
provide: StitcherService,
|
||||
useValue: mockStitcherService,
|
||||
},
|
||||
{
|
||||
provide: MatrixRoomService,
|
||||
useValue: mockMatrixRoomService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
const newService = module.get<MatrixService>(MatrixService);
|
||||
|
||||
await expect(newService.connect()).rejects.toThrow("MATRIX_HOMESERVER_URL is required");
|
||||
|
||||
// Restore for other tests
|
||||
process.env.MATRIX_HOMESERVER_URL = "https://matrix.example.com";
|
||||
});
|
||||
|
||||
it("should throw error if MATRIX_ACCESS_TOKEN is not set", async () => {
|
||||
delete process.env.MATRIX_ACCESS_TOKEN;
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
MatrixService,
|
||||
CommandParserService,
|
||||
{
|
||||
provide: StitcherService,
|
||||
useValue: mockStitcherService,
|
||||
},
|
||||
{
|
||||
provide: MatrixRoomService,
|
||||
useValue: mockMatrixRoomService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
const newService = module.get<MatrixService>(MatrixService);
|
||||
|
||||
await expect(newService.connect()).rejects.toThrow("MATRIX_ACCESS_TOKEN is required");
|
||||
|
||||
// Restore for other tests
|
||||
process.env.MATRIX_ACCESS_TOKEN = "test-access-token";
|
||||
});
|
||||
|
||||
it("should throw error if MATRIX_BOT_USER_ID is not set", async () => {
|
||||
delete process.env.MATRIX_BOT_USER_ID;
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
MatrixService,
|
||||
CommandParserService,
|
||||
{
|
||||
provide: StitcherService,
|
||||
useValue: mockStitcherService,
|
||||
},
|
||||
{
|
||||
provide: MatrixRoomService,
|
||||
useValue: mockMatrixRoomService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
const newService = module.get<MatrixService>(MatrixService);
|
||||
|
||||
await expect(newService.connect()).rejects.toThrow("MATRIX_BOT_USER_ID is required");
|
||||
|
||||
// Restore for other tests
|
||||
process.env.MATRIX_BOT_USER_ID = "@mosaic-bot:example.com";
|
||||
});
|
||||
|
||||
it("should throw error if MATRIX_WORKSPACE_ID is not set", async () => {
|
||||
delete process.env.MATRIX_WORKSPACE_ID;
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
MatrixService,
|
||||
CommandParserService,
|
||||
{
|
||||
provide: StitcherService,
|
||||
useValue: mockStitcherService,
|
||||
},
|
||||
{
|
||||
provide: MatrixRoomService,
|
||||
useValue: mockMatrixRoomService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
const newService = module.get<MatrixService>(MatrixService);
|
||||
|
||||
await expect(newService.connect()).rejects.toThrow("MATRIX_WORKSPACE_ID is required");
|
||||
|
||||
// Restore for other tests
|
||||
process.env.MATRIX_WORKSPACE_ID = "test-workspace-id";
|
||||
});
|
||||
|
||||
it("should use configured workspace ID from environment", async () => {
|
||||
const testWorkspaceId = "configured-workspace-456";
|
||||
process.env.MATRIX_WORKSPACE_ID = testWorkspaceId;
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
MatrixService,
|
||||
CommandParserService,
|
||||
{
|
||||
provide: StitcherService,
|
||||
useValue: mockStitcherService,
|
||||
},
|
||||
{
|
||||
provide: MatrixRoomService,
|
||||
useValue: mockMatrixRoomService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
const newService = module.get<MatrixService>(MatrixService);
|
||||
|
||||
const message: ChatMessage = {
|
||||
id: "msg-1",
|
||||
channelId: "!test-room:example.com",
|
||||
authorId: "@user:example.com",
|
||||
authorName: "@user:example.com",
|
||||
content: "@mosaic fix 42",
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
await newService.connect();
|
||||
await newService.handleCommand({
|
||||
command: "fix",
|
||||
args: ["42"],
|
||||
message,
|
||||
});
|
||||
|
||||
expect(mockStitcherService.dispatchJob).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
workspaceId: testWorkspaceId,
|
||||
})
|
||||
);
|
||||
|
||||
// Restore for other tests
|
||||
process.env.MATRIX_WORKSPACE_ID = "test-workspace-id";
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Logging Security", () => {
|
||||
it("should sanitize sensitive data in error logs", async () => {
|
||||
const loggerErrorSpy = vi.spyOn(
|
||||
(service as Record<string, unknown>)["logger"] as { error: (...args: unknown[]) => void },
|
||||
"error"
|
||||
);
|
||||
|
||||
await service.connect();
|
||||
|
||||
// Trigger room.event handler with null event to exercise error path
|
||||
expect(mockEventCallbacks.length).toBeGreaterThan(0);
|
||||
mockEventCallbacks[0]?.("!room:example.com", null as unknown as Record<string, unknown>);
|
||||
|
||||
// Verify error was logged
|
||||
expect(loggerErrorSpy).toHaveBeenCalled();
|
||||
|
||||
// Get the logged error
|
||||
const loggedArgs = loggerErrorSpy.mock.calls[0];
|
||||
const loggedError = loggedArgs?.[1] as Record<string, unknown>;
|
||||
|
||||
// Verify non-sensitive error info is preserved
|
||||
expect(loggedError).toBeDefined();
|
||||
expect((loggedError as { message: string }).message).toBe("Received null event from Matrix");
|
||||
});
|
||||
|
||||
it("should not include access token in error output", () => {
|
||||
// Verify the access token is stored privately and not exposed
|
||||
const serviceAsRecord = service as unknown as Record<string, unknown>;
|
||||
// The accessToken should exist but should not appear in any public-facing method output
|
||||
expect(serviceAsRecord["accessToken"]).toBe("test-access-token");
|
||||
|
||||
// Verify isConnected does not leak token
|
||||
const connected = service.isConnected();
|
||||
expect(String(connected)).not.toContain("test-access-token");
|
||||
});
|
||||
});
|
||||
|
||||
describe("MatrixRoomService reverse lookup", () => {
|
||||
it("should call getWorkspaceForRoom when processing messages", async () => {
|
||||
mockMatrixRoomService.getWorkspaceForRoom.mockResolvedValue("resolved-workspace");
|
||||
|
||||
await service.connect();
|
||||
|
||||
const callback = mockMessageCallbacks[0];
|
||||
callback?.("!some-room:example.com", {
|
||||
event_id: "$msg-1",
|
||||
sender: "@user:example.com",
|
||||
origin_server_ts: Date.now(),
|
||||
content: {
|
||||
msgtype: "m.text",
|
||||
body: "@mosaic help",
|
||||
},
|
||||
});
|
||||
|
||||
// Wait for async processing
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
expect(matrixRoomService.getWorkspaceForRoom).toHaveBeenCalledWith("!some-room:example.com");
|
||||
});
|
||||
|
||||
it("should fall back to control room workspace when MatrixRoomService returns null", async () => {
|
||||
mockMatrixRoomService.getWorkspaceForRoom.mockResolvedValue(null);
|
||||
|
||||
await service.connect();
|
||||
|
||||
const callback = mockMessageCallbacks[0];
|
||||
// Send to the control room (fallback path)
|
||||
callback?.("!test-room:example.com", {
|
||||
event_id: "$msg-1",
|
||||
sender: "@user:example.com",
|
||||
origin_server_ts: Date.now(),
|
||||
content: {
|
||||
msgtype: "m.text",
|
||||
body: "@mosaic fix #10",
|
||||
},
|
||||
});
|
||||
|
||||
// Wait for async processing
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Should dispatch with the env-configured workspace
|
||||
expect(stitcherService.dispatchJob).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
workspaceId: "test-workspace-id",
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
649
apps/api/src/bridge/matrix/matrix.service.ts
Normal file
649
apps/api/src/bridge/matrix/matrix.service.ts
Normal file
@@ -0,0 +1,649 @@
|
||||
import { Injectable, Logger, Optional, Inject } from "@nestjs/common";
|
||||
import { MatrixClient, SimpleFsStorageProvider, AutojoinRoomsMixin } from "matrix-bot-sdk";
|
||||
import { StitcherService } from "../../stitcher/stitcher.service";
|
||||
import { CommandParserService } from "../parser/command-parser.service";
|
||||
import { CommandAction } from "../parser/command.interface";
|
||||
import type { ParsedCommand } from "../parser/command.interface";
|
||||
import { MatrixRoomService } from "./matrix-room.service";
|
||||
import { sanitizeForLogging } from "../../common/utils";
|
||||
import type {
|
||||
IChatProvider,
|
||||
ChatMessage,
|
||||
ChatCommand,
|
||||
ThreadCreateOptions,
|
||||
ThreadMessageOptions,
|
||||
} from "../interfaces";
|
||||
|
||||
/**
|
||||
* Matrix room message event content
|
||||
*/
|
||||
interface MatrixMessageContent {
|
||||
msgtype: string;
|
||||
body: string;
|
||||
"m.relates_to"?: MatrixRelatesTo;
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrix relationship metadata for threads (MSC3440)
|
||||
*/
|
||||
interface MatrixRelatesTo {
|
||||
rel_type: string;
|
||||
event_id: string;
|
||||
is_falling_back?: boolean;
|
||||
"m.in_reply_to"?: {
|
||||
event_id: string;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrix room event structure
|
||||
*/
|
||||
interface MatrixRoomEvent {
|
||||
event_id: string;
|
||||
sender: string;
|
||||
origin_server_ts: number;
|
||||
content: MatrixMessageContent;
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrix Service - Matrix chat platform integration
|
||||
*
|
||||
* Responsibilities:
|
||||
* - Connect to Matrix via access token
|
||||
* - Listen for commands in mapped rooms (via MatrixRoomService)
|
||||
* - Parse commands using shared CommandParserService
|
||||
* - Forward commands to stitcher
|
||||
* - Receive status updates from herald
|
||||
* - Post updates to threads (MSC3440)
|
||||
*/
|
||||
@Injectable()
|
||||
export class MatrixService implements IChatProvider {
|
||||
private readonly logger = new Logger(MatrixService.name);
|
||||
private client: MatrixClient | null = null;
|
||||
private connected = false;
|
||||
private readonly homeserverUrl: string;
|
||||
private readonly accessToken: string;
|
||||
private readonly botUserId: string;
|
||||
private readonly controlRoomId: string;
|
||||
private readonly workspaceId: string;
|
||||
|
||||
constructor(
|
||||
private readonly stitcherService: StitcherService,
|
||||
@Optional()
|
||||
@Inject(CommandParserService)
|
||||
private readonly commandParser: CommandParserService | null,
|
||||
@Optional()
|
||||
@Inject(MatrixRoomService)
|
||||
private readonly matrixRoomService: MatrixRoomService | null
|
||||
) {
|
||||
this.homeserverUrl = process.env.MATRIX_HOMESERVER_URL ?? "";
|
||||
this.accessToken = process.env.MATRIX_ACCESS_TOKEN ?? "";
|
||||
this.botUserId = process.env.MATRIX_BOT_USER_ID ?? "";
|
||||
this.controlRoomId = process.env.MATRIX_CONTROL_ROOM_ID ?? "";
|
||||
this.workspaceId = process.env.MATRIX_WORKSPACE_ID ?? "";
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to Matrix homeserver
|
||||
*/
|
||||
async connect(): Promise<void> {
|
||||
if (!this.homeserverUrl) {
|
||||
throw new Error("MATRIX_HOMESERVER_URL is required");
|
||||
}
|
||||
|
||||
if (!this.accessToken) {
|
||||
throw new Error("MATRIX_ACCESS_TOKEN is required");
|
||||
}
|
||||
|
||||
if (!this.workspaceId) {
|
||||
throw new Error("MATRIX_WORKSPACE_ID is required");
|
||||
}
|
||||
|
||||
if (!this.botUserId) {
|
||||
throw new Error("MATRIX_BOT_USER_ID is required");
|
||||
}
|
||||
|
||||
this.logger.log("Connecting to Matrix...");
|
||||
|
||||
const storage = new SimpleFsStorageProvider("matrix-bot-storage.json");
|
||||
this.client = new MatrixClient(this.homeserverUrl, this.accessToken, storage);
|
||||
|
||||
// Auto-join rooms when invited
|
||||
AutojoinRoomsMixin.setupOnClient(this.client);
|
||||
|
||||
// Setup event handlers
|
||||
this.setupEventHandlers();
|
||||
|
||||
// Start syncing
|
||||
await this.client.start();
|
||||
this.connected = true;
|
||||
this.logger.log(`Matrix bot connected as ${this.botUserId}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Setup event handlers for Matrix client
|
||||
*/
|
||||
private setupEventHandlers(): void {
|
||||
if (!this.client) return;
|
||||
|
||||
this.client.on("room.message", (roomId: string, event: MatrixRoomEvent) => {
|
||||
// Ignore messages from the bot itself
|
||||
if (event.sender === this.botUserId) return;
|
||||
|
||||
// Only handle text messages
|
||||
if (event.content.msgtype !== "m.text") return;
|
||||
|
||||
this.handleRoomMessage(roomId, event).catch((error: unknown) => {
|
||||
this.logger.error(
|
||||
`Error handling room message in ${roomId}:`,
|
||||
error instanceof Error ? error.message : error
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
this.client.on("room.event", (_roomId: string, event: MatrixRoomEvent | null) => {
|
||||
// Handle errors emitted as events
|
||||
if (!event) {
|
||||
const error = new Error("Received null event from Matrix");
|
||||
const sanitizedError = sanitizeForLogging(error);
|
||||
this.logger.error("Matrix client error:", sanitizedError);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle an incoming room message.
|
||||
*
|
||||
* Resolves the workspace for the room (via MatrixRoomService or fallback
|
||||
* to the control room), then delegates to the shared CommandParserService
|
||||
* for platform-agnostic command parsing and dispatches the result.
|
||||
*/
|
||||
private async handleRoomMessage(roomId: string, event: MatrixRoomEvent): Promise<void> {
|
||||
// Resolve workspace: try MatrixRoomService first, fall back to control room
|
||||
let resolvedWorkspaceId: string | null = null;
|
||||
|
||||
if (this.matrixRoomService) {
|
||||
resolvedWorkspaceId = await this.matrixRoomService.getWorkspaceForRoom(roomId);
|
||||
}
|
||||
|
||||
// Fallback: if the room is the configured control room, use the env workspace
|
||||
if (!resolvedWorkspaceId && roomId === this.controlRoomId) {
|
||||
resolvedWorkspaceId = this.workspaceId;
|
||||
}
|
||||
|
||||
// If room is not mapped to any workspace, ignore the message
|
||||
if (!resolvedWorkspaceId) {
|
||||
return;
|
||||
}
|
||||
|
||||
const messageContent = event.content.body;
|
||||
|
||||
// Build ChatMessage for interface compatibility
|
||||
const chatMessage: ChatMessage = {
|
||||
id: event.event_id,
|
||||
channelId: roomId,
|
||||
authorId: event.sender,
|
||||
authorName: event.sender,
|
||||
content: messageContent,
|
||||
timestamp: new Date(event.origin_server_ts),
|
||||
...(event.content["m.relates_to"]?.rel_type === "m.thread" && {
|
||||
threadId: event.content["m.relates_to"].event_id,
|
||||
}),
|
||||
};
|
||||
|
||||
// Use shared CommandParserService if available
|
||||
if (this.commandParser) {
|
||||
// Normalize !mosaic to @mosaic for the shared parser
|
||||
const normalizedContent = messageContent.replace(/^!mosaic/i, "@mosaic");
|
||||
|
||||
const result = this.commandParser.parseCommand(normalizedContent);
|
||||
|
||||
if (result.success) {
|
||||
await this.handleParsedCommand(result.command, chatMessage, resolvedWorkspaceId);
|
||||
} else if (normalizedContent.toLowerCase().startsWith("@mosaic")) {
|
||||
// The user tried to use a command but it failed to parse -- send help
|
||||
await this.sendMessage(roomId, result.error.help ?? result.error.message);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback: use the built-in parseCommand if CommandParserService not injected
|
||||
const command = this.parseCommand(chatMessage);
|
||||
if (command) {
|
||||
await this.handleCommand(command);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle a command parsed by the shared CommandParserService.
|
||||
*
|
||||
* Routes the ParsedCommand to the appropriate handler, passing
|
||||
* along workspace context for job dispatch.
|
||||
*/
|
||||
private async handleParsedCommand(
|
||||
parsed: ParsedCommand,
|
||||
message: ChatMessage,
|
||||
workspaceId: string
|
||||
): Promise<void> {
|
||||
this.logger.log(
|
||||
`Handling command: ${parsed.action} from ${message.authorName} in workspace ${workspaceId}`
|
||||
);
|
||||
|
||||
switch (parsed.action) {
|
||||
case CommandAction.FIX:
|
||||
await this.handleFixCommand(parsed.rawArgs, message, workspaceId);
|
||||
break;
|
||||
case CommandAction.STATUS:
|
||||
await this.handleStatusCommand(parsed.rawArgs, message);
|
||||
break;
|
||||
case CommandAction.CANCEL:
|
||||
await this.handleCancelCommand(parsed.rawArgs, message);
|
||||
break;
|
||||
case CommandAction.VERBOSE:
|
||||
await this.handleVerboseCommand(parsed.rawArgs, message);
|
||||
break;
|
||||
case CommandAction.QUIET:
|
||||
await this.handleQuietCommand(parsed.rawArgs, message);
|
||||
break;
|
||||
case CommandAction.HELP:
|
||||
await this.handleHelpCommand(parsed.rawArgs, message);
|
||||
break;
|
||||
case CommandAction.RETRY:
|
||||
await this.handleRetryCommand(parsed.rawArgs, message);
|
||||
break;
|
||||
default:
|
||||
await this.sendMessage(
|
||||
message.channelId,
|
||||
`Unknown command. Type \`@mosaic help\` or \`!mosaic help\` for available commands.`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Disconnect from Matrix
|
||||
*/
|
||||
disconnect(): Promise<void> {
|
||||
this.logger.log("Disconnecting from Matrix...");
|
||||
this.connected = false;
|
||||
if (this.client) {
|
||||
this.client.stop();
|
||||
}
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the provider is connected
|
||||
*/
|
||||
isConnected(): boolean {
|
||||
return this.connected;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the underlying MatrixClient instance.
|
||||
*
|
||||
* Used by MatrixStreamingService for low-level operations
|
||||
* (message edits, typing indicators) that require direct client access.
|
||||
*
|
||||
* @returns The MatrixClient instance, or null if not connected
|
||||
*/
|
||||
getClient(): MatrixClient | null {
|
||||
return this.client;
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a message to a room
|
||||
*/
|
||||
async sendMessage(roomId: string, content: string): Promise<void> {
|
||||
if (!this.client) {
|
||||
throw new Error("Matrix client is not connected");
|
||||
}
|
||||
|
||||
const messageContent: MatrixMessageContent = {
|
||||
msgtype: "m.text",
|
||||
body: content,
|
||||
};
|
||||
|
||||
await this.client.sendMessage(roomId, messageContent);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a thread for job updates (MSC3440)
|
||||
*
|
||||
* Matrix threads are created by sending an initial message
|
||||
* and then replying with m.thread relation. The initial
|
||||
* message event ID becomes the thread root.
|
||||
*/
|
||||
async createThread(options: ThreadCreateOptions): Promise<string> {
|
||||
if (!this.client) {
|
||||
throw new Error("Matrix client is not connected");
|
||||
}
|
||||
|
||||
const { channelId, name, message } = options;
|
||||
|
||||
// Send the initial message that becomes the thread root
|
||||
const initialContent: MatrixMessageContent = {
|
||||
msgtype: "m.text",
|
||||
body: `[${name}] ${message}`,
|
||||
};
|
||||
|
||||
const eventId = await this.client.sendMessage(channelId, initialContent);
|
||||
|
||||
return eventId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a message to a thread (MSC3440)
|
||||
*
|
||||
* Uses m.thread relation to associate the message with the thread root event.
|
||||
*/
|
||||
async sendThreadMessage(options: ThreadMessageOptions): Promise<void> {
|
||||
if (!this.client) {
|
||||
throw new Error("Matrix client is not connected");
|
||||
}
|
||||
|
||||
const { threadId, channelId, content } = options;
|
||||
|
||||
// Use the channelId from options (threads are room-scoped), fall back to control room
|
||||
const roomId = channelId || this.controlRoomId;
|
||||
|
||||
const threadContent: MatrixMessageContent = {
|
||||
msgtype: "m.text",
|
||||
body: content,
|
||||
"m.relates_to": {
|
||||
rel_type: "m.thread",
|
||||
event_id: threadId,
|
||||
is_falling_back: true,
|
||||
"m.in_reply_to": {
|
||||
event_id: threadId,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
await this.client.sendMessage(roomId, threadContent);
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a command from a message (IChatProvider interface).
|
||||
*
|
||||
* Delegates to the shared CommandParserService when available,
|
||||
* falling back to built-in parsing for backwards compatibility.
|
||||
*/
|
||||
parseCommand(message: ChatMessage): ChatCommand | null {
|
||||
const { content } = message;
|
||||
|
||||
// Try shared parser first
|
||||
if (this.commandParser) {
|
||||
const normalizedContent = content.replace(/^!mosaic/i, "@mosaic");
|
||||
const result = this.commandParser.parseCommand(normalizedContent);
|
||||
|
||||
if (result.success) {
|
||||
return {
|
||||
command: result.command.action,
|
||||
args: result.command.rawArgs,
|
||||
message,
|
||||
};
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// Fallback: built-in parsing for when CommandParserService is not injected
|
||||
const lowerContent = content.toLowerCase();
|
||||
if (!lowerContent.includes("@mosaic") && !lowerContent.includes("!mosaic")) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const parts = content.trim().split(/\s+/);
|
||||
const mosaicIndex = parts.findIndex(
|
||||
(part) => part.toLowerCase().includes("@mosaic") || part.toLowerCase().includes("!mosaic")
|
||||
);
|
||||
|
||||
if (mosaicIndex === -1 || mosaicIndex === parts.length - 1) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const commandPart = parts[mosaicIndex + 1];
|
||||
if (!commandPart) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const command = commandPart.toLowerCase();
|
||||
const args = parts.slice(mosaicIndex + 2);
|
||||
|
||||
const validCommands = ["fix", "status", "cancel", "verbose", "quiet", "help"];
|
||||
|
||||
if (!validCommands.includes(command)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
command,
|
||||
args,
|
||||
message,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle a parsed command (ChatCommand format, used by fallback path)
|
||||
*/
|
||||
async handleCommand(command: ChatCommand): Promise<void> {
|
||||
const { command: cmd, args, message } = command;
|
||||
|
||||
this.logger.log(
|
||||
`Handling command: ${cmd} with args: ${args.join(", ")} from ${message.authorName}`
|
||||
);
|
||||
|
||||
switch (cmd) {
|
||||
case "fix":
|
||||
await this.handleFixCommand(args, message, this.workspaceId);
|
||||
break;
|
||||
case "status":
|
||||
await this.handleStatusCommand(args, message);
|
||||
break;
|
||||
case "cancel":
|
||||
await this.handleCancelCommand(args, message);
|
||||
break;
|
||||
case "verbose":
|
||||
await this.handleVerboseCommand(args, message);
|
||||
break;
|
||||
case "quiet":
|
||||
await this.handleQuietCommand(args, message);
|
||||
break;
|
||||
case "help":
|
||||
await this.handleHelpCommand(args, message);
|
||||
break;
|
||||
default:
|
||||
await this.sendMessage(
|
||||
message.channelId,
|
||||
`Unknown command: ${cmd}. Type \`@mosaic help\` or \`!mosaic help\` for available commands.`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle fix command - Start a job for an issue
|
||||
*/
|
||||
private async handleFixCommand(
|
||||
args: string[],
|
||||
message: ChatMessage,
|
||||
workspaceId?: string
|
||||
): Promise<void> {
|
||||
if (args.length === 0 || !args[0]) {
|
||||
await this.sendMessage(
|
||||
message.channelId,
|
||||
"Usage: `@mosaic fix <issue-number>` or `!mosaic fix <issue-number>`"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse issue number: handle both "#42" and "42" formats
|
||||
const issueArg = args[0].replace(/^#/, "");
|
||||
const issueNumber = parseInt(issueArg, 10);
|
||||
|
||||
if (isNaN(issueNumber)) {
|
||||
await this.sendMessage(
|
||||
message.channelId,
|
||||
"Invalid issue number. Please provide a numeric issue number."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const targetWorkspaceId = workspaceId ?? this.workspaceId;
|
||||
|
||||
// Create thread for job updates
|
||||
const threadId = await this.createThread({
|
||||
channelId: message.channelId,
|
||||
name: `Job #${String(issueNumber)}`,
|
||||
message: `Starting job for issue #${String(issueNumber)}...`,
|
||||
});
|
||||
|
||||
// Dispatch job to stitcher
|
||||
try {
|
||||
const result = await this.stitcherService.dispatchJob({
|
||||
workspaceId: targetWorkspaceId,
|
||||
type: "code-task",
|
||||
priority: 10,
|
||||
metadata: {
|
||||
issueNumber,
|
||||
command: "fix",
|
||||
channelId: message.channelId,
|
||||
threadId: threadId,
|
||||
authorId: message.authorId,
|
||||
authorName: message.authorName,
|
||||
},
|
||||
});
|
||||
|
||||
// Send confirmation to thread
|
||||
await this.sendThreadMessage({
|
||||
threadId,
|
||||
channelId: message.channelId,
|
||||
content: `Job created: ${result.jobId}\nStatus: ${result.status}\nQueue: ${result.queueName}`,
|
||||
});
|
||||
} catch (error: unknown) {
|
||||
const errorMessage = error instanceof Error ? error.message : "Unknown error";
|
||||
this.logger.error(
|
||||
`Failed to dispatch job for issue #${String(issueNumber)}: ${errorMessage}`
|
||||
);
|
||||
await this.sendThreadMessage({
|
||||
threadId,
|
||||
channelId: message.channelId,
|
||||
content: `Failed to start job: ${errorMessage}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle status command - Get job status
|
||||
*/
|
||||
private async handleStatusCommand(args: string[], message: ChatMessage): Promise<void> {
|
||||
if (args.length === 0 || !args[0]) {
|
||||
await this.sendMessage(
|
||||
message.channelId,
|
||||
"Usage: `@mosaic status <job-id>` or `!mosaic status <job-id>`"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const jobId = args[0];
|
||||
|
||||
// TODO: Implement job status retrieval from stitcher
|
||||
await this.sendMessage(
|
||||
message.channelId,
|
||||
`Status command not yet implemented for job: ${jobId}`
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle cancel command - Cancel a running job
|
||||
*/
|
||||
private async handleCancelCommand(args: string[], message: ChatMessage): Promise<void> {
|
||||
if (args.length === 0 || !args[0]) {
|
||||
await this.sendMessage(
|
||||
message.channelId,
|
||||
"Usage: `@mosaic cancel <job-id>` or `!mosaic cancel <job-id>`"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const jobId = args[0];
|
||||
|
||||
// TODO: Implement job cancellation in stitcher
|
||||
await this.sendMessage(
|
||||
message.channelId,
|
||||
`Cancel command not yet implemented for job: ${jobId}`
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle retry command - Retry a failed job
|
||||
*/
|
||||
private async handleRetryCommand(args: string[], message: ChatMessage): Promise<void> {
|
||||
if (args.length === 0 || !args[0]) {
|
||||
await this.sendMessage(
|
||||
message.channelId,
|
||||
"Usage: `@mosaic retry <job-id>` or `!mosaic retry <job-id>`"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const jobId = args[0];
|
||||
|
||||
// TODO: Implement job retry in stitcher
|
||||
await this.sendMessage(
|
||||
message.channelId,
|
||||
`Retry command not yet implemented for job: ${jobId}`
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle verbose command - Stream full logs to thread
|
||||
*/
|
||||
private async handleVerboseCommand(args: string[], message: ChatMessage): Promise<void> {
|
||||
if (args.length === 0 || !args[0]) {
|
||||
await this.sendMessage(
|
||||
message.channelId,
|
||||
"Usage: `@mosaic verbose <job-id>` or `!mosaic verbose <job-id>`"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const jobId = args[0];
|
||||
|
||||
// TODO: Implement verbose logging
|
||||
await this.sendMessage(message.channelId, `Verbose mode not yet implemented for job: ${jobId}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle quiet command - Reduce notifications
|
||||
*/
|
||||
private async handleQuietCommand(_args: string[], message: ChatMessage): Promise<void> {
|
||||
// TODO: Implement quiet mode
|
||||
await this.sendMessage(
|
||||
message.channelId,
|
||||
"Quiet mode not yet implemented. Currently showing milestone updates only."
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle help command - Show available commands
|
||||
*/
|
||||
private async handleHelpCommand(_args: string[], message: ChatMessage): Promise<void> {
|
||||
const helpMessage = `
|
||||
**Available commands:**
|
||||
|
||||
\`@mosaic fix <issue>\` or \`!mosaic fix <issue>\` - Start job for issue
|
||||
\`@mosaic status <job>\` or \`!mosaic status <job>\` - Get job status
|
||||
\`@mosaic cancel <job>\` or \`!mosaic cancel <job>\` - Cancel running job
|
||||
\`@mosaic retry <job>\` or \`!mosaic retry <job>\` - Retry failed job
|
||||
\`@mosaic verbose <job>\` or \`!mosaic verbose <job>\` - Stream full logs to thread
|
||||
\`@mosaic quiet\` or \`!mosaic quiet\` - Reduce notifications
|
||||
\`@mosaic help\` or \`!mosaic help\` - Show this help message
|
||||
|
||||
**Noise Management:**
|
||||
- Main room: Low verbosity (milestones only)
|
||||
- Job threads: Medium verbosity (step completions)
|
||||
- DMs: Configurable per user
|
||||
`.trim();
|
||||
|
||||
await this.sendMessage(message.channelId, helpMessage);
|
||||
}
|
||||
}
|
||||
177
apps/api/src/common/controllers/csrf.controller.spec.ts
Normal file
177
apps/api/src/common/controllers/csrf.controller.spec.ts
Normal file
@@ -0,0 +1,177 @@
|
||||
/**
|
||||
* CSRF Controller Tests
|
||||
*
|
||||
* Tests CSRF token generation endpoint with session binding.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
|
||||
import { Request, Response } from "express";
|
||||
import { CsrfController } from "./csrf.controller";
|
||||
import { CsrfService } from "../services/csrf.service";
|
||||
import type { AuthenticatedUser } from "../types/user.types";
|
||||
|
||||
interface AuthenticatedRequest extends Request {
|
||||
user?: AuthenticatedUser;
|
||||
}
|
||||
|
||||
describe("CsrfController", () => {
|
||||
let controller: CsrfController;
|
||||
let csrfService: CsrfService;
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
process.env = { ...originalEnv };
|
||||
process.env.CSRF_SECRET = "test-secret-0123456789abcdef0123456789abcdef";
|
||||
csrfService = new CsrfService();
|
||||
csrfService.onModuleInit();
|
||||
controller = new CsrfController(csrfService);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
const createMockRequest = (userId?: string): AuthenticatedRequest => {
|
||||
return {
|
||||
user: userId ? { id: userId, email: "test@example.com", name: "Test User" } : undefined,
|
||||
} as AuthenticatedRequest;
|
||||
};
|
||||
|
||||
const createMockResponse = (): Response => {
|
||||
return {
|
||||
cookie: vi.fn(),
|
||||
} as unknown as Response;
|
||||
};
|
||||
|
||||
describe("getCsrfToken", () => {
|
||||
it("should generate and return a CSRF token with session binding", () => {
|
||||
const mockRequest = createMockRequest("user-123");
|
||||
const mockResponse = createMockResponse();
|
||||
|
||||
const result = controller.getCsrfToken(mockRequest, mockResponse);
|
||||
|
||||
expect(result).toHaveProperty("token");
|
||||
expect(typeof result.token).toBe("string");
|
||||
// Token format: random:hmac (64 hex chars : 64 hex chars)
|
||||
expect(result.token).toContain(":");
|
||||
const parts = result.token.split(":");
|
||||
expect(parts[0]).toHaveLength(64);
|
||||
expect(parts[1]).toHaveLength(64);
|
||||
});
|
||||
|
||||
it("should set CSRF token in httpOnly cookie", () => {
|
||||
const mockRequest = createMockRequest("user-123");
|
||||
const mockResponse = createMockResponse();
|
||||
|
||||
const result = controller.getCsrfToken(mockRequest, mockResponse);
|
||||
|
||||
expect(mockResponse.cookie).toHaveBeenCalledWith(
|
||||
"csrf-token",
|
||||
result.token,
|
||||
expect.objectContaining({
|
||||
httpOnly: true,
|
||||
sameSite: "strict",
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should set secure flag in production", () => {
|
||||
process.env.NODE_ENV = "production";
|
||||
|
||||
const mockRequest = createMockRequest("user-123");
|
||||
const mockResponse = createMockResponse();
|
||||
|
||||
controller.getCsrfToken(mockRequest, mockResponse);
|
||||
|
||||
expect(mockResponse.cookie).toHaveBeenCalledWith(
|
||||
"csrf-token",
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
secure: true,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should not set secure flag in development", () => {
|
||||
process.env.NODE_ENV = "development";
|
||||
|
||||
const mockRequest = createMockRequest("user-123");
|
||||
const mockResponse = createMockResponse();
|
||||
|
||||
controller.getCsrfToken(mockRequest, mockResponse);
|
||||
|
||||
expect(mockResponse.cookie).toHaveBeenCalledWith(
|
||||
"csrf-token",
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
secure: false,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should generate unique tokens on each call", () => {
|
||||
const mockRequest = createMockRequest("user-123");
|
||||
const mockResponse = createMockResponse();
|
||||
|
||||
const result1 = controller.getCsrfToken(mockRequest, mockResponse);
|
||||
const result2 = controller.getCsrfToken(mockRequest, mockResponse);
|
||||
|
||||
expect(result1.token).not.toBe(result2.token);
|
||||
});
|
||||
|
||||
it("should set cookie with 24 hour expiry", () => {
|
||||
const mockRequest = createMockRequest("user-123");
|
||||
const mockResponse = createMockResponse();
|
||||
|
||||
controller.getCsrfToken(mockRequest, mockResponse);
|
||||
|
||||
expect(mockResponse.cookie).toHaveBeenCalledWith(
|
||||
"csrf-token",
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
maxAge: 24 * 60 * 60 * 1000, // 24 hours
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should throw error when user is not authenticated", () => {
|
||||
const mockRequest = createMockRequest(); // No user ID
|
||||
const mockResponse = createMockResponse();
|
||||
|
||||
expect(() => controller.getCsrfToken(mockRequest, mockResponse)).toThrow(
|
||||
"User ID not available after authentication"
|
||||
);
|
||||
});
|
||||
|
||||
it("should generate token bound to specific user session", () => {
|
||||
const mockRequest = createMockRequest("user-123");
|
||||
const mockResponse = createMockResponse();
|
||||
|
||||
const result = controller.getCsrfToken(mockRequest, mockResponse);
|
||||
|
||||
// Token should be valid for user-123
|
||||
expect(csrfService.validateToken(result.token, "user-123")).toBe(true);
|
||||
|
||||
// Token should be invalid for different user
|
||||
expect(csrfService.validateToken(result.token, "user-456")).toBe(false);
|
||||
});
|
||||
|
||||
it("should generate different tokens for different users", () => {
|
||||
const mockResponse = createMockResponse();
|
||||
|
||||
const request1 = createMockRequest("user-A");
|
||||
const request2 = createMockRequest("user-B");
|
||||
|
||||
const result1 = controller.getCsrfToken(request1, mockResponse);
|
||||
const result2 = controller.getCsrfToken(request2, mockResponse);
|
||||
|
||||
expect(result1.token).not.toBe(result2.token);
|
||||
|
||||
// Each token only valid for its user
|
||||
expect(csrfService.validateToken(result1.token, "user-A")).toBe(true);
|
||||
expect(csrfService.validateToken(result1.token, "user-B")).toBe(false);
|
||||
expect(csrfService.validateToken(result2.token, "user-B")).toBe(true);
|
||||
expect(csrfService.validateToken(result2.token, "user-A")).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
57
apps/api/src/common/controllers/csrf.controller.ts
Normal file
57
apps/api/src/common/controllers/csrf.controller.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
/**
|
||||
* CSRF Controller
|
||||
*
|
||||
* Provides CSRF token generation endpoint for client applications.
|
||||
* Tokens are cryptographically bound to the user session via HMAC.
|
||||
*/
|
||||
|
||||
import { Controller, Get, Res, Req, UseGuards } from "@nestjs/common";
|
||||
import { Response, Request } from "express";
|
||||
import { SkipCsrf } from "../decorators/skip-csrf.decorator";
|
||||
import { CsrfService } from "../services/csrf.service";
|
||||
import { AuthGuard } from "../../auth/guards/auth.guard";
|
||||
import type { AuthenticatedUser } from "../types/user.types";
|
||||
|
||||
interface AuthenticatedRequest extends Request {
|
||||
user?: AuthenticatedUser;
|
||||
}
|
||||
|
||||
@Controller("api/v1/csrf")
|
||||
export class CsrfController {
|
||||
constructor(private readonly csrfService: CsrfService) {}
|
||||
|
||||
/**
|
||||
* Generate and set CSRF token bound to user session
|
||||
* Requires authentication to bind token to session
|
||||
* Returns token to client and sets it in httpOnly cookie
|
||||
*/
|
||||
@Get("token")
|
||||
@UseGuards(AuthGuard)
|
||||
@SkipCsrf() // This endpoint itself doesn't need CSRF protection
|
||||
getCsrfToken(
|
||||
@Req() request: AuthenticatedRequest,
|
||||
@Res({ passthrough: true }) response: Response
|
||||
): { token: string } {
|
||||
// Get user ID from authenticated request
|
||||
const userId = request.user?.id;
|
||||
|
||||
if (!userId) {
|
||||
// This should not happen if AuthGuard is working correctly
|
||||
throw new Error("User ID not available after authentication");
|
||||
}
|
||||
|
||||
// Generate session-bound CSRF token
|
||||
const token = this.csrfService.generateToken(userId);
|
||||
|
||||
// Set token in httpOnly cookie
|
||||
response.cookie("csrf-token", token, {
|
||||
httpOnly: true,
|
||||
secure: process.env.NODE_ENV === "production",
|
||||
sameSite: "strict",
|
||||
maxAge: 24 * 60 * 60 * 1000, // 24 hours
|
||||
});
|
||||
|
||||
// Return token to client (so it can include in X-CSRF-Token header)
|
||||
return { token };
|
||||
}
|
||||
}
|
||||
52
apps/api/src/common/decorators/sanitize.decorator.ts
Normal file
52
apps/api/src/common/decorators/sanitize.decorator.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
/**
|
||||
* Sanitize Decorator
|
||||
*
|
||||
* Custom class-validator decorator to sanitize string input and prevent XSS.
|
||||
*/
|
||||
|
||||
import { Transform } from "class-transformer";
|
||||
import { sanitizeString, sanitizeObject } from "../utils/sanitize.util";
|
||||
|
||||
/**
|
||||
* Sanitize decorator for DTO properties
|
||||
* Automatically sanitizes string values to prevent XSS attacks
|
||||
*
|
||||
* Usage:
|
||||
* ```typescript
|
||||
* class MyDto {
|
||||
* @Sanitize()
|
||||
* @IsString()
|
||||
* userInput!: string;
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export function Sanitize(): PropertyDecorator {
|
||||
return Transform(({ value }: { value: unknown }) => {
|
||||
if (typeof value === "string") {
|
||||
return sanitizeString(value);
|
||||
}
|
||||
return value;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* SanitizeObject decorator for nested objects
|
||||
* Recursively sanitizes all string values in an object
|
||||
*
|
||||
* Usage:
|
||||
* ```typescript
|
||||
* class MyDto {
|
||||
* @SanitizeObject()
|
||||
* @IsObject()
|
||||
* metadata?: Record<string, unknown>;
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export function SanitizeObject(): PropertyDecorator {
|
||||
return Transform(({ value }: { value: unknown }) => {
|
||||
if (typeof value === "object" && value !== null) {
|
||||
return sanitizeObject(value as Record<string, unknown>);
|
||||
}
|
||||
return value;
|
||||
});
|
||||
}
|
||||
20
apps/api/src/common/decorators/skip-csrf.decorator.ts
Normal file
20
apps/api/src/common/decorators/skip-csrf.decorator.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
/**
|
||||
* Skip CSRF Decorator
|
||||
*
|
||||
* Marks an endpoint to skip CSRF protection.
|
||||
* Use for endpoints that have alternative authentication (e.g., signature verification).
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* @Post('incoming/connect')
|
||||
* @SkipCsrf()
|
||||
* async handleIncomingConnection() {
|
||||
* // Signature-based authentication, no CSRF needed
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
|
||||
import { SetMetadata } from "@nestjs/common";
|
||||
import { SKIP_CSRF_KEY } from "../guards/csrf.guard";
|
||||
|
||||
export const SkipCsrf = () => SetMetadata(SKIP_CSRF_KEY, true);
|
||||
@@ -113,34 +113,24 @@ describe("ApiKeyGuard", () => {
|
||||
const validApiKey = "test-api-key-12345";
|
||||
vi.mocked(mockConfigService.get).mockReturnValue(validApiKey);
|
||||
|
||||
const startTime = Date.now();
|
||||
const context1 = createMockExecutionContext({
|
||||
"x-api-key": "wrong-key-short",
|
||||
// Verify that same-length keys are compared properly (exercises timingSafeEqual path)
|
||||
// and different-length keys are rejected before comparison
|
||||
const sameLength = createMockExecutionContext({
|
||||
"x-api-key": "test-api-key-12344", // Same length, one char different
|
||||
});
|
||||
const differentLength = createMockExecutionContext({
|
||||
"x-api-key": "short", // Different length
|
||||
});
|
||||
|
||||
try {
|
||||
guard.canActivate(context1);
|
||||
} catch {
|
||||
// Expected to fail
|
||||
}
|
||||
const shortKeyTime = Date.now() - startTime;
|
||||
// Both should throw, proving the comparison logic handles both cases
|
||||
expect(() => guard.canActivate(sameLength)).toThrow("Invalid API key");
|
||||
expect(() => guard.canActivate(differentLength)).toThrow("Invalid API key");
|
||||
|
||||
const startTime2 = Date.now();
|
||||
const context2 = createMockExecutionContext({
|
||||
"x-api-key": "test-api-key-12344", // Very close to correct key
|
||||
// Correct key should pass
|
||||
const correct = createMockExecutionContext({
|
||||
"x-api-key": validApiKey,
|
||||
});
|
||||
|
||||
try {
|
||||
guard.canActivate(context2);
|
||||
} catch {
|
||||
// Expected to fail
|
||||
}
|
||||
const longKeyTime = Date.now() - startTime2;
|
||||
|
||||
// Times should be similar (within 10ms) to prevent timing attacks
|
||||
// Note: This is a simplified test; real timing attack prevention
|
||||
// is handled by crypto.timingSafeEqual
|
||||
expect(Math.abs(shortKeyTime - longKeyTime)).toBeLessThan(10);
|
||||
expect(guard.canActivate(correct)).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
320
apps/api/src/common/guards/csrf.guard.spec.ts
Normal file
320
apps/api/src/common/guards/csrf.guard.spec.ts
Normal file
@@ -0,0 +1,320 @@
|
||||
/**
|
||||
* CSRF Guard Tests
|
||||
*
|
||||
* Tests CSRF protection using double-submit cookie pattern with session binding.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
|
||||
import { ExecutionContext, ForbiddenException } from "@nestjs/common";
|
||||
import { Reflector } from "@nestjs/core";
|
||||
import { CsrfGuard } from "./csrf.guard";
|
||||
import { CsrfService } from "../services/csrf.service";
|
||||
|
||||
describe("CsrfGuard", () => {
|
||||
let guard: CsrfGuard;
|
||||
let reflector: Reflector;
|
||||
let csrfService: CsrfService;
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
process.env = { ...originalEnv };
|
||||
process.env.CSRF_SECRET = "test-secret-0123456789abcdef0123456789abcdef";
|
||||
reflector = new Reflector();
|
||||
csrfService = new CsrfService();
|
||||
csrfService.onModuleInit();
|
||||
guard = new CsrfGuard(reflector, csrfService);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
const createContext = (
|
||||
method: string,
|
||||
cookies: Record<string, string> = {},
|
||||
headers: Record<string, string> = {},
|
||||
skipCsrf = false,
|
||||
userId?: string
|
||||
): ExecutionContext => {
|
||||
const request = {
|
||||
method,
|
||||
cookies,
|
||||
headers,
|
||||
path: "/api/test",
|
||||
user: userId ? { id: userId, email: "test@example.com", name: "Test" } : undefined,
|
||||
};
|
||||
|
||||
return {
|
||||
switchToHttp: () => ({
|
||||
getRequest: () => request,
|
||||
}),
|
||||
getHandler: () => ({}),
|
||||
getClass: () => ({}),
|
||||
getAllAndOverride: vi.fn().mockReturnValue(skipCsrf),
|
||||
} as unknown as ExecutionContext;
|
||||
};
|
||||
|
||||
/**
|
||||
* Helper to generate a valid session-bound token
|
||||
*/
|
||||
const generateValidToken = (userId: string): string => {
|
||||
return csrfService.generateToken(userId);
|
||||
};
|
||||
|
||||
describe("Safe HTTP methods", () => {
|
||||
it("should allow GET requests without CSRF token", () => {
|
||||
const context = createContext("GET");
|
||||
expect(guard.canActivate(context)).toBe(true);
|
||||
});
|
||||
|
||||
it("should allow HEAD requests without CSRF token", () => {
|
||||
const context = createContext("HEAD");
|
||||
expect(guard.canActivate(context)).toBe(true);
|
||||
});
|
||||
|
||||
it("should allow OPTIONS requests without CSRF token", () => {
|
||||
const context = createContext("OPTIONS");
|
||||
expect(guard.canActivate(context)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Endpoints marked to skip CSRF", () => {
|
||||
it("should allow POST requests when @SkipCsrf() is applied", () => {
|
||||
vi.spyOn(reflector, "getAllAndOverride").mockReturnValue(true);
|
||||
const context = createContext("POST");
|
||||
expect(guard.canActivate(context)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("State-changing methods requiring CSRF", () => {
|
||||
it("should reject POST without CSRF token", () => {
|
||||
const context = createContext("POST", {}, {}, false, "user-123");
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
expect(() => guard.canActivate(context)).toThrow("CSRF token missing");
|
||||
});
|
||||
|
||||
it("should reject PUT without CSRF token", () => {
|
||||
const context = createContext("PUT", {}, {}, false, "user-123");
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
});
|
||||
|
||||
it("should reject PATCH without CSRF token", () => {
|
||||
const context = createContext("PATCH", {}, {}, false, "user-123");
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
});
|
||||
|
||||
it("should reject DELETE without CSRF token", () => {
|
||||
const context = createContext("DELETE", {}, {}, false, "user-123");
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
});
|
||||
|
||||
it("should reject when only cookie token is present", () => {
|
||||
const token = generateValidToken("user-123");
|
||||
const context = createContext("POST", { "csrf-token": token }, {}, false, "user-123");
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
expect(() => guard.canActivate(context)).toThrow("CSRF token missing");
|
||||
});
|
||||
|
||||
it("should reject when only header token is present", () => {
|
||||
const token = generateValidToken("user-123");
|
||||
const context = createContext("POST", {}, { "x-csrf-token": token }, false, "user-123");
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
expect(() => guard.canActivate(context)).toThrow("CSRF token missing");
|
||||
});
|
||||
|
||||
it("should reject when tokens do not match", () => {
|
||||
const token1 = generateValidToken("user-123");
|
||||
const token2 = generateValidToken("user-123");
|
||||
const context = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": token1 },
|
||||
{ "x-csrf-token": token2 },
|
||||
false,
|
||||
"user-123"
|
||||
);
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
expect(() => guard.canActivate(context)).toThrow("CSRF token mismatch");
|
||||
});
|
||||
|
||||
it("should allow when tokens match and session is valid", () => {
|
||||
const token = generateValidToken("user-123");
|
||||
const context = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": token },
|
||||
{ "x-csrf-token": token },
|
||||
false,
|
||||
"user-123"
|
||||
);
|
||||
expect(guard.canActivate(context)).toBe(true);
|
||||
});
|
||||
|
||||
it("should allow PATCH when tokens match and session is valid", () => {
|
||||
const token = generateValidToken("user-123");
|
||||
const context = createContext(
|
||||
"PATCH",
|
||||
{ "csrf-token": token },
|
||||
{ "x-csrf-token": token },
|
||||
false,
|
||||
"user-123"
|
||||
);
|
||||
expect(guard.canActivate(context)).toBe(true);
|
||||
});
|
||||
|
||||
it("should allow DELETE when tokens match and session is valid", () => {
|
||||
const token = generateValidToken("user-123");
|
||||
const context = createContext(
|
||||
"DELETE",
|
||||
{ "csrf-token": token },
|
||||
{ "x-csrf-token": token },
|
||||
false,
|
||||
"user-123"
|
||||
);
|
||||
expect(guard.canActivate(context)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Session binding validation", () => {
|
||||
it("should reject when user is not authenticated", () => {
|
||||
const token = generateValidToken("user-123");
|
||||
const context = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": token },
|
||||
{ "x-csrf-token": token },
|
||||
false
|
||||
// No userId - unauthenticated
|
||||
);
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
expect(() => guard.canActivate(context)).toThrow("CSRF validation requires authentication");
|
||||
});
|
||||
|
||||
it("should reject token from different session", () => {
|
||||
// Token generated for user-A
|
||||
const tokenForUserA = generateValidToken("user-A");
|
||||
|
||||
// But request is from user-B
|
||||
const context = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": tokenForUserA },
|
||||
{ "x-csrf-token": tokenForUserA },
|
||||
false,
|
||||
"user-B" // Different user
|
||||
);
|
||||
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
expect(() => guard.canActivate(context)).toThrow("CSRF token not bound to session");
|
||||
});
|
||||
|
||||
it("should reject token with invalid HMAC", () => {
|
||||
// Create a token with tampered HMAC
|
||||
const validToken = generateValidToken("user-123");
|
||||
const parts = validToken.split(":");
|
||||
const tamperedToken = `${parts[0]}:0000000000000000000000000000000000000000000000000000000000000000`;
|
||||
|
||||
const context = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": tamperedToken },
|
||||
{ "x-csrf-token": tamperedToken },
|
||||
false,
|
||||
"user-123"
|
||||
);
|
||||
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
expect(() => guard.canActivate(context)).toThrow("CSRF token not bound to session");
|
||||
});
|
||||
|
||||
it("should reject token with invalid format", () => {
|
||||
const invalidToken = "not-a-valid-token";
|
||||
|
||||
const context = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": invalidToken },
|
||||
{ "x-csrf-token": invalidToken },
|
||||
false,
|
||||
"user-123"
|
||||
);
|
||||
|
||||
expect(() => guard.canActivate(context)).toThrow(ForbiddenException);
|
||||
expect(() => guard.canActivate(context)).toThrow("CSRF token not bound to session");
|
||||
});
|
||||
|
||||
it("should not allow token reuse across sessions", () => {
|
||||
// Generate token for user-A
|
||||
const tokenA = generateValidToken("user-A");
|
||||
|
||||
// Valid for user-A
|
||||
const contextA = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": tokenA },
|
||||
{ "x-csrf-token": tokenA },
|
||||
false,
|
||||
"user-A"
|
||||
);
|
||||
expect(guard.canActivate(contextA)).toBe(true);
|
||||
|
||||
// Invalid for user-B
|
||||
const contextB = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": tokenA },
|
||||
{ "x-csrf-token": tokenA },
|
||||
false,
|
||||
"user-B"
|
||||
);
|
||||
expect(() => guard.canActivate(contextB)).toThrow("CSRF token not bound to session");
|
||||
|
||||
// Invalid for user-C
|
||||
const contextC = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": tokenA },
|
||||
{ "x-csrf-token": tokenA },
|
||||
false,
|
||||
"user-C"
|
||||
);
|
||||
expect(() => guard.canActivate(contextC)).toThrow("CSRF token not bound to session");
|
||||
});
|
||||
|
||||
it("should allow each user to use only their own token", () => {
|
||||
const tokenA = generateValidToken("user-A");
|
||||
const tokenB = generateValidToken("user-B");
|
||||
|
||||
// User A with token A - valid
|
||||
const contextAA = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": tokenA },
|
||||
{ "x-csrf-token": tokenA },
|
||||
false,
|
||||
"user-A"
|
||||
);
|
||||
expect(guard.canActivate(contextAA)).toBe(true);
|
||||
|
||||
// User B with token B - valid
|
||||
const contextBB = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": tokenB },
|
||||
{ "x-csrf-token": tokenB },
|
||||
false,
|
||||
"user-B"
|
||||
);
|
||||
expect(guard.canActivate(contextBB)).toBe(true);
|
||||
|
||||
// User A with token B - invalid (cross-session)
|
||||
const contextAB = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": tokenB },
|
||||
{ "x-csrf-token": tokenB },
|
||||
false,
|
||||
"user-A"
|
||||
);
|
||||
expect(() => guard.canActivate(contextAB)).toThrow("CSRF token not bound to session");
|
||||
|
||||
// User B with token A - invalid (cross-session)
|
||||
const contextBA = createContext(
|
||||
"POST",
|
||||
{ "csrf-token": tokenA },
|
||||
{ "x-csrf-token": tokenA },
|
||||
false,
|
||||
"user-B"
|
||||
);
|
||||
expect(() => guard.canActivate(contextBA)).toThrow("CSRF token not bound to session");
|
||||
});
|
||||
});
|
||||
});
|
||||
120
apps/api/src/common/guards/csrf.guard.ts
Normal file
120
apps/api/src/common/guards/csrf.guard.ts
Normal file
@@ -0,0 +1,120 @@
|
||||
/**
|
||||
* CSRF Guard
|
||||
*
|
||||
* Implements CSRF protection using double-submit cookie pattern with session binding.
|
||||
* Validates that:
|
||||
* 1. CSRF token in cookie matches token in header
|
||||
* 2. Token HMAC is valid for the current user session
|
||||
*
|
||||
* Usage:
|
||||
* - Apply to controllers handling state-changing operations
|
||||
* - Use @SkipCsrf() decorator to exempt specific endpoints
|
||||
* - Safe methods (GET, HEAD, OPTIONS) are automatically exempted
|
||||
*/
|
||||
|
||||
import {
|
||||
Injectable,
|
||||
CanActivate,
|
||||
ExecutionContext,
|
||||
ForbiddenException,
|
||||
Logger,
|
||||
} from "@nestjs/common";
|
||||
import { Reflector } from "@nestjs/core";
|
||||
import { Request } from "express";
|
||||
import { CsrfService } from "../services/csrf.service";
|
||||
import type { AuthenticatedUser } from "../types/user.types";
|
||||
|
||||
export const SKIP_CSRF_KEY = "skipCsrf";
|
||||
|
||||
interface RequestWithUser extends Request {
|
||||
user?: AuthenticatedUser;
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class CsrfGuard implements CanActivate {
|
||||
private readonly logger = new Logger(CsrfGuard.name);
|
||||
|
||||
constructor(
|
||||
private reflector: Reflector,
|
||||
private csrfService: CsrfService
|
||||
) {}
|
||||
|
||||
canActivate(context: ExecutionContext): boolean {
|
||||
// Check if endpoint is marked to skip CSRF
|
||||
const skipCsrf = this.reflector.getAllAndOverride<boolean>(SKIP_CSRF_KEY, [
|
||||
context.getHandler(),
|
||||
context.getClass(),
|
||||
]);
|
||||
|
||||
if (skipCsrf) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const request = context.switchToHttp().getRequest<RequestWithUser>();
|
||||
|
||||
// Exempt safe HTTP methods (GET, HEAD, OPTIONS)
|
||||
if (["GET", "HEAD", "OPTIONS"].includes(request.method)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Get CSRF token from cookie and header
|
||||
const cookies = request.cookies as Record<string, string> | undefined;
|
||||
const cookieToken = cookies?.["csrf-token"];
|
||||
const headerToken = request.headers["x-csrf-token"] as string | undefined;
|
||||
|
||||
// Validate tokens exist and match
|
||||
if (!cookieToken || !headerToken) {
|
||||
this.logger.warn({
|
||||
event: "CSRF_TOKEN_MISSING",
|
||||
method: request.method,
|
||||
path: request.path,
|
||||
hasCookie: !!cookieToken,
|
||||
hasHeader: !!headerToken,
|
||||
securityEvent: true,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
throw new ForbiddenException("CSRF token missing");
|
||||
}
|
||||
|
||||
if (cookieToken !== headerToken) {
|
||||
this.logger.warn({
|
||||
event: "CSRF_TOKEN_MISMATCH",
|
||||
method: request.method,
|
||||
path: request.path,
|
||||
securityEvent: true,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
throw new ForbiddenException("CSRF token mismatch");
|
||||
}
|
||||
|
||||
// Validate session binding via HMAC
|
||||
const userId = request.user?.id;
|
||||
if (!userId) {
|
||||
this.logger.warn({
|
||||
event: "CSRF_NO_USER_CONTEXT",
|
||||
method: request.method,
|
||||
path: request.path,
|
||||
securityEvent: true,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
throw new ForbiddenException("CSRF validation requires authentication");
|
||||
}
|
||||
|
||||
if (!this.csrfService.validateToken(cookieToken, userId)) {
|
||||
this.logger.warn({
|
||||
event: "CSRF_SESSION_BINDING_INVALID",
|
||||
method: request.method,
|
||||
path: request.path,
|
||||
securityEvent: true,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
throw new ForbiddenException("CSRF token not bound to session");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,11 @@
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { ExecutionContext, ForbiddenException } from "@nestjs/common";
|
||||
import { ExecutionContext, ForbiddenException, InternalServerErrorException } from "@nestjs/common";
|
||||
import { Reflector } from "@nestjs/core";
|
||||
import { Prisma, WorkspaceMemberRole } from "@prisma/client";
|
||||
import { PermissionGuard } from "./permission.guard";
|
||||
import { PrismaService } from "../../prisma/prisma.service";
|
||||
import { Permission } from "../decorators/permissions.decorator";
|
||||
import { WorkspaceMemberRole } from "@prisma/client";
|
||||
|
||||
describe("PermissionGuard", () => {
|
||||
let guard: PermissionGuard;
|
||||
@@ -208,13 +208,67 @@ describe("PermissionGuard", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle database errors gracefully", async () => {
|
||||
it("should throw InternalServerErrorException on database connection errors", async () => {
|
||||
const context = createMockExecutionContext({ id: userId }, { id: workspaceId });
|
||||
|
||||
mockReflector.getAllAndOverride.mockReturnValue(Permission.WORKSPACE_MEMBER);
|
||||
mockPrismaService.workspaceMember.findUnique.mockRejectedValue(new Error("Database error"));
|
||||
mockPrismaService.workspaceMember.findUnique.mockRejectedValue(
|
||||
new Error("Database connection failed")
|
||||
);
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(InternalServerErrorException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow("Failed to verify permissions");
|
||||
});
|
||||
|
||||
it("should throw InternalServerErrorException on Prisma connection timeout", async () => {
|
||||
const context = createMockExecutionContext({ id: userId }, { id: workspaceId });
|
||||
|
||||
mockReflector.getAllAndOverride.mockReturnValue(Permission.WORKSPACE_MEMBER);
|
||||
|
||||
const prismaError = new Prisma.PrismaClientKnownRequestError("Connection timed out", {
|
||||
code: "P1001", // Authentication failed (connection error)
|
||||
clientVersion: "5.0.0",
|
||||
});
|
||||
|
||||
mockPrismaService.workspaceMember.findUnique.mockRejectedValue(prismaError);
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(InternalServerErrorException);
|
||||
});
|
||||
|
||||
it("should return null role for Prisma not found error (P2025)", async () => {
|
||||
const context = createMockExecutionContext({ id: userId }, { id: workspaceId });
|
||||
|
||||
mockReflector.getAllAndOverride.mockReturnValue(Permission.WORKSPACE_MEMBER);
|
||||
|
||||
const prismaError = new Prisma.PrismaClientKnownRequestError("Record not found", {
|
||||
code: "P2025", // Record not found
|
||||
clientVersion: "5.0.0",
|
||||
});
|
||||
|
||||
mockPrismaService.workspaceMember.findUnique.mockRejectedValue(prismaError);
|
||||
|
||||
// P2025 should be treated as "not a member" -> ForbiddenException
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(ForbiddenException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(
|
||||
"You are not a member of this workspace"
|
||||
);
|
||||
});
|
||||
|
||||
it("should NOT mask database pool exhaustion as permission denied", async () => {
|
||||
const context = createMockExecutionContext({ id: userId }, { id: workspaceId });
|
||||
|
||||
mockReflector.getAllAndOverride.mockReturnValue(Permission.WORKSPACE_MEMBER);
|
||||
|
||||
const prismaError = new Prisma.PrismaClientKnownRequestError("Connection pool exhausted", {
|
||||
code: "P2024", // Connection pool timeout
|
||||
clientVersion: "5.0.0",
|
||||
});
|
||||
|
||||
mockPrismaService.workspaceMember.findUnique.mockRejectedValue(prismaError);
|
||||
|
||||
// Should NOT throw ForbiddenException for DB errors
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(InternalServerErrorException);
|
||||
await expect(guard.canActivate(context)).rejects.not.toThrow(ForbiddenException);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3,8 +3,10 @@ import {
|
||||
CanActivate,
|
||||
ExecutionContext,
|
||||
ForbiddenException,
|
||||
InternalServerErrorException,
|
||||
Logger,
|
||||
} from "@nestjs/common";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { Reflector } from "@nestjs/core";
|
||||
import { PrismaService } from "../../prisma/prisma.service";
|
||||
import { PERMISSION_KEY, Permission } from "../decorators/permissions.decorator";
|
||||
@@ -99,6 +101,10 @@ export class PermissionGuard implements CanActivate {
|
||||
|
||||
/**
|
||||
* Fetches the user's role in a workspace
|
||||
*
|
||||
* SEC-API-3 FIX: Database errors are no longer swallowed as null role.
|
||||
* Connection timeouts, pool exhaustion, and other infrastructure errors
|
||||
* are propagated as 500 errors to avoid masking operational issues.
|
||||
*/
|
||||
private async getUserWorkspaceRole(
|
||||
userId: string,
|
||||
@@ -119,11 +125,23 @@ export class PermissionGuard implements CanActivate {
|
||||
|
||||
return member?.role ?? null;
|
||||
} catch (error) {
|
||||
// Only handle Prisma "not found" errors (P2025) as expected cases
|
||||
// All other database errors (connection, timeout, pool) should propagate
|
||||
if (
|
||||
error instanceof Prisma.PrismaClientKnownRequestError &&
|
||||
error.code === "P2025" // Record not found
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Log the error before propagating
|
||||
this.logger.error(
|
||||
`Failed to fetch user role: ${error instanceof Error ? error.message : "Unknown error"}`,
|
||||
`Database error during permission check: ${error instanceof Error ? error.message : "Unknown error"}`,
|
||||
error instanceof Error ? error.stack : undefined
|
||||
);
|
||||
return null;
|
||||
|
||||
// Propagate infrastructure errors as 500s, not permission denied
|
||||
throw new InternalServerErrorException("Failed to verify permissions");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { ExecutionContext, ForbiddenException, BadRequestException } from "@nestjs/common";
|
||||
import {
|
||||
ExecutionContext,
|
||||
ForbiddenException,
|
||||
BadRequestException,
|
||||
InternalServerErrorException,
|
||||
} from "@nestjs/common";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { WorkspaceGuard } from "./workspace.guard";
|
||||
import { PrismaService } from "../../prisma/prisma.service";
|
||||
|
||||
@@ -37,13 +43,15 @@ describe("WorkspaceGuard", () => {
|
||||
user: any,
|
||||
headers: Record<string, string> = {},
|
||||
params: Record<string, string> = {},
|
||||
body: Record<string, any> = {}
|
||||
body: Record<string, any> = {},
|
||||
query: Record<string, string> = {}
|
||||
): ExecutionContext => {
|
||||
const mockRequest = {
|
||||
user,
|
||||
headers,
|
||||
params,
|
||||
body,
|
||||
query,
|
||||
};
|
||||
|
||||
return {
|
||||
@@ -111,16 +119,40 @@ describe("WorkspaceGuard", () => {
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it("should prioritize header over param and body", async () => {
|
||||
it("should allow access when user is a workspace member (via query string)", async () => {
|
||||
const context = createMockExecutionContext({ id: userId }, {}, {}, {}, { workspaceId });
|
||||
|
||||
mockPrismaService.workspaceMember.findUnique.mockResolvedValue({
|
||||
workspaceId,
|
||||
userId,
|
||||
role: "MEMBER",
|
||||
});
|
||||
|
||||
const result = await guard.canActivate(context);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockPrismaService.workspaceMember.findUnique).toHaveBeenCalledWith({
|
||||
where: {
|
||||
workspaceId_userId: {
|
||||
workspaceId,
|
||||
userId,
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should prioritize header over param, body, and query", async () => {
|
||||
const headerWorkspaceId = "workspace-header";
|
||||
const paramWorkspaceId = "workspace-param";
|
||||
const bodyWorkspaceId = "workspace-body";
|
||||
const queryWorkspaceId = "workspace-query";
|
||||
|
||||
const context = createMockExecutionContext(
|
||||
{ id: userId },
|
||||
{ "x-workspace-id": headerWorkspaceId },
|
||||
{ workspaceId: paramWorkspaceId },
|
||||
{ workspaceId: bodyWorkspaceId }
|
||||
{ workspaceId: bodyWorkspaceId },
|
||||
{ workspaceId: queryWorkspaceId }
|
||||
);
|
||||
|
||||
mockPrismaService.workspaceMember.findUnique.mockResolvedValue({
|
||||
@@ -141,6 +173,67 @@ describe("WorkspaceGuard", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should prioritize param over body and query when header missing", async () => {
|
||||
const paramWorkspaceId = "workspace-param";
|
||||
const bodyWorkspaceId = "workspace-body";
|
||||
const queryWorkspaceId = "workspace-query";
|
||||
|
||||
const context = createMockExecutionContext(
|
||||
{ id: userId },
|
||||
{},
|
||||
{ workspaceId: paramWorkspaceId },
|
||||
{ workspaceId: bodyWorkspaceId },
|
||||
{ workspaceId: queryWorkspaceId }
|
||||
);
|
||||
|
||||
mockPrismaService.workspaceMember.findUnique.mockResolvedValue({
|
||||
workspaceId: paramWorkspaceId,
|
||||
userId,
|
||||
role: "MEMBER",
|
||||
});
|
||||
|
||||
await guard.canActivate(context);
|
||||
|
||||
expect(mockPrismaService.workspaceMember.findUnique).toHaveBeenCalledWith({
|
||||
where: {
|
||||
workspaceId_userId: {
|
||||
workspaceId: paramWorkspaceId,
|
||||
userId,
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should prioritize body over query when header and param missing", async () => {
|
||||
const bodyWorkspaceId = "workspace-body";
|
||||
const queryWorkspaceId = "workspace-query";
|
||||
|
||||
const context = createMockExecutionContext(
|
||||
{ id: userId },
|
||||
{},
|
||||
{},
|
||||
{ workspaceId: bodyWorkspaceId },
|
||||
{ workspaceId: queryWorkspaceId }
|
||||
);
|
||||
|
||||
mockPrismaService.workspaceMember.findUnique.mockResolvedValue({
|
||||
workspaceId: bodyWorkspaceId,
|
||||
userId,
|
||||
role: "MEMBER",
|
||||
});
|
||||
|
||||
await guard.canActivate(context);
|
||||
|
||||
expect(mockPrismaService.workspaceMember.findUnique).toHaveBeenCalledWith({
|
||||
where: {
|
||||
workspaceId_userId: {
|
||||
workspaceId: bodyWorkspaceId,
|
||||
userId,
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should throw ForbiddenException when user is not authenticated", async () => {
|
||||
const context = createMockExecutionContext(null, { "x-workspace-id": workspaceId });
|
||||
|
||||
@@ -166,14 +259,60 @@ describe("WorkspaceGuard", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle database errors gracefully", async () => {
|
||||
it("should throw InternalServerErrorException on database connection errors", async () => {
|
||||
const context = createMockExecutionContext({ id: userId }, { "x-workspace-id": workspaceId });
|
||||
|
||||
mockPrismaService.workspaceMember.findUnique.mockRejectedValue(
|
||||
new Error("Database connection failed")
|
||||
);
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(InternalServerErrorException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow("Failed to verify workspace access");
|
||||
});
|
||||
|
||||
it("should throw InternalServerErrorException on Prisma connection timeout", async () => {
|
||||
const context = createMockExecutionContext({ id: userId }, { "x-workspace-id": workspaceId });
|
||||
|
||||
const prismaError = new Prisma.PrismaClientKnownRequestError("Connection timed out", {
|
||||
code: "P1001", // Authentication failed (connection error)
|
||||
clientVersion: "5.0.0",
|
||||
});
|
||||
|
||||
mockPrismaService.workspaceMember.findUnique.mockRejectedValue(prismaError);
|
||||
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(InternalServerErrorException);
|
||||
});
|
||||
|
||||
it("should return false for Prisma not found error (P2025)", async () => {
|
||||
const context = createMockExecutionContext({ id: userId }, { "x-workspace-id": workspaceId });
|
||||
|
||||
const prismaError = new Prisma.PrismaClientKnownRequestError("Record not found", {
|
||||
code: "P2025", // Record not found
|
||||
clientVersion: "5.0.0",
|
||||
});
|
||||
|
||||
mockPrismaService.workspaceMember.findUnique.mockRejectedValue(prismaError);
|
||||
|
||||
// P2025 should be treated as "not a member" -> ForbiddenException
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(ForbiddenException);
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(
|
||||
"You do not have access to this workspace"
|
||||
);
|
||||
});
|
||||
|
||||
it("should NOT mask database pool exhaustion as access denied", async () => {
|
||||
const context = createMockExecutionContext({ id: userId }, { "x-workspace-id": workspaceId });
|
||||
|
||||
const prismaError = new Prisma.PrismaClientKnownRequestError("Connection pool exhausted", {
|
||||
code: "P2024", // Connection pool timeout
|
||||
clientVersion: "5.0.0",
|
||||
});
|
||||
|
||||
mockPrismaService.workspaceMember.findUnique.mockRejectedValue(prismaError);
|
||||
|
||||
// Should NOT throw ForbiddenException for DB errors
|
||||
await expect(guard.canActivate(context)).rejects.toThrow(InternalServerErrorException);
|
||||
await expect(guard.canActivate(context)).rejects.not.toThrow(ForbiddenException);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,8 +4,10 @@ import {
|
||||
ExecutionContext,
|
||||
ForbiddenException,
|
||||
BadRequestException,
|
||||
InternalServerErrorException,
|
||||
Logger,
|
||||
} from "@nestjs/common";
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { PrismaService } from "../../prisma/prisma.service";
|
||||
import type { AuthenticatedRequest } from "../types/user.types";
|
||||
|
||||
@@ -30,11 +32,12 @@ import type { AuthenticatedRequest } from "../types/user.types";
|
||||
* ```
|
||||
*
|
||||
* The workspace ID can be provided via:
|
||||
* - Header: `X-Workspace-Id`
|
||||
* - Header: `X-Workspace-Id` (recommended)
|
||||
* - URL parameter: `:workspaceId`
|
||||
* - Request body: `workspaceId` field
|
||||
* - Query parameter: `?workspaceId=xxx` (backward compatibility)
|
||||
*
|
||||
* Priority: Header > Param > Body
|
||||
* Priority: Header > Param > Body > Query
|
||||
*
|
||||
* Note: RLS context must be set at the service layer using withUserContext()
|
||||
* or withUserTransaction() to ensure proper transaction scoping with connection pooling.
|
||||
@@ -58,7 +61,7 @@ export class WorkspaceGuard implements CanActivate {
|
||||
|
||||
if (!workspaceId) {
|
||||
throw new BadRequestException(
|
||||
"Workspace ID is required (via header X-Workspace-Id, URL parameter, or request body)"
|
||||
"Workspace ID is required (via header X-Workspace-Id, URL parameter, request body, or query string)"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -89,18 +92,19 @@ export class WorkspaceGuard implements CanActivate {
|
||||
|
||||
/**
|
||||
* Extracts workspace ID from request in order of priority:
|
||||
* 1. X-Workspace-Id header
|
||||
* 1. X-Workspace-Id header (recommended)
|
||||
* 2. :workspaceId URL parameter
|
||||
* 3. workspaceId in request body
|
||||
* 4. workspaceId query parameter (for backward compatibility)
|
||||
*/
|
||||
private extractWorkspaceId(request: AuthenticatedRequest): string | undefined {
|
||||
// 1. Check header
|
||||
// 1. Check header (recommended approach)
|
||||
const headerWorkspaceId = request.headers["x-workspace-id"];
|
||||
if (typeof headerWorkspaceId === "string") {
|
||||
return headerWorkspaceId;
|
||||
}
|
||||
|
||||
// 2. Check URL params
|
||||
// 2. Check URL params (:workspaceId in route)
|
||||
const paramWorkspaceId = request.params.workspaceId;
|
||||
if (paramWorkspaceId) {
|
||||
return paramWorkspaceId;
|
||||
@@ -112,11 +116,23 @@ export class WorkspaceGuard implements CanActivate {
|
||||
return bodyWorkspaceId;
|
||||
}
|
||||
|
||||
// 4. Check query string (backward compatibility for existing clients)
|
||||
// Access query property if it exists (may not be in all request types)
|
||||
const requestWithQuery = request as typeof request & { query?: Record<string, unknown> };
|
||||
const queryWorkspaceId = requestWithQuery.query?.workspaceId;
|
||||
if (typeof queryWorkspaceId === "string") {
|
||||
return queryWorkspaceId;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Verifies that a user is a member of the specified workspace
|
||||
*
|
||||
* SEC-API-2 FIX: Database errors are no longer swallowed as "access denied".
|
||||
* Connection timeouts, pool exhaustion, and other infrastructure errors
|
||||
* are propagated as 500 errors to avoid masking operational issues.
|
||||
*/
|
||||
private async verifyWorkspaceMembership(userId: string, workspaceId: string): Promise<boolean> {
|
||||
try {
|
||||
@@ -131,11 +147,23 @@ export class WorkspaceGuard implements CanActivate {
|
||||
|
||||
return member !== null;
|
||||
} catch (error) {
|
||||
// Only handle Prisma "not found" errors (P2025) as expected cases
|
||||
// All other database errors (connection, timeout, pool) should propagate
|
||||
if (
|
||||
error instanceof Prisma.PrismaClientKnownRequestError &&
|
||||
error.code === "P2025" // Record not found
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Log the error before propagating
|
||||
this.logger.error(
|
||||
`Failed to verify workspace membership: ${error instanceof Error ? error.message : "Unknown error"}`,
|
||||
`Database error during workspace membership check: ${error instanceof Error ? error.message : "Unknown error"}`,
|
||||
error instanceof Error ? error.stack : undefined
|
||||
);
|
||||
return false;
|
||||
|
||||
// Propagate infrastructure errors as 500s, not access denied
|
||||
throw new InternalServerErrorException("Failed to verify workspace access");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
198
apps/api/src/common/interceptors/rls-context.integration.spec.ts
Normal file
198
apps/api/src/common/interceptors/rls-context.integration.spec.ts
Normal file
@@ -0,0 +1,198 @@
|
||||
/**
|
||||
* RLS Context Integration Tests
|
||||
*
|
||||
* Tests that the RlsContextInterceptor correctly sets RLS context
|
||||
* and that services can access the RLS-scoped client.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { Injectable, Controller, Get, UseGuards, UseInterceptors } from "@nestjs/common";
|
||||
import { of } from "rxjs";
|
||||
import { RlsContextInterceptor, type TransactionClient } from "./rls-context.interceptor";
|
||||
import { PrismaService } from "../../prisma/prisma.service";
|
||||
import { getRlsClient } from "../../prisma/rls-context.provider";
|
||||
|
||||
/**
|
||||
* Mock service that uses getRlsClient() pattern
|
||||
*/
|
||||
@Injectable()
|
||||
class TestService {
|
||||
private rlsClientUsed = false;
|
||||
private queriesExecuted: string[] = [];
|
||||
|
||||
constructor(private readonly prisma: PrismaService) {}
|
||||
|
||||
async findWithRls(): Promise<{ usedRlsClient: boolean; queries: string[] }> {
|
||||
const client = getRlsClient() ?? this.prisma;
|
||||
this.rlsClientUsed = client !== this.prisma;
|
||||
|
||||
// Track that we're using the client
|
||||
this.queriesExecuted.push("findMany");
|
||||
|
||||
return {
|
||||
usedRlsClient: this.rlsClientUsed,
|
||||
queries: this.queriesExecuted,
|
||||
};
|
||||
}
|
||||
|
||||
reset() {
|
||||
this.rlsClientUsed = false;
|
||||
this.queriesExecuted = [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mock controller that uses the test service
|
||||
*/
|
||||
@Controller("test")
|
||||
class TestController {
|
||||
constructor(private readonly testService: TestService) {}
|
||||
|
||||
@Get()
|
||||
@UseInterceptors(RlsContextInterceptor)
|
||||
async test() {
|
||||
return this.testService.findWithRls();
|
||||
}
|
||||
}
|
||||
|
||||
describe("RLS Context Integration", () => {
|
||||
let testService: TestService;
|
||||
let prismaService: PrismaService;
|
||||
let mockTransactionClient: TransactionClient;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create mock transaction client (excludes $connect, $disconnect, etc.)
|
||||
mockTransactionClient = {
|
||||
$executeRaw: vi.fn().mockResolvedValue(undefined),
|
||||
} as unknown as TransactionClient;
|
||||
|
||||
// Create mock Prisma service
|
||||
const mockPrismaService = {
|
||||
$transaction: vi.fn(async (callback: (tx: TransactionClient) => Promise<unknown>) => {
|
||||
return callback(mockTransactionClient);
|
||||
}),
|
||||
};
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
controllers: [TestController],
|
||||
providers: [
|
||||
TestService,
|
||||
RlsContextInterceptor,
|
||||
{
|
||||
provide: PrismaService,
|
||||
useValue: mockPrismaService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
testService = module.get<TestService>(TestService);
|
||||
prismaService = module.get<PrismaService>(PrismaService);
|
||||
});
|
||||
|
||||
describe("Service queries with RLS context", () => {
|
||||
it("should provide RLS client to services when user is authenticated", async () => {
|
||||
const userId = "user-123";
|
||||
const workspaceId = "workspace-456";
|
||||
|
||||
// Create interceptor instance
|
||||
const interceptor = new RlsContextInterceptor(prismaService);
|
||||
|
||||
// Mock execution context
|
||||
const mockContext = {
|
||||
switchToHttp: () => ({
|
||||
getRequest: () => ({
|
||||
user: {
|
||||
id: userId,
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
workspaceId,
|
||||
},
|
||||
workspace: {
|
||||
id: workspaceId,
|
||||
},
|
||||
}),
|
||||
}),
|
||||
} as any;
|
||||
|
||||
// Mock call handler
|
||||
const mockNext = {
|
||||
handle: vi.fn(() => {
|
||||
// This simulates the controller calling the service
|
||||
// Must return an Observable, not a Promise
|
||||
const result = testService.findWithRls();
|
||||
return of(result);
|
||||
}),
|
||||
} as any;
|
||||
|
||||
const result = await new Promise((resolve) => {
|
||||
interceptor.intercept(mockContext, mockNext).subscribe({
|
||||
next: resolve,
|
||||
});
|
||||
});
|
||||
|
||||
// Verify RLS client was used
|
||||
expect(result).toMatchObject({
|
||||
usedRlsClient: true,
|
||||
queries: ["findMany"],
|
||||
});
|
||||
|
||||
// Verify SET LOCAL was called
|
||||
expect(mockTransactionClient.$executeRaw).toHaveBeenCalledWith(
|
||||
expect.arrayContaining(["SET LOCAL app.current_user_id = ", ""]),
|
||||
userId
|
||||
);
|
||||
expect(mockTransactionClient.$executeRaw).toHaveBeenCalledWith(
|
||||
expect.arrayContaining(["SET LOCAL app.current_workspace_id = ", ""]),
|
||||
workspaceId
|
||||
);
|
||||
});
|
||||
|
||||
it("should fall back to standard client when no RLS context", async () => {
|
||||
// Call service directly without going through interceptor
|
||||
testService.reset();
|
||||
const result = await testService.findWithRls();
|
||||
|
||||
expect(result).toMatchObject({
|
||||
usedRlsClient: false,
|
||||
queries: ["findMany"],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("RLS context scoping", () => {
|
||||
it("should clear RLS context after request completes", async () => {
|
||||
const userId = "user-123";
|
||||
|
||||
const interceptor = new RlsContextInterceptor(prismaService);
|
||||
|
||||
const mockContext = {
|
||||
switchToHttp: () => ({
|
||||
getRequest: () => ({
|
||||
user: {
|
||||
id: userId,
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
},
|
||||
}),
|
||||
}),
|
||||
} as any;
|
||||
|
||||
const mockNext = {
|
||||
handle: vi.fn(() => {
|
||||
return of({ data: "test" });
|
||||
}),
|
||||
} as any;
|
||||
|
||||
await new Promise((resolve) => {
|
||||
interceptor.intercept(mockContext, mockNext).subscribe({
|
||||
next: resolve,
|
||||
});
|
||||
});
|
||||
|
||||
// After request completes, RLS context should be cleared
|
||||
const client = getRlsClient();
|
||||
expect(client).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
306
apps/api/src/common/interceptors/rls-context.interceptor.spec.ts
Normal file
306
apps/api/src/common/interceptors/rls-context.interceptor.spec.ts
Normal file
@@ -0,0 +1,306 @@
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { ExecutionContext, CallHandler, InternalServerErrorException } from "@nestjs/common";
|
||||
import { of, throwError } from "rxjs";
|
||||
import { RlsContextInterceptor, type TransactionClient } from "./rls-context.interceptor";
|
||||
import { PrismaService } from "../../prisma/prisma.service";
|
||||
import { getRlsClient } from "../../prisma/rls-context.provider";
|
||||
import type { AuthenticatedRequest } from "../types/user.types";
|
||||
|
||||
describe("RlsContextInterceptor", () => {
|
||||
let interceptor: RlsContextInterceptor;
|
||||
let prismaService: PrismaService;
|
||||
let mockExecutionContext: ExecutionContext;
|
||||
let mockCallHandler: CallHandler;
|
||||
let mockTransactionClient: TransactionClient;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create mock transaction client (excludes $connect, $disconnect, etc.)
|
||||
mockTransactionClient = {
|
||||
$executeRaw: vi.fn().mockResolvedValue(undefined),
|
||||
} as unknown as TransactionClient;
|
||||
|
||||
// Create mock Prisma service
|
||||
const mockPrismaService = {
|
||||
$transaction: vi.fn(
|
||||
async (
|
||||
callback: (tx: TransactionClient) => Promise<unknown>,
|
||||
options?: { timeout?: number; maxWait?: number }
|
||||
) => {
|
||||
return callback(mockTransactionClient);
|
||||
}
|
||||
),
|
||||
};
|
||||
|
||||
const module: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
RlsContextInterceptor,
|
||||
{
|
||||
provide: PrismaService,
|
||||
useValue: mockPrismaService,
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
|
||||
interceptor = module.get<RlsContextInterceptor>(RlsContextInterceptor);
|
||||
prismaService = module.get<PrismaService>(PrismaService);
|
||||
|
||||
// Setup mock call handler
|
||||
mockCallHandler = {
|
||||
handle: vi.fn(() => of({ data: "test response" })),
|
||||
};
|
||||
});
|
||||
|
||||
const createMockExecutionContext = (request: Partial<AuthenticatedRequest>): ExecutionContext => {
|
||||
return {
|
||||
switchToHttp: () => ({
|
||||
getRequest: () => request,
|
||||
}),
|
||||
} as ExecutionContext;
|
||||
};
|
||||
|
||||
describe("intercept", () => {
|
||||
it("should set user context when user is authenticated", async () => {
|
||||
const userId = "user-123";
|
||||
const request: Partial<AuthenticatedRequest> = {
|
||||
user: {
|
||||
id: userId,
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
},
|
||||
};
|
||||
|
||||
mockExecutionContext = createMockExecutionContext(request);
|
||||
|
||||
const result = await new Promise((resolve) => {
|
||||
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
||||
next: resolve,
|
||||
});
|
||||
});
|
||||
|
||||
expect(result).toEqual({ data: "test response" });
|
||||
expect(mockTransactionClient.$executeRaw).toHaveBeenCalledWith(
|
||||
expect.arrayContaining(["SET LOCAL app.current_user_id = ", ""]),
|
||||
userId
|
||||
);
|
||||
});
|
||||
|
||||
it("should set workspace context when workspace is present", async () => {
|
||||
const userId = "user-123";
|
||||
const workspaceId = "workspace-456";
|
||||
const request: Partial<AuthenticatedRequest> = {
|
||||
user: {
|
||||
id: userId,
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
workspaceId,
|
||||
},
|
||||
workspace: {
|
||||
id: workspaceId,
|
||||
},
|
||||
};
|
||||
|
||||
mockExecutionContext = createMockExecutionContext(request);
|
||||
|
||||
await new Promise((resolve) => {
|
||||
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
||||
next: resolve,
|
||||
});
|
||||
});
|
||||
|
||||
// Check that user context was set
|
||||
expect(mockTransactionClient.$executeRaw).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.arrayContaining(["SET LOCAL app.current_user_id = ", ""]),
|
||||
userId
|
||||
);
|
||||
// Check that workspace context was set
|
||||
expect(mockTransactionClient.$executeRaw).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.arrayContaining(["SET LOCAL app.current_workspace_id = ", ""]),
|
||||
workspaceId
|
||||
);
|
||||
});
|
||||
|
||||
it("should not set context when user is not authenticated", async () => {
|
||||
const request: Partial<AuthenticatedRequest> = {
|
||||
user: undefined,
|
||||
};
|
||||
|
||||
mockExecutionContext = createMockExecutionContext(request);
|
||||
|
||||
await new Promise((resolve) => {
|
||||
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
||||
next: resolve,
|
||||
});
|
||||
});
|
||||
|
||||
expect(mockTransactionClient.$executeRaw).not.toHaveBeenCalled();
|
||||
expect(mockCallHandler.handle).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should propagate RLS client via AsyncLocalStorage", async () => {
|
||||
const userId = "user-123";
|
||||
const request: Partial<AuthenticatedRequest> = {
|
||||
user: {
|
||||
id: userId,
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
},
|
||||
};
|
||||
|
||||
mockExecutionContext = createMockExecutionContext(request);
|
||||
|
||||
// Override call handler to check if RLS client is available
|
||||
let capturedClient: PrismaClient | undefined;
|
||||
mockCallHandler = {
|
||||
handle: vi.fn(() => {
|
||||
capturedClient = getRlsClient();
|
||||
return of({ data: "test response" });
|
||||
}),
|
||||
};
|
||||
|
||||
await new Promise((resolve) => {
|
||||
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
||||
next: resolve,
|
||||
});
|
||||
});
|
||||
|
||||
expect(capturedClient).toBe(mockTransactionClient);
|
||||
});
|
||||
|
||||
it("should handle errors and still propagate them", async () => {
|
||||
const userId = "user-123";
|
||||
const request: Partial<AuthenticatedRequest> = {
|
||||
user: {
|
||||
id: userId,
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
},
|
||||
};
|
||||
|
||||
mockExecutionContext = createMockExecutionContext(request);
|
||||
|
||||
const error = new Error("Test error");
|
||||
mockCallHandler = {
|
||||
handle: vi.fn(() => throwError(() => error)),
|
||||
};
|
||||
|
||||
await expect(
|
||||
new Promise((resolve, reject) => {
|
||||
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
||||
next: resolve,
|
||||
error: reject,
|
||||
});
|
||||
})
|
||||
).rejects.toThrow(error);
|
||||
|
||||
// Context should still have been set before error
|
||||
expect(mockTransactionClient.$executeRaw).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should clear RLS context after request completes", async () => {
|
||||
const userId = "user-123";
|
||||
const request: Partial<AuthenticatedRequest> = {
|
||||
user: {
|
||||
id: userId,
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
},
|
||||
};
|
||||
|
||||
mockExecutionContext = createMockExecutionContext(request);
|
||||
|
||||
await new Promise((resolve) => {
|
||||
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
||||
next: resolve,
|
||||
});
|
||||
});
|
||||
|
||||
// After the observable completes, RLS context should be cleared
|
||||
const client = getRlsClient();
|
||||
expect(client).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should handle missing user.id gracefully", async () => {
|
||||
const request: Partial<AuthenticatedRequest> = {
|
||||
user: {
|
||||
id: "",
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
},
|
||||
};
|
||||
|
||||
mockExecutionContext = createMockExecutionContext(request);
|
||||
|
||||
await new Promise((resolve) => {
|
||||
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
||||
next: resolve,
|
||||
});
|
||||
});
|
||||
|
||||
expect(mockTransactionClient.$executeRaw).not.toHaveBeenCalled();
|
||||
expect(mockCallHandler.handle).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should configure transaction with timeout and maxWait", async () => {
|
||||
const userId = "user-123";
|
||||
const request: Partial<AuthenticatedRequest> = {
|
||||
user: {
|
||||
id: userId,
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
},
|
||||
};
|
||||
|
||||
mockExecutionContext = createMockExecutionContext(request);
|
||||
|
||||
await new Promise((resolve) => {
|
||||
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
||||
next: resolve,
|
||||
});
|
||||
});
|
||||
|
||||
// Verify transaction was called with timeout options
|
||||
expect(prismaService.$transaction).toHaveBeenCalledWith(
|
||||
expect.any(Function),
|
||||
expect.objectContaining({
|
||||
timeout: 30000, // 30 seconds
|
||||
maxWait: 10000, // 10 seconds
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
it("should sanitize database errors before sending to client", async () => {
|
||||
const userId = "user-123";
|
||||
const request: Partial<AuthenticatedRequest> = {
|
||||
user: {
|
||||
id: userId,
|
||||
email: "test@example.com",
|
||||
name: "Test User",
|
||||
},
|
||||
};
|
||||
|
||||
mockExecutionContext = createMockExecutionContext(request);
|
||||
|
||||
// Mock transaction to throw a database error with sensitive information
|
||||
const databaseError = new Error(
|
||||
"PrismaClientKnownRequestError: Connection failed to database.internal.example.com:5432"
|
||||
);
|
||||
vi.spyOn(prismaService, "$transaction").mockRejectedValue(databaseError);
|
||||
|
||||
const errorPromise = new Promise((resolve, reject) => {
|
||||
interceptor.intercept(mockExecutionContext, mockCallHandler).subscribe({
|
||||
next: resolve,
|
||||
error: reject,
|
||||
});
|
||||
});
|
||||
|
||||
await expect(errorPromise).rejects.toThrow(InternalServerErrorException);
|
||||
await expect(errorPromise).rejects.toThrow("Request processing failed");
|
||||
|
||||
// Verify the detailed error was NOT sent to the client
|
||||
await expect(errorPromise).rejects.not.toThrow("database.internal.example.com");
|
||||
});
|
||||
});
|
||||
});
|
||||
155
apps/api/src/common/interceptors/rls-context.interceptor.ts
Normal file
155
apps/api/src/common/interceptors/rls-context.interceptor.ts
Normal file
@@ -0,0 +1,155 @@
|
||||
import {
|
||||
Injectable,
|
||||
NestInterceptor,
|
||||
ExecutionContext,
|
||||
CallHandler,
|
||||
Logger,
|
||||
InternalServerErrorException,
|
||||
} from "@nestjs/common";
|
||||
import { Observable } from "rxjs";
|
||||
import { finalize } from "rxjs/operators";
|
||||
import type { PrismaClient } from "@prisma/client";
|
||||
import { PrismaService } from "../../prisma/prisma.service";
|
||||
import { runWithRlsClient } from "../../prisma/rls-context.provider";
|
||||
import type { AuthenticatedRequest } from "../types/user.types";
|
||||
|
||||
/**
|
||||
* Transaction-safe Prisma client type that excludes methods not available on transaction clients.
|
||||
* This prevents services from accidentally calling $connect, $disconnect, $transaction, etc.
|
||||
* on a transaction client, which would cause runtime errors.
|
||||
*/
|
||||
export type TransactionClient = Omit<
|
||||
PrismaClient,
|
||||
"$connect" | "$disconnect" | "$transaction" | "$on" | "$use"
|
||||
>;
|
||||
|
||||
/**
|
||||
* RlsContextInterceptor sets Row-Level Security (RLS) session variables for authenticated requests.
|
||||
*
|
||||
* This interceptor runs after AuthGuard and WorkspaceGuard, extracting the authenticated user
|
||||
* and workspace from the request and setting PostgreSQL session variables within a transaction:
|
||||
* - SET LOCAL app.current_user_id = '...'
|
||||
* - SET LOCAL app.current_workspace_id = '...'
|
||||
*
|
||||
* The transaction-scoped Prisma client is then propagated via AsyncLocalStorage, allowing
|
||||
* services to access it via getRlsClient() without explicit dependency injection.
|
||||
*
|
||||
* ## Security Design
|
||||
*
|
||||
* SET LOCAL is used instead of SET to ensure session variables are transaction-scoped.
|
||||
* This is critical for connection pooling safety - without transaction scoping, variables
|
||||
* would leak between requests that reuse the same connection from the pool.
|
||||
*
|
||||
* The entire request handler is executed within the transaction boundary, ensuring all
|
||||
* queries inherit the RLS context.
|
||||
*
|
||||
* ## Usage
|
||||
*
|
||||
* Registered globally as APP_INTERCEPTOR in AppModule (after TelemetryInterceptor).
|
||||
* Services access the RLS client via:
|
||||
*
|
||||
* ```typescript
|
||||
* const client = getRlsClient() ?? this.prisma;
|
||||
* return client.task.findMany(); // Filtered by RLS
|
||||
* ```
|
||||
*
|
||||
* ## Unauthenticated Routes
|
||||
*
|
||||
* Routes without AuthGuard (public endpoints) will not have request.user set.
|
||||
* The interceptor gracefully handles this by skipping RLS context setup.
|
||||
*
|
||||
* @see docs/design/credential-security.md for RLS architecture
|
||||
*/
|
||||
@Injectable()
|
||||
export class RlsContextInterceptor implements NestInterceptor {
|
||||
private readonly logger = new Logger(RlsContextInterceptor.name);
|
||||
|
||||
// Transaction timeout configuration
|
||||
// Longer timeout to support file uploads, complex queries, and bulk operations
|
||||
private readonly TRANSACTION_TIMEOUT_MS = 30000; // 30 seconds
|
||||
private readonly TRANSACTION_MAX_WAIT_MS = 10000; // 10 seconds to acquire connection
|
||||
|
||||
constructor(private readonly prisma: PrismaService) {}
|
||||
|
||||
/**
|
||||
* Intercept HTTP requests and set RLS context if user is authenticated.
|
||||
*
|
||||
* @param context - The execution context
|
||||
* @param next - The next call handler
|
||||
* @returns Observable of the response with RLS context applied
|
||||
*/
|
||||
intercept(context: ExecutionContext, next: CallHandler): Observable<unknown> {
|
||||
const request = context.switchToHttp().getRequest<AuthenticatedRequest>();
|
||||
const user = request.user;
|
||||
|
||||
// Skip RLS context setup for unauthenticated requests
|
||||
if (!user?.id) {
|
||||
this.logger.debug("Skipping RLS context: no authenticated user");
|
||||
return next.handle();
|
||||
}
|
||||
|
||||
const userId = user.id;
|
||||
const workspaceId = request.workspace?.id ?? user.workspaceId;
|
||||
|
||||
this.logger.debug(
|
||||
`Setting RLS context: user=${userId}${workspaceId ? `, workspace=${workspaceId}` : ""}`
|
||||
);
|
||||
|
||||
// Execute the entire request within a transaction with RLS context set
|
||||
return new Observable((subscriber) => {
|
||||
this.prisma
|
||||
.$transaction(
|
||||
async (tx) => {
|
||||
// Set user context (always present for authenticated requests)
|
||||
await tx.$executeRaw`SET LOCAL app.current_user_id = ${userId}`;
|
||||
|
||||
// Set workspace context (if present)
|
||||
if (workspaceId) {
|
||||
await tx.$executeRaw`SET LOCAL app.current_workspace_id = ${workspaceId}`;
|
||||
}
|
||||
|
||||
// Propagate the transaction client via AsyncLocalStorage
|
||||
// This allows services to access it via getRlsClient()
|
||||
// Use TransactionClient type to maintain type safety
|
||||
return runWithRlsClient(tx as TransactionClient, () => {
|
||||
return new Promise((resolve, reject) => {
|
||||
next
|
||||
.handle()
|
||||
.pipe(
|
||||
finalize(() => {
|
||||
this.logger.debug("RLS context cleared");
|
||||
})
|
||||
)
|
||||
.subscribe({
|
||||
next: (value) => {
|
||||
subscriber.next(value);
|
||||
resolve(value);
|
||||
},
|
||||
error: (error: unknown) => {
|
||||
const err = error instanceof Error ? error : new Error(String(error));
|
||||
subscriber.error(err);
|
||||
reject(err);
|
||||
},
|
||||
complete: () => {
|
||||
subscriber.complete();
|
||||
resolve(undefined);
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
},
|
||||
{
|
||||
timeout: this.TRANSACTION_TIMEOUT_MS,
|
||||
maxWait: this.TRANSACTION_MAX_WAIT_MS,
|
||||
}
|
||||
)
|
||||
.catch((error: unknown) => {
|
||||
const err = error instanceof Error ? error : new Error(String(error));
|
||||
this.logger.error(`Failed to set RLS context: ${err.message}`, err.stack);
|
||||
// Sanitize error before sending to client to prevent information disclosure
|
||||
// (schema info, internal variable names, connection details, etc.)
|
||||
subscriber.error(new InternalServerErrorException("Request processing failed"));
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
54
apps/api/src/common/providers/redis.provider.ts
Normal file
54
apps/api/src/common/providers/redis.provider.ts
Normal file
@@ -0,0 +1,54 @@
|
||||
/**
|
||||
* Redis Provider
|
||||
*
|
||||
* Provides Redis/Valkey client instance for the application.
|
||||
*/
|
||||
|
||||
import { Logger } from "@nestjs/common";
|
||||
import type { Provider } from "@nestjs/common";
|
||||
import Redis from "ioredis";
|
||||
|
||||
/**
|
||||
* Factory function to create Redis client instance
|
||||
*/
|
||||
function createRedisClient(): Redis {
|
||||
const logger = new Logger("RedisProvider");
|
||||
const valkeyUrl = process.env.VALKEY_URL ?? "redis://localhost:6379";
|
||||
|
||||
logger.log(`Connecting to Valkey at ${valkeyUrl}`);
|
||||
|
||||
const client = new Redis(valkeyUrl, {
|
||||
maxRetriesPerRequest: 3,
|
||||
retryStrategy: (times) => {
|
||||
const delay = Math.min(times * 50, 2000);
|
||||
logger.warn(
|
||||
`Valkey connection retry attempt ${times.toString()}, waiting ${delay.toString()}ms`
|
||||
);
|
||||
return delay;
|
||||
},
|
||||
reconnectOnError: (err) => {
|
||||
logger.error("Valkey connection error:", err.message);
|
||||
return true;
|
||||
},
|
||||
});
|
||||
|
||||
client.on("connect", () => {
|
||||
logger.log("Connected to Valkey");
|
||||
});
|
||||
|
||||
client.on("error", (err) => {
|
||||
logger.error("Valkey error:", err.message);
|
||||
});
|
||||
|
||||
return client;
|
||||
}
|
||||
|
||||
/**
|
||||
* Redis Client Provider
|
||||
*
|
||||
* Provides a singleton Redis client instance for dependency injection.
|
||||
*/
|
||||
export const RedisProvider: Provider = {
|
||||
provide: "REDIS_CLIENT",
|
||||
useFactory: createRedisClient,
|
||||
};
|
||||
209
apps/api/src/common/services/csrf.service.spec.ts
Normal file
209
apps/api/src/common/services/csrf.service.spec.ts
Normal file
@@ -0,0 +1,209 @@
|
||||
/**
|
||||
* CSRF Service Tests
|
||||
*
|
||||
* Tests CSRF token generation and validation with session binding.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
|
||||
import { CsrfService } from "./csrf.service";
|
||||
|
||||
describe("CsrfService", () => {
|
||||
let service: CsrfService;
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
process.env = { ...originalEnv };
|
||||
// Set a consistent secret for tests
|
||||
process.env.CSRF_SECRET = "test-secret-key-0123456789abcdef0123456789abcdef";
|
||||
service = new CsrfService();
|
||||
service.onModuleInit();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
describe("onModuleInit", () => {
|
||||
it("should initialize with configured secret", () => {
|
||||
const testService = new CsrfService();
|
||||
process.env.CSRF_SECRET = "configured-secret";
|
||||
expect(() => testService.onModuleInit()).not.toThrow();
|
||||
});
|
||||
|
||||
it("should throw in production without CSRF_SECRET", () => {
|
||||
const testService = new CsrfService();
|
||||
process.env.NODE_ENV = "production";
|
||||
delete process.env.CSRF_SECRET;
|
||||
expect(() => testService.onModuleInit()).toThrow(
|
||||
"CSRF_SECRET environment variable is required in production"
|
||||
);
|
||||
});
|
||||
|
||||
it("should generate random secret in development without CSRF_SECRET", () => {
|
||||
const testService = new CsrfService();
|
||||
process.env.NODE_ENV = "development";
|
||||
delete process.env.CSRF_SECRET;
|
||||
expect(() => testService.onModuleInit()).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("generateToken", () => {
|
||||
it("should generate a token with random:hmac format", () => {
|
||||
const token = service.generateToken("user-123");
|
||||
|
||||
expect(token).toContain(":");
|
||||
const parts = token.split(":");
|
||||
expect(parts).toHaveLength(2);
|
||||
});
|
||||
|
||||
it("should generate 64-char hex random part (32 bytes)", () => {
|
||||
const token = service.generateToken("user-123");
|
||||
const randomPart = token.split(":")[0];
|
||||
|
||||
expect(randomPart).toHaveLength(64);
|
||||
expect(/^[0-9a-f]{64}$/.test(randomPart as string)).toBe(true);
|
||||
});
|
||||
|
||||
it("should generate 64-char hex HMAC (SHA-256)", () => {
|
||||
const token = service.generateToken("user-123");
|
||||
const hmacPart = token.split(":")[1];
|
||||
|
||||
expect(hmacPart).toHaveLength(64);
|
||||
expect(/^[0-9a-f]{64}$/.test(hmacPart as string)).toBe(true);
|
||||
});
|
||||
|
||||
it("should generate unique tokens on each call", () => {
|
||||
const token1 = service.generateToken("user-123");
|
||||
const token2 = service.generateToken("user-123");
|
||||
|
||||
expect(token1).not.toBe(token2);
|
||||
});
|
||||
|
||||
it("should generate different HMACs for different sessions", () => {
|
||||
const token1 = service.generateToken("user-123");
|
||||
const token2 = service.generateToken("user-456");
|
||||
|
||||
const hmac1 = token1.split(":")[1];
|
||||
const hmac2 = token2.split(":")[1];
|
||||
|
||||
// Even with same random part, HMACs would differ due to session binding
|
||||
// But since random parts differ, this just confirms they're different tokens
|
||||
expect(hmac1).not.toBe(hmac2);
|
||||
});
|
||||
});
|
||||
|
||||
describe("validateToken", () => {
|
||||
it("should validate a token for the correct session", () => {
|
||||
const sessionId = "user-123";
|
||||
const token = service.generateToken(sessionId);
|
||||
|
||||
expect(service.validateToken(token, sessionId)).toBe(true);
|
||||
});
|
||||
|
||||
it("should reject a token for a different session", () => {
|
||||
const token = service.generateToken("user-123");
|
||||
|
||||
expect(service.validateToken(token, "user-456")).toBe(false);
|
||||
});
|
||||
|
||||
it("should reject empty token", () => {
|
||||
expect(service.validateToken("", "user-123")).toBe(false);
|
||||
});
|
||||
|
||||
it("should reject empty session ID", () => {
|
||||
const token = service.generateToken("user-123");
|
||||
expect(service.validateToken(token, "")).toBe(false);
|
||||
});
|
||||
|
||||
it("should reject token without colon separator", () => {
|
||||
expect(service.validateToken("invalidtoken", "user-123")).toBe(false);
|
||||
});
|
||||
|
||||
it("should reject token with empty random part", () => {
|
||||
expect(service.validateToken(":somehash", "user-123")).toBe(false);
|
||||
});
|
||||
|
||||
it("should reject token with empty HMAC part", () => {
|
||||
expect(service.validateToken("somerandom:", "user-123")).toBe(false);
|
||||
});
|
||||
|
||||
it("should reject token with invalid hex in random part", () => {
|
||||
expect(
|
||||
service.validateToken(
|
||||
"invalid-hex-here-not-64-chars:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
|
||||
"user-123"
|
||||
)
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("should reject token with invalid hex in HMAC part", () => {
|
||||
expect(
|
||||
service.validateToken(
|
||||
"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef:not-valid-hex",
|
||||
"user-123"
|
||||
)
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("should reject token with tampered HMAC", () => {
|
||||
const token = service.generateToken("user-123");
|
||||
const parts = token.split(":");
|
||||
// Tamper with the HMAC
|
||||
const tamperedToken = `${parts[0]}:0000000000000000000000000000000000000000000000000000000000000000`;
|
||||
|
||||
expect(service.validateToken(tamperedToken, "user-123")).toBe(false);
|
||||
});
|
||||
|
||||
it("should reject token with tampered random part", () => {
|
||||
const token = service.generateToken("user-123");
|
||||
const parts = token.split(":");
|
||||
// Tamper with the random part
|
||||
const tamperedToken = `0000000000000000000000000000000000000000000000000000000000000000:${parts[1]}`;
|
||||
|
||||
expect(service.validateToken(tamperedToken, "user-123")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("session binding security", () => {
|
||||
it("should bind token to specific session", () => {
|
||||
const token = service.generateToken("session-A");
|
||||
|
||||
// Token valid for session-A
|
||||
expect(service.validateToken(token, "session-A")).toBe(true);
|
||||
|
||||
// Token invalid for any other session
|
||||
expect(service.validateToken(token, "session-B")).toBe(false);
|
||||
expect(service.validateToken(token, "session-C")).toBe(false);
|
||||
expect(service.validateToken(token, "")).toBe(false);
|
||||
});
|
||||
|
||||
it("should not allow token reuse across sessions", () => {
|
||||
const userAToken = service.generateToken("user-A");
|
||||
const userBToken = service.generateToken("user-B");
|
||||
|
||||
// Each token only valid for its own session
|
||||
expect(service.validateToken(userAToken, "user-A")).toBe(true);
|
||||
expect(service.validateToken(userAToken, "user-B")).toBe(false);
|
||||
|
||||
expect(service.validateToken(userBToken, "user-B")).toBe(true);
|
||||
expect(service.validateToken(userBToken, "user-A")).toBe(false);
|
||||
});
|
||||
|
||||
it("should use different secrets to generate different tokens", () => {
|
||||
// Generate token with current secret
|
||||
const token1 = service.generateToken("user-123");
|
||||
|
||||
// Create new service with different secret
|
||||
process.env.CSRF_SECRET = "different-secret-key-abcdef0123456789";
|
||||
const service2 = new CsrfService();
|
||||
service2.onModuleInit();
|
||||
|
||||
// Token from service1 should not validate with service2
|
||||
expect(service2.validateToken(token1, "user-123")).toBe(false);
|
||||
|
||||
// But service2's own tokens should validate
|
||||
const token2 = service2.generateToken("user-123");
|
||||
expect(service2.validateToken(token2, "user-123")).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
116
apps/api/src/common/services/csrf.service.ts
Normal file
116
apps/api/src/common/services/csrf.service.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
/**
|
||||
* CSRF Service
|
||||
*
|
||||
* Handles CSRF token generation and validation with session binding.
|
||||
* Tokens are cryptographically tied to the user session via HMAC.
|
||||
*
|
||||
* Token format: {random_part}:{hmac(random_part + session_id, secret)}
|
||||
*/
|
||||
|
||||
import { Injectable, Logger, OnModuleInit } from "@nestjs/common";
|
||||
import * as crypto from "crypto";
|
||||
|
||||
@Injectable()
|
||||
export class CsrfService implements OnModuleInit {
|
||||
private readonly logger = new Logger(CsrfService.name);
|
||||
private csrfSecret = "";
|
||||
|
||||
onModuleInit(): void {
|
||||
const secret = process.env.CSRF_SECRET;
|
||||
|
||||
if (process.env.NODE_ENV === "production" && !secret) {
|
||||
throw new Error(
|
||||
"CSRF_SECRET environment variable is required in production. " +
|
||||
"Generate with: node -e \"console.log(require('crypto').randomBytes(32).toString('hex'))\""
|
||||
);
|
||||
}
|
||||
|
||||
// Use provided secret or generate a random one for development
|
||||
if (secret) {
|
||||
this.csrfSecret = secret;
|
||||
this.logger.log("CSRF service initialized with configured secret");
|
||||
} else {
|
||||
this.csrfSecret = crypto.randomBytes(32).toString("hex");
|
||||
this.logger.warn(
|
||||
"CSRF service initialized with random secret (development mode). " +
|
||||
"Set CSRF_SECRET for persistent tokens across restarts."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a CSRF token bound to a session identifier
|
||||
* @param sessionId - User session identifier (e.g., user ID or session token)
|
||||
* @returns Token in format: {random}:{hmac}
|
||||
*/
|
||||
generateToken(sessionId: string): string {
|
||||
// Generate cryptographically secure random part (32 bytes = 64 hex chars)
|
||||
const randomPart = crypto.randomBytes(32).toString("hex");
|
||||
|
||||
// Create HMAC binding the random part to the session
|
||||
const hmac = this.createHmac(randomPart, sessionId);
|
||||
|
||||
return `${randomPart}:${hmac}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a CSRF token against a session identifier
|
||||
* @param token - The full CSRF token (random:hmac format)
|
||||
* @param sessionId - User session identifier to validate against
|
||||
* @returns true if token is valid and bound to the session
|
||||
*/
|
||||
validateToken(token: string, sessionId: string): boolean {
|
||||
if (!token || !sessionId) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Parse token parts
|
||||
const colonIndex = token.indexOf(":");
|
||||
if (colonIndex === -1) {
|
||||
this.logger.debug("Invalid token format: missing colon separator");
|
||||
return false;
|
||||
}
|
||||
|
||||
const randomPart = token.substring(0, colonIndex);
|
||||
const providedHmac = token.substring(colonIndex + 1);
|
||||
|
||||
if (!randomPart || !providedHmac) {
|
||||
this.logger.debug("Invalid token format: empty random part or HMAC");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify the random part is valid hex (64 characters for 32 bytes)
|
||||
if (!/^[0-9a-fA-F]{64}$/.test(randomPart)) {
|
||||
this.logger.debug("Invalid token format: random part is not valid hex");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compute expected HMAC
|
||||
const expectedHmac = this.createHmac(randomPart, sessionId);
|
||||
|
||||
// Use timing-safe comparison to prevent timing attacks
|
||||
try {
|
||||
return crypto.timingSafeEqual(
|
||||
Buffer.from(providedHmac, "hex"),
|
||||
Buffer.from(expectedHmac, "hex")
|
||||
);
|
||||
} catch {
|
||||
// Buffer creation fails if providedHmac is not valid hex
|
||||
this.logger.debug("Invalid token format: HMAC is not valid hex");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create HMAC for token binding
|
||||
* @param randomPart - The random part of the token
|
||||
* @param sessionId - The session identifier
|
||||
* @returns Hex-encoded HMAC
|
||||
*/
|
||||
private createHmac(randomPart: string, sessionId: string): string {
|
||||
return crypto
|
||||
.createHmac("sha256", this.csrfSecret)
|
||||
.update(`${randomPart}:${sessionId}`)
|
||||
.digest("hex");
|
||||
}
|
||||
}
|
||||
1170
apps/api/src/common/tests/workspace-isolation.spec.ts
Normal file
1170
apps/api/src/common/tests/workspace-isolation.spec.ts
Normal file
File diff suppressed because it is too large
Load Diff
257
apps/api/src/common/throttler/throttler-storage.service.spec.ts
Normal file
257
apps/api/src/common/throttler/throttler-storage.service.spec.ts
Normal file
@@ -0,0 +1,257 @@
|
||||
import { describe, it, expect, beforeEach, vi, afterEach, Mock } from "vitest";
|
||||
import { ThrottlerValkeyStorageService } from "./throttler-storage.service";
|
||||
|
||||
// Create a mock Redis class
|
||||
const createMockRedis = (
|
||||
options: {
|
||||
shouldFailConnect?: boolean;
|
||||
error?: Error;
|
||||
} = {}
|
||||
): Record<string, Mock> => ({
|
||||
connect: vi.fn().mockImplementation(() => {
|
||||
if (options.shouldFailConnect) {
|
||||
return Promise.reject(options.error ?? new Error("Connection refused"));
|
||||
}
|
||||
return Promise.resolve();
|
||||
}),
|
||||
ping: vi.fn().mockResolvedValue("PONG"),
|
||||
quit: vi.fn().mockResolvedValue("OK"),
|
||||
multi: vi.fn().mockReturnThis(),
|
||||
incr: vi.fn().mockReturnThis(),
|
||||
pexpire: vi.fn().mockReturnThis(),
|
||||
exec: vi.fn().mockResolvedValue([
|
||||
[null, 1],
|
||||
[null, 1],
|
||||
]),
|
||||
get: vi.fn().mockResolvedValue("5"),
|
||||
});
|
||||
|
||||
// Mock ioredis module
|
||||
vi.mock("ioredis", () => {
|
||||
return {
|
||||
default: vi.fn().mockImplementation(() => createMockRedis({ shouldFailConnect: true })),
|
||||
};
|
||||
});
|
||||
|
||||
describe("ThrottlerValkeyStorageService", () => {
|
||||
let service: ThrottlerValkeyStorageService;
|
||||
let loggerErrorSpy: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
service = new ThrottlerValkeyStorageService();
|
||||
|
||||
// Spy on logger methods - access the private logger
|
||||
const logger = (
|
||||
service as unknown as { logger: { error: () => void; log: () => void; warn: () => void } }
|
||||
).logger;
|
||||
loggerErrorSpy = vi.spyOn(logger, "error").mockImplementation(() => undefined);
|
||||
vi.spyOn(logger, "log").mockImplementation(() => undefined);
|
||||
vi.spyOn(logger, "warn").mockImplementation(() => undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("initialization and fallback behavior", () => {
|
||||
it("should start in fallback mode before initialization", () => {
|
||||
// Before onModuleInit is called, useRedis is false by default
|
||||
expect(service.isUsingFallback()).toBe(true);
|
||||
});
|
||||
|
||||
it("should log ERROR when Redis connection fails", async () => {
|
||||
const newService = new ThrottlerValkeyStorageService();
|
||||
const newLogger = (
|
||||
newService as unknown as { logger: { error: () => void; log: () => void } }
|
||||
).logger;
|
||||
const newErrorSpy = vi.spyOn(newLogger, "error").mockImplementation(() => undefined);
|
||||
vi.spyOn(newLogger, "log").mockImplementation(() => undefined);
|
||||
|
||||
await newService.onModuleInit();
|
||||
|
||||
// Verify ERROR was logged (not WARN)
|
||||
expect(newErrorSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("Failed to connect to Valkey for rate limiting")
|
||||
);
|
||||
expect(newErrorSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("DEGRADED MODE: Falling back to in-memory rate limiting storage")
|
||||
);
|
||||
});
|
||||
|
||||
it("should log message indicating rate limits will not be shared", async () => {
|
||||
const newService = new ThrottlerValkeyStorageService();
|
||||
const newLogger = (
|
||||
newService as unknown as { logger: { error: () => void; log: () => void } }
|
||||
).logger;
|
||||
const newErrorSpy = vi.spyOn(newLogger, "error").mockImplementation(() => undefined);
|
||||
vi.spyOn(newLogger, "log").mockImplementation(() => undefined);
|
||||
|
||||
await newService.onModuleInit();
|
||||
|
||||
expect(newErrorSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("Rate limits will not be shared across API instances")
|
||||
);
|
||||
});
|
||||
|
||||
it("should be in fallback mode when Redis connection fails", async () => {
|
||||
const newService = new ThrottlerValkeyStorageService();
|
||||
const newLogger = (
|
||||
newService as unknown as { logger: { error: () => void; log: () => void } }
|
||||
).logger;
|
||||
vi.spyOn(newLogger, "error").mockImplementation(() => undefined);
|
||||
vi.spyOn(newLogger, "log").mockImplementation(() => undefined);
|
||||
|
||||
await newService.onModuleInit();
|
||||
|
||||
expect(newService.isUsingFallback()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("isUsingFallback()", () => {
|
||||
it("should return true when in memory fallback mode", () => {
|
||||
// Default state is fallback mode
|
||||
expect(service.isUsingFallback()).toBe(true);
|
||||
});
|
||||
|
||||
it("should return boolean type", () => {
|
||||
const result = service.isUsingFallback();
|
||||
expect(typeof result).toBe("boolean");
|
||||
});
|
||||
});
|
||||
|
||||
describe("getHealthStatus()", () => {
|
||||
it("should return degraded status when in fallback mode", () => {
|
||||
// Default state is fallback mode
|
||||
const status = service.getHealthStatus();
|
||||
|
||||
expect(status).toEqual({
|
||||
healthy: true,
|
||||
mode: "memory",
|
||||
degraded: true,
|
||||
message: expect.stringContaining("in-memory fallback"),
|
||||
});
|
||||
});
|
||||
|
||||
it("should indicate degraded mode message includes lack of sharing", () => {
|
||||
const status = service.getHealthStatus();
|
||||
|
||||
expect(status.message).toContain("not shared across instances");
|
||||
});
|
||||
|
||||
it("should always report healthy even in degraded mode", () => {
|
||||
// In degraded mode, the service is still functional
|
||||
const status = service.getHealthStatus();
|
||||
expect(status.healthy).toBe(true);
|
||||
});
|
||||
|
||||
it("should have correct structure for health checks", () => {
|
||||
const status = service.getHealthStatus();
|
||||
|
||||
expect(status).toHaveProperty("healthy");
|
||||
expect(status).toHaveProperty("mode");
|
||||
expect(status).toHaveProperty("degraded");
|
||||
expect(status).toHaveProperty("message");
|
||||
});
|
||||
|
||||
it("should report mode as memory when in fallback", () => {
|
||||
const status = service.getHealthStatus();
|
||||
expect(status.mode).toBe("memory");
|
||||
});
|
||||
|
||||
it("should report degraded as true when in fallback", () => {
|
||||
const status = service.getHealthStatus();
|
||||
expect(status.degraded).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getHealthStatus() with Redis (unit test via internal state)", () => {
|
||||
it("should return non-degraded status when Redis is available", () => {
|
||||
// Manually set the internal state to simulate Redis being available
|
||||
// This tests the method logic without requiring actual Redis connection
|
||||
const testService = new ThrottlerValkeyStorageService();
|
||||
|
||||
// Access private property for testing (this is acceptable for unit testing)
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(testService as any).useRedis = true;
|
||||
|
||||
const status = testService.getHealthStatus();
|
||||
|
||||
expect(status).toEqual({
|
||||
healthy: true,
|
||||
mode: "redis",
|
||||
degraded: false,
|
||||
message: expect.stringContaining("Redis storage"),
|
||||
});
|
||||
});
|
||||
|
||||
it("should report distributed mode message when Redis is available", () => {
|
||||
const testService = new ThrottlerValkeyStorageService();
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(testService as any).useRedis = true;
|
||||
|
||||
const status = testService.getHealthStatus();
|
||||
|
||||
expect(status.message).toContain("distributed mode");
|
||||
});
|
||||
|
||||
it("should report isUsingFallback as false when Redis is available", () => {
|
||||
const testService = new ThrottlerValkeyStorageService();
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(testService as any).useRedis = true;
|
||||
|
||||
expect(testService.isUsingFallback()).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("in-memory fallback operations", () => {
|
||||
it("should increment correctly in fallback mode", async () => {
|
||||
const result = await service.increment("test-key", 60000, 10, 0, "default");
|
||||
|
||||
expect(result.totalHits).toBe(1);
|
||||
expect(result.isBlocked).toBe(false);
|
||||
});
|
||||
|
||||
it("should accumulate hits in fallback mode", async () => {
|
||||
await service.increment("test-key", 60000, 10, 0, "default");
|
||||
await service.increment("test-key", 60000, 10, 0, "default");
|
||||
const result = await service.increment("test-key", 60000, 10, 0, "default");
|
||||
|
||||
expect(result.totalHits).toBe(3);
|
||||
});
|
||||
|
||||
it("should return correct blocked status when limit exceeded", async () => {
|
||||
// Make 3 requests with limit of 2
|
||||
await service.increment("test-key", 60000, 2, 1000, "default");
|
||||
await service.increment("test-key", 60000, 2, 1000, "default");
|
||||
const result = await service.increment("test-key", 60000, 2, 1000, "default");
|
||||
|
||||
expect(result.totalHits).toBe(3);
|
||||
expect(result.isBlocked).toBe(true);
|
||||
expect(result.timeToBlockExpire).toBe(1000);
|
||||
});
|
||||
|
||||
it("should return 0 for get on non-existent key in fallback mode", async () => {
|
||||
const result = await service.get("non-existent-key");
|
||||
expect(result).toBe(0);
|
||||
});
|
||||
|
||||
it("should return correct timeToExpire in response", async () => {
|
||||
const ttl = 30000;
|
||||
const result = await service.increment("test-key", ttl, 10, 0, "default");
|
||||
|
||||
expect(result.timeToExpire).toBe(ttl);
|
||||
});
|
||||
|
||||
it("should isolate different keys in fallback mode", async () => {
|
||||
await service.increment("key-1", 60000, 10, 0, "default");
|
||||
await service.increment("key-1", 60000, 10, 0, "default");
|
||||
const result1 = await service.increment("key-1", 60000, 10, 0, "default");
|
||||
|
||||
const result2 = await service.increment("key-2", 60000, 10, 0, "default");
|
||||
|
||||
expect(result1.totalHits).toBe(3);
|
||||
expect(result2.totalHits).toBe(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -16,11 +16,18 @@ interface ThrottlerStorageRecord {
|
||||
/**
|
||||
* Redis-based storage for rate limiting using Valkey
|
||||
*
|
||||
* This service uses Valkey (Redis-compatible) as the storage backend
|
||||
* for rate limiting. This allows rate limits to work across multiple
|
||||
* API instances in a distributed environment.
|
||||
* This service uses Valkey (Redis-compatible) as the primary storage backend
|
||||
* for rate limiting, which provides atomic operations and allows rate limits
|
||||
* to work correctly across multiple API instances in a distributed environment.
|
||||
*
|
||||
* If Redis is unavailable, falls back to in-memory storage.
|
||||
* **Fallback behavior:** If Valkey is unavailable (connection failure or command
|
||||
* error), the service falls back to in-memory storage. The in-memory mode is
|
||||
* **best-effort only** — it uses a non-atomic read-modify-write pattern that may
|
||||
* allow slightly more requests than the configured limit under high concurrency.
|
||||
* This is an acceptable trade-off because the fallback path is only used when
|
||||
* the primary distributed store is down, and adding mutex/locking complexity for
|
||||
* a degraded-mode code path provides minimal benefit. In-memory rate limits are
|
||||
* also not shared across API instances.
|
||||
*/
|
||||
@Injectable()
|
||||
export class ThrottlerValkeyStorageService implements ThrottlerStorage, OnModuleInit {
|
||||
@@ -53,8 +60,11 @@ export class ThrottlerValkeyStorageService implements ThrottlerStorage, OnModule
|
||||
this.logger.log("Valkey connected successfully for rate limiting");
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
this.logger.warn(`Failed to connect to Valkey for rate limiting: ${errorMessage}`);
|
||||
this.logger.warn("Falling back to in-memory rate limiting storage");
|
||||
this.logger.error(`Failed to connect to Valkey for rate limiting: ${errorMessage}`);
|
||||
this.logger.error(
|
||||
"DEGRADED MODE: Falling back to in-memory rate limiting storage. " +
|
||||
"Rate limits will not be shared across API instances."
|
||||
);
|
||||
this.useRedis = false;
|
||||
this.client = undefined;
|
||||
}
|
||||
@@ -92,7 +102,10 @@ export class ThrottlerValkeyStorageService implements ThrottlerStorage, OnModule
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
this.logger.error(`Redis increment failed: ${errorMessage}`);
|
||||
// Fall through to in-memory
|
||||
this.logger.warn(
|
||||
"Falling back to in-memory rate limiting for this request. " +
|
||||
"In-memory mode is best-effort and may be slightly permissive under high concurrency."
|
||||
);
|
||||
totalHits = this.incrementMemory(throttleKey, ttl);
|
||||
}
|
||||
} else {
|
||||
@@ -126,7 +139,10 @@ export class ThrottlerValkeyStorageService implements ThrottlerStorage, OnModule
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
this.logger.error(`Redis get failed: ${errorMessage}`);
|
||||
// Fall through to in-memory
|
||||
this.logger.warn(
|
||||
"Falling back to in-memory rate limiting for this request. " +
|
||||
"In-memory mode is best-effort and may be slightly permissive under high concurrency."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,7 +151,26 @@ export class ThrottlerValkeyStorageService implements ThrottlerStorage, OnModule
|
||||
}
|
||||
|
||||
/**
|
||||
* In-memory increment implementation
|
||||
* In-memory increment implementation (best-effort rate limiting).
|
||||
*
|
||||
* **Race condition note:** This method uses a non-atomic read-modify-write
|
||||
* pattern (read from Map -> filter -> push -> write to Map). Under high
|
||||
* concurrency, multiple async operations could read the same snapshot of
|
||||
* timestamps before any of them write back, causing some increments to be
|
||||
* lost. This means the rate limiter may allow slightly more requests than
|
||||
* the configured limit.
|
||||
*
|
||||
* This is intentionally left without a mutex/lock because:
|
||||
* 1. This is the **fallback** path, only used when Valkey is unavailable.
|
||||
* 2. The primary Valkey path uses atomic INCR operations and is race-free.
|
||||
* 3. Adding locking complexity to a rarely-used degraded code path provides
|
||||
* minimal benefit while increasing maintenance burden.
|
||||
* 4. In degraded mode, "slightly permissive" rate limiting is preferable
|
||||
* to added latency or deadlock risk from synchronization primitives.
|
||||
*
|
||||
* @param key - The throttle key to increment
|
||||
* @param ttl - Time-to-live in milliseconds for the sliding window
|
||||
* @returns The current hit count (may be slightly undercounted under concurrency)
|
||||
*/
|
||||
private incrementMemory(key: string, ttl: number): number {
|
||||
const now = Date.now();
|
||||
@@ -147,7 +182,8 @@ export class ThrottlerValkeyStorageService implements ThrottlerStorage, OnModule
|
||||
// Add new timestamp
|
||||
validTimestamps.push(now);
|
||||
|
||||
// Store updated timestamps
|
||||
// NOTE: Non-atomic write — concurrent calls may overwrite each other's updates.
|
||||
// See method JSDoc for why this is acceptable in the fallback path.
|
||||
this.fallbackStorage.set(key, validTimestamps);
|
||||
|
||||
return validTimestamps.length;
|
||||
@@ -168,6 +204,46 @@ export class ThrottlerValkeyStorageService implements ThrottlerStorage, OnModule
|
||||
return `${this.THROTTLER_PREFIX}${key}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the service is using fallback in-memory storage
|
||||
*
|
||||
* This indicates a degraded state where rate limits are not shared
|
||||
* across API instances. Use this for health checks.
|
||||
*
|
||||
* @returns true if using in-memory fallback, false if using Redis
|
||||
*/
|
||||
isUsingFallback(): boolean {
|
||||
return !this.useRedis;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get rate limiter health status for health check endpoints
|
||||
*
|
||||
* @returns Health status object with storage mode and details
|
||||
*/
|
||||
getHealthStatus(): {
|
||||
healthy: boolean;
|
||||
mode: "redis" | "memory";
|
||||
degraded: boolean;
|
||||
message: string;
|
||||
} {
|
||||
if (this.useRedis) {
|
||||
return {
|
||||
healthy: true,
|
||||
mode: "redis",
|
||||
degraded: false,
|
||||
message: "Rate limiter using Redis storage (distributed mode)",
|
||||
};
|
||||
}
|
||||
return {
|
||||
healthy: true, // Service is functional, but degraded
|
||||
mode: "memory",
|
||||
degraded: true,
|
||||
message:
|
||||
"Rate limiter using in-memory fallback (degraded mode - limits not shared across instances)",
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up on module destroy
|
||||
*/
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user