Skip to content

Merge the batch and sequence dimensions#469

Merged
jlamypoirier merged 24 commits intomainfrom
jlp_token_dim
Mar 17, 2026
Merged

Merge the batch and sequence dimensions#469
jlamypoirier merged 24 commits intomainfrom
jlp_token_dim

Conversation

@jlamypoirier
Copy link
Collaborator

@jlamypoirier jlamypoirier commented Feb 11, 2026

✨ Description

Combine the batch and sequence dimensions in most of the model. This simplifies various sequence-based, and in particular removes the need for the sequence_first format.

Other changes

  • Fix various issues with distillation.
  • Fix the runner kwargs not propagating to namespaces.
  • Fix AuxiliaryLoss for eval mode.
  • Make the model head use the _debug util for returning logits.

Known issue: MTP has a different name for logits which causes incompatibility issues

@jlamypoirier jlamypoirier changed the title Token dim Merge the batch and sequence dimensions Feb 11, 2026
@jlamypoirier jlamypoirier marked this pull request as ready for review February 11, 2026 21:12
Base automatically changed from jlp_triton_loss to main March 17, 2026 23:43
@jlamypoirier jlamypoirier merged commit a1cdc55 into main Mar 17, 2026
1 of 2 checks passed
@jlamypoirier jlamypoirier deleted the jlp_token_dim branch March 17, 2026 23:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant