From f86e34673d35023bf082b8b5c521abc728a5005e Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 May 2026 08:55:55 -0700 Subject: [PATCH 1/5] Adding Conformer encoder style Transformer encoder Signed-off-by: taejinp --- .../asr/modules/transformer_encoder.py | 318 +++++++++++++++--- .../asr/parts/submodules/subsampling.py | 42 +++ .../asr/test_transformer_encoder.py | 230 ++++++++++++- 3 files changed, 525 insertions(+), 65 deletions(-) diff --git a/nemo/collections/asr/modules/transformer_encoder.py b/nemo/collections/asr/modules/transformer_encoder.py index f2af64cb8974..07b0b3113256 100644 --- a/nemo/collections/asr/modules/transformer_encoder.py +++ b/nemo/collections/asr/modules/transformer_encoder.py @@ -12,12 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math +from collections import OrderedDict from dataclasses import dataclass import torch import torch.nn as nn from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from nemo.collections.asr.parts.submodules.multi_head_attention import PositionalEncoding +from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling, FeatureStacking, StackingSubsampling +from nemo.collections.asr.parts.utils.regularization_utils import compute_stochastic_depth_drop_probs +from nemo.core.classes.common import typecheck +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.mixins import AccessMixin +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types import AcousticEncodedRepresentation, BoolType, LengthsType, NeuralType, SpectrogramType + flex_attention_compiled = torch.compile(flex_attention, dynamic=True) @@ -47,48 +58,6 @@ def pad_mask(b, h, q_idx, kv_idx): return pad_mask -class FeatureStacking(nn.Module): - """Stacks consecutive input frames and projects to model dimension. - - Reduces the temporal resolution by ``subsampling_factor`` while increasing - the feature dimension proportionally, then linearly projects back to - ``feat_out``. - - Args: - subsampling_factor: Number of consecutive frames to stack. - feat_in: Input feature dimension (e.g. number of mel bins). - feat_out: Output feature dimension (model hidden size). - """ - - def __init__(self, subsampling_factor: int, feat_in: int, feat_out: int): - super().__init__() - self.subsampling_factor = subsampling_factor - self.proj = nn.Linear(subsampling_factor * feat_in, feat_out, bias=False) - - def compute_num_out_frames(self, in_frames): - return (in_frames + self.subsampling_factor - 1) // self.subsampling_factor - - def forward(self, x, lengths): - """ - Args: - x: (B, C, T) — input features (channels-first from preprocessor). - lengths: (B,) — valid lengths per sample. - Returns: - x: (B, T', feat_out) — stacked and projected features. - lengths: (B,) — updated lengths after subsampling. - """ - x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) - b, t, c = x.size() - pad_size = (self.subsampling_factor - (t % self.subsampling_factor)) % self.subsampling_factor - if pad_size > 0: - x = nn.functional.pad(x, (0, 0, 0, pad_size)) - t_new = (t + pad_size) // self.subsampling_factor - x = x.reshape(b, t_new, c * self.subsampling_factor) - x = self.proj(x) - lengths = self.compute_num_out_frames(lengths) - return x, lengths - - class FeedForward(nn.Module): def __init__(self, cfg: TransformerEncoderConfig): super().__init__() @@ -131,7 +100,8 @@ def forward(self, x, block_mask=None): q = self.q_norm(q).to(v.dtype) k = self.k_norm(k).to(v.dtype) - out = flex_attention_compiled(q, k, v, block_mask=block_mask) + attn_fn = flex_attention_compiled if q.is_cuda else flex_attention + out = attn_fn(q, k, v, block_mask=block_mask) out = out.transpose(1, 2).contiguous().view(B, T, self.d_model) return self.out_proj(out) @@ -151,10 +121,10 @@ def forward(self, x, block_mask=None): return x -class TransformerEncoder(nn.Module): +class TransformerEncoder(NeuralModule, Exportable, AccessMixin): """Pre-norm Transformer encoder for ASR. - Architecture: FeatureStacking -> LayerNorm -> N x TransformerBlock -> FinalNorm + Architecture: PreEncode -> PositionalEncoding -> LayerNorm -> N x TransformerBlock -> FinalNorm Uses PyTorch FlexAttention for attention computation. On CUDA, mask functions are compiled into fused Triton kernels with block-sparse optimization. On CPU, @@ -165,7 +135,18 @@ class TransformerEncoder(nn.Module): d_model: Transformer hidden dimension. n_heads: Number of attention heads. n_layers: Number of transformer blocks. + feat_out: Output feature dimension. Defaults to ``d_model``. + subsampling: Subsampling method. Supports ``feature_stacking`` for the + Transformer-native ``FeatureStacking`` module, plus Conformer-style + ``stacking``, ``stacking_norm``, ``vggnet``, ``striding``, + ``dw_striding``/``dw-striding``, ``striding_conv1d``, and ``dw_striding_conv1d``. + subsampling_factor: Subsampling factor for the pre-encoder. + subsampling_conv_chunking_factor: Optional input chunking factor for convolutional subsampling. + subsampling_conv_channels: Hidden channels for convolutional subsampling. + causal_downsampling: Whether convolutional subsampling should be causal. drop_rate: Dropout probability. + dropout_pre_encoder: Dropout probability after positional encoding. Defaults to ``drop_rate``. + dropout_emb: Dropout probability for positional embeddings. qkv_bias: Whether to use bias in Q/K/V projections. qk_norm: Whether to apply per-head LayerNorm to Q and K before the dot product. ff_expansion: Feed-forward expansion factor (float to support sub-1x for MoE). @@ -173,29 +154,99 @@ class TransformerEncoder(nn.Module): transformer block (BERT/ViT-style). Set False to match pre-norm transformers such as Whisper or GPT-2 — required when loading pretrained weights from those checkpoints. - subsampling_factor: Frame stacking factor for the pre-encoder. + pos_emb_max_len: Initial maximum length for sinusoidal positional embeddings. + xscaling: Whether to scale embeddings by ``sqrt(d_model)`` before adding positions. + stochastic_depth_drop_prob: Final-layer stochastic depth drop probability. + stochastic_depth_mode: Stochastic depth schedule, ``linear`` or ``uniform``. + stochastic_depth_start_layer: First 1-based layer index eligible for stochastic depth. attn_mode: Attention pattern — currently only "full" (bidirectional) is supported. + sync_max_audio_length: When true, sync positional encoding allocation length across distributed ranks. """ + def input_example(self, max_batch=1, max_dim=256): + """Generates input examples for tracing and export.""" + dev = next(self.parameters()).device + input_example = torch.randn(max_batch, self._feat_in, max_dim, device=dev) + input_example_length = torch.randint(max_dim // 4, max_dim, (max_batch,), device=dev, dtype=torch.int64) + return tuple([input_example, input_example_length]) + + @property + def input_types(self): + """Returns definitions of module input ports.""" + return OrderedDict( + { + "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "length": NeuralType(tuple('B'), LengthsType()), + "bypass_pre_encode": NeuralType(tuple(), BoolType(), optional=True), + } + ) + + @property + def input_types_for_export(self): + """Returns definitions of module input ports for export.""" + return self.input_types + + @property + def output_types(self): + """Returns definitions of module output ports.""" + return OrderedDict( + { + "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + } + ) + + @property + def output_types_for_export(self): + """Returns definitions of module output ports for export.""" + return self.output_types + + @property + def disabled_deployment_input_names(self): + return set() + + @property + def disabled_deployment_output_names(self): + return set() + def __init__( self, feat_in: int = 80, d_model: int = 512, n_heads: int = 8, n_layers: int = 17, + feat_out: int = -1, + causal_downsampling: bool = False, + subsampling: str = 'feature_stacking', + subsampling_factor: int = 4, + subsampling_conv_chunking_factor: int = 1, + subsampling_conv_channels: int = -1, drop_rate: float = 0.1, + dropout_pre_encoder: float = None, + dropout_emb: float = 0.0, qkv_bias: bool = False, qk_norm: bool = False, ff_expansion: float = 4.0, pre_block_norm: bool = True, - subsampling_factor: int = 4, + pos_emb_max_len: int = 5000, + xscaling: bool = True, + stochastic_depth_drop_prob: float = 0.0, + stochastic_depth_mode: str = "linear", + stochastic_depth_start_layer: int = 1, attn_mode: str = "full", + sync_max_audio_length: bool = True, ): super().__init__() if d_model % n_heads != 0: raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads}).") if attn_mode != "full": raise ValueError(f"attn_mode='{attn_mode}' is not yet supported. Currently only 'full' is available.") + if dropout_pre_encoder is None: + dropout_pre_encoder = drop_rate + if subsampling == 'feature-stacking': + subsampling = 'feature_stacking' + if subsampling == 'dw-striding': + subsampling = 'dw_striding' cfg = TransformerEncoderConfig( feat_in=feat_in, @@ -211,31 +262,188 @@ def __init__( attn_mode=attn_mode, ) self.d_model = d_model - - self.pre_encode = FeatureStacking(subsampling_factor, feat_in, d_model) + self.n_layers = n_layers + self._feat_in = feat_in + self.subsampling = subsampling + self.subsampling_factor = subsampling_factor + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + self.sync_max_audio_length = sync_max_audio_length + + if subsampling_conv_channels == -1: + subsampling_conv_channels = d_model + if subsampling == 'feature_stacking': + self.pre_encode = FeatureStacking(subsampling_factor, feat_in, d_model) + elif subsampling and subsampling_factor > 1: + if subsampling in ['stacking', 'stacking_norm']: + self.pre_encode = StackingSubsampling( + subsampling_factor=subsampling_factor, + feat_in=feat_in, + feat_out=d_model, + norm=True if subsampling == 'stacking_norm' else False, + ) + else: + self.pre_encode = ConvSubsampling( + subsampling=subsampling, + subsampling_factor=subsampling_factor, + feat_in=feat_in, + feat_out=d_model, + conv_channels=subsampling_conv_channels, + subsampling_conv_chunking_factor=subsampling_conv_chunking_factor, + activation=nn.ReLU(True), + is_causal=causal_downsampling, + ) + else: + self.pre_encode = nn.Linear(feat_in, d_model) + + self._feat_out = d_model + if xscaling: + self.xscale = math.sqrt(d_model) + else: + self.xscale = None + self.pos_emb_max_len = pos_emb_max_len + self.pos_enc = PositionalEncoding( + d_model=d_model, + dropout_rate=dropout_pre_encoder, + max_len=pos_emb_max_len, + xscale=self.xscale, + dropout_rate_emb=dropout_emb, + ) self.embed_norm = nn.LayerNorm(d_model) if pre_block_norm else nn.Identity() self.layers = nn.ModuleList([TransformerBlock(cfg) for _ in range(n_layers)]) self.final_norm = nn.LayerNorm(d_model) + if feat_out > 0 and feat_out != self._feat_out: + self.out_proj = nn.Linear(self._feat_out, feat_out) + self._feat_out = feat_out + else: + self.out_proj = None + + self.set_max_audio_length(self.pos_emb_max_len) + self.use_pad_mask = True + self.layer_drop_probs = compute_stochastic_depth_drop_probs( + len(self.layers), stochastic_depth_drop_prob, stochastic_depth_mode, stochastic_depth_start_layer + ) + self.interctc_capture_at_layers = None + + def forward_for_export(self, audio_signal, length): + """Forward function for model export. Please see ``forward()`` for details.""" + return self.forward_internal(audio_signal=audio_signal, length=length) - def forward(self, audio_signal, length): + @typecheck() + def forward(self, audio_signal, length, bypass_pre_encode=False): """ Args: - audio_signal: (B, C, T) — mel spectrogram from preprocessor. + audio_signal: ``(B, C, T)`` mel spectrogram when ``bypass_pre_encode=False``, + or ``(B, T, D)`` pre-encoded embeddings when ``bypass_pre_encode=True``. length: (B,) — valid frame counts per sample. + bypass_pre_encode: If true, skip the pre-encoder and consume frame-level embeddings. Returns: x: (B, D, T') — encoded representation (channels-first). length: (B,) — output lengths after subsampling. """ - x, length = self.pre_encode(audio_signal, length) - + if not bypass_pre_encode and audio_signal.shape[-2] != self._feat_in: + raise ValueError( + f"If bypass_pre_encode is False, audio_signal should have shape " + f"(batch, {self._feat_in}, n_frame) but got last dimension {audio_signal.shape[-2]}." + ) + if bypass_pre_encode and audio_signal.shape[-1] != self.d_model: + raise ValueError( + f"If bypass_pre_encode is True, audio_signal should have shape " + f"(batch, n_frame, {self.d_model}) but got last dimension {audio_signal.shape[-1]}." + ) + + if bypass_pre_encode: + self.update_max_seq_length(seq_length=audio_signal.size(1), device=audio_signal.device) + else: + self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device) + return self.forward_internal(audio_signal, length, bypass_pre_encode=bypass_pre_encode) + + def forward_internal(self, audio_signal, length, bypass_pre_encode=False): + if length is None: + length = audio_signal.new_full( + (audio_signal.size(0),), + audio_signal.size(1) if bypass_pre_encode else audio_signal.size(-1), + dtype=torch.int64, + device=audio_signal.device, + ) + + if not bypass_pre_encode: + if isinstance(self.pre_encode, FeatureStacking): + x, length = self.pre_encode(audio_signal, length) + else: + x = torch.transpose(audio_signal, 1, 2) + if isinstance(self.pre_encode, nn.Linear): + x = self.pre_encode(x) + elif not isinstance(self.pre_encode, FeatureStacking): + x, length = self.pre_encode(x=x, lengths=length) + length = length.to(torch.int64) + else: + x = audio_signal + length = length.to(torch.int64) + + x, _ = self.pos_enc(x=x) x = self.embed_norm(x) B, T, _ = x.shape - block_mask = create_block_mask(_make_padding_mod(length), B=B, H=1, Q_LEN=T, KV_LEN=T, device=x.device) + if self.use_pad_mask: + block_mask = create_block_mask(_make_padding_mod(length), B=B, H=1, Q_LEN=T, KV_LEN=T, device=x.device) + else: + block_mask = None - for layer in self.layers: + for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)): + original_signal = x x = layer(x, block_mask=block_mask) + if self.training and drop_prob > 0.0: + should_drop = torch.rand(1, device=x.device) < drop_prob + if should_drop: + x = x * 0.0 + original_signal + else: + x = (x - original_signal) / (1.0 - drop_prob) + original_signal + + if self.is_access_enabled(getattr(self, "model_guid", None)): + if self.interctc_capture_at_layers is None: + self.interctc_capture_at_layers = self.access_cfg.get('interctc', {}).get('capture_layers', []) + if lth in self.interctc_capture_at_layers: + lth_audio_signal = x + if self.out_proj is not None: + lth_audio_signal = self.out_proj(lth_audio_signal) + self.register_accessible_tensor( + name=f'interctc/layer_output_{lth}', tensor=torch.transpose(lth_audio_signal, 1, 2) + ) + self.register_accessible_tensor(name=f'interctc/layer_length_{lth}', tensor=length) + x = self.final_norm(x) + if self.out_proj is not None: + x = self.out_proj(x) x = x.transpose(1, 2) # (B, T, D) -> (B, D, T) + length = length.to(dtype=torch.int64) return x, length + + def update_max_seq_length(self, seq_length: int, device): + """ + Updates the maximum sequence length for positional encodings. + + Args: + seq_length: New maximum sequence length. + device: Device to use for computations. + """ + if self.sync_max_audio_length and torch.distributed.is_initialized(): + global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device) + torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX) + seq_length = global_max_len.int().item() + + if seq_length > self.max_audio_length: + self.set_max_audio_length(seq_length) + + def set_max_audio_length(self, max_audio_length): + """Sets maximum input length and extends positional encodings if needed.""" + self.max_audio_length = max_audio_length + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + self.pos_enc.extend_pe(max_audio_length, device, dtype) + + def enable_pad_mask(self, on=True): + """Enables or disables pad masking and returns the previous state.""" + mask = self.use_pad_mask + self.use_pad_mask = on + return mask diff --git a/nemo/collections/asr/parts/submodules/subsampling.py b/nemo/collections/asr/parts/submodules/subsampling.py index 7f9fc606991c..d3a21aae53c4 100644 --- a/nemo/collections/asr/parts/submodules/subsampling.py +++ b/nemo/collections/asr/parts/submodules/subsampling.py @@ -22,6 +22,48 @@ from nemo.utils import logging +class FeatureStacking(nn.Module): + """Stacks consecutive input frames and projects to model dimension. + + Reduces the temporal resolution by ``subsampling_factor`` while increasing + the feature dimension proportionally, then linearly projects back to + ``feat_out``. + + Args: + subsampling_factor: Number of consecutive frames to stack. + feat_in: Input feature dimension. + feat_out: Output feature dimension. + """ + + def __init__(self, subsampling_factor: int, feat_in: int, feat_out: int): + super().__init__() + self.subsampling_factor = subsampling_factor + self.proj = nn.Linear(subsampling_factor * feat_in, feat_out, bias=False) + + def compute_num_out_frames(self, in_frames): + return (in_frames + self.subsampling_factor - 1) // self.subsampling_factor + + def forward(self, x, lengths): + """ + Args: + x: (B, C, T) input features. + lengths: (B,) valid lengths per sample. + Returns: + x: (B, T', feat_out) stacked and projected features. + lengths: (B,) updated lengths after subsampling. + """ + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + b, t, c = x.size() + pad_size = (self.subsampling_factor - (t % self.subsampling_factor)) % self.subsampling_factor + if pad_size > 0: + x = nn.functional.pad(x, (0, 0, 0, pad_size)) + t_new = (t + pad_size) // self.subsampling_factor + x = x.reshape(b, t_new, c * self.subsampling_factor) + x = self.proj(x) + lengths = self.compute_num_out_frames(lengths) + return x, lengths + + class StackingSubsampling(torch.nn.Module): """Stacking subsampling which simply stacks consecutive frames to reduce the sampling rate Args: diff --git a/tests/collections/asr/test_transformer_encoder.py b/tests/collections/asr/test_transformer_encoder.py index 0cc2f174a1e5..1307f030c832 100644 --- a/tests/collections/asr/test_transformer_encoder.py +++ b/tests/collections/asr/test_transformer_encoder.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pytest import torch @@ -73,6 +74,7 @@ def test_padding_when_not_divisible(self): out, out_lengths = stacking(x, lengths) expected_t = stacking.compute_num_out_frames(T) assert out.shape == (B, expected_t, 256) + assert out_lengths[0].item() == expected_t @pytest.mark.unit def test_length_shorter_than_batch(self): @@ -99,6 +101,213 @@ def test_no_padding_when_divisible(self): assert out_lengths[0].item() == stacking.compute_num_out_frames(T) +class TestStochasticDepth: + """Testing stochastic depth functionality.""" + + def test_stochastic_depth_model_creation(self): + """Testing basic model creation and the drop probs are correctly assigned.""" + n_layers = 4 + model = TransformerEncoder(feat_in=10, n_layers=n_layers, d_model=4, n_heads=2, feat_out=8) + + # checking that by default SD is disabled + assert model.layer_drop_probs == [0.0] * n_layers + + # linear mode + for drop_prob in [0.3, 0.5, 0.9]: + for start_layer in [1, 3]: + model = TransformerEncoder( + feat_in=10, + n_layers=n_layers, + d_model=4, + n_heads=2, + feat_out=8, + stochastic_depth_drop_prob=drop_prob, + stochastic_depth_start_layer=start_layer, + ) + L = n_layers - start_layer + assert model.layer_drop_probs == [0.0] * start_layer + [drop_prob * l / L for l in range(1, L + 1)] + + # uniform mode + for drop_prob in [0.3, 0.5, 0.9]: + model = TransformerEncoder( + feat_in=10, + n_layers=n_layers, + d_model=4, + n_heads=2, + feat_out=8, + stochastic_depth_drop_prob=drop_prob, + stochastic_depth_mode="uniform", + stochastic_depth_start_layer=start_layer, + ) + L = n_layers - start_layer + assert model.layer_drop_probs == [0.0] * start_layer + [drop_prob] * L + + # checking for errors + for drop_prob in [-1.0, 1.0]: + with pytest.raises(ValueError, match="stochastic_depth_drop_prob has to be in"): + TransformerEncoder( + feat_in=10, + n_layers=n_layers, + d_model=4, + n_heads=2, + feat_out=8, + stochastic_depth_drop_prob=drop_prob, + stochastic_depth_mode="uniform", + ) + + with pytest.raises(ValueError, match="stochastic_depth_mode has to be one of"): + TransformerEncoder(feat_in=10, n_layers=n_layers, d_model=4, n_heads=2, feat_out=8, stochastic_depth_mode="weird") + + for start_layer in [-1, 0, 5]: + with pytest.raises(ValueError, match="stochastic_depth_start_layer has to be in"): + TransformerEncoder( + feat_in=10, + n_layers=n_layers, + d_model=4, + n_heads=2, + feat_out=8, + stochastic_depth_start_layer=start_layer, + ) + + @pytest.mark.pleasefixme + def test_stochastic_depth_forward(self): + """Testing that forward works and we get randomness during training, but not during eval.""" + random_input = torch.rand((1, 2, 16)) + random_length = torch.tensor([16], dtype=torch.int64) + + model = TransformerEncoder( + feat_in=2, + n_layers=3, + d_model=4, + n_heads=2, + feat_out=4, + stochastic_depth_drop_prob=0.8, + drop_rate=0.0, + dropout_pre_encoder=0.0, + dropout_emb=0.0, + ) + model.train() + outputs = [None] * 5 + for i in range(5): + outputs[i] = model(audio_signal=random_input, length=random_length)[0] + # checking that not all outputs are the same + num_diff = 0 + for i in range(1, 5): + if not torch.allclose(outputs[i], outputs[0]): + num_diff += 1 + assert num_diff > 0 + + model.eval() + outputs = [None] * 5 + for i in range(5): + outputs[i] = model(audio_signal=random_input, length=random_length)[0] + # checking that not all outputs are the same + num_diff = 0 + for i in range(1, 5): + if not torch.allclose(outputs[i], outputs[0]): + num_diff += 1 + assert num_diff == 0 + + +class TestBypassPreEncode: + """Testing bypass pre-encode functionality.""" + + def test_bypass_pre_encode_forward(self): + """Testing that forward works with "bypass pre-encode" mode.""" + # For pre-encoded embeddings, the shape is (batch_size, n_frames, emb_dim) + batch_size = 2 + n_frames, emb_dim, feat_out = 17, 16, 8 + random_input = torch.rand((batch_size, n_frames, emb_dim)) + random_length = torch.tensor([n_frames] * batch_size, dtype=torch.int64) + + model = TransformerEncoder( + feat_in=10, + n_layers=3, + d_model=emb_dim, + n_heads=4, + feat_out=feat_out, + stochastic_depth_drop_prob=0.0, + drop_rate=0.0, + dropout_pre_encoder=0.0, + dropout_emb=0.0, + ) + model.train() + fwd_outputs = model(audio_signal=random_input, length=random_length, bypass_pre_encode=True)[0] + assert fwd_outputs.shape == (batch_size, feat_out, n_frames) + + model.eval() + fwd_outputs = model(audio_signal=random_input, length=random_length, bypass_pre_encode=True)[0] + assert fwd_outputs.shape == (batch_size, feat_out, n_frames) + + def test_error_shape_invalid_bypass_pre_encode_forward(self): + """ + Testing that error messages are correctly triggered regarding "bypass pre-encode" mode. + Both correct samples and wrongs samples are tested. + + (1) bypass_pre_encode = False (default): + `audio_signal` must be a tensor containing audio features. + Shape: (batch, self._feat_in, n_frames) + (2) bypass_pre_encode = True: + `audio_signal` must be a tensor containing pre-encoded embeddings. + Shape: (batch, n_frame, self.d_model) + """ + batch_size = 2 + n_frames, emb_dim, feat_in, feat_out = 17, 16, 10, 8 + + pre_encode_input = torch.rand((batch_size, n_frames, emb_dim)) + feat_input = torch.rand((batch_size, feat_in, n_frames)) + input_length = torch.tensor([n_frames] * batch_size, dtype=torch.int64) + + model = TransformerEncoder( + feat_in=feat_in, + n_layers=3, + d_model=emb_dim, + n_heads=4, + feat_out=feat_out, + stochastic_depth_drop_prob=0.0, + drop_rate=0.0, + dropout_pre_encoder=0.0, + dropout_emb=0.0, + ) + sub_sampled_n_frames = np.ceil(n_frames / model.subsampling_factor) + + # Test with bypass_pre_encode = True, should be pre_encode_input but given feat_input. + model.train() + with pytest.raises(ValueError): + model(audio_signal=feat_input, length=input_length, bypass_pre_encode=True) + + model.eval() + with pytest.raises(ValueError): + model(audio_signal=feat_input, length=input_length, bypass_pre_encode=True) + + # Test with bypass_pre_encode = True, given the correct input pre_encode_input. + model.train() + fwd_outputs = model(audio_signal=pre_encode_input, length=input_length, bypass_pre_encode=True)[0] + assert fwd_outputs.shape == (batch_size, feat_out, n_frames) + + model.eval() + fwd_outputs = model(audio_signal=pre_encode_input, length=input_length, bypass_pre_encode=True)[0] + assert fwd_outputs.shape == (batch_size, feat_out, n_frames) + + # Test with bypass_pre_encode = False, should be feat_input but given pre_encode_input. + model.train() + with pytest.raises(ValueError): + model(audio_signal=pre_encode_input, length=input_length, bypass_pre_encode=False) + + model.eval() + with pytest.raises(ValueError): + model(audio_signal=pre_encode_input, length=input_length, bypass_pre_encode=False) + + # Test with bypass_pre_encode = False, given the correct input feat_input. + model.train() + fwd_outputs = model(audio_signal=feat_input, length=input_length, bypass_pre_encode=False)[0] + assert fwd_outputs.shape == (batch_size, feat_out, sub_sampled_n_frames) + + model.eval() + fwd_outputs = model(audio_signal=feat_input, length=input_length, bypass_pre_encode=False)[0] + assert fwd_outputs.shape == (batch_size, feat_out, sub_sampled_n_frames) + + class TestTransformerEncoder: @pytest.mark.unit def test_model_creation(self): @@ -137,7 +346,7 @@ def test_forward_cpu(self): lengths = torch.tensor([400, 300]) with torch.no_grad(): - out, out_lengths = model(x, lengths) + out, out_lengths = model(audio_signal=x, length=lengths) assert out.shape == (B, 64, T // 4) assert out_lengths[0].item() == T // 4 @@ -153,7 +362,7 @@ def test_forward_cpu_with_qk_norm(self): lengths = torch.tensor([200]) with torch.no_grad(): - out, _ = model(x, lengths) + out, _ = model(audio_signal=x, length=lengths) assert out.shape == (1, 64, 50) assert not torch.isnan(out).any() @@ -169,7 +378,7 @@ def test_forward_basic(self): model.eval() with torch.no_grad(): - out, out_lengths = model(x, lengths) + out, out_lengths = model(audio_signal=x, length=lengths) assert out.shape == (B, 64, T // 4) assert out_lengths[0].item() == T // 4 @@ -189,9 +398,10 @@ def test_forward_with_qk_norm(self): model.eval() with torch.no_grad(): - out, out_lengths = model(x, lengths) + out, out_lengths = model(audio_signal=x, length=lengths) assert out.shape == (B, 128, T // 8) + assert out_lengths[1].item() == 640 // 8 assert not torch.isnan(out).any() @pytest.mark.run_only_on('GPU') @@ -205,7 +415,7 @@ def test_forward_output_channels_first(self): model.eval() with torch.no_grad(): - out, _ = model(x, lengths) + out, _ = model(audio_signal=x, length=lengths) assert out.shape[1] == 64 # D dimension assert out.shape[2] == 200 // 4 # T dimension @@ -220,8 +430,8 @@ def test_eval_deterministic(self): lengths = torch.tensor([200], device='cuda') with torch.no_grad(): - out1, _ = model(x, lengths) - out2, _ = model(x, lengths) + out1, _ = model(audio_signal=x, length=lengths) + out2, _ = model(audio_signal=x, length=lengths) assert torch.allclose(out1, out2, atol=1e-6) @@ -241,8 +451,8 @@ def test_padding_does_not_affect_valid_output(self): lengths_long = torch.tensor([T_valid], device='cuda') with torch.no_grad(): - out_short, len_short = model(x_short, lengths_short) - out_long, len_long = model(x_long, lengths_long) + out_short, len_short = model(audio_signal=x_short, length=lengths_short) + out_long, len_long = model(audio_signal=x_long, length=lengths_long) assert len_short[0].item() == len_long[0].item() valid_t = len_short[0].item() @@ -257,7 +467,7 @@ def test_backward_pass(self): x = torch.randn(2, 80, 200, device='cuda', dtype=torch.bfloat16) lengths = torch.tensor([200, 160], device='cuda') - out, out_lengths = model(x, lengths) + out, _ = model(audio_signal=x, length=lengths) loss = out.sum() loss.backward() From ba9d40da8270703eaec01b408a5b4edec7e8739c Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 May 2026 15:38:58 -0700 Subject: [PATCH 2/5] Adding final touch up Signed-off-by: taejinp --- .../asr/modules/transformer_encoder.py | 73 ++++++++++++++++--- .../asr/test_transformer_encoder.py | 44 +++++------ 2 files changed, 84 insertions(+), 33 deletions(-) diff --git a/nemo/collections/asr/modules/transformer_encoder.py b/nemo/collections/asr/modules/transformer_encoder.py index 07b0b3113256..3debb3341260 100644 --- a/nemo/collections/asr/modules/transformer_encoder.py +++ b/nemo/collections/asr/modules/transformer_encoder.py @@ -34,7 +34,37 @@ @dataclass class TransformerEncoderConfig: - feat_in: int = 80 + """Configuration for ``TransformerEncoder`` and its sub-blocks. + + Args: + feat_in: Input feature dimension (e.g. number of mel bins). + d_model: Transformer encoder state dimension, i.e. the size of the residual stream that flows + through every block (token/frame embedding size, attention input/output size, and + feed-forward input/output size). Also known as ``hidden_size`` in HuggingFace + ``transformers`` configs and ``embed_dim``/``d_model`` in PyTorch's + ``nn.TransformerEncoderLayer``. + n_heads: Number of attention heads. + n_layers: Number of Transformer blocks. + drop_rate: Dropout probability applied inside attention and feed-forward sublayers. + qkv_bias: If True, add a learnable bias to the fused Q/K/V projection. Many modern ASR/LM + Transformers (e.g. HuggingFace Whisper) drop the bias on the K projection because a + constant K bias adds the same scalar to every key and is wiped out by softmax's + shift-invariance, making it a redundant parameter. Default ``False`` matches that style. + qk_norm: If True, apply per-head ``LayerNorm`` to Q and K before the dot product. Stabilizes + training by preventing exponential Q/K-norm growth and "attention entropy collapse" + (Henry et al. 2020; used in OLMo 2, Gemma 3, Qwen 3). Cheap, ~no-op for inference. + ff_expansion: Multiplier for the per-block FFN inner hidden size: + ``ffn_hidden_size = int(ff_expansion * d_model)``. Only widens the intermediate FFN + projection; FFN input/output stays at ``d_model``. Typical value ``4.0``; ``float`` + allows sub-1x experts for MoE. Equivalent to ``intermediate_size / hidden_size`` in + HuggingFace and ``dim_feedforward / d_model`` in PyTorch's ``nn.TransformerEncoderLayer``. + pre_block_norm: If True, apply ``LayerNorm`` to embeddings before the first Transformer block + (BERT/ViT-style). Set False to match pre-norm Transformers such as Whisper or GPT-2. + subsampling_factor: Frame-level subsampling factor performed by the pre-encoder. + attn_mode: Attention pattern. Currently only ``"full"`` (bidirectional) is supported. + Future modes: ``"causal"``, ``"lookahead"``, ``"local"``, ``"sliding_window"``. + """ + feat_in: int = 128 d_model: int = 512 n_heads: int = 8 n_layers: int = 17 @@ -44,8 +74,6 @@ class TransformerEncoderConfig: ff_expansion: float = 4.0 pre_block_norm: bool = True subsampling_factor: int = 4 - # Attention mode — currently only "full" is supported. - # Future: "causal", "lookahead", "local", "sliding_window" attn_mode: str = "full" @@ -100,6 +128,7 @@ def forward(self, x, block_mask=None): q = self.q_norm(q).to(v.dtype) k = self.k_norm(k).to(v.dtype) + # Use compiled FlexAttention for CUDA, fallback to unfused for CPU. attn_fn = flex_attention_compiled if q.is_cuda else flex_attention out = attn_fn(q, k, v, block_mask=block_mask) out = out.transpose(1, 2).contiguous().view(B, T, self.d_model) @@ -132,9 +161,13 @@ class TransformerEncoder(NeuralModule, Exportable, AccessMixin): Args: feat_in: Input feature dimension (number of mel bins). - d_model: Transformer hidden dimension. + d_model: Transformer encoder state dimension, i.e. the size of the residual stream that flows + through every block (token/frame embedding size, attention input/output size, and + feed-forward input/output size). Also known as ``hidden_size`` in HuggingFace + ``transformers`` configs and ``embed_dim``/``d_model`` in PyTorch's + ``nn.TransformerEncoderLayer``. n_heads: Number of attention heads. - n_layers: Number of transformer blocks. + n_layers: Number of Transformer blocks. feat_out: Output feature dimension. Defaults to ``d_model``. subsampling: Subsampling method. Supports ``feature_stacking`` for the Transformer-native ``FeatureStacking`` module, plus Conformer-style @@ -147,15 +180,30 @@ class TransformerEncoder(NeuralModule, Exportable, AccessMixin): drop_rate: Dropout probability. dropout_pre_encoder: Dropout probability after positional encoding. Defaults to ``drop_rate``. dropout_emb: Dropout probability for positional embeddings. - qkv_bias: Whether to use bias in Q/K/V projections. - qk_norm: Whether to apply per-head LayerNorm to Q and K before the dot product. - ff_expansion: Feed-forward expansion factor (float to support sub-1x for MoE). + qkv_bias: If True, add a learnable bias to the fused Q/K/V projection. Many modern ASR/LM + Transformers (e.g. HuggingFace Whisper) drop the bias on the K projection because a + constant K bias adds the same scalar to every key and is wiped out by softmax's + shift-invariance, making it a redundant parameter. Default ``False`` matches that style. + qk_norm: If True, apply per-head ``LayerNorm`` to Q and K before the dot product. Stabilizes + training by preventing exponential Q/K-norm growth and "attention entropy collapse" + (Henry et al. 2020; used in OLMo 2, Gemma 3, Qwen 3). Cheap, ~no-op for inference. + ff_expansion: Multiplier for the per-block FFN inner hidden size: + ``ffn_hidden_size = int(ff_expansion * d_model)``. Only widens the intermediate FFN + projection; FFN input/output stays at ``d_model``. Typical value ``4.0``; ``float`` + allows sub-1x experts for MoE. Equivalent to ``intermediate_size / hidden_size`` in + HuggingFace and ``dim_feedforward / d_model`` in PyTorch's ``nn.TransformerEncoderLayer``. pre_block_norm: If True (default), apply LayerNorm to embeddings before the first - transformer block (BERT/ViT-style). Set False to match pre-norm transformers + Transformer block (BERT/ViT-style). Set False to match pre-norm Transformers such as Whisper or GPT-2 — required when loading pretrained weights from those checkpoints. pos_emb_max_len: Initial maximum length for sinusoidal positional embeddings. - xscaling: Whether to scale embeddings by ``sqrt(d_model)`` before adding positions. + xscaling: If True, scale embeddings by ``sqrt(d_model)`` before adding positional encodings, + following "Attention Is All You Need" article. Originally intended to balance the magnitude + of small-variance token embeddings against unit-bounded sinusoidal positions and to keep + tied input/pre-softmax logits well-scaled. With modern unit-variance ``nn.Linear`` + pre-encoders and the LayerNorm directly after the positional sum, this scaling is + largely a no-op for activation magnitudes. Only meaningful when ``pre_block_norm=False`` + or when matching pretrained checkpoints that expect this scaling. stochastic_depth_drop_prob: Final-layer stochastic depth drop probability. stochastic_depth_mode: Stochastic depth schedule, ``linear`` or ``uniform``. stochastic_depth_start_layer: First 1-based layer index eligible for stochastic depth. @@ -211,7 +259,7 @@ def disabled_deployment_output_names(self): def __init__( self, - feat_in: int = 80, + feat_in: int = 128, d_model: int = 512, n_heads: int = 8, n_layers: int = 17, @@ -229,7 +277,7 @@ def __init__( ff_expansion: float = 4.0, pre_block_norm: bool = True, pos_emb_max_len: int = 5000, - xscaling: bool = True, + xscaling: bool = False, stochastic_depth_drop_prob: float = 0.0, stochastic_depth_mode: str = "linear", stochastic_depth_start_layer: int = 1, @@ -336,6 +384,7 @@ def forward(self, audio_signal, length, bypass_pre_encode=False): or ``(B, T, D)`` pre-encoded embeddings when ``bypass_pre_encode=True``. length: (B,) — valid frame counts per sample. bypass_pre_encode: If true, skip the pre-encoder and consume frame-level embeddings. + Returns: x: (B, D, T') — encoded representation (channels-first). length: (B,) — output lengths after subsampling. diff --git a/tests/collections/asr/test_transformer_encoder.py b/tests/collections/asr/test_transformer_encoder.py index 1307f030c832..fe19eaa9b94b 100644 --- a/tests/collections/asr/test_transformer_encoder.py +++ b/tests/collections/asr/test_transformer_encoder.py @@ -27,7 +27,7 @@ class TestTransformerEncoderConfig: @pytest.mark.unit def test_default_config(self): cfg = TransformerEncoderConfig() - assert cfg.feat_in == 80 + assert cfg.feat_in == 128 assert cfg.d_model == 512 assert cfg.n_heads == 8 assert cfg.n_layers == 17 @@ -156,7 +156,9 @@ def test_stochastic_depth_model_creation(self): ) with pytest.raises(ValueError, match="stochastic_depth_mode has to be one of"): - TransformerEncoder(feat_in=10, n_layers=n_layers, d_model=4, n_heads=2, feat_out=8, stochastic_depth_mode="weird") + TransformerEncoder( + feat_in=10, n_layers=n_layers, d_model=4, n_heads=2, feat_out=8, stochastic_depth_mode="weird" + ) for start_layer in [-1, 0, 5]: with pytest.raises(ValueError, match="stochastic_depth_start_layer has to be in"): @@ -311,21 +313,21 @@ def test_error_shape_invalid_bypass_pre_encode_forward(self): class TestTransformerEncoder: @pytest.mark.unit def test_model_creation(self): - model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2) total_params = sum(p.numel() for p in model.parameters()) assert total_params > 0 assert len(model.layers) == 2 @pytest.mark.unit def test_model_creation_with_qk_norm(self): - model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, qk_norm=True) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, qk_norm=True) attn = model.layers[0].attn assert hasattr(attn, 'q_norm') assert hasattr(attn, 'k_norm') @pytest.mark.unit def test_model_creation_without_qk_norm(self): - model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, qk_norm=False) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, qk_norm=False) attn = model.layers[0].attn assert not hasattr(attn, 'q_norm') assert not hasattr(attn, 'k_norm') @@ -333,15 +335,15 @@ def test_model_creation_without_qk_norm(self): @pytest.mark.unit def test_invalid_attn_mode(self): with pytest.raises(ValueError, match="not yet supported"): - TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, attn_mode="causal") + TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, attn_mode="causal") @pytest.mark.unit def test_forward_cpu(self): """Forward pass on CPU uses unfused FlexAttention fallback.""" - model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0, subsampling_factor=4) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0, subsampling_factor=4) model.eval() - B, C, T = 2, 80, 400 + B, C, T = 2, 128, 400 x = torch.randn(B, C, T) lengths = torch.tensor([400, 300]) @@ -355,10 +357,10 @@ def test_forward_cpu(self): @pytest.mark.unit def test_forward_cpu_with_qk_norm(self): - model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0, qk_norm=True) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0, qk_norm=True) model.eval() - x = torch.randn(1, 80, 200) + x = torch.randn(1, 128, 200) lengths = torch.tensor([200]) with torch.no_grad(): @@ -369,10 +371,10 @@ def test_forward_cpu_with_qk_norm(self): @pytest.mark.run_only_on('GPU') def test_forward_basic(self): - model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0, subsampling_factor=4) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0, subsampling_factor=4) model = model.cuda().to(torch.bfloat16) - B, C, T = 2, 80, 400 + B, C, T = 2, 128, 400 x = torch.randn(B, C, T, device='cuda', dtype=torch.bfloat16) lengths = torch.tensor([400, 300], device='cuda') @@ -407,10 +409,10 @@ def test_forward_with_qk_norm(self): @pytest.mark.run_only_on('GPU') def test_forward_output_channels_first(self): """Verify output is (B, D, T) channels-first as expected by downstream decoders.""" - model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=1, drop_rate=0.0) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=1, drop_rate=0.0) model = model.cuda().to(torch.bfloat16) - x = torch.randn(1, 80, 200, device='cuda', dtype=torch.bfloat16) + x = torch.randn(1, 128, 200, device='cuda', dtype=torch.bfloat16) lengths = torch.tensor([200], device='cuda') model.eval() @@ -423,10 +425,10 @@ def test_forward_output_channels_first(self): @pytest.mark.run_only_on('GPU') def test_eval_deterministic(self): """In eval mode with no dropout, repeated forward passes should produce identical output.""" - model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0) model = model.cuda().to(torch.bfloat16).eval() - x = torch.randn(1, 80, 200, device='cuda', dtype=torch.bfloat16) + x = torch.randn(1, 128, 200, device='cuda', dtype=torch.bfloat16) lengths = torch.tensor([200], device='cuda') with torch.no_grad(): @@ -438,15 +440,15 @@ def test_eval_deterministic(self): @pytest.mark.run_only_on('GPU') def test_padding_does_not_affect_valid_output(self): """Padding frames should not change the encoded output at valid positions.""" - model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0) model = model.cuda().to(torch.bfloat16).eval() T_valid = 200 - x_short = torch.randn(1, 80, T_valid, device='cuda', dtype=torch.bfloat16) + x_short = torch.randn(1, 128, T_valid, device='cuda', dtype=torch.bfloat16) lengths_short = torch.tensor([T_valid], device='cuda') T_padded = 400 - x_long = torch.zeros(1, 80, T_padded, device='cuda', dtype=torch.bfloat16) + x_long = torch.zeros(1, 128, T_padded, device='cuda', dtype=torch.bfloat16) x_long[:, :, :T_valid] = x_short lengths_long = torch.tensor([T_valid], device='cuda') @@ -461,10 +463,10 @@ def test_padding_does_not_affect_valid_output(self): @pytest.mark.run_only_on('GPU') def test_backward_pass(self): - model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0) model = model.cuda().to(torch.bfloat16).train() - x = torch.randn(2, 80, 200, device='cuda', dtype=torch.bfloat16) + x = torch.randn(2, 128, 200, device='cuda', dtype=torch.bfloat16) lengths = torch.tensor([200, 160], device='cuda') out, _ = model(audio_signal=x, length=lengths) From 539860402976a7d094e0f8c4b8ccabd01cd6da1f Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 May 2026 15:49:21 -0700 Subject: [PATCH 3/5] Fixing Black issue Signed-off-by: taejinp --- nemo/collections/asr/modules/transformer_encoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/asr/modules/transformer_encoder.py b/nemo/collections/asr/modules/transformer_encoder.py index 3debb3341260..147ae70044cb 100644 --- a/nemo/collections/asr/modules/transformer_encoder.py +++ b/nemo/collections/asr/modules/transformer_encoder.py @@ -64,6 +64,7 @@ class TransformerEncoderConfig: attn_mode: Attention pattern. Currently only ``"full"`` (bidirectional) is supported. Future modes: ``"causal"``, ``"lookahead"``, ``"local"``, ``"sliding_window"``. """ + feat_in: int = 128 d_model: int = 512 n_heads: int = 8 From c19ca033559c4b3c2a1e49c2e18bb03f23e34cf5 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Wed, 20 May 2026 22:59:51 -0700 Subject: [PATCH 4/5] Adding relative position encoding and transformer-ctc yaml Signed-off-by: Taejin Park --- .../conf/transformer/transformer_ctc_bpe.yaml | 216 ++++++++++++++++++ .../asr/modules/transformer_encoder.py | 190 +++++++++++++-- .../asr/test_transformer_encoder.py | 182 +++++++++++++-- 3 files changed, 559 insertions(+), 29 deletions(-) create mode 100644 examples/asr/conf/transformer/transformer_ctc_bpe.yaml diff --git a/examples/asr/conf/transformer/transformer_ctc_bpe.yaml b/examples/asr/conf/transformer/transformer_ctc_bpe.yaml new file mode 100644 index 000000000000..ccbb7f774d3e --- /dev/null +++ b/examples/asr/conf/transformer/transformer_ctc_bpe.yaml @@ -0,0 +1,216 @@ +# It contains the default values for training a Transformer-CTC ASR model with CTC loss and sub-word encoding. +# +# This config is the Transformer counterpart of ``fast-conformer_ctc_bpe.yaml``: same +# preprocessor / spec-augment / decoder / optimiser / trainer / exp_manager sections, but the +# encoder is the FlexAttention-based ``TransformerEncoder`` defined in +# ``nemo/collections/asr/modules/transformer_encoder.py``. By default it uses +# ``self_attention_model: rel_pos`` (Transformer-XL relative positional encoding wired into +# FlexAttention via a ``score_mod`` closure and a ``Q + pos_bias_u`` query rewrite). +# +# Use trainer.precision=bf16 on GPUs that support it; the FlexAttention kernel is compiled +# with ``torch.compile(dynamic=True)`` and works on CUDA out of the box. On CPU it falls back +# to the un-fused FlexAttention path. + +name: "Transformer-CTC-BPE" + +model: + sample_rate: 16000 + log_prediction: true # enables logging sample predictions in the output during training + ctc_reduction: 'mean_volume' + skip_nan_grad: false + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + use_start_end_token: false + num_workers: 8 + pin_memory: true + + test_ds: + manifest_filepath: null + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + use_start_end_token: false + num_workers: 8 + pin_memory: true + + # recommend vocab size of 128 or 256 when training on ~1k hr datasets and 1k vocab size on 10+k hr datasets + # you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + pad_value: 0.0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + # you may use lower time_masks for smaller models to have a faster convergence + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.TransformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + # n_layers=31 chosen so the encoder has ~108.1M params, matching the FastConformer baseline + # (109.5M at d=512, L=17, conv_kernel=9, ff_x4) to within ~1.4%. A Conformer layer carries + # an extra convolution module and a sandwich-pair of FFNs, so the post-norm Transformer + # needs more layers (not more heads — heads only partition d_model and add no parameters) + # to reach the same capacity. + n_layers: 31 + d_model: 512 + n_heads: 8 + + # Sub-sampling params (Conformer-style options are supported; ``feature_stacking`` is the + # Transformer-native default in the module itself, but for parity with the Canary-flash + # baseline we keep dw_striding x8 here.) + subsampling: dw_striding # feature_stacking, stacking, stacking_norm, vggnet, striding, dw_striding, striding_conv1d, dw_striding_conv1d + subsampling_factor: 8 # must be power of 2 for striding / vggnet variants + subsampling_conv_channels: 256 # -1 sets it to d_model + subsampling_conv_chunking_factor: 1 # 1 = auto-chunking, -1 = no chunking, otherwise power-of-2 + causal_downsampling: false + + # Feed-forward module's params + ff_expansion: 4.0 # FFN hidden = ff_expansion * d_model + + # Self-attention / positional encoding + # - ``rel_pos`` (default): Transformer-XL relative positional encoding, wired into + # FlexAttention via a ``score_mod`` closure plus a ``Q + pos_bias_u`` query rewrite. + # - ``abs_pos``: sinusoidal absolute positional encoding added before the first block. + # - ``no_pos`` (or ``null``): no positional encoding; pre-encoder output flows directly + # into ``embed_norm`` and the Transformer blocks. + self_attention_model: rel_pos + pos_emb_max_len: 5000 + xscaling: false # scale embeddings by sqrt(d_model); mostly a no-op when pre_block_norm=true + + # Attention/FFN block options + qkv_bias: false # add a learnable bias to the fused Q/K/V projection (Whisper-style: false) + qk_norm: false # per-head LayerNorm on Q and K before the dot product (OLMo 2 / Gemma 3 style) + pre_block_norm: true # BERT/ViT-style: LayerNorm on embeddings before the first block + attn_mode: full # currently only "full" (bidirectional) is supported + + # Regularization + drop_rate: 0.1 # dropout inside attention/FFN sublayers (corresponds to conformer's ``dropout``) + dropout_pre_encoder: 0.1 # dropout applied after positional encoding (unused when self_attention_model=no_pos) + dropout_emb: 0.0 # dropout for the positional embeddings (unused when self_attention_model=no_pos) + + # Set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + # When true, sync max-audio-length across distributed ranks before extending positional buffers + sync_max_audio_length: true + + decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: null + num_classes: -1 + vocabulary: [] + + # config for InterCTC loss: https://arxiv.org/abs/2102.03216 + # specify loss weights and which layers to use for InterCTC + # e.g., to reproduce the paper results, set loss_weights: [0.3] + # and apply_at_layers: [8] (assuming 18 layers). Note that final + # layer loss coefficient is automatically adjusted (to 0.7 in above example) + interctc: + loss_weights: [] + apply_at_layers: [] + + optim: + name: adamw + lr: 1e-3 + # optimizer arguments + betas: [0.9, 0.98] + # less necessity for weight_decay as we already have large augmentations with SpecAug + # you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used + # weight decay of 0.0 with lr of 2.0 also works fine + weight_decay: 1e-3 + + # scheduler setup + sched: + name: CosineAnnealing + # scheduler config override + warmup_steps: 15000 + warmup_ratio: null + min_lr: 1e-4 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 1000 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + gradient_as_bucket_view: true + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to True to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/nemo/collections/asr/modules/transformer_encoder.py b/nemo/collections/asr/modules/transformer_encoder.py index 147ae70044cb..61f1a4d21e14 100644 --- a/nemo/collections/asr/modules/transformer_encoder.py +++ b/nemo/collections/asr/modules/transformer_encoder.py @@ -15,12 +15,17 @@ import math from collections import OrderedDict from dataclasses import dataclass +from typing import Optional import torch import torch.nn as nn from torch.nn.attention.flex_attention import create_block_mask, flex_attention -from nemo.collections.asr.parts.submodules.multi_head_attention import PositionalEncoding +from nemo.collections.asr.parts.submodules.multi_head_attention import ( + PositionalEncoding, + RelPositionalEncoding, + RelPositionMultiHeadAttention, +) from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling, FeatureStacking, StackingSubsampling from nemo.collections.asr.parts.utils.regularization_utils import compute_stochastic_depth_drop_probs from nemo.core.classes.common import typecheck @@ -63,6 +68,18 @@ class TransformerEncoderConfig: subsampling_factor: Frame-level subsampling factor performed by the pre-encoder. attn_mode: Attention pattern. Currently only ``"full"`` (bidirectional) is supported. Future modes: ``"causal"``, ``"lookahead"``, ``"local"``, ``"sliding_window"``. + self_attention_model: Positional encoding / attention scoring scheme. + + - ``"rel_pos"`` (default): Transformer-XL relative positional encoding + (https://arxiv.org/abs/1901.02860). The (b)+(d) cross/positional bias is computed + from the relative-position embedding and injected into FlexAttention via a + ``score_mod`` closure; the (c) global-content bias is folded into the query as + ``Q + pos_bias_u``. + - ``"abs_pos"``: sinusoidal absolute positional encoding added to embeddings + before the first block; standard scaled dot-product attention. + - ``"no_pos"`` (or ``None``): no positional encoding at all. The pre-encoder output + is consumed directly by the Transformer blocks. ``xscaling``, ``pos_emb_max_len``, + ``dropout_pre_encoder`` and ``dropout_emb`` are unused in this mode. """ feat_in: int = 128 @@ -76,6 +93,7 @@ class TransformerEncoderConfig: pre_block_norm: bool = True subsampling_factor: int = 4 attn_mode: str = "full" + self_attention_model: str = "rel_pos" def _make_padding_mod(lengths): @@ -109,6 +127,7 @@ def __init__(self, cfg: TransformerEncoderConfig): self.n_heads = cfg.n_heads self.head_dim = cfg.d_model // cfg.n_heads self.d_model = cfg.d_model + self.self_attention_model = cfg.self_attention_model self.w_qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=cfg.qkv_bias) self.out_proj = nn.Linear(cfg.d_model, cfg.d_model) @@ -118,7 +137,86 @@ def __init__(self, cfg: TransformerEncoderConfig): self.q_norm = nn.LayerNorm(self.head_dim) self.k_norm = nn.LayerNorm(self.head_dim) - def forward(self, x, block_mask=None): + # Transformer-XL relative-position parameters (matrix b and matrix d from + # https://arxiv.org/abs/1901.02860 Section 3.3). The "matrix c" term `u @ K^T` is + # absorbed by passing `Q + pos_bias_u` as the query to FlexAttention. + if self.self_attention_model == "rel_pos": + self.linear_pos = nn.Linear(cfg.d_model, cfg.d_model, bias=False) + self.pos_bias_u = nn.Parameter(torch.zeros(self.n_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.n_heads, self.head_dim)) + else: + self.linear_pos = None + self.pos_bias_u = None + self.pos_bias_v = None + + # Per-forward Transformer-XL (b)+(d) bias of shape (B, H, T, T), set by ``forward`` and + # read by ``_rel_pos_score_mod`` while FlexAttention is executing. + self._rel_pos_bias = None + + def _rel_pos_score_mod(self, score, b, h, q_idx, kv_idx): + """FlexAttention ``score_mod`` adding the Transformer-XL (b)+(d) bias. + + FlexAttention's ``score_mod`` API expects a callable with a fixed signature, so the + per-forward bias tensor is passed in via ``self._rel_pos_bias`` rather than as an + explicit argument; ``forward`` populates that attribute immediately before invoking + ``flex_attention``. + """ + return score + self._rel_pos_bias[b, h, q_idx, kv_idx] + + def _rel_shift(self, x): + """Transformer-XL relative-position shift. + + Delegates to ``RelPositionMultiHeadAttention.rel_shift`` (which does not reference + ``self``) so the logic lives in a single place — NeMo's existing reference + implementation in ``parts/submodules/multi_head_attention.py``. + """ + return RelPositionMultiHeadAttention.rel_shift(None, x) + + def _build_rel_pos_score_mod(self, q, pos_emb): + """Build the FlexAttention inputs that realize Transformer-XL relative attention. + + Implements the (b), (c), (d) terms of Transformer-XL Section 3.3 + (https://arxiv.org/abs/1901.02860) on top of FlexAttention: + + - Matrices (b) + (d) — the position-dependent score bias ``(Q + v) @ R^T`` rel- + shifted into ``(q_idx, kv_idx)`` coordinates — are precomputed into a + ``(B, H, T, T)`` tensor, scaled by ``1/sqrt(D)`` (to match FlexAttention's + already-scaled ``QK^T`` scores), and stashed on ``self._rel_pos_bias``. The + bound closure ``self._rel_pos_score_mod`` reads that buffer while FlexAttention + is executing. The state-passing detour is necessary because FlexAttention's + ``score_mod`` API fixes the callable signature, so the per-forward tensor + cannot be threaded through as an explicit argument. + - Matrix (c) — the global-content bias ``u @ K^T`` — is folded into FlexAttention + by rewriting the query as ``Q + pos_bias_u``, which is returned. + + Args: + q: Query tensor with shape ``(B, H, T, D)``. + pos_emb: Relative positional embedding ``(1, 2T - 1, d_model)`` produced by + ``RelPositionalEncoding``. + + Returns: + score_mod: Callable to pass as ``flex_attention(..., score_mod=...)``. + q_with_bias_u: ``Q + pos_bias_u`` — the (c) "matrix c" query rewrite. + """ + H, D = self.n_heads, self.head_dim + T = q.size(-2) + # pos_emb: (1, 2T - 1, d_model) -> p: (1, H, 2T - 1, D) + p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, H, D).transpose(1, 2) + # pos_bias_{u,v}: (H, D) -> (1, H, 1, D) so they broadcast over the (B, H, T, D) + # Q tensor against the head/depth axes rather than (incorrectly) against time. + bias_u = self.pos_bias_u.view(1, H, 1, D) + bias_v = self.pos_bias_v.view(1, H, 1, D) + # Matrix b + d: ((Q + v) @ R^T) shifted into (q_idx, kv_idx) space, then scaled + # by 1/sqrt(D) so it can be added directly to FlexAttention's already-scaled scores. + q_with_bias_v = q + bias_v + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) # (B, H, T, 2T - 1) + # rel_shift converts absolute-relative-position columns into (query, key) columns; + # keep the first T to land in (B, H, T, T) bias space. + self._rel_pos_bias = self._rel_shift(matrix_bd)[..., :T] * (D ** -0.5) + # Matrix c: fold u @ K^T into FlexAttention by rewriting Q as (Q + u). + return self._rel_pos_score_mod, q + bias_u + + def forward(self, x, block_mask=None, pos_emb=None): B, T, _ = x.shape H, D = self.n_heads, self.head_dim @@ -129,9 +227,20 @@ def forward(self, x, block_mask=None): q = self.q_norm(q).to(v.dtype) k = self.k_norm(k).to(v.dtype) - # Use compiled FlexAttention for CUDA, fallback to unfused for CPU. + score_mod = None + if self.self_attention_model == "rel_pos": + if pos_emb is None: + raise ValueError("MultiHeadAttention with self_attention_model='rel_pos' requires pos_emb.") + score_mod, q = self._build_rel_pos_score_mod(q, pos_emb) + + if q.is_cuda and D < 16: + raise ValueError( + "PyTorch FlexAttention CUDA backend requires per-head embedding dimension >= 16, " + f"but got head_dim={D} from d_model={self.d_model}, n_heads={self.n_heads}." + ) + attn_fn = flex_attention_compiled if q.is_cuda else flex_attention - out = attn_fn(q, k, v, block_mask=block_mask) + out = attn_fn(q, k, v, block_mask=block_mask, score_mod=score_mod) out = out.transpose(1, 2).contiguous().view(B, T, self.d_model) return self.out_proj(out) @@ -145,8 +254,8 @@ def __init__(self, cfg: TransformerEncoderConfig): self.norm2 = nn.LayerNorm(cfg.d_model) self.ffn = FeedForward(cfg) - def forward(self, x, block_mask=None): - x = x + self.drop(self.attn(self.norm1(x), block_mask=block_mask)) + def forward(self, x, block_mask=None, pos_emb=None): + x = x + self.drop(self.attn(self.norm1(x), block_mask=block_mask, pos_emb=pos_emb)) x = x + self.drop(self.ffn(self.norm2(x))) return x @@ -167,6 +276,7 @@ class TransformerEncoder(NeuralModule, Exportable, AccessMixin): feed-forward input/output size). Also known as ``hidden_size`` in HuggingFace ``transformers`` configs and ``embed_dim``/``d_model`` in PyTorch's ``nn.TransformerEncoderLayer``. + n_heads: Number of attention heads. n_layers: Number of Transformer blocks. feat_out: Output feature dimension. Defaults to ``d_model``. @@ -197,6 +307,23 @@ class TransformerEncoder(NeuralModule, Exportable, AccessMixin): Transformer block (BERT/ViT-style). Set False to match pre-norm Transformers such as Whisper or GPT-2 — required when loading pretrained weights from those checkpoints. + self_attention_model: Type of positional encoding and attention scoring scheme. Mirrors + the Conformer encoder's ``self_attention_model`` choices, plus a ``"no_pos"`` option: + + - ``"rel_pos"`` (default): Transformer-XL relative positional encoding + (https://arxiv.org/abs/1901.02860). The relative-position bias is computed in each + layer and injected into FlexAttention via a ``score_mod`` closure (the (b)+(d) + terms) plus a ``Q + pos_bias_u`` query rewrite (the (c) term), so the kernel stays + FlexAttention. + - ``"abs_pos"``: sinusoidal absolute positional encoding added to the embeddings + before the first block; standard ``Q @ K^T`` attention via FlexAttention. + - ``"no_pos"`` (or ``None``): no positional encoding at all — pre-encoder output + flows straight into ``embed_norm`` and the Transformer blocks. ``xscaling``, + ``pos_emb_max_len``, ``dropout_pre_encoder`` and ``dropout_emb`` have no effect + in this mode. ``None`` is accepted as a YAML-friendly alias for ``"no_pos"`` + (an unset field in a config maps to ``None``). + + ``"rel_pos_local_attn"`` is not implemented yet. pos_emb_max_len: Initial maximum length for sinusoidal positional embeddings. xscaling: If True, scale embeddings by ``sqrt(d_model)`` before adding positional encodings, following "Attention Is All You Need" article. Originally intended to balance the magnitude @@ -277,6 +404,7 @@ def __init__( qk_norm: bool = False, ff_expansion: float = 4.0, pre_block_norm: bool = True, + self_attention_model: Optional[str] = "rel_pos", pos_emb_max_len: int = 5000, xscaling: bool = False, stochastic_depth_drop_prob: float = 0.0, @@ -290,6 +418,16 @@ def __init__( raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads}).") if attn_mode != "full": raise ValueError(f"attn_mode='{attn_mode}' is not yet supported. Currently only 'full' is available.") + # ``None`` is accepted as a YAML-friendly alias for ``"no_pos"`` (an unset field in a + # config simply maps to None) — normalize here so the rest of the module only deals with + # the string form. + if self_attention_model is None: + self_attention_model = "no_pos" + if self_attention_model not in ("abs_pos", "rel_pos", "no_pos"): + raise ValueError( + f"self_attention_model='{self_attention_model}' is not supported. " + "Currently only 'abs_pos', 'rel_pos', and 'no_pos' (or None) are available." + ) if dropout_pre_encoder is None: dropout_pre_encoder = drop_rate if subsampling == 'feature-stacking': @@ -309,6 +447,7 @@ def __init__( pre_block_norm=pre_block_norm, subsampling_factor=subsampling_factor, attn_mode=attn_mode, + self_attention_model=self_attention_model, ) self.d_model = d_model self.n_layers = n_layers @@ -317,6 +456,7 @@ def __init__( self.subsampling_factor = subsampling_factor self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor self.sync_max_audio_length = sync_max_audio_length + self.self_attention_model = self_attention_model if subsampling_conv_channels == -1: subsampling_conv_channels = d_model @@ -350,13 +490,24 @@ def __init__( else: self.xscale = None self.pos_emb_max_len = pos_emb_max_len - self.pos_enc = PositionalEncoding( - d_model=d_model, - dropout_rate=dropout_pre_encoder, - max_len=pos_emb_max_len, - xscale=self.xscale, - dropout_rate_emb=dropout_emb, - ) + if self_attention_model == "rel_pos": + self.pos_enc = RelPositionalEncoding( + d_model=d_model, + dropout_rate=dropout_pre_encoder, + max_len=pos_emb_max_len, + xscale=self.xscale, + dropout_rate_emb=dropout_emb, + ) + elif self_attention_model == "abs_pos": + self.pos_enc = PositionalEncoding( + d_model=d_model, + dropout_rate=dropout_pre_encoder, + max_len=pos_emb_max_len, + xscale=self.xscale, + dropout_rate_emb=dropout_emb, + ) + else: # "no_pos" + self.pos_enc = None self.embed_norm = nn.LayerNorm(d_model) if pre_block_norm else nn.Identity() self.layers = nn.ModuleList([TransformerBlock(cfg) for _ in range(n_layers)]) self.final_norm = nn.LayerNorm(d_model) @@ -430,7 +581,10 @@ def forward_internal(self, audio_signal, length, bypass_pre_encode=False): x = audio_signal length = length.to(torch.int64) - x, _ = self.pos_enc(x=x) + if self.pos_enc is not None: + x, pos_emb = self.pos_enc(x=x) + else: # "no_pos": pre-encoder output flows in unchanged + pos_emb = None x = self.embed_norm(x) B, T, _ = x.shape @@ -439,9 +593,13 @@ def forward_internal(self, audio_signal, length, bypass_pre_encode=False): else: block_mask = None + # For ``abs_pos`` the positional information is already baked into ``x``, so we don't + # need to thread ``pos_emb`` through each layer; only ``rel_pos`` consumes it. + layer_pos_emb = pos_emb if self.self_attention_model == "rel_pos" else None + for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)): original_signal = x - x = layer(x, block_mask=block_mask) + x = layer(x, block_mask=block_mask, pos_emb=layer_pos_emb) if self.training and drop_prob > 0.0: should_drop = torch.rand(1, device=x.device) < drop_prob @@ -488,6 +646,8 @@ def update_max_seq_length(self, seq_length: int, device): def set_max_audio_length(self, max_audio_length): """Sets maximum input length and extends positional encodings if needed.""" self.max_audio_length = max_audio_length + if self.pos_enc is None: # "no_pos" mode has no buffer to extend + return device = next(self.parameters()).device dtype = next(self.parameters()).dtype self.pos_enc.extend_pe(max_audio_length, device, dtype) diff --git a/tests/collections/asr/test_transformer_encoder.py b/tests/collections/asr/test_transformer_encoder.py index fe19eaa9b94b..4b0d4cb3b3a1 100644 --- a/tests/collections/asr/test_transformer_encoder.py +++ b/tests/collections/asr/test_transformer_encoder.py @@ -38,15 +38,19 @@ def test_default_config(self): assert cfg.pre_block_norm is True assert cfg.subsampling_factor == 4 assert cfg.attn_mode == "full" + assert cfg.self_attention_model == "rel_pos" @pytest.mark.unit def test_custom_config(self): - cfg = TransformerEncoderConfig(feat_in=128, d_model=1280, n_heads=16, n_layers=32, qk_norm=True) + cfg = TransformerEncoderConfig( + feat_in=128, d_model=1280, n_heads=16, n_layers=32, qk_norm=True, self_attention_model="abs_pos" + ) assert cfg.feat_in == 128 assert cfg.d_model == 1280 assert cfg.n_heads == 16 assert cfg.n_layers == 32 assert cfg.qk_norm is True + assert cfg.self_attention_model == "abs_pos" class TestFeatureStacking: @@ -171,9 +175,14 @@ def test_stochastic_depth_model_creation(self): stochastic_depth_start_layer=start_layer, ) - @pytest.mark.pleasefixme def test_stochastic_depth_forward(self): - """Testing that forward works and we get randomness during training, but not during eval.""" + """Testing that forward works and we get randomness during training, but not during eval. + + The forwards are wrapped in ``torch.no_grad()`` because FlexAttention's CPU path raises + ``NotImplementedError`` if any input requires gradients. ``torch.no_grad()`` does not + touch ``model.training``, so the stochastic-depth Bernoulli branch (driven by + ``torch.rand(1) < drop_prob``, not autograd) still fires in train mode. + """ random_input = torch.rand((1, 2, 16)) random_length = torch.tensor([16], dtype=torch.int64) @@ -190,8 +199,9 @@ def test_stochastic_depth_forward(self): ) model.train() outputs = [None] * 5 - for i in range(5): - outputs[i] = model(audio_signal=random_input, length=random_length)[0] + with torch.no_grad(): + for i in range(5): + outputs[i] = model(audio_signal=random_input, length=random_length)[0] # checking that not all outputs are the same num_diff = 0 for i in range(1, 5): @@ -201,8 +211,9 @@ def test_stochastic_depth_forward(self): model.eval() outputs = [None] * 5 - for i in range(5): - outputs[i] = model(audio_signal=random_input, length=random_length)[0] + with torch.no_grad(): + for i in range(5): + outputs[i] = model(audio_signal=random_input, length=random_length)[0] # checking that not all outputs are the same num_diff = 0 for i in range(1, 5): @@ -215,7 +226,13 @@ class TestBypassPreEncode: """Testing bypass pre-encode functionality.""" def test_bypass_pre_encode_forward(self): - """Testing that forward works with "bypass pre-encode" mode.""" + """Testing that forward works with "bypass pre-encode" mode. + + Forwards are wrapped in ``torch.no_grad()`` so the test runs on CPU as well as GPU: + FlexAttention's CPU path refuses to run when any input requires gradients (parameters + of an ``nn.Module`` do by default), and we are only checking output shapes here, never + calling ``.backward()``. + """ # For pre-encoded embeddings, the shape is (batch_size, n_frames, emb_dim) batch_size = 2 n_frames, emb_dim, feat_out = 17, 16, 8 @@ -234,11 +251,13 @@ def test_bypass_pre_encode_forward(self): dropout_emb=0.0, ) model.train() - fwd_outputs = model(audio_signal=random_input, length=random_length, bypass_pre_encode=True)[0] + with torch.no_grad(): + fwd_outputs = model(audio_signal=random_input, length=random_length, bypass_pre_encode=True)[0] assert fwd_outputs.shape == (batch_size, feat_out, n_frames) model.eval() - fwd_outputs = model(audio_signal=random_input, length=random_length, bypass_pre_encode=True)[0] + with torch.no_grad(): + fwd_outputs = model(audio_signal=random_input, length=random_length, bypass_pre_encode=True)[0] assert fwd_outputs.shape == (batch_size, feat_out, n_frames) def test_error_shape_invalid_bypass_pre_encode_forward(self): @@ -283,12 +302,19 @@ def test_error_shape_invalid_bypass_pre_encode_forward(self): model(audio_signal=feat_input, length=input_length, bypass_pre_encode=True) # Test with bypass_pre_encode = True, given the correct input pre_encode_input. + # NB: forwards that actually reach FlexAttention are wrapped in ``torch.no_grad()`` so + # the test passes on CPU (FlexAttention's CPU path refuses inputs that require grad). + # The ``pytest.raises(ValueError)`` blocks above/below intentionally do *not* need this + # wrapper because the shape check in ``TransformerEncoder.forward()`` raises before any + # attention computation. model.train() - fwd_outputs = model(audio_signal=pre_encode_input, length=input_length, bypass_pre_encode=True)[0] + with torch.no_grad(): + fwd_outputs = model(audio_signal=pre_encode_input, length=input_length, bypass_pre_encode=True)[0] assert fwd_outputs.shape == (batch_size, feat_out, n_frames) model.eval() - fwd_outputs = model(audio_signal=pre_encode_input, length=input_length, bypass_pre_encode=True)[0] + with torch.no_grad(): + fwd_outputs = model(audio_signal=pre_encode_input, length=input_length, bypass_pre_encode=True)[0] assert fwd_outputs.shape == (batch_size, feat_out, n_frames) # Test with bypass_pre_encode = False, should be feat_input but given pre_encode_input. @@ -302,11 +328,13 @@ def test_error_shape_invalid_bypass_pre_encode_forward(self): # Test with bypass_pre_encode = False, given the correct input feat_input. model.train() - fwd_outputs = model(audio_signal=feat_input, length=input_length, bypass_pre_encode=False)[0] + with torch.no_grad(): + fwd_outputs = model(audio_signal=feat_input, length=input_length, bypass_pre_encode=False)[0] assert fwd_outputs.shape == (batch_size, feat_out, sub_sampled_n_frames) model.eval() - fwd_outputs = model(audio_signal=feat_input, length=input_length, bypass_pre_encode=False)[0] + with torch.no_grad(): + fwd_outputs = model(audio_signal=feat_input, length=input_length, bypass_pre_encode=False)[0] assert fwd_outputs.shape == (batch_size, feat_out, sub_sampled_n_frames) @@ -476,3 +504,129 @@ def test_backward_pass(self): for name, param in model.named_parameters(): assert param.grad is not None, f"No gradient for {name}" assert not torch.isnan(param.grad).any(), f"NaN gradient for {name}" + + +class TestSelfAttentionModel: + """Tests for the ``self_attention_model`` positional encoding option.""" + + @pytest.mark.unit + def test_default_is_rel_pos(self): + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2) + assert model.self_attention_model == "rel_pos" + + @pytest.mark.unit + @pytest.mark.parametrize("mode", ["abs_pos", "rel_pos", "no_pos"]) + def test_valid_modes_are_accepted(self, mode): + model = TransformerEncoder( + feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model=mode + ) + assert model.self_attention_model == mode + + @pytest.mark.unit + def test_none_aliases_no_pos(self): + """Passing ``self_attention_model=None`` must be equivalent to ``"no_pos"``.""" + model = TransformerEncoder( + feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model=None + ) + assert model.self_attention_model == "no_pos" + assert model.pos_enc is None + + @pytest.mark.unit + def test_invalid_mode_raises(self): + with pytest.raises(ValueError, match="not supported"): + TransformerEncoder( + feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model="rel_pos_local_attn" + ) + + @pytest.mark.unit + def test_rel_pos_attention_params_allocated(self): + """rel_pos mode allocates the Transformer-XL bias parameters per attention layer.""" + d_model, n_heads, n_layers = 64, 4, 2 + model = TransformerEncoder( + feat_in=128, d_model=d_model, n_heads=n_heads, n_layers=n_layers, self_attention_model="rel_pos" + ) + head_dim = d_model // n_heads + assert model.pos_enc is not None + for layer in model.layers: + attn = layer.attn + assert attn.linear_pos is not None + assert attn.pos_bias_u is not None + assert attn.pos_bias_v is not None + assert attn.pos_bias_u.shape == (n_heads, head_dim) + assert attn.pos_bias_v.shape == (n_heads, head_dim) + + @pytest.mark.unit + @pytest.mark.parametrize("mode", ["abs_pos", "no_pos"]) + def test_non_rel_pos_modes_have_no_rel_params(self, mode): + """abs_pos and no_pos modes must not allocate the rel-pos parameters.""" + model = TransformerEncoder( + feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model=mode + ) + for layer in model.layers: + attn = layer.attn + assert attn.linear_pos is None + assert attn.pos_bias_u is None + assert attn.pos_bias_v is None + + @pytest.mark.unit + def test_no_pos_has_no_positional_encoding_module(self): + model = TransformerEncoder( + feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model="no_pos" + ) + assert model.pos_enc is None + # set_max_audio_length is invoked in __init__; it must not crash for no_pos and must + # still record the requested max length so update_max_seq_length works normally. + assert model.max_audio_length == model.pos_emb_max_len + + @pytest.mark.unit + @pytest.mark.parametrize("mode", ["abs_pos", "rel_pos", "no_pos", None]) + def test_forward_each_mode_cpu(self, mode): + """Each ``self_attention_model`` choice (including ``None``) must produce a valid forward.""" + model = TransformerEncoder( + feat_in=128, + d_model=64, + n_heads=4, + n_layers=2, + drop_rate=0.0, + subsampling_factor=4, + self_attention_model=mode, + ) + model.eval() + + B, C, T = 2, 128, 200 + x = torch.randn(B, C, T) + lengths = torch.tensor([T, 160]) + + with torch.no_grad(): + out, out_lengths = model(audio_signal=x, length=lengths) + + assert out.shape == (B, 64, T // 4) + assert out_lengths[0].item() == T // 4 + assert out_lengths[1].item() == 160 // 4 + assert not torch.isnan(out).any() + + @pytest.mark.unit + def test_rel_pos_broadcasts_when_T_differs_from_n_heads(self): + """Regression test for the Transformer-XL bias broadcasting. + + ``pos_bias_{u,v}`` has shape ``(H, D)`` and must broadcast against the head axis of + ``q`` which has shape ``(B, H, T, D)``. A naive add would right-align ``H`` against + ``T`` and either crash (``T != H``) or silently apply the bias on the wrong axis + (``T == H``). This test exercises a configuration where ``T_attn != n_heads`` so the + broken broadcast would surface as an error. + """ + # 200 input frames / subsampling_factor=4 -> 50 attention frames; n_heads=4 -> T != H. + model = TransformerEncoder( + feat_in=128, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0, self_attention_model="rel_pos" + ) + model.eval() + + B, C, T = 2, 128, 200 + x = torch.randn(B, C, T) + lengths = torch.tensor([T, 160]) + + with torch.no_grad(): + out, _ = model(audio_signal=x, length=lengths) + + assert out.shape == (B, 64, T // 4) + assert not torch.isnan(out).any() From 67259308afc6c5aa35f729e4e774f2cb822741e5 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Wed, 20 May 2026 23:13:35 -0700 Subject: [PATCH 5/5] Apply black formatting Signed-off-by: Taejin Park --- .../asr/modules/transformer_encoder.py | 2 +- .../collections/asr/test_transformer_encoder.py | 16 ++++------------ 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/nemo/collections/asr/modules/transformer_encoder.py b/nemo/collections/asr/modules/transformer_encoder.py index 61f1a4d21e14..23f066bf3b3d 100644 --- a/nemo/collections/asr/modules/transformer_encoder.py +++ b/nemo/collections/asr/modules/transformer_encoder.py @@ -212,7 +212,7 @@ def _build_rel_pos_score_mod(self, q, pos_emb): matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) # (B, H, T, 2T - 1) # rel_shift converts absolute-relative-position columns into (query, key) columns; # keep the first T to land in (B, H, T, T) bias space. - self._rel_pos_bias = self._rel_shift(matrix_bd)[..., :T] * (D ** -0.5) + self._rel_pos_bias = self._rel_shift(matrix_bd)[..., :T] * (D**-0.5) # Matrix c: fold u @ K^T into FlexAttention by rewriting Q as (Q + u). return self._rel_pos_score_mod, q + bias_u diff --git a/tests/collections/asr/test_transformer_encoder.py b/tests/collections/asr/test_transformer_encoder.py index 4b0d4cb3b3a1..61cb1a63470f 100644 --- a/tests/collections/asr/test_transformer_encoder.py +++ b/tests/collections/asr/test_transformer_encoder.py @@ -517,17 +517,13 @@ def test_default_is_rel_pos(self): @pytest.mark.unit @pytest.mark.parametrize("mode", ["abs_pos", "rel_pos", "no_pos"]) def test_valid_modes_are_accepted(self, mode): - model = TransformerEncoder( - feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model=mode - ) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model=mode) assert model.self_attention_model == mode @pytest.mark.unit def test_none_aliases_no_pos(self): """Passing ``self_attention_model=None`` must be equivalent to ``"no_pos"``.""" - model = TransformerEncoder( - feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model=None - ) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model=None) assert model.self_attention_model == "no_pos" assert model.pos_enc is None @@ -559,9 +555,7 @@ def test_rel_pos_attention_params_allocated(self): @pytest.mark.parametrize("mode", ["abs_pos", "no_pos"]) def test_non_rel_pos_modes_have_no_rel_params(self, mode): """abs_pos and no_pos modes must not allocate the rel-pos parameters.""" - model = TransformerEncoder( - feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model=mode - ) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model=mode) for layer in model.layers: attn = layer.attn assert attn.linear_pos is None @@ -570,9 +564,7 @@ def test_non_rel_pos_modes_have_no_rel_params(self, mode): @pytest.mark.unit def test_no_pos_has_no_positional_encoding_module(self): - model = TransformerEncoder( - feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model="no_pos" - ) + model = TransformerEncoder(feat_in=128, d_model=64, n_heads=4, n_layers=2, self_attention_model="no_pos") assert model.pos_enc is None # set_max_audio_length is invoked in __init__; it must not crash for no_pos and must # still record the requested max length so update_max_seq_length works normally.