Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions backends/aoti/aoti_delegate_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*;
using AOTInductorStreamHandle = void*;
using AOTIProxyExecutorHandle = void*;

// Opaque types for AOTI constant management.
// AtenTensorOpaque wraps at::Tensor* in the AOTI runtime — distinct from
// AOTITensorHandle which wraps executorch::runtime::etensor::Tensor*.
struct AtenTensorOpaque;
using AtenTensorHandle = AtenTensorOpaque*;

struct AOTInductorConstantMap;
using AOTInductorConstantMapHandle = AOTInductorConstantMap*;

struct AOTInductorConstantMapEntry {
const char* name;
AtenTensorHandle handle;
};

// Function pointer types for AOT Inductor model container operations
using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle* container_handle,
Expand Down Expand Up @@ -77,6 +91,37 @@ using AOTInductorModelUpdateConstantsFromBlobFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
const uint8_t* weight_blob_ptr);

// Retrieves a constant's AOTI internal name by index.
using AOTInductorModelContainerGetConstantNameFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t idx,
const char** name);

// Retrieves a constant's original fully-qualified name by index.
using AOTInductorModelContainerGetConstantOriginalFQNFunc =
AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t idx,
const char** original_fqn);

// Extracts the constants map from the container (active or inactive buffer).
// constant_map_handle should point to a
// std::unordered_map<std::string, AtenTensorHandle>.
using AOTInductorModelContainerExtractConstantsMapFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
AOTInductorConstantMapHandle constant_map_handle,
bool use_inactive);

// Updates the container's constants with user-managed tensor handles.
// DLL-boundary safe — uses a flat C array instead of std::unordered_map.
using AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFunc =
AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
const AOTInductorConstantMapEntry* pairs,
size_t num_pairs,
bool use_inactive,
bool validate_full_update);

} // extern "C"

// AOTI Delegate Handle structure
Expand All @@ -93,6 +138,14 @@ struct AOTIDelegateHandle {
AOTInductorModelContainerGetNumOutputsFunc get_num_outputs;
AOTInductorModelContainerRunFunc run;
AOTInductorModelUpdateConstantsFromBlobFunc update_constants_from_blob;

// Constant management function pointers (for cross-method buffer sharing)
AOTInductorModelContainerGetNumConstantsFunc get_num_constants;
AOTInductorModelContainerGetConstantNameFunc get_constant_name;
AOTInductorModelContainerGetConstantOriginalFQNFunc get_constant_original_fqn;
AOTInductorModelContainerExtractConstantsMapFunc extract_constants_map;
AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFunc
update_user_managed_constant_buffer_pairs;
};

} // namespace aoti
Expand Down
139 changes: 139 additions & 0 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,30 @@ class ET_EXPERIMENTAL CudaBackend final
Info,
"Failed to load AOTInductorModelUpdateConstantsFromBlob. This .so is probably compiled on an old version of torch (<2.9.0)");
}

// Load constant management symbols (optional — needed for cross-method
// buffer sharing). These are available in torch >= 2.6.
#define LOAD_OPTIONAL_SYMBOL(member, name) \
do { \
auto res = get_function(so_handle, #name); \
handle->member = \
res.ok() ? reinterpret_cast<name##Func>(res.get()) : nullptr; \
} while (0)

LOAD_OPTIONAL_SYMBOL(
get_num_constants, AOTInductorModelContainerGetNumConstants);
LOAD_OPTIONAL_SYMBOL(
get_constant_name, AOTInductorModelContainerGetConstantName);
LOAD_OPTIONAL_SYMBOL(
get_constant_original_fqn,
AOTInductorModelContainerGetConstantOriginalFQN);
LOAD_OPTIONAL_SYMBOL(
extract_constants_map, AOTInductorModelContainerExtractConstantsMap);
LOAD_OPTIONAL_SYMBOL(
update_user_managed_constant_buffer_pairs,
AOTInductorModelContainerUpdateUserManagedConstantBufferPairs);
#undef LOAD_OPTIONAL_SYMBOL

return Error::Ok;
}

Expand Down Expand Up @@ -378,6 +402,105 @@ class ET_EXPERIMENTAL CudaBackend final
method_name.c_str());
}

// ---------------------------------------------------------------
// Cross-method constant sharing (e.g., KV cache between prefill/decode).
//
// The first container to initialize extracts its constants (keyed by
// original FQN) and stores the AtenTensorHandle's. Subsequent containers
// with matching FQNs are updated to point to the same GPU tensors via
// UpdateUserManagedConstantBufferPairs (user_managed = true → no copy,
// the source container retains ownership).
// ---------------------------------------------------------------
if (handle->get_num_constants && handle->get_constant_name &&
handle->get_constant_original_fqn && handle->extract_constants_map &&
handle->update_user_managed_constant_buffer_pairs) {
size_t num_constants = 0;
handle->get_num_constants(handle->container_handle, &num_constants);

if (num_constants > 0) {
// Build FQN → internal_name mapping for this container.
std::unordered_map<std::string, std::string> fqn_to_name;
for (size_t i = 0; i < num_constants; i++) {
const char* name = nullptr;
const char* fqn = nullptr;
handle->get_constant_name(handle->container_handle, i, &name);
handle->get_constant_original_fqn(handle->container_handle, i, &fqn);
if (name && fqn && fqn[0] != '\0') {
fqn_to_name[fqn] = name;
}
}

std::lock_guard<std::mutex> guard(shared_constants_mutex_);

if (!constants_extracted_) {
// First container: extract its constants and store by FQN.
std::unordered_map<std::string, AtenTensorHandle> extracted_map;
auto extract_err = handle->extract_constants_map(
handle->container_handle,
reinterpret_cast<AOTInductorConstantMapHandle>(&extracted_map),
/*use_inactive=*/false);

if (extract_err == Error::Ok) {
for (const auto& [fqn, internal_name] : fqn_to_name) {
auto it = extracted_map.find(fqn);
if (it != extracted_map.end()) {
shared_constant_tensors_[fqn] = it->second;
}
}
constants_extracted_ = true;
ET_LOG(
Info,
"Extracted %zu shared constants from method '%s'",
shared_constant_tensors_.size(),
method_name.c_str());
} else {
ET_LOG(
Error,
"Failed to extract constants from '%s'",
method_name.c_str());
}
} else {
// Subsequent container: share matching constants from the first.
std::vector<AOTInductorConstantMapEntry> pairs;
for (const auto& [fqn, internal_name] : fqn_to_name) {
auto it = shared_constant_tensors_.find(fqn);
if (it != shared_constant_tensors_.end()) {
// UpdateUserManagedConstantBufferPairs matches against the
// codegen constant name (underscored), not the original FQN.
pairs.push_back({internal_name.c_str(), it->second});
}
}

if (!pairs.empty()) {
auto update_err = handle->update_user_managed_constant_buffer_pairs(
handle->container_handle,
pairs.data(),
pairs.size(),
/*use_inactive=*/false,
/*validate_full_update=*/false);

if (update_err == Error::Ok) {
ET_LOG(
Info,
"Shared %zu constants into method '%s'",
pairs.size(),
method_name.c_str());
} else {
ET_LOG(
Error,
"Failed to share constants into '%s'",
method_name.c_str());
}
}
}
}
} else {
ET_LOG(
Info,
"Constant sharing APIs not available for method '%s'",
method_name.c_str());
}

return (DelegateHandle*)handle; // Return the handle post-processing
}

Expand Down Expand Up @@ -623,6 +746,22 @@ class ET_EXPERIMENTAL CudaBackend final
mutable std::
unordered_map<cuda::CudaDelegateHandle*, std::vector<SlimTensor*>>
cached_outputs_;

// Cross-method constant sharing state.
// When multiple AOTI containers share mutable buffers (e.g., KV cache),
// the first container's constants are extracted and stored here. Subsequent
// containers with matching FQNs share the same GPU tensors via
// UpdateUserManagedConstantBufferPairs.
mutable std::mutex shared_constants_mutex_;

// FQN → AtenTensorHandle from the source (first) container.
// The tensor handles are owned by the source container (which is never
// explicitly deleted — see destroy() comment).
mutable std::unordered_map<std::string, AtenTensorHandle>
shared_constant_tensors_;

// Whether we've already extracted constants from a source container.
mutable bool constants_extracted_ = false;
};

} // namespace executorch::backends::cuda
Expand Down
83 changes: 63 additions & 20 deletions examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096):
# Any missing weight key indicates a version mismatch between the
# checkpoint and the model (e.g., unfused vs fused projections).
runtime_prefixes = (
".mask",
".inv_freq",
".kv_cache.",
".conv_state",
".recurrent_state",
Expand Down Expand Up @@ -312,10 +314,11 @@ def _materialize_buffers(model, config):
"""Materialize meta-device buffers before torch.export.

Replaces meta buffers with real tensors on CPU, recomputes RoPE
inv_freq and causal masks.
inv_freq and causal masks. State buffers (KV cache, conv/recurrent
state) are zero-initialized registered buffers that will be shared
across methods via share_mutable_buffers.
"""
# State buffers (KV cache, conv/recurrent state) are bf16 to match
# compute dtype. Masks stay bool, inv_freq stays float32.
# Masks stay bool, inv_freq stays float32.
for fqn, buf in list(model.named_buffers()):
if buf.device.type == "meta":
dtype = torch.bfloat16 if buf.dtype != torch.bool else torch.bool
Expand Down Expand Up @@ -378,7 +381,18 @@ def _apply_turboquant(model, config):


def export_and_lower(model, config, args):
"""Export model to .pte via torch.export + CUDA backend."""
"""Export model to .pte via torch.export + CUDA backend.

Exports two methods:
- "decode": decode path (T=1), uses native PyTorch recurrent FLA
so AOTI can fuse with surrounding ops for maximum decode throughput.
- "prefill": prefill path (T>=2), uses chunked FLA triton_op with
dynamic sequence length.

Both methods share mutable state buffers (KV cache, conv_state,
recurrent_state) via share_mutable_buffers=True. The model uses
registered buffers with in-place updates — no state in/out args.
"""
import torch._inductor.config as inductor_config

from executorch.backends.cuda.cuda_backend import CudaBackend
Expand All @@ -398,25 +412,39 @@ def export_and_lower(model, config, args):
# -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"

# Dynamic shapes
example_tokens = torch.tensor([[0, 1]], dtype=torch.long)
example_input_pos = torch.tensor([0, 1], dtype=torch.long)
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1)
dynamic_shapes = ({1: seq_dim}, {0: seq_dim})

print("Exporting with torch.export...")
# --- Decode method (T=1, static shape) ---
print("Exporting decode method...")
decode_tokens = torch.tensor([[0]], dtype=torch.long)
decode_pos = torch.tensor([0], dtype=torch.long)
with torch.no_grad():
decode_ep = export(
model,
(decode_tokens, decode_pos),
strict=True,
)
print("Decode export successful!")

# --- Prefill method (T>=2, dynamic shape) ---
print("Exporting prefill method...")
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
prefill_dynamic_shapes = (
{1: seq_dim}, # tokens
{0: seq_dim}, # input_pos
)
with torch.no_grad():
exported = export(
prefill_ep = export(
model,
(example_tokens, example_input_pos),
dynamic_shapes=dynamic_shapes,
(prefill_tokens, prefill_pos),
dynamic_shapes=prefill_dynamic_shapes,
strict=True,
)
print("Export successful!")
print("Prefill export successful!")

# Lower with CUDA backend
# Lower with CUDA backend (per-method partitioners to avoid so_blob collision)
print("Lowering to ExecuTorch with CUDA...")
compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")]

metadata = {
"get_max_seq_len": config.max_seq_len,
"get_vocab_size": config.vocab_size,
Expand All @@ -426,8 +454,19 @@ def export_and_lower(model, config, args):
"enable_dynamic_shape": True,
}
et_prog = to_edge_transform_and_lower(
exported,
partitioner=[CudaPartitioner(compile_specs)],
{"decode": decode_ep, "prefill": prefill_ep},
partitioner={
"decode": [
CudaPartitioner(
[CudaBackend.generate_method_name_compile_spec("decode")]
)
],
"prefill": [
CudaPartitioner(
[CudaBackend.generate_method_name_compile_spec("prefill")]
)
],
},
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
Expand All @@ -438,7 +477,11 @@ def export_and_lower(model, config, args):
config=ExecutorchBackendConfig(
extract_delegate_segments=True,
do_quant_fusion_and_const_prop=True,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
memory_planning_pass=MemoryPlanningPass(
alloc_graph_input=False,
share_mutable_buffers=True,
),
emit_mutable_buffer_names=True,
),
)

Expand Down
5 changes: 2 additions & 3 deletions examples/models/qwen3_5_moe/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ def generate(
):
"""Generate text autoregressively with KV cache.

Prefills one token at a time (the chunk_gated_delta_rule kernel's chunked
path has numerical issues with T>1 in eager mode; token-by-token uses the
stable recurrent path).
Prefills one token at a time (the recurrent path; chunked FLA via
@triton_op is used for T>1 prefill in the exported PTE).
"""
if eos_token_ids is None:
eos_token_ids = set()
Expand Down
Loading
Loading