Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 1 addition & 17 deletions pathwaysutils/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
`pathwaysutils`'s compatibility window.
"""

import functools

import jax
import functools


class _FakeJaxFunction:
Expand All @@ -46,20 +45,6 @@ def __call__(self, *args, **kwargs):
raise ImportError(self.error_message)


try:
# jax>=0.7.1
from jax.extend import backend # pylint: disable=g-import-not-at-top

ifrt_proxy = backend.ifrt_proxy
del backend
except AttributeError:
# jax<0.7.1
from jax.lib import xla_extension # pylint: disable=g-import-not-at-top

ifrt_proxy = xla_extension.ifrt_proxy
del xla_extension


try:
# jax>=0.8.0
from jaxlib import _pathways # pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -112,6 +97,5 @@ def ifrt_reshard_available() -> bool:
del jax


del jax
del _FakeJaxFunction
del functools
6 changes: 3 additions & 3 deletions pathwaysutils/proxy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@

import jax
from jax.extend import backend
from pathwaysutils import jax as pw_jax
from jax.extend.backend import ifrt_proxy


def register_backend_factory() -> None:
backend.register_backend_factory(
"proxy",
lambda: pw_jax.ifrt_proxy.get_client(
lambda: ifrt_proxy.get_client(
jax.config.read("jax_backend_target"),
pw_jax.ifrt_proxy.ClientConnectionOptions(),
ifrt_proxy.ClientConnectionOptions(),
),
priority=-1,
)
4 changes: 2 additions & 2 deletions pathwaysutils/test/proxy_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from absl.testing import absltest
import jax
from jax.extend import backend
from pathwaysutils import jax as pw_jax
from jax.extend.backend import ifrt_proxy
from pathwaysutils import proxy_backend


Expand Down Expand Up @@ -46,7 +46,7 @@ def test_no_proxy_backend_registration_raises_error(self):
def test_proxy_backend_registration(self):
self.enter_context(
mock.patch.object(
pw_jax.ifrt_proxy,
ifrt_proxy,
"get_client",
return_value=mock.MagicMock(),
)
Expand Down
Loading