diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 4cab2d39b..390a8b05f 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -157,6 +157,21 @@ class StochasticMixerConfig(MixerConfig): hint=FieldHint.feature, ) + predefined_layouts: list[list[str]] | None = Field( + default=None, + desc="List of predefined layouts to oversample. Each layout is a list of mixer names, one per layer. " + "Mixer names must match keys in the mixers dict.", + hint=FieldHint.feature, + ) + + predefined_layout_probability: float = Field( + default=0.0, + desc="Probability of sampling from predefined_layouts instead of using the sampling_strategy. " + "Must be in [0, 1]. Only used when predefined_layouts is provided.", + hint=FieldHint.feature, + valid=check_field(Assert.in_range_incl, 0.0, 1.0), + ) + seed_shift: int = Field( default=_BIG_PRIMES[11], desc="Seed shift for mixer sampling reproducibility.", @@ -191,6 +206,23 @@ def _validate(self) -> None: normalized_values = normalize_probabilities(list(self.sampling_weights.values())) self.sampling_weights = dict(zip(self.sampling_weights.keys(), normalized_values)) + # Validate predefined layouts + if self.predefined_layouts is not None: + if len(self.predefined_layouts) == 0: + raise ValueError("predefined_layouts must be non-empty if provided") + mixer_names = set(self.mixers.keys()) + for i, layout in enumerate(self.predefined_layouts): + unknown = set(layout) - mixer_names + if unknown: + raise ValueError( + f"predefined_layouts[{i}] contains unknown mixer names: {unknown}. " + f"Valid names: {mixer_names}" + ) + if self.predefined_layout_probability <= 0: + warnings.warn("predefined_layouts provided but predefined_layout_probability is 0") + elif self.predefined_layout_probability > 0: + raise ValueError("predefined_layout_probability > 0 but predefined_layouts is not provided") + @property def layer_class(self) -> "type[StochasticMixer]": from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py index 9def3895c..814b0592f 100644 --- a/fast_llm/layers/decoder/stochastic_mixer.py +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -16,7 +16,7 @@ ) from fast_llm.logging import get_model_debug_level from fast_llm.tensor import TensorMeta -from fast_llm.utils import safe_merge_dicts +from fast_llm.utils import Assert, safe_merge_dicts logger = logging.getLogger(__name__) @@ -111,7 +111,8 @@ def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str: if not self.training: return self._config.main_mixer_name - if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: + # Layout-based selection (full_layout strategy or predefined layout override) + if StochasticMixerKwargs.layout in kwargs: layout = kwargs[StochasticMixerKwargs.layout] counter = kwargs[StochasticMixerKwargs.layout_counter] idx = counter[0] @@ -192,6 +193,21 @@ def _sample_placement(self, counts: list[int], num_layers: int, generator: torch perm = torch.randperm(num_layers, generator=generator) return [layout[i] for i in perm.tolist()] + def _sample_predefined_layout(self, num_layers: int, generator: torch.Generator) -> list[str] | None: + """ + With probability `predefined_layout_probability`, pick a predefined layout uniformly. + Returns None if we should use the normal sampling strategy instead. + """ + if not self._config.predefined_layouts or self._config.predefined_layout_probability <= 0: + return None + coin = torch.rand(1, generator=generator).item() + if coin >= self._config.predefined_layout_probability: + return None + idx = torch.randint(len(self._config.predefined_layouts), (1,), generator=generator).item() + layout = list(self._config.predefined_layouts[idx]) + Assert.eq(len(layout), num_layers) + return layout + def preprocess(self, kwargs: dict[str, typing.Any]) -> None: from fast_llm.engine.distributed.config import MAX_SEED from fast_llm.layers.block.config import BlockKwargs @@ -202,8 +218,14 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: generator.manual_seed(seed) kwargs[StochasticMixerKwargs.generator] = generator - if self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: - num_layers = kwargs[BlockKwargs.num_blocks_in_sequence] + num_layers = kwargs[BlockKwargs.num_blocks_in_sequence] + predefined = self._sample_predefined_layout(num_layers, generator) + + if predefined is not None: + # Use predefined layout (overrides any sampling strategy) + kwargs[StochasticMixerKwargs.layout] = predefined + kwargs[StochasticMixerKwargs.layout_counter] = [0] + elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.full_layout: counts = self._sample_allocation(num_layers, generator) layout = self._sample_placement(counts, num_layers, generator) kwargs[StochasticMixerKwargs.layout] = layout