Implementation for soft offline distillation using saved top-k teacher logits#3382
Implementation for soft offline distillation using saved top-k teacher logits#3382ajkv-google wants to merge 6 commits intomainfrom
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| def __init__(self, data_dir: str, epochs: int = 100): | ||
| # Check if the user passed a directory or a direct file path | ||
| if tf.io.gfile.isdir(data_dir): | ||
| self.filepath = os.path.join(data_dir, "teacher_top_k_global.array_record") |
There was a problem hiding this comment.
is it ok to hardcode this file as teacher_top_k_global.array_record?
There was a problem hiding this comment.
In the save_top_k_teacher logits file (from this PR), we are writing a single arrayrecord file from one host rather than having multiple hosts write their chunks of data. So, I just named the file as "teacher_top_k_global.arrayrecord". But, I believe not everyone running offline distillation will use the same file to save top-k teacher logits, so I will add this as a field to the config so that users can specify the filename of the saved top-k teacher logits to have it be dynamic.
|
|
||
| if __name__ == "__main__": | ||
| app.run(main) | ||
| parser = argparse.ArgumentParser() |
There was a problem hiding this comment.
I think these should go inside types.py to add them as part of the config.
There was a problem hiding this comment.
That makes sense, it would make the command less complex and make things more organized if it is in the config. I moved these to types.py and verified the training ran successfully after the change.
vlad-karp
left a comment
There was a problem hiding this comment.
LGTM overall
need a new unit test for this specific path
src/maxtext/configs/types.py
Outdated
| ) | ||
|
|
||
| # --- Offline Distillation Fields --- | ||
| offline_distillation: bool = Field( |
There was a problem hiding this comment.
looks redundant
if you specify offline_data_dir parameter, that can be a direct sign of switching to the offline processing
There was a problem hiding this comment.
That's true, I removed the offline_distillation field from the config and relied on offline_data_dir in the train_distill.py file to check when to perform offline/online distillation
| # Scatter the offline arrays into a dense tensor of -10000s | ||
| dense_shape = batch.input_tokens.shape + (student_config.vocab_size,) | ||
| dense_logits = jnp.full(dense_shape, -10000.0, dtype=jnp.float32) | ||
| dense_logits = jnp.put_along_axis(dense_logits, batch.top_k_indices, batch.top_k_logits, axis=-1, inplace=False) |
There was a problem hiding this comment.
Jax arrays are immutable, so when you use jnp.put_along_axis, it will create a new array with the updated values. In standard numpy, you can modify arrays in place, but cannot with jax. So, you need to set inplace=False so that jax knows to return the newly created array instead of throwing an error for trying to modify the original one. Here is the documentation explaining how this arg must be set to False: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.put_along_axis.html
There was a problem hiding this comment.
yes, but is there a specific need to use a jax initial array instead of numpy one?
There was a problem hiding this comment.
yes, the code runs inside of a jitted function, and numpy functions cannot handle the jax tracers. If we use numpy, the code will crash and not trace properly, so we have to stick to jnp.
There was a problem hiding this comment.
wondering if it will be faster to do this logic on cpu in-place in parallel with train steps execution on tpus
…a_dir to know when to run offfline vs online distillation
Description
This PR introduces an end-to-end offline distillation training pipeline. Previously, the distillation loop executed in an "online" mode, which required both the frozen Teacher model and the learning Student model to be loaded and executed simultaneously during training. This change allows the trainer to load pre-computed, top-K Teacher logits from .array_record files, which allows us to bybass the forward pass for the teacher model during the training loop.
Tests
Tested this code change by running the following command:
python3 src/maxtext/trainers/post_train/distillation/train_distill.py src/maxtext/configs/post_train/distillation.yml steps=100 tokenizer_path="/mnt/ajkv/disks/codebase/maxtext/src/maxtext/assets/tokenizers/tokenizer_llama3.tiktoken" --offline_distillation --offline_data_dir="/mnt/ajkv/disks/teacher_logits_output/teacher_top_k_global.array_record"Truncated output showing the successful run: https://paste.googleplex.com/6342987127848960#l=8.
Verified that the training happened sucessfully and finished the distillation run.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.