fix(#411): remediate backend review findings — COOKIE_DOMAIN, TRUSTED_ORIGINS validation, verifySession

- Wire COOKIE_DOMAIN env var into BetterAuth cookie config
- Add URL validation for TRUSTED_ORIGINS (rejects non-HTTP, invalid URLs)
- Include original parse error in validateRedirectUri error message
- Distinguish infrastructure errors from auth errors in verifySession
  (Prisma/connection errors now propagate as 500 instead of masking as 401)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Jason Woltje
2026-02-16 12:31:53 -06:00
parent 3fbba135b9
commit 7ead8b1076
4 changed files with 166 additions and 7 deletions

View File

@@ -35,6 +35,7 @@ describe("auth.config", () => {
delete process.env.NEXT_PUBLIC_APP_URL;
delete process.env.NEXT_PUBLIC_API_URL;
delete process.env.TRUSTED_ORIGINS;
delete process.env.COOKIE_DOMAIN;
});
afterEach(() => {
@@ -185,6 +186,7 @@ describe("auth.config", () => {
expect(() => validateOidcConfig()).toThrow("OIDC_REDIRECT_URI must be a valid URL");
expect(() => validateOidcConfig()).toThrow("not-a-url");
expect(() => validateOidcConfig()).toThrow("Parse error:");
});
it("should throw when OIDC_REDIRECT_URI path does not start with /auth/callback", () => {
@@ -404,6 +406,40 @@ describe("auth.config", () => {
expect(origins).toContain("http://localhost:3001");
expect(origins).toHaveLength(5);
});
it("should reject invalid URLs in TRUSTED_ORIGINS with a warning", () => {
process.env.TRUSTED_ORIGINS = "not-a-url,https://valid.example.com";
process.env.NODE_ENV = "production";
const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
const origins = getTrustedOrigins();
expect(origins).toContain("https://valid.example.com");
expect(origins).not.toContain("not-a-url");
expect(warnSpy).toHaveBeenCalledWith(
expect.stringContaining('Ignoring invalid URL in TRUSTED_ORIGINS: "not-a-url"')
);
warnSpy.mockRestore();
});
it("should reject non-HTTP origins in TRUSTED_ORIGINS with a warning", () => {
process.env.TRUSTED_ORIGINS = "ftp://files.example.com,https://valid.example.com";
process.env.NODE_ENV = "production";
const warnSpy = vi.spyOn(console, "warn").mockImplementation(() => {});
const origins = getTrustedOrigins();
expect(origins).toContain("https://valid.example.com");
expect(origins).not.toContain("ftp://files.example.com");
expect(warnSpy).toHaveBeenCalledWith(
expect.stringContaining("Ignoring non-HTTP origin in TRUSTED_ORIGINS")
);
warnSpy.mockRestore();
});
});
describe("createAuth - session and cookie configuration", () => {
@@ -504,5 +540,43 @@ describe("auth.config", () => {
};
expect(config.advanced.defaultCookieAttributes.secure).toBe(false);
});
it("should set cookie domain when COOKIE_DOMAIN env var is present", () => {
process.env.COOKIE_DOMAIN = ".mosaicstack.dev";
const mockPrisma = {} as PrismaClient;
createAuth(mockPrisma);
expect(mockBetterAuth).toHaveBeenCalledOnce();
const config = mockBetterAuth.mock.calls[0][0] as {
advanced: {
defaultCookieAttributes: {
httpOnly: boolean;
secure: boolean;
sameSite: string;
domain?: string;
};
};
};
expect(config.advanced.defaultCookieAttributes.domain).toBe(".mosaicstack.dev");
});
it("should not set cookie domain when COOKIE_DOMAIN env var is absent", () => {
delete process.env.COOKIE_DOMAIN;
const mockPrisma = {} as PrismaClient;
createAuth(mockPrisma);
expect(mockBetterAuth).toHaveBeenCalledOnce();
const config = mockBetterAuth.mock.calls[0][0] as {
advanced: {
defaultCookieAttributes: {
httpOnly: boolean;
secure: boolean;
sameSite: string;
domain?: string;
};
};
};
expect(config.advanced.defaultCookieAttributes.domain).toBeUndefined();
});
});
});

View File

@@ -80,9 +80,11 @@ function validateRedirectUri(): void {
let parsed: URL;
try {
parsed = new URL(redirectUri);
} catch {
} catch (urlError: unknown) {
const detail = urlError instanceof Error ? urlError.message : String(urlError);
throw new Error(
`OIDC_REDIRECT_URI must be a valid URL. Current value: "${redirectUri}". ` +
`Parse error: ${detail}. ` +
`Example: "https://app.example.com/auth/callback/authentik".`
);
}
@@ -153,12 +155,23 @@ export function getTrustedOrigins(): string[] {
origins.push(process.env.NEXT_PUBLIC_API_URL);
}
// Comma-separated additional origins
// Comma-separated additional origins (validated)
if (process.env.TRUSTED_ORIGINS) {
const extra = process.env.TRUSTED_ORIGINS.split(",")
const rawOrigins = process.env.TRUSTED_ORIGINS.split(",")
.map((o) => o.trim())
.filter((o) => o !== "");
origins.push(...extra);
for (const origin of rawOrigins) {
try {
const parsed = new URL(origin);
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
console.warn(`[AUTH] Ignoring non-HTTP origin in TRUSTED_ORIGINS: "${origin}"`);
continue;
}
origins.push(origin);
} catch {
console.warn(`[AUTH] Ignoring invalid URL in TRUSTED_ORIGINS: "${origin}"`);
}
}
}
// Localhost fallbacks for development only
@@ -192,6 +205,7 @@ export function createAuth(prisma: PrismaClient) {
httpOnly: true,
secure: process.env.NODE_ENV === "production",
sameSite: "lax" as const,
...(process.env.COOKIE_DOMAIN ? { domain: process.env.COOKIE_DOMAIN } : {}),
},
},
trustedOrigins: getTrustedOrigins(),

View File

@@ -1,5 +1,26 @@
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
import { Test, TestingModule } from "@nestjs/testing";
// Mock better-auth modules before importing AuthService
vi.mock("better-auth/node", () => ({
toNodeHandler: vi.fn().mockReturnValue(vi.fn()),
}));
vi.mock("better-auth", () => ({
betterAuth: vi.fn().mockReturnValue({
handler: vi.fn(),
api: { getSession: vi.fn() },
}),
}));
vi.mock("better-auth/adapters/prisma", () => ({
prismaAdapter: vi.fn().mockReturnValue({}),
}));
vi.mock("better-auth/plugins", () => ({
genericOAuth: vi.fn().mockReturnValue({ id: "generic-oauth" }),
}));
import { AuthService } from "./auth.service";
import { PrismaService } from "../prisma/prisma.service";
@@ -294,7 +315,7 @@ describe("AuthService", () => {
expect(result).toBeNull();
});
it("should return null and log error on verification failure", async () => {
it("should return null and log warning on auth verification failure", async () => {
const auth = service.getAuth();
const mockGetSession = vi.fn().mockRejectedValue(new Error("Verification failed"));
auth.api = { getSession: mockGetSession } as any;
@@ -303,5 +324,38 @@ describe("AuthService", () => {
expect(result).toBeNull();
});
it("should re-throw Prisma infrastructure errors", async () => {
const auth = service.getAuth();
const prismaError = new Error("connect ECONNREFUSED 127.0.0.1:5432");
const mockGetSession = vi.fn().mockRejectedValue(prismaError);
auth.api = { getSession: mockGetSession } as any;
await expect(service.verifySession("any-token")).rejects.toThrow("ECONNREFUSED");
});
it("should re-throw timeout errors as infrastructure errors", async () => {
const auth = service.getAuth();
const timeoutError = new Error("Connection timeout after 5000ms");
const mockGetSession = vi.fn().mockRejectedValue(timeoutError);
auth.api = { getSession: mockGetSession } as any;
await expect(service.verifySession("any-token")).rejects.toThrow("timeout");
});
it("should re-throw errors with Prisma-prefixed constructor name", async () => {
const auth = service.getAuth();
class PrismaClientKnownRequestError extends Error {
constructor(message: string) {
super(message);
this.name = "PrismaClientKnownRequestError";
}
}
const prismaError = new PrismaClientKnownRequestError("Database connection lost");
const mockGetSession = vi.fn().mockRejectedValue(prismaError);
auth.api = { getSession: mockGetSession } as any;
await expect(service.verifySession("any-token")).rejects.toThrow("Database connection lost");
});
});
});

View File

@@ -108,9 +108,26 @@ export class AuthService {
session: session.session as Record<string, unknown>,
};
} catch (error) {
this.logger.error(
// Infrastructure errors (database down, connection failures) should propagate
// so the global exception filter returns 500/503, not 401
if (
error instanceof Error &&
(error.constructor.name.startsWith("Prisma") ||
error.message.includes("connect") ||
error.message.includes("ECONNREFUSED") ||
error.message.includes("timeout"))
) {
this.logger.error(
"Session verification failed due to infrastructure error",
error.stack,
);
throw error;
}
// Expected auth errors (invalid/expired token) return null
this.logger.warn(
"Session verification failed",
error instanceof Error ? error.message : "Unknown error"
error instanceof Error ? error.message : "Unknown error",
);
return null;
}