Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions pathwaysutils/elastic/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,14 @@ def get_active_slice_indices(
A set of integers representing the indices of the active slices.
"""
if slice_to_devices is None:
_logger.debug("slice_to_devices is None. Getting from jax.devices().")
slice_to_devices = get_slice_to_devices(tuple(jax.devices()))

_logger.debug(
"Getting active slice indices for slices: %s",
sorted(list(slice_to_devices.keys())),
)

active_slice_indices = set()

results = {
Expand All @@ -116,17 +122,19 @@ def get_active_slice_indices(
}

for slice_index, x in results.items():
_logger.info("Checking slice_index=%s", slice_index)
_logger.debug("Checking slice_index=%s", slice_index)
expected = (
np.zeros(len(slice_to_devices[slice_index]), dtype=float)
+ _SIMPLE_EXECUTION_TEST_VALUE
)
try:
with timing.Timer(f"Checking {slice_index=}"):
_logger.debug("Blocking until ready for slice_index=%s", slice_index)
jax.block_until_ready(x)
_logger.debug("Execution finished for slice_index=%s", slice_index)
if np.allclose(x, expected):
active_slice_indices.add(slice_index)
_logger.info("slice_index=%s active", slice_index)
_logger.debug("slice_index=%s active", slice_index)
else:
_logger.error(
"Error with _simple_execution for slice_index=%s. "
Expand All @@ -139,11 +147,15 @@ def get_active_slice_indices(
f"Error with _simple_execution for slice_index={slice_index}."
)
except jax.errors.JaxRuntimeError as error:
_logger.debug(
"Caught JaxRuntimeError for slice_index=%s: %s", slice_index, error
)
if not is_error_due_to_slice_down(error):
_logger.info("Re-raising error for slice_index=%s", slice_index)
raise
_logger.info("slice_index=%s bad", slice_index)
_logger.debug("slice_index=%s bad", slice_index)

_logger.info("active_slice_indices=%s", active_slice_indices)
_logger.debug("active_slice_indices=%s", active_slice_indices)

return active_slice_indices

Expand Down Expand Up @@ -174,22 +186,36 @@ def wait_for_slices(
active.
"""
if slice_to_devices is None:
_logger.debug("slice_to_devices is None. Getting from jax.devices().")
slice_to_devices = get_slice_to_devices(jax.devices())

_logger.info(
"Waiting for %s slices. Poll interval: %s, Timeout: %s",
slice_count,
poll_interval,
timeout,
)
start_time = time.time()

while True:
check_start_time = time.time()

_logger.debug("Checking active slices...")
active_slice_indices = get_active_slice_indices(slice_to_devices)
if len(active_slice_indices) >= slice_count:
_logger.info("%s slices active.", len(active_slice_indices))
_logger.info(
"Sufficient slices active: %s >= %s. Active indices: %s",
len(active_slice_indices),
slice_count,
active_slice_indices,
)
return active_slice_indices

_logger.info(
"%s slices active. Wanting at least %s.",
"%s slices active. Wanting at least %s. Active indices: %s",
len(active_slice_indices),
slice_count,
active_slice_indices,
)

time_to_sleep = max(0, poll_interval - (time.time() - check_start_time))
Expand All @@ -206,7 +232,7 @@ def wait_for_slices(
)

if time_to_sleep > 0:
_logger.info("Sleeping for %.2f seconds.", time_to_sleep)
_logger.debug("Sleeping for %.2f seconds.", time_to_sleep)

time.sleep(time_to_sleep)

Expand All @@ -228,10 +254,14 @@ def is_error_due_to_slice_down(error: Exception) -> bool:
traceback_logging_level = logging.DEBUG

if isinstance(error, jax.errors.JaxRuntimeError):
_logger.debug("Checking if JaxRuntimeError is due to slice down: %s", error)
if any(
error_type in str(error) for error_type in _ELASTIC_DOWN_ERROR_TYPES
):
_logger.info("Caught an error due to slice down")
_logger.debug(
"Caught an error due to slice down (matched"
" _ELASTIC_DOWN_ERROR_TYPES)"
)

error_due_to_slice_down = True

Expand All @@ -240,15 +270,16 @@ def is_error_due_to_slice_down(error: Exception) -> bool:
for error_type in _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES
):
_logger.warning(
"Caught an error due that may or may not be due to slice down. This"
" error will be treated as due to slice down."
"Caught an error that may or may not be due to slice down (matched"
" _ELASTIC_DOWN_ADDITIONAL_ERROR_TYPES). This error will be treated"
" as due to slice down."
)
traceback_logging_level = logging.WARNING

error_due_to_slice_down = True

if not error_due_to_slice_down:
_logger.info("Caught an error not due to slice down")
_logger.debug("Caught an error not due to slice down")

_logger.log(traceback_logging_level, "Error details:", exc_info=True)

Expand Down
174 changes: 164 additions & 10 deletions pathwaysutils/elastic/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
events. It also provides a utility for waiting for slices to become active.
"""

import _thread
from collections.abc import Callable, Mapping, Sequence
import functools
import logging
import threading
from typing import Any, TypeVar

import jax
Expand All @@ -34,6 +36,17 @@ class ElasticRuntimeError(RuntimeError):
"""Error raised when elasticity cannot continue."""


class ScaleUpError(RuntimeError):
"""Signals that the workload is ready to scale up.

This exception should be raised by user code when it detects that new hardware
is available and it wants to restart computation to make use of it.
Raising this exception will interrupt the current computation and cause the
elasticity manager to retry it with an updated slice configuration that
includes the new hardware.
"""


_F = TypeVar("_F", bound=Callable[..., Any])


Expand All @@ -54,11 +67,21 @@ def _elastic_event_cleanup() -> None:


class Manager:
"""Utility class for elastic training."""
"""Utility class for elastic training.

Attributes:
slice_to_devices: A mapping from slice index to a sequence of `jax.Device`
objects for that slice.
all_slice_indices: A set of all possible slice indices.
active_slice_indices: A set of indices of the currently active slices.
new_slice_event: A `threading.Event` that is set when new slices become
available during replica/resize mode.
"""

_total_slice_count: int | None = None
slice_to_devices: Mapping[int, Sequence[jax.Device]]
all_slice_indices: set[int]
active_slice_indices: set[int]
new_slice_event: threading.Event

def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
"""Initializes the manager.
Expand All @@ -70,20 +93,21 @@ def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
devices = jax.devices()
self.slice_to_devices = elastic.get_slice_to_devices(devices)

self.all_slice_indices = set(self.slice_to_devices.keys())

self.active_slice_indices = elastic.get_active_slice_indices(
slice_to_devices=self.slice_to_devices
)
self.new_slice_event = threading.Event()

@property
@functools.cached_property
def total_slice_count(self) -> int:
"""Returns the total number of slices."""
if self._total_slice_count is None:
self._total_slice_count = len(self.slice_to_devices)
return self._total_slice_count
"""The total number of slices."""
return len(self.slice_to_devices)

@property
def default_device(self) -> jax.Device:
"""Returns the device that should be set to the default device.
"""The device that should be set to the default device.

This will be from one of the slices in `active_slice_indices`.
"""
Expand All @@ -94,9 +118,14 @@ def default_device(self) -> jax.Device:

@property
def active_slice_count(self) -> int:
"""Returns the number of slices."""
"""The number of active slices."""
return len(self.active_slice_indices)

@property
def inactive_slice_indices(self) -> set[int]:
"""The set of inactive slice indices."""
return self.all_slice_indices - self.active_slice_indices

def scale_by_active_slices(self, x: int | float) -> int | float:
"""Scale x by the number of active slices."""
if isinstance(x, int):
Expand All @@ -114,6 +143,20 @@ def scale_by_active_slices(self, x: int | float) -> int | float:
else:
raise ValueError(f"Unsupported type: {type(x)=}")

def _cleanup_on_retry(self):
"""Cleans up JAX caches and traces on retry."""
try:
_logger.debug("Cleaning up any ongoing traces")
jax.profiler.stop_trace()
except (RuntimeError, ValueError):
_logger.debug("No ongoing traces to clean up")
except Exception: # pylint: disable=broad-exception-caught
_logger.exception("Error cleaning up ongoing traces")

jax.clear_caches()
for array in jax.live_arrays():
array.delete()

def _elasticity_retry_decorator(
self,
max_retries: int,
Expand Down Expand Up @@ -148,10 +191,23 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:

with jax.default_device(self.default_device):
return func(*args, **kwargs)
except ScaleUpError:
_logger.info("Scale up requested. Retrying.")
_elastic_event_cleanup()

if on_elastic_event_callback is not None:
on_elastic_event_callback()
except jax.errors.JaxRuntimeError as error:
if not elastic.is_error_due_to_slice_down(error):
raise

if self.new_slice_event.is_set():
_logger.info(
"Slice down event and new slice available detected. Retrying."
)
else:
_logger.info("Slice down event detected. Retrying.")

_elastic_event_cleanup()

if on_elastic_event_callback is not None:
Expand All @@ -162,6 +218,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
)

return wrapper

return decorator

def pause_resume(
Expand Down Expand Up @@ -197,7 +254,7 @@ def pause_resume(
occurs.

Returns:
The result of the wrapped function.
A decorator that retries the wrapped function.

Raises:
ElasticRuntimeError: If all retry attempts fail.
Expand All @@ -219,3 +276,100 @@ def internal_pre_callback():
pre_callback=internal_pre_callback,
on_elastic_event_callback=on_elastic_event_callback,
)

def _monitor_new_slices(
self, stop_event: threading.Event, poll_interval: float | int
):
"""Monitors for new slices and sets the `new_slice_event` if found."""
while not stop_event.wait(poll_interval):
try:
if not self.inactive_slice_indices:
_logger.debug("No inactive slices to check.")
continue

_logger.debug(
"Checking inactive slices: %s", self.inactive_slice_indices
)
inactive_slice_to_devices = {
i: self.slice_to_devices[i] for i in self.inactive_slice_indices
}
newly_active_indices = elastic.get_active_slice_indices(
inactive_slice_to_devices
)

if newly_active_indices:
_logger.info(
"New slices found: %s. Setting new slice event.",
newly_active_indices,
)
self.new_slice_event.set()
return

_logger.debug("No new slices found.")
except Exception: # pylint: disable=broad-exception-caught
_logger.exception("Error in monitor thread")

def replica_resize(
self,
max_resizes: int,
poll_interval: float = 10,
pre_callback: Callable[..., Any] | None = None,
on_elastic_event_callback: Callable[..., Any] | None = None,
) -> Callable[[_F], _F]:
"""Retries a function with replica/resize fault tolerance.

Args:
max_resizes: The maximum number of times to retry the function after
resizing the replica count.
poll_interval: The number of seconds to wait between active slice checks.
Defaults to 10 seconds.
pre_callback: A callback to call before the function is attempted.
on_elastic_event_callback: A callback to call after an elastic failure
occurs.

Returns:
A decorator that retries the wrapped function.

Raises:
ElasticRuntimeError: If all retry attempts fail.
Exception: Any other exception raised by the wrapped function that is not
due to a slice down event.
"""

def internal_pre_callback():
self.active_slice_indices = elastic.wait_for_slices(
slice_count=1,
slice_to_devices=self.slice_to_devices,
poll_interval=poll_interval,
)

if pre_callback is not None:
pre_callback()

retry_decorator = self._elasticity_retry_decorator(
max_retries=max_resizes,
pre_callback=internal_pre_callback,
on_elastic_event_callback=on_elastic_event_callback,
)

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
self.new_slice_event.clear()
stop_event = threading.Event()

monitor_thread = threading.Thread(
target=self._monitor_new_slices,
args=(stop_event, poll_interval),
daemon=True,
)
monitor_thread.start()
try:
return func(*args, **kwargs)
finally:
stop_event.set()
monitor_thread.join()

return retry_decorator(wrapper)

return decorator
Loading