[Relax][Frontend][Torch] Fix squeeze to return no-op for non-unit dims#18907
[Relax][Frontend][Torch] Fix squeeze to return no-op for non-unit dims#18907Aharrypotter wants to merge 1 commit intoapache:mainfrom
Conversation
PyTorch's squeeze(dim) silently returns the original tensor when the specified dimension is not size 1. The TVM Relax frontend was not replicating this behavior: it forwarded dim unconditionally to relax.op.squeeze, which could raise an InternalError (TVM <= 0.23) or emit a redundant identity squeeze op on the current main branch. Two bugs fixed in BaseFXGraphImporter._squeeze: 1. Single-dim case: no check on the actual dimension size before passing to relax.op.squeeze. 2. List/tuple dims case: the existing filter only checked array bounds, not whether each axis is actually size 1. Fix: for statically-known non-unit dimensions, return the input tensor directly (no op emitted). Dynamic/symbolic dimensions are passed through conservatively as before. Updated tests: - Squeeze2 (squeeze() no-arg): expected axis filter corrected to [1,3] - Squeeze3 (squeeze(2) on non-unit dim): now expects no squeeze op - Squeeze4 (new): exact repro of issue apache#18442 — squeeze(1) on (32,10,5) Fixes apache#18442 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical behavioral discrepancy in the TVM Relax PyTorch frontend concerning the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request fixes an issue in the PyTorch frontend where squeeze(dim) on a non-unit dimension was not handled correctly, leading to errors or redundant operations. The fix aligns the behavior with PyTorch by making such operations a no-op. The changes involve adding checks for dimension sizes before emitting a squeeze operation. The accompanying test updates are thorough and correctly verify the new behavior, including a new test case that reproduces the original issue.
My review focuses on the implementation of the fix. I have one suggestion to refactor the code to reduce duplication and improve maintainability. This suggestion also points out a minor discrepancy with PyTorch's error handling for out-of-bounds dimensions in squeeze, which is currently handled as a silent no-op.
| if isinstance(dim, list | tuple) and len(dim) > 0: | ||
| shape = self.shape_of(x) | ||
| # Filter to only include axes where the dimension is 1 | ||
| # PyTorch decomposition may pass all axes. Filter to only include axes | ||
| # where the dimension is statically known to be 1, or is symbolic (dynamic). | ||
| # Axes where the dimension is statically known to be != 1 are silently | ||
| # ignored, matching PyTorch's behavior. | ||
| valid_dims = [] | ||
| for d in dim: | ||
| axis = d if d >= 0 else len(shape) + d | ||
| if axis < len(shape): | ||
| valid_dims.append(d) | ||
| # If no valid dims, use None to squeeze all size-1 dimensions | ||
| dim = valid_dims if valid_dims else None | ||
| axis = d if d >= 0 else ndim + d | ||
| if 0 <= axis < ndim: | ||
| dim_val = shape[axis] | ||
| # Include axis if it's dynamic (not a static integer) or is size 1 | ||
| if not isinstance(dim_val, tir.IntImm) or dim_val.value == 1: | ||
| valid_dims.append(d) | ||
| # If no axes will actually be squeezed, return the original tensor (PyTorch no-op) | ||
| if not valid_dims: | ||
| return x | ||
| dim = valid_dims | ||
| elif dim is not None: | ||
| # Single-dim case: mimic PyTorch behavior — if the dimension is statically | ||
| # known to not be 1, squeezing is a no-op, so return the original tensor. | ||
| axis = dim if dim >= 0 else ndim + dim | ||
| if 0 <= axis < ndim: | ||
| dim_val = shape[axis] | ||
| if isinstance(dim_val, tir.IntImm) and dim_val.value != 1: | ||
| return x |
There was a problem hiding this comment.
The logic for checking if a dimension is squeezable is duplicated for the single-dimension and multi-dimension cases. This could be refactored into a helper function to improve maintainability and reduce redundancy.
Additionally, the condition len(dim) > 0 prevents this logic from handling an empty list or tuple for dim. If dim is [], this if block is skipped, and a relax.op.squeeze(x, []) op is emitted, which is a redundant no-op. The suggested refactoring handles this case by returning x directly.
Finally, PyTorch's squeeze operator raises an IndexError for out-of-bounds dimensions, whereas this implementation silently ignores them. To fully match PyTorch's behavior, you might consider raising an error for out-of-bounds axes. The suggested refactoring includes a comment on this.
def _is_squeezable(d: int) -> bool:
axis = d if d >= 0 else ndim + d
if not (0 <= axis < ndim):
# To fully match PyTorch, an IndexError should be raised here.
# For now, returning False maintains the no-op behavior for out-of-bounds dims.
return False
dim_val = shape[axis]
# Squeezable if the dimension is dynamic or its size is 1.
return not isinstance(dim_val, tir.IntImm) or dim_val.value == 1
if isinstance(dim, list | tuple):
if not dim:
# An empty list of dims is a no-op.
return x
valid_dims = [d for d in dim if _is_squeezable(d)]
if not valid_dims:
return x
dim = valid_dims
elif dim is not None:
if not _is_squeezable(dim):
return x
What problem does this PR solve?
Fixes #18442
When using
squeeze(dim)with a dimension that is not size 1, PyTorch silently returns the original tensor unchanged (a no-op). The TVM Relax PyTorch frontend was not replicating this behavior — it forwardeddimunconditionally torelax.op.squeeze, which raised anInternalErrorin TVM ≤ 0.23 and still emits a redundant no-op squeeze op on the current main branch.Reproduction:
```python
import torch, torch.nn as nn
from tvm.relax.frontend.torch import from_exported_program
class SqueezeModel(nn.Module):
def forward(self, x):
return x.squeeze(1) # dim=1 has size 10, not 1
x = torch.randn(32, 10, 5)
ep = torch.export.export(SqueezeModel().eval(), (x,))
mod = from_exported_program(ep) # previously: InternalError
```
Two bugs fixed in
BaseFXGraphImporter._squeeze:relax.op.squeezeFix: for statically-known non-unit dimensions, return the input tensor directly. Dynamic/symbolic dimensions pass through conservatively as before.
Tests
Squeeze2(no-argsqueeze()): expected axis filter corrected from[0,1,2,3]to[1,3]Squeeze3(squeeze(2)on non-unit dim): now expects no squeeze op emittedSqueeze4(new): exact repro of [Bug] InternalError: Squeeze dimension check too strict compared to PyTorch behavior #18442 —squeeze(1)on shape(32,10,5)Checklist