diff --git a/cuda_core/cuda/core/typing.py b/cuda_core/cuda/core/typing.py new file mode 100644 index 0000000000..a66ab1881f --- /dev/null +++ b/cuda_core/cuda/core/typing.py @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Public type aliases and protocols used in cuda.core API signatures.""" + +from cuda.core._memory._buffer import DevicePointerT +from cuda.core._stream import IsStreamT + +__all__ = [ + "DevicePointerT", + "IsStreamT", +] diff --git a/cuda_core/docs/source/api_private.rst b/cuda_core/docs/source/api_private.rst index 0aa88d1d64..3ee9aa7c7b 100644 --- a/cuda_core/docs/source/api_private.rst +++ b/cuda_core/docs/source/api_private.rst @@ -16,7 +16,7 @@ CUDA runtime .. autosummary:: :toctree: generated/ - _memory._buffer.DevicePointerT + typing.DevicePointerT _memory._virtual_memory_resource.VirtualMemoryAllocationTypeT _memory._virtual_memory_resource.VirtualMemoryLocationTypeT _memory._virtual_memory_resource.VirtualMemoryGranularityT @@ -41,4 +41,4 @@ CUDA protocols :toctree: generated/ :template: protocol.rst - _stream.IsStreamT + typing.IsStreamT diff --git a/cuda_core/tests/test_typing_imports.py b/cuda_core/tests/test_typing_imports.py new file mode 100644 index 0000000000..c05e3ae3b3 --- /dev/null +++ b/cuda_core/tests/test_typing_imports.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for cuda.core.typing public type aliases and protocols.""" + + +def test_typing_module_imports(): + """All type aliases and protocols are importable from cuda.core.typing.""" + from cuda.core.typing import ( + DevicePointerT, + IsStreamT, + ) + + assert DevicePointerT is not None + assert IsStreamT is not None + + +def test_typing_matches_private_definitions(): + """cuda.core.typing re-exports match the original private definitions.""" + from cuda.core._memory._buffer import DevicePointerT as _DevicePointerT + from cuda.core._stream import IsStreamT as _IsStreamT + from cuda.core.typing import ( + DevicePointerT, + IsStreamT, + ) + + assert DevicePointerT is _DevicePointerT + assert IsStreamT is _IsStreamT