diff --git a/apps/api/src/auth/auth.controller.spec.ts b/apps/api/src/auth/auth.controller.spec.ts index 86d66fa..4ce8095 100644 --- a/apps/api/src/auth/auth.controller.spec.ts +++ b/apps/api/src/auth/auth.controller.spec.ts @@ -1,15 +1,18 @@ import { describe, it, expect, beforeEach, vi } from "vitest"; import { Test, TestingModule } from "@nestjs/testing"; import type { AuthUser, AuthSession } from "@mosaic/shared"; +import type { Request as ExpressRequest, Response as ExpressResponse } from "express"; import { AuthController } from "./auth.controller"; import { AuthService } from "./auth.service"; describe("AuthController", () => { let controller: AuthController; - let authService: AuthService; + + const mockNodeHandler = vi.fn().mockResolvedValue(undefined); const mockAuthService = { getAuth: vi.fn(), + getNodeHandler: vi.fn().mockReturnValue(mockNodeHandler), }; beforeEach(async () => { @@ -24,25 +27,30 @@ describe("AuthController", () => { }).compile(); controller = module.get(AuthController); - authService = module.get(AuthService); vi.clearAllMocks(); + + // Restore mock implementations after clearAllMocks + mockAuthService.getNodeHandler.mockReturnValue(mockNodeHandler); + mockNodeHandler.mockResolvedValue(undefined); }); describe("handleAuth", () => { - it("should call BetterAuth handler", async () => { - const mockHandler = vi.fn().mockResolvedValue({ status: 200 }); - mockAuthService.getAuth.mockReturnValue({ handler: mockHandler }); - + it("should delegate to BetterAuth node handler with Express req/res", async () => { const mockRequest = { method: "GET", url: "/auth/session", - }; + headers: {}, + ip: "127.0.0.1", + socket: { remoteAddress: "127.0.0.1" }, + } as unknown as ExpressRequest; - await controller.handleAuth(mockRequest as unknown as Request); + const mockResponse = {} as unknown as ExpressResponse; - expect(mockAuthService.getAuth).toHaveBeenCalled(); - expect(mockHandler).toHaveBeenCalledWith(mockRequest); + await controller.handleAuth(mockRequest, mockResponse); + + expect(mockAuthService.getNodeHandler).toHaveBeenCalled(); + expect(mockNodeHandler).toHaveBeenCalledWith(mockRequest, mockResponse); }); }); diff --git a/apps/api/src/auth/auth.controller.ts b/apps/api/src/auth/auth.controller.ts index 8b8f8d9..c632bbc 100644 --- a/apps/api/src/auth/auth.controller.ts +++ b/apps/api/src/auth/auth.controller.ts @@ -1,5 +1,6 @@ -import { Controller, All, Req, Get, UseGuards, Request, Logger } from "@nestjs/common"; +import { Controller, All, Req, Res, Get, UseGuards, Request, Logger } from "@nestjs/common"; import { Throttle } from "@nestjs/throttler"; +import type { Request as ExpressRequest, Response as ExpressResponse } from "express"; import type { AuthUser, AuthSession } from "@mosaic/shared"; import { AuthService } from "./auth.service"; import { AuthGuard } from "./guards/auth.guard"; @@ -88,37 +89,29 @@ export class AuthController { */ @All("*") @Throttle({ strict: { limit: 10, ttl: 60000 } }) - async handleAuth(@Req() req: Request): Promise { + async handleAuth(@Req() req: ExpressRequest, @Res() res: ExpressResponse): Promise { // Extract client IP for logging const clientIp = this.getClientIp(req); - const requestPath = (req as unknown as { url?: string }).url ?? "unknown"; - const method = (req as unknown as { method?: string }).method ?? "UNKNOWN"; // Log auth catch-all hits for monitoring and debugging - this.logger.debug(`Auth catch-all: ${method} ${requestPath} from ${clientIp}`); + this.logger.debug(`Auth catch-all: ${req.method} ${req.url} from ${clientIp}`); - const auth = this.authService.getAuth(); - return auth.handler(req); + const handler = this.authService.getNodeHandler(); + return handler(req, res); } /** * Extract client IP from request, handling proxies */ - private getClientIp(req: Request): string { - const reqWithHeaders = req as unknown as { - headers?: Record; - ip?: string; - socket?: { remoteAddress?: string }; - }; - + private getClientIp(req: ExpressRequest): string { // Check X-Forwarded-For header (for reverse proxy setups) - const forwardedFor = reqWithHeaders.headers?.["x-forwarded-for"]; + const forwardedFor = req.headers["x-forwarded-for"]; if (forwardedFor) { const ips = Array.isArray(forwardedFor) ? forwardedFor[0] : forwardedFor; return ips?.split(",")[0]?.trim() ?? "unknown"; } // Fall back to direct IP - return reqWithHeaders.ip ?? reqWithHeaders.socket?.remoteAddress ?? "unknown"; + return req.ip ?? req.socket.remoteAddress ?? "unknown"; } } diff --git a/apps/api/src/auth/auth.rate-limit.spec.ts b/apps/api/src/auth/auth.rate-limit.spec.ts index 89da36f..07bafb1 100644 --- a/apps/api/src/auth/auth.rate-limit.spec.ts +++ b/apps/api/src/auth/auth.rate-limit.spec.ts @@ -23,10 +23,17 @@ describe("AuthController - Rate Limiting", () => { let app: INestApplication; let loggerSpy: ReturnType; + const mockNodeHandler = vi.fn( + (_req: unknown, res: { statusCode: number; end: (body: string) => void }) => { + res.statusCode = 200; + res.end(JSON.stringify({})); + return Promise.resolve(); + } + ); + const mockAuthService = { - getAuth: vi.fn().mockReturnValue({ - handler: vi.fn().mockResolvedValue({ status: 200, body: {} }), - }), + getAuth: vi.fn(), + getNodeHandler: vi.fn().mockReturnValue(mockNodeHandler), }; beforeEach(async () => { @@ -76,7 +83,7 @@ describe("AuthController - Rate Limiting", () => { expect(response.status).not.toBe(HttpStatus.TOO_MANY_REQUESTS); } - expect(mockAuthService.getAuth).toHaveBeenCalledTimes(3); + expect(mockAuthService.getNodeHandler).toHaveBeenCalledTimes(3); }); it("should return 429 when rate limit is exceeded", async () => { diff --git a/apps/api/src/auth/auth.service.ts b/apps/api/src/auth/auth.service.ts index c960766..d97553f 100644 --- a/apps/api/src/auth/auth.service.ts +++ b/apps/api/src/auth/auth.service.ts @@ -1,5 +1,7 @@ import { Injectable, Logger } from "@nestjs/common"; import type { PrismaClient } from "@prisma/client"; +import type { IncomingMessage, ServerResponse } from "http"; +import { toNodeHandler } from "better-auth/node"; import { PrismaService } from "../prisma/prisma.service"; import { createAuth, type Auth } from "./auth.config"; @@ -7,11 +9,13 @@ import { createAuth, type Auth } from "./auth.config"; export class AuthService { private readonly logger = new Logger(AuthService.name); private readonly auth: Auth; + private readonly nodeHandler: (req: IncomingMessage, res: ServerResponse) => Promise; constructor(private readonly prisma: PrismaService) { // PrismaService extends PrismaClient and is compatible with BetterAuth's adapter // Cast is safe as PrismaService provides all required PrismaClient methods this.auth = createAuth(this.prisma as unknown as PrismaClient); + this.nodeHandler = toNodeHandler(this.auth); } /** @@ -21,6 +25,14 @@ export class AuthService { return this.auth; } + /** + * Get Node.js-compatible request handler for BetterAuth. + * Wraps BetterAuth's Web API handler to work with Express/Node.js req/res. + */ + getNodeHandler(): (req: IncomingMessage, res: ServerResponse) => Promise { + return this.nodeHandler; + } + /** * Get user by ID */ diff --git a/apps/api/vitest.config.ts b/apps/api/vitest.config.ts index 83f3f78..0e687b2 100644 --- a/apps/api/vitest.config.ts +++ b/apps/api/vitest.config.ts @@ -26,7 +26,7 @@ export default defineConfig({ }, plugins: [ swc.vite({ - module: { type: "es6" }, + tsconfigFile: path.resolve(__dirname, "tsconfig.json"), }), ], });