Skip to content

[Relax][Frontend][Torch] Fix squeeze to return no-op for non-unit dims#18907

Open
Aharrypotter wants to merge 1 commit intoapache:mainfrom
Aharrypotter:fix/squeeze-non-unit-dim-18442
Open

[Relax][Frontend][Torch] Fix squeeze to return no-op for non-unit dims#18907
Aharrypotter wants to merge 1 commit intoapache:mainfrom
Aharrypotter:fix/squeeze-non-unit-dim-18442

Conversation

@Aharrypotter
Copy link

What problem does this PR solve?

Fixes #18442

When using squeeze(dim) with a dimension that is not size 1, PyTorch silently returns the original tensor unchanged (a no-op). The TVM Relax PyTorch frontend was not replicating this behavior — it forwarded dim unconditionally to relax.op.squeeze, which raised an InternalError in TVM ≤ 0.23 and still emits a redundant no-op squeeze op on the current main branch.

Reproduction:
```python
import torch, torch.nn as nn
from tvm.relax.frontend.torch import from_exported_program

class SqueezeModel(nn.Module):
def forward(self, x):
return x.squeeze(1) # dim=1 has size 10, not 1

x = torch.randn(32, 10, 5)
ep = torch.export.export(SqueezeModel().eval(), (x,))
mod = from_exported_program(ep) # previously: InternalError
```

Two bugs fixed in BaseFXGraphImporter._squeeze:

  1. Single-dim case: no check on the actual dimension size before passing to relax.op.squeeze
  2. List/tuple dims case: existing filter only checked array bounds, not whether each axis is actually size 1

Fix: for statically-known non-unit dimensions, return the input tensor directly. Dynamic/symbolic dimensions pass through conservatively as before.

Tests

Checklist

PyTorch's squeeze(dim) silently returns the original tensor when the
specified dimension is not size 1. The TVM Relax frontend was not
replicating this behavior: it forwarded dim unconditionally to
relax.op.squeeze, which could raise an InternalError (TVM <= 0.23)
or emit a redundant identity squeeze op on the current main branch.

Two bugs fixed in BaseFXGraphImporter._squeeze:

1. Single-dim case: no check on the actual dimension size before
   passing to relax.op.squeeze.
2. List/tuple dims case: the existing filter only checked array bounds,
   not whether each axis is actually size 1.

Fix: for statically-known non-unit dimensions, return the input tensor
directly (no op emitted). Dynamic/symbolic dimensions are passed through
conservatively as before.

Updated tests:
- Squeeze2 (squeeze() no-arg): expected axis filter corrected to [1,3]
- Squeeze3 (squeeze(2) on non-unit dim): now expects no squeeze op
- Squeeze4 (new): exact repro of issue apache#18442 — squeeze(1) on (32,10,5)

Fixes apache#18442

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical behavioral discrepancy in the TVM Relax PyTorch frontend concerning the torch.squeeze operation. It ensures that when squeeze is applied to dimensions that are not of size 1, the frontend now correctly interprets this as a no-operation, mirroring PyTorch's behavior. This change prevents errors and unnecessary computational graph nodes, leading to more accurate and efficient translation of PyTorch models.

Highlights

  • Corrected torch.squeeze behavior: The TVM Relax PyTorch frontend now correctly handles torch.squeeze operations on dimensions that are not of size 1, treating them as no-operations to match PyTorch's native behavior.
  • Prevented InternalError and redundant ops: This fix resolves an issue where previously such squeeze calls could lead to an InternalError in older TVM versions or emit redundant relax.op.squeeze operations.
  • Enhanced dimension handling: The _squeeze method was improved to filter dimensions based on whether they are statically known to be size 1 or are dynamic, ensuring only relevant dimensions are squeezed.
  • New and updated test cases: The Squeeze2 test was updated for accurate axis filtering, and a new Squeeze4 test was added to specifically reproduce and verify the fix for the reported issue.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/tvm/relax/frontend/torch/base_fx_graph_translator.py
    • Implemented logic within _squeeze to return the original tensor directly if a single specified dimension is statically known to not be 1.
    • Enhanced the handling of list/tuple dimensions to filter out axes that are statically known to not be 1, only passing dynamic or size-1 dimensions to relax.op.squeeze.
  • tests/python/relax/test_frontend_from_exported_program.py
    • Updated the Squeeze2 test case to correctly reflect the expected axis filtering for squeeze with multiple dimensions.
    • Modified the Expected3 test case to show no squeeze op emitted when the dimension is not size 1.
    • Added a new Squeeze4 test case and its Expected4 IRModule to specifically validate the no-op behavior for squeeze on a non-unit dimension, directly addressing issue [Bug] InternalError: Squeeze dimension check too strict compared to PyTorch behavior #18442.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an issue in the PyTorch frontend where squeeze(dim) on a non-unit dimension was not handled correctly, leading to errors or redundant operations. The fix aligns the behavior with PyTorch by making such operations a no-op. The changes involve adding checks for dimension sizes before emitting a squeeze operation. The accompanying test updates are thorough and correctly verify the new behavior, including a new test case that reproduces the original issue.

My review focuses on the implementation of the fix. I have one suggestion to refactor the code to reduce duplication and improve maintainability. This suggestion also points out a minor discrepancy with PyTorch's error handling for out-of-bounds dimensions in squeeze, which is currently handled as a silent no-op.

Comment on lines 2139 to +2163
if isinstance(dim, list | tuple) and len(dim) > 0:
shape = self.shape_of(x)
# Filter to only include axes where the dimension is 1
# PyTorch decomposition may pass all axes. Filter to only include axes
# where the dimension is statically known to be 1, or is symbolic (dynamic).
# Axes where the dimension is statically known to be != 1 are silently
# ignored, matching PyTorch's behavior.
valid_dims = []
for d in dim:
axis = d if d >= 0 else len(shape) + d
if axis < len(shape):
valid_dims.append(d)
# If no valid dims, use None to squeeze all size-1 dimensions
dim = valid_dims if valid_dims else None
axis = d if d >= 0 else ndim + d
if 0 <= axis < ndim:
dim_val = shape[axis]
# Include axis if it's dynamic (not a static integer) or is size 1
if not isinstance(dim_val, tir.IntImm) or dim_val.value == 1:
valid_dims.append(d)
# If no axes will actually be squeezed, return the original tensor (PyTorch no-op)
if not valid_dims:
return x
dim = valid_dims
elif dim is not None:
# Single-dim case: mimic PyTorch behavior — if the dimension is statically
# known to not be 1, squeezing is a no-op, so return the original tensor.
axis = dim if dim >= 0 else ndim + dim
if 0 <= axis < ndim:
dim_val = shape[axis]
if isinstance(dim_val, tir.IntImm) and dim_val.value != 1:
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for checking if a dimension is squeezable is duplicated for the single-dimension and multi-dimension cases. This could be refactored into a helper function to improve maintainability and reduce redundancy.

Additionally, the condition len(dim) > 0 prevents this logic from handling an empty list or tuple for dim. If dim is [], this if block is skipped, and a relax.op.squeeze(x, []) op is emitted, which is a redundant no-op. The suggested refactoring handles this case by returning x directly.

Finally, PyTorch's squeeze operator raises an IndexError for out-of-bounds dimensions, whereas this implementation silently ignores them. To fully match PyTorch's behavior, you might consider raising an error for out-of-bounds axes. The suggested refactoring includes a comment on this.

        def _is_squeezable(d: int) -> bool:
            axis = d if d >= 0 else ndim + d
            if not (0 <= axis < ndim):
                # To fully match PyTorch, an IndexError should be raised here.
                # For now, returning False maintains the no-op behavior for out-of-bounds dims.
                return False
            dim_val = shape[axis]
            # Squeezable if the dimension is dynamic or its size is 1.
            return not isinstance(dim_val, tir.IntImm) or dim_val.value == 1

        if isinstance(dim, list | tuple):
            if not dim:
                # An empty list of dims is a no-op.
                return x

            valid_dims = [d for d in dim if _is_squeezable(d)]
            if not valid_dims:
                return x
            dim = valid_dims
        elif dim is not None:
            if not _is_squeezable(dim):
                return x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] InternalError: Squeeze dimension check too strict compared to PyTorch behavior

1 participant