Skip to content
Open
26 changes: 26 additions & 0 deletions docs/source/speechlm2/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,32 @@ Defaults come from Automodel's ``BackendConfig`` and auto-select TransformerEngi
DeepEP when available; override here to pin a specific backend (for example,
``attn: sdpa`` to bypass TE).

**Packed sequences (THD):**

.. code-block:: yaml

model:
packed_sequences: true # default false (right-padded BSHD path)
automodel_backend:
attn: te # THD path dispatches TE varlen FlashAttention

When ``packed_sequences`` is true, ``SALMAutomodel.prepare_inputs`` packs
each minibatch into a single flat ``[T_total, H]`` sequence with a
``cu_seqlens`` index instead of right-padding to ``[B, T_max, H]``.
``SALMAutomodel`` then forwards the THD metadata (``qkv_format``,
``cu_seqlens``, ``position_ids``, ``max_seqlen``) through ``forward()`` to
the LLM. The TE attention preprocessor splits the singular ``max_seqlen``
into the ``max_seqlen_q`` / ``max_seqlen_kv`` pair that
``DotProductAttention`` requires for ``qkv_format="thd"``. The packing also
rounds each utterance's flat length up to a multiple of ``2 * cp_size`` so
the same THD batch satisfies TE's CP DualChunkSwap contract — see the
"Context Parallelism (CP)" subsection in
:doc:`training_and_scaling` for the recommended pairing with ``cp_size > 1``.

Padding overhead drops from ``O(B * (T_max - T_avg))`` to
``O(per-utt rounding to 2*cp_size)``. Throughput improvement scales with
the variance of utterance lengths in your bucketing.

DuplexS2SModel Configuration
-----------------------------

Expand Down
79 changes: 77 additions & 2 deletions docs/source/speechlm2/training_and_scaling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,83 @@ For distributed inference, launch with ``torchrun``:
inputs=path/to/manifest \
ep_size=2

Configuration
^^^^^^^^^^^^^
Packed Sequences (THD)
""""""""""""""""""""""

``SALMAutomodel`` supports an opt-in packed-sequence (``THD``) training and
validation path that concatenates per-utterance text + audio embeddings into
a single flat ``[T_total, H]`` sequence with a ``cu_seqlens`` index, instead
of right-padding into the standard ``[B, T_max, H]`` (``BSHD``) layout. TE's
varlen FlashAttention then operates segment-by-segment without ever attending
across utterances, and Mamba's ``seq_idx`` is derived from the same
``cu_seqlens`` so SSM state resets at document boundaries.

For variable-length speech batches the padding overhead is substantial — the
``BSHD`` layout pays ``B * (T_max - T_avg)`` wasted compute per minibatch,
``THD`` pays only the per-utterance rounding to a multiple of ``2*cp_size``
(needed for TE's CP DualChunkSwap pattern). Throughput improvement scales
with the variance of utterance lengths.

Enable per-batch:

.. code-block:: yaml

model:
packed_sequences: true # opt-in; default false (BSHD)
automodel_backend:
attn: te # THD path requires TE attention

When ``packed_sequences`` is unset, the existing BSHD path is used unchanged.
Generate / inference always uses BSHD (it doesn't go through ``prepare_inputs``).

Context Parallelism (CP)
""""""""""""""""""""""""

``SALMAutomodel`` supports context parallelism for long-audio training on
hybrid Mamba/attention LLMs (e.g. Nemotron-V3). CP shards the sequence
dimension across GPUs so per-rank activations and KV-cache memory scale as
``T / cp_size`` instead of ``T``; attention layers go through TE's
DualChunkSwap pattern and Mamba mixers go through hidden-parallel
all-to-all (``MambaContextParallel`` in NeMo Automodel).

Enable via the strategy:

.. code-block:: yaml

trainer:
strategy:
_target_: nemo.collections.speechlm2.parts.parallel.AutomodelParallelStrategy
cp_size: 2 # context parallel size; must divide num_heads of every Mamba block
ep_size: 2 # may share the same ranks as CP

**The THD packed-sequence path is the only supported configuration under
CP.** Each utterance is its own attention segment and the per-utterance
sequence rounding aligns naturally with CP's ``2*cp_size`` requirement.

.. warning::
**BSHD + CP is not supported.** TE's fused-attention CP path supports
``causal`` but not ``padding_causal``, so the right-pad mask must be
dropped before the LLM. With the mask dropped, pad K/V leak into
real-token attention through the causal mask and the gradient through
the LoRA / projection parameters becomes ``NaN`` after the first
optimizer step (validated empirically: BSHD + CP=2 + EP=2 on a 2-GPU
run produces ``loss=4.62`` at step 1 then ``loss=nan`` from step 2
onwards). This is independent of the TE/cuDNN backward issue
documented below — setting ``NVTE_FUSED_ATTN=0`` does not fix it.
Set ``model.packed_sequences: true`` to use the THD path instead.

.. note::
**TE/THD exploding-gradients workaround on some GPUs.** On certain GPU
architectures (notably Blackwell ``sm_120``), the cuDNN backend that
TransformerEngine 2.14 picks for ``qkv_format="thd"`` with
``attn_mask_type="padding_causal"`` returns correct forward activations
but gradients amplified 8×–960× per layer. Compounded across the LLM's
attention stack this drives gradients to ``1e22``-magnitudes at step 0,
the gradient-clip-by-norm computes ``1.0 / inf = 0``, and Adam's moments
eventually NaN. Force TE to dispatch FlashAttention instead of cuDNN by
setting ``NVTE_FUSED_ATTN=0`` in the launcher environment (requires
``flash-attn`` to be installed for your GPU arch). The FlashAttention
THD/``padding_causal`` backward is gradient-correct on the same shapes.

To configure parallelism, modify the ``trainer.strategy`` section in your YAML config:

Expand Down
138 changes: 117 additions & 21 deletions nemo/collections/speechlm2/models/salm_automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,29 @@
input_embeds: Tensor,
attention_mask: Tensor = None,
cache=None,
**llm_kwargs,
) -> dict[str, Tensor]:
"""
Implements a fully offline forward pass through the entire model.
The flow is the following:

|speech and text embeddings| -> |llm| -> |lm_head| -> |token ids|

``llm_kwargs`` carries optional THD/packed-sequence metadata
(``qkv_format``, ``cu_seqlens``, ``position_ids``, ``max_seqlen``);
it is empty for the BSHD path.
"""
# input_embeds and out: (B, T, H)
# input_embeds: (B, T, H) for BSHD or (T_total, H) for THD packed
# (the THD shape mirrors Automodel's _shard_thd_chunk_for_te output —
# the model squeezes 3D inputs internally when qkv_format=="thd", so
# passing 2D directly skips that hop)
out = self.llm(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
past_key_values=cache,
use_cache=cache is not None,
return_dict=True,
**llm_kwargs,
)
if not isinstance(out, dict):
# NeMo Automodel doesn't respect return_dict=True yet
Expand All @@ -186,62 +194,139 @@
* Take care of any necessary slicing to align the shapes of source audio,
target audio, and target token ids.
"""
# Source audio encoding.
# Input audio: (B, T_samples)
# Audio embeddings: (B, T, H)
audio_embs = encode_audio_with_optional_chunking(
from nemo.collections.speechlm2.parts.cp_helpers import (
encode_audio_with_cp_distribution,
get_cp_mesh,
shard_bshd_for_cp,
)

cp_mesh, cp_size, _ = get_cp_mesh(getattr(self, "_device_mesh", None))

# Source audio encoding (distributed across CP ranks when CP is active).
# Input audio: (B_aud, T_samples) → list of (L_i, H) embeddings.
audio_embs = encode_audio_with_cp_distribution(
self.perception,
batch["audios"],
batch["audio_lens"],
chunk_size_seconds=self.cfg.get("encoder_chunk_size_seconds", None),
sampling_rate=self.sampling_rate,
cp_mesh=cp_mesh,
)
input_ids_to_embed = torch.where(batch["input_ids"] == self.audio_locator_tag_id, 0, batch["input_ids"])
text_embs = self._embed_tokens(input_ids_to_embed)
target_ids_full = batch["input_ids"].where(batch["loss_mask"], -100) # CrossEntropyLoss().ignore_index

# Packed-sequence (THD) path — used for both training and validation when enabled.
# Generate stays on the BSHD path (it doesn't go through prepare_inputs).
if self.cfg.get("packed_sequences", False):
from nemo.collections.speechlm2.parts.packed_sequences import prepare_packed_llm_inputs

return prepare_packed_llm_inputs(
input_ids=batch["input_ids"],
text_embs=text_embs,
audio_embs=audio_embs,
target_ids=target_ids_full,
padding_id=self.text_pad_id,
placeholder_id=self.audio_locator_tag_id,
device_mesh=getattr(self, "_device_mesh", None),
)

input_embs, target_ids, attention_mask = replace_placeholders_and_build_targets(
input_ids=batch["input_ids"],
embeds=text_embs,
padding_id=self.text_pad_id,
placeholder_id=self.audio_locator_tag_id,
replacements=audio_embs,
target_ids=batch["input_ids"].where(batch["loss_mask"], -100), # CrossEntropyLoss().ignore_index
target_ids=target_ids_full,
)
input_embs = input_embs[:, :-1]
attention_mask = attention_mask[:, :-1]
target_ids = target_ids[:, 1:]

# Combine target audio and text into a single tensor to slice them together.
# It will also help us truncate the sequence lengths to be divisible by TP world size,
# when TP is enabled.
# Input ids: (B, T, K+1)
if self._use_tp:
tp_world_size = self.device_mesh["tp"].size()
if (remainder := (input_embs.shape[1] - 1) % tp_world_size) != 0:
# Sequence-length divisibility for sequence/context parallelism.
# CP path: pad to 2*cp_size*tp_size and partition along the seq dim
# (the existing TP truncation is folded into the CP padding). BSHD-only
# path keeps the original TP-truncation behavior.
tp_size = self.device_mesh["tp"].size() if self._use_tp else 1
if cp_size > 1:
sharded = shard_bshd_for_cp(input_embs, attention_mask, target_ids, cp_mesh, tp_size=tp_size)
input_embs = sharded["input_embs"]
attention_mask = sharded["attention_mask"]

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable attention_mask is not used.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
target_ids = sharded["target_ids"]
elif self._use_tp:
if (remainder := (input_embs.shape[1] - 1) % tp_size) != 0:
# Truncate some tokens from the end to make the sequence length shape divisible by tensor parallelism
# world size. Otherwise, sequence parallelism will change the input shape making leading to mismatches.
input_embs = input_embs[:, :-remainder]
attention_mask = attention_mask[:, :-remainder]
target_ids = target_ids[:, :-remainder]

# TE's fused-attention CP path rejects ``padding_causal``; only ``causal``
# is supported. BSHD batches are left-padded so dropping the padding mask
# lets pad K/V leak into real-token attention — empirically this drives
# the loss to NaN at step 2 (the gradient through the LoRA / projection
# parameters is corrupted by the leak after one optimizer step). BSHD +
# CP is therefore not a supported configuration; set
# ``model.packed_sequences: true`` to use the THD path under CP, which
# uses cu_seqlens-aware attention and has no equivalent issue.
llm_attention_mask = None if cp_size > 1 else attention_mask

return {
"input_embeds": input_embs,
"attention_mask": attention_mask,
"attention_mask": llm_attention_mask,
"target_ids": target_ids,
"llm_kwargs": {},
}

def on_fit_start(self) -> None:
"""Configure the MoE aux-loss backward scaler to cancel FSDP's gradient
averaging (see ``_configure_moe_aux_loss_scaler``)."""
self._validate_parallelism_compatibility()
self._configure_moe_aux_loss_scaler()

def _validate_parallelism_compatibility(self) -> None:
"""Raise on known-incompatible THD/CP/backend configurations.

Delegates to :func:`nemo.collections.speechlm2.parts.parallel.validate_parallelism_compatibility`
with the runtime-derived values from this model's config and device mesh.
"""
import os

from nemo.collections.speechlm2.parts.parallel import validate_parallelism_compatibility

cp_size = 1
device_mesh = getattr(self, "_device_mesh", None)
if device_mesh is not None:
names = device_mesh.mesh_dim_names or ()
if "cp" in names:
cp_size = device_mesh["cp"].size()

attn_backend = self.cfg.get("automodel_backend", {}).get("attn", "te")
nvte_fused_attn = os.environ.get("NVTE_FUSED_ATTN")
device_capability = (
torch.cuda.get_device_capability() if torch.cuda.is_available() else None
)

validate_parallelism_compatibility(
packed_sequences=bool(self.cfg.get("packed_sequences", False)),
cp_size=cp_size,
attn_backend=attn_backend,
nvte_fused_attn=nvte_fused_attn,
device_capability=device_capability,
)

def training_step(self, batch: dict, batch_idx: int):
self._current_batch_idx = batch_idx
for m in (self.perception.preprocessor, self.perception.encoder, self.llm):
if is_frozen(m):
m.eval()

inputs = self.prepare_inputs(batch)
forward_outputs = self(inputs["input_embeds"], attention_mask=inputs["attention_mask"])
forward_outputs = self(
inputs["input_embeds"],
attention_mask=inputs["attention_mask"],
**inputs.get("llm_kwargs", {}),
)
num_frames = (inputs["target_ids"] != -100).long().sum()

# Match Automodel's training recipe: normalize CE by the *global* token count across
Expand All @@ -260,9 +345,10 @@
num_frames_global = num_frames_global.clamp(min=1)

with loss_parallel():
logits = forward_outputs["logits"]
loss_sum = torch.nn.functional.cross_entropy(
forward_outputs["logits"].flatten(0, 1), # (B, T, Vt) -> (*, Vt)
inputs["target_ids"].flatten(0, 1),
logits.reshape(-1, logits.size(-1)), # BSHD (B,T,V) or THD (1,T,V) -> (*, V)
inputs["target_ids"].reshape(-1), # BSHD (B,T) or THD (T,) -> (*,)
reduction="sum",
ignore_index=-100,
)
Expand All @@ -273,7 +359,12 @@
with torch.no_grad():
loss_display = loss_sum.detach() / num_frames.clamp(min=1)

B, T = inputs["input_embeds"].shape[:2]
# Input embeds shape is (B, T, H) for BSHD or (T, H) for THD packed.
input_embeds = inputs["input_embeds"]
if input_embeds.dim() == 2:
B, T = 1, input_embeds.shape[0]
else:
B, T = input_embeds.shape[:2]
ans = {
"loss": loss,
"learning_rate": (
Expand Down Expand Up @@ -318,13 +409,18 @@
if dataset_batch is None:
continue # some dataset is exhausted
inputs = self.prepare_inputs(dataset_batch)
forward_outputs = self(inputs["input_embeds"], attention_mask=inputs["attention_mask"])
forward_outputs = self(
inputs["input_embeds"],
attention_mask=inputs["attention_mask"],
**inputs.get("llm_kwargs", {}),
)
num_frames = (inputs["target_ids"] != -100).long().sum()
with loss_parallel():
logits = forward_outputs["logits"]
loss = (
torch.nn.functional.cross_entropy(
forward_outputs["logits"].flatten(0, 1),
inputs["target_ids"].flatten(0, 1),
logits.reshape(-1, logits.size(-1)),
inputs["target_ids"].reshape(-1),
reduction="sum",
ignore_index=-100,
)
Expand Down
Loading
Loading