Fix torch.split fails in to_edge with alias annotations#18700
Open
Lidang-Jiang wants to merge 2 commits intopytorch:mainfrom
Open
Fix torch.split fails in to_edge with alias annotations#18700Lidang-Jiang wants to merge 2 commits intopytorch:mainfrom
Lidang-Jiang wants to merge 2 commits intopytorch:mainfrom
Conversation
Fixes pytorch#11723 _remove_invalid_ops_for_not_decompose relied on torchgen's aliased_return_names() to detect ops with aliased returns, but it returns [None] for ops returning lists of aliased tensors (e.g., split.Tensor returns Tensor(a)[]). This let split.Tensor through into the EDGE_DO_NOT_DECOMP namespace where functionalization failed. Add a fallback check using op._schema.returns directly, which correctly reports alias_info on list return types. This also fixes the same latent issue for chunk and tensor_split. Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18700
Note: Links to docs will display an error until the docs builds have been completed.
|
- Change 'may fail' to 'does not detect' (torchgen structurally cannot handle ListType alias annotations) - Add split_with_sizes.default to test to document overlap with blocklist Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
Author
|
@pytorchbot label "release notes: exir" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #11723
Summary
torch.splitfails withRuntimeError: Found a custom (non-ATen) operator whose output has alias annotationswhen used withto_edge_transform_and_lowerand a partitioner that requests op preservation.Root cause:
_remove_invalid_ops_for_not_decomposerelies ontorchgen'saliased_return_names()to detect ops with aliased returns. However, for ops returning lists of aliased tensors (e.g.,split.TensorreturnsTensor(a)[]),aliased_return_names()returns[None], failing to detect the alias annotation. This letssplit.Tensorpass through into theEDGE_DO_NOT_DECOMPnamespace, where functionalization fails.Fix: Add a fallback check using
op._schema.returnsdirectly, which correctly reportsalias_infoon list return types. This also fixes the same latent issue forchunk.defaultandtensor_split.sections.Test plan
test_remove_invalid_ops_filters_aliased_list_returnsregression testpytest exir/tests/test_passes.py::TestPasses::test_remove_invalid_ops_filters_aliased_list_returns -xvstest_to_out_variant_singleon_tensor_listtest_compile_fix_broken_opsBefore fix
After fix
Unit test output
This PR was authored with the assistance of Claude.