import { Test, TestingModule } from "@nestjs/testing"; import { MatrixStreamingService } from "./matrix-streaming.service"; import { MatrixService } from "./matrix.service"; import { vi, describe, it, expect, beforeEach, afterEach } from "vitest"; import type { StreamResponseOptions } from "./matrix-streaming.service"; // Mock matrix-bot-sdk to prevent native module loading vi.mock("matrix-bot-sdk", () => { return { MatrixClient: class MockMatrixClient {}, SimpleFsStorageProvider: class MockStorageProvider { constructor(_filename: string) { // No-op for testing } }, AutojoinRoomsMixin: { setupOnClient: vi.fn(), }, }; }); // Mock MatrixClient const mockClient = { sendMessage: vi.fn().mockResolvedValue("$initial-event-id"), sendEvent: vi.fn().mockResolvedValue("$edit-event-id"), setTyping: vi.fn().mockResolvedValue(undefined), }; // Mock MatrixService const mockMatrixService = { isConnected: vi.fn().mockReturnValue(true), getClient: vi.fn().mockReturnValue(mockClient), }; /** * Helper: create an async iterable from an array of strings with optional delays */ async function* createTokenStream( tokens: string[], delayMs = 0 ): AsyncGenerator { for (const token of tokens) { if (delayMs > 0) { await new Promise((resolve) => setTimeout(resolve, delayMs)); } yield token; } } /** * Helper: create a token stream that throws an error mid-stream */ async function* createErrorStream( tokens: string[], errorAfter: number ): AsyncGenerator { let count = 0; for (const token of tokens) { if (count >= errorAfter) { throw new Error("LLM provider connection lost"); } yield token; count++; } } describe("MatrixStreamingService", () => { let service: MatrixStreamingService; beforeEach(async () => { vi.useFakeTimers({ shouldAdvanceTime: true }); const module: TestingModule = await Test.createTestingModule({ providers: [ MatrixStreamingService, { provide: MatrixService, useValue: mockMatrixService, }, ], }).compile(); service = module.get(MatrixStreamingService); // Clear all mocks vi.clearAllMocks(); // Re-apply default mock returns after clearing mockMatrixService.isConnected.mockReturnValue(true); mockMatrixService.getClient.mockReturnValue(mockClient); mockClient.sendMessage.mockResolvedValue("$initial-event-id"); mockClient.sendEvent.mockResolvedValue("$edit-event-id"); mockClient.setTyping.mockResolvedValue(undefined); }); afterEach(() => { vi.useRealTimers(); }); describe("editMessage", () => { it("should send a m.replace event to edit an existing message", async () => { await service.editMessage("!room:example.com", "$original-event-id", "Updated content"); expect(mockClient.sendEvent).toHaveBeenCalledWith("!room:example.com", "m.room.message", { "m.new_content": { msgtype: "m.text", body: "Updated content", }, "m.relates_to": { rel_type: "m.replace", event_id: "$original-event-id", }, // Fallback for clients that don't support edits msgtype: "m.text", body: "* Updated content", }); }); it("should throw error when client is not connected", async () => { mockMatrixService.isConnected.mockReturnValue(false); await expect( service.editMessage("!room:example.com", "$event-id", "content") ).rejects.toThrow("Matrix client is not connected"); }); it("should throw error when client is null", async () => { mockMatrixService.getClient.mockReturnValue(null); await expect( service.editMessage("!room:example.com", "$event-id", "content") ).rejects.toThrow("Matrix client is not connected"); }); }); describe("setTypingIndicator", () => { it("should call client.setTyping with true and timeout", async () => { await service.setTypingIndicator("!room:example.com", true); expect(mockClient.setTyping).toHaveBeenCalledWith("!room:example.com", true, 30000); }); it("should call client.setTyping with false to clear indicator", async () => { await service.setTypingIndicator("!room:example.com", false); expect(mockClient.setTyping).toHaveBeenCalledWith("!room:example.com", false, undefined); }); it("should throw error when client is not connected", async () => { mockMatrixService.isConnected.mockReturnValue(false); await expect(service.setTypingIndicator("!room:example.com", true)).rejects.toThrow( "Matrix client is not connected" ); }); }); describe("sendStreamingMessage", () => { it("should send an initial message and return the event ID", async () => { const eventId = await service.sendStreamingMessage("!room:example.com", "Thinking..."); expect(eventId).toBe("$initial-event-id"); expect(mockClient.sendMessage).toHaveBeenCalledWith("!room:example.com", { msgtype: "m.text", body: "Thinking...", }); }); it("should send a thread message when threadId is provided", async () => { const eventId = await service.sendStreamingMessage( "!room:example.com", "Thinking...", "$thread-root-id" ); expect(eventId).toBe("$initial-event-id"); expect(mockClient.sendMessage).toHaveBeenCalledWith("!room:example.com", { msgtype: "m.text", body: "Thinking...", "m.relates_to": { rel_type: "m.thread", event_id: "$thread-root-id", is_falling_back: true, "m.in_reply_to": { event_id: "$thread-root-id", }, }, }); }); it("should throw error when client is not connected", async () => { mockMatrixService.isConnected.mockReturnValue(false); await expect(service.sendStreamingMessage("!room:example.com", "Test")).rejects.toThrow( "Matrix client is not connected" ); }); }); describe("streamResponse", () => { it("should send initial 'Thinking...' message and start typing indicator", async () => { vi.useRealTimers(); const tokens = ["Hello", " world"]; const stream = createTokenStream(tokens); await service.streamResponse("!room:example.com", stream); // Should have sent initial message expect(mockClient.sendMessage).toHaveBeenCalledWith( "!room:example.com", expect.objectContaining({ msgtype: "m.text", body: "Thinking...", }) ); // Should have started typing indicator expect(mockClient.setTyping).toHaveBeenCalledWith("!room:example.com", true, 30000); }); it("should use custom initial message when provided", async () => { vi.useRealTimers(); const tokens = ["Hi"]; const stream = createTokenStream(tokens); const options: StreamResponseOptions = { initialMessage: "Processing..." }; await service.streamResponse("!room:example.com", stream, options); expect(mockClient.sendMessage).toHaveBeenCalledWith( "!room:example.com", expect.objectContaining({ body: "Processing...", }) ); }); it("should edit message with accumulated tokens on completion", async () => { vi.useRealTimers(); const tokens = ["Hello", " ", "world", "!"]; const stream = createTokenStream(tokens); await service.streamResponse("!room:example.com", stream); // The final edit should contain the full accumulated text const sendEventCalls = mockClient.sendEvent.mock.calls; const lastEditCall = sendEventCalls[sendEventCalls.length - 1]; expect(lastEditCall).toBeDefined(); // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access expect(lastEditCall[2]["m.new_content"].body).toBe("Hello world!"); }); it("should clear typing indicator on completion", async () => { vi.useRealTimers(); const tokens = ["Done"]; const stream = createTokenStream(tokens); await service.streamResponse("!room:example.com", stream); // Last setTyping call should be false const typingCalls = mockClient.setTyping.mock.calls; const lastTypingCall = typingCalls[typingCalls.length - 1]; expect(lastTypingCall).toEqual(["!room:example.com", false, undefined]); }); it("should rate-limit edits to at most one every 500ms", async () => { vi.useRealTimers(); // Send tokens with small delays - all within one 500ms window const tokens = ["a", "b", "c", "d", "e"]; const stream = createTokenStream(tokens, 50); // 50ms between tokens = 250ms total await service.streamResponse("!room:example.com", stream); // With 250ms total streaming time (5 tokens * 50ms), all tokens arrive // within one 500ms window. We expect at most 1 intermediate edit + 1 final edit, // or just the final edit. The key point is that there should NOT be 5 separate edits. const editCalls = mockClient.sendEvent.mock.calls.filter( (call) => call[1] === "m.room.message" ); // Should have fewer edits than tokens (rate limiting in effect) expect(editCalls.length).toBeLessThanOrEqual(2); // Should have at least the final edit expect(editCalls.length).toBeGreaterThanOrEqual(1); }); it("should handle errors gracefully and edit message with error notice", async () => { vi.useRealTimers(); const stream = createErrorStream(["Hello", " ", "world"], 2); await service.streamResponse("!room:example.com", stream); // Should edit message with error content const sendEventCalls = mockClient.sendEvent.mock.calls; const lastEditCall = sendEventCalls[sendEventCalls.length - 1]; expect(lastEditCall).toBeDefined(); // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access const finalBody = lastEditCall[2]["m.new_content"].body as string; expect(finalBody).toContain("error"); // Should clear typing on error const typingCalls = mockClient.setTyping.mock.calls; const lastTypingCall = typingCalls[typingCalls.length - 1]; expect(lastTypingCall).toEqual(["!room:example.com", false, undefined]); }); it("should include token usage in final message when provided", async () => { vi.useRealTimers(); const tokens = ["Hello"]; const stream = createTokenStream(tokens); const options: StreamResponseOptions = { showTokenUsage: true, tokenUsage: { prompt: 10, completion: 5, total: 15 }, }; await service.streamResponse("!room:example.com", stream, options); const sendEventCalls = mockClient.sendEvent.mock.calls; const lastEditCall = sendEventCalls[sendEventCalls.length - 1]; expect(lastEditCall).toBeDefined(); // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access const finalBody = lastEditCall[2]["m.new_content"].body as string; expect(finalBody).toContain("15"); }); it("should throw error when client is not connected", async () => { mockMatrixService.isConnected.mockReturnValue(false); const stream = createTokenStream(["test"]); await expect(service.streamResponse("!room:example.com", stream)).rejects.toThrow( "Matrix client is not connected" ); }); it("should handle empty token stream", async () => { vi.useRealTimers(); const stream = createTokenStream([]); await service.streamResponse("!room:example.com", stream); // Should still send initial message expect(mockClient.sendMessage).toHaveBeenCalled(); // Should edit with empty/no-content message const sendEventCalls = mockClient.sendEvent.mock.calls; expect(sendEventCalls.length).toBeGreaterThanOrEqual(1); // Should clear typing const typingCalls = mockClient.setTyping.mock.calls; const lastTypingCall = typingCalls[typingCalls.length - 1]; expect(lastTypingCall).toEqual(["!room:example.com", false, undefined]); }); it("should support thread context in streamResponse", async () => { vi.useRealTimers(); const tokens = ["Reply"]; const stream = createTokenStream(tokens); const options: StreamResponseOptions = { threadId: "$thread-root" }; await service.streamResponse("!room:example.com", stream, options); // Initial message should include thread relation expect(mockClient.sendMessage).toHaveBeenCalledWith( "!room:example.com", expect.objectContaining({ "m.relates_to": expect.objectContaining({ rel_type: "m.thread", event_id: "$thread-root", }), }) ); }); it("should perform multiple edits for long-running streams", async () => { vi.useRealTimers(); // Create tokens with 200ms delays - total ~2000ms, should get multiple edit windows const tokens = Array.from({ length: 10 }, (_, i) => `token${String(i)} `); const stream = createTokenStream(tokens, 200); await service.streamResponse("!room:example.com", stream); // With 10 tokens at 200ms each = 2000ms total, at 500ms intervals // we expect roughly 3-4 intermediate edits + 1 final = 4-5 total const editCalls = mockClient.sendEvent.mock.calls.filter( (call) => call[1] === "m.room.message" ); // Should have multiple edits (at least 2) but far fewer than 10 expect(editCalls.length).toBeGreaterThanOrEqual(2); expect(editCalls.length).toBeLessThanOrEqual(8); }); }); });