Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,44 @@ client = Auth0(
)
```

### Custom Domains

If your Auth0 tenant uses multiple custom domains, you can specify which custom domain to use via the `Auth0-Custom-Domain` header. The SDK enforces a whitelist, the header is only sent on supported endpoints.

**Global (all whitelisted requests):**

```python
from auth0.management import ManagementClient

client = ManagementClient(
domain="your-tenant.auth0.com",
token="YOUR_TOKEN",
custom_domain="login.mycompany.com",
)
```

**Per-request override:**

```python
from auth0.management import ManagementClient, CustomDomainHeader

client = ManagementClient(
domain="your-tenant.auth0.com",
token="YOUR_TOKEN",
custom_domain="login.mycompany.com",
)

# Override the global custom domain for this specific request
client.users.create(
connection="Username-Password-Authentication",
email="user@example.com",
password="SecurePass123!",
request_options=CustomDomainHeader("other.mycompany.com"),
)
```

If both a global `custom_domain` and a per-request `CustomDomainHeader` are provided, the per-request value takes precedence.

## Feedback

### Contributing
Expand Down
4 changes: 3 additions & 1 deletion src/auth0/management/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,7 @@
from .client import AsyncAuth0, Auth0
from .environment import Auth0Environment
from .event_streams import EventStreamsCreateRequest
from .management_client import AsyncManagementClient, ManagementClient
from .management_client import AsyncManagementClient, CustomDomainHeader, ManagementClient
from .version import __version__
_dynamic_imports: typing.Dict[str, str] = {
"Action": ".types",
Expand Down Expand Up @@ -1458,6 +1458,7 @@
"CreatedAuthenticationMethodTypeEnum": ".types",
"CreatedUserAuthenticationMethodTypeEnum": ".types",
"CredentialId": ".types",
"CustomDomainHeader": ".management_client",
"CustomDomain": ".types",
"CustomDomainCustomClientIpHeader": ".types",
"CustomDomainCustomClientIpHeaderEnum": ".types",
Expand Down Expand Up @@ -2690,6 +2691,7 @@ def __dir__():
"CreatedUserAuthenticationMethodTypeEnum",
"CredentialId",
"CustomDomain",
"CustomDomainHeader",
"CustomDomainCustomClientIpHeader",
"CustomDomainCustomClientIpHeaderEnum",
"CustomDomainProvisioningTypeEnum",
Expand Down
79 changes: 78 additions & 1 deletion src/auth0/management/management_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,60 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
import re
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import httpx
from .client import AsyncAuth0, Auth0
from .token_provider import TokenProvider

CUSTOM_DOMAIN_HEADER = "Auth0-Custom-Domain"

WHITELISTED_PATH_PATTERNS: List[re.Pattern[str]] = [
re.compile(r"^/api/v2/jobs/verification-email$"),
re.compile(r"^/api/v2/tickets/email-verification$"),
re.compile(r"^/api/v2/tickets/password-change$"),
re.compile(r"^/api/v2/organizations/[^/]+/invitations$"),
re.compile(r"^/api/v2/users$"),
re.compile(r"^/api/v2/users/[^/]+$"),
re.compile(r"^/api/v2/guardian/enrollments/ticket$"),
re.compile(r"^/api/v2/self-service-profiles/[^/]+/sso-ticket$"),
]


def _is_path_whitelisted(path: str) -> bool:
"""Check if the given path is whitelisted for the custom domain header."""
return any(p.match(path) for p in WHITELISTED_PATH_PATTERNS)


def _enforce_custom_domain_whitelist(request: httpx.Request) -> None:
"""httpx event hook that strips Auth0-Custom-Domain on non-whitelisted paths."""
if CUSTOM_DOMAIN_HEADER in request.headers and not _is_path_whitelisted(
request.url.path
):
del request.headers[CUSTOM_DOMAIN_HEADER]


def CustomDomainHeader(domain: str) -> Dict[str, Any]:
"""Create request options that set the Auth0-Custom-Domain header for a single request.

When both a global custom_domain (set at client init) and a per-request
custom_domain_header are provided, the per-request value takes precedence.
The header is only sent on whitelisted endpoints.

Usage::

from auth0.management import ManagementClient, CustomDomainHeader

client = ManagementClient(domain="tenant.auth0.com", token="TOKEN")
client.users.create(
connection="Username-Password-Authentication",
email="user@example.com",
password="...",
request_options=CustomDomainHeader("login.mycompany.com"),
)
"""
return {"additional_headers": {CUSTOM_DOMAIN_HEADER: domain}}

if TYPE_CHECKING:
from .actions.client import ActionsClient, AsyncActionsClient
from .anomaly.client import AnomalyClient, AsyncAnomalyClient
Expand Down Expand Up @@ -86,6 +135,10 @@ class ManagementClient:
The API audience. Defaults to https://{domain}/api/v2/
headers : Optional[Dict[str, str]]
Additional headers to send with requests.
custom_domain : Optional[str]
A custom domain to send via the Auth0-Custom-Domain header.
The header is only sent on whitelisted endpoints. Use
``CustomDomainHeader()`` for per-request overrides.
timeout : Optional[float]
Request timeout in seconds. Defaults to 60.
httpx_client : Optional[httpx.Client]
Expand All @@ -106,6 +159,7 @@ def __init__(
client_secret: Optional[str] = None,
audience: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
custom_domain: Optional[str] = None,
timeout: Optional[float] = None,
httpx_client: Optional[httpx.Client] = None,
):
Expand All @@ -128,6 +182,15 @@ def __init__(
else:
resolved_token = token # type: ignore[assignment]

# Set up custom domain header with whitelist enforcement
if custom_domain is not None:
headers = {**(headers or {}), CUSTOM_DOMAIN_HEADER: custom_domain}
if httpx_client is None:
httpx_client = httpx.Client(timeout=timeout or 60, follow_redirects=True)
httpx_client.event_hooks.setdefault("request", []).append(
_enforce_custom_domain_whitelist
)

# Create underlying client
self._api = Auth0(
base_url=f"https://{domain}/api/v2",
Expand Down Expand Up @@ -333,6 +396,10 @@ class AsyncManagementClient:
The API audience. Defaults to https://{domain}/api/v2/
headers : Optional[Dict[str, str]]
Additional headers to send with requests.
custom_domain : Optional[str]
A custom domain to send via the Auth0-Custom-Domain header.
The header is only sent on whitelisted endpoints. Use
``CustomDomainHeader()`` for per-request overrides.
timeout : Optional[float]
Request timeout in seconds. Defaults to 60.
httpx_client : Optional[httpx.AsyncClient]
Expand All @@ -353,6 +420,7 @@ def __init__(
client_secret: Optional[str] = None,
audience: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
custom_domain: Optional[str] = None,
timeout: Optional[float] = None,
httpx_client: Optional[httpx.AsyncClient] = None,
):
Expand All @@ -378,6 +446,15 @@ def __init__(
else:
resolved_token = token # type: ignore[assignment]

# Set up custom domain header with whitelist enforcement
if custom_domain is not None:
headers = {**(headers or {}), CUSTOM_DOMAIN_HEADER: custom_domain}
if httpx_client is None:
httpx_client = httpx.AsyncClient(timeout=timeout or 60, follow_redirects=True)
httpx_client.event_hooks.setdefault("request", []).append(
_enforce_custom_domain_whitelist
)

# Create underlying client
self._api = AsyncAuth0(
base_url=f"https://{domain}/api/v2",
Expand Down
127 changes: 126 additions & 1 deletion tests/management/test_management_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
import time
from unittest.mock import MagicMock, patch

import httpx
import pytest

from auth0.management import AsyncManagementClient, AsyncTokenProvider, ManagementClient, TokenProvider
from auth0.management import (
AsyncManagementClient,
AsyncTokenProvider,
CustomDomainHeader,
ManagementClient,
TokenProvider,
)
from auth0.management.management_client import (
CUSTOM_DOMAIN_HEADER as _CUSTOM_DOMAIN_HEADER,
)
from auth0.management.management_client import (
_enforce_custom_domain_whitelist,
_is_path_whitelisted,
)


class TestManagementClientInit:
Expand Down Expand Up @@ -337,6 +351,115 @@ def test_init_with_custom_audience(self):
assert provider._audience == "https://custom.api.com/"


class TestCustomDomainHeader:
"""Tests for Auth0-Custom-Domain header support."""

def test_init_with_custom_domain(self):
"""Should set Auth0-Custom-Domain header when custom_domain is provided."""
client = ManagementClient(
domain="test.auth0.com",
token="my-token",
custom_domain="login.mycompany.com",
)
custom_headers = client._api._client_wrapper.get_custom_headers()
assert custom_headers is not None
assert custom_headers[_CUSTOM_DOMAIN_HEADER] == "login.mycompany.com"

def test_init_custom_domain_with_existing_headers(self):
"""Should merge custom_domain with other custom headers."""
client = ManagementClient(
domain="test.auth0.com",
token="my-token",
headers={"X-Custom": "value"},
custom_domain="login.mycompany.com",
)
custom_headers = client._api._client_wrapper.get_custom_headers()
assert custom_headers is not None
assert custom_headers["X-Custom"] == "value"
assert custom_headers[_CUSTOM_DOMAIN_HEADER] == "login.mycompany.com"

def test_init_without_custom_domain(self):
"""Should not set Auth0-Custom-Domain header when custom_domain is not provided."""
client = ManagementClient(
domain="test.auth0.com",
token="my-token",
)
custom_headers = client._api._client_wrapper.get_custom_headers()
assert custom_headers is None or _CUSTOM_DOMAIN_HEADER not in custom_headers

def test_custom_domain_header_helper(self):
"""Should return correct request options dict."""
result = CustomDomainHeader("login.mycompany.com")
assert result == {
"additional_headers": {
_CUSTOM_DOMAIN_HEADER: "login.mycompany.com",
}
}

def test_async_init_with_custom_domain(self):
"""Should set Auth0-Custom-Domain header on async client."""
client = AsyncManagementClient(
domain="test.auth0.com",
token="my-token",
custom_domain="login.mycompany.com",
)
custom_headers = client._api._client_wrapper.get_custom_headers()
assert custom_headers is not None
assert custom_headers[_CUSTOM_DOMAIN_HEADER] == "login.mycompany.com"

def test_whitelist_strips_header_on_non_whitelisted_path(self):
"""Should strip Auth0-Custom-Domain header on non-whitelisted paths."""
request = httpx.Request(
"GET",
"https://test.auth0.com/api/v2/clients",
headers={_CUSTOM_DOMAIN_HEADER: "login.mycompany.com"},
)
_enforce_custom_domain_whitelist(request)
assert _CUSTOM_DOMAIN_HEADER not in request.headers

def test_whitelist_keeps_header_on_whitelisted_path(self):
"""Should keep Auth0-Custom-Domain header on whitelisted paths."""
request = httpx.Request(
"POST",
"https://test.auth0.com/api/v2/users",
headers={_CUSTOM_DOMAIN_HEADER: "login.mycompany.com"},
)
_enforce_custom_domain_whitelist(request)
assert request.headers[_CUSTOM_DOMAIN_HEADER] == "login.mycompany.com"

@pytest.mark.parametrize(
"path",
[
"/api/v2/jobs/verification-email",
"/api/v2/tickets/email-verification",
"/api/v2/tickets/password-change",
"/api/v2/organizations/org_abc123/invitations",
"/api/v2/users",
"/api/v2/users/auth0|abc123",
"/api/v2/guardian/enrollments/ticket",
"/api/v2/self-service-profiles/ssp_abc123/sso-ticket",
],
)
def test_whitelisted_paths_match(self, path):
"""Should match all 8 whitelisted path patterns."""
assert _is_path_whitelisted(path) is True

@pytest.mark.parametrize(
"path",
[
"/api/v2/clients",
"/api/v2/connections",
"/api/v2/roles",
"/api/v2/users/auth0|abc123/roles",
"/api/v2/jobs/users-imports",
"/api/v2/tenants/settings",
],
)
def test_non_whitelisted_paths_do_not_match(self, path):
"""Should not match non-whitelisted paths."""
assert _is_path_whitelisted(path) is False


class TestImports:
"""Tests for module imports."""

Expand All @@ -345,6 +468,7 @@ def test_import_from_management(self):
from auth0.management import (
AsyncManagementClient,
AsyncTokenProvider,
CustomDomainHeader,
ManagementClient,
TokenProvider,
)
Expand All @@ -353,3 +477,4 @@ def test_import_from_management(self):
assert AsyncManagementClient is not None
assert TokenProvider is not None
assert AsyncTokenProvider is not None
assert CustomDomainHeader is not None
Loading