Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ MaxText aims to provide you with the best OSS models, whether as a reference imp
* Gemma 2 (2B, 9B, 27B)
* Gemma 1 (2B, 7B)
* Alibaba
* Qwen 2.5 (7B, 14B)
* Qwen 3 MoE 2507 (235B, 480B)
* Qwen 3 MoE (30B, 235B)
* Qwen 3 Dense (0.6B, 1.7B, 4B, 8B, 14B, 32B)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The following models are supported:
| **Gemma2** | 2B, 9B, 27B | √ | √ | √ | √ |
| **Gemma3** (Multimodal) | 4B, 12B, 27B | √ | √ | √ | √ |
| **Llama3.1** | 8B, 70B, 450B | √ | √ | √ | √ |
| **Qwen2.5** | 7B, 14B | √ | √ | √ | √ |
| **Qwen3** | 0.6B, 4B, 8B, 14B, 32B | √ | √ | √ | √ |
| **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ |
| **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ |
Expand Down
37 changes: 37 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/hf_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,41 @@
query_pre_attn_scalar=144,
)

qwen25_7b_config = transformers.Qwen2Config(
vocab_size=152064,
hidden_size=3584,
intermediate_size=18944,
num_hidden_layers=28,
num_attention_heads=28,
num_key_value_heads=4,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-06,
use_cache=True,
rope_theta=1000000.0,
tie_word_embeddings=False,
torch_dtype="bfloat16",
attention_bias=True,
)

qwen25_14b_config = transformers.Qwen2Config(
vocab_size=152064,
hidden_size=5120,
intermediate_size=13824,
num_hidden_layers=48,
num_attention_heads=40,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
rms_norm_eps=1e-06,
rope_theta=1000000.0,
tie_word_embeddings=False,
torch_dtype="bfloat16",
attention_bias=True,
)


qwen3_0_6b_config = transformers.Qwen3Config(
vocab_size=151936,
hidden_size=1024,
Expand Down Expand Up @@ -815,6 +850,8 @@
"gemma3-4b": gemma3_4b_config,
"gemma3-12b": gemma3_12b_config,
"gemma3-27b": gemma3_27b_config,
"qwen2.5-7b": qwen25_7b_config,
"qwen2.5-14b": qwen25_14b_config,
"qwen3-0.6b": qwen3_0_6b_config,
"qwen3-4b": qwen3_4b_config,
"qwen3-4b-thinking-2507": qwen3_4b_config,
Expand Down
34 changes: 23 additions & 11 deletions src/maxtext/checkpoint_conversion/utils/hf_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,8 @@ def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config):
return mapping


def QWEN3_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace Qwen3 weights path and the HuggingFace weights shape.
def QWEN_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace Qwen weights path and the HuggingFace weights shape.

To check this mapping, dump the huggingface model shapes:
from transformers import AutoModelForCausalLM
Expand All @@ -555,6 +555,7 @@ def QWEN3_HF_WEIGHTS_TO_SHAPE(config):
head_dim = config.get(
"head_dim", config["hidden_size"] // config["num_attention_heads"]
) # head_dim might not always be present
attention_bias = config.get("attention_bias", False)

mapping = {
"model.embed_tokens.weight": [config["vocab_size"], hidden_size],
Expand All @@ -580,6 +581,15 @@ def QWEN3_HF_WEIGHTS_TO_SHAPE(config):
f"{layer_prefix}.self_attn.k_norm.weight": [head_dim],
}

if attention_bias:
layer_mapping.update(
{
f"{layer_prefix}.self_attn.q_proj.bias": [num_attention_heads * head_dim],
f"{layer_prefix}.self_attn.k_proj.bias": [num_key_value_heads * head_dim],
f"{layer_prefix}.self_attn.v_proj.bias": [num_key_value_heads * head_dim],
}
)

if num_experts > 1:
# MoE MLP layers
moe_ffn_intermediate_size = config.get("moe_intermediate_size")
Expand Down Expand Up @@ -756,18 +766,20 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config):
"gemma3-4b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
"gemma3-12b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
"gemma3-27b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
"qwen3-0.6b": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen3-4b": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen3-4b-thinking-2507": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen3-8b": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen3-14b": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen3-32b": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen2.5-7b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen2.5-14b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-0.6b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-4b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-4b-thinking-2507": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-8b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-14b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-32b": QWEN_HF_WEIGHTS_TO_SHAPE,
"llama3.1-8b": LLAMA31_HF_WEIGHTS_TO_SHAPE,
"llama3.1-70b": LLAMA31_HF_WEIGHTS_TO_SHAPE,
"llama3.1-405b": LLAMA31_HF_WEIGHTS_TO_SHAPE,
"qwen3-30b-a3b": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen3-235b-a22b": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen3-480b-a35b": QWEN3_HF_WEIGHTS_TO_SHAPE,
"qwen3-30b-a3b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-235b-a22b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen3-480b-a35b": QWEN_HF_WEIGHTS_TO_SHAPE,
"deepseek3-671b": DEEPSEEK_HF_WEIGHTS_TO_SHAPE,
"gpt-oss-20b": GPT_OSS_HF_WEIGHTS_TO_SHAPE,
"gpt-oss-120b": GPT_OSS_HF_WEIGHTS_TO_SHAPE,
Expand Down
90 changes: 65 additions & 25 deletions src/maxtext/checkpoint_conversion/utils/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,11 @@ def scale_query_layer(input_tensor, target_shape):
return mapping


def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Returns mapping from MaxText to HuggingFace Qwen3 weight paths.
def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Returns mapping from MaxText to HuggingFace Qwen weight paths.

This function generates a dictionary that maps parameter names from a MaxText
Qwen3 checkpoint to their corresponding names in the Hugging Face format.
Qwen checkpoint to their corresponding names in the Hugging Face format.
It handles both dense and Mixture-of-Experts (MoE) model variants.

Args:
Expand Down Expand Up @@ -631,6 +631,15 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
"params-decoder-layers-self_attention-value-kernel": [
f"model.layers.{i}.self_attn.v_proj.weight" for i in range(n_layers)
],
"params-decoder-layers-self_attention-query-bias": [
f"model.layers.{i}.self_attn.q_proj.bias" for i in range(n_layers)
],
"params-decoder-layers-self_attention-key-bias": [
f"model.layers.{i}.self_attn.k_proj.bias" for i in range(n_layers)
],
"params-decoder-layers-self_attention-value-bias": [
f"model.layers.{i}.self_attn.v_proj.bias" for i in range(n_layers)
],
"params-decoder-layers-self_attention-out-kernel": [
f"model.layers.{i}.self_attn.o_proj.weight" for i in range(n_layers)
],
Expand Down Expand Up @@ -688,6 +697,9 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
f"params-decoder-layers_{i}-self_attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight",
f"params-decoder-layers_{i}-self_attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight",
f"params-decoder-layers_{i}-self_attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight",
f"params-decoder-layers_{i}-self_attention-query-bias": f"model.layers.{i}.self_attn.q_proj.bias",
f"params-decoder-layers_{i}-self_attention-key-bias": f"model.layers.{i}.self_attn.k_proj.bias",
f"params-decoder-layers_{i}-self_attention-value-bias": f"model.layers.{i}.self_attn.v_proj.bias",
f"params-decoder-layers_{i}-self_attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight",
f"params-decoder-layers_{i}-self_attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight",
f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight",
Expand Down Expand Up @@ -721,8 +733,8 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False)
return mapping


def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Creates parameter transformation functions for Qwen3.
def QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""Creates parameter transformation functions for Qwen.

This function provides a dictionary of transformation functions (hooks) for
converting Qwen3 model parameters between MaxText and Hugging Face formats.
Expand Down Expand Up @@ -766,6 +778,15 @@ def reshape_kernel(input_tensor, target_shape):
else:
return input_tensor.T.reshape(target_shape)

def reshape_bias(input_tensor, target_shape=None):
"""Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
if saving_to_hf:
# MaxText [heads, head_dim] -> HF [hidden_dim] (flatten)
return input_tensor.reshape(target_shape)
else:
# HF [hidden_dim] -> MaxText [heads, head_dim]
return input_tensor.reshape(target_shape)

mapping = {
"params-token_embedder-embedding": pad_embedding_layer,
"params-decoder-logits_dense-kernel": reshape_kernel,
Expand All @@ -780,6 +801,11 @@ def reshape_kernel(input_tensor, target_shape):
"mlp-wi_1-kernel",
"mlp-wo-kernel",
]
bias_hooks = [
"self_attention-query-bias",
"self_attention-key-bias",
"self_attention-value-bias",
]
moe_kernel_hooks = [
"moe_block-gate-kernel",
"moe_block-wi_0-kernel",
Expand All @@ -793,13 +819,17 @@ def reshape_kernel(input_tensor, target_shape):
if scan_layers:
for key in kernel_hooks:
mapping[f"params-decoder-layers-{key}"] = reshape_kernel
for key in bias_hooks:
mapping[f"params-decoder-layers-{key}"] = reshape_bias
if num_experts > 1:
for key in moe_kernel_hooks:
mapping[f"params-decoder-layers-{key}"] = reshape_kernel
else:
for i in range(n_layers):
for key in kernel_hooks:
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel
for key in bias_hooks:
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_bias
if num_experts > 1:
for key in moe_kernel_hooks:
mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel
Expand Down Expand Up @@ -1376,7 +1406,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_laye
# Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function
num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0)
n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"]
text_mapping = QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(
text_mapping = QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(
config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text},
maxtext_config=maxtext_config,
scan_layers=scan_layers,
Expand Down Expand Up @@ -1544,7 +1574,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_laye
# Text hooks, reusing QWEN3-MOE hook function
num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0)
n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"]
text_hooks = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(
text_hooks = QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(
config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text},
maxtext_config=maxtext_config,
scan_layers=scan_layers,
Expand Down Expand Up @@ -2332,18 +2362,23 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
"gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
"gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
"gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen2.5-0.5b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen2.5-1.5b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen2.5-3b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
"llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING,
"deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING,
"gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
"gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
Expand All @@ -2364,18 +2399,23 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
"gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen2.5-0.5b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen2.5-1.5b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen2.5-3b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
"gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class DecoderBlockType(enum.Enum):
GEMMA = "gemma"
GEMMA2 = "gemma2"
GEMMA3 = "gemma3"
QWEN2 = "qwen2"
QWEN3 = "qwen3"
QWEN3_MOE = "qwen3_moe"
QWEN3_NEXT = "qwen3_next"
Expand Down
34 changes: 34 additions & 0 deletions src/maxtext/configs/models/qwen2.5-14b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Qwen 2.5 14B Instruct Configuration
# https://huggingface.co/Qwen/Qwen2.5-14B-Instruct

base_emb_dim: 5120
base_num_query_heads: 40
base_num_kv_heads: 8
base_mlp_dim: 13824
base_num_decoder_layers: 48
head_dim: 128
mlp_activations: ["silu", "linear"]
vocab_size: 152064
decoder_block: "qwen2"
normalization_layer_epsilon: 1.0e-6
rope_max_timescale: 1000000.0
use_qk_norm: False
# Bias for q, k, v proj.
attention_bias: True
logits_via_embedding: False
normalize_embedding_logits: False
tokenizer_type: "huggingface"
34 changes: 34 additions & 0 deletions src/maxtext/configs/models/qwen2.5-7b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Qwen 2.5 7B Instruct Configuration
# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct

base_emb_dim: 3584
base_num_query_heads: 28
base_num_kv_heads: 4
base_mlp_dim: 18944
base_num_decoder_layers: 28
head_dim: 128
mlp_activations: ["silu", "linear"]
vocab_size: 152064
decoder_block: "qwen2"
normalization_layer_epsilon: 1e-06
rope_max_timescale: 1000000.0
use_qk_norm: False
# Bias for q, k, v proj.
attention_bias: True
logits_via_embedding: False
normalize_embedding_logits: False
tokenizer_type: "huggingface"
Loading
Loading