feat(routing): routing_rules schema + types — M4-001/002/003 (#315)
Some checks failed
ci/woodpecker/push/ci Pipeline failed
Some checks failed
ci/woodpecker/push/ci Pipeline failed
Co-authored-by: Jason Woltje <jason@diversecanvas.com> Co-committed-by: Jason Woltje <jason@diversecanvas.com>
This commit was merged in pull request #315.
This commit is contained in:
@@ -2,3 +2,4 @@ export { OllamaAdapter } from './ollama.adapter.js';
|
||||
export { AnthropicAdapter } from './anthropic.adapter.js';
|
||||
export { OpenAIAdapter } from './openai.adapter.js';
|
||||
export { OpenRouterAdapter } from './openrouter.adapter.js';
|
||||
export { ZaiAdapter } from './zai.adapter.js';
|
||||
|
||||
187
apps/gateway/src/agent/adapters/zai.adapter.ts
Normal file
187
apps/gateway/src/agent/adapters/zai.adapter.ts
Normal file
@@ -0,0 +1,187 @@
|
||||
import { Logger } from '@nestjs/common';
|
||||
import OpenAI from 'openai';
|
||||
import type {
|
||||
CompletionEvent,
|
||||
CompletionParams,
|
||||
IProviderAdapter,
|
||||
ModelInfo,
|
||||
ProviderHealth,
|
||||
} from '@mosaic/types';
|
||||
import { getModelCapability } from '../model-capabilities.js';
|
||||
|
||||
/**
|
||||
* Default Z.ai API base URL.
|
||||
* Z.ai (BigModel / Zhipu AI) exposes an OpenAI-compatible API at this endpoint.
|
||||
* Can be overridden via the ZAI_BASE_URL environment variable.
|
||||
*/
|
||||
const DEFAULT_ZAI_BASE_URL = 'https://open.bigmodel.cn/api/paas/v4';
|
||||
|
||||
/**
|
||||
* GLM-5 model identifier on the Z.ai platform.
|
||||
*/
|
||||
const GLM5_MODEL_ID = 'glm-5';
|
||||
|
||||
/**
|
||||
* Z.ai (Zhipu AI / BigModel) provider adapter.
|
||||
*
|
||||
* Z.ai exposes an OpenAI-compatible REST API. This adapter uses the `openai`
|
||||
* SDK with a custom base URL and the ZAI_API_KEY environment variable.
|
||||
*
|
||||
* Configuration:
|
||||
* ZAI_API_KEY — required; Z.ai API key
|
||||
* ZAI_BASE_URL — optional; override the default API base URL
|
||||
*/
|
||||
export class ZaiAdapter implements IProviderAdapter {
|
||||
readonly name = 'zai';
|
||||
|
||||
private readonly logger = new Logger(ZaiAdapter.name);
|
||||
private client: OpenAI | null = null;
|
||||
private registeredModels: ModelInfo[] = [];
|
||||
|
||||
async register(): Promise<void> {
|
||||
const apiKey = process.env['ZAI_API_KEY'];
|
||||
if (!apiKey) {
|
||||
this.logger.debug('Skipping Z.ai provider registration: ZAI_API_KEY not set');
|
||||
return;
|
||||
}
|
||||
|
||||
const baseURL = process.env['ZAI_BASE_URL'] ?? DEFAULT_ZAI_BASE_URL;
|
||||
|
||||
this.client = new OpenAI({ apiKey, baseURL });
|
||||
|
||||
this.registeredModels = this.buildModelList();
|
||||
this.logger.log(`Z.ai provider registered with ${this.registeredModels.length} model(s)`);
|
||||
}
|
||||
|
||||
listModels(): ModelInfo[] {
|
||||
return this.registeredModels;
|
||||
}
|
||||
|
||||
async healthCheck(): Promise<ProviderHealth> {
|
||||
const apiKey = process.env['ZAI_API_KEY'];
|
||||
if (!apiKey) {
|
||||
return {
|
||||
status: 'down',
|
||||
lastChecked: new Date().toISOString(),
|
||||
error: 'ZAI_API_KEY not configured',
|
||||
};
|
||||
}
|
||||
|
||||
const baseURL = process.env['ZAI_BASE_URL'] ?? DEFAULT_ZAI_BASE_URL;
|
||||
const start = Date.now();
|
||||
|
||||
try {
|
||||
const res = await fetch(`${baseURL}/models`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
Accept: 'application/json',
|
||||
},
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
const latencyMs = Date.now() - start;
|
||||
|
||||
if (!res.ok) {
|
||||
return {
|
||||
status: 'degraded',
|
||||
latencyMs,
|
||||
lastChecked: new Date().toISOString(),
|
||||
error: `HTTP ${res.status}`,
|
||||
};
|
||||
}
|
||||
|
||||
return { status: 'healthy', latencyMs, lastChecked: new Date().toISOString() };
|
||||
} catch (err) {
|
||||
const latencyMs = Date.now() - start;
|
||||
const error = err instanceof Error ? err.message : String(err);
|
||||
return { status: 'down', latencyMs, lastChecked: new Date().toISOString(), error };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a completion through Z.ai's OpenAI-compatible API.
|
||||
*/
|
||||
async *createCompletion(params: CompletionParams): AsyncIterable<CompletionEvent> {
|
||||
if (!this.client) {
|
||||
throw new Error('ZaiAdapter is not initialized. Ensure ZAI_API_KEY is set.');
|
||||
}
|
||||
|
||||
const stream = await this.client.chat.completions.create({
|
||||
model: params.model,
|
||||
messages: params.messages.map((m) => ({ role: m.role, content: m.content })),
|
||||
temperature: params.temperature,
|
||||
max_tokens: params.maxTokens,
|
||||
stream: true,
|
||||
});
|
||||
|
||||
let inputTokens = 0;
|
||||
let outputTokens = 0;
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const choice = chunk.choices[0];
|
||||
if (!choice) continue;
|
||||
|
||||
const delta = choice.delta;
|
||||
|
||||
if (delta.content) {
|
||||
yield { type: 'text_delta', content: delta.content };
|
||||
}
|
||||
|
||||
if (choice.finish_reason === 'stop') {
|
||||
const usage = (chunk as { usage?: { prompt_tokens?: number; completion_tokens?: number } })
|
||||
.usage;
|
||||
if (usage) {
|
||||
inputTokens = usage.prompt_tokens ?? 0;
|
||||
outputTokens = usage.completion_tokens ?? 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
yield {
|
||||
type: 'done',
|
||||
usage: { inputTokens, outputTokens },
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Private helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
private buildModelList(): ModelInfo[] {
|
||||
const capability = getModelCapability(GLM5_MODEL_ID);
|
||||
|
||||
if (!capability) {
|
||||
this.logger.warn(`Model capability entry not found for '${GLM5_MODEL_ID}'; using defaults`);
|
||||
return [
|
||||
{
|
||||
id: GLM5_MODEL_ID,
|
||||
provider: 'zai',
|
||||
name: 'GLM-5',
|
||||
reasoning: false,
|
||||
contextWindow: 128000,
|
||||
maxTokens: 8192,
|
||||
inputTypes: ['text'],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
return [
|
||||
{
|
||||
id: capability.id,
|
||||
provider: 'zai',
|
||||
name: capability.displayName,
|
||||
reasoning: capability.capabilities.reasoning,
|
||||
contextWindow: capability.contextWindow,
|
||||
maxTokens: capability.maxOutputTokens,
|
||||
inputTypes: capability.capabilities.vision ? ['text', 'image'] : ['text'],
|
||||
cost: {
|
||||
input: capability.costPer1kInput ?? 0,
|
||||
output: capability.costPer1kOutput ?? 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
},
|
||||
];
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Global, Module } from '@nestjs/common';
|
||||
import { AgentService } from './agent.service.js';
|
||||
import { ProviderService } from './provider.service.js';
|
||||
import { ProviderCredentialsService } from './provider-credentials.service.js';
|
||||
import { RoutingService } from './routing.service.js';
|
||||
import { SkillLoaderService } from './skill-loader.service.js';
|
||||
import { ProvidersController } from './providers.controller.js';
|
||||
@@ -14,8 +15,20 @@ import { GCModule } from '../gc/gc.module.js';
|
||||
@Global()
|
||||
@Module({
|
||||
imports: [CoordModule, McpClientModule, SkillsModule, GCModule],
|
||||
providers: [ProviderService, RoutingService, SkillLoaderService, AgentService],
|
||||
providers: [
|
||||
ProviderService,
|
||||
ProviderCredentialsService,
|
||||
RoutingService,
|
||||
SkillLoaderService,
|
||||
AgentService,
|
||||
],
|
||||
controllers: [ProvidersController, SessionsController, AgentConfigsController],
|
||||
exports: [AgentService, ProviderService, RoutingService, SkillLoaderService],
|
||||
exports: [
|
||||
AgentService,
|
||||
ProviderService,
|
||||
ProviderCredentialsService,
|
||||
RoutingService,
|
||||
SkillLoaderService,
|
||||
],
|
||||
})
|
||||
export class AgentModule {}
|
||||
|
||||
23
apps/gateway/src/agent/provider-credentials.dto.ts
Normal file
23
apps/gateway/src/agent/provider-credentials.dto.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
/** DTO for storing a provider credential. */
|
||||
export interface StoreCredentialDto {
|
||||
/** Provider identifier (e.g., 'anthropic', 'openai', 'openrouter', 'zai') */
|
||||
provider: string;
|
||||
/** Credential type */
|
||||
type: 'api_key' | 'oauth_token';
|
||||
/** Plain-text credential value — will be encrypted before storage */
|
||||
value: string;
|
||||
/** Optional extra config (e.g., base URL overrides) */
|
||||
metadata?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
/** DTO returned in list/existence responses — never contains decrypted values. */
|
||||
export interface ProviderCredentialSummaryDto {
|
||||
provider: string;
|
||||
credentialType: 'api_key' | 'oauth_token';
|
||||
/** Whether a credential is stored for this provider */
|
||||
exists: boolean;
|
||||
expiresAt?: string | null;
|
||||
metadata?: Record<string, unknown> | null;
|
||||
createdAt: string;
|
||||
updatedAt: string;
|
||||
}
|
||||
175
apps/gateway/src/agent/provider-credentials.service.ts
Normal file
175
apps/gateway/src/agent/provider-credentials.service.ts
Normal file
@@ -0,0 +1,175 @@
|
||||
import { Inject, Injectable, Logger } from '@nestjs/common';
|
||||
import { createCipheriv, createDecipheriv, createHash, randomBytes } from 'node:crypto';
|
||||
import type { Db } from '@mosaic/db';
|
||||
import { providerCredentials, eq, and } from '@mosaic/db';
|
||||
import { DB } from '../database/database.module.js';
|
||||
import type { ProviderCredentialSummaryDto } from './provider-credentials.dto.js';
|
||||
|
||||
const ALGORITHM = 'aes-256-gcm';
|
||||
const IV_LENGTH = 12; // 96-bit IV for GCM
|
||||
const TAG_LENGTH = 16; // 128-bit auth tag
|
||||
|
||||
/**
|
||||
* Derive a 32-byte AES-256 key from BETTER_AUTH_SECRET using SHA-256.
|
||||
* The secret is assumed to be set in the environment.
|
||||
*/
|
||||
function deriveEncryptionKey(): Buffer {
|
||||
const secret = process.env['BETTER_AUTH_SECRET'];
|
||||
if (!secret) {
|
||||
throw new Error('BETTER_AUTH_SECRET is not set — cannot derive encryption key');
|
||||
}
|
||||
return createHash('sha256').update(secret).digest();
|
||||
}
|
||||
|
||||
/**
|
||||
* Encrypt a plain-text value using AES-256-GCM.
|
||||
* Output format: base64(iv + authTag + ciphertext)
|
||||
*/
|
||||
function encrypt(plaintext: string): string {
|
||||
const key = deriveEncryptionKey();
|
||||
const iv = randomBytes(IV_LENGTH);
|
||||
const cipher = createCipheriv(ALGORITHM, key, iv);
|
||||
|
||||
const encrypted = Buffer.concat([cipher.update(plaintext, 'utf8'), cipher.final()]);
|
||||
const authTag = cipher.getAuthTag();
|
||||
|
||||
// Combine iv (12) + authTag (16) + ciphertext and base64-encode
|
||||
const combined = Buffer.concat([iv, authTag, encrypted]);
|
||||
return combined.toString('base64');
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypt a value encrypted by `encrypt()`.
|
||||
* Throws on authentication failure (tampered data).
|
||||
*/
|
||||
function decrypt(encoded: string): string {
|
||||
const key = deriveEncryptionKey();
|
||||
const combined = Buffer.from(encoded, 'base64');
|
||||
|
||||
const iv = combined.subarray(0, IV_LENGTH);
|
||||
const authTag = combined.subarray(IV_LENGTH, IV_LENGTH + TAG_LENGTH);
|
||||
const ciphertext = combined.subarray(IV_LENGTH + TAG_LENGTH);
|
||||
|
||||
const decipher = createDecipheriv(ALGORITHM, key, iv);
|
||||
decipher.setAuthTag(authTag);
|
||||
|
||||
const decrypted = Buffer.concat([decipher.update(ciphertext), decipher.final()]);
|
||||
return decrypted.toString('utf8');
|
||||
}
|
||||
|
||||
@Injectable()
|
||||
export class ProviderCredentialsService {
|
||||
private readonly logger = new Logger(ProviderCredentialsService.name);
|
||||
|
||||
constructor(@Inject(DB) private readonly db: Db) {}
|
||||
|
||||
/**
|
||||
* Encrypt and store (or update) a credential for the given user + provider.
|
||||
* Uses an upsert pattern: one row per (userId, provider).
|
||||
*/
|
||||
async store(
|
||||
userId: string,
|
||||
provider: string,
|
||||
type: 'api_key' | 'oauth_token',
|
||||
value: string,
|
||||
metadata?: Record<string, unknown>,
|
||||
): Promise<void> {
|
||||
const encryptedValue = encrypt(value);
|
||||
|
||||
await this.db
|
||||
.insert(providerCredentials)
|
||||
.values({
|
||||
userId,
|
||||
provider,
|
||||
credentialType: type,
|
||||
encryptedValue,
|
||||
metadata: metadata ?? null,
|
||||
})
|
||||
.onConflictDoUpdate({
|
||||
target: [providerCredentials.userId, providerCredentials.provider],
|
||||
set: {
|
||||
credentialType: type,
|
||||
encryptedValue,
|
||||
metadata: metadata ?? null,
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
});
|
||||
|
||||
this.logger.log(`Credential stored for user=${userId} provider=${provider}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypt and return the plain-text credential value for the given user + provider.
|
||||
* Returns null if no credential is stored.
|
||||
*/
|
||||
async retrieve(userId: string, provider: string): Promise<string | null> {
|
||||
const rows = await this.db
|
||||
.select()
|
||||
.from(providerCredentials)
|
||||
.where(
|
||||
and(eq(providerCredentials.userId, userId), eq(providerCredentials.provider, provider)),
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (rows.length === 0) return null;
|
||||
|
||||
const row = rows[0]!;
|
||||
|
||||
// Skip expired OAuth tokens
|
||||
if (row.expiresAt && row.expiresAt < new Date()) {
|
||||
this.logger.warn(`Credential for user=${userId} provider=${provider} has expired`);
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
return decrypt(row.encryptedValue);
|
||||
} catch (err) {
|
||||
this.logger.error(
|
||||
`Failed to decrypt credential for user=${userId} provider=${provider}`,
|
||||
err instanceof Error ? err.message : String(err),
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete the stored credential for the given user + provider.
|
||||
*/
|
||||
async remove(userId: string, provider: string): Promise<void> {
|
||||
await this.db
|
||||
.delete(providerCredentials)
|
||||
.where(
|
||||
and(eq(providerCredentials.userId, userId), eq(providerCredentials.provider, provider)),
|
||||
);
|
||||
|
||||
this.logger.log(`Credential removed for user=${userId} provider=${provider}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* List all providers for which the user has stored credentials.
|
||||
* Never returns decrypted values.
|
||||
*/
|
||||
async listProviders(userId: string): Promise<ProviderCredentialSummaryDto[]> {
|
||||
const rows = await this.db
|
||||
.select({
|
||||
provider: providerCredentials.provider,
|
||||
credentialType: providerCredentials.credentialType,
|
||||
expiresAt: providerCredentials.expiresAt,
|
||||
metadata: providerCredentials.metadata,
|
||||
createdAt: providerCredentials.createdAt,
|
||||
updatedAt: providerCredentials.updatedAt,
|
||||
})
|
||||
.from(providerCredentials)
|
||||
.where(eq(providerCredentials.userId, userId));
|
||||
|
||||
return rows.map((row) => ({
|
||||
provider: row.provider,
|
||||
credentialType: row.credentialType,
|
||||
exists: true,
|
||||
expiresAt: row.expiresAt?.toISOString() ?? null,
|
||||
metadata: row.metadata as Record<string, unknown> | null,
|
||||
createdAt: row.createdAt.toISOString(),
|
||||
updatedAt: row.updatedAt.toISOString(),
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,11 @@
|
||||
import { Injectable, Logger, type OnModuleDestroy, type OnModuleInit } from '@nestjs/common';
|
||||
import {
|
||||
Inject,
|
||||
Injectable,
|
||||
Logger,
|
||||
Optional,
|
||||
type OnModuleDestroy,
|
||||
type OnModuleInit,
|
||||
} from '@nestjs/common';
|
||||
import { ModelRegistry, AuthStorage } from '@mariozechner/pi-coding-agent';
|
||||
import { getModel, type Model, type Api } from '@mariozechner/pi-ai';
|
||||
import type {
|
||||
@@ -13,8 +20,10 @@ import {
|
||||
OllamaAdapter,
|
||||
OpenAIAdapter,
|
||||
OpenRouterAdapter,
|
||||
ZaiAdapter,
|
||||
} from './adapters/index.js';
|
||||
import type { TestConnectionResultDto } from './provider.dto.js';
|
||||
import { ProviderCredentialsService } from './provider-credentials.service.js';
|
||||
|
||||
/** Default health check interval in seconds */
|
||||
const DEFAULT_HEALTH_INTERVAL_SECS = 60;
|
||||
@@ -22,11 +31,25 @@ const DEFAULT_HEALTH_INTERVAL_SECS = 60;
|
||||
/** DI injection token for the provider adapter array. */
|
||||
export const PROVIDER_ADAPTERS = Symbol('PROVIDER_ADAPTERS');
|
||||
|
||||
/** Environment variable names for well-known providers */
|
||||
const PROVIDER_ENV_KEYS: Record<string, string> = {
|
||||
anthropic: 'ANTHROPIC_API_KEY',
|
||||
openai: 'OPENAI_API_KEY',
|
||||
openrouter: 'OPENROUTER_API_KEY',
|
||||
zai: 'ZAI_API_KEY',
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
export class ProviderService implements OnModuleInit, OnModuleDestroy {
|
||||
private readonly logger = new Logger(ProviderService.name);
|
||||
private registry!: ModelRegistry;
|
||||
|
||||
constructor(
|
||||
@Optional()
|
||||
@Inject(ProviderCredentialsService)
|
||||
private readonly credentialsService: ProviderCredentialsService | null,
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Adapters registered with this service.
|
||||
* Built-in adapters (Ollama) are always present; additional adapters can be
|
||||
@@ -52,14 +75,13 @@ export class ProviderService implements OnModuleInit, OnModuleDestroy {
|
||||
new AnthropicAdapter(this.registry),
|
||||
new OpenAIAdapter(this.registry),
|
||||
new OpenRouterAdapter(),
|
||||
new ZaiAdapter(),
|
||||
];
|
||||
|
||||
// Run all adapter registrations first (Ollama, Anthropic, and any future adapters)
|
||||
// Run all adapter registrations first (Ollama, Anthropic, OpenAI, OpenRouter, Z.ai)
|
||||
await this.registerAll();
|
||||
|
||||
// Register API-key providers directly (Z.ai, custom)
|
||||
// OpenAI now has a dedicated adapter (M3-003).
|
||||
this.registerZaiProvider();
|
||||
// Register API-key providers directly (custom)
|
||||
this.registerCustomProviders();
|
||||
|
||||
const available = this.registry.getAvailable();
|
||||
@@ -340,30 +362,9 @@ export class ProviderService implements OnModuleInit, OnModuleDestroy {
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Private helpers — direct registry registration for providers without adapters yet
|
||||
// (Z.ai will move to an adapter in M3-005)
|
||||
// Private helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
private registerZaiProvider(): void {
|
||||
const apiKey = process.env['ZAI_API_KEY'];
|
||||
if (!apiKey) {
|
||||
this.logger.debug('Skipping Z.ai provider registration: ZAI_API_KEY not set');
|
||||
return;
|
||||
}
|
||||
|
||||
const models = ['glm-4.5', 'glm-4.5-air', 'glm-4.5-flash'].map((id) =>
|
||||
this.cloneBuiltInModel('zai', id),
|
||||
);
|
||||
|
||||
this.registry.registerProvider('zai', {
|
||||
apiKey,
|
||||
baseUrl: 'https://open.bigmodel.cn/api/paas/v4',
|
||||
models,
|
||||
});
|
||||
|
||||
this.logger.log('Z.ai provider registered with 3 models');
|
||||
}
|
||||
|
||||
private registerCustomProviders(): void {
|
||||
const customJson = process.env['MOSAIC_CUSTOM_PROVIDERS'];
|
||||
if (!customJson) return;
|
||||
@@ -378,6 +379,29 @@ export class ProviderService implements OnModuleInit, OnModuleDestroy {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve an API key for a provider, scoped to a specific user.
|
||||
* User-stored credentials take precedence over environment variables.
|
||||
* Returns null if no key is available from either source.
|
||||
*/
|
||||
async resolveApiKey(userId: string, provider: string): Promise<string | null> {
|
||||
if (this.credentialsService) {
|
||||
const userKey = await this.credentialsService.retrieve(userId, provider);
|
||||
if (userKey) {
|
||||
this.logger.debug(`Using user-scoped credential for user=${userId} provider=${provider}`);
|
||||
return userKey;
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
const envVar = PROVIDER_ENV_KEYS[provider];
|
||||
const envKey = envVar ? (process.env[envVar] ?? null) : null;
|
||||
if (envKey) {
|
||||
this.logger.debug(`Using env-var credential for provider=${provider}`);
|
||||
}
|
||||
return envKey;
|
||||
}
|
||||
|
||||
private cloneBuiltInModel(
|
||||
provider: string,
|
||||
modelId: string,
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
import { Body, Controller, Get, Inject, Post, UseGuards } from '@nestjs/common';
|
||||
import { Body, Controller, Delete, Get, Inject, Param, Post, UseGuards } from '@nestjs/common';
|
||||
import type { RoutingCriteria } from '@mosaic/types';
|
||||
import { AuthGuard } from '../auth/auth.guard.js';
|
||||
import { CurrentUser } from '../auth/current-user.decorator.js';
|
||||
import { ProviderService } from './provider.service.js';
|
||||
import { ProviderCredentialsService } from './provider-credentials.service.js';
|
||||
import { RoutingService } from './routing.service.js';
|
||||
import type { TestConnectionDto, TestConnectionResultDto } from './provider.dto.js';
|
||||
import type {
|
||||
StoreCredentialDto,
|
||||
ProviderCredentialSummaryDto,
|
||||
} from './provider-credentials.dto.js';
|
||||
|
||||
@Controller('api/providers')
|
||||
@UseGuards(AuthGuard)
|
||||
export class ProvidersController {
|
||||
constructor(
|
||||
@Inject(ProviderService) private readonly providerService: ProviderService,
|
||||
@Inject(ProviderCredentialsService)
|
||||
private readonly credentialsService: ProviderCredentialsService,
|
||||
@Inject(RoutingService) private readonly routingService: RoutingService,
|
||||
) {}
|
||||
|
||||
@@ -42,4 +50,49 @@ export class ProvidersController {
|
||||
rank(@Body() criteria: RoutingCriteria) {
|
||||
return this.routingService.rank(criteria);
|
||||
}
|
||||
|
||||
// ── Credential CRUD ──────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* GET /api/providers/credentials
|
||||
* List all provider credentials for the authenticated user.
|
||||
* Returns provider names, types, and metadata — never decrypted values.
|
||||
*/
|
||||
@Get('credentials')
|
||||
listCredentials(@CurrentUser() user: { id: string }): Promise<ProviderCredentialSummaryDto[]> {
|
||||
return this.credentialsService.listProviders(user.id);
|
||||
}
|
||||
|
||||
/**
|
||||
* POST /api/providers/credentials
|
||||
* Store or update a provider credential for the authenticated user.
|
||||
* The value is encrypted before storage and never returned.
|
||||
*/
|
||||
@Post('credentials')
|
||||
async storeCredential(
|
||||
@CurrentUser() user: { id: string },
|
||||
@Body() body: StoreCredentialDto,
|
||||
): Promise<{ success: boolean; provider: string }> {
|
||||
await this.credentialsService.store(
|
||||
user.id,
|
||||
body.provider,
|
||||
body.type,
|
||||
body.value,
|
||||
body.metadata,
|
||||
);
|
||||
return { success: true, provider: body.provider };
|
||||
}
|
||||
|
||||
/**
|
||||
* DELETE /api/providers/credentials/:provider
|
||||
* Remove a stored credential for the authenticated user.
|
||||
*/
|
||||
@Delete('credentials/:provider')
|
||||
async removeCredential(
|
||||
@CurrentUser() user: { id: string },
|
||||
@Param('provider') provider: string,
|
||||
): Promise<{ success: boolean; provider: string }> {
|
||||
await this.credentialsService.remove(user.id, provider);
|
||||
return { success: true, provider };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ const COST_TIER_THRESHOLDS: Record<CostTier, { maxInput: number }> = {
|
||||
cheap: { maxInput: 1 },
|
||||
standard: { maxInput: 10 },
|
||||
premium: { maxInput: Infinity },
|
||||
// local = self-hosted; treat as cheapest tier for cost scoring purposes
|
||||
local: { maxInput: 0 },
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
|
||||
118
apps/gateway/src/agent/routing/routing.types.ts
Normal file
118
apps/gateway/src/agent/routing/routing.types.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
/**
|
||||
* Routing engine types — M4-002 (condition types) and M4-003 (action types).
|
||||
*
|
||||
* These types are re-exported from `@mosaic/types` for shared use across packages.
|
||||
*/
|
||||
|
||||
// ─── Classification primitives ───────────────────────────────────────────────
|
||||
|
||||
/** Category of work the agent is being asked to perform */
|
||||
export type TaskType =
|
||||
| 'coding'
|
||||
| 'research'
|
||||
| 'summarization'
|
||||
| 'conversation'
|
||||
| 'analysis'
|
||||
| 'creative';
|
||||
|
||||
/** Estimated complexity of the task, used to bias toward cheaper or more capable models */
|
||||
export type Complexity = 'simple' | 'moderate' | 'complex';
|
||||
|
||||
/** Primary knowledge domain of the task */
|
||||
export type Domain = 'frontend' | 'backend' | 'devops' | 'docs' | 'general';
|
||||
|
||||
/**
|
||||
* Cost tier for model selection.
|
||||
* Extends the existing `CostTier` in `@mosaic/types` with `local` for self-hosted models.
|
||||
*/
|
||||
export type CostTier = 'cheap' | 'standard' | 'premium' | 'local';
|
||||
|
||||
/** Special model capability required by the task */
|
||||
export type Capability = 'tools' | 'vision' | 'long-context' | 'reasoning' | 'embedding';
|
||||
|
||||
// ─── Condition types ─────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* A single predicate that must be satisfied for a routing rule to match.
|
||||
*
|
||||
* - `eq` — scalar equality: `field === value`
|
||||
* - `in` — set membership: `value` contains `field`
|
||||
* - `includes` — array containment: `field` (array) includes `value`
|
||||
*/
|
||||
export interface RoutingCondition {
|
||||
/** The task-classification field to test */
|
||||
field: 'taskType' | 'complexity' | 'domain' | 'costTier' | 'requiredCapabilities';
|
||||
/** Comparison operator */
|
||||
operator: 'eq' | 'in' | 'includes';
|
||||
/** Expected value or set of values */
|
||||
value: string | string[];
|
||||
}
|
||||
|
||||
// ─── Action types ────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* The routing action to execute when all conditions in a rule are satisfied.
|
||||
*/
|
||||
export interface RoutingAction {
|
||||
/** LLM provider identifier, e.g. `'anthropic'`, `'openai'`, `'ollama'` */
|
||||
provider: string;
|
||||
/** Model identifier, e.g. `'claude-opus-4-6'`, `'gpt-4o'` */
|
||||
model: string;
|
||||
/** Optional: use a specific pre-configured agent config from the agent registry */
|
||||
agentConfigId?: string;
|
||||
/** Optional: override the agent's default system prompt for this route */
|
||||
systemPromptOverride?: string;
|
||||
/** Optional: restrict the tool set available to the agent for this route */
|
||||
toolAllowlist?: string[];
|
||||
}
|
||||
|
||||
// ─── Rule and decision types ─────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Full routing rule as stored in the database and used at runtime.
|
||||
*/
|
||||
export interface RoutingRule {
|
||||
/** UUID primary key */
|
||||
id: string;
|
||||
/** Human-readable rule name */
|
||||
name: string;
|
||||
/** Lower number = evaluated first; unique per scope */
|
||||
priority: number;
|
||||
/** `'system'` rules apply globally; `'user'` rules override for a specific user */
|
||||
scope: 'system' | 'user';
|
||||
/** Present only for `'user'`-scoped rules */
|
||||
userId?: string;
|
||||
/** All conditions must match for the rule to fire */
|
||||
conditions: RoutingCondition[];
|
||||
/** Action to take when all conditions are met */
|
||||
action: RoutingAction;
|
||||
/** Whether this rule is active */
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Structured representation of what an agent has been asked to do,
|
||||
* produced by the task classifier and consumed by the routing engine.
|
||||
*/
|
||||
export interface TaskClassification {
|
||||
taskType: TaskType;
|
||||
complexity: Complexity;
|
||||
domain: Domain;
|
||||
requiredCapabilities: Capability[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Output of the routing engine — which model to use and why.
|
||||
*/
|
||||
export interface RoutingDecision {
|
||||
/** LLM provider identifier */
|
||||
provider: string;
|
||||
/** Model identifier */
|
||||
model: string;
|
||||
/** Optional agent config to apply */
|
||||
agentConfigId?: string;
|
||||
/** Name of the rule that matched, for observability */
|
||||
ruleName: string;
|
||||
/** Human-readable explanation of why this rule was selected */
|
||||
reason: string;
|
||||
}
|
||||
17
packages/db/drizzle/0004_bumpy_miracleman.sql
Normal file
17
packages/db/drizzle/0004_bumpy_miracleman.sql
Normal file
@@ -0,0 +1,17 @@
|
||||
CREATE TABLE "routing_rules" (
|
||||
"id" uuid PRIMARY KEY DEFAULT gen_random_uuid() NOT NULL,
|
||||
"name" text NOT NULL,
|
||||
"priority" integer NOT NULL,
|
||||
"scope" text DEFAULT 'system' NOT NULL,
|
||||
"user_id" text,
|
||||
"conditions" jsonb NOT NULL,
|
||||
"action" jsonb NOT NULL,
|
||||
"enabled" boolean DEFAULT true NOT NULL,
|
||||
"created_at" timestamp with time zone DEFAULT now() NOT NULL,
|
||||
"updated_at" timestamp with time zone DEFAULT now() NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
ALTER TABLE "routing_rules" ADD CONSTRAINT "routing_rules_user_id_users_id_fk" FOREIGN KEY ("user_id") REFERENCES "public"."users"("id") ON DELETE cascade ON UPDATE no action;--> statement-breakpoint
|
||||
CREATE INDEX "routing_rules_scope_priority_idx" ON "routing_rules" USING btree ("scope","priority");--> statement-breakpoint
|
||||
CREATE INDEX "routing_rules_user_id_idx" ON "routing_rules" USING btree ("user_id");--> statement-breakpoint
|
||||
CREATE INDEX "routing_rules_enabled_idx" ON "routing_rules" USING btree ("enabled");
|
||||
2635
packages/db/drizzle/meta/0004_snapshot.json
Normal file
2635
packages/db/drizzle/meta/0004_snapshot.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -29,6 +29,13 @@
|
||||
"when": 1773887085247,
|
||||
"tag": "0003_p8003_perf_indexes",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 4,
|
||||
"version": "7",
|
||||
"when": 1774224004898,
|
||||
"tag": "0004_bumpy_miracleman",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -479,6 +479,66 @@ export const skills = pgTable(
|
||||
(t) => [index('skills_enabled_idx').on(t.enabled)],
|
||||
);
|
||||
|
||||
// ─── Routing Rules ──────────────────────────────────────────────────────────
|
||||
|
||||
export const routingRules = pgTable(
|
||||
'routing_rules',
|
||||
{
|
||||
id: uuid('id').primaryKey().defaultRandom(),
|
||||
/** Human-readable rule name */
|
||||
name: text('name').notNull(),
|
||||
/** Lower number = higher priority; unique per scope */
|
||||
priority: integer('priority').notNull(),
|
||||
/** 'system' rules apply globally; 'user' rules are scoped to a specific user */
|
||||
scope: text('scope', { enum: ['system', 'user'] })
|
||||
.notNull()
|
||||
.default('system'),
|
||||
/** Null for system-scoped rules; FK to users.id for user-scoped rules */
|
||||
userId: text('user_id').references(() => users.id, { onDelete: 'cascade' }),
|
||||
/** Array of condition objects that must all match for the rule to fire */
|
||||
conditions: jsonb('conditions').notNull().$type<Record<string, unknown>[]>(),
|
||||
/** Routing action to take when all conditions are satisfied */
|
||||
action: jsonb('action').notNull().$type<Record<string, unknown>>(),
|
||||
/** Whether this rule is active */
|
||||
enabled: boolean('enabled').notNull().default(true),
|
||||
createdAt: timestamp('created_at', { withTimezone: true }).notNull().defaultNow(),
|
||||
updatedAt: timestamp('updated_at', { withTimezone: true }).notNull().defaultNow(),
|
||||
},
|
||||
(t) => [
|
||||
// Lookup by scope + priority for ordered rule evaluation
|
||||
index('routing_rules_scope_priority_idx').on(t.scope, t.priority),
|
||||
// User-scoped rules lookup
|
||||
index('routing_rules_user_id_idx').on(t.userId),
|
||||
// Filter enabled rules efficiently
|
||||
index('routing_rules_enabled_idx').on(t.enabled),
|
||||
],
|
||||
);
|
||||
|
||||
// ─── Provider Credentials ────────────────────────────────────────────────────
|
||||
|
||||
export const providerCredentials = pgTable(
|
||||
'provider_credentials',
|
||||
{
|
||||
id: uuid('id').primaryKey().defaultRandom(),
|
||||
userId: text('user_id')
|
||||
.notNull()
|
||||
.references(() => users.id, { onDelete: 'cascade' }),
|
||||
provider: text('provider').notNull(),
|
||||
credentialType: text('credential_type', { enum: ['api_key', 'oauth_token'] }).notNull(),
|
||||
encryptedValue: text('encrypted_value').notNull(),
|
||||
refreshToken: text('refresh_token'),
|
||||
expiresAt: timestamp('expires_at', { withTimezone: true }),
|
||||
metadata: jsonb('metadata'),
|
||||
createdAt: timestamp('created_at', { withTimezone: true }).notNull().defaultNow(),
|
||||
updatedAt: timestamp('updated_at', { withTimezone: true }).notNull().defaultNow(),
|
||||
},
|
||||
(t) => [
|
||||
// Unique constraint: one credential entry per user per provider
|
||||
uniqueIndex('provider_credentials_user_provider_idx').on(t.userId, t.provider),
|
||||
index('provider_credentials_user_id_idx').on(t.userId),
|
||||
],
|
||||
);
|
||||
|
||||
// ─── Summarization Jobs ─────────────────────────────────────────────────────
|
||||
|
||||
export const summarizationJobs = pgTable(
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
/** Cost tier for model selection */
|
||||
export type CostTier = 'cheap' | 'standard' | 'premium';
|
||||
// ─── Legacy simple-routing types (kept for backward compatibility) ────────────
|
||||
|
||||
/** Task type hint for routing */
|
||||
export type TaskType = 'chat' | 'coding' | 'analysis' | 'summarization' | 'general';
|
||||
/** Result of a simple scoring-based routing decision */
|
||||
export interface RoutingResult {
|
||||
provider: string;
|
||||
modelId: string;
|
||||
modelName: string;
|
||||
score: number;
|
||||
reasoning: string;
|
||||
}
|
||||
|
||||
/** Routing criteria for model selection */
|
||||
/** Routing criteria for score-based model selection */
|
||||
export interface RoutingCriteria {
|
||||
taskType?: TaskType;
|
||||
costTier?: CostTier;
|
||||
@@ -15,11 +20,115 @@ export interface RoutingCriteria {
|
||||
preferredModel?: string;
|
||||
}
|
||||
|
||||
/** Result of a routing decision */
|
||||
export interface RoutingResult {
|
||||
provider: string;
|
||||
modelId: string;
|
||||
modelName: string;
|
||||
score: number;
|
||||
reasoning: string;
|
||||
// ─── Classification primitives (M4-002) ──────────────────────────────────────
|
||||
|
||||
/** Category of work the agent is being asked to perform */
|
||||
export type TaskType =
|
||||
| 'chat'
|
||||
| 'coding'
|
||||
| 'research'
|
||||
| 'summarization'
|
||||
| 'conversation'
|
||||
| 'analysis'
|
||||
| 'creative'
|
||||
| 'general';
|
||||
|
||||
/** Estimated complexity of the task, used to bias toward cheaper or more capable models */
|
||||
export type Complexity = 'simple' | 'moderate' | 'complex';
|
||||
|
||||
/** Primary knowledge domain of the task */
|
||||
export type Domain = 'frontend' | 'backend' | 'devops' | 'docs' | 'general';
|
||||
|
||||
/**
|
||||
* Cost tier for model selection.
|
||||
* `local` targets self-hosted/on-premises models.
|
||||
*/
|
||||
export type CostTier = 'cheap' | 'standard' | 'premium' | 'local';
|
||||
|
||||
/** Special model capability required by the task */
|
||||
export type Capability = 'tools' | 'vision' | 'long-context' | 'reasoning' | 'embedding';
|
||||
|
||||
// ─── Condition types (M4-002) ─────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* A single predicate that must be satisfied for a routing rule to match.
|
||||
*
|
||||
* - `eq` — scalar equality: `field === value`
|
||||
* - `in` — set membership: `value` (array) contains `field`
|
||||
* - `includes` — array containment: `field` (array) includes `value`
|
||||
*/
|
||||
export interface RoutingCondition {
|
||||
/** The task-classification field to test */
|
||||
field: 'taskType' | 'complexity' | 'domain' | 'costTier' | 'requiredCapabilities';
|
||||
/** Comparison operator */
|
||||
operator: 'eq' | 'in' | 'includes';
|
||||
/** Expected value or set of values */
|
||||
value: string | string[];
|
||||
}
|
||||
|
||||
// ─── Action types (M4-003) ────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* The routing action to execute when all conditions in a rule are satisfied.
|
||||
*/
|
||||
export interface RoutingAction {
|
||||
/** LLM provider identifier, e.g. `'anthropic'`, `'openai'`, `'ollama'` */
|
||||
provider: string;
|
||||
/** Model identifier, e.g. `'claude-opus-4-6'`, `'gpt-4o'` */
|
||||
model: string;
|
||||
/** Optional: use a specific pre-configured agent config from the agent registry */
|
||||
agentConfigId?: string;
|
||||
/** Optional: override the agent's default system prompt for this route */
|
||||
systemPromptOverride?: string;
|
||||
/** Optional: restrict the tool set available to the agent for this route */
|
||||
toolAllowlist?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Full routing rule as stored in the database and used at runtime.
|
||||
*/
|
||||
export interface RoutingRule {
|
||||
/** UUID primary key */
|
||||
id: string;
|
||||
/** Human-readable rule name */
|
||||
name: string;
|
||||
/** Lower number = evaluated first; unique per scope */
|
||||
priority: number;
|
||||
/** `'system'` rules apply globally; `'user'` rules override for a specific user */
|
||||
scope: 'system' | 'user';
|
||||
/** Present only for `'user'`-scoped rules */
|
||||
userId?: string;
|
||||
/** All conditions must match for the rule to fire */
|
||||
conditions: RoutingCondition[];
|
||||
/** Action to take when all conditions are met */
|
||||
action: RoutingAction;
|
||||
/** Whether this rule is active */
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Structured representation of what an agent has been asked to do,
|
||||
* produced by the task classifier and consumed by the routing engine.
|
||||
*/
|
||||
export interface TaskClassification {
|
||||
taskType: TaskType;
|
||||
complexity: Complexity;
|
||||
domain: Domain;
|
||||
requiredCapabilities: Capability[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Output of the routing engine — which model to use and why.
|
||||
*/
|
||||
export interface RoutingDecision {
|
||||
/** LLM provider identifier */
|
||||
provider: string;
|
||||
/** Model identifier */
|
||||
model: string;
|
||||
/** Optional agent config to apply */
|
||||
agentConfigId?: string;
|
||||
/** Name of the rule that matched, for observability */
|
||||
ruleName: string;
|
||||
/** Human-readable explanation of why this rule was selected */
|
||||
reason: string;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user