feat(#391): add base TTS provider and factory classes
All checks were successful
ci/woodpecker/push/api Pipeline was successful
All checks were successful
ci/woodpecker/push/api Pipeline was successful
Add the BaseTTSProvider abstract class and TTS provider factory that were part of the tiered TTS architecture but missed from the previous commit. - BaseTTSProvider: abstract base with synthesize(), listVoices(), isHealthy() - tts-provider.factory: creates Kokoro/Chatterbox/Piper providers from config - 30 tests (22 base provider + 8 factory) Refs #391
This commit is contained in:
329
apps/api/src/speech/providers/base-tts.provider.spec.ts
Normal file
329
apps/api/src/speech/providers/base-tts.provider.spec.ts
Normal file
@@ -0,0 +1,329 @@
|
||||
/**
|
||||
* BaseTTSProvider Unit Tests
|
||||
*
|
||||
* Tests the abstract base class for OpenAI-compatible TTS providers.
|
||||
* Uses a concrete test implementation to exercise the base class logic.
|
||||
*
|
||||
* Issue #391
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi, type Mock } from "vitest";
|
||||
import { BaseTTSProvider } from "./base-tts.provider";
|
||||
import type { SpeechTier, SynthesizeOptions, AudioFormat } from "../interfaces/speech-types";
|
||||
|
||||
// ==========================================
|
||||
// Mock OpenAI SDK
|
||||
// ==========================================
|
||||
|
||||
const mockCreate = vi.fn();
|
||||
|
||||
vi.mock("openai", () => {
|
||||
class MockOpenAI {
|
||||
audio = {
|
||||
speech: {
|
||||
create: mockCreate,
|
||||
},
|
||||
};
|
||||
}
|
||||
return { default: MockOpenAI };
|
||||
});
|
||||
|
||||
// ==========================================
|
||||
// Concrete test implementation
|
||||
// ==========================================
|
||||
|
||||
class TestTTSProvider extends BaseTTSProvider {
|
||||
readonly name = "test-provider";
|
||||
readonly tier: SpeechTier = "default";
|
||||
|
||||
constructor(baseURL: string, defaultVoice?: string, defaultFormat?: AudioFormat) {
|
||||
super(baseURL, defaultVoice, defaultFormat);
|
||||
}
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// Test helpers
|
||||
// ==========================================
|
||||
|
||||
/**
|
||||
* Create a mock Response-like object that mimics OpenAI SDK's audio.speech.create() return.
|
||||
* The OpenAI SDK returns a Response object with arrayBuffer() method.
|
||||
*/
|
||||
function createMockAudioResponse(audioData: Uint8Array): { arrayBuffer: Mock } {
|
||||
return {
|
||||
arrayBuffer: vi.fn().mockResolvedValue(audioData.buffer),
|
||||
};
|
||||
}
|
||||
|
||||
describe("BaseTTSProvider", () => {
|
||||
let provider: TestTTSProvider;
|
||||
|
||||
const testBaseURL = "http://localhost:8880/v1";
|
||||
const testVoice = "af_heart";
|
||||
const testFormat: AudioFormat = "mp3";
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
provider = new TestTTSProvider(testBaseURL, testVoice, testFormat);
|
||||
});
|
||||
|
||||
// ==========================================
|
||||
// Constructor
|
||||
// ==========================================
|
||||
|
||||
describe("constructor", () => {
|
||||
it("should create an instance with provided configuration", () => {
|
||||
expect(provider).toBeDefined();
|
||||
expect(provider.name).toBe("test-provider");
|
||||
expect(provider.tier).toBe("default");
|
||||
});
|
||||
|
||||
it("should use default voice 'alloy' when none provided", () => {
|
||||
const defaultProvider = new TestTTSProvider(testBaseURL);
|
||||
expect(defaultProvider).toBeDefined();
|
||||
});
|
||||
|
||||
it("should use default format 'mp3' when none provided", () => {
|
||||
const defaultProvider = new TestTTSProvider(testBaseURL, "voice-1");
|
||||
expect(defaultProvider).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
// ==========================================
|
||||
// synthesize()
|
||||
// ==========================================
|
||||
|
||||
describe("synthesize", () => {
|
||||
it("should synthesize text and return a SynthesisResult with audio buffer", async () => {
|
||||
const audioBytes = new Uint8Array([0x49, 0x44, 0x33, 0x04, 0x00]);
|
||||
mockCreate.mockResolvedValue(createMockAudioResponse(audioBytes));
|
||||
|
||||
const result = await provider.synthesize("Hello, world!");
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.audio).toBeInstanceOf(Buffer);
|
||||
expect(result.audio.length).toBe(audioBytes.length);
|
||||
expect(result.format).toBe("mp3");
|
||||
expect(result.voice).toBe("af_heart");
|
||||
expect(result.tier).toBe("default");
|
||||
});
|
||||
|
||||
it("should pass correct parameters to OpenAI SDK", async () => {
|
||||
const audioBytes = new Uint8Array([0x01, 0x02]);
|
||||
mockCreate.mockResolvedValue(createMockAudioResponse(audioBytes));
|
||||
|
||||
await provider.synthesize("Test text");
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
model: "tts-1",
|
||||
input: "Test text",
|
||||
voice: "af_heart",
|
||||
response_format: "mp3",
|
||||
speed: 1.0,
|
||||
});
|
||||
});
|
||||
|
||||
it("should use custom voice from options", async () => {
|
||||
const audioBytes = new Uint8Array([0x01]);
|
||||
mockCreate.mockResolvedValue(createMockAudioResponse(audioBytes));
|
||||
|
||||
const options: SynthesizeOptions = { voice: "custom_voice" };
|
||||
const result = await provider.synthesize("Hello", options);
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ voice: "custom_voice" }));
|
||||
expect(result.voice).toBe("custom_voice");
|
||||
});
|
||||
|
||||
it("should use custom format from options", async () => {
|
||||
const audioBytes = new Uint8Array([0x01]);
|
||||
mockCreate.mockResolvedValue(createMockAudioResponse(audioBytes));
|
||||
|
||||
const options: SynthesizeOptions = { format: "wav" };
|
||||
const result = await provider.synthesize("Hello", options);
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ response_format: "wav" }));
|
||||
expect(result.format).toBe("wav");
|
||||
});
|
||||
|
||||
it("should use custom speed from options", async () => {
|
||||
const audioBytes = new Uint8Array([0x01]);
|
||||
mockCreate.mockResolvedValue(createMockAudioResponse(audioBytes));
|
||||
|
||||
const options: SynthesizeOptions = { speed: 1.5 };
|
||||
await provider.synthesize("Hello", options);
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ speed: 1.5 }));
|
||||
});
|
||||
|
||||
it("should throw an error when synthesis fails", async () => {
|
||||
mockCreate.mockRejectedValue(new Error("Connection refused"));
|
||||
|
||||
await expect(provider.synthesize("Hello")).rejects.toThrow(
|
||||
"TTS synthesis failed for test-provider: Connection refused"
|
||||
);
|
||||
});
|
||||
|
||||
it("should throw an error when response arrayBuffer fails", async () => {
|
||||
const mockResponse = {
|
||||
arrayBuffer: vi.fn().mockRejectedValue(new Error("Read error")),
|
||||
};
|
||||
mockCreate.mockResolvedValue(mockResponse);
|
||||
|
||||
await expect(provider.synthesize("Hello")).rejects.toThrow(
|
||||
"TTS synthesis failed for test-provider: Read error"
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle empty text input gracefully", async () => {
|
||||
const audioBytes = new Uint8Array([]);
|
||||
mockCreate.mockResolvedValue(createMockAudioResponse(audioBytes));
|
||||
|
||||
const result = await provider.synthesize("");
|
||||
|
||||
expect(result.audio).toBeInstanceOf(Buffer);
|
||||
expect(result.audio.length).toBe(0);
|
||||
});
|
||||
|
||||
it("should handle non-Error exceptions", async () => {
|
||||
mockCreate.mockRejectedValue("string error");
|
||||
|
||||
await expect(provider.synthesize("Hello")).rejects.toThrow(
|
||||
"TTS synthesis failed for test-provider: string error"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// ==========================================
|
||||
// listVoices()
|
||||
// ==========================================
|
||||
|
||||
describe("listVoices", () => {
|
||||
it("should return default voice list with the configured default voice", async () => {
|
||||
const voices = await provider.listVoices();
|
||||
|
||||
expect(voices).toBeInstanceOf(Array);
|
||||
expect(voices.length).toBeGreaterThan(0);
|
||||
|
||||
const defaultVoice = voices.find((v) => v.isDefault === true);
|
||||
expect(defaultVoice).toBeDefined();
|
||||
expect(defaultVoice?.id).toBe("af_heart");
|
||||
expect(defaultVoice?.tier).toBe("default");
|
||||
});
|
||||
|
||||
it("should set tier correctly on all returned voices", async () => {
|
||||
const voices = await provider.listVoices();
|
||||
|
||||
for (const voice of voices) {
|
||||
expect(voice.tier).toBe("default");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// ==========================================
|
||||
// isHealthy()
|
||||
// ==========================================
|
||||
|
||||
describe("isHealthy", () => {
|
||||
it("should return true when the TTS server is reachable", async () => {
|
||||
// Mock global fetch for health check
|
||||
const mockFetch = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
status: 200,
|
||||
});
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
const healthy = await provider.isHealthy();
|
||||
|
||||
expect(healthy).toBe(true);
|
||||
expect(mockFetch).toHaveBeenCalled();
|
||||
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("should return false when the TTS server is unreachable", async () => {
|
||||
const mockFetch = vi.fn().mockRejectedValue(new Error("ECONNREFUSED"));
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
const healthy = await provider.isHealthy();
|
||||
|
||||
expect(healthy).toBe(false);
|
||||
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("should return false when the TTS server returns an error status", async () => {
|
||||
const mockFetch = vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 503,
|
||||
});
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
const healthy = await provider.isHealthy();
|
||||
|
||||
expect(healthy).toBe(false);
|
||||
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("should use the base URL for the health check", async () => {
|
||||
const mockFetch = vi.fn().mockResolvedValue({ ok: true, status: 200 });
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
await provider.isHealthy();
|
||||
|
||||
// Should call a health-related endpoint at the base URL
|
||||
const calledUrl = mockFetch.mock.calls[0][0] as string;
|
||||
expect(calledUrl).toContain("localhost:8880");
|
||||
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("should set a timeout for the health check", async () => {
|
||||
const mockFetch = vi.fn().mockResolvedValue({ ok: true, status: 200 });
|
||||
vi.stubGlobal("fetch", mockFetch);
|
||||
|
||||
await provider.isHealthy();
|
||||
|
||||
// Should pass an AbortSignal for timeout
|
||||
const fetchOptions = mockFetch.mock.calls[0][1] as RequestInit;
|
||||
expect(fetchOptions.signal).toBeDefined();
|
||||
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
});
|
||||
|
||||
// ==========================================
|
||||
// Default values
|
||||
// ==========================================
|
||||
|
||||
describe("default values", () => {
|
||||
it("should use 'alloy' as default voice when none specified", async () => {
|
||||
const defaultProvider = new TestTTSProvider(testBaseURL);
|
||||
const audioBytes = new Uint8Array([0x01]);
|
||||
mockCreate.mockResolvedValue(createMockAudioResponse(audioBytes));
|
||||
|
||||
await defaultProvider.synthesize("Hello");
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ voice: "alloy" }));
|
||||
});
|
||||
|
||||
it("should use 'mp3' as default format when none specified", async () => {
|
||||
const defaultProvider = new TestTTSProvider(testBaseURL);
|
||||
const audioBytes = new Uint8Array([0x01]);
|
||||
mockCreate.mockResolvedValue(createMockAudioResponse(audioBytes));
|
||||
|
||||
await defaultProvider.synthesize("Hello");
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ response_format: "mp3" }));
|
||||
});
|
||||
|
||||
it("should use speed 1.0 as default speed", async () => {
|
||||
const audioBytes = new Uint8Array([0x01]);
|
||||
mockCreate.mockResolvedValue(createMockAudioResponse(audioBytes));
|
||||
|
||||
await provider.synthesize("Hello");
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ speed: 1.0 }));
|
||||
});
|
||||
});
|
||||
});
|
||||
189
apps/api/src/speech/providers/base-tts.provider.ts
Normal file
189
apps/api/src/speech/providers/base-tts.provider.ts
Normal file
@@ -0,0 +1,189 @@
|
||||
/**
|
||||
* Base TTS Provider
|
||||
*
|
||||
* Abstract base class implementing common OpenAI-compatible TTS logic.
|
||||
* All concrete TTS providers (Kokoro, Chatterbox, Piper) extend this class.
|
||||
*
|
||||
* Uses the OpenAI SDK with a configurable baseURL to communicate with
|
||||
* OpenAI-compatible speech synthesis endpoints.
|
||||
*
|
||||
* Issue #391
|
||||
*/
|
||||
|
||||
import { Logger } from "@nestjs/common";
|
||||
import OpenAI from "openai";
|
||||
import type { ITTSProvider } from "../interfaces/tts-provider.interface";
|
||||
import type {
|
||||
SpeechTier,
|
||||
SynthesizeOptions,
|
||||
SynthesisResult,
|
||||
VoiceInfo,
|
||||
AudioFormat,
|
||||
} from "../interfaces/speech-types";
|
||||
|
||||
/** Default TTS model identifier used for OpenAI-compatible APIs */
|
||||
const DEFAULT_MODEL = "tts-1";
|
||||
|
||||
/** Default voice when none is configured */
|
||||
const DEFAULT_VOICE = "alloy";
|
||||
|
||||
/** Default audio format */
|
||||
const DEFAULT_FORMAT: AudioFormat = "mp3";
|
||||
|
||||
/** Default speech speed multiplier */
|
||||
const DEFAULT_SPEED = 1.0;
|
||||
|
||||
/** Health check timeout in milliseconds */
|
||||
const HEALTH_CHECK_TIMEOUT_MS = 5000;
|
||||
|
||||
/**
|
||||
* Abstract base class for OpenAI-compatible TTS providers.
|
||||
*
|
||||
* Provides common logic for:
|
||||
* - Synthesizing text to audio via OpenAI SDK's audio.speech.create()
|
||||
* - Listing available voices (with a default implementation)
|
||||
* - Health checking the TTS endpoint
|
||||
*
|
||||
* Subclasses must set `name` and `tier` properties and may override
|
||||
* `listVoices()` to provide provider-specific voice lists.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* class KokoroProvider extends BaseTTSProvider {
|
||||
* readonly name = "kokoro";
|
||||
* readonly tier: SpeechTier = "default";
|
||||
*
|
||||
* constructor(baseURL: string) {
|
||||
* super(baseURL, "af_heart", "mp3");
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export abstract class BaseTTSProvider implements ITTSProvider {
|
||||
abstract readonly name: string;
|
||||
abstract readonly tier: SpeechTier;
|
||||
|
||||
protected readonly logger: Logger;
|
||||
protected readonly client: OpenAI;
|
||||
protected readonly baseURL: string;
|
||||
protected readonly defaultVoice: string;
|
||||
protected readonly defaultFormat: AudioFormat;
|
||||
|
||||
/**
|
||||
* Create a new BaseTTSProvider.
|
||||
*
|
||||
* @param baseURL - The base URL for the OpenAI-compatible TTS endpoint
|
||||
* @param defaultVoice - Default voice ID to use when none is specified in options
|
||||
* @param defaultFormat - Default audio format to use when none is specified in options
|
||||
*/
|
||||
constructor(
|
||||
baseURL: string,
|
||||
defaultVoice: string = DEFAULT_VOICE,
|
||||
defaultFormat: AudioFormat = DEFAULT_FORMAT
|
||||
) {
|
||||
this.baseURL = baseURL;
|
||||
this.defaultVoice = defaultVoice;
|
||||
this.defaultFormat = defaultFormat;
|
||||
this.logger = new Logger(this.constructor.name);
|
||||
|
||||
this.client = new OpenAI({
|
||||
baseURL,
|
||||
apiKey: "not-needed", // Self-hosted services don't require an API key
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Synthesize text to audio using the OpenAI-compatible TTS endpoint.
|
||||
*
|
||||
* Calls `client.audio.speech.create()` with the provided text and options,
|
||||
* then converts the response to a Buffer.
|
||||
*
|
||||
* @param text - Text to convert to speech
|
||||
* @param options - Optional synthesis parameters (voice, format, speed)
|
||||
* @returns Synthesis result with audio buffer and metadata
|
||||
* @throws {Error} If synthesis fails
|
||||
*/
|
||||
async synthesize(text: string, options?: SynthesizeOptions): Promise<SynthesisResult> {
|
||||
const voice = options?.voice ?? this.defaultVoice;
|
||||
const format = options?.format ?? this.defaultFormat;
|
||||
const speed = options?.speed ?? DEFAULT_SPEED;
|
||||
|
||||
try {
|
||||
const response = await this.client.audio.speech.create({
|
||||
model: DEFAULT_MODEL,
|
||||
input: text,
|
||||
voice,
|
||||
response_format: format,
|
||||
speed,
|
||||
});
|
||||
|
||||
const arrayBuffer = await response.arrayBuffer();
|
||||
const audio = Buffer.from(arrayBuffer);
|
||||
|
||||
return {
|
||||
audio,
|
||||
format,
|
||||
voice,
|
||||
tier: this.tier,
|
||||
};
|
||||
} catch (error: unknown) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
this.logger.error(`TTS synthesis failed: ${message}`);
|
||||
throw new Error(`TTS synthesis failed for ${this.name}: ${message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List available voices for this provider.
|
||||
*
|
||||
* Default implementation returns the configured default voice.
|
||||
* Subclasses should override this to provide a full voice list
|
||||
* from their specific TTS engine.
|
||||
*
|
||||
* @returns Array of voice information objects
|
||||
*/
|
||||
listVoices(): Promise<VoiceInfo[]> {
|
||||
return Promise.resolve([
|
||||
{
|
||||
id: this.defaultVoice,
|
||||
name: this.defaultVoice,
|
||||
tier: this.tier,
|
||||
isDefault: true,
|
||||
},
|
||||
]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the TTS server is reachable and healthy.
|
||||
*
|
||||
* Performs a simple HTTP request to the base URL's models endpoint
|
||||
* to verify the server is running and responding.
|
||||
*
|
||||
* @returns true if the server is reachable, false otherwise
|
||||
*/
|
||||
async isHealthy(): Promise<boolean> {
|
||||
try {
|
||||
// Extract the base URL without the /v1 path for health checking
|
||||
const healthUrl = this.baseURL.replace(/\/v1\/?$/, "/v1/models");
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => {
|
||||
controller.abort();
|
||||
}, HEALTH_CHECK_TIMEOUT_MS);
|
||||
|
||||
try {
|
||||
const response = await fetch(healthUrl, {
|
||||
method: "GET",
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
return response.ok;
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
this.logger.warn(`Health check failed for ${this.name}: ${message}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
279
apps/api/src/speech/providers/tts-provider.factory.spec.ts
Normal file
279
apps/api/src/speech/providers/tts-provider.factory.spec.ts
Normal file
@@ -0,0 +1,279 @@
|
||||
/**
|
||||
* TTS Provider Factory Unit Tests
|
||||
*
|
||||
* Tests the factory that creates and registers TTS providers based on config.
|
||||
*
|
||||
* Issue #391
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { createTTSProviders } from "./tts-provider.factory";
|
||||
import type { SpeechConfig } from "../speech.config";
|
||||
import type { SpeechTier } from "../interfaces/speech-types";
|
||||
|
||||
// ==========================================
|
||||
// Mock OpenAI SDK
|
||||
// ==========================================
|
||||
|
||||
vi.mock("openai", () => {
|
||||
class MockOpenAI {
|
||||
audio = {
|
||||
speech: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
};
|
||||
}
|
||||
return { default: MockOpenAI };
|
||||
});
|
||||
|
||||
// ==========================================
|
||||
// Test helpers
|
||||
// ==========================================
|
||||
|
||||
function createTestConfig(overrides?: Partial<SpeechConfig>): SpeechConfig {
|
||||
return {
|
||||
stt: {
|
||||
enabled: false,
|
||||
baseUrl: "http://speaches:8000/v1",
|
||||
model: "whisper",
|
||||
language: "en",
|
||||
},
|
||||
tts: {
|
||||
default: {
|
||||
enabled: false,
|
||||
url: "http://kokoro-tts:8880/v1",
|
||||
voice: "af_heart",
|
||||
format: "mp3",
|
||||
},
|
||||
premium: {
|
||||
enabled: false,
|
||||
url: "http://chatterbox-tts:8881/v1",
|
||||
},
|
||||
fallback: {
|
||||
enabled: false,
|
||||
url: "http://openedai-speech:8000/v1",
|
||||
},
|
||||
},
|
||||
limits: {
|
||||
maxUploadSize: 25_000_000,
|
||||
maxDurationSeconds: 600,
|
||||
maxTextLength: 4096,
|
||||
},
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("createTTSProviders", () => {
|
||||
// ==========================================
|
||||
// Empty map when nothing enabled
|
||||
// ==========================================
|
||||
|
||||
describe("when no TTS tiers are enabled", () => {
|
||||
it("should return an empty map", () => {
|
||||
const config = createTestConfig();
|
||||
const providers = createTTSProviders(config);
|
||||
|
||||
expect(providers).toBeInstanceOf(Map);
|
||||
expect(providers.size).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
// ==========================================
|
||||
// Default tier
|
||||
// ==========================================
|
||||
|
||||
describe("when default tier is enabled", () => {
|
||||
it("should create a provider for the default tier", () => {
|
||||
const config = createTestConfig({
|
||||
tts: {
|
||||
default: {
|
||||
enabled: true,
|
||||
url: "http://kokoro-tts:8880/v1",
|
||||
voice: "af_heart",
|
||||
format: "mp3",
|
||||
},
|
||||
premium: { enabled: false, url: "" },
|
||||
fallback: { enabled: false, url: "" },
|
||||
},
|
||||
});
|
||||
|
||||
const providers = createTTSProviders(config);
|
||||
|
||||
expect(providers.size).toBe(1);
|
||||
expect(providers.has("default")).toBe(true);
|
||||
|
||||
const provider = providers.get("default");
|
||||
expect(provider).toBeDefined();
|
||||
expect(provider?.tier).toBe("default");
|
||||
expect(provider?.name).toBe("kokoro");
|
||||
});
|
||||
});
|
||||
|
||||
// ==========================================
|
||||
// Premium tier
|
||||
// ==========================================
|
||||
|
||||
describe("when premium tier is enabled", () => {
|
||||
it("should create a provider for the premium tier", () => {
|
||||
const config = createTestConfig({
|
||||
tts: {
|
||||
default: { enabled: false, url: "", voice: "", format: "" },
|
||||
premium: {
|
||||
enabled: true,
|
||||
url: "http://chatterbox-tts:8881/v1",
|
||||
},
|
||||
fallback: { enabled: false, url: "" },
|
||||
},
|
||||
});
|
||||
|
||||
const providers = createTTSProviders(config);
|
||||
|
||||
expect(providers.size).toBe(1);
|
||||
expect(providers.has("premium")).toBe(true);
|
||||
|
||||
const provider = providers.get("premium");
|
||||
expect(provider).toBeDefined();
|
||||
expect(provider?.tier).toBe("premium");
|
||||
expect(provider?.name).toBe("chatterbox");
|
||||
});
|
||||
});
|
||||
|
||||
// ==========================================
|
||||
// Fallback tier
|
||||
// ==========================================
|
||||
|
||||
describe("when fallback tier is enabled", () => {
|
||||
it("should create a provider for the fallback tier", () => {
|
||||
const config = createTestConfig({
|
||||
tts: {
|
||||
default: { enabled: false, url: "", voice: "", format: "" },
|
||||
premium: { enabled: false, url: "" },
|
||||
fallback: {
|
||||
enabled: true,
|
||||
url: "http://openedai-speech:8000/v1",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const providers = createTTSProviders(config);
|
||||
|
||||
expect(providers.size).toBe(1);
|
||||
expect(providers.has("fallback")).toBe(true);
|
||||
|
||||
const provider = providers.get("fallback");
|
||||
expect(provider).toBeDefined();
|
||||
expect(provider?.tier).toBe("fallback");
|
||||
expect(provider?.name).toBe("piper");
|
||||
});
|
||||
});
|
||||
|
||||
// ==========================================
|
||||
// Multiple tiers
|
||||
// ==========================================
|
||||
|
||||
describe("when multiple tiers are enabled", () => {
|
||||
it("should create providers for all enabled tiers", () => {
|
||||
const config = createTestConfig({
|
||||
tts: {
|
||||
default: {
|
||||
enabled: true,
|
||||
url: "http://kokoro-tts:8880/v1",
|
||||
voice: "af_heart",
|
||||
format: "mp3",
|
||||
},
|
||||
premium: {
|
||||
enabled: true,
|
||||
url: "http://chatterbox-tts:8881/v1",
|
||||
},
|
||||
fallback: {
|
||||
enabled: true,
|
||||
url: "http://openedai-speech:8000/v1",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const providers = createTTSProviders(config);
|
||||
|
||||
expect(providers.size).toBe(3);
|
||||
expect(providers.has("default")).toBe(true);
|
||||
expect(providers.has("premium")).toBe(true);
|
||||
expect(providers.has("fallback")).toBe(true);
|
||||
});
|
||||
|
||||
it("should create providers only for enabled tiers", () => {
|
||||
const config = createTestConfig({
|
||||
tts: {
|
||||
default: {
|
||||
enabled: true,
|
||||
url: "http://kokoro-tts:8880/v1",
|
||||
voice: "af_heart",
|
||||
format: "mp3",
|
||||
},
|
||||
premium: { enabled: false, url: "" },
|
||||
fallback: {
|
||||
enabled: true,
|
||||
url: "http://openedai-speech:8000/v1",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const providers = createTTSProviders(config);
|
||||
|
||||
expect(providers.size).toBe(2);
|
||||
expect(providers.has("default")).toBe(true);
|
||||
expect(providers.has("premium")).toBe(false);
|
||||
expect(providers.has("fallback")).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
// ==========================================
|
||||
// Provider properties
|
||||
// ==========================================
|
||||
|
||||
describe("provider properties", () => {
|
||||
it("should implement ITTSProvider interface methods", () => {
|
||||
const config = createTestConfig({
|
||||
tts: {
|
||||
default: {
|
||||
enabled: true,
|
||||
url: "http://kokoro-tts:8880/v1",
|
||||
voice: "af_heart",
|
||||
format: "mp3",
|
||||
},
|
||||
premium: { enabled: false, url: "" },
|
||||
fallback: { enabled: false, url: "" },
|
||||
},
|
||||
});
|
||||
|
||||
const providers = createTTSProviders(config);
|
||||
const provider = providers.get("default");
|
||||
|
||||
expect(provider).toBeDefined();
|
||||
expect(typeof provider?.synthesize).toBe("function");
|
||||
expect(typeof provider?.listVoices).toBe("function");
|
||||
expect(typeof provider?.isHealthy).toBe("function");
|
||||
});
|
||||
|
||||
it("should return providers as a Map<SpeechTier, ITTSProvider>", () => {
|
||||
const config = createTestConfig({
|
||||
tts: {
|
||||
default: {
|
||||
enabled: true,
|
||||
url: "http://kokoro-tts:8880/v1",
|
||||
voice: "af_heart",
|
||||
format: "mp3",
|
||||
},
|
||||
premium: { enabled: false, url: "" },
|
||||
fallback: { enabled: false, url: "" },
|
||||
},
|
||||
});
|
||||
|
||||
const providers = createTTSProviders(config);
|
||||
|
||||
// Verify the map keys are valid SpeechTier values
|
||||
for (const [tier] of providers) {
|
||||
expect(["default", "premium", "fallback"]).toContain(tier as SpeechTier);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
112
apps/api/src/speech/providers/tts-provider.factory.ts
Normal file
112
apps/api/src/speech/providers/tts-provider.factory.ts
Normal file
@@ -0,0 +1,112 @@
|
||||
/**
|
||||
* TTS Provider Factory
|
||||
*
|
||||
* Creates and registers TTS providers based on speech configuration.
|
||||
* Reads enabled flags and URLs from config and instantiates the appropriate
|
||||
* provider for each tier.
|
||||
*
|
||||
* Each tier maps to a specific TTS engine:
|
||||
* - default: Kokoro-FastAPI (CPU, always available)
|
||||
* - premium: Chatterbox (GPU, voice cloning)
|
||||
* - fallback: Piper via OpenedAI Speech (ultra-lightweight CPU)
|
||||
*
|
||||
* Issue #391
|
||||
*/
|
||||
|
||||
import { Logger } from "@nestjs/common";
|
||||
import { BaseTTSProvider } from "./base-tts.provider";
|
||||
import type { ITTSProvider } from "../interfaces/tts-provider.interface";
|
||||
import type { SpeechTier, AudioFormat } from "../interfaces/speech-types";
|
||||
import type { SpeechConfig } from "../speech.config";
|
||||
|
||||
// ==========================================
|
||||
// Concrete provider classes
|
||||
// ==========================================
|
||||
|
||||
/**
|
||||
* Kokoro TTS provider (default tier).
|
||||
* CPU-based, always available, Apache 2.0 license.
|
||||
*/
|
||||
class KokoroProvider extends BaseTTSProvider {
|
||||
readonly name = "kokoro";
|
||||
readonly tier: SpeechTier = "default";
|
||||
}
|
||||
|
||||
/**
|
||||
* Chatterbox TTS provider (premium tier).
|
||||
* GPU required, voice cloning capable, MIT license.
|
||||
*/
|
||||
class ChatterboxProvider extends BaseTTSProvider {
|
||||
readonly name = "chatterbox";
|
||||
readonly tier: SpeechTier = "premium";
|
||||
|
||||
constructor(baseURL: string) {
|
||||
super(baseURL, "default", "mp3");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Piper TTS provider via OpenedAI Speech (fallback tier).
|
||||
* Ultra-lightweight CPU, GPL license.
|
||||
*/
|
||||
class PiperProvider extends BaseTTSProvider {
|
||||
readonly name = "piper";
|
||||
readonly tier: SpeechTier = "fallback";
|
||||
|
||||
constructor(baseURL: string) {
|
||||
super(baseURL, "alloy", "mp3");
|
||||
}
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// Factory function
|
||||
// ==========================================
|
||||
|
||||
const logger = new Logger("TTSProviderFactory");
|
||||
|
||||
/**
|
||||
* Create and register TTS providers based on the speech configuration.
|
||||
*
|
||||
* Only creates providers for tiers that are enabled in the config.
|
||||
* Returns a Map keyed by SpeechTier for use with the TTS_PROVIDERS injection token.
|
||||
*
|
||||
* @param config - Speech configuration with TTS tier settings
|
||||
* @returns Map of enabled TTS providers keyed by tier
|
||||
*/
|
||||
export function createTTSProviders(config: SpeechConfig): Map<SpeechTier, ITTSProvider> {
|
||||
const providers = new Map<SpeechTier, ITTSProvider>();
|
||||
|
||||
// Default tier: Kokoro
|
||||
if (config.tts.default.enabled) {
|
||||
const provider = new KokoroProvider(
|
||||
config.tts.default.url,
|
||||
config.tts.default.voice,
|
||||
config.tts.default.format as AudioFormat
|
||||
);
|
||||
providers.set("default", provider);
|
||||
logger.log(`Registered default TTS provider: kokoro at ${config.tts.default.url}`);
|
||||
}
|
||||
|
||||
// Premium tier: Chatterbox
|
||||
if (config.tts.premium.enabled) {
|
||||
const provider = new ChatterboxProvider(config.tts.premium.url);
|
||||
providers.set("premium", provider);
|
||||
logger.log(`Registered premium TTS provider: chatterbox at ${config.tts.premium.url}`);
|
||||
}
|
||||
|
||||
// Fallback tier: Piper
|
||||
if (config.tts.fallback.enabled) {
|
||||
const provider = new PiperProvider(config.tts.fallback.url);
|
||||
providers.set("fallback", provider);
|
||||
logger.log(`Registered fallback TTS provider: piper at ${config.tts.fallback.url}`);
|
||||
}
|
||||
|
||||
if (providers.size === 0) {
|
||||
logger.warn("No TTS providers are enabled. TTS synthesis will not be available.");
|
||||
} else {
|
||||
const tierNames = Array.from(providers.keys()).join(", ");
|
||||
logger.log(`TTS providers ready: ${tierNames} (${String(providers.size)} total)`);
|
||||
}
|
||||
|
||||
return providers;
|
||||
}
|
||||
Reference in New Issue
Block a user