fix(#410): use toNodeHandler for BetterAuth Express compatibility
Some checks failed
ci/woodpecker/push/api Pipeline failed
Some checks failed
ci/woodpecker/push/api Pipeline failed
BetterAuth expects Web API Request objects (Fetch API standard) with headers.get(), but NestJS/Express passes IncomingMessage objects with headers[] property access. Use better-auth/node's toNodeHandler to properly convert between Express req/res and BetterAuth's Web API handler. Also fixes vitest SWC config to read the correct tsconfig for NestJS decorator metadata emission, which was causing DI injection failures in tests. Fixes #410 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,15 +1,18 @@
|
||||
import { describe, it, expect, beforeEach, vi } from "vitest";
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import type { AuthUser, AuthSession } from "@mosaic/shared";
|
||||
import type { Request as ExpressRequest, Response as ExpressResponse } from "express";
|
||||
import { AuthController } from "./auth.controller";
|
||||
import { AuthService } from "./auth.service";
|
||||
|
||||
describe("AuthController", () => {
|
||||
let controller: AuthController;
|
||||
let authService: AuthService;
|
||||
|
||||
const mockNodeHandler = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
const mockAuthService = {
|
||||
getAuth: vi.fn(),
|
||||
getNodeHandler: vi.fn().mockReturnValue(mockNodeHandler),
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
@@ -24,25 +27,30 @@ describe("AuthController", () => {
|
||||
}).compile();
|
||||
|
||||
controller = module.get<AuthController>(AuthController);
|
||||
authService = module.get<AuthService>(AuthService);
|
||||
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Restore mock implementations after clearAllMocks
|
||||
mockAuthService.getNodeHandler.mockReturnValue(mockNodeHandler);
|
||||
mockNodeHandler.mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
describe("handleAuth", () => {
|
||||
it("should call BetterAuth handler", async () => {
|
||||
const mockHandler = vi.fn().mockResolvedValue({ status: 200 });
|
||||
mockAuthService.getAuth.mockReturnValue({ handler: mockHandler });
|
||||
|
||||
it("should delegate to BetterAuth node handler with Express req/res", async () => {
|
||||
const mockRequest = {
|
||||
method: "GET",
|
||||
url: "/auth/session",
|
||||
};
|
||||
headers: {},
|
||||
ip: "127.0.0.1",
|
||||
socket: { remoteAddress: "127.0.0.1" },
|
||||
} as unknown as ExpressRequest;
|
||||
|
||||
await controller.handleAuth(mockRequest as unknown as Request);
|
||||
const mockResponse = {} as unknown as ExpressResponse;
|
||||
|
||||
expect(mockAuthService.getAuth).toHaveBeenCalled();
|
||||
expect(mockHandler).toHaveBeenCalledWith(mockRequest);
|
||||
await controller.handleAuth(mockRequest, mockResponse);
|
||||
|
||||
expect(mockAuthService.getNodeHandler).toHaveBeenCalled();
|
||||
expect(mockNodeHandler).toHaveBeenCalledWith(mockRequest, mockResponse);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Controller, All, Req, Get, UseGuards, Request, Logger } from "@nestjs/common";
|
||||
import { Controller, All, Req, Res, Get, UseGuards, Request, Logger } from "@nestjs/common";
|
||||
import { Throttle } from "@nestjs/throttler";
|
||||
import type { Request as ExpressRequest, Response as ExpressResponse } from "express";
|
||||
import type { AuthUser, AuthSession } from "@mosaic/shared";
|
||||
import { AuthService } from "./auth.service";
|
||||
import { AuthGuard } from "./guards/auth.guard";
|
||||
@@ -88,37 +89,29 @@ export class AuthController {
|
||||
*/
|
||||
@All("*")
|
||||
@Throttle({ strict: { limit: 10, ttl: 60000 } })
|
||||
async handleAuth(@Req() req: Request): Promise<unknown> {
|
||||
async handleAuth(@Req() req: ExpressRequest, @Res() res: ExpressResponse): Promise<void> {
|
||||
// Extract client IP for logging
|
||||
const clientIp = this.getClientIp(req);
|
||||
const requestPath = (req as unknown as { url?: string }).url ?? "unknown";
|
||||
const method = (req as unknown as { method?: string }).method ?? "UNKNOWN";
|
||||
|
||||
// Log auth catch-all hits for monitoring and debugging
|
||||
this.logger.debug(`Auth catch-all: ${method} ${requestPath} from ${clientIp}`);
|
||||
this.logger.debug(`Auth catch-all: ${req.method} ${req.url} from ${clientIp}`);
|
||||
|
||||
const auth = this.authService.getAuth();
|
||||
return auth.handler(req);
|
||||
const handler = this.authService.getNodeHandler();
|
||||
return handler(req, res);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract client IP from request, handling proxies
|
||||
*/
|
||||
private getClientIp(req: Request): string {
|
||||
const reqWithHeaders = req as unknown as {
|
||||
headers?: Record<string, string | string[] | undefined>;
|
||||
ip?: string;
|
||||
socket?: { remoteAddress?: string };
|
||||
};
|
||||
|
||||
private getClientIp(req: ExpressRequest): string {
|
||||
// Check X-Forwarded-For header (for reverse proxy setups)
|
||||
const forwardedFor = reqWithHeaders.headers?.["x-forwarded-for"];
|
||||
const forwardedFor = req.headers["x-forwarded-for"];
|
||||
if (forwardedFor) {
|
||||
const ips = Array.isArray(forwardedFor) ? forwardedFor[0] : forwardedFor;
|
||||
return ips?.split(",")[0]?.trim() ?? "unknown";
|
||||
}
|
||||
|
||||
// Fall back to direct IP
|
||||
return reqWithHeaders.ip ?? reqWithHeaders.socket?.remoteAddress ?? "unknown";
|
||||
return req.ip ?? req.socket.remoteAddress ?? "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,10 +23,17 @@ describe("AuthController - Rate Limiting", () => {
|
||||
let app: INestApplication;
|
||||
let loggerSpy: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
const mockNodeHandler = vi.fn(
|
||||
(_req: unknown, res: { statusCode: number; end: (body: string) => void }) => {
|
||||
res.statusCode = 200;
|
||||
res.end(JSON.stringify({}));
|
||||
return Promise.resolve();
|
||||
}
|
||||
);
|
||||
|
||||
const mockAuthService = {
|
||||
getAuth: vi.fn().mockReturnValue({
|
||||
handler: vi.fn().mockResolvedValue({ status: 200, body: {} }),
|
||||
}),
|
||||
getAuth: vi.fn(),
|
||||
getNodeHandler: vi.fn().mockReturnValue(mockNodeHandler),
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
@@ -76,7 +83,7 @@ describe("AuthController - Rate Limiting", () => {
|
||||
expect(response.status).not.toBe(HttpStatus.TOO_MANY_REQUESTS);
|
||||
}
|
||||
|
||||
expect(mockAuthService.getAuth).toHaveBeenCalledTimes(3);
|
||||
expect(mockAuthService.getNodeHandler).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it("should return 429 when rate limit is exceeded", async () => {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { Injectable, Logger } from "@nestjs/common";
|
||||
import type { PrismaClient } from "@prisma/client";
|
||||
import type { IncomingMessage, ServerResponse } from "http";
|
||||
import { toNodeHandler } from "better-auth/node";
|
||||
import { PrismaService } from "../prisma/prisma.service";
|
||||
import { createAuth, type Auth } from "./auth.config";
|
||||
|
||||
@@ -7,11 +9,13 @@ import { createAuth, type Auth } from "./auth.config";
|
||||
export class AuthService {
|
||||
private readonly logger = new Logger(AuthService.name);
|
||||
private readonly auth: Auth;
|
||||
private readonly nodeHandler: (req: IncomingMessage, res: ServerResponse) => Promise<void>;
|
||||
|
||||
constructor(private readonly prisma: PrismaService) {
|
||||
// PrismaService extends PrismaClient and is compatible with BetterAuth's adapter
|
||||
// Cast is safe as PrismaService provides all required PrismaClient methods
|
||||
this.auth = createAuth(this.prisma as unknown as PrismaClient);
|
||||
this.nodeHandler = toNodeHandler(this.auth);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -21,6 +25,14 @@ export class AuthService {
|
||||
return this.auth;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Node.js-compatible request handler for BetterAuth.
|
||||
* Wraps BetterAuth's Web API handler to work with Express/Node.js req/res.
|
||||
*/
|
||||
getNodeHandler(): (req: IncomingMessage, res: ServerResponse) => Promise<void> {
|
||||
return this.nodeHandler;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get user by ID
|
||||
*/
|
||||
|
||||
@@ -26,7 +26,7 @@ export default defineConfig({
|
||||
},
|
||||
plugins: [
|
||||
swc.vite({
|
||||
module: { type: "es6" },
|
||||
tsconfigFile: path.resolve(__dirname, "tsconfig.json"),
|
||||
}),
|
||||
],
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user