Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions fast_llm/layers/decoder/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,21 @@ class StochasticMixerConfig(MixerConfig):
hint=FieldHint.feature,
)

predefined_layouts: list[list[str]] | None = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list[list[str]] with default_factory=list would be enough

default=None,
desc="List of predefined layouts to oversample. Each layout is a list of mixer names, one per layer. "
"Mixer names must match keys in the mixers dict.",
hint=FieldHint.feature,
)

predefined_layout_probability: float = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be a list so each layout has its own probability?

default=0.0,
desc="Probability of sampling from predefined_layouts instead of using the sampling_strategy. "
"Must be in [0, 1]. Only used when predefined_layouts is provided.",
hint=FieldHint.feature,
valid=check_field(Assert.in_range_incl, 0.0, 1.0),
)

seed_shift: int = Field(
default=_BIG_PRIMES[11],
desc="Seed shift for mixer sampling reproducibility.",
Expand Down Expand Up @@ -191,6 +206,23 @@ def _validate(self) -> None:
normalized_values = normalize_probabilities(list(self.sampling_weights.values()))
self.sampling_weights = dict(zip(self.sampling_weights.keys(), normalized_values))

# Validate predefined layouts
if self.predefined_layouts is not None:
if len(self.predefined_layouts) == 0:
raise ValueError("predefined_layouts must be non-empty if provided")
mixer_names = set(self.mixers.keys())
for i, layout in enumerate(self.predefined_layouts):
unknown = set(layout) - mixer_names
if unknown:
raise ValueError(
f"predefined_layouts[{i}] contains unknown mixer names: {unknown}. "
f"Valid names: {mixer_names}"
)
if self.predefined_layout_probability <= 0:
warnings.warn("predefined_layouts provided but predefined_layout_probability is 0")
elif self.predefined_layout_probability > 0:
raise ValueError("predefined_layout_probability > 0 but predefined_layouts is not provided")

@property
def layer_class(self) -> "type[StochasticMixer]":
from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer
Expand Down
30 changes: 26 additions & 4 deletions fast_llm/layers/decoder/stochastic_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from fast_llm.logging import get_model_debug_level
from fast_llm.tensor import TensorMeta
from fast_llm.utils import safe_merge_dicts
from fast_llm.utils import Assert, safe_merge_dicts

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,7 +111,8 @@ def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str:
if not self.training:
return self._config.main_mixer_name

if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout:
# Layout-based selection (full_layout strategy or predefined layout override)
if StochasticMixerKwargs.layout in kwargs:
layout = kwargs[StochasticMixerKwargs.layout]
counter = kwargs[StochasticMixerKwargs.layout_counter]
idx = counter[0]
Expand Down Expand Up @@ -192,6 +193,21 @@ def _sample_placement(self, counts: list[int], num_layers: int, generator: torch
perm = torch.randperm(num_layers, generator=generator)
return [layout[i] for i in perm.tolist()]

def _sample_predefined_layout(self, num_layers: int, generator: torch.Generator) -> list[str] | None:
"""
With probability `predefined_layout_probability`, pick a predefined layout uniformly.
Returns None if we should use the normal sampling strategy instead.
"""
if not self._config.predefined_layouts or self._config.predefined_layout_probability <= 0:
return None
coin = torch.rand(1, generator=generator).item()
if coin >= self._config.predefined_layout_probability:
return None
idx = torch.randint(len(self._config.predefined_layouts), (1,), generator=generator).item()
layout = list(self._config.predefined_layouts[idx])
Assert.eq(len(layout), num_layers)
return layout

def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
from fast_llm.engine.distributed.config import MAX_SEED
from fast_llm.layers.block.config import BlockKwargs
Expand All @@ -202,8 +218,14 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
generator.manual_seed(seed)
kwargs[StochasticMixerKwargs.generator] = generator

if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout:
num_layers = kwargs[BlockKwargs.num_blocks_in_sequence]
num_layers = kwargs[BlockKwargs.num_blocks_in_sequence]
predefined = self._sample_predefined_layout(num_layers, generator)

if predefined is not None:
# Use predefined layout (overrides any sampling strategy)
kwargs[StochasticMixerKwargs.layout] = predefined
kwargs[StochasticMixerKwargs.layout_counter] = [0]
elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout:
counts = self._sample_allocation(num_layers, generator)
layout = self._sample_placement(counts, num_layers, generator)
kwargs[StochasticMixerKwargs.layout] = layout
Expand Down
Loading