Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/apps/detection/transforms/box_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def convert_box_to_mask(
boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b])
# apply to global mask
slicing = [b]
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type:ignore
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type: ignore
boxes_mask_np[tuple(slicing)] = boxes_only_mask
return convert_to_dst_type(src=boxes_mask_np, dst=boxes, dtype=torch.int16)[0]

Expand Down
2 changes: 1 addition & 1 deletion monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def update_ops_nested_label(self, nested_key: str, op: Operations) -> None:
raise ValueError("Nested_key input format is wrong. Please ensure it is like key1#0#key2")
root: str
child_key: str
(root, _, child_key) = keys
root, _, child_key = keys
if root not in self.ops:
self.ops[root] = [{}]
self.ops[root][0].update({child_key: None})
Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,7 +1948,7 @@ def create_workflow(

"""
_args = update_kwargs(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
(workflow_name, config_file) = _pop_args(
workflow_name, config_file = _pop_args(
_args, workflow_name=ConfigWorkflow, config_file=None
) # the default workflow name is "ConfigWorkflow"
if isinstance(workflow_name, str):
Expand Down
4 changes: 2 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class DatasetFunc(Dataset):
"""

def __init__(self, data: Any, func: Callable, **kwargs) -> None:
super().__init__(data=None, transform=None) # type:ignore
super().__init__(data=None, transform=None) # type: ignore
self.src = data
self.func = func
self.kwargs = kwargs
Expand Down Expand Up @@ -1635,7 +1635,7 @@ def _cachecheck(self, item_transformed):
return (_data, _meta)
return _data
else:
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type: ignore
for i, _item in enumerate(item_transformed):
for k in _item:
meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}")
Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def stopping_fn_from_loss() -> Callable[[Engine], Any]:
"""

def stopping_fn(engine: Engine) -> Any:
return -engine.state.output # type:ignore
return -engine.state.output # type: ignore

return stopping_fn

Expand Down
2 changes: 1 addition & 1 deletion monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def get_edge_surface_distance(
edges_spacing = None
if use_subvoxels:
edges_spacing = spacing if spacing is not None else ([1] * len(y_pred.shape))
(edges_pred, edges_gt, *areas) = get_mask_edges(
edges_pred, edges_gt, *areas = get_mask_edges(
y_pred, y, crop=True, spacing=edges_spacing, always_return_as_numpy=False
)
if not edges_gt.any():
Expand Down
84 changes: 84 additions & 0 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,90 @@ def __init__(
self.set_random_state(seed=get_seed())
self.overrides = overrides

# Automatically assign group ID to child transforms for inversion tracking
self._set_transform_groups()

def _set_transform_groups(self):
"""
Automatically set group IDs on child transforms for inversion tracking.

This allows Invertd to identify which transforms belong to this
``Compose`` instance, including wrapped transforms (for example,
array transforms inside dictionary transforms).

Args:
None.

Returns:
None.
"""
from monai.transforms.inverse import TraceableTransform

group_id = str(id(self))
visited = set() # Track visited objects to avoid infinite recursion

def set_group_recursive(obj, gid, allow_compose: bool = False):
"""
Recursively set a group ID on a transform and its wrapped transforms.

Args:
obj: Transform instance to process.
gid: Group identifier to assign.
allow_compose: Whether to set group on ``Compose`` instances.
``Compose`` internals are not traversed to preserve nested
pipeline boundaries.

Returns:
None.
"""
if obj is None or isinstance(obj, (bool, int, float, str, bytes)):
return

# Avoid infinite recursion
obj_id = id(obj)
if obj_id in visited:
return
visited.add(obj_id)

if isinstance(obj, Compose):
if allow_compose:
obj._group = gid
return

if isinstance(obj, TraceableTransform):
obj._group = gid

if isinstance(obj, Mapping):
for attr in obj.values():
set_group_recursive(attr, gid)
return

if isinstance(obj, (list, tuple, set)):
for attr in obj:
set_group_recursive(attr, gid)
return

attrs: list[Any] = []
if hasattr(obj, "__dict__"):
attrs.extend(vars(obj).values())

slots = getattr(type(obj), "__slots__", ())
if isinstance(slots, str):
slots = (slots,)
for slot in slots:
if slot.startswith("__"):
continue
try:
attrs.append(getattr(obj, slot))
except AttributeError:
continue

for attr in attrs:
set_group_recursive(attr, gid)

for transform in self.transforms:
set_group_recursive(transform, group_id, allow_compose=True)

@LazyTransform.lazy.setter # type: ignore
def lazy(self, val: bool):
self._lazy = val
Expand Down
15 changes: 14 additions & 1 deletion monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def _init_trace_threadlocal(self):
if not hasattr(self._tracing, "value"):
self._tracing.value = MONAIEnvVars.trace_transform() != "0"

# Initialize group identifier (set by Compose for automatic group tracking)
if not hasattr(self, "_group"):
self._group: str | None = None

def __getstate__(self):
"""When pickling, remove the `_tracing` member from the output, if present, since it's not picklable."""
_dict = dict(getattr(self, "__dict__", {})) # this makes __dict__ always present in the unpickled object
Expand Down Expand Up @@ -119,13 +123,22 @@ def get_transform_info(self) -> dict:
"""
Return a dictionary with the relevant information pertaining to an applied transform.
"""
# Ensure _group is initialized
self._init_trace_threadlocal()

vals = (
self.__class__.__name__,
id(self),
self.tracing,
self._do_transform if hasattr(self, "_do_transform") else True,
)
return dict(zip(self.transform_info_keys(), vals))
info = dict(zip(self.transform_info_keys(), vals))

# Add group if set (automatically set by Compose)
if self._group is not None:
info[TraceKeys.GROUP] = self._group

return info

def push_transform(self, data, *args, **kwargs):
"""
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""
A collection of "vanilla" transforms for IO functions.
"""

from __future__ import annotations

import inspect
Expand Down
30 changes: 27 additions & 3 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from monai.transforms.transform import MapTransform
from monai.transforms.utility.array import ToTensor
from monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode
from monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep
from monai.utils import PostFix, TraceKeys, convert_to_tensor, ensure_tuple, ensure_tuple_rep
from monai.utils.type_conversion import convert_to_dst_type

__all__ = [
Expand Down Expand Up @@ -859,6 +859,27 @@ def __init__(
self.post_func = ensure_tuple_rep(post_func, len(self.keys))
self._totensor = ToTensor()

def _filter_transforms_by_group(self, all_transforms: list[dict]) -> list[dict]:
"""Filter applied operations to only include transforms from the target pipeline.

Uses automatic group tracking where ``Compose`` assigns its ID to child transforms.

Args:
all_transforms: Full list of applied transform metadata dictionaries.

Returns:
Subset whose ``TraceKeys.GROUP`` matches ``str(id(self.transform))``, or the original
list when no match is found for backward compatibility.
"""
# Get the group ID of the transform (Compose instance)
target_group = str(id(self.transform))

# Filter transforms that match the target group
filtered = [xform for xform in all_transforms if xform.get(TraceKeys.GROUP) == target_group]

# If no transforms match (backward compatibility), return all transforms
return filtered if filtered else all_transforms

def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
d = dict(data)
for (
Expand Down Expand Up @@ -894,10 +915,13 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:

orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}"
if orig_key in d and isinstance(d[orig_key], MetaTensor):
transform_info = d[orig_key].applied_operations
all_transforms = d[orig_key].applied_operations
meta_info = d[orig_key].meta

# Automatically filter by Compose instance group ID
transform_info = self._filter_transforms_by_group(all_transforms)
else:
transform_info = d[InvertibleTransform.trace_key(orig_key)]
transform_info = self._filter_transforms_by_group(d[InvertibleTransform.trace_key(orig_key)])
meta_info = d.get(orig_meta_key, {})
if nearest_interp:
transform_info = convert_applied_interp_mode(
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def __init__(
# if the root log level is higher than INFO, set a separate stream handler to record
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
console.is_data_stats_handler = True # type:ignore[attr-defined]
console.is_data_stats_handler = True # type: ignore[attr-defined]
_logger.addHandler(console)

def __call__(
Expand Down
1 change: 1 addition & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ class TraceKeys(StrEnum):
TRACING: str = "tracing"
STATUSES: str = "statuses"
LAZY: str = "lazy"
GROUP: str = "group"


class TraceStatusKeys(StrEnum):
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_loader_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.
"""this test should not generate errors or
UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores"""

from __future__ import annotations

import multiprocessing as mp
Expand Down
1 change: 1 addition & 0 deletions tests/profile_subclass/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor
Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark
"""

from __future__ import annotations

import argparse
Expand Down
1 change: 1 addition & 0 deletions tests/profile_subclass/pyspy_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
To be used with py-spy, comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor
Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark
"""

from __future__ import annotations

import argparse
Expand Down
37 changes: 37 additions & 0 deletions tests/transforms/compose/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,43 @@ def test_data_loader_2(self):
self.assertAlmostEqual(out_1.cpu().item(), 0.28602141572)
set_determinism(None)

def test_set_transform_groups_on_wrapped_transform_attributes(self):
class _IdentityInvertible(mt.InvertibleTransform):
def __call__(self, data):
return data

def inverse(self, data):
return data

class _WrapperWithTransform:
def __init__(self):
self.transform = _IdentityInvertible()

def __call__(self, data):
return self.transform(data)

class _WrapperWithTransforms:
def __init__(self):
self.transforms = [_IdentityInvertible(), {"inner": _IdentityInvertible()}]

def __call__(self, data):
for transform in self.transforms:
if isinstance(transform, dict):
for nested_transform in transform.values():
data = nested_transform(data)
else:
data = transform(data)
return data

wrapped_transform = _WrapperWithTransform()
wrapped_transforms = _WrapperWithTransforms()
composed = mt.Compose([wrapped_transform, wrapped_transforms])
expected_group = str(id(composed))

self.assertEqual(getattr(wrapped_transform.transform, "_group", None), expected_group)
self.assertEqual(getattr(wrapped_transforms.transforms[0], "_group", None), expected_group)
self.assertEqual(getattr(wrapped_transforms.transforms[1]["inner"], "_group", None), expected_group)

def test_flatten_and_len(self):
x = mt.EnsureChannelFirst(channel_dim="no_channel")
t1 = mt.Compose([x, x, x, x, mt.Compose([mt.Compose([x, x]), x, x])])
Expand Down
1 change: 1 addition & 0 deletions tests/transforms/croppad/test_pad_nd_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Tests for pad_nd dtype support and backend selection.
Validates PyTorch padding preference and NumPy fallback behavior.
"""

from __future__ import annotations

import unittest
Expand Down
Loading
Loading