Skip to content

Implementation for soft offline distillation using saved top-k teacher logits#3382

Open
ajkv-google wants to merge 6 commits intomainfrom
ajkv/offline-distillation-soft
Open

Implementation for soft offline distillation using saved top-k teacher logits#3382
ajkv-google wants to merge 6 commits intomainfrom
ajkv/offline-distillation-soft

Conversation

@ajkv-google
Copy link
Collaborator

@ajkv-google ajkv-google commented Mar 11, 2026

Description

This PR introduces an end-to-end offline distillation training pipeline. Previously, the distillation loop executed in an "online" mode, which required both the frozen Teacher model and the learning Student model to be loaded and executed simultaneously during training. This change allows the trainer to load pre-computed, top-K Teacher logits from .array_record files, which allows us to bybass the forward pass for the teacher model during the training loop.

Tests

Tested this code change by running the following command:

python3 src/maxtext/trainers/post_train/distillation/train_distill.py src/maxtext/configs/post_train/distillation.yml steps=100 tokenizer_path="/mnt/ajkv/disks/codebase/maxtext/src/maxtext/assets/tokenizers/tokenizer_llama3.tiktoken" --offline_distillation --offline_data_dir="/mnt/ajkv/disks/teacher_logits_output/teacher_top_k_global.array_record"

Truncated output showing the successful run: https://paste.googleplex.com/6342987127848960#l=8.

Verified that the training happened sucessfully and finished the distillation run.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Mar 11, 2026

Codecov Report

❌ Patch coverage is 18.75000% with 39 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...ners/post_train/distillation/distillation_utils.py 28.12% 23 Missing ⚠️
.../trainers/post_train/distillation/train_distill.py 0.00% 16 Missing ⚠️

📢 Thoughts on this report? Let us know!

def __init__(self, data_dir: str, epochs: int = 100):
# Check if the user passed a directory or a direct file path
if tf.io.gfile.isdir(data_dir):
self.filepath = os.path.join(data_dir, "teacher_top_k_global.array_record")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it ok to hardcode this file as teacher_top_k_global.array_record?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the save_top_k_teacher logits file (from this PR), we are writing a single arrayrecord file from one host rather than having multiple hosts write their chunks of data. So, I just named the file as "teacher_top_k_global.arrayrecord". But, I believe not everyone running offline distillation will use the same file to save top-k teacher logits, so I will add this as a field to the config so that users can specify the filename of the saved top-k teacher logits to have it be dynamic.


if __name__ == "__main__":
app.run(main)
parser = argparse.ArgumentParser()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these should go inside types.py to add them as part of the config.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, it would make the command less complex and make things more organized if it is in the config. I moved these to types.py and verified the training ran successfully after the change.

Copy link
Collaborator

@vlad-karp vlad-karp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall
need a new unit test for this specific path

)

# --- Offline Distillation Fields ---
offline_distillation: bool = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks redundant
if you specify offline_data_dir parameter, that can be a direct sign of switching to the offline processing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, I removed the offline_distillation field from the config and relied on offline_data_dir in the train_distill.py file to check when to perform offline/online distillation

# Scatter the offline arrays into a dense tensor of -10000s
dense_shape = batch.input_tokens.shape + (student_config.vocab_size,)
dense_logits = jnp.full(dense_shape, -10000.0, dtype=jnp.float32)
dense_logits = jnp.put_along_axis(dense_logits, batch.top_k_indices, batch.top_k_logits, axis=-1, inplace=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why inplace=False?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jax arrays are immutable, so when you use jnp.put_along_axis, it will create a new array with the updated values. In standard numpy, you can modify arrays in place, but cannot with jax. So, you need to set inplace=False so that jax knows to return the newly created array instead of throwing an error for trying to modify the original one. Here is the documentation explaining how this arg must be set to False: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.put_along_axis.html

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but is there a specific need to use a jax initial array instead of numpy one?

Copy link
Collaborator Author

@ajkv-google ajkv-google Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the code runs inside of a jitted function, and numpy functions cannot handle the jax tracers. If we use numpy, the code will crash and not trace properly, so we have to stick to jnp.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if it will be faster to do this logic on cpu in-place in parallel with train steps execution on tpus

…a_dir to know when to run offfline vs online distillation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants