diff --git a/pathwaysutils/jax/__init__.py b/pathwaysutils/jax/__init__.py index c74d9a8..e5bc106 100644 --- a/pathwaysutils/jax/__init__.py +++ b/pathwaysutils/jax/__init__.py @@ -17,9 +17,8 @@ `pathwaysutils`'s compatibility window. """ -import functools -import jax +import functools class _FakeJaxFunction: @@ -46,20 +45,6 @@ def __call__(self, *args, **kwargs): raise ImportError(self.error_message) -try: - # jax>=0.7.1 - from jax.extend import backend # pylint: disable=g-import-not-at-top - - ifrt_proxy = backend.ifrt_proxy - del backend -except AttributeError: - # jax<0.7.1 - from jax.lib import xla_extension # pylint: disable=g-import-not-at-top - - ifrt_proxy = xla_extension.ifrt_proxy - del xla_extension - - try: # jax>=0.8.0 from jaxlib import _pathways # pylint: disable=g-import-not-at-top @@ -112,6 +97,5 @@ def ifrt_reshard_available() -> bool: del jax -del jax del _FakeJaxFunction del functools diff --git a/pathwaysutils/proxy_backend.py b/pathwaysutils/proxy_backend.py index d599f21..cf1f806 100644 --- a/pathwaysutils/proxy_backend.py +++ b/pathwaysutils/proxy_backend.py @@ -15,15 +15,15 @@ import jax from jax.extend import backend -from pathwaysutils import jax as pw_jax +from jax.extend.backend import ifrt_proxy def register_backend_factory() -> None: backend.register_backend_factory( "proxy", - lambda: pw_jax.ifrt_proxy.get_client( + lambda: ifrt_proxy.get_client( jax.config.read("jax_backend_target"), - pw_jax.ifrt_proxy.ClientConnectionOptions(), + ifrt_proxy.ClientConnectionOptions(), ), priority=-1, ) diff --git a/pathwaysutils/test/proxy_backend_test.py b/pathwaysutils/test/proxy_backend_test.py index 2d6b613..fb8ad8c 100644 --- a/pathwaysutils/test/proxy_backend_test.py +++ b/pathwaysutils/test/proxy_backend_test.py @@ -18,7 +18,7 @@ from absl.testing import absltest import jax from jax.extend import backend -from pathwaysutils import jax as pw_jax +from jax.extend.backend import ifrt_proxy from pathwaysutils import proxy_backend @@ -46,7 +46,7 @@ def test_no_proxy_backend_registration_raises_error(self): def test_proxy_backend_registration(self): self.enter_context( mock.patch.object( - pw_jax.ifrt_proxy, + ifrt_proxy, "get_client", return_value=mock.MagicMock(), )