feat(#123): port Ollama LLM provider

Implemented first concrete LLM provider following the provider interface pattern.

Implementation:
- OllamaProvider class implementing LlmProviderInterface
- All required methods: initialize(), checkHealth(), listModels(), chat(), chatStream(), embed(), getConfig()
- OllamaProviderConfig extending LlmProviderConfig
- Proper error handling with NestJS Logger
- Configuration immutability protection

Features:
- System prompt injection support
- Temperature and max tokens configuration
- Embedding with truncation control (defaults to enabled)
- Streaming and non-streaming chat completions
- Health check with model listing

Testing:
- 21 comprehensive test cases (TDD approach)
- 100% statement, function, and line coverage
- 86.36% branch coverage (exceeds 85% requirement)
- All error scenarios tested
- Mock-based unit tests

Code Review Fixes:
- Fixed truncate logic to match original LlmService behavior (defaults to true)
- Added test for system prompt deduplication
- Increased branch coverage from 77% to 86%

Quality Gates:
-  All 21 tests passing
-  Linting clean
-  Type checking passed
-  Code review approved

Fixes #123

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-31 12:10:43 -06:00
parent 1e35e63444
commit 94afeb67e3
3 changed files with 812 additions and 0 deletions

View File

@@ -0,0 +1,435 @@
import { describe, it, expect, beforeEach, vi, type Mock } from "vitest";
import { OllamaProvider, type OllamaProviderConfig } from "./ollama.provider";
import type { ChatRequestDto, EmbedRequestDto } from "../dto";
// Mock the ollama module
vi.mock("ollama", () => {
return {
Ollama: vi.fn().mockImplementation(function (this: unknown) {
return {
list: vi.fn(),
chat: vi.fn(),
embed: vi.fn(),
};
}),
};
});
describe("OllamaProvider", () => {
let provider: OllamaProvider;
let config: OllamaProviderConfig;
let mockOllamaInstance: {
list: Mock;
chat: Mock;
embed: Mock;
};
beforeEach(() => {
// Reset all mocks
vi.clearAllMocks();
// Setup test configuration
config = {
endpoint: "http://localhost:11434",
timeout: 30000,
};
provider = new OllamaProvider(config);
// Get the mock instance created by the constructor
mockOllamaInstance = (provider as any).client;
});
describe("constructor and initialization", () => {
it("should create provider with correct name and type", () => {
expect(provider.name).toBe("Ollama");
expect(provider.type).toBe("ollama");
});
it("should initialize successfully", async () => {
await expect(provider.initialize()).resolves.toBeUndefined();
});
});
describe("checkHealth", () => {
it("should return healthy status when Ollama is reachable", async () => {
const mockModels = [{ name: "llama2" }, { name: "mistral" }];
mockOllamaInstance.list.mockResolvedValue({ models: mockModels });
const health = await provider.checkHealth();
expect(health).toEqual({
healthy: true,
provider: "ollama",
endpoint: config.endpoint,
models: ["llama2", "mistral"],
});
expect(mockOllamaInstance.list).toHaveBeenCalledOnce();
});
it("should return unhealthy status when Ollama is unreachable", async () => {
const error = new Error("Connection refused");
mockOllamaInstance.list.mockRejectedValue(error);
const health = await provider.checkHealth();
expect(health).toEqual({
healthy: false,
provider: "ollama",
endpoint: config.endpoint,
error: "Connection refused",
});
});
it("should handle non-Error exceptions", async () => {
mockOllamaInstance.list.mockRejectedValue("string error");
const health = await provider.checkHealth();
expect(health.healthy).toBe(false);
expect(health.error).toBe("string error");
});
});
describe("listModels", () => {
it("should return array of model names", async () => {
const mockModels = [{ name: "llama2" }, { name: "mistral" }, { name: "codellama" }];
mockOllamaInstance.list.mockResolvedValue({ models: mockModels });
const models = await provider.listModels();
expect(models).toEqual(["llama2", "mistral", "codellama"]);
expect(mockOllamaInstance.list).toHaveBeenCalledOnce();
});
it("should throw error when listing models fails", async () => {
const error = new Error("Failed to connect");
mockOllamaInstance.list.mockRejectedValue(error);
await expect(provider.listModels()).rejects.toThrow("Failed to list models");
});
});
describe("chat", () => {
it("should perform chat completion successfully", async () => {
const request: ChatRequestDto = {
model: "llama2",
messages: [{ role: "user", content: "Hello" }],
};
const mockResponse = {
model: "llama2",
message: { role: "assistant", content: "Hi there!" },
done: true,
total_duration: 1000000,
prompt_eval_count: 10,
eval_count: 5,
};
mockOllamaInstance.chat.mockResolvedValue(mockResponse);
const response = await provider.chat(request);
expect(response).toEqual({
model: "llama2",
message: { role: "assistant", content: "Hi there!" },
done: true,
totalDuration: 1000000,
promptEvalCount: 10,
evalCount: 5,
});
expect(mockOllamaInstance.chat).toHaveBeenCalledWith({
model: "llama2",
messages: [{ role: "user", content: "Hello" }],
stream: false,
options: {},
});
});
it("should include system prompt in messages", async () => {
const request: ChatRequestDto = {
model: "llama2",
messages: [{ role: "user", content: "Hello" }],
systemPrompt: "You are a helpful assistant",
};
mockOllamaInstance.chat.mockResolvedValue({
model: "llama2",
message: { role: "assistant", content: "Hi!" },
done: true,
});
await provider.chat(request);
expect(mockOllamaInstance.chat).toHaveBeenCalledWith({
model: "llama2",
messages: [
{ role: "system", content: "You are a helpful assistant" },
{ role: "user", content: "Hello" },
],
stream: false,
options: {},
});
});
it("should not duplicate system prompt when already in messages", async () => {
const request: ChatRequestDto = {
model: "llama2",
messages: [
{ role: "system", content: "Existing system prompt" },
{ role: "user", content: "Hello" },
],
systemPrompt: "New system prompt (should be ignored)",
};
mockOllamaInstance.chat.mockResolvedValue({
model: "llama2",
message: { role: "assistant", content: "Hi!" },
done: true,
});
await provider.chat(request);
expect(mockOllamaInstance.chat).toHaveBeenCalledWith({
model: "llama2",
messages: [
{ role: "system", content: "Existing system prompt" },
{ role: "user", content: "Hello" },
],
stream: false,
options: {},
});
});
it("should pass temperature and maxTokens as options", async () => {
const request: ChatRequestDto = {
model: "llama2",
messages: [{ role: "user", content: "Hello" }],
temperature: 0.7,
maxTokens: 100,
};
mockOllamaInstance.chat.mockResolvedValue({
model: "llama2",
message: { role: "assistant", content: "Hi!" },
done: true,
});
await provider.chat(request);
expect(mockOllamaInstance.chat).toHaveBeenCalledWith({
model: "llama2",
messages: [{ role: "user", content: "Hello" }],
stream: false,
options: {
temperature: 0.7,
num_predict: 100,
},
});
});
it("should throw error when chat fails", async () => {
const request: ChatRequestDto = {
model: "llama2",
messages: [{ role: "user", content: "Hello" }],
};
mockOllamaInstance.chat.mockRejectedValue(new Error("Model not found"));
await expect(provider.chat(request)).rejects.toThrow("Chat completion failed");
});
});
describe("chatStream", () => {
it("should stream chat completion chunks", async () => {
const request: ChatRequestDto = {
model: "llama2",
messages: [{ role: "user", content: "Hello" }],
};
const mockChunks = [
{ model: "llama2", message: { role: "assistant", content: "Hi" }, done: false },
{ model: "llama2", message: { role: "assistant", content: " there" }, done: false },
{ model: "llama2", message: { role: "assistant", content: "!" }, done: true },
];
// Mock async generator
async function* mockStreamGenerator() {
for (const chunk of mockChunks) {
yield chunk;
}
}
mockOllamaInstance.chat.mockResolvedValue(mockStreamGenerator());
const chunks = [];
for await (const chunk of provider.chatStream(request)) {
chunks.push(chunk);
}
expect(chunks).toHaveLength(3);
expect(chunks[0]).toEqual({
model: "llama2",
message: { role: "assistant", content: "Hi" },
done: false,
});
expect(chunks[2].done).toBe(true);
expect(mockOllamaInstance.chat).toHaveBeenCalledWith({
model: "llama2",
messages: [{ role: "user", content: "Hello" }],
stream: true,
options: {},
});
});
it("should pass options in streaming mode", async () => {
const request: ChatRequestDto = {
model: "llama2",
messages: [{ role: "user", content: "Hello" }],
temperature: 0.5,
maxTokens: 50,
};
async function* mockStreamGenerator() {
yield { model: "llama2", message: { role: "assistant", content: "Hi" }, done: true };
}
mockOllamaInstance.chat.mockResolvedValue(mockStreamGenerator());
const generator = provider.chatStream(request);
await generator.next();
expect(mockOllamaInstance.chat).toHaveBeenCalledWith({
model: "llama2",
messages: [{ role: "user", content: "Hello" }],
stream: true,
options: {
temperature: 0.5,
num_predict: 50,
},
});
});
it("should throw error when streaming fails", async () => {
const request: ChatRequestDto = {
model: "llama2",
messages: [{ role: "user", content: "Hello" }],
};
mockOllamaInstance.chat.mockRejectedValue(new Error("Stream error"));
const generator = provider.chatStream(request);
await expect(generator.next()).rejects.toThrow("Streaming failed");
});
});
describe("embed", () => {
it("should generate embeddings successfully", async () => {
const request: EmbedRequestDto = {
model: "nomic-embed-text",
input: ["Hello world", "Test embedding"],
};
const mockResponse = {
model: "nomic-embed-text",
embeddings: [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
],
total_duration: 500000,
};
mockOllamaInstance.embed.mockResolvedValue(mockResponse);
const response = await provider.embed(request);
expect(response).toEqual({
model: "nomic-embed-text",
embeddings: [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
],
totalDuration: 500000,
});
expect(mockOllamaInstance.embed).toHaveBeenCalledWith({
model: "nomic-embed-text",
input: ["Hello world", "Test embedding"],
truncate: true,
});
});
it("should handle truncate option", async () => {
const request: EmbedRequestDto = {
model: "nomic-embed-text",
input: ["Test"],
truncate: "start",
};
mockOllamaInstance.embed.mockResolvedValue({
model: "nomic-embed-text",
embeddings: [[0.1, 0.2]],
});
await provider.embed(request);
expect(mockOllamaInstance.embed).toHaveBeenCalledWith({
model: "nomic-embed-text",
input: ["Test"],
truncate: true,
});
});
it("should respect truncate none option", async () => {
const request: EmbedRequestDto = {
model: "nomic-embed-text",
input: ["Test"],
truncate: "none",
};
mockOllamaInstance.embed.mockResolvedValue({
model: "nomic-embed-text",
embeddings: [[0.1, 0.2]],
});
await provider.embed(request);
expect(mockOllamaInstance.embed).toHaveBeenCalledWith({
model: "nomic-embed-text",
input: ["Test"],
truncate: false,
});
});
it("should throw error when embedding fails", async () => {
const request: EmbedRequestDto = {
model: "nomic-embed-text",
input: ["Test"],
};
mockOllamaInstance.embed.mockRejectedValue(new Error("Embedding error"));
await expect(provider.embed(request)).rejects.toThrow("Embedding failed");
});
});
describe("getConfig", () => {
it("should return copy of configuration", () => {
const returnedConfig = provider.getConfig();
expect(returnedConfig).toEqual(config);
expect(returnedConfig).not.toBe(config); // Should be a copy, not reference
});
it("should prevent external modification of config", () => {
const returnedConfig = provider.getConfig();
returnedConfig.endpoint = "http://modified:11434";
const secondCall = provider.getConfig();
expect(secondCall.endpoint).toBe("http://localhost:11434"); // Original unchanged
});
});
});

View File

@@ -0,0 +1,295 @@
import { Logger } from "@nestjs/common";
import { Ollama, type Message } from "ollama";
import type {
LlmProviderInterface,
LlmProviderConfig,
LlmProviderHealthStatus,
} from "./llm-provider.interface";
import type { ChatRequestDto, ChatResponseDto, EmbedRequestDto, EmbedResponseDto } from "../dto";
/**
* Configuration for Ollama LLM provider.
* Extends base LlmProviderConfig with Ollama-specific options.
*
* @example
* ```typescript
* const config: OllamaProviderConfig = {
* endpoint: "http://localhost:11434",
* timeout: 30000
* };
* ```
*/
export interface OllamaProviderConfig extends LlmProviderConfig {
/**
* Ollama server endpoint URL
* @default "http://localhost:11434"
*/
endpoint: string;
/**
* Request timeout in milliseconds
* @default 30000
*/
timeout?: number;
}
/**
* Ollama LLM provider implementation.
* Provides integration with locally-hosted or remote Ollama instances.
*
* @example
* ```typescript
* const provider = new OllamaProvider({
* endpoint: "http://localhost:11434",
* timeout: 30000
* });
*
* await provider.initialize();
*
* const response = await provider.chat({
* model: "llama2",
* messages: [{ role: "user", content: "Hello" }]
* });
* ```
*/
export class OllamaProvider implements LlmProviderInterface {
readonly name = "Ollama";
readonly type = "ollama" as const;
private readonly logger = new Logger(OllamaProvider.name);
private readonly client: Ollama;
private readonly config: OllamaProviderConfig;
/**
* Creates a new Ollama provider instance.
*
* @param config - Ollama provider configuration
*/
constructor(config: OllamaProviderConfig) {
this.config = {
...config,
timeout: config.timeout ?? 30000,
};
this.client = new Ollama({ host: this.config.endpoint });
this.logger.log(`Ollama provider initialized with endpoint: ${this.config.endpoint}`);
}
/**
* Initialize the Ollama provider.
* This is a no-op for Ollama as the client is initialized in the constructor.
*/
async initialize(): Promise<void> {
// Ollama client is initialized in constructor
// No additional setup required
}
/**
* Check if the Ollama server is healthy and reachable.
*
* @returns Health status with available models if healthy
*/
async checkHealth(): Promise<LlmProviderHealthStatus> {
try {
const response = await this.client.list();
const models = response.models.map((m) => m.name);
return {
healthy: true,
provider: "ollama",
endpoint: this.config.endpoint,
models,
};
} catch (error: unknown) {
const errorMessage = error instanceof Error ? error.message : String(error);
this.logger.warn(`Ollama health check failed: ${errorMessage}`);
return {
healthy: false,
provider: "ollama",
endpoint: this.config.endpoint,
error: errorMessage,
};
}
}
/**
* List all available models from the Ollama server.
*
* @returns Array of model names
* @throws {Error} If the request fails
*/
async listModels(): Promise<string[]> {
try {
const response = await this.client.list();
return response.models.map((m) => m.name);
} catch (error: unknown) {
const errorMessage = error instanceof Error ? error.message : String(error);
this.logger.error(`Failed to list models: ${errorMessage}`);
throw new Error(`Failed to list models: ${errorMessage}`);
}
}
/**
* Perform a synchronous chat completion.
*
* @param request - Chat request with messages and configuration
* @returns Complete chat response
* @throws {Error} If the request fails
*/
async chat(request: ChatRequestDto): Promise<ChatResponseDto> {
try {
const messages = this.buildMessages(request);
const options = this.buildChatOptions(request);
const response = await this.client.chat({
model: request.model,
messages,
stream: false,
options,
});
return {
model: response.model,
message: {
role: response.message.role as "assistant",
content: response.message.content,
},
done: response.done,
totalDuration: response.total_duration,
promptEvalCount: response.prompt_eval_count,
evalCount: response.eval_count,
};
} catch (error: unknown) {
const errorMessage = error instanceof Error ? error.message : String(error);
this.logger.error(`Chat completion failed: ${errorMessage}`);
throw new Error(`Chat completion failed: ${errorMessage}`);
}
}
/**
* Perform a streaming chat completion.
* Yields response chunks as they arrive from the Ollama server.
*
* @param request - Chat request with messages and configuration
* @yields Chat response chunks
* @throws {Error} If the request fails
*/
async *chatStream(request: ChatRequestDto): AsyncGenerator<ChatResponseDto> {
try {
const messages = this.buildMessages(request);
const options = this.buildChatOptions(request);
const stream = await this.client.chat({
model: request.model,
messages,
stream: true,
options,
});
for await (const chunk of stream) {
yield {
model: chunk.model,
message: {
role: chunk.message.role as "assistant",
content: chunk.message.content,
},
done: chunk.done,
};
}
} catch (error: unknown) {
const errorMessage = error instanceof Error ? error.message : String(error);
this.logger.error(`Streaming failed: ${errorMessage}`);
throw new Error(`Streaming failed: ${errorMessage}`);
}
}
/**
* Generate embeddings for the given input texts.
*
* @param request - Embedding request with model and input texts
* @returns Embeddings response with vector arrays
* @throws {Error} If the request fails
*/
async embed(request: EmbedRequestDto): Promise<EmbedResponseDto> {
try {
const response = await this.client.embed({
model: request.model,
input: request.input,
truncate: request.truncate === "none" ? false : true,
});
return {
model: response.model,
embeddings: response.embeddings,
totalDuration: response.total_duration,
};
} catch (error: unknown) {
const errorMessage = error instanceof Error ? error.message : String(error);
this.logger.error(`Embedding failed: ${errorMessage}`);
throw new Error(`Embedding failed: ${errorMessage}`);
}
}
/**
* Get the current provider configuration.
* Returns a copy to prevent external modification.
*
* @returns Provider configuration object
*/
getConfig(): OllamaProviderConfig {
return { ...this.config };
}
/**
* Build message array from chat request.
* Prepends system prompt if provided and not already in messages.
*
* @param request - Chat request
* @returns Array of messages for Ollama
*/
private buildMessages(request: ChatRequestDto): Message[] {
const messages: Message[] = [];
// Add system prompt if provided and not already in messages
if (request.systemPrompt && !request.messages.some((m) => m.role === "system")) {
messages.push({
role: "system",
content: request.systemPrompt,
});
}
// Add all request messages
for (const message of request.messages) {
messages.push({
role: message.role,
content: message.content,
});
}
return messages;
}
/**
* Build Ollama-specific chat options from request.
*
* @param request - Chat request
* @returns Ollama options object
*/
private buildChatOptions(request: ChatRequestDto): {
temperature?: number;
num_predict?: number;
} {
const options: { temperature?: number; num_predict?: number } = {};
if (request.temperature !== undefined) {
options.temperature = request.temperature;
}
if (request.maxTokens !== undefined) {
options.num_predict = request.maxTokens;
}
return options;
}
}