diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index a4ab84c75cb5..161e6cc338b3 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -56,6 +56,7 @@ ) from nemo.utils import logging + __all__ = ['ConformerEncoder', 'ConformerMultiLayerFeatureExtractor'] @@ -1065,15 +1066,24 @@ def setup_streaming_params( else: streaming_cfg.pre_encode_cache_size = 0 + # Number of subsampled output frames produced from the pre-encode left-context + # cache. This is what we drop after subsampling so chunked inference matches a + # full pass. For convolutional subsampling with stride > 1, this is NOT a simple + # floor division — see ``ConvSubsampling.get_streaming_drop_size``. if isinstance(streaming_cfg.pre_encode_cache_size, list): - if streaming_cfg.pre_encode_cache_size[1] >= 1: - streaming_cfg.drop_extra_pre_encoded = ( - 1 + (streaming_cfg.pre_encode_cache_size[1] - 1) // self.subsampling_factor - ) - else: - streaming_cfg.drop_extra_pre_encoded = 0 + pre_encode_cache = streaming_cfg.pre_encode_cache_size[1] + else: + pre_encode_cache = streaming_cfg.pre_encode_cache_size + + if pre_encode_cache <= 0: + streaming_cfg.drop_extra_pre_encoded = 0 + elif hasattr(self.pre_encode, "get_streaming_drop_size"): + streaming_cfg.drop_extra_pre_encoded = self.pre_encode.get_streaming_drop_size(pre_encode_cache) else: - streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor + # Legacy fallback for custom pre_encode modules that pre-date + # ``get_streaming_drop_size``. Coincides with the convolutional recurrence at + # the default ``cache_size = subsampling_factor + 1`` but diverges otherwise. + streaming_cfg.drop_extra_pre_encoded = 1 + (pre_encode_cache - 1) // self.subsampling_factor for m in self.layers.modules(): if hasattr(m, "_max_cache_len"): diff --git a/nemo/collections/asr/parts/submodules/subsampling.py b/nemo/collections/asr/parts/submodules/subsampling.py index 7f9fc606991c..6df7c14e7fe0 100644 --- a/nemo/collections/asr/parts/submodules/subsampling.py +++ b/nemo/collections/asr/parts/submodules/subsampling.py @@ -46,6 +46,18 @@ def get_sampling_frames(self): def get_streaming_cache_size(self): return 0 + def get_streaming_drop_size(self, cache_size: int) -> int: + """Number of subsampled output frames produced from `cache_size` input frames. + + Used by streaming encoders to know how many leading frames of the encoder output + correspond to the pre-encode left-context cache, so they can be dropped after + subsampling. For StackingSubsampling the relation is exact: subsampling_factor + consecutive frames stack into one output frame. + """ + if cache_size <= 0: + return 0 + return cache_size // self.subsampling_factor + def forward(self, x, lengths): b, t, h = x.size() pad_size = (self.subsampling_factor - (t % self.subsampling_factor)) % self.subsampling_factor @@ -382,6 +394,30 @@ def get_sampling_frames(self): def get_streaming_cache_size(self): return [0, self.subsampling_factor + 1] + def get_streaming_drop_size(self, cache_size: int) -> int: + """Number of subsampled output frames produced from `cache_size` input frames. + + For convolutional subsampling with stride > 1, the length transformation through + each layer is not a simple floor division: it follows the recurrence + `L_next = floor((L + all_paddings - kernel_size) / stride) + 1` (or `ceil` when + `_ceil_mode` is set). Composed over `_sampling_num` layers, the result is what + `calc_length` already computes for the actual forward pass. Using the same helper + here keeps the streaming-drop count consistent with the encoder's own length + bookkeeping for arbitrary `cache_size`, instead of a divisor approximation that + only happens to match the default `subsampling_factor + 1` cache size. + """ + if cache_size <= 0: + return 0 + out = calc_length( + torch.tensor(cache_size, dtype=torch.float), + all_paddings=self._left_padding + self._right_padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + repeat_num=self._sampling_num, + ) + return int(out.item()) + def forward(self, x, lengths): out_lengths = calc_length( lengths, diff --git a/tests/collections/asr/test_asr_subsampling.py b/tests/collections/asr/test_asr_subsampling.py index 8f638afb73c0..98339c56a96b 100644 --- a/tests/collections/asr/test_asr_subsampling.py +++ b/tests/collections/asr/test_asr_subsampling.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest import torch - from nemo.collections.asr.models import ASRModel @@ -59,3 +58,100 @@ def test_forward(self): assert diff <= 0.2 diff = torch.mean(torch.abs(logprobs_batch4_split - logprobs_batch4_nosplit)) assert diff <= 0.2 + + +class TestStreamingDropExtraPreEncoded: + """``ConvSubsampling.get_streaming_drop_size`` must match what the encoder actually + produces from a ``cache_size``-long input segment. + + Regression test for the streaming/full-pass mismatch reported in + https://github.com/NVIDIA-NeMo/NeMo/issues/15482 — the old formula + ``1 + (cache_size - 1) // subsampling_factor`` diverges from the true convolutional + recurrence for arbitrary ``pre_encode_cache_size``. + """ + + @pytest.mark.unit + @pytest.mark.parametrize( + "subsampling,subsampling_factor", + [ + ("striding", 4), + ("striding", 8), + ("dw_striding", 4), + ("dw_striding", 8), + ], + ) + @pytest.mark.parametrize("cache_size", [1, 4, 8, 9, 11, 16, 32]) + def test_drop_size_matches_forward(self, subsampling, subsampling_factor, cache_size): + """For a causal conv subsampling, the number of output frames the actual + ``forward`` returns from a ``cache_size``-long input must equal + ``get_streaming_drop_size(cache_size)``. + """ + from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling + + feat_in = 80 + sub = ConvSubsampling( + subsampling=subsampling, + subsampling_factor=subsampling_factor, + feat_in=feat_in, + feat_out=16, + conv_channels=16, + subsampling_conv_chunking_factor=1, + is_causal=True, + ) + sub.eval() + x = torch.zeros(1, cache_size, feat_in) + lengths = torch.tensor([cache_size], dtype=torch.int64) + with torch.no_grad(): + _, out_lengths = sub(x, lengths) + expected = int(out_lengths[0].item()) + assert sub.get_streaming_drop_size(cache_size) == expected + + @pytest.mark.unit + def test_drop_size_zero_for_empty_cache(self): + from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling, StackingSubsampling + + sub = ConvSubsampling( + subsampling="striding", + subsampling_factor=8, + feat_in=80, + feat_out=16, + conv_channels=16, + subsampling_conv_chunking_factor=1, + is_causal=True, + ) + assert sub.get_streaming_drop_size(0) == 0 + + stack = StackingSubsampling(subsampling_factor=4, feat_in=80, feat_out=16) + assert stack.get_streaming_drop_size(0) == 0 + + @pytest.mark.unit + def test_drop_size_legacy_formula_diverges_for_non_default_cache(self): + """Document the bug being fixed: at the issue-reported case ``cache_size=11`` + with ``subsampling_factor=8``, the old formula returns 2 but the true value is 3. + """ + from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling + + sub = ConvSubsampling( + subsampling="striding", + subsampling_factor=8, + feat_in=80, + feat_out=16, + conv_channels=16, + subsampling_conv_chunking_factor=1, + is_causal=True, + ) + cache_size = 11 + legacy = 1 + (cache_size - 1) // 8 + assert legacy == 2 # old, wrong + assert sub.get_streaming_drop_size(cache_size) == 3 # new, matches the forward pass + + @pytest.mark.unit + def test_stacking_drop_size(self): + from nemo.collections.asr.parts.submodules.subsampling import StackingSubsampling + + stack = StackingSubsampling(subsampling_factor=4, feat_in=80, feat_out=16) + # StackingSubsampling.get_streaming_cache_size() returns 0 by default, but the + # helper should still answer sensibly for any positive cache_size. + assert stack.get_streaming_drop_size(4) == 1 + assert stack.get_streaming_drop_size(7) == 1 + assert stack.get_streaming_drop_size(8) == 2