diff --git a/README.md b/README.md index ab32bff1..d9e08abb 100644 --- a/README.md +++ b/README.md @@ -315,6 +315,44 @@ client = Auth0( ) ``` +### Custom Domains + +If your Auth0 tenant uses multiple custom domains, you can specify which custom domain to use via the `Auth0-Custom-Domain` header. The SDK enforces a whitelist, the header is only sent on supported endpoints. + +**Global (all whitelisted requests):** + +```python +from auth0.management import ManagementClient + +client = ManagementClient( + domain="your-tenant.auth0.com", + token="YOUR_TOKEN", + custom_domain="login.mycompany.com", +) +``` + +**Per-request override:** + +```python +from auth0.management import ManagementClient, CustomDomainHeader + +client = ManagementClient( + domain="your-tenant.auth0.com", + token="YOUR_TOKEN", + custom_domain="login.mycompany.com", +) + +# Override the global custom domain for this specific request +client.users.create( + connection="Username-Password-Authentication", + email="user@example.com", + password="SecurePass123!", + request_options=CustomDomainHeader("other.mycompany.com"), +) +``` + +If both a global `custom_domain` and a per-request `CustomDomainHeader` are provided, the per-request value takes precedence. + ## Feedback ### Contributing diff --git a/src/auth0/management/__init__.py b/src/auth0/management/__init__.py index 0faf06f7..c89e28b3 100644 --- a/src/auth0/management/__init__.py +++ b/src/auth0/management/__init__.py @@ -1215,7 +1215,7 @@ from .client import AsyncAuth0, Auth0 from .environment import Auth0Environment from .event_streams import EventStreamsCreateRequest - from .management_client import AsyncManagementClient, ManagementClient + from .management_client import AsyncManagementClient, CustomDomainHeader, ManagementClient from .version import __version__ _dynamic_imports: typing.Dict[str, str] = { "Action": ".types", @@ -1458,6 +1458,7 @@ "CreatedAuthenticationMethodTypeEnum": ".types", "CreatedUserAuthenticationMethodTypeEnum": ".types", "CredentialId": ".types", + "CustomDomainHeader": ".management_client", "CustomDomain": ".types", "CustomDomainCustomClientIpHeader": ".types", "CustomDomainCustomClientIpHeaderEnum": ".types", @@ -2690,6 +2691,7 @@ def __dir__(): "CreatedUserAuthenticationMethodTypeEnum", "CredentialId", "CustomDomain", + "CustomDomainHeader", "CustomDomainCustomClientIpHeader", "CustomDomainCustomClientIpHeaderEnum", "CustomDomainProvisioningTypeEnum", diff --git a/src/auth0/management/management_client.py b/src/auth0/management/management_client.py index 43d4d1d7..90bac0c9 100644 --- a/src/auth0/management/management_client.py +++ b/src/auth0/management/management_client.py @@ -1,11 +1,60 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Dict, Optional, Union +import re +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import httpx from .client import AsyncAuth0, Auth0 from .token_provider import TokenProvider +CUSTOM_DOMAIN_HEADER = "Auth0-Custom-Domain" + +WHITELISTED_PATH_PATTERNS: List[re.Pattern[str]] = [ + re.compile(r"^/api/v2/jobs/verification-email$"), + re.compile(r"^/api/v2/tickets/email-verification$"), + re.compile(r"^/api/v2/tickets/password-change$"), + re.compile(r"^/api/v2/organizations/[^/]+/invitations$"), + re.compile(r"^/api/v2/users$"), + re.compile(r"^/api/v2/users/[^/]+$"), + re.compile(r"^/api/v2/guardian/enrollments/ticket$"), + re.compile(r"^/api/v2/self-service-profiles/[^/]+/sso-ticket$"), +] + + +def _is_path_whitelisted(path: str) -> bool: + """Check if the given path is whitelisted for the custom domain header.""" + return any(p.match(path) for p in WHITELISTED_PATH_PATTERNS) + + +def _enforce_custom_domain_whitelist(request: httpx.Request) -> None: + """httpx event hook that strips Auth0-Custom-Domain on non-whitelisted paths.""" + if CUSTOM_DOMAIN_HEADER in request.headers and not _is_path_whitelisted( + request.url.path + ): + del request.headers[CUSTOM_DOMAIN_HEADER] + + +def CustomDomainHeader(domain: str) -> Dict[str, Any]: + """Create request options that set the Auth0-Custom-Domain header for a single request. + + When both a global custom_domain (set at client init) and a per-request + custom_domain_header are provided, the per-request value takes precedence. + The header is only sent on whitelisted endpoints. + + Usage:: + + from auth0.management import ManagementClient, CustomDomainHeader + + client = ManagementClient(domain="tenant.auth0.com", token="TOKEN") + client.users.create( + connection="Username-Password-Authentication", + email="user@example.com", + password="...", + request_options=CustomDomainHeader("login.mycompany.com"), + ) + """ + return {"additional_headers": {CUSTOM_DOMAIN_HEADER: domain}} + if TYPE_CHECKING: from .actions.client import ActionsClient, AsyncActionsClient from .anomaly.client import AnomalyClient, AsyncAnomalyClient @@ -86,6 +135,10 @@ class ManagementClient: The API audience. Defaults to https://{domain}/api/v2/ headers : Optional[Dict[str, str]] Additional headers to send with requests. + custom_domain : Optional[str] + A custom domain to send via the Auth0-Custom-Domain header. + The header is only sent on whitelisted endpoints. Use + ``CustomDomainHeader()`` for per-request overrides. timeout : Optional[float] Request timeout in seconds. Defaults to 60. httpx_client : Optional[httpx.Client] @@ -106,6 +159,7 @@ def __init__( client_secret: Optional[str] = None, audience: Optional[str] = None, headers: Optional[Dict[str, str]] = None, + custom_domain: Optional[str] = None, timeout: Optional[float] = None, httpx_client: Optional[httpx.Client] = None, ): @@ -128,6 +182,15 @@ def __init__( else: resolved_token = token # type: ignore[assignment] + # Set up custom domain header with whitelist enforcement + if custom_domain is not None: + headers = {**(headers or {}), CUSTOM_DOMAIN_HEADER: custom_domain} + if httpx_client is None: + httpx_client = httpx.Client(timeout=timeout or 60, follow_redirects=True) + httpx_client.event_hooks.setdefault("request", []).append( + _enforce_custom_domain_whitelist + ) + # Create underlying client self._api = Auth0( base_url=f"https://{domain}/api/v2", @@ -333,6 +396,10 @@ class AsyncManagementClient: The API audience. Defaults to https://{domain}/api/v2/ headers : Optional[Dict[str, str]] Additional headers to send with requests. + custom_domain : Optional[str] + A custom domain to send via the Auth0-Custom-Domain header. + The header is only sent on whitelisted endpoints. Use + ``CustomDomainHeader()`` for per-request overrides. timeout : Optional[float] Request timeout in seconds. Defaults to 60. httpx_client : Optional[httpx.AsyncClient] @@ -353,6 +420,7 @@ def __init__( client_secret: Optional[str] = None, audience: Optional[str] = None, headers: Optional[Dict[str, str]] = None, + custom_domain: Optional[str] = None, timeout: Optional[float] = None, httpx_client: Optional[httpx.AsyncClient] = None, ): @@ -378,6 +446,15 @@ def __init__( else: resolved_token = token # type: ignore[assignment] + # Set up custom domain header with whitelist enforcement + if custom_domain is not None: + headers = {**(headers or {}), CUSTOM_DOMAIN_HEADER: custom_domain} + if httpx_client is None: + httpx_client = httpx.AsyncClient(timeout=timeout or 60, follow_redirects=True) + httpx_client.event_hooks.setdefault("request", []).append( + _enforce_custom_domain_whitelist + ) + # Create underlying client self._api = AsyncAuth0( base_url=f"https://{domain}/api/v2", diff --git a/tests/management/test_management_client.py b/tests/management/test_management_client.py index c09c9c58..9813e98f 100644 --- a/tests/management/test_management_client.py +++ b/tests/management/test_management_client.py @@ -1,9 +1,23 @@ import time from unittest.mock import MagicMock, patch +import httpx import pytest -from auth0.management import AsyncManagementClient, AsyncTokenProvider, ManagementClient, TokenProvider +from auth0.management import ( + AsyncManagementClient, + AsyncTokenProvider, + CustomDomainHeader, + ManagementClient, + TokenProvider, +) +from auth0.management.management_client import ( + CUSTOM_DOMAIN_HEADER as _CUSTOM_DOMAIN_HEADER, +) +from auth0.management.management_client import ( + _enforce_custom_domain_whitelist, + _is_path_whitelisted, +) class TestManagementClientInit: @@ -337,6 +351,115 @@ def test_init_with_custom_audience(self): assert provider._audience == "https://custom.api.com/" +class TestCustomDomainHeader: + """Tests for Auth0-Custom-Domain header support.""" + + def test_init_with_custom_domain(self): + """Should set Auth0-Custom-Domain header when custom_domain is provided.""" + client = ManagementClient( + domain="test.auth0.com", + token="my-token", + custom_domain="login.mycompany.com", + ) + custom_headers = client._api._client_wrapper.get_custom_headers() + assert custom_headers is not None + assert custom_headers[_CUSTOM_DOMAIN_HEADER] == "login.mycompany.com" + + def test_init_custom_domain_with_existing_headers(self): + """Should merge custom_domain with other custom headers.""" + client = ManagementClient( + domain="test.auth0.com", + token="my-token", + headers={"X-Custom": "value"}, + custom_domain="login.mycompany.com", + ) + custom_headers = client._api._client_wrapper.get_custom_headers() + assert custom_headers is not None + assert custom_headers["X-Custom"] == "value" + assert custom_headers[_CUSTOM_DOMAIN_HEADER] == "login.mycompany.com" + + def test_init_without_custom_domain(self): + """Should not set Auth0-Custom-Domain header when custom_domain is not provided.""" + client = ManagementClient( + domain="test.auth0.com", + token="my-token", + ) + custom_headers = client._api._client_wrapper.get_custom_headers() + assert custom_headers is None or _CUSTOM_DOMAIN_HEADER not in custom_headers + + def test_custom_domain_header_helper(self): + """Should return correct request options dict.""" + result = CustomDomainHeader("login.mycompany.com") + assert result == { + "additional_headers": { + _CUSTOM_DOMAIN_HEADER: "login.mycompany.com", + } + } + + def test_async_init_with_custom_domain(self): + """Should set Auth0-Custom-Domain header on async client.""" + client = AsyncManagementClient( + domain="test.auth0.com", + token="my-token", + custom_domain="login.mycompany.com", + ) + custom_headers = client._api._client_wrapper.get_custom_headers() + assert custom_headers is not None + assert custom_headers[_CUSTOM_DOMAIN_HEADER] == "login.mycompany.com" + + def test_whitelist_strips_header_on_non_whitelisted_path(self): + """Should strip Auth0-Custom-Domain header on non-whitelisted paths.""" + request = httpx.Request( + "GET", + "https://test.auth0.com/api/v2/clients", + headers={_CUSTOM_DOMAIN_HEADER: "login.mycompany.com"}, + ) + _enforce_custom_domain_whitelist(request) + assert _CUSTOM_DOMAIN_HEADER not in request.headers + + def test_whitelist_keeps_header_on_whitelisted_path(self): + """Should keep Auth0-Custom-Domain header on whitelisted paths.""" + request = httpx.Request( + "POST", + "https://test.auth0.com/api/v2/users", + headers={_CUSTOM_DOMAIN_HEADER: "login.mycompany.com"}, + ) + _enforce_custom_domain_whitelist(request) + assert request.headers[_CUSTOM_DOMAIN_HEADER] == "login.mycompany.com" + + @pytest.mark.parametrize( + "path", + [ + "/api/v2/jobs/verification-email", + "/api/v2/tickets/email-verification", + "/api/v2/tickets/password-change", + "/api/v2/organizations/org_abc123/invitations", + "/api/v2/users", + "/api/v2/users/auth0|abc123", + "/api/v2/guardian/enrollments/ticket", + "/api/v2/self-service-profiles/ssp_abc123/sso-ticket", + ], + ) + def test_whitelisted_paths_match(self, path): + """Should match all 8 whitelisted path patterns.""" + assert _is_path_whitelisted(path) is True + + @pytest.mark.parametrize( + "path", + [ + "/api/v2/clients", + "/api/v2/connections", + "/api/v2/roles", + "/api/v2/users/auth0|abc123/roles", + "/api/v2/jobs/users-imports", + "/api/v2/tenants/settings", + ], + ) + def test_non_whitelisted_paths_do_not_match(self, path): + """Should not match non-whitelisted paths.""" + assert _is_path_whitelisted(path) is False + + class TestImports: """Tests for module imports.""" @@ -345,6 +468,7 @@ def test_import_from_management(self): from auth0.management import ( AsyncManagementClient, AsyncTokenProvider, + CustomDomainHeader, ManagementClient, TokenProvider, ) @@ -353,3 +477,4 @@ def test_import_from_management(self): assert AsyncManagementClient is not None assert TokenProvider is not None assert AsyncTokenProvider is not None + assert CustomDomainHeader is not None