diff --git a/backends/aoti/aoti_delegate_handle.h b/backends/aoti/aoti_delegate_handle.h index 2bc6abf9bd1..2d1a3146ae5 100644 --- a/backends/aoti/aoti_delegate_handle.h +++ b/backends/aoti/aoti_delegate_handle.h @@ -31,6 +31,20 @@ using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*; using AOTInductorStreamHandle = void*; using AOTIProxyExecutorHandle = void*; +// Opaque types for AOTI constant management. +// AtenTensorOpaque wraps at::Tensor* in the AOTI runtime — distinct from +// AOTITensorHandle which wraps executorch::runtime::etensor::Tensor*. +struct AtenTensorOpaque; +using AtenTensorHandle = AtenTensorOpaque*; + +struct AOTInductorConstantMap; +using AOTInductorConstantMapHandle = AOTInductorConstantMap*; + +struct AOTInductorConstantMapEntry { + const char* name; + AtenTensorHandle handle; +}; + // Function pointer types for AOT Inductor model container operations using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)( AOTInductorModelContainerHandle* container_handle, @@ -77,6 +91,37 @@ using AOTInductorModelUpdateConstantsFromBlobFunc = AOTIRuntimeError (*)( AOTInductorModelContainerHandle container_handle, const uint8_t* weight_blob_ptr); +// Retrieves a constant's AOTI internal name by index. +using AOTInductorModelContainerGetConstantNameFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** name); + +// Retrieves a constant's original fully-qualified name by index. +using AOTInductorModelContainerGetConstantOriginalFQNFunc = + AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** original_fqn); + +// Extracts the constants map from the container (active or inactive buffer). +// constant_map_handle should point to a +// std::unordered_map. +using AOTInductorModelContainerExtractConstantsMapFunc = AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive); + +// Updates the container's constants with user-managed tensor handles. +// DLL-boundary safe — uses a flat C array instead of std::unordered_map. +using AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFunc = + AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + const AOTInductorConstantMapEntry* pairs, + size_t num_pairs, + bool use_inactive, + bool validate_full_update); + } // extern "C" // AOTI Delegate Handle structure @@ -93,6 +138,14 @@ struct AOTIDelegateHandle { AOTInductorModelContainerGetNumOutputsFunc get_num_outputs; AOTInductorModelContainerRunFunc run; AOTInductorModelUpdateConstantsFromBlobFunc update_constants_from_blob; + + // Constant management function pointers (for cross-method buffer sharing) + AOTInductorModelContainerGetNumConstantsFunc get_num_constants; + AOTInductorModelContainerGetConstantNameFunc get_constant_name; + AOTInductorModelContainerGetConstantOriginalFQNFunc get_constant_original_fqn; + AOTInductorModelContainerExtractConstantsMapFunc extract_constants_map; + AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFunc + update_user_managed_constant_buffer_pairs; }; } // namespace aoti diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index c94d56b796f..1b3d9323c0c 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -207,6 +207,30 @@ class ET_EXPERIMENTAL CudaBackend final Info, "Failed to load AOTInductorModelUpdateConstantsFromBlob. This .so is probably compiled on an old version of torch (<2.9.0)"); } + + // Load constant management symbols (optional — needed for cross-method + // buffer sharing). These are available in torch >= 2.6. +#define LOAD_OPTIONAL_SYMBOL(member, name) \ + do { \ + auto res = get_function(so_handle, #name); \ + handle->member = \ + res.ok() ? reinterpret_cast(res.get()) : nullptr; \ + } while (0) + + LOAD_OPTIONAL_SYMBOL( + get_num_constants, AOTInductorModelContainerGetNumConstants); + LOAD_OPTIONAL_SYMBOL( + get_constant_name, AOTInductorModelContainerGetConstantName); + LOAD_OPTIONAL_SYMBOL( + get_constant_original_fqn, + AOTInductorModelContainerGetConstantOriginalFQN); + LOAD_OPTIONAL_SYMBOL( + extract_constants_map, AOTInductorModelContainerExtractConstantsMap); + LOAD_OPTIONAL_SYMBOL( + update_user_managed_constant_buffer_pairs, + AOTInductorModelContainerUpdateUserManagedConstantBufferPairs); +#undef LOAD_OPTIONAL_SYMBOL + return Error::Ok; } @@ -378,6 +402,105 @@ class ET_EXPERIMENTAL CudaBackend final method_name.c_str()); } + // --------------------------------------------------------------- + // Cross-method constant sharing (e.g., KV cache between prefill/decode). + // + // The first container to initialize extracts its constants (keyed by + // original FQN) and stores the AtenTensorHandle's. Subsequent containers + // with matching FQNs are updated to point to the same GPU tensors via + // UpdateUserManagedConstantBufferPairs (user_managed = true → no copy, + // the source container retains ownership). + // --------------------------------------------------------------- + if (handle->get_num_constants && handle->get_constant_name && + handle->get_constant_original_fqn && handle->extract_constants_map && + handle->update_user_managed_constant_buffer_pairs) { + size_t num_constants = 0; + handle->get_num_constants(handle->container_handle, &num_constants); + + if (num_constants > 0) { + // Build FQN → internal_name mapping for this container. + std::unordered_map fqn_to_name; + for (size_t i = 0; i < num_constants; i++) { + const char* name = nullptr; + const char* fqn = nullptr; + handle->get_constant_name(handle->container_handle, i, &name); + handle->get_constant_original_fqn(handle->container_handle, i, &fqn); + if (name && fqn && fqn[0] != '\0') { + fqn_to_name[fqn] = name; + } + } + + std::lock_guard guard(shared_constants_mutex_); + + if (!constants_extracted_) { + // First container: extract its constants and store by FQN. + std::unordered_map extracted_map; + auto extract_err = handle->extract_constants_map( + handle->container_handle, + reinterpret_cast(&extracted_map), + /*use_inactive=*/false); + + if (extract_err == Error::Ok) { + for (const auto& [fqn, internal_name] : fqn_to_name) { + auto it = extracted_map.find(fqn); + if (it != extracted_map.end()) { + shared_constant_tensors_[fqn] = it->second; + } + } + constants_extracted_ = true; + ET_LOG( + Info, + "Extracted %zu shared constants from method '%s'", + shared_constant_tensors_.size(), + method_name.c_str()); + } else { + ET_LOG( + Error, + "Failed to extract constants from '%s'", + method_name.c_str()); + } + } else { + // Subsequent container: share matching constants from the first. + std::vector pairs; + for (const auto& [fqn, internal_name] : fqn_to_name) { + auto it = shared_constant_tensors_.find(fqn); + if (it != shared_constant_tensors_.end()) { + // UpdateUserManagedConstantBufferPairs matches against the + // codegen constant name (underscored), not the original FQN. + pairs.push_back({internal_name.c_str(), it->second}); + } + } + + if (!pairs.empty()) { + auto update_err = handle->update_user_managed_constant_buffer_pairs( + handle->container_handle, + pairs.data(), + pairs.size(), + /*use_inactive=*/false, + /*validate_full_update=*/false); + + if (update_err == Error::Ok) { + ET_LOG( + Info, + "Shared %zu constants into method '%s'", + pairs.size(), + method_name.c_str()); + } else { + ET_LOG( + Error, + "Failed to share constants into '%s'", + method_name.c_str()); + } + } + } + } + } else { + ET_LOG( + Info, + "Constant sharing APIs not available for method '%s'", + method_name.c_str()); + } + return (DelegateHandle*)handle; // Return the handle post-processing } @@ -623,6 +746,22 @@ class ET_EXPERIMENTAL CudaBackend final mutable std:: unordered_map> cached_outputs_; + + // Cross-method constant sharing state. + // When multiple AOTI containers share mutable buffers (e.g., KV cache), + // the first container's constants are extracted and stored here. Subsequent + // containers with matching FQNs share the same GPU tensors via + // UpdateUserManagedConstantBufferPairs. + mutable std::mutex shared_constants_mutex_; + + // FQN → AtenTensorHandle from the source (first) container. + // The tensor handles are owned by the source container (which is never + // explicitly deleted — see destroy() comment). + mutable std::unordered_map + shared_constant_tensors_; + + // Whether we've already extracted constants from a source container. + mutable bool constants_extracted_ = false; }; } // namespace executorch::backends::cuda diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index 7437bc5f461..19a720a2e79 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -107,6 +107,8 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096): # Any missing weight key indicates a version mismatch between the # checkpoint and the model (e.g., unfused vs fused projections). runtime_prefixes = ( + ".mask", + ".inv_freq", ".kv_cache.", ".conv_state", ".recurrent_state", @@ -312,10 +314,11 @@ def _materialize_buffers(model, config): """Materialize meta-device buffers before torch.export. Replaces meta buffers with real tensors on CPU, recomputes RoPE - inv_freq and causal masks. + inv_freq and causal masks. State buffers (KV cache, conv/recurrent + state) are zero-initialized registered buffers that will be shared + across methods via share_mutable_buffers. """ - # State buffers (KV cache, conv/recurrent state) are bf16 to match - # compute dtype. Masks stay bool, inv_freq stays float32. + # Masks stay bool, inv_freq stays float32. for fqn, buf in list(model.named_buffers()): if buf.device.type == "meta": dtype = torch.bfloat16 if buf.dtype != torch.bool else torch.bool @@ -378,7 +381,18 @@ def _apply_turboquant(model, config): def export_and_lower(model, config, args): - """Export model to .pte via torch.export + CUDA backend.""" + """Export model to .pte via torch.export + CUDA backend. + + Exports two methods: + - "decode": decode path (T=1), uses native PyTorch recurrent FLA + so AOTI can fuse with surrounding ops for maximum decode throughput. + - "prefill": prefill path (T>=2), uses chunked FLA triton_op with + dynamic sequence length. + + Both methods share mutable state buffers (KV cache, conv_state, + recurrent_state) via share_mutable_buffers=True. The model uses + registered buffers with in-place updates — no state in/out args. + """ import torch._inductor.config as inductor_config from executorch.backends.cuda.cuda_backend import CudaBackend @@ -398,25 +412,39 @@ def export_and_lower(model, config, args): # -O0 compiles ~8x faster than -O1 with no measurable runtime impact. inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" - # Dynamic shapes - example_tokens = torch.tensor([[0, 1]], dtype=torch.long) - example_input_pos = torch.tensor([0, 1], dtype=torch.long) - seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1) - dynamic_shapes = ({1: seq_dim}, {0: seq_dim}) - - print("Exporting with torch.export...") + # --- Decode method (T=1, static shape) --- + print("Exporting decode method...") + decode_tokens = torch.tensor([[0]], dtype=torch.long) + decode_pos = torch.tensor([0], dtype=torch.long) + with torch.no_grad(): + decode_ep = export( + model, + (decode_tokens, decode_pos), + strict=True, + ) + print("Decode export successful!") + + # --- Prefill method (T>=2, dynamic shape) --- + print("Exporting prefill method...") + prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long) + prefill_pos = torch.tensor([0, 1], dtype=torch.long) + seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1) + prefill_dynamic_shapes = ( + {1: seq_dim}, # tokens + {0: seq_dim}, # input_pos + ) with torch.no_grad(): - exported = export( + prefill_ep = export( model, - (example_tokens, example_input_pos), - dynamic_shapes=dynamic_shapes, + (prefill_tokens, prefill_pos), + dynamic_shapes=prefill_dynamic_shapes, strict=True, ) - print("Export successful!") + print("Prefill export successful!") - # Lower with CUDA backend + # Lower with CUDA backend (per-method partitioners to avoid so_blob collision) print("Lowering to ExecuTorch with CUDA...") - compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")] + metadata = { "get_max_seq_len": config.max_seq_len, "get_vocab_size": config.vocab_size, @@ -426,8 +454,19 @@ def export_and_lower(model, config, args): "enable_dynamic_shape": True, } et_prog = to_edge_transform_and_lower( - exported, - partitioner=[CudaPartitioner(compile_specs)], + {"decode": decode_ep, "prefill": prefill_ep}, + partitioner={ + "decode": [ + CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec("decode")] + ) + ], + "prefill": [ + CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec("prefill")] + ) + ], + }, compile_config=EdgeCompileConfig( _check_ir_validity=False, _skip_dim_order=True, @@ -438,7 +477,11 @@ def export_and_lower(model, config, args): config=ExecutorchBackendConfig( extract_delegate_segments=True, do_quant_fusion_and_const_prop=True, - memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + share_mutable_buffers=True, + ), + emit_mutable_buffer_names=True, ), ) diff --git a/examples/models/qwen3_5_moe/inference.py b/examples/models/qwen3_5_moe/inference.py index ce9de9230af..c824f6a6444 100644 --- a/examples/models/qwen3_5_moe/inference.py +++ b/examples/models/qwen3_5_moe/inference.py @@ -75,9 +75,8 @@ def generate( ): """Generate text autoregressively with KV cache. - Prefills one token at a time (the chunk_gated_delta_rule kernel's chunked - path has numerical issues with T>1 in eager mode; token-by-token uses the - stable recurrent path). + Prefills one token at a time (the recurrent path; chunked FLA via + @triton_op is used for T>1 prefill in the exported PTE). """ if eos_token_ids is None: eos_token_ids = set() diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 266d0e65419..d0a1a93169e 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -8,14 +8,20 @@ #include -#include +#include +#include +#include #include +#include #include #include +#include #include #include +#include + DEFINE_string(model_path, "", "Model .pte file path."); DEFINE_string(data_path, "", "Data file (.ptd) for CUDA backend."); DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); @@ -24,6 +30,13 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy)."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); namespace llm = ::executorch::extension::llm; +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::extension::TensorPtr; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; + +using SizesType = executorch::aten::SizesType; int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -37,11 +50,6 @@ int main(int argc, char** argv) { return 1; } - std::vector data_files; - if (!FLAGS_data_path.empty()) { - data_files.push_back(FLAGS_data_path); - } - // Load tokenizer auto tokenizer = std::make_unique(); auto tok_status = tokenizer->load(FLAGS_tokenizer_path); @@ -53,25 +61,177 @@ int main(int argc, char** argv) { return 1; } - // Create LLM runner - auto runner = llm::create_text_llm_runner( - FLAGS_model_path, std::move(tokenizer), data_files, FLAGS_temperature); + // Create Module with share_memory_arenas=true so prefill and forward + // share mutable buffers (KV cache, conv_state, recurrent_state). + std::vector data_files; + if (!FLAGS_data_path.empty()) { + data_files.push_back(FLAGS_data_path); + } + auto module = std::make_unique( + FLAGS_model_path, + data_files, + Module::LoadMode::File, + /*event_tracer=*/nullptr, + /*memory_allocator=*/nullptr, + /*temp_allocator=*/nullptr, + /*share_memory_arenas=*/true); + + // Get metadata + auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get()); + if (metadata_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to get metadata from model"); + return 1; + } + auto metadata = metadata_result.get(); + + printf("Loading methods...\n"); + + // Load both methods + auto err = module->load_method("prefill"); + if (err != Error::Ok) { + ET_LOG(Error, "Failed to load prefill method"); + return 1; + } + err = module->load_method("decode"); + if (err != Error::Ok) { + ET_LOG(Error, "Failed to load decode method"); + return 1; + } + + // Get EOS ids + auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); - if (runner == nullptr) { - ET_LOG(Error, "Failed to create runner"); + // Encode prompt + auto encode_result = tokenizer->encode(FLAGS_prompt); + if (!encode_result.ok()) { + ET_LOG(Error, "Failed to encode prompt"); return 1; } + auto prompt_tokens = std::move(*encode_result); + int64_t num_prompt_tokens = prompt_tokens.size(); + printf("Prompt tokens: %ld\n", num_prompt_tokens); - // Generate - llm::GenerationConfig config; - config.temperature = FLAGS_temperature; - config.max_new_tokens = FLAGS_max_new_tokens; + // --------------------------------------------------------------- + // Prefill or decode-only + // --------------------------------------------------------------- + auto S = [](int64_t v) -> SizesType { return static_cast(v); }; - auto error = runner->generate(FLAGS_prompt.c_str(), config); - if (error != executorch::runtime::Error::Ok) { - ET_LOG(Error, "Generation failed"); + uint64_t cur_token = 0; + auto prefill_start = std::chrono::steady_clock::now(); + + // Chunked prefill + std::vector pos_data(num_prompt_tokens); + for (int64_t i = 0; i < num_prompt_tokens; i++) { + pos_data[i] = i; + } + std::vector token_data(prompt_tokens.begin(), prompt_tokens.end()); + auto tokens_tensor = from_blob( + token_data.data(), + {1, S(num_prompt_tokens)}, + executorch::aten::ScalarType::Long); + auto pos_tensor = from_blob( + pos_data.data(), + {S(num_prompt_tokens)}, + executorch::aten::ScalarType::Long); + + std::vector prefill_inputs; + prefill_inputs.push_back(tokens_tensor); + prefill_inputs.push_back(pos_tensor); + + auto prefill_result = module->execute("prefill", prefill_inputs); + if (prefill_result.error() != Error::Ok) { + ET_LOG(Error, "Prefill failed"); return 1; } + auto& prefill_outputs = prefill_result.get(); + + auto logits_tensor = prefill_outputs[0].toTensor(); + auto logits_ptr = + std::make_shared(std::move(logits_tensor)); + cur_token = llm::logits_to_token(*logits_ptr, FLAGS_temperature); + + auto prefill_end = std::chrono::steady_clock::now(); + double prefill_ms = + std::chrono::duration(prefill_end - prefill_start) + .count(); + printf( + "Prefill: %ld tokens in %.1f ms (%.1f tok/s)\n", + num_prompt_tokens, + prefill_ms, + num_prompt_tokens * 1000.0 / prefill_ms); + + // Synchronize CUDA device to ensure prefill's writes to shared mutable + // buffers (KV cache, conv_state, recurrent_state) are visible to the + // decode method, which may run on a different CUDA stream. + cudaDeviceSynchronize(); + + // --------------------------------------------------------------- + // Decode — generate tokens one at a time + // --------------------------------------------------------------- + llm::Stats stats; + int64_t pos = num_prompt_tokens; + uint64_t prev_token; + + std::vector decode_token_data = {static_cast(cur_token)}; + std::vector decode_pos_data = {pos}; + auto decode_tokens = from_blob( + decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long); + auto decode_pos = from_blob( + decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long); + + auto decode_start = std::chrono::steady_clock::now(); + + for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) { + decode_token_data[0] = static_cast(cur_token); + decode_pos_data[0] = pos; + + std::vector decode_inputs; + decode_inputs.push_back(EValue(decode_tokens)); + decode_inputs.push_back(EValue(decode_pos)); + + auto decode_result = module->execute("decode", decode_inputs); + if (decode_result.error() != Error::Ok) { + ET_LOG(Error, "Decode step %d failed", step); + return 1; + } + auto& decode_outputs = decode_result.get(); + + auto step_logits = decode_outputs[0].toTensor(); + auto step_logits_ptr = + std::make_shared(std::move(step_logits)); + + prev_token = cur_token; + stats.on_sampling_begin(); + cur_token = llm::logits_to_token(*step_logits_ptr, FLAGS_temperature); + stats.on_sampling_end(); + + pos++; + + auto decode_str = tokenizer->decode(prev_token, cur_token); + if (decode_str.ok()) { + printf("%s", decode_str->c_str()); + fflush(stdout); + } + + if (eos_ids.find(cur_token) != eos_ids.end()) { + printf("\n"); + break; + } + } + + auto decode_end = std::chrono::steady_clock::now(); + + printf("\n"); + int64_t num_generated = pos - num_prompt_tokens; + double decode_ms = + std::chrono::duration(decode_end - decode_start) + .count(); + printf( + "Decode: %ld tokens in %.1f ms (%.1f tok/s)\n", + num_generated, + decode_ms, + num_generated * 1000.0 / decode_ms); + printf("Prompt tokens: %ld\n", num_prompt_tokens); return 0; } diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index d9f127d9ed1..751915fb123 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -350,6 +350,12 @@ def __init__(self, config): ) def forward(self, x, input_pos): + """GatedDeltaNet with trace-time dispatch. + + When traced with T=1: uses native PyTorch recurrent delta rule + (AOTI fuses with surrounding ops for maximum decode throughput). + When traced with T>1: uses chunked FLA via triton_op. + """ B, T, _ = x.size() # Reset state at position 0 @@ -406,13 +412,43 @@ def forward(self, x, input_pos): beta = b.sigmoid() g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) - # FLA Triton kernel (returns final_state separately, does not mutate initial_state) - output, state = torch.ops.triton.chunk_gated_delta_rule( - q, k, v, g, beta, self.recurrent_state[:B] - ) + if T == 1: + # Native recurrent delta rule — AOTI fuses with surrounding ops + scale = self.head_k_dim**-0.5 - with torch.no_grad(): - self.recurrent_state[:B].copy_(state) + q_s = q[:, 0].float() # [B, H, K] + k_s = k[:, 0].float() # [B, H, K] + v_s = v[:, 0].float() # [B, H, V] + g_s = g[:, 0] # [B, H] + beta_s = beta[:, 0] # [B, H] + + state = self.recurrent_state[:B].float() # [B, H, K, V] + + # Decay state by exp(g) + decay = torch.exp(g_s).unsqueeze(-1).unsqueeze(-1) # [B, H, 1, 1] + state = state * decay + + # Sk = state @ k (project state by key) + Sk = torch.einsum("bhkv,bhk->bhv", state, k_s) + + # Delta rule state update + delta = beta_s.unsqueeze(-1) * (v_s - Sk) # [B, H, V] + state = state + torch.einsum("bhk,bhv->bhkv", k_s, delta) + + # Output = state @ q * scale + output = torch.einsum("bhkv,bhk->bhv", state, q_s) * scale + output = output.unsqueeze(1).to(q.dtype) # [B, 1, H, V] + + with torch.no_grad(): + self.recurrent_state[:B].copy_(state.to(self.recurrent_state.dtype)) + else: + # Chunked FLA triton_op for prefill + output, new_state = torch.ops.triton.chunk_gated_delta_rule( + q, k, v, g, beta, self.recurrent_state[:B] + ) + + with torch.no_grad(): + self.recurrent_state[:B].copy_(new_state) # Output: RMSNorm(output) * silu(z) output = output.reshape(-1, self.head_v_dim)