Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions array_api_compat/cupy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
from cupy.linalg import * # noqa: F403
# cupy.linalg doesn't have __all__. If it is added, replace this with

# https://github.com/cupy/cupy/issues/9749
from cupy.linalg import lstsq # noqa: F401

# cupy.linalg doesn't have __all__ in cupy<14. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
_n: dict[str, object] = {}
exec('from cupy.linalg import *', _n)
del _n['__builtins__']
linalg_all = list(_n)
linalg_all = list(_n) + ['lstsq']
del _n

try:
# cupy 14 exports it, cupy 13 does not
from cupy.linalg import annotations # noqa: F401
linalg_all += ['annotations']
except ImportError:
pass


from ..common import _linalg
from .._internal import get_xp

Expand Down Expand Up @@ -43,5 +55,8 @@

__all__ = linalg_all + _linalg.__all__

# cupy 13 does not have __all__, cupy 14 has it: remove duplicates
__all__ = sorted(list(set(__all__)))

def __dir__() -> list[str]:
return __all__
8 changes: 7 additions & 1 deletion tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
# Manipulation Functions
"broadcast_arrays",
"broadcast_to",
"broadcast_shapes",
"concat",
"expand_dims",
"flip",
Expand All @@ -164,6 +165,7 @@
"unique_counts",
"unique_inverse",
"unique_values",
"isin",
# Sorting Functions
"argsort",
"sort",
Expand Down Expand Up @@ -205,6 +207,8 @@
"diagonal",
"eigh",
"eigvalsh",
"eig",
"eigvals",
"inv",
"matmul",
"matrix_norm",
Expand All @@ -227,12 +231,14 @@

XFAILS = {
("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [],
("dask.array", ""): ["from_dlpack", "take_along_axis"],
("dask.array", ""): ["from_dlpack", "take_along_axis", "broadcast_shapes"],
("dask.array", "linalg"): [
"cross",
"det",
"eigh",
"eigvalsh",
"eig",
"eigvals",
"matrix_power",
"pinv",
"slogdet",
Expand Down
Loading