diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index da301574..1943cb15 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -1,13 +1,25 @@ from cupy.linalg import * # noqa: F403 -# cupy.linalg doesn't have __all__. If it is added, replace this with + +# https://github.com/cupy/cupy/issues/9749 +from cupy.linalg import lstsq # noqa: F401 + +# cupy.linalg doesn't have __all__ in cupy<14. If it is added, replace this with # # from cupy.linalg import __all__ as linalg_all _n: dict[str, object] = {} exec('from cupy.linalg import *', _n) del _n['__builtins__'] -linalg_all = list(_n) +linalg_all = list(_n) + ['lstsq'] del _n +try: + # cupy 14 exports it, cupy 13 does not + from cupy.linalg import annotations # noqa: F401 + linalg_all += ['annotations'] +except ImportError: + pass + + from ..common import _linalg from .._internal import get_xp @@ -43,5 +55,8 @@ __all__ = linalg_all + _linalg.__all__ +# cupy 13 does not have __all__, cupy 14 has it: remove duplicates +__all__ = sorted(list(set(__all__))) + def __dir__() -> list[str]: return __all__ diff --git a/tests/test_all.py b/tests/test_all.py index c36aef67..d9350ce7 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -140,6 +140,7 @@ # Manipulation Functions "broadcast_arrays", "broadcast_to", + "broadcast_shapes", "concat", "expand_dims", "flip", @@ -164,6 +165,7 @@ "unique_counts", "unique_inverse", "unique_values", + "isin", # Sorting Functions "argsort", "sort", @@ -205,6 +207,8 @@ "diagonal", "eigh", "eigvalsh", + "eig", + "eigvals", "inv", "matmul", "matrix_norm", @@ -227,12 +231,14 @@ XFAILS = { ("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [], - ("dask.array", ""): ["from_dlpack", "take_along_axis"], + ("dask.array", ""): ["from_dlpack", "take_along_axis", "broadcast_shapes"], ("dask.array", "linalg"): [ "cross", "det", "eigh", "eigvalsh", + "eig", + "eigvals", "matrix_power", "pinv", "slogdet",