From b6ad0feddf600506f5b64f03598e6387bfbe5552 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 14 Jan 2026 19:05:28 +0400 Subject: [PATCH 01/13] add streaming beam search Signed-off-by: lilithgrigoryan --- .../speech_to_text_streaming_infer_rnnt.py | 152 +++++++++++-- .../parts/submodules/rnnt_beam_decoding.py | 2 +- .../submodules/rnnt_maes_batched_computer.py | 108 ++++++--- .../submodules/rnnt_malsd_batched_computer.py | 204 ++++++++++++----- .../transducer_decoding/label_looping_base.py | 2 + .../utils/batched_beam_decoding_utils.py | 207 ++++++++++++++++++ .../asr/parts/utils/streaming_utils.py | 11 +- 7 files changed, 587 insertions(+), 99 deletions(-) diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index 88831aaed00c..686eb8bbea25 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -75,9 +75,12 @@ from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel from nemo.collections.asr.parts.context_biasing.biasing_multi_model import BiasingRequestItemConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.collections.asr.parts.submodules.rnnt_maes_batched_computer import ModifiedAESBatchedRNNTComputer +from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import ( GreedyBatchedLabelLoopingComputerBase, ) +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import BatchedBeamHyps from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer from nemo.collections.asr.parts.utils.manifest_utils import filepath_to_absolute, read_manifest from nemo.collections.asr.parts.utils.rnnt_utils import BatchedHyps, batched_hyps_to_hypotheses @@ -232,9 +235,21 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: # Change Decoding Config with open_dict(cfg.decoding): - if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True: + if cfg.decoding.strategy == "greedy_batch": + if cfg.decoding.greedy.loop_labels is not True: + raise NotImplementedError( + "This script supports `greedy_batch` strategy only with Label-Looping algorithm" + ) + cfg.decoding.greedy.preserve_alignments = False + elif cfg.decoding.strategy == "malsd_batch": + # MALSD beam search is supported for streaming + pass + elif cfg.decoding.strategy == "maes_batch": + # MAES beam search is supported for streaming + pass + else: raise NotImplementedError( - "This script currently supports only `greedy_batch` strategy with Label-Looping algorithm" + "This script currently supports only `greedy_batch` with Label-Looping or `malsd_batch` strategy with MALSD" ) cfg.decoding.tdt_include_token_duration = cfg.timestamps cfg.decoding.greedy.preserve_alignments = False @@ -278,7 +293,17 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: asr_model.preprocessor.featurizer.pad_to = 0 asr_model.eval() - decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer + # Get decoding computer based on strategy + if cfg.decoding.strategy == "greedy_batch": + decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer + elif cfg.decoding.strategy == "malsd_batch": + # Beam search strategies use _decoding_computer (private attribute) + decoding_computer: ModifiedALSDBatchedRNNTComputer = asr_model.decoding.decoding._decoding_computer + elif cfg.decoding.strategy == "maes_batch": + # MAES beam search returns BatchedBeamHyps + decoding_computer: ModifiedAESBatchedRNNTComputer = asr_model.decoding.decoding._decoding_computer + else: + raise ValueError(f"Unsupported decoding strategy: {cfg.decoding.strategy}") audio_sample_rate = model_cfg.preprocessor['sample_rate'] @@ -394,7 +419,49 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: ) rest_audio_lengths = audio_batch_lengths.clone() - # iterate over audio samples + # For MALSD: batched_hyps is stored in state and reused (no merge needed) + # For greedy: fresh BatchedHyps created each chunk, needs merging + is_beam_search = isinstance(decoding_computer, ModifiedALSDBatchedRNNTComputer) or isinstance(decoding_computer, ModifiedAESBatchedRNNTComputer) + + # ============================================================================ + # ENCODER PROCESSING MODES: + # + # MODE 1: FULL ENCODER PASS (Non-streaming simulation) + # - Runs encoder once on entire audio upfront + # - Extracts chunks from pre-computed output + # - Use this to verify consistency with non-streaming scripts + # - TO ENABLE: Uncomment all "MODE 1" sections below + # + # MODE 2: STREAMING CHUNKED ENCODER (Default/Original) + # - Runs encoder separately for each chunk with context + # - True streaming behavior with left-chunk-right context windows + # - Currently ACTIVE + # - TO KEEP: Leave "MODE 2" sections uncommented + # ============================================================================ + + # import pdb; pdb.set_trace() + # ============================================================================ + # MODE 1: FULL ENCODER PASS (for testing consistency with non-streaming) + # TO ENABLE: Uncomment this section and comment out MODE 2 sections below + # ============================================================================ + # # TESTING: Run encoder once on full audio (not chunked) + # full_encoder_output, full_encoder_output_len = asr_model( + # input_signal=audio_batch, + # input_signal_length=audio_batch_lengths, + # ) + # full_encoder_output = full_encoder_output.transpose(1, 2) # [B, T, C] + + # # do not recalculate joint projection, project only once + # full_encoder_output_projected = asr_model.joint.project_encoder(full_encoder_output) + # full_encoder_output_projected_len = full_encoder_output_len + + # # Track per-sample frame positions in the full encoder output + # # Different samples may have different lengths, so we need per-sample tracking + # encoder_frame_positions = torch.zeros([batch_size], dtype=torch.long, device=device) + # ============================================================================ + + # import pdb; pdb.set_trace() + # iterate over audio samples (but only for decoding chunks) while left_sample < audio_batch.shape[1]: # add samples to buffer chunk_length = min(right_sample, audio_batch.shape[1]) - left_sample @@ -412,6 +479,50 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: is_last_chunk_batch=is_last_chunk_batch, ) + # ======================================================================== + # MODE 1: FULL ENCODER PASS - Extract pre-computed chunks + # TO ENABLE: Uncomment this section and comment out MODE 2 section below + # ======================================================================== + # Use buffer's context to know the actual chunk sizes (handles variable lengths and last chunks) + # encoder_context_batch = buffer.context_size_batch.subsample(factor=encoder_frame2audio_samples) + + # # Extract chunks for each sample from their current position + # # Since samples can be at different positions, we need to handle each separately + # max_chunk_size_this_iter = encoder_context_batch.chunk.max().item() + # encoder_output_chunk = torch.zeros( + # [batch_size, max_chunk_size_this_iter, full_encoder_output_projected.shape[2]], + # dtype=full_encoder_output_projected.dtype, + # device=device + # ) + + # # Extract the appropriate chunk for each sample + # for b_idx in range(batch_size): + # start_pos = encoder_frame_positions[b_idx].item() + # chunk_len = encoder_context_batch.chunk[b_idx].item() + # end_pos = min(start_pos + chunk_len, full_encoder_output_projected.shape[1]) + # actual_len = end_pos - start_pos + # if actual_len > 0: + # encoder_output_chunk[b_idx, :actual_len] = full_encoder_output_projected[b_idx, start_pos:end_pos] + + # # Use the buffer's chunk size calculations (which properly handle per-sample lengths) + # encoder_out_len_chunk = encoder_context_batch.chunk + + # # decode only chunk frames (using pre-computed encoder output) + # chunk_batched_hyps, _, state = decoding_computer( + # x=encoder_output_chunk, + # out_len=encoder_out_len_chunk, + # prev_batched_state=state, + # ) + + # # Update per-sample positions + # encoder_frame_positions += encoder_context_batch.chunk + # ======================================================================== + + # import pdb; pdb.set_trace() + # ======================================================================== + # MODE 2: STREAMING CHUNKED ENCODER (ORIGINAL/DEFAULT) + # TO DISABLE: Comment out this section when using MODE 1 + # ======================================================================== # get encoder output using full buffer [left-chunk-right] encoder_output, encoder_output_len = asr_model( input_signal=buffer.samples, @@ -435,24 +546,37 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: prev_batched_state=state, multi_biasing_ids=multi_biasing_ids, ) - # merge hyps with previous hyps - if current_batched_hyps is None: + # ======================================================================== + + # Handle hypothesis accumulation differently for beam search vs greedy + if is_beam_search: + # For beam search: same object reused across chunks (stored in state) current_batched_hyps = chunk_batched_hyps else: - current_batched_hyps.merge_(chunk_batched_hyps) + # For greedy: merge chunks using merge_ + if current_batched_hyps is None: + current_batched_hyps = chunk_batched_hyps + else: + current_batched_hyps.merge_(chunk_batched_hyps) # move to next sample rest_audio_lengths -= chunk_lengths_batch left_sample = right_sample right_sample = min(right_sample + context_samples.chunk, audio_batch.shape[1]) # add next chunk - # remove biasing requests from the decoder - if use_per_stream_biasing and audio_data.biasing_requests is not None: - for request in audio_data.biasing_requests: - if request is not None and request.multi_model_id is not None: - decoding_computer.biasing_multi_model.remove_model(request.multi_model_id) - request.multi_model_id = None - all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size)) + # Convert batched hypotheses to list + if isinstance(current_batched_hyps, BatchedBeamHyps): + # MALSD beam search returns BatchedBeamHyps + all_hyps.extend(current_batched_hyps.to_hyps_list(score_norm=True)) + else: + # Greedy batch returns BatchedHyps + # remove biasing requests from the decoder + if use_per_stream_biasing and audio_data.biasing_requests is not None: + for request in audio_data.biasing_requests: + if request is not None and request.multi_model_id is not None: + decoding_computer.biasing_multi_model.remove_model(request.multi_model_id) + request.multi_model_id = None + all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size)) timer.stop(device=map_location) # convert text diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index d0b02a73daaf..1909212c81f2 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -1690,7 +1690,7 @@ def forward( self.joint.eval() inseq = encoder_output # [B, T, D] - batched_beam_hyps = self._decoding_computer(x=inseq, out_len=logitlen) + batched_beam_hyps, alignments, decoding_state = self._decoding_computer(x=inseq, out_len=logitlen) batch_size = encoder_output.shape[0] if self.return_best_hypothesis: diff --git a/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py index 1af34f11ac8e..ef5458a5a762 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py @@ -16,6 +16,8 @@ import torch +from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import BatchedLabelLoopingState +from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( INACTIVE_SCORE, @@ -121,6 +123,7 @@ def batched_modified_adaptive_expansion_search_torch( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, + prev_batched_state: Optional[BatchedBeamHyps] = None, ) -> BatchedBeamHyps: """ Pure PyTorch implementation @@ -134,22 +137,11 @@ def batched_modified_adaptive_expansion_search_torch( encoder_output_projected = self.joint.project_encoder(encoder_output) float_dtype = encoder_output_projected.dtype - - # init empty batched hypotheses - batched_hyps = BatchedBeamHyps( - batch_size=batch_size, - beam_size=self.beam_size, - blank_index=self._blank_index, - init_length=max_time * (self.maes_num_steps + 1) if self.maes_num_steps is not None else max_time, - device=device, - float_dtype=float_dtype, - store_prefix_hashes=True, - ) - - last_labels_wb = torch.full( - [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long - ) - + + # import pdb; pdb.set_trace() + # encoder_output_projected = encoder_output + # float_dtype = encoder_output.dtype + batch_indices = ( torch.arange(batch_size, device=device)[:, None].expand(batch_size, self.beam_size).clone() ) # size: batch_size x beam_size @@ -162,6 +154,28 @@ def batched_modified_adaptive_expansion_search_torch( .clone() ) # size: batch_size x beam_size x beam_size + maes_expansion_beta + if prev_batched_state is not None and prev_batched_state.batched_hyps is not None: + batched_hyps = prev_batched_state.batched_hyps + time_indices = torch.zeros_like(beam_indices) + last_timesteps = (encoder_output_length - 1)[:, None].expand_as(beam_indices) + safe_time_indices = torch.minimum(time_indices, last_timesteps) + active_mask = time_indices <= last_timesteps + else: + # init empty batched hypotheses + batched_hyps = BatchedBeamHyps( + batch_size=batch_size, + beam_size=self.beam_size, + blank_index=self._blank_index, + init_length=max_time * (self.maes_num_steps + 1) if self.maes_num_steps is not None else max_time, + device=device, + float_dtype=float_dtype, + store_prefix_hashes=True, + ) + time_indices = torch.zeros_like(beam_indices) + safe_time_indices = torch.zeros_like(time_indices) # time indices, guaranteed to be < out_len + last_timesteps = (encoder_output_length - 1)[:, None].expand_as(beam_indices) + active_mask = time_indices <= last_timesteps + time_indices = torch.zeros_like(batch_indices) safe_time_indices = torch.zeros_like(time_indices) last_timesteps = (encoder_output_length - 1)[:, None].expand(batch_size, self.beam_size) @@ -176,14 +190,34 @@ def batched_modified_adaptive_expansion_search_torch( ) # vocab_size_no_blank lm_scores = lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha - decoder_output, decoder_state, *_ = self.decoder.predict( - last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size - ) - # do not recalculate joint projection - decoder_output = self.joint.project_prednet(decoder_output) + if prev_batched_state is None: + last_labels_wb = torch.full( + [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long + ) + decoder_state = self.decoder.initialize_state( + torch.empty( + [ + batch_size * self.beam_size, + ], + dtype=float_dtype, + device=device, + ) + ) + + decoder_output, state, *_ = self.decoder.predict( + last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size + ) + # do not recalculate joint projection + decoder_output = self.joint.project_prednet(decoder_output) # size: [(batch_size x beam_size), 1, Dim] + self.decoder.batch_replace_states_all(state, dst_states=decoder_state) + else: + # Continuing from previous chunk - batched_hyps already contains all state + decoder_output = prev_batched_state.predictor_outputs + decoder_state = prev_batched_state.predictor_states while active_mask.any(): # frames loop to_update = active_mask.clone() # mask for expansions loop + # import pdb; pdb.set_trace() # step 1: get joint output logits = ( @@ -235,6 +269,7 @@ def batched_modified_adaptive_expansion_search_torch( next_labels = next_labels.view(batch_size, -1)[batch_indices, idx] hyp_indices = expansion_beam_indices.view(batch_size, -1)[batch_indices, idx] + # import pdb; pdb.set_trace() # step 3.3: update batched beam hypotheses structure batched_hyps.add_results_(hyp_indices, next_labels, next_hyps_probs) @@ -311,6 +346,7 @@ def batched_modified_adaptive_expansion_search_torch( expansion_steps += 1 if to_update.any(): + # import pdb; pdb.set_trace() # step 4: force blank to active hypotheses next_hyps_probs = torch.where(to_update, batched_hyps.scores + logps[..., -1], batched_hyps.scores) next_labels = torch.where(to_update, self._blank_index, -1) @@ -320,8 +356,29 @@ def batched_modified_adaptive_expansion_search_torch( time_indices += 1 active_mask = time_indices <= last_timesteps safe_time_indices = torch.where(active_mask, time_indices, last_timesteps) - - return batched_hyps + + # import pdb; pdb.set_trace() + last_labels = batched_hyps.get_last_labels(pad_id=self._SOS) + batched_hyps.next_timestamp.fill_(0) + decoding_state = BatchedLabelLoopingState( + predictor_states=decoder_state, + predictor_outputs=decoder_output, + labels=( + torch.where(last_labels == self._SOS, prev_batched_state.labels, last_labels) + if prev_batched_state is not None + else last_labels + ), + decoded_lengths=( + encoder_output_length.clone() + if prev_batched_state is None + else encoder_output_length + prev_batched_state.decoded_lengths + ), + # fusion_states_list=fusion_states_list if self.fusion_models is not None else None, + time_jumps=None, + batched_hyps=batched_hyps, # Save batched_hyps object for next chunk + ) + + return batched_hyps, None, decoding_state def combine_scores(self, log_probs, lm_scores): """ @@ -438,5 +495,6 @@ def __call__( self, x: torch.Tensor, out_len: torch.Tensor, - ) -> BatchedBeamHyps: - return self.batched_modified_adaptive_expansion_search_torch(encoder_output=x, encoder_output_length=out_len) + prev_batched_state: Optional[BatchedBeamHyps] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: + return self.batched_modified_adaptive_expansion_search_torch(encoder_output=x, encoder_output_length=out_len, prev_batched_state=prev_batched_state) diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index f1ef93c64f5d..ef9a08e1cea3 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -19,6 +19,8 @@ import torch.nn.functional as F from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel +from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import BatchedLabelLoopingState +from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( INACTIVE_SCORE, @@ -350,7 +352,8 @@ def modified_alsd_torch( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: """ Pytorch implementation of the batched ALSD algorithm for RNN-T. Args: @@ -367,23 +370,14 @@ def modified_alsd_torch( if torch.is_autocast_enabled(): encoder_output = encoder_output.to(torch.get_autocast_gpu_dtype()) - # do not recalculate joint projection, project only once + # # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) float_dtype = encoder_output_projected.dtype - - # init empty batched beam hypotheses - batched_hyps = BatchedBeamHyps( - batch_size=batch_size, - beam_size=self.beam_size, - blank_index=self._blank_index, - init_length=max_time * (self.max_symbols + 1) if self.max_symbols is not None else max_time, - device=device, - float_dtype=float_dtype, - ) - - last_labels_wb = torch.full( - [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long - ) + + # encoder_output_projected = encoder_output + # float_dtype = encoder_output.dtype + + # import pdb; pdb.set_trace() batch_beam_indices = ( torch.arange(batch_size, dtype=torch.long, device=device)[:, None] @@ -396,6 +390,19 @@ def modified_alsd_torch( .clone() ) # size: batch_size x beam_size x beam_size + # Reuse batched beam hypotheses from state if continuing, otherwise create new + if prev_batched_state is not None and prev_batched_state.batched_hyps is not None: + batched_hyps = prev_batched_state.batched_hyps + else: + batched_hyps = BatchedBeamHyps( + batch_size=batch_size, + beam_size=self.beam_size, + blank_index=self._blank_index, + init_length=max_time * (self.max_symbols + 1) if self.max_symbols is not None else max_time, + device=device, + float_dtype=float_dtype, + ) + time_indices = torch.zeros_like(batch_beam_indices) safe_time_indices = torch.zeros_like(time_indices) # time indices, guaranteed to be < out_len last_timesteps = (encoder_output_length - 1)[:, None].expand_as(batch_beam_indices) @@ -403,43 +410,73 @@ def modified_alsd_torch( # setup fusion models if available if self.fusion_models is not None: - fusion_states_list = [] - fusion_states_candidates_list = [] - fusion_scores_list = [] - for fusion_model_idx, fusion_model in enumerate(self.fusion_models): - fusion_model.to(device) - fusion_states = fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True) - fusion_scores, fusion_states_candidates = fusion_model.advance( - states=fusion_states - ) # vocab_size_no_blank + if prev_batched_state is None or prev_batched_state.fusion_states_list is None: + fusion_states_list = [] + fusion_states_candidates_list = [] + fusion_scores_list = [] + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(device) + fusion_states = fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True) + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=fusion_states + ) # vocab_size_no_blank - fusion_scores = ( - fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) - * self.fusion_models_alpha[fusion_model_idx] - ) - fusion_states_list.append(fusion_states) - fusion_states_candidates_list.append(fusion_states_candidates) - fusion_scores_list.append(fusion_scores) + fusion_scores = ( + fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_model_idx] + ) + fusion_states_list.append(fusion_states) + fusion_states_candidates_list.append(fusion_states_candidates) + fusion_scores_list.append(fusion_scores) + else: + fusion_states_list = prev_batched_state.fusion_states_list + fusion_states_candidates_list = [] + fusion_scores_list = [] + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(device) + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=fusion_states_list[fusion_model_idx] + ) + fusion_scores = ( + fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_model_idx] + ) + fusion_states_candidates_list.append(fusion_states_candidates) + fusion_scores_list.append(fusion_scores) + else: + fusion_states_list = None - decoder_state = self.decoder.initialize_state( - torch.empty( - [ - batch_size * self.beam_size, - ], - dtype=float_dtype, - device=device, + if prev_batched_state is None: + last_labels_wb = torch.full( + [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long + ) + decoder_state = self.decoder.initialize_state( + torch.empty( + [ + batch_size * self.beam_size, + ], + dtype=float_dtype, + device=device, + ) ) - ) - - decoder_output, state, *_ = self.decoder.predict( - last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size - ) - # do not recalculate joint projection - decoder_output = self.joint.project_prednet(decoder_output) # size: [(batch_size x beam_size), 1, Dim] - self.decoder.batch_replace_states_all(state, dst_states=decoder_state) + decoder_output, state, *_ = self.decoder.predict( + last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size + ) + # do not recalculate joint projection + decoder_output = self.joint.project_prednet(decoder_output) # size: [(batch_size x beam_size), 1, Dim] + self.decoder.batch_replace_states_all(state, dst_states=decoder_state) + else: + # Continuing from previous chunk - batched_hyps already contains all state + decoder_output = prev_batched_state.predictor_outputs + decoder_state = prev_batched_state.predictor_states + + step1=0 while active_mask.any(): + # import pdb; pdb.set_trace() + # if step1 >= 0: # step 1: get joint output + fuse with fusion models (if present) + # print(f"Encoder output length: {encoder_output_length}") logits = ( self.joint.joint_after_projection( encoder_output_projected[batch_beam_indices.view(-1), safe_time_indices.view(-1)].unsqueeze(1), @@ -513,11 +550,16 @@ def modified_alsd_torch( labels_top_k.reshape(batch_size, -1), dim=-1, index=hyps_candidates_indices ) # labels for extended hypotheses + # import pdb; pdb.set_trace() + batched_hyps.add_results_(hyps_indices, next_labels, next_hyps_prob) # step 3: store results - if self.max_symbols is None: - batched_hyps.add_results_(hyps_indices, next_labels, next_hyps_prob) - else: - batched_hyps.add_results_no_checks_(hyps_indices, next_labels, next_hyps_prob) + + # if step1 == 37: + # print(f"Step {step1}") + # if self.max_symbols is None: + # batched_hyps.add_results_(hyps_indices, next_labels, next_hyps_prob) + # else: + # batched_hyps.add_results_no_checks_(hyps_indices, next_labels, next_hyps_prob) # step 4: recombine hypotheses: sum probabilities of identical hypotheses. batched_hyps.recombine_hyps_() @@ -587,12 +629,51 @@ def modified_alsd_torch( fusion_states_candidates_list[fusion_model_idx] = fusion_states_candidates fusion_scores_list[fusion_model_idx] = fusion_scores + # import pdb; pdb.set_trace() # step 6: update time indices + active mask - time_indices = batched_hyps.next_timestamp + time_indices = torch.gather(time_indices, dim=-1, index=hyps_indices) + (next_labels == self._blank_index) torch.minimum(time_indices, last_timesteps, out=safe_time_indices) active_mask = time_indices <= last_timesteps + # if step1 == 24: + # import pdb; pdb.set_trace() + # print(f"Step {step1}") + # print(f"Time indices: {time_indices}") + # print(F"Scores: {batched_hyps.scores}") + # print(F"Trancripts: {batched_hyps.transcript_wb[..., :batched_hyps.current_lengths_wb.max()]}") + # print(F"Trancript ptrs: {batched_hyps.transcript_wb_prev_ptr[..., :batched_hyps.current_lengths_wb.max()]}") + # print(F"Trancript_lengths: {batched_hyps.current_lengths_wb}") + + step1 += 1 + + # # fix timestamps for iterative decoding + # if not is_beam_search: + # if prev_batched_state is not None: + # batched_hyps.timestamps += prev_batched_state.decoded_lengths.unsqueeze(1).unsqueeze(1) + + # NB: last labels can not exist (nothing decoded on this step). + # return the last labels from the previous state in this case + last_labels = batched_hyps.get_last_labels(pad_id=self._SOS) + batched_hyps.next_timestamp.fill_(0) + decoding_state = BatchedLabelLoopingState( + predictor_states=decoder_state, + predictor_outputs=decoder_output, + labels=( + torch.where(last_labels == self._SOS, prev_batched_state.labels, last_labels) + if prev_batched_state is not None + else last_labels + ), + decoded_lengths=( + encoder_output_length.clone() + if prev_batched_state is None + else encoder_output_length + prev_batched_state.decoded_lengths + ), + fusion_states_list=fusion_states_list if self.fusion_models is not None else None, + time_jumps=None, + batched_hyps=batched_hyps, # Save batched_hyps object for next chunk + ) - return batched_hyps + # import pdb; pdb.set_trace() + return batched_hyps, None, decoding_state def topk_fusion_model(self, fusion_scores_list, log_probs, eps=1e-2): """ @@ -1176,9 +1257,20 @@ def __call__( self, x: torch.Tensor, out_len: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: + self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS if self.cuda_graphs_mode is not None and x.device.type == "cuda": + # CUDA graphs don't support streaming yet, fall back to torch implementation + if prev_batched_state is not None: + return self.modified_alsd_torch( + encoder_output=x, encoder_output_length=out_len, prev_batched_state=prev_batched_state + ) with torch.amp.autocast(device_type="cuda", enabled=False): - return self.modified_alsd_cuda_graphs(encoder_output=x, encoder_output_length=out_len) + batched_hyps = self.modified_alsd_cuda_graphs(encoder_output=x, encoder_output_length=out_len) + # Return empty state for non-streaming case + return batched_hyps, None, None - return self.modified_alsd_torch(encoder_output=x, encoder_output_length=out_len) + return self.modified_alsd_torch( + encoder_output=x, encoder_output_length=out_len, prev_batched_state=prev_batched_state + ) diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py b/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py index 30a4b97366cd..c71e21997326 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py @@ -47,6 +47,7 @@ class BatchedLabelLoopingState: decoded_lengths: torch.Tensor fusion_states_list: list[torch.Tensor] = field(default_factory=list) time_jumps: torch.Tensor | None = None + batched_hyps: Any = None # For beam search: BatchedBeamHyps object to continue across chunks @dataclass @@ -290,6 +291,7 @@ def __call__( prev_batched_state: previous batched decoding state multi_biasing_ids: optional tensor [Batch] with ids of fused biasing models """ + self.cuda_graphs_mode = None if self.cuda_graphs_mode is not None and x.device.type == "cuda": # disable CUDA graphs if Mixed Precision is used due to incorrect behavior with torch.amp.autocast(device_type="cuda", enabled=False): diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index e0af3b7623d7..3fe2e8db18b7 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -199,6 +199,24 @@ def clear_(self): self.next_timestamp.fill_(0) self.last_timestamp_lasts.fill_(0) + def get_last_labels(self, pad_id: int = -1) -> torch.Tensor: + """ + Get last labels for each hypothesis in the beam. + + Args: + pad_id: Value to use for padding (for hypotheses without labels). Defaults to -1. + + Returns: + Tensor of shape [batch_size, beam_size] with the last label for each hypothesis. + """ + # last_label already contains the last label for each beam + # Replace NON_EXISTENT_LABEL_VALUE with pad_id + return torch.where( + self.last_label != NON_EXISTENT_LABEL_VALUE, + self.last_label, + pad_id + ) + def _allocate_more(self): """ Dynamically allocates more memory for the internal buffers. @@ -275,6 +293,19 @@ def add_results_no_checks_( is_extended = next_labels >= 0 extended_with_blank = next_labels == self.blank_index extended_with_label = (is_extended) & (~extended_with_blank) + + # TODO: uncomment + # last_labels = torch.gather(self.last_label, dim=-1, index=next_indices) + # self.transcript_wb.scatter_( + # dim=-1, + # index=self.current_lengths_wb.unsqueeze(-1), + # src=torch.where(is_extended, next_labels, NON_EXISTENT_LABEL_VALUE).unsqueeze(-1) + # ) + # self.transcript_wb_prev_ptr.scatter_( + # dim=-1, + # index=self.current_lengths_wb.unsqueeze(-1), + # src=torch.where(is_extended, next_indices, INIT_POINTER_VALUE).unsqueeze(-1) + # ) if self.model_type == ASRModelTypeEnum.CTC: # for CTC last non-blank and non-repeated label extended_with_label = (extended_with_label) & (next_labels != last_labels) # non-repeated non-blank label @@ -314,6 +345,12 @@ def add_results_no_checks_( ) torch.add(self.current_lengths_wb, 1, out=self.current_lengths_wb) self.scores.copy_(next_hyps_prob) + + # TODO: uncomment + # self.current_lengths_wb.copy_( + # torch.gather(self.current_lengths_wb, dim=-1, index=next_indices) + is_extended + # ) + # self.scores.copy_(torch.where(is_extended, next_hyps_prob, torch.gather(self.scores, dim=-1, index=next_indices))) prev_transcript_hash = torch.gather(self.transcript_hash, dim=-1, index=next_indices) # update hashes and prefix hashes @@ -529,6 +566,86 @@ def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]: ] return hypotheses + # def flatten_sort_(self, score_norm: bool = True): + # """ + # Sorts and flattens the tree structure of hypotheses in a batched beam search decoding process. + # This is a SERIALIZED version that processes each batch element sequentially to avoid + # issues with pointer chasing in the batched version. + + # Args: + # score_norm (bool, optional): If True, normalizes the scores by dividing + # them by the current lengths of the hypotheses plus one. Defaults to True. + # This method performs the following steps: + # 1. Normalizes the scores if `score_norm` is True. + # 2. Sorts the normalized scores in descending order and retrieves the corresponding indices. + # 3. Iteratively reconstructs the tokens and timestamps for each hypothesis in reverse order. + # 4. Updates the internal state of the object, including transcripts, timestamps, scores, + # lengths, labels, and other metadata, based on the sorted order. + # """ + + # # add one for consistency with non-batched decodings, that use SOS. + # normalized_scores = ( + # self.scores / (self.current_lengths_nb.to(self.scores.dtype) + 1) if score_norm else self.scores + # ) + # normalized_scores, indices = torch.sort(normalized_scores, dim=-1, descending=True) + + # # Create temporary buffers to hold the sorted results + # new_transcript_wb = self.transcript_wb.clone() + # new_timestamps = self.timestamps.clone() if (self.model_type == ASRModelTypeEnum.TDT or self.model_type == ASRModelTypeEnum.RNNT) else None + + # # Process each batch element sequentially + # for batch_idx in range(self.batch_size): + # max_idx = self.current_lengths_wb[batch_idx].max() - 1 + # if max_idx < 0: + # continue + + # batch_indices_local = indices[batch_idx] # [beam_size] + + # # For each beam in the sorted order, reconstruct the path by following pointers + # for beam_idx in range(self.beam_size): + # src_beam = batch_indices_local[beam_idx].item() + # ptr = src_beam + + # # Reconstruct the path from max_idx down to 0 + # for idx in range(max_idx, -1, -1): + # # Copy the token at this position + # new_transcript_wb[batch_idx, beam_idx, idx] = self.transcript_wb[batch_idx, ptr, idx] + + # # Copy timestamp if applicable + # if new_timestamps is not None: + # new_timestamps[batch_idx, beam_idx, idx] = self.timestamps[batch_idx, ptr, idx] + + # # Follow the pointer to the previous beam + # next_ptr = self.transcript_wb_prev_ptr[batch_idx, ptr, idx].item() + # if next_ptr != INIT_POINTER_VALUE: + # # Only update pointer if it's valid; otherwise keep current ptr + # ptr = next_ptr + + # # Copy reconstructed paths back to main buffers + # self.transcript_wb.copy_(new_transcript_wb) + # if new_timestamps is not None: + # self.timestamps.copy_(new_timestamps) + + # # Reset pointers to simple sequential structure + # max_idx = self.current_lengths_wb.max() - 1 + # if max_idx >= 0: + # self.transcript_wb_prev_ptr[..., : max_idx + 1].copy_(self.beam_indices.unsqueeze(0).unsqueeze(-1)) + + # # Sort all other state tensors according to indices + # self.scores.copy_(torch.gather(self.scores, dim=-1, index=indices)) + # self.current_lengths_nb.copy_(torch.gather(self.current_lengths_nb, dim=-1, index=indices)) + # self.current_lengths_wb.copy_(torch.gather(self.current_lengths_wb, dim=-1, index=indices)) + + # self.last_label.copy_(torch.gather(self.last_label, dim=-1, index=indices)) + + # if self.model_type == ASRModelTypeEnum.TDT or self.model_type == ASRModelTypeEnum.RNNT: + # self.next_timestamp.copy_(torch.gather(self.next_timestamp, dim=-1, index=indices)) + # self.last_timestamp_lasts.copy_(torch.gather(self.last_timestamp_lasts, dim=-1, index=indices)) + + # self.transcript_hash.copy_(torch.gather(self.transcript_hash, dim=-1, index=indices)) + # if self.store_prefix_hashes: + # self.transcript_prefix_hash.copy_(torch.gather(self.transcript_prefix_hash, dim=-1, index=indices)) + def flatten_sort_(self, score_norm: bool = True): """ Sorts and flattens the tree structure of hypotheses in a batched beam search decoding process. @@ -621,3 +738,93 @@ def _create_transcripts_mask(self, transcripts: torch.Tensor): return self._create_fold_consecutive_mask(transcripts) else: return (transcripts >= 0) & (transcripts != self.blank_index) + + def merge_(self, other: "BatchedBeamHyps") -> "BatchedBeamHyps": + """ + Merge two batched beam hypotheses structures by concatenating transcripts. + Used for streaming/chunked inference where results from multiple chunks need to be combined. + + Prerequisites: + - Both self and other should have been processed with flatten_sort_() before merging, + so that each beam contains an independent flattened hypothesis. + - Beam indices should correspond across chunks (beam i in self matches beam i in other). + + Notes: + - Timestamps in 'other' should already be cumulative (adjusted for time offset). + - The transcript_hash values are copied from 'other' and won't reflect the full + merged transcript. This means recombine_hyps_() should NOT be called on merged + results without recomputing hashes. This is acceptable for output-only use. + + Args: + other: BatchedBeamHyps from the next chunk to merge + + Returns: + Self (modified in-place) + """ + max_other_len = other.current_lengths_wb.max().item() + + # Early return if other has nothing to merge + if max_other_len == 0: + return self + + # Check if we need more storage (using allocated buffer size, not current shape) + # Compute max needed length: current max + other max + max_needed = self.current_lengths_wb.max().item() + max_other_len + + # Expand storage if needed - use existing _allocate_more() method + while max_needed > self._max_length: + self._allocate_more() + + # Create a range tensor: [0, 1, 2, ..., max_other_len-1] + other_indices = torch.arange(max_other_len, device=self.device, dtype=torch.long) + + # Create shifted indices: current_lengths + [0, 1, 2, ...] + # Shape: [batch_size, beam_size, max_other_len] + shifted_indices = self.current_lengths_wb.unsqueeze(-1) + other_indices.unsqueeze(0).unsqueeze(0) + + # Scatter other's transcripts into self at shifted positions + self.transcript_wb.scatter_( + dim=-1, + index=shifted_indices, + src=other.transcript_wb[..., :max_other_len], + ) + + # Update pointers: for the merged portion, we set pointers to point to the same beam + # (since after flatten_sort_, each beam is independent) + self.transcript_wb_prev_ptr.scatter_( + dim=-1, + index=shifted_indices, + src=self.beam_indices.view(1, self.beam_size, 1).expand(self.batch_size, -1, max_other_len), + ) + + # Scatter timestamps + self.timestamps.scatter_( + dim=-1, + index=shifted_indices, + src=other.timestamps[..., :max_other_len], + ) + + # Update lengths + self.current_lengths_wb += other.current_lengths_wb + self.current_lengths_nb += other.current_lengths_nb + + # Update scores (log probabilities, so we add them) + self.scores += other.scores + + # Update transcript hash by combining hashes + # The hash of the merged transcript should account for all non-blank labels + self.transcript_hash.copy_(other.transcript_hash) + + # Update prefix hashes if used + if self.store_prefix_hashes: + self.transcript_prefix_hash.copy_(other.transcript_prefix_hash) + + # Update tracking fields from other (they reflect the end state after other chunk) + self.last_label.copy_(other.last_label) + + # Only update timestamp tracking fields for transducer models + if self.model_type != ASRModelTypeEnum.CTC: + self.next_timestamp.copy_(other.next_timestamp) + self.last_timestamp_lasts.copy_(other.last_timestamp_lasts) + + return self diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index d657c56a67b6..6795be5b6046 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -2242,10 +2242,15 @@ def subsample(self, factor: int) -> "ContextSizeBatch": factor: subsampling factor """ return ContextSizeBatch( - left=torch.div(self.left, factor, rounding_mode="floor"), - chunk=torch.div(self.chunk, factor, rounding_mode="floor"), - right=torch.div(self.right, factor, rounding_mode="floor"), + left=torch.ceil(self.left / factor).to(dtype=torch.long), + chunk=torch.ceil(self.chunk / factor).to(dtype=torch.long), + right=torch.ceil(self.right / factor).to(dtype=torch.long), ) + # return ContextSizeBatch( + # left=torch.div(self.left, factor).round().to(dtype=torch.long), + # chunk=torch.div(self.chunk, factor).round().to(dtype=torch.long), + # right=torch.div(self.right, factor).round().to(dtype=torch.long), + # ) def add_frames_( self, num_frames_batch: torch.Tensor, is_last_chunk_batch: torch.Tensor, expected_context: "ContextSize" From d4295258c93ca1f2af7cdbc1f3299da83f1325f5 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Mon, 19 Jan 2026 11:40:16 +0400 Subject: [PATCH 02/13] restore cuda graphs Signed-off-by: lilithgrigoryan --- .../asr/parts/submodules/rnnt_malsd_batched_computer.py | 2 +- .../parts/submodules/transducer_decoding/label_looping_base.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index ef9a08e1cea3..bc187628801a 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -1259,7 +1259,7 @@ def __call__( out_len: torch.Tensor, prev_batched_state: Optional[BatchedLabelLoopingState] = None, ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: - self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS + # self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS if self.cuda_graphs_mode is not None and x.device.type == "cuda": # CUDA graphs don't support streaming yet, fall back to torch implementation if prev_batched_state is not None: diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py b/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py index c71e21997326..b53531218698 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py @@ -291,7 +291,6 @@ def __call__( prev_batched_state: previous batched decoding state multi_biasing_ids: optional tensor [Batch] with ids of fused biasing models """ - self.cuda_graphs_mode = None if self.cuda_graphs_mode is not None and x.device.type == "cuda": # disable CUDA graphs if Mixed Precision is used due to incorrect behavior with torch.amp.autocast(device_type="cuda", enabled=False): From 0f686f72949b659f83cb2a71f1e49587b7994ecb Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 29 Jan 2026 12:52:36 +0400 Subject: [PATCH 03/13] add streaming beam search final Signed-off-by: lilithgrigoryan --- .../speech_to_text_streaming_infer_rnnt.py | 17 +- .../submodules/rnnt_malsd_batched_computer.py | 257 ++++++++++- .../asr/parts/submodules/tdt_beam_decoding.py | 11 +- .../submodules/tdt_malsd_batched_computer.py | 436 +++++++++++++++--- .../utils/batched_beam_decoding_utils.py | 98 +++- 5 files changed, 723 insertions(+), 96 deletions(-) diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index 686eb8bbea25..9c04fbac1a11 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -77,6 +77,7 @@ from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig from nemo.collections.asr.parts.submodules.rnnt_maes_batched_computer import ModifiedAESBatchedRNNTComputer from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer +from nemo.collections.asr.parts.submodules.tdt_malsd_batched_computer import ModifiedALSDBatchedTDTComputer from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import ( GreedyBatchedLabelLoopingComputerBase, ) @@ -298,7 +299,12 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer elif cfg.decoding.strategy == "malsd_batch": # Beam search strategies use _decoding_computer (private attribute) - decoding_computer: ModifiedALSDBatchedRNNTComputer = asr_model.decoding.decoding._decoding_computer + is_tdt = True + print(f"is_tdt: {is_tdt}") + if is_tdt: + decoding_computer: ModifiedALSDBatchedTDTComputer = asr_model.decoding.decoding._decoding_computer + else: + decoding_computer: ModifiedALSDBatchedRNNTComputer = asr_model.decoding.decoding._decoding_computer elif cfg.decoding.strategy == "maes_batch": # MAES beam search returns BatchedBeamHyps decoding_computer: ModifiedAESBatchedRNNTComputer = asr_model.decoding.decoding._decoding_computer @@ -419,9 +425,14 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: ) rest_audio_lengths = audio_batch_lengths.clone() + # print(f"decoding_computer: {type(decoding_computer)}") # For MALSD: batched_hyps is stored in state and reused (no merge needed) # For greedy: fresh BatchedHyps created each chunk, needs merging - is_beam_search = isinstance(decoding_computer, ModifiedALSDBatchedRNNTComputer) or isinstance(decoding_computer, ModifiedAESBatchedRNNTComputer) + is_beam_search = ( + isinstance(decoding_computer, ModifiedALSDBatchedRNNTComputer) or \ + isinstance(decoding_computer, ModifiedAESBatchedRNNTComputer) or \ + isinstance(decoding_computer, ModifiedALSDBatchedTDTComputer) + ) # ============================================================================ # ENCODER PROCESSING MODES: @@ -444,7 +455,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: # MODE 1: FULL ENCODER PASS (for testing consistency with non-streaming) # TO ENABLE: Uncomment this section and comment out MODE 2 sections below # ============================================================================ - # # TESTING: Run encoder once on full audio (not chunked) + # TESTING: Run encoder once on full audio (not chunked) # full_encoder_output, full_encoder_output_len = asr_model( # input_signal=audio_batch, # input_signal_length=audio_batch_lengths, diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index bc187628801a..985b88041fc6 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -103,6 +103,10 @@ class MALSDState: init_fusion_states_candidates_list: Optional[List[torch.Tensor]] = None # list of initial fusion states candidates init_fusion_scores_list: Optional[List[torch.Tensor]] = None # list of initial fusion scores + # Streaming state fields + is_continuation: torch.Tensor # flag indicating if this is a continuation from previous chunk + decoded_lengths: torch.Tensor # accumulated decoded lengths across chunks + def __init__( self, batch_size: int, @@ -126,11 +130,12 @@ def __init__( blank_index: index of the blank symbol """ + max_time = 375 self.device = device self.float_dtype = float_dtype self.batch_size = batch_size self.beam_size = beam_size - self.max_time = max_time + self.max_time = max_time self.blank_index = blank_index self.NON_EXISTENT_LABEL = torch.tensor(NON_EXISTENT_LABEL_VALUE, device=self.device, dtype=torch.long) @@ -186,6 +191,10 @@ def __init__( float_dtype=float_dtype, ) + # Streaming state fields + self.is_continuation = torch.tensor(False, device=self.device, dtype=torch.bool) + self.decoded_lengths = torch.zeros([self.batch_size], dtype=torch.long, device=self.device) + def need_reinit(self, encoder_output_projected: torch.Tensor) -> bool: """Check if need to reinit state: larger batch_size/max_time, or new device""" return ( @@ -200,6 +209,7 @@ class SeparateGraphsMALSD: """Class to store Cuda graphs for decoding when separate graphs are used""" before_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + before_loop_continuation: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) loop_body: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) loop_update_decoder: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) @@ -222,6 +232,7 @@ class CudaGraphsMode(PrettyStrEnum): separate_graphs: Optional[SeparateGraphsMALSD] full_graph: Optional[torch.cuda.CUDAGraph] + full_graph_continuation: Optional[torch.cuda.CUDAGraph] cuda_graphs_mode: Optional[CudaGraphsMode] state: Optional[MALSDState] fusion_models: Optional[List[NGramGPULanguageModel]] @@ -272,6 +283,7 @@ def __init__( self.state = None self.full_graph = None + self.full_graph_continuation = None self.separate_graphs = None self.cuda_graphs_mode = None @@ -346,6 +358,7 @@ def reset_cuda_graphs_state(self): """Reset state to release memory (for CUDA graphs implementations)""" self.state = None self.full_graph = None + self.full_graph_continuation = None self.separate_graphs = None def modified_alsd_torch( @@ -749,7 +762,8 @@ def modified_alsd_cuda_graphs( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: """ Cuda-Graphs implementation of the batched ALSD algorithm. Args: @@ -757,8 +771,9 @@ def modified_alsd_cuda_graphs( [batch_size, max_time, encoder_dim]. encoder_output_length (torch.Tensor): The lengths of the encoder outputs for each batch with shape [batch_size]. + prev_batched_state (Optional[BatchedLabelLoopingState]): The previous batched state. Returns: - BathcedBeamHyps: Batched beam hypotheses. + tuple: (BatchedBeamHyps, None, BatchedLabelLoopingState) """ assert self.cuda_graphs_mode is not None @@ -775,29 +790,54 @@ def modified_alsd_cuda_graphs( if self.state is None or self.state.need_reinit(encoder_output): self._graph_reinitialize(encoder_output, encoder_output_length) + # Set continuation flag and restore state from previous chunk if provided + is_continuation = prev_batched_state is not None + self.state.is_continuation.fill_(is_continuation) + + if is_continuation: + # Restore state from previous chunk + self._restore_state_from_prev(prev_batched_state, current_batch_size) + # set length to zero for elements outside the current batch self.state.encoder_output_length.fill_(0) - # copy (projected) encoder output and lenghts - self.state.encoder_output_projected[:current_batch_size, :current_max_time, ...].copy_(encoder_output) + # copy (projected) encoder output and lengths + # print("encoder_output.shape: ", encoder_output.shape) + # print("current_batch_size: ", current_batch_size) + # print("current_max_time: ", current_max_time) + self.state.encoder_output_projected[:current_batch_size, :current_max_time, ...].copy_(encoder_output[:current_batch_size, :current_max_time, ...]) self.state.encoder_output_length[:current_batch_size].copy_(encoder_output_length.unsqueeze(-1)) + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: - self.full_graph.replay() + # Use continuation graph if continuing from previous chunk, otherwise first chunk graph + if is_continuation: + self.full_graph_continuation.replay() + else: + self.full_graph.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: - self.separate_graphs.before_loop.replay() + # Use continuation before_loop graph if continuing from previous chunk + if is_continuation: + self.separate_graphs.before_loop_continuation.replay() + else: + self.separate_graphs.before_loop.replay() while self.state.active_mask_any.item(): self.separate_graphs.loop_body.replay() self.separate_graphs.loop_update_decoder.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: - # this mode is only for testing purposes # manual loop instead of using graphs - self._before_loop() + if is_continuation: + self._before_loop_continuation() + else: + self._before_loop() while self.state.active_mask_any.item(): self._loop_body() self._loop_update_decoder() else: raise NotImplementedError(f"Unknown graph mode: {self.cuda_graphs_mode}") - return self.state.batched_hyps + # Create and return decoding state for next chunk + decoding_state = self._create_decoding_state(encoder_output_length, prev_batched_state) + + return self.state.batched_hyps, None, decoding_state @classmethod def _create_loop_body_kernel(cls): @@ -945,6 +985,8 @@ def _partial_graphs_compile(self): stream_for_graph = torch.cuda.Stream(self.state.device) stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) self.separate_graphs = SeparateGraphsMALSD() + + # Compile before_loop graph for first chunk with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), @@ -954,6 +996,16 @@ def _partial_graphs_compile(self): ): self._before_loop() + # Compile before_loop_continuation graph for streaming + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph( + self.separate_graphs.before_loop_continuation, stream=stream_for_graph, capture_error_mode="thread_local" + ), + ): + self._before_loop_continuation() + with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), @@ -1005,6 +1057,34 @@ def _full_graph_compile(self): self._loop_body() self._loop_update_decoder() + # Compile continuation graph for streaming + self.full_graph_continuation = torch.cuda.CUDAGraph() + + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.full_graph_continuation, stream=stream_for_graph, capture_error_mode="thread_local"), + ): + self._before_loop_continuation() + capture_status, _, graph, _, _, _ = cu_call( + cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) + ) + + assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive + + # capture: while self.active_mask_any: + (loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) + loop_kernel = self._create_loop_body_kernel() + active_mask_any_ptr = np.array([self.state.active_mask_any.data_ptr()], dtype=np.uint64) + loop_args = np.array( + [loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], + dtype=np.uint64, + ) + # loop while there are active utterances + with with_conditional_node(loop_kernel, loop_args, loop_conditional_handle, device=self.state.device): + self._loop_body() + self._loop_update_decoder() + def _before_loop(self): """ Clears state and compute initial active mask @@ -1022,13 +1102,38 @@ def _before_loop(self): self.state.fusion_scores_list[fusion_idx].copy_(self.state.init_fusion_scores_list[fusion_idx]) self.state.fusion_states_prev_list[fusion_idx].copy_(self.state.init_fusion_states_list[fusion_idx]) + # set decoder state and output to initial values + self.state.decoder_output.copy_(self.state.init_decoder_output) + self.state.decoder_state[0].copy_(self.state.init_decoder_state[0]) + self.state.decoder_state[1].copy_(self.state.init_decoder_state[1]) + # last found labels - initially () symbol self.state.last_labels_wb.fill_(self._SOS) + + self._before_loop_common() + + def _before_loop_continuation(self): + """ + Prepares state for continuation chunk without clearing batched_hyps. + Decoder state and fusion states are already restored from previous chunk + via _restore_state_from_prev before this method is called. + """ + # Don't clear batched_hyps - it's already restored from previous state + # Don't reset decoder state - it's already restored from previous state + # Don't reset fusion states - they're already restored from previous state + + self._before_loop_common() + + def _before_loop_common(self): + """ + Common initialization for both first chunk and continuation. + Resets temporary variables and computes active mask. + """ self.state.next_scores.fill_(0.0) self.state.next_labels.fill_(0.0) self.state.next_idx.fill_(0.0) - # time indices + # time indices - reset for current chunk self.state.time_indices.fill_(0) self.state.safe_time_indices.fill_(0) # safe time indices: guaranteed to be < encoder_output_length @@ -1041,11 +1146,6 @@ def _before_loop(self): # same as: self.active_mask_any = active_mask.any() torch.any(self.state.active_mask, out=self.state.active_mask_any) - # set decoder state and output to initial values - self.state.decoder_output.copy_(self.state.init_decoder_output) - self.state.decoder_state[0].copy_(self.state.init_decoder_state[0]) - self.state.decoder_state[1].copy_(self.state.init_decoder_state[1]) - # set previous decoder state and output to initial values self.state.prev_decoder_output.fill_(0) self.state.prev_decoder_state[0].fill_(0) @@ -1253,6 +1353,117 @@ def _loop_update_decoder(self): torch.less_equal(self.state.time_indices, self.state.last_timesteps, out=self.state.active_mask) torch.any(self.state.active_mask, out=self.state.active_mask_any) + def _restore_state_from_prev( + self, prev_batched_state: BatchedLabelLoopingState, current_batch_size: int + ): + """ + Restore decoder state, fusion states, and batched_hyps from previous chunk's state. + Used for streaming/chunked decoding. + + Args: + prev_batched_state: State from previous chunk + current_batch_size: Current batch size + """ + # Restore decoder output and state + if prev_batched_state.predictor_outputs is not None: + self.state.decoder_output[:current_batch_size * self.beam_size].copy_( + prev_batched_state.predictor_outputs.view(-1, 1, prev_batched_state.predictor_outputs.shape[-1]) + ) + + if prev_batched_state.predictor_states is not None: + # Copy decoder states (assuming tuple of tensors) + for i, state_tensor in enumerate(prev_batched_state.predictor_states): + if state_tensor is not None: + self.state.decoder_state[i][:, :current_batch_size * self.beam_size].copy_( + state_tensor[:, :current_batch_size * self.beam_size] + ) + + # Restore fusion states if present + if prev_batched_state.fusion_states_list is not None and self.fusion_models is not None: + for fusion_idx, fusion_state in enumerate(prev_batched_state.fusion_states_list): + if fusion_state is not None: + self.state.fusion_states_list[fusion_idx][:current_batch_size].copy_( + fusion_state[:current_batch_size] + ) + # Recompute fusion scores and candidates from restored states + fusion_scores, fusion_states_candidates = self.fusion_models[fusion_idx].advance( + states=self.state.fusion_states_list[fusion_idx][:current_batch_size].reshape(-1) + ) + self.state.fusion_scores_list[fusion_idx][:current_batch_size].copy_( + fusion_scores.to(dtype=self.state.float_dtype).view(current_batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_idx] + ) + self.state.fusion_states_candidates_list[fusion_idx][:current_batch_size].copy_( + fusion_states_candidates.view(current_batch_size, self.beam_size, -1) + ) + + # Restore batched_hyps from previous state + if prev_batched_state.batched_hyps is not None: + self.state.batched_hyps.copy_from_(prev_batched_state.batched_hyps) + + # Restore decoded_lengths + if prev_batched_state.decoded_lengths is not None: + self.state.decoded_lengths[:current_batch_size].copy_( + prev_batched_state.decoded_lengths[:current_batch_size] + ) + + def _create_decoding_state( + self, + encoder_output_length: torch.Tensor, + prev_batched_state: Optional[BatchedLabelLoopingState], + ) -> BatchedLabelLoopingState: + """ + Create BatchedLabelLoopingState for the next chunk. + + Args: + encoder_output_length: Length of current encoder output + prev_batched_state: State from previous chunk (if any) + + Returns: + BatchedLabelLoopingState containing current decoding state + """ + current_batch_size = encoder_output_length.shape[0] + + # Get last labels from batched_hyps + last_labels = self.state.batched_hyps.get_last_labels(pad_id=self._SOS) + + # Reset next_timestamp for next chunk + self.state.batched_hyps.next_timestamp.fill_(0) + + # Calculate accumulated decoded lengths + if prev_batched_state is None: + decoded_lengths = encoder_output_length.clone() + else: + decoded_lengths = encoder_output_length + prev_batched_state.decoded_lengths[:current_batch_size] + + # Handle labels - if nothing decoded this chunk, use previous labels + if prev_batched_state is not None: + last_labels = torch.where( + last_labels == self._SOS, + prev_batched_state.labels, + last_labels + ) + + # Get fusion states if present + fusion_states_list = None + if self.fusion_models is not None and self.state.fusion_states_list is not None: + fusion_states_list = [ + state[:current_batch_size].clone() for state in self.state.fusion_states_list + ] + + return BatchedLabelLoopingState( + predictor_states=( + self.state.decoder_state[0][:, :current_batch_size * self.beam_size].clone(), + self.state.decoder_state[1][:, :current_batch_size * self.beam_size].clone(), + ), + predictor_outputs=self.state.decoder_output[:current_batch_size * self.beam_size].clone(), + labels=last_labels, + decoded_lengths=decoded_lengths, + fusion_states_list=fusion_states_list, + time_jumps=None, + batched_hyps=self.state.batched_hyps.clone(), + ) + def __call__( self, x: torch.Tensor, @@ -1261,15 +1472,13 @@ def __call__( ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: # self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS if self.cuda_graphs_mode is not None and x.device.type == "cuda": - # CUDA graphs don't support streaming yet, fall back to torch implementation - if prev_batched_state is not None: - return self.modified_alsd_torch( - encoder_output=x, encoder_output_length=out_len, prev_batched_state=prev_batched_state - ) with torch.amp.autocast(device_type="cuda", enabled=False): - batched_hyps = self.modified_alsd_cuda_graphs(encoder_output=x, encoder_output_length=out_len) - # Return empty state for non-streaming case - return batched_hyps, None, None + # print("Using CUDA graphs mode: NO_WHILE_LOOPS") + return self.modified_alsd_cuda_graphs( + encoder_output=x, + encoder_output_length=out_len, + prev_batched_state=prev_batched_state, + ) return self.modified_alsd_torch( encoder_output=x, encoder_output_length=out_len, prev_batched_state=prev_batched_state diff --git a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py index 65255b85849d..e7b4edfb63ff 100644 --- a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py @@ -961,14 +961,17 @@ def forward( self.joint.eval() inseq = encoder_output # [B, T, D] - batched_beam_hyps = self._decoding_computer(x=inseq, out_len=logitlen) + + # import pdb; pdb.set_trace() + encoder_output_projected = self.joint.project_encoder(encoder_output) + encoder_output_projected_len = encoded_lengths + batched_beam_hyps, alignments, decoding_state = self._decoding_computer(x=encoder_output_projected, out_len=encoder_output_projected_len) - # Ensures the correct number of hypotheses (batch_size) for CUDA Graphs compatibility batch_size = encoder_output.shape[0] if self.return_best_hypothesis: - hyps = batched_beam_hyps.to_hyps_list(score_norm=self.score_norm)[:batch_size] + hyps = batched_beam_hyps.to_hyps_list(score_norm=self.score_norm)[:batch_size] # type: ignore else: - hyps = batched_beam_hyps.to_nbest_hyps_list(score_norm=self.score_norm)[:batch_size] + hyps = batched_beam_hyps.to_nbest_hyps_list(score_norm=self.score_norm)[:batch_size] # type: ignore self.decoder.train(decoder_training_state) self.joint.train(joint_training_state) diff --git a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py index 45ecfed299a5..dfd6c07305fe 100644 --- a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py @@ -19,6 +19,8 @@ import torch.nn.functional as F from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel +from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import BatchedLabelLoopingState +from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( INACTIVE_SCORE, @@ -92,6 +94,10 @@ class MALSDState: batched_hyps: BatchedBeamHyps # batched hypotheses - decoding result + # Streaming state fields + is_continuation: torch.Tensor # flag indicating if this is a continuation from previous chunk + decoded_lengths: torch.Tensor # accumulated decoded lengths across chunks + # fusion models related fields fusion_models: Optional[List[NGramGPULanguageModel]] = None fusion_models_alpha: Optional[List[float]] = None @@ -184,6 +190,10 @@ def __init__( self.blank_mask = torch.zeros_like(self.active_mask, dtype=torch.bool) self.active_mask_any = torch.tensor(True, device=self.device, dtype=torch.bool) + # Streaming state fields + self.is_continuation = torch.tensor(False, device=self.device, dtype=torch.bool) + self.decoded_lengths = torch.zeros([self.batch_size], dtype=torch.long, device=self.device) + self.batched_hyps = BatchedBeamHyps( batch_size=batch_size, beam_size=self.beam_size, @@ -208,6 +218,7 @@ class SeparateGraphsMALSD: """Class to store Cuda graphs for decoding when separate graphs are used""" before_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + before_loop_continuation: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) loop_body: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) loop_update_decoder: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) @@ -230,6 +241,7 @@ class CudaGraphsMode(PrettyStrEnum): separate_graphs: Optional[SeparateGraphsMALSD] full_graph: Optional[torch.cuda.CUDAGraph] + full_graph_continuation: Optional[torch.cuda.CUDAGraph] cuda_graphs_mode: Optional[CudaGraphsMode] state: Optional[MALSDState] fusion_models: Optional[List[NGramGPULanguageModel]] @@ -357,13 +369,15 @@ def reset_cuda_graphs_state(self): """Reset state to release memory (for CUDA graphs implementations)""" self.state = None self.full_graph = None + self.full_graph_continuation = None self.separate_graphs = None def modified_alsd_torch( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: """ Pytorch implementation of the batched ALSD algorithm for TDT models. Args: @@ -371,8 +385,10 @@ def modified_alsd_torch( [batch_size, max_time, encoder_dim]. encoder_output_length (torch.Tensor): The lengths of the encoder outputs for each batch with shape [batch_size]. + prev_batched_state (Optional[BatchedLabelLoopingState]): The previous batched state for streaming. Returns: - BatchedBeamHyps: Batched beam hypotheses. + tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: + Batched beam hypotheses, alignments (None), and decoding state. """ batch_size, max_time, _ = encoder_output.shape @@ -384,21 +400,9 @@ def modified_alsd_torch( # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) float_dtype = encoder_output_projected.dtype - - # init empty batched beam hypotheses - batched_hyps = BatchedBeamHyps( - batch_size=batch_size, - beam_size=self.beam_size, - blank_index=self._blank_index, - init_length=max_time * (self.max_symbols + 1) if self.max_symbols is not None else max_time, - device=device, - float_dtype=float_dtype, - model_type='tdt', - ) - - last_labels_wb = torch.full( - [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long - ) + + # encoder_output_projected = encoder_output + # float_dtype = encoder_output.dtype batch_beam_indices = ( torch.arange(batch_size, dtype=torch.long, device=device)[:, None] @@ -411,47 +415,111 @@ def modified_alsd_torch( .clone() ) # size: batch_size x beam_size x beam_size - time_indices = torch.zeros_like(batch_beam_indices) + # Reuse batched beam hypotheses from state if continuing, otherwise create new + if prev_batched_state is not None and prev_batched_state.batched_hyps is not None: + batched_hyps = prev_batched_state.batched_hyps + # For continuation, initialize time_indices from batched_hyps.next_timestamp + # This represents where we left off in the previous chunk + # Subtract decoded_lengths to convert from absolute to relative time + time_indices = batched_hyps.next_timestamp.clone() - prev_batched_state.decoded_lengths[:, None].expand_as(batched_hyps.next_timestamp) + assert time_indices.min() >= 0, "Time indices should be non-negative" + else: + batched_hyps = BatchedBeamHyps( + batch_size=batch_size, + beam_size=self.beam_size, + blank_index=self._blank_index, + init_length=max_time * (self.max_symbols + 1) if self.max_symbols is not None else max_time, + device=device, + float_dtype=float_dtype, + model_type='tdt', + ) + time_indices = torch.zeros_like(batch_beam_indices) + safe_time_indices = torch.zeros_like(time_indices) # time indices, guaranteed to be < out_len last_timesteps = (encoder_output_length - 1)[:, None].expand_as(batch_beam_indices) + # Clamp time_indices to valid range for continuation + torch.minimum(time_indices, last_timesteps, out=safe_time_indices) active_mask = time_indices <= last_timesteps # setup fusion models if available if self.fusion_models is not None: - fusion_states_list = [] - fusion_states_candidates_list = [] - fusion_scores_list = [] - for fusion_model_idx, fusion_model in enumerate(self.fusion_models): - fusion_model.to(device) - fusion_states = fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True) - fusion_scores, fusion_states_candidates = fusion_model.advance(states=fusion_states) + if prev_batched_state is None or prev_batched_state.fusion_states_list is None: + fusion_states_list = [] + fusion_states_candidates_list = [] + fusion_scores_list = [] + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(device) + fusion_states = fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True) + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=fusion_states + ) # vocab_size_no_blank - fusion_scores = ( - fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) - * self.fusion_models_alpha[fusion_model_idx] + fusion_scores = ( + fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_model_idx] + ) + fusion_states_list.append(fusion_states) + fusion_states_candidates_list.append(fusion_states_candidates) + fusion_scores_list.append(fusion_scores) + else: + fusion_states_list = prev_batched_state.fusion_states_list + fusion_states_candidates_list = [] + fusion_scores_list = [] + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(device) + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=fusion_states_list[fusion_model_idx] + ) + fusion_scores = ( + fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_model_idx] + ) + fusion_states_candidates_list.append(fusion_states_candidates) + fusion_scores_list.append(fusion_scores) + else: + fusion_states_list = None + + if prev_batched_state is None: + last_labels_wb = torch.full( + [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long + ) + decoder_state = self.decoder.initialize_state( + torch.empty( + [ + batch_size * self.beam_size, + ], + dtype=float_dtype, + device=device, ) - fusion_states_list.append(fusion_states) - fusion_states_candidates_list.append(fusion_states_candidates) - fusion_scores_list.append(fusion_scores) + ) - decoder_state = self.decoder.initialize_state( - torch.empty( - [ - batch_size * self.beam_size, - ], - dtype=float_dtype, - device=device, + decoder_output, state, *_ = self.decoder.predict( + last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size ) - ) + # do not recalculate joint projection + decoder_output = self.joint.project_prednet(decoder_output) # size: [(batch_size x beam_size), 1, Dim] + self.decoder.batch_replace_states_all(state, dst_states=decoder_state) + else: + # Continuing from previous chunk - batched_hyps already contains all state + decoder_output = prev_batched_state.predictor_outputs + decoder_state = prev_batched_state.predictor_states - decoder_output, state, *_ = self.decoder.predict( - last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size - ) - # do not recalculate joint projection - decoder_output = self.joint.project_prednet(decoder_output) # size: [(batch_size x beam_size), 1, Dim] - self.decoder.batch_replace_states_all(state, dst_states=decoder_state) + # import pdb; pdb.set_trace() + step1 = 0 while active_mask.any(): + # import pdb; pdb.set_trace() + # print(f"Step {step1}") + # print(f"Time indices: {safe_time_indices}") + # print(f"Active mask: {active_mask}") + # print(f"Encoder output length: {encoder_output_length}") + # print(f"Decoder output: {decoder_output.shape}") + # print(f"Safe time indices: {safe_time_indices}") + # print(f"encoder_output_projected: {encoder_output_projected.shape}") + # print(f"Decoder state: {decoder_state.shape}") + + + # step 1: get joint output + fuse with fusion models (if present) logits = ( self.joint.joint_after_projection( @@ -549,11 +617,14 @@ def modified_alsd_torch( durations_top_k.reshape(batch_size, -1), dim=-1, index=hyps_candidates_indices ) # durations for extended hypotheses + # import pdb; pdb.set_trace() # step 3: store results - if self.max_symbols is None: - batched_hyps.add_results_(hyps_indices, next_labels, next_hyps_prob, next_label_durations) - else: - batched_hyps.add_results_no_checks_(hyps_indices, next_labels, next_hyps_prob, next_label_durations) + # if self.max_symbols is None: + # batched_hyps.add_results_(hyps_indices, next_labels, next_hyps_prob, next_label_durations) + # else: + # batched_hyps.add_results_no_checks_(hyps_indices, next_labels, next_hyps_prob, next_label_durations) + # print(f"DEBUG: Adding results to batched_hyps") + batched_hyps.add_results_(hyps_indices, next_labels, next_hyps_prob, next_label_durations) # step 4: recombine hypotheses: sum probabilities of identical hypotheses. batched_hyps.recombine_hyps_() @@ -624,11 +695,43 @@ def modified_alsd_torch( fusion_scores_list[fusion_model_idx] = fusion_scores # step 6: update time indices + active mask - time_indices.copy_(batched_hyps.next_timestamp) + time_indices = torch.gather(time_indices, dim=-1, index=hyps_indices) + next_label_durations torch.minimum(time_indices, last_timesteps, out=safe_time_indices) torch.less_equal(time_indices, last_timesteps, out=active_mask) + + step1 += 1 + + # fix timestamps for iterative decoding + # Add offset to timestamps so they are cumulative across chunks + if prev_batched_state is not None: + batched_hyps.timestamps += prev_batched_state.decoded_lengths[:, None, None].expand_as(batched_hyps.timestamps) + # Also update next_timestamp for proper continuation + # batched_hyps.next_timestamp += prev_batched_state.decoded_lengths[:, None].expand_as(batched_hyps.next_timestamp) + + # NB: last labels can not exist (nothing decoded on this step). + # return the last labels from the previous state in this case + # import pdb; pdb.set_trace() + last_labels = batched_hyps.get_last_labels(pad_id=self._SOS) + # batched_hyps.next_timestamp.copy_(batched_hyps.next_timestamp - encoder_output_length.unsqueeze(-1)) + decoding_state = BatchedLabelLoopingState( + predictor_states=decoder_state, + predictor_outputs=decoder_output, + labels=( + torch.where(last_labels == self._SOS, prev_batched_state.labels, last_labels) + if prev_batched_state is not None + else last_labels + ), + decoded_lengths=( + encoder_output_length.clone() + if prev_batched_state is None + else encoder_output_length + prev_batched_state.decoded_lengths + ), + fusion_states_list=fusion_states_list if self.fusion_models is not None else None, + time_jumps=None, # Not needed for beam search since we use batched_hyps.next_timestamp + batched_hyps=batched_hyps, # Save batched_hyps object for next chunk + ) - return batched_hyps + return batched_hyps, None, decoding_state def topk_fusion_model(self, fusion_scores_list, log_probs, duration_log_probs, eps=1e-2): """ @@ -747,7 +850,8 @@ def modified_alsd_cuda_graphs( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: """ Cuda-Graphs implementation of the batched ALSD algorithm. Args: @@ -755,8 +859,9 @@ def modified_alsd_cuda_graphs( [batch_size, max_time, encoder_dim]. encoder_output_length (torch.Tensor): The lengths of the encoder outputs for each batch with shape [batch_size]. + prev_batched_state (Optional[BatchedLabelLoopingState]): The previous batched state. Returns: - BathedBeamHyps: Batched beam hypotheses. + tuple: (BatchedBeamHyps, None, BatchedLabelLoopingState) """ assert self.cuda_graphs_mode is not None @@ -773,29 +878,52 @@ def modified_alsd_cuda_graphs( if self.state is None or self.state.need_reinit(encoder_output): self._graph_reinitialize(encoder_output, encoder_output_length) + # Set continuation flag and restore state from previous chunk if provided + is_continuation = prev_batched_state is not None + self.state.is_continuation.fill_(is_continuation) + + if is_continuation: + # Restore state from previous chunk + self._restore_state_from_prev(prev_batched_state, current_batch_size) + # set length to zero for elements outside the current batch self.state.encoder_output_length.fill_(0) # copy (projected) encoder output and lenghts self.state.encoder_output_projected[:current_batch_size, :current_max_time, ...].copy_(encoder_output) self.state.encoder_output_length[:current_batch_size].copy_(encoder_output_length.unsqueeze(-1)) + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: - self.full_graph.replay() + # Use continuation graph if continuing from previous chunk, otherwise first chunk graph + if is_continuation: + self.full_graph_continuation.replay() + else: + self.full_graph.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: - self.separate_graphs.before_loop.replay() + # Use continuation before_loop graph if continuing from previous chunk + if is_continuation: + self.separate_graphs.before_loop_continuation.replay() + else: + self.separate_graphs.before_loop.replay() while self.state.active_mask_any.item(): self.separate_graphs.loop_body.replay() self.separate_graphs.loop_update_decoder.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: # this mode is only for testing purposes # manual loop instead of using graphs - self._before_loop() + if is_continuation: + self._before_loop_continuation() + else: + self._before_loop() while self.state.active_mask_any.item(): self._loop_body() self._loop_update_decoder() else: raise NotImplementedError(f"Unknown graph mode: {self.cuda_graphs_mode}") - return self.state.batched_hyps + # Create and return decoding state for next chunk + decoding_state = self._create_decoding_state(encoder_output_length, prev_batched_state) + + return self.state.batched_hyps, None, decoding_state @classmethod def _create_loop_body_kernel(cls): @@ -954,6 +1082,17 @@ def _partial_graphs_compile(self): ): self._before_loop() + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph( + self.separate_graphs.before_loop_continuation, + stream=stream_for_graph, + capture_error_mode="thread_local", + ), + ): + self._before_loop_continuation() + with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), @@ -1005,6 +1144,34 @@ def _full_graph_compile(self): self._loop_body() self._loop_update_decoder() + # Compile continuation graph for streaming + self.full_graph_continuation = torch.cuda.CUDAGraph() + + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph(self.full_graph_continuation, stream=stream_for_graph, capture_error_mode="thread_local"), + ): + self._before_loop_continuation() + capture_status, _, graph, _, _, _ = cu_call( + cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) + ) + + assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive + + # capture: while self.active_mask_any: + (loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) + loop_kernel = self._create_loop_body_kernel() + active_mask_any_ptr = np.array([self.state.active_mask_any.data_ptr()], dtype=np.uint64) + loop_args = np.array( + [loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], + dtype=np.uint64, + ) + # loop while there are active utterances + with with_conditional_node(loop_kernel, loop_args, loop_conditional_handle, device=self.state.device): + self._loop_body() + self._loop_update_decoder() + def _before_loop(self): """ Clears state and compute initial active mask @@ -1024,6 +1191,31 @@ def _before_loop(self): # last found labels - initially () symbol self.state.last_labels_wb.fill_(self._SOS) + + # set decoder state and output to initial values + self.state.decoder_output.copy_(self.state.init_decoder_output) + self.state.decoder_state[0].copy_(self.state.init_decoder_state[0]) + self.state.decoder_state[1].copy_(self.state.init_decoder_state[1]) + + self._before_loop_common() + + def _before_loop_continuation(self): + """ + Prepares state for continuation chunk without clearing batched_hyps. + Decoder state and fusion states are already restored from previous chunk + via _restore_state_from_prev before this method is called. + """ + # Don't clear batched_hyps - it's already restored from previous state + # Don't reset decoder state - it's already restored from previous state + # Don't reset fusion states - they're already restored from previous state + + self._before_loop_common() + + def _before_loop_common(self): + """ + Common initialization for both first chunk and continuation. + Resets temporary variables and computes active mask. + """ self.state.next_scores.fill_(0.0) self.state.next_labels.fill_(0.0) self.state.next_idx.fill_(0.0) @@ -1042,10 +1234,6 @@ def _before_loop(self): # same as: self.active_mask_any = active_mask.any() torch.any(self.state.active_mask, out=self.state.active_mask_any) - self.state.decoder_output.copy_(self.state.init_decoder_output) - self.state.decoder_state[0].copy_(self.state.init_decoder_state[0]) - self.state.decoder_state[1].copy_(self.state.init_decoder_state[1]) - self.state.prev_decoder_output.fill_(0) self.state.prev_decoder_state[0].fill_(0) self.state.prev_decoder_state[1].fill_(0) @@ -1295,13 +1483,133 @@ def _loop_update_decoder(self): torch.less_equal(self.state.time_indices, self.state.last_timestamps, out=self.state.active_mask) torch.any(self.state.active_mask, out=self.state.active_mask_any) + def _restore_state_from_prev( + self, prev_batched_state: BatchedLabelLoopingState, current_batch_size: int + ): + """ + Restore decoder state, fusion states, and batched_hyps from previous chunk's state. + Used for streaming/chunked decoding. + + Args: + prev_batched_state: State from previous chunk + current_batch_size: Current batch size + """ + # Restore decoder output and state + if prev_batched_state.predictor_outputs is not None: + self.state.decoder_output[:current_batch_size * self.beam_size].copy_( + prev_batched_state.predictor_outputs.view(-1, 1, prev_batched_state.predictor_outputs.shape[-1]) + ) + + if prev_batched_state.predictor_states is not None: + # Copy decoder states (assuming tuple of tensors) + for i, state_tensor in enumerate(prev_batched_state.predictor_states): + if state_tensor is not None: + self.state.decoder_state[i][:, :current_batch_size * self.beam_size].copy_( + state_tensor[:, :current_batch_size * self.beam_size] + ) + + # Restore fusion states if present + if prev_batched_state.fusion_states_list is not None and self.fusion_models is not None: + for fusion_idx, fusion_state in enumerate(prev_batched_state.fusion_states_list): + if fusion_state is not None: + self.state.fusion_states_list[fusion_idx][:current_batch_size].copy_( + fusion_state[:current_batch_size] + ) + # Recompute fusion scores and candidates from restored states + fusion_scores, fusion_states_candidates = self.fusion_models[fusion_idx].advance( + states=self.state.fusion_states_list[fusion_idx][:current_batch_size].reshape(-1) + ) + self.state.fusion_scores_list[fusion_idx][:current_batch_size].copy_( + fusion_scores.to(dtype=self.state.float_dtype).view(current_batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_idx] + ) + self.state.fusion_states_candidates_list[fusion_idx][:current_batch_size].copy_( + fusion_states_candidates.view(current_batch_size, self.beam_size, -1) + ) + + # Restore batched_hyps from previous state + if prev_batched_state.batched_hyps is not None: + self.state.batched_hyps.copy_from_(prev_batched_state.batched_hyps) + + # Restore decoded_lengths + if prev_batched_state.decoded_lengths is not None: + self.state.decoded_lengths[:current_batch_size].copy_( + prev_batched_state.decoded_lengths[:current_batch_size] + ) + + def _create_decoding_state( + self, + encoder_output_length: torch.Tensor, + prev_batched_state: Optional[BatchedLabelLoopingState], + ) -> BatchedLabelLoopingState: + """ + Create BatchedLabelLoopingState for the next chunk. + + Args: + encoder_output_length: Length of current encoder output + prev_batched_state: State from previous chunk (if any) + + Returns: + BatchedLabelLoopingState containing current decoding state + """ + current_batch_size = encoder_output_length.shape[0] + + # Get last labels from batched_hyps + last_labels = self.state.batched_hyps.get_last_labels(pad_id=self._SOS) + + # Reset next_timestamp for next chunk + self.state.batched_hyps.next_timestamp.fill_(0) + + # Calculate accumulated decoded lengths + if prev_batched_state is None: + decoded_lengths = encoder_output_length.clone() + else: + decoded_lengths = encoder_output_length + prev_batched_state.decoded_lengths[:current_batch_size] + + # Handle labels - if nothing decoded this chunk, use previous labels + if prev_batched_state is not None: + last_labels = torch.where( + last_labels == self._SOS, + prev_batched_state.labels, + last_labels + ) + + # Get fusion states if present + fusion_states_list = None + if self.fusion_models is not None and self.state.fusion_states_list is not None: + fusion_states_list = [ + state[:current_batch_size].clone() for state in self.state.fusion_states_list + ] + + return BatchedLabelLoopingState( + predictor_states=( + self.state.decoder_state[0][:, :current_batch_size * self.beam_size].clone(), + self.state.decoder_state[1][:, :current_batch_size * self.beam_size].clone(), + ), + predictor_outputs=self.state.decoder_output[:current_batch_size * self.beam_size].clone(), + labels=last_labels, + decoded_lengths=decoded_lengths, + fusion_states_list=fusion_states_list, + time_jumps=None, + batched_hyps=self.state.batched_hyps.clone(), + ) + def __call__( self, x: torch.Tensor, out_len: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: + # self.cuda_graphs_mode = None if self.cuda_graphs_mode is not None and x.device.type == "cuda": with torch.amp.autocast(device_type="cuda", enabled=False): - return self.modified_alsd_cuda_graphs(encoder_output=x, encoder_output_length=out_len) + batched_hyps, alignments, decoding_state = self.modified_alsd_cuda_graphs( + encoder_output=x, + encoder_output_length=out_len, + prev_batched_state=prev_batched_state + ) + return batched_hyps, alignments, decoding_state - return self.modified_alsd_torch(encoder_output=x, encoder_output_length=out_len) + return self.modified_alsd_torch( + encoder_output=x, encoder_output_length=out_len, prev_batched_state=prev_batched_state + ) diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index 3fe2e8db18b7..ccfe24adfa27 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -199,6 +199,78 @@ def clear_(self): self.next_timestamp.fill_(0) self.last_timestamp_lasts.fill_(0) + def copy_from_(self, other: "BatchedBeamHyps"): + """ + Copy state from another BatchedBeamHyps object (in-place). + Used for streaming/chunked decoding to restore state from previous chunk. + + Args: + other: Source BatchedBeamHyps to copy from + """ + batch_size = min(self.batch_size, other.batch_size) + beam_size = min(self.beam_size, other.beam_size) + max_length = min(self._max_length, other._max_length) + + self.current_lengths_nb[:batch_size, :beam_size].copy_( + other.current_lengths_nb[:batch_size, :beam_size] + ) + self.current_lengths_wb[:batch_size, :beam_size].copy_( + other.current_lengths_wb[:batch_size, :beam_size] + ) + self.transcript_wb[:batch_size, :beam_size, :max_length].copy_( + other.transcript_wb[:batch_size, :beam_size, :max_length] + ) + self.transcript_wb_prev_ptr[:batch_size, :beam_size, :max_length].copy_( + other.transcript_wb_prev_ptr[:batch_size, :beam_size, :max_length] + ) + self.scores[:batch_size, :beam_size].copy_( + other.scores[:batch_size, :beam_size] + ) + self.last_label[:batch_size, :beam_size].copy_( + other.last_label[:batch_size, :beam_size] + ) + self.transcript_hash[:batch_size, :beam_size].copy_( + other.transcript_hash[:batch_size, :beam_size] + ) + + if self.store_prefix_hashes and other.store_prefix_hashes: + self.transcript_prefix_hash[:batch_size, :beam_size].copy_( + other.transcript_prefix_hash[:batch_size, :beam_size] + ) + + self.timestamps[:batch_size, :beam_size, :max_length].copy_( + other.timestamps[:batch_size, :beam_size, :max_length] + ) + + if self.model_type != ASRModelTypeEnum.CTC: + self.next_timestamp[:batch_size, :beam_size].copy_( + other.next_timestamp[:batch_size, :beam_size] + ) + self.last_timestamp_lasts[:batch_size, :beam_size].copy_( + other.last_timestamp_lasts[:batch_size, :beam_size] + ) + + def clone(self) -> "BatchedBeamHyps": + """ + Create a deep copy of this BatchedBeamHyps object. + Used for streaming/chunked decoding to save state for next chunk. + + Returns: + New BatchedBeamHyps with copied state + """ + new_hyps = BatchedBeamHyps( + batch_size=self.batch_size, + beam_size=self.beam_size, + init_length=self._max_length, + blank_index=self.blank_index, + device=self.device, + float_dtype=self.scores.dtype, + store_prefix_hashes=self.store_prefix_hashes, + model_type=self.model_type, + ) + new_hyps.copy_from_(self) + return new_hyps + def get_last_labels(self, pad_id: int = -1) -> torch.Tensor: """ Get last labels for each hypothesis in the beam. @@ -385,9 +457,15 @@ def recombine_hyps_(self): Note: The method modifies the `self.scores` tensor in place to reflect the recombined hypotheses. """ + # print(f"DEBUG: Entering recombine_hyps_. batch_size={self.batch_size}, beam_size={self.beam_size}") + if self.beam_size <= 1: return + # print(f"DEBUG: transcript_hash shape: {self.transcript_hash.shape}") + # print(f"DEBUG: last_label shape: {self.last_label.shape}") + # print(f"DEBUG: current_lengths_nb shape: {self.current_lengths_nb.shape}") + hyps_equal = ( (self.transcript_hash[:, :, None] == self.transcript_hash[:, None, :]) & (self.last_label[:, :, None] == self.last_label[:, None, :]) @@ -395,14 +473,24 @@ def recombine_hyps_(self): ) if self.model_type == ASRModelTypeEnum.TDT: + # print(f"DEBUG: TDT model type. next_timestamp shape: {self.next_timestamp.shape}") hyps_equal &= self.next_timestamp[:, :, None] == self.next_timestamp[:, None, :] + # print(f"DEBUG: hyps_equal shape: {hyps_equal.shape}") + + # print(f"DEBUG: self.scores : {self.scores}") + scores_matrix = torch.where( hyps_equal, self.scores[:, None, :].expand(self.batch_size, self.beam_size, self.beam_size), self.INACTIVE_SCORE_TENSOR, ) + # print(f"DEBUG: scores_matrix : {scores_matrix}") + # print(f"DEBUG: scores_matrix shape: {scores_matrix.shape}") + # print(f"DEBUG: scores_matrix : {scores_matrix}") + scores_argmax = scores_matrix.argmax(-1, keepdim=False) + # print(f"DEBUG: scores_argmax shape: {scores_argmax.shape}, min: {scores_argmax.min()}, max: {scores_argmax.max()}") scores_to_keep = ( torch.arange(self.beam_size, device=scores_argmax.device, dtype=torch.long)[None, :] == scores_argmax ) @@ -410,8 +498,15 @@ def recombine_hyps_(self): new_scores = torch.max(scores_matrix, dim=-1, keepdim=False).values else: new_scores = torch.logsumexp(scores_matrix, dim=-1, keepdim=False) + + # print(f"DEBUG: new_scores shape: {new_scores.shape}") + # print(f"DEBUG: scores_to_keep shape: {scores_to_keep.shape}") + # print(f"DEBUG: self.scores shape: {self.scores.shape}") + torch.where(scores_to_keep, new_scores.to(self.scores.dtype), self.INACTIVE_SCORE_TENSOR, out=self.scores) + # print("DEBUG: Exiting recombine_hyps_") + def remove_duplicates(self, labels: torch.Tensor, total_logps: torch.Tensor): """ Removes duplicate hypotheses that may arise after updating beam hypotheses with labels during the beam search process. @@ -659,7 +754,8 @@ def flatten_sort_(self, score_norm: bool = True): 4. Updates the internal state of the object, including transcripts, timestamps, scores, lengths, labels, and other metadata, based on the sorted order. """ - + + # import pdb; pdb.set_trace() # add one for consistency with non-batched decodings, that use SOS. normalized_scores = ( self.scores / (self.current_lengths_nb.to(self.scores.dtype) + 1) if score_norm else self.scores From 06273e38c87640f006ac50515e1098795231cc28 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 14 May 2026 13:15:42 +0400 Subject: [PATCH 04/13] save before refactor Signed-off-by: lilithgrigoryan --- .../speech_to_text_streaming_infer_rnnt.py | 31 +++++++++++++------ 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index 9c04fbac1a11..95ae0f9dfa2f 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -547,16 +547,27 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: encoder_output = encoder_output[:, encoder_context.left :] # decode only chunk frames - chunk_batched_hyps, _, state = decoding_computer( - x=encoder_output, - out_len=torch.where( - is_last_chunk_batch, - encoder_output_len - encoder_context_batch.left, - encoder_context_batch.chunk, - ), - prev_batched_state=state, - multi_biasing_ids=multi_biasing_ids, - ) + if isinstance(decoding_computer, ModifiedALSDBatchedTDTComputer) or isinstance(decoding_computer, ModifiedAESBatchedRNNTComputer) or isinstance(decoding_computer, ModifiedALSDBatchedRNNTComputer): + chunk_batched_hyps, _, state = decoding_computer( + x=encoder_output, + out_len=torch.where( + is_last_chunk_batch, + encoder_output_len - encoder_context_batch.left, + encoder_context_batch.chunk, + ), + prev_batched_state=state, + ) + else: + chunk_batched_hyps, _, state = decoding_computer( + x=encoder_output, + out_len=torch.where( + is_last_chunk_batch, + encoder_output_len - encoder_context_batch.left, + encoder_context_batch.chunk, + ), + prev_batched_state=state, + multi_biasing_ids=multi_biasing_ids, + ) # ======================================================================== # Handle hypothesis accumulation differently for beam search vs greedy From 0e2092a3e0dc399da98790c7e3cb584e351e5053 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 14 May 2026 23:10:49 +0400 Subject: [PATCH 05/13] working without full graphs Signed-off-by: lilithgrigoryan --- .../speech_to_text_streaming_infer_rnnt.py | 40 ++++- .../submodules/rnnt_malsd_batched_computer.py | 37 +++-- .../utils/batched_beam_decoding_utils.py | 148 +++++++++++++++--- 3 files changed, 193 insertions(+), 32 deletions(-) diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index 95ae0f9dfa2f..acfb9b778452 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -299,7 +299,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer elif cfg.decoding.strategy == "malsd_batch": # Beam search strategies use _decoding_computer (private attribute) - is_tdt = True + is_tdt = False print(f"is_tdt: {is_tdt}") if is_tdt: decoding_computer: ModifiedALSDBatchedTDTComputer = asr_model.decoding.decoding._decoding_computer @@ -570,10 +570,42 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: ) # ======================================================================== - # Handle hypothesis accumulation differently for beam search vs greedy + # Handle hypothesis accumulation differently for beam search vs greedy. if is_beam_search: - # For beam search: same object reused across chunks (stored in state) - current_batched_hyps = chunk_batched_hyps + # For beam search the chunk-local transcript buffers inside + # ``decoding_computer.state.batched_hyps`` are reused across chunks + # (and reset at the start of every continuation chunk), so we need to + # snapshot the per-chunk transcripts before the next call overwrites + # them and merge them into an external accumulator. + # + # ``flatten_`` resolves the chunk-local prefix tree without sorting + # (preserving the chunk's beam ordering) and returns ``root_ptrs``: for + # each chunk-end beam ``i``, the beam index at the chunk's start that + # this hypothesis ultimately descends from. Because the chunk's loop + # body can permute beams at any step via the top-K gather, beam ``i`` + # at the end of chunk ``N`` is generally a different logical + # hypothesis from beam ``i`` at the end of chunk ``N-1``: it is the + # descendant of beam ``root_ptrs[i]`` from chunk ``N-1``. Threading + # ``root_ptrs`` into the accumulator's ``transcript_wb_prev_ptr`` at + # the chunk boundary (via ``merge_(..., boundary_prev_ptr=...)``) + # encodes this redirection in the prefix tree so the final + # ``flatten_sort_`` inside ``to_hyps_list`` walks back through the + # right beam history at every chunk boundary. + # + # The cross-chunk per-beam state (``scores``, ``current_lengths_nb``, + # ...) is already cumulative on each chunk's hyps; pass + # ``is_chunk_continuation=True`` so ``merge_`` replaces (not sums) + # those fields. + chunk_snapshot = chunk_batched_hyps.clone() + chunk_root_ptrs = chunk_snapshot.flatten_() + if current_batched_hyps is None: + current_batched_hyps = chunk_snapshot + else: + current_batched_hyps.merge_( + chunk_snapshot, + is_chunk_continuation=True, + boundary_prev_ptr=chunk_root_ptrs, + ) else: # For greedy: merge chunks using merge_ if current_batched_hyps is None: diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index 985b88041fc6..d1132f558be5 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -777,6 +777,7 @@ def modified_alsd_cuda_graphs( """ assert self.cuda_graphs_mode is not None + self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS # do not recalculate joint projection, project only once encoder_output = self.joint.project_encoder(encoder_output) @@ -962,6 +963,7 @@ def _graph_reinitialize( if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: try: self._full_graph_compile() + print("full_graph_compile") except NeMoCUDAPythonException as e: if not self.cuda_graphs_allow_fallback: raise RuntimeError("Full CUDA graph decoding failed. Mode is forced, raising exception") from e @@ -1066,7 +1068,7 @@ def _full_graph_compile(self): torch.cuda.graph(self.full_graph_continuation, stream=stream_for_graph, capture_error_mode="thread_local"), ): self._before_loop_continuation() - capture_status, _, graph, _, _, _ = cu_call( + capture_status, _, graph, *_ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) ) @@ -1114,14 +1116,32 @@ def _before_loop(self): def _before_loop_continuation(self): """ - Prepares state for continuation chunk without clearing batched_hyps. - Decoder state and fusion states are already restored from previous chunk - via _restore_state_from_prev before this method is called. + Prepares state for continuation chunk without clearing the cross-chunk per-beam + state of ``batched_hyps``. + + Decoder state and fusion states are already restored from the previous chunk via + ``_restore_state_from_prev`` before this method is called. ``batched_hyps`` is + likewise restored from the previous chunk so that ``scores``, ``last_label``, + ``transcript_hash``, ``current_lengths_nb`` and ``last_timestamp_lasts`` continue + the beam search across the chunk boundary. + + However, the per-chunk transcript prefix tree (``transcript_wb`` / + ``transcript_wb_prev_ptr`` / ``timestamps``) and the write cursor into it + (``current_lengths_wb``) must be reset for the new chunk: their buffers are sized + for one chunk's worth of decoding only, and the captured ``add_results_no_checks_`` + scatters into them at ``current_lengths_wb`` without bounds checking. If we kept + the previous chunk's cursor we would index out of bounds within a few chunks, + producing a CUDA illegal-memory-access from inside the captured graph. The caller + is responsible for snapshotting / merging the per-chunk transcripts before this + method runs at the start of the next chunk (see :meth:`BatchedBeamHyps.merge_` + with ``is_chunk_continuation=True``). """ - # Don't clear batched_hyps - it's already restored from previous state - # Don't reset decoder state - it's already restored from previous state - # Don't reset fusion states - they're already restored from previous state - + # Reset chunk-local storage so the captured loop body writes into freshly zeroed + # buffers; cross-chunk per-beam state is preserved. + self.state.batched_hyps.clear_chunk_local_() + + # Decoder state and fusion states are already restored from previous state. + self._before_loop_common() def _before_loop_common(self): @@ -1250,7 +1270,6 @@ def _loop_update_decoder(self): Updates the decoder state, decoder output, and optionally the fusion models state for the next iteration of the decoding loop in a batched RNNT (Recurrent Neural Network Transducer) setup. """ - # step 5: update decoder state + decoder output (+ fusion models state/scores) # step 5.1: mask invalid value labels with blank to avoid errors (refer to step 2.2) torch.where( diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index ccfe24adfa27..180c74ec91b4 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -199,6 +199,36 @@ def clear_(self): self.next_timestamp.fill_(0) self.last_timestamp_lasts.fill_(0) + def clear_chunk_local_(self): + """ + Reset only the chunk-local storage so the per-chunk transcript buffers can be + reused for the next chunk in streaming/chunked beam decoding under CUDA graphs. + + The chunk-local storage is the transcript prefix tree (``transcript_wb``, + ``transcript_wb_prev_ptr``), the per-step timestamps array (``timestamps``), the + write cursor into those buffers (``current_lengths_wb``) and, for transducer + models, the per-chunk ``next_timestamp`` counter. + + Cross-chunk per-beam state - ``scores``, ``last_label``, ``transcript_hash``, + ``current_lengths_nb`` and ``last_timestamp_lasts`` - is intentionally left + untouched so the beam-search loop can continue from the previous chunk's beam + state without re-seeding. + + This method does only in-place ``fill_`` / ``copy_`` writes into the existing + tensors, so it is safe to call from inside a captured CUDA-graph region (the + captured pointers remain valid). + """ + self.current_lengths_wb.fill_(0) + + self.transcript_wb.fill_(NON_EXISTENT_LABEL_VALUE) + self.transcript_wb_prev_ptr.fill_(INIT_POINTER_VALUE) + + if self.model_type == ASRModelTypeEnum.CTC: + self.timestamps.copy_(self._create_timestamps_tensor(self._max_length)) + else: + self.timestamps.fill_(0) + self.next_timestamp.fill_(0) + def copy_from_(self, other: "BatchedBeamHyps"): """ Copy state from another BatchedBeamHyps object (in-place). @@ -754,14 +784,54 @@ def flatten_sort_(self, score_norm: bool = True): 4. Updates the internal state of the object, including transcripts, timestamps, scores, lengths, labels, and other metadata, based on the sorted order. """ - - # import pdb; pdb.set_trace() + # add one for consistency with non-batched decodings, that use SOS. normalized_scores = ( self.scores / (self.current_lengths_nb.to(self.scores.dtype) + 1) if score_norm else self.scores ) - normalized_scores, indices = torch.sort(normalized_scores, dim=-1, descending=True) + _, indices = torch.sort(normalized_scores, dim=-1, descending=True) + self._flatten_with_permutation_(indices) + + def flatten_(self) -> torch.Tensor: + """ + Flatten the tree structure of hypotheses without changing beam order. + + Like :meth:`flatten_sort_` but uses the identity permutation, so beam ``i`` keeps + its identity (its decoded prefix and its cross-chunk per-beam state stay aligned + with the corresponding beam in any other ``BatchedBeamHyps`` constructed under the + same decoding run). Required for inter-chunk :meth:`merge_` calls in streaming + beam decoding where beam indices must correspond across chunks. + + Returns: + ``root_ptrs`` of shape ``[batch_size, beam_size]``: the beam index at the + chunk's *start* (i.e. before the first ``add_results_*`` write) from which + each output beam ultimately descends. For chunked streaming beam search, this + tells the caller how to permute the previous chunks' accumulated per-beam + transcripts so they align with this chunk's beam ordering before merging. + + If the prefix tree is empty (``current_lengths_wb.max() == 0``) the identity + permutation is returned. + """ + identity = self.beam_indices.unsqueeze(0).expand(self.batch_size, self.beam_size).contiguous() + return self._flatten_with_permutation_(identity) + + def _flatten_with_permutation_(self, indices: torch.Tensor) -> torch.Tensor: + """ + In-place flatten of the prefix tree using ``indices`` as the new beam permutation. + + Walks ``transcript_wb_prev_ptr`` from the most recent step back to step 0, + gathering tokens and timestamps for each output beam from the source beam given + by ``indices``. Updates all per-beam metadata to match the new ordering. + Args: + indices: ``[batch_size, beam_size]`` long tensor giving the source beam index + for each output beam (e.g. ``arange(beam_size)`` for no permutation). + + Returns: + ``root_ptrs`` of shape ``[batch_size, beam_size]``: the beam index *before* + step 0 of the prefix tree from which each output beam descends. If the prefix + tree is empty (``max_idx < 0``) this equals ``indices``. + """ max_idx = self.current_lengths_wb.max() - 1 ptrs = indices @@ -786,6 +856,8 @@ def flatten_sort_(self, score_norm: bool = True): if self.store_prefix_hashes: self.transcript_prefix_hash.copy_(torch.gather(self.transcript_prefix_hash, dim=-1, index=indices)) + return ptrs + def _create_fold_consecutive_mask(self, transcript): """ Creates a mask to filter consecutive duplicates, blanks, and invalid tokens in a transcript. @@ -835,25 +907,46 @@ def _create_transcripts_mask(self, transcripts: torch.Tensor): else: return (transcripts >= 0) & (transcripts != self.blank_index) - def merge_(self, other: "BatchedBeamHyps") -> "BatchedBeamHyps": + def merge_( + self, + other: "BatchedBeamHyps", + is_chunk_continuation: bool = False, + boundary_prev_ptr: Optional[torch.Tensor] = None, + ) -> "BatchedBeamHyps": """ Merge two batched beam hypotheses structures by concatenating transcripts. Used for streaming/chunked inference where results from multiple chunks need to be combined. - + Prerequisites: - Both self and other should have been processed with flatten_sort_() before merging, so that each beam contains an independent flattened hypothesis. - Beam indices should correspond across chunks (beam i in self matches beam i in other). - + Notes: - Timestamps in 'other' should already be cumulative (adjusted for time offset). - The transcript_hash values are copied from 'other' and won't reflect the full merged transcript. This means recombine_hyps_() should NOT be called on merged results without recomputing hashes. This is acceptable for output-only use. - + Args: - other: BatchedBeamHyps from the next chunk to merge - + other: BatchedBeamHyps from the next chunk to merge. + is_chunk_continuation: If True, treat ``other`` as a beam-search continuation + chunk in which the cross-chunk per-beam fields (``scores``, + ``current_lengths_nb``) already hold cumulative across-chunks values rather + than chunk-local deltas. In that case those fields are *replaced* with the + values from ``other`` instead of summed, to avoid double-counting. The + default (False) preserves the original "deltas" semantics used by greedy + streaming-style merges. + boundary_prev_ptr: Optional ``[batch_size, beam_size]`` long tensor. When + provided, written into ``transcript_wb_prev_ptr`` at the very first + position of the merged region (i.e. at ``self.current_lengths_wb`` before + the update). All other positions of the merged region still receive + ``beam_indices`` (identity) pointers. This is how chunked streaming beam + search threads the cross-chunk beam permutation (the "root ptrs" returned + by :meth:`flatten_` on ``other``) into the accumulator's prefix tree so + that the final :meth:`flatten_sort_` walk redirects from beam ``i`` in + ``other``'s region back to its source beam in ``self``'s region. + Returns: Self (modified in-place) """ @@ -885,12 +978,22 @@ def merge_(self, other: "BatchedBeamHyps") -> "BatchedBeamHyps": src=other.transcript_wb[..., :max_other_len], ) - # Update pointers: for the merged portion, we set pointers to point to the same beam - # (since after flatten_sort_, each beam is independent) + # Update pointers: in the merged region every position points to its own beam + # (identity), except the *first* merged position which optionally encodes the + # cross-chunk root permutation so the final flatten walk redirects from the new + # region back to the right beam in the old region. + identity_src = self.beam_indices.view(1, self.beam_size, 1).expand( + self.batch_size, -1, max_other_len + ) + if boundary_prev_ptr is not None: + ptr_src = identity_src.clone() + ptr_src[..., 0] = boundary_prev_ptr + else: + ptr_src = identity_src self.transcript_wb_prev_ptr.scatter_( dim=-1, index=shifted_indices, - src=self.beam_indices.view(1, self.beam_size, 1).expand(self.batch_size, -1, max_other_len), + src=ptr_src, ) # Scatter timestamps @@ -899,14 +1002,21 @@ def merge_(self, other: "BatchedBeamHyps") -> "BatchedBeamHyps": index=shifted_indices, src=other.timestamps[..., :max_other_len], ) - - # Update lengths + + # Lengths in the chunk-local write cursor are always additive (``other`` always + # reports a chunk-local ``current_lengths_wb``). self.current_lengths_wb += other.current_lengths_wb - self.current_lengths_nb += other.current_lengths_nb - - # Update scores (log probabilities, so we add them) - self.scores += other.scores - + + if is_chunk_continuation: + # Beam-search streaming: ``other`` carries cumulative cross-chunk state in + # these fields, so replace rather than accumulate. + self.current_lengths_nb.copy_(other.current_lengths_nb) + self.scores.copy_(other.scores) + else: + # Original ("deltas") semantics. + self.current_lengths_nb += other.current_lengths_nb + self.scores += other.scores + # Update transcript hash by combining hashes # The hash of the merged transcript should account for all non-blank labels self.transcript_hash.copy_(other.transcript_hash) From 2003efa149585e734400cac139904e07af04e936 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 14 May 2026 23:53:12 +0400 Subject: [PATCH 06/13] unify full graph into single graph, working Signed-off-by: lilithgrigoryan --- .../submodules/rnnt_malsd_batched_computer.py | 125 +++++++++++++----- .../transducer_decoding/rnnt_label_looping.py | 2 + nemo/core/utils/cuda_python_utils.py | 27 +++- 3 files changed, 113 insertions(+), 41 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index d1132f558be5..319ed63e9f92 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -193,6 +193,9 @@ def __init__( # Streaming state fields self.is_continuation = torch.tensor(False, device=self.device, dtype=torch.bool) + # Inverse flag used by the captured graph to route the prologue. Maintained in + # lockstep with ``is_continuation`` outside the graph (see ``_set_continuation``). + self.is_first_chunk = torch.tensor(True, device=self.device, dtype=torch.bool) self.decoded_lengths = torch.zeros([self.batch_size], dtype=torch.long, device=self.device) def need_reinit(self, encoder_output_projected: torch.Tensor) -> bool: @@ -777,7 +780,7 @@ def modified_alsd_cuda_graphs( """ assert self.cuda_graphs_mode is not None - self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS + self.cuda_graphs_mode = self.CudaGraphsMode.FULL_GRAPH # do not recalculate joint projection, project only once encoder_output = self.joint.project_encoder(encoder_output) @@ -794,6 +797,10 @@ def modified_alsd_cuda_graphs( # Set continuation flag and restore state from previous chunk if provided is_continuation = prev_batched_state is not None self.state.is_continuation.fill_(is_continuation) + # Mirror into the inverse flag so the captured graph's IF nodes can route to + # the right prologue. Both tensors are read by ``loop_conditional``-style + # condition kernels baked into the graph at capture time. + self.state.is_first_chunk.fill_(not is_continuation) if is_continuation: # Restore state from previous chunk @@ -809,11 +816,9 @@ def modified_alsd_cuda_graphs( self.state.encoder_output_length[:current_batch_size].copy_(encoder_output_length.unsqueeze(-1)) if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: - # Use continuation graph if continuing from previous chunk, otherwise first chunk graph - if is_continuation: - self.full_graph_continuation.replay() - else: - self.full_graph.replay() + # Single graph dispatches between first-chunk and continuation prologues internally + # via captured IF nodes that read ``is_first_chunk`` / ``is_continuation``. + self.full_graph.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: # Use continuation before_loop graph if continuing from previous chunk if is_continuation: @@ -960,6 +965,10 @@ def _graph_reinitialize( self.state.fusion_scores_list.append(self.state.init_fusion_scores_list[fusion_model_idx].clone()) self.state.fusion_states_prev_list.append(init_fusion_states.clone()) + # warmup before graph compilation + if self.cuda_graphs_mode is not self.CudaGraphsMode.NO_GRAPHS: + self._warmup_for_cuda_graphs() + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: try: self._full_graph_compile() @@ -981,6 +990,37 @@ def _graph_reinitialize( else: raise NotImplementedError + def _warmup_for_cuda_graphs(self): + """Warmup before compiling CUDA graphs. + + Runs a few eager iterations of both the first-chunk and continuation paths so that + cuBLAS / cuDNN handles and workspaces are allocated and stable before any graph + capture begins. Mirrors the warmup pattern used by the greedy label-looping decoder. + """ + is_ddp = torch.distributed.is_available() and torch.distributed.is_initialized() + # 11 warmup steps required in DDP mode + # see https://pytorch.org/docs/stable/notes/cuda.html#usage-with-distributeddataparallel + num_runs = 11 if is_ddp else 3 + self.state.encoder_output_projected.fill_(0.0) + self.state.encoder_output_length.fill_(1) + s = torch.cuda.Stream(self.state.device) + s.wait_stream(torch.cuda.current_stream(device=self.state.device)) + with torch.cuda.stream(s), torch.inference_mode(): + # Warm up the first-chunk path. + for _ in range(num_runs): + self._before_loop() + self._loop_body() + self._loop_update_decoder() + # Warm up the continuation path so its prologue and any kernels it touches + # are primed too. Both captures share a mempool, so any allocator activity + # they trigger needs to settle before either is captured. + for _ in range(num_runs): + self._before_loop_continuation() + self._loop_body() + self._loop_update_decoder() + torch.cuda.current_stream(device=self.state.device).wait_stream(s) + self.state.encoder_output_length.fill_(0) + def _partial_graphs_compile(self): """Compile decoding by parts""" # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. @@ -1028,9 +1068,23 @@ def _partial_graphs_compile(self): @cuda_python_required def _full_graph_compile(self): - """Compile full graph for decoding""" + """Compile a single CUDA graph that handles both first-chunk and continuation paths. + + The graph contains three conditional sub-graphs in order: + 1. IF (``is_first_chunk``) → ``_before_loop()`` + 2. IF (``is_continuation``) → ``_before_loop_continuation()`` + 3. WHILE (``active_mask_any``) → ``_loop_body()`` + ``_loop_update_decoder()`` + + At replay time the caller toggles ``is_first_chunk`` / ``is_continuation`` so + exactly one prologue executes. This avoids needing two coexisting CUDAGraph + objects (which observed to cause cudaErrorIllegalAddress on replay due to + mempool interaction between the two captures). + """ # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) + # Drain any work pending on the default stream (e.g. the warmup that ran just above in + # ``_graph_reinitialize``) before we start capturing. + stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) self.full_graph = torch.cuda.CUDAGraph() with ( @@ -1038,52 +1092,50 @@ def _full_graph_compile(self): torch.inference_mode(), torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): - self._before_loop() + # The condition-setter kernel (created lazily by ``_create_loop_body_kernel``) is + # signature-compatible with any 0-d bool*; we reuse it for all three conditional nodes. + cond_kernel = self._create_loop_body_kernel() + # NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements capture_status, _, graph, *_ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) ) - assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive - # capture: while self.active_mask_any: - (loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) - loop_kernel = self._create_loop_body_kernel() - active_mask_any_ptr = np.array([self.state.active_mask_any.data_ptr()], dtype=np.uint64) - loop_args = np.array( - [loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], + # --- IF (is_first_chunk): run first-chunk prologue --- + (first_chunk_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) + is_first_chunk_ptr = np.array([self.state.is_first_chunk.data_ptr()], dtype=np.uint64) + first_chunk_args = np.array( + [first_chunk_handle.getPtr(), is_first_chunk_ptr.ctypes.data], dtype=np.uint64, ) - # loop while there are active utterances - with with_conditional_node(loop_kernel, loop_args, loop_conditional_handle, device=self.state.device): - self._loop_body() - self._loop_update_decoder() - - # Compile continuation graph for streaming - self.full_graph_continuation = torch.cuda.CUDAGraph() + with with_conditional_node( + cond_kernel, first_chunk_args, first_chunk_handle, device=self.state.device, cond_type="if" + ): + self._before_loop() - with ( - torch.cuda.stream(stream_for_graph), - torch.inference_mode(), - torch.cuda.graph(self.full_graph_continuation, stream=stream_for_graph, capture_error_mode="thread_local"), - ): - self._before_loop_continuation() - capture_status, _, graph, *_ = cu_call( - cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) + # --- IF (is_continuation): run continuation prologue --- + (continuation_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) + is_continuation_ptr = np.array([self.state.is_continuation.data_ptr()], dtype=np.uint64) + continuation_args = np.array( + [continuation_handle.getPtr(), is_continuation_ptr.ctypes.data], + dtype=np.uint64, ) + with with_conditional_node( + cond_kernel, continuation_args, continuation_handle, device=self.state.device, cond_type="if" + ): + self._before_loop_continuation() - assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive - - # capture: while self.active_mask_any: + # --- WHILE (active_mask_any): main decoding loop --- (loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) - loop_kernel = self._create_loop_body_kernel() active_mask_any_ptr = np.array([self.state.active_mask_any.data_ptr()], dtype=np.uint64) loop_args = np.array( [loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], dtype=np.uint64, ) - # loop while there are active utterances - with with_conditional_node(loop_kernel, loop_args, loop_conditional_handle, device=self.state.device): + with with_conditional_node( + cond_kernel, loop_args, loop_conditional_handle, device=self.state.device, cond_type="while" + ): self._loop_body() self._loop_update_decoder() @@ -1180,6 +1232,7 @@ def _loop_body(self): ].unsqueeze(1), self.state.decoder_output, ).squeeze() + # logits=torch.zeros(self.state.batch_size*self.beam_size, 1025, dtype=self.state.float_dtype, device=self.state.device) log_probs = F.log_softmax(logits, dim=-1, dtype=self.state.float_dtype).view( self.state.batch_size, self.beam_size, -1 ) # [(B x Beam), V] diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py index a732f45e2721..88f09c8c772e 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py @@ -881,6 +881,7 @@ def _graph_reinitialize( f"Full CUDA graph compilation failed: {e}. " "Falling back to native PyTorch CUDA graphs. Decoding will be slower." ) + print("[DEBUG greedy] FULL_GRAPH fallback to NO_WHILE_LOOPS triggered", flush=True) self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS self._partial_graphs_compile() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: @@ -1004,6 +1005,7 @@ def _full_graph_compile(self): ): self._inner_loop_step_find_next_non_blank() self._after_inner_loop_step() + print("[DEBUG greedy] _full_graph_compile completed", flush=True) def _init_decoding_state( self, current_batch_size: int, prev_batched_state: Optional[BatchedLabelLoopingState] = None diff --git a/nemo/core/utils/cuda_python_utils.py b/nemo/core/utils/cuda_python_utils.py index f67ac8b478d4..8bbdbd53abc0 100644 --- a/nemo/core/utils/cuda_python_utils.py +++ b/nemo/core/utils/cuda_python_utils.py @@ -117,7 +117,9 @@ def cu_call(f_call_out): @contextlib.contextmanager @cuda_python_required -def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle, device): +def with_conditional_node( + while_loop_kernel, while_loop_args, while_loop_conditional_handle, device, cond_type="while" +): """ Even though we add a conditional node only once, we need to capture the kernel that calls cudaGraphSetConditional() both @@ -125,7 +127,15 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi and after the rest of the while loop body graph (because we need to decide both whether to enter the loop, and also whether to execute the next iteration of the loop). + + Args: + cond_type: Either "while" (default, original behavior) or "if". For "if" the + condition kernel is launched once (before the body) and the body is + executed at most once; for "while" the kernel is launched both before + entering the body and again at the end of each iteration so the loop + can re-evaluate its condition. """ + assert cond_type in ("while", "if"), f"cond_type must be 'while' or 'if', got {cond_type!r}" # NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements capture_status, _, graph, *_ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream) @@ -155,7 +165,10 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi driver_params = cuda.CUgraphNodeParams() driver_params.type = cuda.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL driver_params.conditional.handle = while_loop_conditional_handle - driver_params.conditional.type = cuda.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE + if cond_type == "while": + driver_params.conditional.type = cuda.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE + else: + driver_params.conditional.type = cuda.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF driver_params.conditional.size = 1 if Version(cuda_python_version) == Version("12.3.0"): # Work around for https://github.com/NVIDIA/cuda-python/issues/55 @@ -215,9 +228,13 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi yield body_stream, body_graph - cuda.cuLaunchKernel( - while_loop_kernel, 1, 1, 1, 1, 1, 1, 0, body_stream.cuda_stream, while_loop_args.ctypes.data, 0 - ) + # For a WHILE node we re-launch the condition kernel at the end of the body so the + # graph can decide whether to execute another iteration. An IF node is one-shot so + # the body simply ends here. + if cond_type == "while": + cuda.cuLaunchKernel( + while_loop_kernel, 1, 1, 1, 1, 1, 1, 0, body_stream.cuda_stream, while_loop_args.ctypes.data, 0 + ) cudart.cudaStreamEndCapture(body_stream.cuda_stream) From 7e119044c059b71aa8f8ec8f8f6b7f47b2c4070f Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 15 May 2026 13:33:54 +0400 Subject: [PATCH 07/13] clean up Signed-off-by: lilithgrigoryan --- .../speech_to_text_streaming_infer_rnnt.py | 179 ++++-------------- .../submodules/rnnt_malsd_batched_computer.py | 72 ++----- .../transducer_decoding/rnnt_label_looping.py | 2 - 3 files changed, 51 insertions(+), 202 deletions(-) diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index acfb9b778452..7278d5854725 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -294,19 +294,13 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: asr_model.preprocessor.featurizer.pad_to = 0 asr_model.eval() - # Get decoding computer based on strategy + # Get decoding computer based on strategy. Beam-search strategies expose the + # underlying computer via the private ``_decoding_computer`` attribute. if cfg.decoding.strategy == "greedy_batch": decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer elif cfg.decoding.strategy == "malsd_batch": - # Beam search strategies use _decoding_computer (private attribute) - is_tdt = False - print(f"is_tdt: {is_tdt}") - if is_tdt: - decoding_computer: ModifiedALSDBatchedTDTComputer = asr_model.decoding.decoding._decoding_computer - else: - decoding_computer: ModifiedALSDBatchedRNNTComputer = asr_model.decoding.decoding._decoding_computer + decoding_computer = asr_model.decoding.decoding._decoding_computer elif cfg.decoding.strategy == "maes_batch": - # MAES beam search returns BatchedBeamHyps decoding_computer: ModifiedAESBatchedRNNTComputer = asr_model.decoding.decoding._decoding_computer else: raise ValueError(f"Unsupported decoding strategy: {cfg.decoding.strategy}") @@ -425,53 +419,14 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: ) rest_audio_lengths = audio_batch_lengths.clone() - # print(f"decoding_computer: {type(decoding_computer)}") - # For MALSD: batched_hyps is stored in state and reused (no merge needed) - # For greedy: fresh BatchedHyps created each chunk, needs merging - is_beam_search = ( - isinstance(decoding_computer, ModifiedALSDBatchedRNNTComputer) or \ - isinstance(decoding_computer, ModifiedAESBatchedRNNTComputer) or \ - isinstance(decoding_computer, ModifiedALSDBatchedTDTComputer) + # Beam-search strategies (MALSD/MAES/TDT-MALSD) keep their batched_hyps inside + # the decoder state and reset chunk-local buffers per chunk; greedy returns a + # fresh BatchedHyps each call which we must merge externally. + is_beam_search = isinstance( + decoding_computer, + (ModifiedALSDBatchedRNNTComputer, ModifiedAESBatchedRNNTComputer, ModifiedALSDBatchedTDTComputer), ) - # ============================================================================ - # ENCODER PROCESSING MODES: - # - # MODE 1: FULL ENCODER PASS (Non-streaming simulation) - # - Runs encoder once on entire audio upfront - # - Extracts chunks from pre-computed output - # - Use this to verify consistency with non-streaming scripts - # - TO ENABLE: Uncomment all "MODE 1" sections below - # - # MODE 2: STREAMING CHUNKED ENCODER (Default/Original) - # - Runs encoder separately for each chunk with context - # - True streaming behavior with left-chunk-right context windows - # - Currently ACTIVE - # - TO KEEP: Leave "MODE 2" sections uncommented - # ============================================================================ - - # import pdb; pdb.set_trace() - # ============================================================================ - # MODE 1: FULL ENCODER PASS (for testing consistency with non-streaming) - # TO ENABLE: Uncomment this section and comment out MODE 2 sections below - # ============================================================================ - # TESTING: Run encoder once on full audio (not chunked) - # full_encoder_output, full_encoder_output_len = asr_model( - # input_signal=audio_batch, - # input_signal_length=audio_batch_lengths, - # ) - # full_encoder_output = full_encoder_output.transpose(1, 2) # [B, T, C] - - # # do not recalculate joint projection, project only once - # full_encoder_output_projected = asr_model.joint.project_encoder(full_encoder_output) - # full_encoder_output_projected_len = full_encoder_output_len - - # # Track per-sample frame positions in the full encoder output - # # Different samples may have different lengths, so we need per-sample tracking - # encoder_frame_positions = torch.zeros([batch_size], dtype=torch.long, device=device) - # ============================================================================ - - # import pdb; pdb.set_trace() # iterate over audio samples (but only for decoding chunks) while left_sample < audio_batch.shape[1]: # add samples to buffer @@ -490,50 +445,6 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: is_last_chunk_batch=is_last_chunk_batch, ) - # ======================================================================== - # MODE 1: FULL ENCODER PASS - Extract pre-computed chunks - # TO ENABLE: Uncomment this section and comment out MODE 2 section below - # ======================================================================== - # Use buffer's context to know the actual chunk sizes (handles variable lengths and last chunks) - # encoder_context_batch = buffer.context_size_batch.subsample(factor=encoder_frame2audio_samples) - - # # Extract chunks for each sample from their current position - # # Since samples can be at different positions, we need to handle each separately - # max_chunk_size_this_iter = encoder_context_batch.chunk.max().item() - # encoder_output_chunk = torch.zeros( - # [batch_size, max_chunk_size_this_iter, full_encoder_output_projected.shape[2]], - # dtype=full_encoder_output_projected.dtype, - # device=device - # ) - - # # Extract the appropriate chunk for each sample - # for b_idx in range(batch_size): - # start_pos = encoder_frame_positions[b_idx].item() - # chunk_len = encoder_context_batch.chunk[b_idx].item() - # end_pos = min(start_pos + chunk_len, full_encoder_output_projected.shape[1]) - # actual_len = end_pos - start_pos - # if actual_len > 0: - # encoder_output_chunk[b_idx, :actual_len] = full_encoder_output_projected[b_idx, start_pos:end_pos] - - # # Use the buffer's chunk size calculations (which properly handle per-sample lengths) - # encoder_out_len_chunk = encoder_context_batch.chunk - - # # decode only chunk frames (using pre-computed encoder output) - # chunk_batched_hyps, _, state = decoding_computer( - # x=encoder_output_chunk, - # out_len=encoder_out_len_chunk, - # prev_batched_state=state, - # ) - - # # Update per-sample positions - # encoder_frame_positions += encoder_context_batch.chunk - # ======================================================================== - - # import pdb; pdb.set_trace() - # ======================================================================== - # MODE 2: STREAMING CHUNKED ENCODER (ORIGINAL/DEFAULT) - # TO DISABLE: Comment out this section when using MODE 1 - # ======================================================================== # get encoder output using full buffer [left-chunk-right] encoder_output, encoder_output_len = asr_model( input_signal=buffer.samples, @@ -547,55 +458,41 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: encoder_output = encoder_output[:, encoder_context.left :] # decode only chunk frames - if isinstance(decoding_computer, ModifiedALSDBatchedTDTComputer) or isinstance(decoding_computer, ModifiedAESBatchedRNNTComputer) or isinstance(decoding_computer, ModifiedALSDBatchedRNNTComputer): + out_len = torch.where( + is_last_chunk_batch, + encoder_output_len - encoder_context_batch.left, + encoder_context_batch.chunk, + ) + if is_beam_search: + # Beam-search computers don't accept ``multi_biasing_ids`` yet. chunk_batched_hyps, _, state = decoding_computer( - x=encoder_output, - out_len=torch.where( - is_last_chunk_batch, - encoder_output_len - encoder_context_batch.left, - encoder_context_batch.chunk, - ), - prev_batched_state=state, + x=encoder_output, out_len=out_len, prev_batched_state=state ) else: chunk_batched_hyps, _, state = decoding_computer( x=encoder_output, - out_len=torch.where( - is_last_chunk_batch, - encoder_output_len - encoder_context_batch.left, - encoder_context_batch.chunk, - ), + out_len=out_len, prev_batched_state=state, multi_biasing_ids=multi_biasing_ids, ) - # ======================================================================== - # Handle hypothesis accumulation differently for beam search vs greedy. + # Accumulate hypotheses across chunks. if is_beam_search: - # For beam search the chunk-local transcript buffers inside - # ``decoding_computer.state.batched_hyps`` are reused across chunks - # (and reset at the start of every continuation chunk), so we need to - # snapshot the per-chunk transcripts before the next call overwrites - # them and merge them into an external accumulator. - # - # ``flatten_`` resolves the chunk-local prefix tree without sorting - # (preserving the chunk's beam ordering) and returns ``root_ptrs``: for - # each chunk-end beam ``i``, the beam index at the chunk's start that - # this hypothesis ultimately descends from. Because the chunk's loop - # body can permute beams at any step via the top-K gather, beam ``i`` - # at the end of chunk ``N`` is generally a different logical - # hypothesis from beam ``i`` at the end of chunk ``N-1``: it is the - # descendant of beam ``root_ptrs[i]`` from chunk ``N-1``. Threading - # ``root_ptrs`` into the accumulator's ``transcript_wb_prev_ptr`` at - # the chunk boundary (via ``merge_(..., boundary_prev_ptr=...)``) - # encodes this redirection in the prefix tree so the final - # ``flatten_sort_`` inside ``to_hyps_list`` walks back through the - # right beam history at every chunk boundary. - # - # The cross-chunk per-beam state (``scores``, ``current_lengths_nb``, - # ...) is already cumulative on each chunk's hyps; pass - # ``is_chunk_continuation=True`` so ``merge_`` replaces (not sums) - # those fields. + # Beam-search reuses chunk-local transcript buffers inside + # ``decoding_computer.state.batched_hyps`` (reset at every continuation + # chunk), so we must snapshot before the next call overwrites them. + # ``flatten_`` resolves the chunk-local prefix tree without sorting and + # returns ``root_ptrs``: for each chunk-end beam ``i``, the beam index + # at the chunk's start that this hypothesis descends from. The loop + # body can permute beams via top-K, so beam ``i`` at the end of chunk + # ``N`` is generally a different logical hypothesis from beam ``i`` at + # the end of chunk ``N-1`` - it is the descendant of beam + # ``root_ptrs[i]``. Threading ``root_ptrs`` into the accumulator's + # ``transcript_wb_prev_ptr`` via ``boundary_prev_ptr`` encodes this + # redirection so the final ``flatten_sort_`` in ``to_hyps_list`` walks + # back through the right beam history. ``is_chunk_continuation=True`` + # makes ``merge_`` replace (not sum) the cumulative per-beam fields + # (``scores``, ``current_lengths_nb``, ...). chunk_snapshot = chunk_batched_hyps.clone() chunk_root_ptrs = chunk_snapshot.flatten_() if current_batched_hyps is None: @@ -606,12 +503,10 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: is_chunk_continuation=True, boundary_prev_ptr=chunk_root_ptrs, ) + elif current_batched_hyps is None: + current_batched_hyps = chunk_batched_hyps else: - # For greedy: merge chunks using merge_ - if current_batched_hyps is None: - current_batched_hyps = chunk_batched_hyps - else: - current_batched_hyps.merge_(chunk_batched_hyps) + current_batched_hyps.merge_(chunk_batched_hyps) # move to next sample rest_audio_lengths -= chunk_lengths_batch diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index 319ed63e9f92..e1d09aa560c3 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -105,7 +105,7 @@ class MALSDState: # Streaming state fields is_continuation: torch.Tensor # flag indicating if this is a continuation from previous chunk - decoded_lengths: torch.Tensor # accumulated decoded lengths across chunks + is_first_chunk: torch.Tensor # complement of ``is_continuation``; both feed the captured graph's IF nodes def __init__( self, @@ -129,13 +129,11 @@ def __init__( float_dtype: default float dtype for tensors (should match projected encoder output) blank_index: index of the blank symbol """ - - max_time = 375 self.device = device self.float_dtype = float_dtype self.batch_size = batch_size self.beam_size = beam_size - self.max_time = max_time + self.max_time = max_time self.blank_index = blank_index self.NON_EXISTENT_LABEL = torch.tensor(NON_EXISTENT_LABEL_VALUE, device=self.device, dtype=torch.long) @@ -191,12 +189,12 @@ def __init__( float_dtype=float_dtype, ) - # Streaming state fields + # Streaming state fields. The captured FULL_GRAPH reads ``is_first_chunk`` and + # ``is_continuation`` to route to the first-chunk vs. continuation prologue at replay + # time. The two flags are kept in lockstep by ``modified_alsd_cuda_graphs`` before + # each replay (they are mutually exclusive). self.is_continuation = torch.tensor(False, device=self.device, dtype=torch.bool) - # Inverse flag used by the captured graph to route the prologue. Maintained in - # lockstep with ``is_continuation`` outside the graph (see ``_set_continuation``). self.is_first_chunk = torch.tensor(True, device=self.device, dtype=torch.bool) - self.decoded_lengths = torch.zeros([self.batch_size], dtype=torch.long, device=self.device) def need_reinit(self, encoder_output_projected: torch.Tensor) -> bool: """Check if need to reinit state: larger batch_size/max_time, or new device""" @@ -235,7 +233,6 @@ class CudaGraphsMode(PrettyStrEnum): separate_graphs: Optional[SeparateGraphsMALSD] full_graph: Optional[torch.cuda.CUDAGraph] - full_graph_continuation: Optional[torch.cuda.CUDAGraph] cuda_graphs_mode: Optional[CudaGraphsMode] state: Optional[MALSDState] fusion_models: Optional[List[NGramGPULanguageModel]] @@ -286,7 +283,6 @@ def __init__( self.state = None self.full_graph = None - self.full_graph_continuation = None self.separate_graphs = None self.cuda_graphs_mode = None @@ -361,7 +357,6 @@ def reset_cuda_graphs_state(self): """Reset state to release memory (for CUDA graphs implementations)""" self.state = None self.full_graph = None - self.full_graph_continuation = None self.separate_graphs = None def modified_alsd_torch( @@ -386,14 +381,9 @@ def modified_alsd_torch( if torch.is_autocast_enabled(): encoder_output = encoder_output.to(torch.get_autocast_gpu_dtype()) - # # do not recalculate joint projection, project only once + # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) float_dtype = encoder_output_projected.dtype - - # encoder_output_projected = encoder_output - # float_dtype = encoder_output.dtype - - # import pdb; pdb.set_trace() batch_beam_indices = ( torch.arange(batch_size, dtype=torch.long, device=device)[:, None] @@ -489,10 +479,7 @@ def modified_alsd_torch( step1=0 while active_mask.any(): - # import pdb; pdb.set_trace() - # if step1 >= 0: # step 1: get joint output + fuse with fusion models (if present) - # print(f"Encoder output length: {encoder_output_length}") logits = ( self.joint.joint_after_projection( encoder_output_projected[batch_beam_indices.view(-1), safe_time_indices.view(-1)].unsqueeze(1), @@ -566,16 +553,11 @@ def modified_alsd_torch( labels_top_k.reshape(batch_size, -1), dim=-1, index=hyps_candidates_indices ) # labels for extended hypotheses - # import pdb; pdb.set_trace() - batched_hyps.add_results_(hyps_indices, next_labels, next_hyps_prob) # step 3: store results - - # if step1 == 37: - # print(f"Step {step1}") - # if self.max_symbols is None: - # batched_hyps.add_results_(hyps_indices, next_labels, next_hyps_prob) - # else: - # batched_hyps.add_results_no_checks_(hyps_indices, next_labels, next_hyps_prob) + if self.max_symbols is None: + batched_hyps.add_results_(hyps_indices, next_labels, next_hyps_prob) + else: + batched_hyps.add_results_no_checks_(hyps_indices, next_labels, next_hyps_prob) # step 4: recombine hypotheses: sum probabilities of identical hypotheses. batched_hyps.recombine_hyps_() @@ -645,27 +627,13 @@ def modified_alsd_torch( fusion_states_candidates_list[fusion_model_idx] = fusion_states_candidates fusion_scores_list[fusion_model_idx] = fusion_scores - # import pdb; pdb.set_trace() # step 6: update time indices + active mask time_indices = torch.gather(time_indices, dim=-1, index=hyps_indices) + (next_labels == self._blank_index) torch.minimum(time_indices, last_timesteps, out=safe_time_indices) active_mask = time_indices <= last_timesteps - # if step1 == 24: - # import pdb; pdb.set_trace() - # print(f"Step {step1}") - # print(f"Time indices: {time_indices}") - # print(F"Scores: {batched_hyps.scores}") - # print(F"Trancripts: {batched_hyps.transcript_wb[..., :batched_hyps.current_lengths_wb.max()]}") - # print(F"Trancript ptrs: {batched_hyps.transcript_wb_prev_ptr[..., :batched_hyps.current_lengths_wb.max()]}") - # print(F"Trancript_lengths: {batched_hyps.current_lengths_wb}") step1 += 1 - # # fix timestamps for iterative decoding - # if not is_beam_search: - # if prev_batched_state is not None: - # batched_hyps.timestamps += prev_batched_state.decoded_lengths.unsqueeze(1).unsqueeze(1) - # NB: last labels can not exist (nothing decoded on this step). # return the last labels from the previous state in this case last_labels = batched_hyps.get_last_labels(pad_id=self._SOS) @@ -780,7 +748,6 @@ def modified_alsd_cuda_graphs( """ assert self.cuda_graphs_mode is not None - self.cuda_graphs_mode = self.CudaGraphsMode.FULL_GRAPH # do not recalculate joint projection, project only once encoder_output = self.joint.project_encoder(encoder_output) @@ -809,10 +776,9 @@ def modified_alsd_cuda_graphs( # set length to zero for elements outside the current batch self.state.encoder_output_length.fill_(0) # copy (projected) encoder output and lengths - # print("encoder_output.shape: ", encoder_output.shape) - # print("current_batch_size: ", current_batch_size) - # print("current_max_time: ", current_max_time) - self.state.encoder_output_projected[:current_batch_size, :current_max_time, ...].copy_(encoder_output[:current_batch_size, :current_max_time, ...]) + self.state.encoder_output_projected[:current_batch_size, :current_max_time, ...].copy_( + encoder_output[:current_batch_size, :current_max_time, ...] + ) self.state.encoder_output_length[:current_batch_size].copy_(encoder_output_length.unsqueeze(-1)) if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: @@ -972,7 +938,6 @@ def _graph_reinitialize( if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: try: self._full_graph_compile() - print("full_graph_compile") except NeMoCUDAPythonException as e: if not self.cuda_graphs_allow_fallback: raise RuntimeError("Full CUDA graph decoding failed. Mode is forced, raising exception") from e @@ -1232,7 +1197,6 @@ def _loop_body(self): ].unsqueeze(1), self.state.decoder_output, ).squeeze() - # logits=torch.zeros(self.state.batch_size*self.beam_size, 1025, dtype=self.state.float_dtype, device=self.state.device) log_probs = F.log_softmax(logits, dim=-1, dtype=self.state.float_dtype).view( self.state.batch_size, self.beam_size, -1 ) # [(B x Beam), V] @@ -1472,12 +1436,6 @@ def _restore_state_from_prev( # Restore batched_hyps from previous state if prev_batched_state.batched_hyps is not None: self.state.batched_hyps.copy_from_(prev_batched_state.batched_hyps) - - # Restore decoded_lengths - if prev_batched_state.decoded_lengths is not None: - self.state.decoded_lengths[:current_batch_size].copy_( - prev_batched_state.decoded_lengths[:current_batch_size] - ) def _create_decoding_state( self, @@ -1542,10 +1500,8 @@ def __call__( out_len: torch.Tensor, prev_batched_state: Optional[BatchedLabelLoopingState] = None, ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: - # self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS if self.cuda_graphs_mode is not None and x.device.type == "cuda": with torch.amp.autocast(device_type="cuda", enabled=False): - # print("Using CUDA graphs mode: NO_WHILE_LOOPS") return self.modified_alsd_cuda_graphs( encoder_output=x, encoder_output_length=out_len, diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py index 88f09c8c772e..a732f45e2721 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py @@ -881,7 +881,6 @@ def _graph_reinitialize( f"Full CUDA graph compilation failed: {e}. " "Falling back to native PyTorch CUDA graphs. Decoding will be slower." ) - print("[DEBUG greedy] FULL_GRAPH fallback to NO_WHILE_LOOPS triggered", flush=True) self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS self._partial_graphs_compile() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: @@ -1005,7 +1004,6 @@ def _full_graph_compile(self): ): self._inner_loop_step_find_next_non_blank() self._after_inner_loop_step() - print("[DEBUG greedy] _full_graph_compile completed", flush=True) def _init_decoding_state( self, current_batch_size: int, prev_batched_state: Optional[BatchedLabelLoopingState] = None From b48978aeb57751e37cbda8b5517442756da60a5e Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 15 May 2026 15:56:29 +0400 Subject: [PATCH 08/13] add tdt Signed-off-by: lilithgrigoryan --- .../speech_to_text_streaming_infer_rnnt.py | 29 ++- .../submodules/tdt_malsd_batched_computer.py | 181 ++++++++++++------ 2 files changed, 132 insertions(+), 78 deletions(-) diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index 7278d5854725..36f0c8454098 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -478,28 +478,23 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: # Accumulate hypotheses across chunks. if is_beam_search: - # Beam-search reuses chunk-local transcript buffers inside - # ``decoding_computer.state.batched_hyps`` (reset at every continuation - # chunk), so we must snapshot before the next call overwrites them. # ``flatten_`` resolves the chunk-local prefix tree without sorting and # returns ``root_ptrs``: for each chunk-end beam ``i``, the beam index - # at the chunk's start that this hypothesis descends from. The loop - # body can permute beams via top-K, so beam ``i`` at the end of chunk - # ``N`` is generally a different logical hypothesis from beam ``i`` at - # the end of chunk ``N-1`` - it is the descendant of beam - # ``root_ptrs[i]``. Threading ``root_ptrs`` into the accumulator's - # ``transcript_wb_prev_ptr`` via ``boundary_prev_ptr`` encodes this - # redirection so the final ``flatten_sort_`` in ``to_hyps_list`` walks - # back through the right beam history. ``is_chunk_continuation=True`` - # makes ``merge_`` replace (not sum) the cumulative per-beam fields - # (``scores``, ``current_lengths_nb``, ...). - chunk_snapshot = chunk_batched_hyps.clone() - chunk_root_ptrs = chunk_snapshot.flatten_() + # at the chunk's start it descends from. The loop body permutes beams + # via top-K, so beam ``i`` at the end of chunk ``N`` is generally a + # different logical hypothesis from beam ``i`` at the end of chunk + # ``N-1`` - it descends from beam ``root_ptrs[i]``. Threading these + # pointers via ``boundary_prev_ptr`` encodes the redirection so the + # final ``flatten_sort_`` in ``to_hyps_list`` walks back through the + # right beam history. ``is_chunk_continuation=True`` makes ``merge_`` + # replace (not sum) the cumulative per-beam fields (``scores``, + # ``current_lengths_nb``, ...). + chunk_root_ptrs = state.batched_hyps.flatten_() if current_batched_hyps is None: - current_batched_hyps = chunk_snapshot + current_batched_hyps = state.batched_hyps else: current_batched_hyps.merge_( - chunk_snapshot, + state.batched_hyps, is_chunk_continuation=True, boundary_prev_ptr=chunk_root_ptrs, ) diff --git a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py index dfd6c07305fe..2ed099e57f2f 100644 --- a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py @@ -96,7 +96,7 @@ class MALSDState: # Streaming state fields is_continuation: torch.Tensor # flag indicating if this is a continuation from previous chunk - decoded_lengths: torch.Tensor # accumulated decoded lengths across chunks + is_first_chunk: torch.Tensor # complement of ``is_continuation``; both feed the captured graph's IF nodes # fusion models related fields fusion_models: Optional[List[NGramGPULanguageModel]] = None @@ -190,9 +190,12 @@ def __init__( self.blank_mask = torch.zeros_like(self.active_mask, dtype=torch.bool) self.active_mask_any = torch.tensor(True, device=self.device, dtype=torch.bool) - # Streaming state fields + # Streaming state fields. The captured FULL_GRAPH reads ``is_first_chunk`` and + # ``is_continuation`` to route to the first-chunk vs. continuation prologue at replay + # time. The two flags are kept in lockstep by ``modified_alsd_cuda_graphs`` before + # each replay (they are mutually exclusive). self.is_continuation = torch.tensor(False, device=self.device, dtype=torch.bool) - self.decoded_lengths = torch.zeros([self.batch_size], dtype=torch.long, device=self.device) + self.is_first_chunk = torch.tensor(True, device=self.device, dtype=torch.bool) self.batched_hyps = BatchedBeamHyps( batch_size=batch_size, @@ -241,7 +244,6 @@ class CudaGraphsMode(PrettyStrEnum): separate_graphs: Optional[SeparateGraphsMALSD] full_graph: Optional[torch.cuda.CUDAGraph] - full_graph_continuation: Optional[torch.cuda.CUDAGraph] cuda_graphs_mode: Optional[CudaGraphsMode] state: Optional[MALSDState] fusion_models: Optional[List[NGramGPULanguageModel]] @@ -369,7 +371,6 @@ def reset_cuda_graphs_state(self): """Reset state to release memory (for CUDA graphs implementations)""" self.state = None self.full_graph = None - self.full_graph_continuation = None self.separate_graphs = None def modified_alsd_torch( @@ -881,6 +882,10 @@ def modified_alsd_cuda_graphs( # Set continuation flag and restore state from previous chunk if provided is_continuation = prev_batched_state is not None self.state.is_continuation.fill_(is_continuation) + # Mirror into the inverse flag so the captured graph's IF nodes can route to + # the right prologue. Both tensors are read by ``loop_conditional``-style + # condition kernels baked into the graph at capture time. + self.state.is_first_chunk.fill_(not is_continuation) if is_continuation: # Restore state from previous chunk @@ -888,16 +893,14 @@ def modified_alsd_cuda_graphs( # set length to zero for elements outside the current batch self.state.encoder_output_length.fill_(0) - # copy (projected) encoder output and lenghts + # copy (projected) encoder output and lengths self.state.encoder_output_projected[:current_batch_size, :current_max_time, ...].copy_(encoder_output) self.state.encoder_output_length[:current_batch_size].copy_(encoder_output_length.unsqueeze(-1)) - + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: - # Use continuation graph if continuing from previous chunk, otherwise first chunk graph - if is_continuation: - self.full_graph_continuation.replay() - else: - self.full_graph.replay() + # Single graph dispatches between first-chunk and continuation prologues internally + # via captured IF nodes that read ``is_first_chunk`` / ``is_continuation``. + self.full_graph.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: # Use continuation before_loop graph if continuing from previous chunk if is_continuation: @@ -908,7 +911,6 @@ def modified_alsd_cuda_graphs( self.separate_graphs.loop_body.replay() self.separate_graphs.loop_update_decoder.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: - # this mode is only for testing purposes # manual loop instead of using graphs if is_continuation: self._before_loop_continuation() @@ -1047,6 +1049,10 @@ def _graph_reinitialize( self.state.fusion_scores_list.append(self.state.init_fusion_scores_list[fusion_model_idx].clone()) self.state.fusion_states_prev_list.append(init_fusion_states.clone()) + # warmup before graph compilation + if self.cuda_graphs_mode is not self.CudaGraphsMode.NO_GRAPHS: + self._warmup_for_cuda_graphs() + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: try: self._full_graph_compile() @@ -1067,6 +1073,37 @@ def _graph_reinitialize( else: raise NotImplementedError + def _warmup_for_cuda_graphs(self): + """Warmup before compiling CUDA graphs. + + Runs a few eager iterations of both the first-chunk and continuation paths so that + cuBLAS / cuDNN handles and workspaces are allocated and stable before any graph + capture begins. Mirrors the warmup pattern used by the greedy label-looping decoder. + """ + is_ddp = torch.distributed.is_available() and torch.distributed.is_initialized() + # 11 warmup steps required in DDP mode + # see https://pytorch.org/docs/stable/notes/cuda.html#usage-with-distributeddataparallel + num_runs = 11 if is_ddp else 3 + self.state.encoder_output_projected.fill_(0.0) + self.state.encoder_output_length.fill_(1) + s = torch.cuda.Stream(self.state.device) + s.wait_stream(torch.cuda.current_stream(device=self.state.device)) + with torch.cuda.stream(s), torch.inference_mode(): + # Warm up the first-chunk path. + for _ in range(num_runs): + self._before_loop() + self._loop_body() + self._loop_update_decoder() + # Warm up the continuation path so its prologue and any kernels it touches + # are primed too. Both captures share a mempool, so any allocator activity + # they trigger needs to settle before either is captured. + for _ in range(num_runs): + self._before_loop_continuation() + self._loop_body() + self._loop_update_decoder() + torch.cuda.current_stream(device=self.state.device).wait_stream(s) + self.state.encoder_output_length.fill_(0) + def _partial_graphs_compile(self): """Compile decoding by parts""" # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. @@ -1113,9 +1150,23 @@ def _partial_graphs_compile(self): @cuda_python_required def _full_graph_compile(self): - """Compile full graph for decoding""" + """Compile a single CUDA graph that handles both first-chunk and continuation paths. + + The graph contains three conditional sub-graphs in order: + 1. IF (``is_first_chunk``) → ``_before_loop()`` + 2. IF (``is_continuation``) → ``_before_loop_continuation()`` + 3. WHILE (``active_mask_any``) → ``_loop_body()`` + ``_loop_update_decoder()`` + + At replay time the caller toggles ``is_first_chunk`` / ``is_continuation`` so + exactly one prologue executes. This avoids needing two coexisting CUDAGraph + objects (which observed to cause cudaErrorIllegalAddress on replay due to + mempool interaction between the two captures). + """ # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) + # Drain any work pending on the default stream (e.g. the warmup that ran just above in + # ``_graph_reinitialize``) before we start capturing. + stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) self.full_graph = torch.cuda.CUDAGraph() with ( @@ -1123,52 +1174,50 @@ def _full_graph_compile(self): torch.inference_mode(), torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): - self._before_loop() + # The condition-setter kernel (created lazily by ``_create_loop_body_kernel``) is + # signature-compatible with any 0-d bool*; we reuse it for all three conditional nodes. + cond_kernel = self._create_loop_body_kernel() + # NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements capture_status, _, graph, *_ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) ) - assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive - # capture: while self.active_mask_any: - (loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) - loop_kernel = self._create_loop_body_kernel() - active_mask_any_ptr = np.array([self.state.active_mask_any.data_ptr()], dtype=np.uint64) - loop_args = np.array( - [loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], + # --- IF (is_first_chunk): run first-chunk prologue --- + (first_chunk_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) + is_first_chunk_ptr = np.array([self.state.is_first_chunk.data_ptr()], dtype=np.uint64) + first_chunk_args = np.array( + [first_chunk_handle.getPtr(), is_first_chunk_ptr.ctypes.data], dtype=np.uint64, ) - # loop while there are active utterances - with with_conditional_node(loop_kernel, loop_args, loop_conditional_handle, device=self.state.device): - self._loop_body() - self._loop_update_decoder() - - # Compile continuation graph for streaming - self.full_graph_continuation = torch.cuda.CUDAGraph() + with with_conditional_node( + cond_kernel, first_chunk_args, first_chunk_handle, device=self.state.device, cond_type="if" + ): + self._before_loop() - with ( - torch.cuda.stream(stream_for_graph), - torch.inference_mode(), - torch.cuda.graph(self.full_graph_continuation, stream=stream_for_graph, capture_error_mode="thread_local"), - ): - self._before_loop_continuation() - capture_status, _, graph, _, _, _ = cu_call( - cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) + # --- IF (is_continuation): run continuation prologue --- + (continuation_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) + is_continuation_ptr = np.array([self.state.is_continuation.data_ptr()], dtype=np.uint64) + continuation_args = np.array( + [continuation_handle.getPtr(), is_continuation_ptr.ctypes.data], + dtype=np.uint64, ) + with with_conditional_node( + cond_kernel, continuation_args, continuation_handle, device=self.state.device, cond_type="if" + ): + self._before_loop_continuation() - assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive - - # capture: while self.active_mask_any: + # --- WHILE (active_mask_any): main decoding loop --- (loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) - loop_kernel = self._create_loop_body_kernel() active_mask_any_ptr = np.array([self.state.active_mask_any.data_ptr()], dtype=np.uint64) loop_args = np.array( [loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], dtype=np.uint64, ) - # loop while there are active utterances - with with_conditional_node(loop_kernel, loop_args, loop_conditional_handle, device=self.state.device): + with with_conditional_node( + cond_kernel, loop_args, loop_conditional_handle, device=self.state.device, cond_type="while" + ): self._loop_body() self._loop_update_decoder() @@ -1201,14 +1250,32 @@ def _before_loop(self): def _before_loop_continuation(self): """ - Prepares state for continuation chunk without clearing batched_hyps. - Decoder state and fusion states are already restored from previous chunk - via _restore_state_from_prev before this method is called. + Prepares state for continuation chunk without clearing the cross-chunk per-beam + state of ``batched_hyps``. + + Decoder state and fusion states are already restored from the previous chunk via + ``_restore_state_from_prev`` before this method is called. ``batched_hyps`` is + likewise restored from the previous chunk so that ``scores``, ``last_label``, + ``transcript_hash``, ``current_lengths_nb`` and ``last_timestamp_lasts`` continue + the beam search across the chunk boundary. + + However, the per-chunk transcript prefix tree (``transcript_wb`` / + ``transcript_wb_prev_ptr`` / ``timestamps``) and the write cursor into it + (``current_lengths_wb``) must be reset for the new chunk: their buffers are sized + for one chunk's worth of decoding only, and the captured ``add_results_no_checks_`` + scatters into them at ``current_lengths_wb`` without bounds checking. If we kept + the previous chunk's cursor we would index out of bounds within a few chunks, + producing a CUDA illegal-memory-access from inside the captured graph. The caller + is responsible for snapshotting / merging the per-chunk transcripts before this + method runs at the start of the next chunk (see :meth:`BatchedBeamHyps.merge_` + with ``is_chunk_continuation=True``). """ - # Don't clear batched_hyps - it's already restored from previous state - # Don't reset decoder state - it's already restored from previous state - # Don't reset fusion states - they're already restored from previous state - + # Reset chunk-local storage so the captured loop body writes into freshly zeroed + # buffers; cross-chunk per-beam state is preserved. + self.state.batched_hyps.clear_chunk_local_() + + # Decoder state and fusion states are already restored from previous state. + self._before_loop_common() def _before_loop_common(self): @@ -1530,12 +1597,6 @@ def _restore_state_from_prev( # Restore batched_hyps from previous state if prev_batched_state.batched_hyps is not None: self.state.batched_hyps.copy_from_(prev_batched_state.batched_hyps) - - # Restore decoded_lengths - if prev_batched_state.decoded_lengths is not None: - self.state.decoded_lengths[:current_batch_size].copy_( - prev_batched_state.decoded_lengths[:current_batch_size] - ) def _create_decoding_state( self, @@ -1600,15 +1661,13 @@ def __call__( out_len: torch.Tensor, prev_batched_state: Optional[BatchedLabelLoopingState] = None, ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: - # self.cuda_graphs_mode = None if self.cuda_graphs_mode is not None and x.device.type == "cuda": with torch.amp.autocast(device_type="cuda", enabled=False): - batched_hyps, alignments, decoding_state = self.modified_alsd_cuda_graphs( - encoder_output=x, + return self.modified_alsd_cuda_graphs( + encoder_output=x, encoder_output_length=out_len, - prev_batched_state=prev_batched_state + prev_batched_state=prev_batched_state, ) - return batched_hyps, alignments, decoding_state return self.modified_alsd_torch( encoder_output=x, encoder_output_length=out_len, prev_batched_state=prev_batched_state From 1c842e1a1679e12d788b490395bf077c47c39ab8 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 15 May 2026 16:10:16 +0400 Subject: [PATCH 09/13] clean uo Signed-off-by: lilithgrigoryan --- nemo/collections/asr/parts/submodules/tdt_beam_decoding.py | 6 +++--- nemo/collections/asr/parts/utils/streaming_utils.py | 5 ----- nemo/core/utils/cuda_python_utils.py | 7 ------- 3 files changed, 3 insertions(+), 15 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py index e7b4edfb63ff..a0154b1f80e9 100644 --- a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py @@ -962,16 +962,16 @@ def forward( inseq = encoder_output # [B, T, D] - # import pdb; pdb.set_trace() encoder_output_projected = self.joint.project_encoder(encoder_output) encoder_output_projected_len = encoded_lengths batched_beam_hyps, alignments, decoding_state = self._decoding_computer(x=encoder_output_projected, out_len=encoder_output_projected_len) + # Ensures the correct number of hypotheses (batch_size) for CUDA Graphs compatibility batch_size = encoder_output.shape[0] if self.return_best_hypothesis: - hyps = batched_beam_hyps.to_hyps_list(score_norm=self.score_norm)[:batch_size] # type: ignore + hyps = batched_beam_hyps.to_hyps_list(score_norm=self.score_norm)[:batch_size] else: - hyps = batched_beam_hyps.to_nbest_hyps_list(score_norm=self.score_norm)[:batch_size] # type: ignore + hyps = batched_beam_hyps.to_nbest_hyps_list(score_norm=self.score_norm)[:batch_size] self.decoder.train(decoder_training_state) self.joint.train(joint_training_state) diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index 6795be5b6046..b2ebb715bd11 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -2246,11 +2246,6 @@ def subsample(self, factor: int) -> "ContextSizeBatch": chunk=torch.ceil(self.chunk / factor).to(dtype=torch.long), right=torch.ceil(self.right / factor).to(dtype=torch.long), ) - # return ContextSizeBatch( - # left=torch.div(self.left, factor).round().to(dtype=torch.long), - # chunk=torch.div(self.chunk, factor).round().to(dtype=torch.long), - # right=torch.div(self.right, factor).round().to(dtype=torch.long), - # ) def add_frames_( self, num_frames_batch: torch.Tensor, is_last_chunk_batch: torch.Tensor, expected_context: "ContextSize" diff --git a/nemo/core/utils/cuda_python_utils.py b/nemo/core/utils/cuda_python_utils.py index 8bbdbd53abc0..c40c567765c7 100644 --- a/nemo/core/utils/cuda_python_utils.py +++ b/nemo/core/utils/cuda_python_utils.py @@ -127,13 +127,6 @@ def with_conditional_node( and after the rest of the while loop body graph (because we need to decide both whether to enter the loop, and also whether to execute the next iteration of the loop). - - Args: - cond_type: Either "while" (default, original behavior) or "if". For "if" the - condition kernel is launched once (before the body) and the body is - executed at most once; for "while" the kernel is launched both before - entering the body and again at the end of each iteration so the loop - can re-evaluate its condition. """ assert cond_type in ("while", "if"), f"cond_type must be 'while' or 'if', got {cond_type!r}" # NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements From ea601a805cd6c4b463b59f95ad877cf00d052c1f Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 15 May 2026 16:13:30 +0400 Subject: [PATCH 10/13] clean up Signed-off-by: lilithgrigoryan --- .../submodules/tdt_malsd_batched_computer.py | 31 +++---------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py index 2ed099e57f2f..4eeaad0031c8 100644 --- a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py @@ -401,9 +401,6 @@ def modified_alsd_torch( # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) float_dtype = encoder_output_projected.dtype - - # encoder_output_projected = encoder_output - # float_dtype = encoder_output.dtype batch_beam_indices = ( torch.arange(batch_size, dtype=torch.long, device=device)[:, None] @@ -508,19 +505,7 @@ def modified_alsd_torch( # import pdb; pdb.set_trace() step1 = 0 - while active_mask.any(): - # import pdb; pdb.set_trace() - # print(f"Step {step1}") - # print(f"Time indices: {safe_time_indices}") - # print(f"Active mask: {active_mask}") - # print(f"Encoder output length: {encoder_output_length}") - # print(f"Decoder output: {decoder_output.shape}") - # print(f"Safe time indices: {safe_time_indices}") - # print(f"encoder_output_projected: {encoder_output_projected.shape}") - # print(f"Decoder state: {decoder_state.shape}") - - - + while active_mask.any(): # step 1: get joint output + fuse with fusion models (if present) logits = ( self.joint.joint_after_projection( @@ -618,14 +603,11 @@ def modified_alsd_torch( durations_top_k.reshape(batch_size, -1), dim=-1, index=hyps_candidates_indices ) # durations for extended hypotheses - # import pdb; pdb.set_trace() # step 3: store results - # if self.max_symbols is None: - # batched_hyps.add_results_(hyps_indices, next_labels, next_hyps_prob, next_label_durations) - # else: - # batched_hyps.add_results_no_checks_(hyps_indices, next_labels, next_hyps_prob, next_label_durations) - # print(f"DEBUG: Adding results to batched_hyps") - batched_hyps.add_results_(hyps_indices, next_labels, next_hyps_prob, next_label_durations) + if self.max_symbols is None: + batched_hyps.add_results_(hyps_indices, next_labels, next_hyps_prob, next_label_durations) + else: + batched_hyps.add_results_no_checks_(hyps_indices, next_labels, next_hyps_prob, next_label_durations) # step 4: recombine hypotheses: sum probabilities of identical hypotheses. batched_hyps.recombine_hyps_() @@ -707,13 +689,10 @@ def modified_alsd_torch( if prev_batched_state is not None: batched_hyps.timestamps += prev_batched_state.decoded_lengths[:, None, None].expand_as(batched_hyps.timestamps) # Also update next_timestamp for proper continuation - # batched_hyps.next_timestamp += prev_batched_state.decoded_lengths[:, None].expand_as(batched_hyps.next_timestamp) # NB: last labels can not exist (nothing decoded on this step). # return the last labels from the previous state in this case - # import pdb; pdb.set_trace() last_labels = batched_hyps.get_last_labels(pad_id=self._SOS) - # batched_hyps.next_timestamp.copy_(batched_hyps.next_timestamp - encoder_output_length.unsqueeze(-1)) decoding_state = BatchedLabelLoopingState( predictor_states=decoder_state, predictor_outputs=decoder_output, From 0d56d81021df99c3d08bdc66b3e0713f6ef37398 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 15 May 2026 16:41:46 +0400 Subject: [PATCH 11/13] add working maes Signed-off-by: lilithgrigoryan --- .../submodules/rnnt_maes_batched_computer.py | 193 ++++++++++-------- .../submodules/rnnt_malsd_batched_computer.py | 34 +-- .../submodules/tdt_malsd_batched_computer.py | 34 +-- 3 files changed, 133 insertions(+), 128 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py index ef5458a5a762..c326307b9d1f 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py @@ -15,6 +15,7 @@ from typing import Optional import torch +import torch.nn.functional as F from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import BatchedLabelLoopingState from nemo.collections.asr.parts.utils import rnnt_utils @@ -35,6 +36,10 @@ class ModifiedAESBatchedRNNTComputer(ConfidenceMethodMixin): Based on https://ieeexplore.ieee.org/document/9250505 with the following modficiations: - does not support prediction network caching - supports prefix search with only longest prefix + + Note: RNN-T only. TDT is not supported by this computer (use ``ModifiedALSDBatchedTDTComputer`` + instead). Unlike :class:`ModifiedALSDBatchedRNNTComputer` this is a pure-PyTorch decoder; + CUDA graphs are not implemented. """ def __init__( @@ -46,7 +51,7 @@ def __init__( maes_num_steps: int, maes_expansion_beta: int, maes_expansion_gamma: int, - preserve_alignments=False, + preserve_alignments: bool = False, ngram_lm_model: Optional[str | Path] = None, ngram_lm_alpha: float = 0.0, blank_lm_score_mode: Optional[str | BlankLMScoreMode] = BlankLMScoreMode.NO_SCORE, @@ -77,37 +82,35 @@ def __init__( ngram_lm_alpha: weight for the n-gram LM scores blank_lm_score_mode: mode for scoring blank symbol with LM pruning_mode: mode for pruning hypotheses with LM - allow_cuda_graphs: whether to allow CUDA graphs + allow_cuda_graphs: accepted for API parity with :class:`ModifiedALSDBatchedRNNTComputer`, + but ignored - this computer does not implement CUDA graphs. """ super().__init__() self.decoder = decoder self.joint = joint self._blank_index = blank_index + self._SOS = self._blank_index self.beam_size = beam_size self.maes_num_steps = maes_num_steps self.maes_expansion_beta = maes_expansion_beta self.maes_expansion_gamma = maes_expansion_gamma - self.preserve_alignments = preserve_alignments - self._SOS = self._blank_index - self.pruning_mode = pruning_mode - self.blank_lm_score_mode = blank_lm_score_mode - self.maes_num_expansions = self.beam_size + self.maes_expansion_beta + self.preserve_alignments = preserve_alignments if self.preserve_alignments: raise NotImplementedError("Preserve alignments is not supported") if allow_cuda_graphs: - logging.info("CUDA Graphs are unsupported for `maes_batch`; preceeding pure pytorch decoding") + logging.info("`allow_cuda_graphs=True` is accepted for API parity, but `maes_batch` runs in pure PyTorch.") + # n-gram LM fusion setup if ngram_lm_model is not None: expected_blank_index = self.joint.num_classes_with_blank - self.joint.num_extra_outputs - 1 if self._blank_index != expected_blank_index: raise ValueError(f"Invalid blank index: expected {expected_blank_index}, got {self._blank_index}") self.ngram_lm_batch = ngram_lm_model - self.pruning_mode = PruningMode.EARLY if pruning_mode is None else PruningMode(pruning_mode) self.blank_lm_score_mode = ( BlankLMScoreMode.LM_WEIGHTED_FULL @@ -116,6 +119,7 @@ def __init__( ) else: self.ngram_lm_batch = None + self.pruning_mode = pruning_mode self.blank_lm_score_mode = None self.ngram_lm_alpha = ngram_lm_alpha @@ -123,45 +127,59 @@ def batched_modified_adaptive_expansion_search_torch( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, - prev_batched_state: Optional[BatchedBeamHyps] = None, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: """ - Pure PyTorch implementation + Pure PyTorch implementation of batched modified adaptive expansion search for RNN-T. Args: - encoder_output: output from the encoder - encoder_output_length: lengths of the utterances in `encoder_output` + encoder_output: output from the encoder, shape [batch_size, max_time, encoder_dim]. + encoder_output_length: lengths of the utterances in ``encoder_output``, shape [batch_size]. + prev_batched_state: optional state from a previous chunk (streaming / chunked decoding). + When provided, ``predictor_states``, ``predictor_outputs`` and ``batched_hyps`` are + reused so that beam search continues across the chunk boundary. + + Returns: + tuple of (batched_hyps, None, decoding_state) where ``decoding_state`` is the state to + pass back as ``prev_batched_state`` on the next chunk. """ - batch_size, max_time, _unused = encoder_output.shape + batch_size, max_time, _ = encoder_output.shape device = encoder_output.device + # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) float_dtype = encoder_output_projected.dtype - - # import pdb; pdb.set_trace() - # encoder_output_projected = encoder_output - # float_dtype = encoder_output.dtype - + batch_indices = ( - torch.arange(batch_size, device=device)[:, None].expand(batch_size, self.beam_size).clone() + torch.arange(batch_size, dtype=torch.long, device=device)[:, None] + .expand(batch_size, self.beam_size) + .clone() ) # size: batch_size x beam_size beam_indices = ( - torch.arange(self.beam_size, device=device)[None, :].expand(batch_size, self.beam_size).clone() + torch.arange(self.beam_size, dtype=torch.long, device=device)[None, :] + .expand(batch_size, self.beam_size) + .clone() ) # size: batch_size x beam_size expansion_beam_indices = ( - torch.arange(self.beam_size, device=device)[None, :, None] + torch.arange(self.beam_size, dtype=torch.long, device=device)[None, :, None] .expand(batch_size, self.beam_size, self.maes_num_expansions) .clone() - ) # size: batch_size x beam_size x beam_size + maes_expansion_beta - + ) # size: batch_size x beam_size x (beam_size + maes_expansion_beta) + + # On continuation, work on a fresh clone of the previous chunk's hypotheses so we + # don't mutate the state object the caller may also be using as its streaming + # accumulator (``current_batched_hyps`` in the streaming script aliases + # ``state.batched_hyps`` on the first chunk - mutating it here would corrupt the + # accumulator). Cross-chunk per-beam fields (``scores``, ``last_label``, + # ``transcript_hash``, ``current_lengths_nb``, ``last_timestamp_lasts``) are + # preserved by ``clone()``; the chunk-local prefix-tree cursor and buffers are + # reset by ``clear_chunk_local_()`` so this chunk's ``add_results_`` scatters + # start at offset zero (cf. ``_before_loop_continuation`` / + # ``_create_decoding_state`` clone in :class:`ModifiedALSDBatchedRNNTComputer`). if prev_batched_state is not None and prev_batched_state.batched_hyps is not None: - batched_hyps = prev_batched_state.batched_hyps - time_indices = torch.zeros_like(beam_indices) - last_timesteps = (encoder_output_length - 1)[:, None].expand_as(beam_indices) - safe_time_indices = torch.minimum(time_indices, last_timesteps) - active_mask = time_indices <= last_timesteps + batched_hyps = prev_batched_state.batched_hyps.clone() + batched_hyps.clear_chunk_local_() else: - # init empty batched hypotheses batched_hyps = BatchedBeamHyps( batch_size=batch_size, beam_size=self.beam_size, @@ -171,53 +189,42 @@ def batched_modified_adaptive_expansion_search_torch( float_dtype=float_dtype, store_prefix_hashes=True, ) - time_indices = torch.zeros_like(beam_indices) - safe_time_indices = torch.zeros_like(time_indices) # time indices, guaranteed to be < out_len - last_timesteps = (encoder_output_length - 1)[:, None].expand_as(beam_indices) - active_mask = time_indices <= last_timesteps time_indices = torch.zeros_like(batch_indices) - safe_time_indices = torch.zeros_like(time_indices) - last_timesteps = (encoder_output_length - 1)[:, None].expand(batch_size, self.beam_size) + safe_time_indices = torch.zeros_like(time_indices) # time indices, guaranteed to be < out_len + last_timesteps = (encoder_output_length - 1)[:, None].expand_as(batch_indices) active_mask = time_indices <= last_timesteps # setup N-gram LM if available + # TODO: when ``prev_batched_state`` is provided, the n-gram LM state is currently + # re-seeded from BOS instead of being restored - see ``ModifiedALSDBatchedRNNTComputer`` + # for how fusion-model states are threaded through ``BatchedLabelLoopingState``. if self.ngram_lm_batch is not None: self.ngram_lm_batch.to(device) batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size * self.beam_size, bos=True) - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=batch_lm_states - ) # vocab_size_no_blank + lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=batch_lm_states) lm_scores = lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha - if prev_batched_state is None: + if prev_batched_state is None: last_labels_wb = torch.full( [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long ) decoder_state = self.decoder.initialize_state( - torch.empty( - [ - batch_size * self.beam_size, - ], - dtype=float_dtype, - device=device, - ) + torch.empty([batch_size * self.beam_size], dtype=float_dtype, device=device) ) - decoder_output, state, *_ = self.decoder.predict( last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size ) # do not recalculate joint projection decoder_output = self.joint.project_prednet(decoder_output) # size: [(batch_size x beam_size), 1, Dim] self.decoder.batch_replace_states_all(state, dst_states=decoder_state) - else: + else: # Continuing from previous chunk - batched_hyps already contains all state decoder_output = prev_batched_state.predictor_outputs decoder_state = prev_batched_state.predictor_states while active_mask.any(): # frames loop to_update = active_mask.clone() # mask for expansions loop - # import pdb; pdb.set_trace() # step 1: get joint output logits = ( @@ -228,14 +235,14 @@ def batched_modified_adaptive_expansion_search_torch( .squeeze(1) .squeeze(1) ) - logps = torch.log_softmax(logits, dim=-1).view(batch_size, self.beam_size, -1) + logps = F.log_softmax(logits, dim=-1).view(batch_size, self.beam_size, -1) # step 2: perform prefix search updated_logps = self.combine_scores(logps, lm_scores) if self.ngram_lm_batch is not None else logps batched_hyps.recombine_prefixes(updated_logps, active_mask) - expansion_steps = 0 # step 3: performs `maes_num_steps` non-blank expansions + expansion_steps = 0 while to_update.any() and expansion_steps < self.maes_num_steps: # expansions loop # step 3.1: get `maes_num_expansion` best expansions (in total beam x maes_num_expansion expansions) if self.ngram_lm_batch is None: @@ -269,11 +276,10 @@ def batched_modified_adaptive_expansion_search_torch( next_labels = next_labels.view(batch_size, -1)[batch_indices, idx] hyp_indices = expansion_beam_indices.view(batch_size, -1)[batch_indices, idx] - # import pdb; pdb.set_trace() # step 3.3: update batched beam hypotheses structure batched_hyps.add_results_(hyp_indices, next_labels, next_hyps_probs) - # step 3.4: update + # step 3.4: update last labels (mask invalid with blank to avoid decoder errors) last_labels_wb = torch.where(next_labels >= 0, next_labels, self._blank_index) preserve_state = last_labels_wb == self._blank_index @@ -329,9 +335,7 @@ def batched_modified_adaptive_expansion_search_torch( ).squeeze(-1) batch_lm_states = torch.where(preserve_state, batch_lm_states_prev, batch_lm_states).view(-1) - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=batch_lm_states - ) # vocab_size_no_blank + lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=batch_lm_states) lm_scores = ( lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha ) @@ -341,13 +345,13 @@ def batched_modified_adaptive_expansion_search_torch( encoder_output_projected[batch_indices.flatten(), safe_time_indices.flatten()].unsqueeze(1), decoder_output, ) - logps = torch.log_softmax(logits, dim=-1).squeeze(1).squeeze(1).view(batch_size, self.beam_size, -1) + logps = F.log_softmax(logits, dim=-1).squeeze(1).squeeze(1).view(batch_size, self.beam_size, -1) to_update = torch.logical_and(to_update, last_labels_wb != self._blank_index) expansion_steps += 1 + + # step 4: force blank to active hypotheses still waiting for one this frame if to_update.any(): - # import pdb; pdb.set_trace() - # step 4: force blank to active hypotheses next_hyps_probs = torch.where(to_update, batched_hyps.scores + logps[..., -1], batched_hyps.scores) next_labels = torch.where(to_update, self._blank_index, -1) batched_hyps.add_results_(beam_indices, next_labels, next_hyps_probs) @@ -356,29 +360,54 @@ def batched_modified_adaptive_expansion_search_torch( time_indices += 1 active_mask = time_indices <= last_timesteps safe_time_indices = torch.where(active_mask, time_indices, last_timesteps) - - # import pdb; pdb.set_trace() + + return ( + batched_hyps, + None, + self._create_decoding_state( + batched_hyps=batched_hyps, + decoder_state=decoder_state, + decoder_output=decoder_output, + encoder_output_length=encoder_output_length, + prev_batched_state=prev_batched_state, + ), + ) + + def _create_decoding_state( + self, + batched_hyps: BatchedBeamHyps, + decoder_state, + decoder_output: torch.Tensor, + encoder_output_length: torch.Tensor, + prev_batched_state: Optional[BatchedLabelLoopingState], + ) -> BatchedLabelLoopingState: + """ + Build the :class:`BatchedLabelLoopingState` returned for the next chunk. + + Mirrors the trailing block of :meth:`ModifiedALSDBatchedRNNTComputer.modified_alsd_torch`: + accumulate ``decoded_lengths`` across chunks, fall back to previous-chunk labels for + beams that emitted nothing in this chunk, and reset the chunk-local ``next_timestamp`` + write cursor on ``batched_hyps``. + """ last_labels = batched_hyps.get_last_labels(pad_id=self._SOS) batched_hyps.next_timestamp.fill_(0) - decoding_state = BatchedLabelLoopingState( + + if prev_batched_state is not None: + # Beams that emitted nothing this chunk return SOS; carry over the previous label. + labels = torch.where(last_labels == self._SOS, prev_batched_state.labels, last_labels) + decoded_lengths = encoder_output_length + prev_batched_state.decoded_lengths + else: + labels = last_labels + decoded_lengths = encoder_output_length.clone() + + return BatchedLabelLoopingState( predictor_states=decoder_state, predictor_outputs=decoder_output, - labels=( - torch.where(last_labels == self._SOS, prev_batched_state.labels, last_labels) - if prev_batched_state is not None - else last_labels - ), - decoded_lengths=( - encoder_output_length.clone() - if prev_batched_state is None - else encoder_output_length + prev_batched_state.decoded_lengths - ), - # fusion_states_list=fusion_states_list if self.fusion_models is not None else None, + labels=labels, + decoded_lengths=decoded_lengths, time_jumps=None, - batched_hyps=batched_hyps, # Save batched_hyps object for next chunk + batched_hyps=batched_hyps, ) - - return batched_hyps, None, decoding_state def combine_scores(self, log_probs, lm_scores): """ @@ -495,6 +524,10 @@ def __call__( self, x: torch.Tensor, out_len: torch.Tensor, - prev_batched_state: Optional[BatchedBeamHyps] = None, + prev_batched_state: Optional[BatchedLabelLoopingState] = None, ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: - return self.batched_modified_adaptive_expansion_search_torch(encoder_output=x, encoder_output_length=out_len, prev_batched_state=prev_batched_state) + return self.batched_modified_adaptive_expansion_search_torch( + encoder_output=x, + encoder_output_length=out_len, + prev_batched_state=prev_batched_state, + ) diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index e1d09aa560c3..be66c07df1f2 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -1133,32 +1133,18 @@ def _before_loop(self): def _before_loop_continuation(self): """ - Prepares state for continuation chunk without clearing the cross-chunk per-beam - state of ``batched_hyps``. - - Decoder state and fusion states are already restored from the previous chunk via - ``_restore_state_from_prev`` before this method is called. ``batched_hyps`` is - likewise restored from the previous chunk so that ``scores``, ``last_label``, - ``transcript_hash``, ``current_lengths_nb`` and ``last_timestamp_lasts`` continue - the beam search across the chunk boundary. - - However, the per-chunk transcript prefix tree (``transcript_wb`` / - ``transcript_wb_prev_ptr`` / ``timestamps``) and the write cursor into it - (``current_lengths_wb``) must be reset for the new chunk: their buffers are sized - for one chunk's worth of decoding only, and the captured ``add_results_no_checks_`` - scatters into them at ``current_lengths_wb`` without bounds checking. If we kept - the previous chunk's cursor we would index out of bounds within a few chunks, - producing a CUDA illegal-memory-access from inside the captured graph. The caller - is responsible for snapshotting / merging the per-chunk transcripts before this - method runs at the start of the next chunk (see :meth:`BatchedBeamHyps.merge_` - with ``is_chunk_continuation=True``). + Prologue for a continuation chunk: preserves cross-chunk per-beam state on + ``batched_hyps`` (scores, last_label, transcript_hash, current_lengths_nb, + last_timestamp_lasts) and resets only the chunk-local prefix-tree buffers + (transcript_wb / transcript_wb_prev_ptr / timestamps / current_lengths_wb) + that the captured loop body would otherwise overflow. + + Decoder and fusion states are already restored by ``_restore_state_from_prev``. + The caller is responsible for snapshotting / merging the chunk-local transcripts + before the next chunk (see :meth:`BatchedBeamHyps.merge_` with + ``is_chunk_continuation=True``). """ - # Reset chunk-local storage so the captured loop body writes into freshly zeroed - # buffers; cross-chunk per-beam state is preserved. self.state.batched_hyps.clear_chunk_local_() - - # Decoder state and fusion states are already restored from previous state. - self._before_loop_common() def _before_loop_common(self): diff --git a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py index 4eeaad0031c8..38352a1b64ae 100644 --- a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py @@ -1229,32 +1229,18 @@ def _before_loop(self): def _before_loop_continuation(self): """ - Prepares state for continuation chunk without clearing the cross-chunk per-beam - state of ``batched_hyps``. - - Decoder state and fusion states are already restored from the previous chunk via - ``_restore_state_from_prev`` before this method is called. ``batched_hyps`` is - likewise restored from the previous chunk so that ``scores``, ``last_label``, - ``transcript_hash``, ``current_lengths_nb`` and ``last_timestamp_lasts`` continue - the beam search across the chunk boundary. - - However, the per-chunk transcript prefix tree (``transcript_wb`` / - ``transcript_wb_prev_ptr`` / ``timestamps``) and the write cursor into it - (``current_lengths_wb``) must be reset for the new chunk: their buffers are sized - for one chunk's worth of decoding only, and the captured ``add_results_no_checks_`` - scatters into them at ``current_lengths_wb`` without bounds checking. If we kept - the previous chunk's cursor we would index out of bounds within a few chunks, - producing a CUDA illegal-memory-access from inside the captured graph. The caller - is responsible for snapshotting / merging the per-chunk transcripts before this - method runs at the start of the next chunk (see :meth:`BatchedBeamHyps.merge_` - with ``is_chunk_continuation=True``). + Prologue for a continuation chunk: preserves cross-chunk per-beam state on + ``batched_hyps`` (scores, last_label, transcript_hash, current_lengths_nb, + last_timestamp_lasts) and resets only the chunk-local prefix-tree buffers + (transcript_wb / transcript_wb_prev_ptr / timestamps / current_lengths_wb) + that the captured loop body would otherwise overflow. + + Decoder and fusion states are already restored by ``_restore_state_from_prev``. + The caller is responsible for snapshotting / merging the chunk-local transcripts + before the next chunk (see :meth:`BatchedBeamHyps.merge_` with + ``is_chunk_continuation=True``). """ - # Reset chunk-local storage so the captured loop body writes into freshly zeroed - # buffers; cross-chunk per-beam state is preserved. self.state.batched_hyps.clear_chunk_local_() - - # Decoder state and fusion states are already restored from previous state. - self._before_loop_common() def _before_loop_common(self): From 539bdcfd14a7fae10d7cdcc7381eafcec7273fff Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 15 May 2026 16:51:10 +0400 Subject: [PATCH 12/13] clean up Signed-off-by: lilithgrigoryan --- .../utils/batched_beam_decoding_utils.py | 119 ------------------ 1 file changed, 119 deletions(-) diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index 180c74ec91b4..29ff3c605bc0 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -396,18 +396,6 @@ def add_results_no_checks_( extended_with_blank = next_labels == self.blank_index extended_with_label = (is_extended) & (~extended_with_blank) - # TODO: uncomment - # last_labels = torch.gather(self.last_label, dim=-1, index=next_indices) - # self.transcript_wb.scatter_( - # dim=-1, - # index=self.current_lengths_wb.unsqueeze(-1), - # src=torch.where(is_extended, next_labels, NON_EXISTENT_LABEL_VALUE).unsqueeze(-1) - # ) - # self.transcript_wb_prev_ptr.scatter_( - # dim=-1, - # index=self.current_lengths_wb.unsqueeze(-1), - # src=torch.where(is_extended, next_indices, INIT_POINTER_VALUE).unsqueeze(-1) - # ) if self.model_type == ASRModelTypeEnum.CTC: # for CTC last non-blank and non-repeated label extended_with_label = (extended_with_label) & (next_labels != last_labels) # non-repeated non-blank label @@ -447,12 +435,6 @@ def add_results_no_checks_( ) torch.add(self.current_lengths_wb, 1, out=self.current_lengths_wb) self.scores.copy_(next_hyps_prob) - - # TODO: uncomment - # self.current_lengths_wb.copy_( - # torch.gather(self.current_lengths_wb, dim=-1, index=next_indices) + is_extended - # ) - # self.scores.copy_(torch.where(is_extended, next_hyps_prob, torch.gather(self.scores, dim=-1, index=next_indices))) prev_transcript_hash = torch.gather(self.transcript_hash, dim=-1, index=next_indices) # update hashes and prefix hashes @@ -487,15 +469,9 @@ def recombine_hyps_(self): Note: The method modifies the `self.scores` tensor in place to reflect the recombined hypotheses. """ - # print(f"DEBUG: Entering recombine_hyps_. batch_size={self.batch_size}, beam_size={self.beam_size}") - if self.beam_size <= 1: return - # print(f"DEBUG: transcript_hash shape: {self.transcript_hash.shape}") - # print(f"DEBUG: last_label shape: {self.last_label.shape}") - # print(f"DEBUG: current_lengths_nb shape: {self.current_lengths_nb.shape}") - hyps_equal = ( (self.transcript_hash[:, :, None] == self.transcript_hash[:, None, :]) & (self.last_label[:, :, None] == self.last_label[:, None, :]) @@ -503,24 +479,15 @@ def recombine_hyps_(self): ) if self.model_type == ASRModelTypeEnum.TDT: - # print(f"DEBUG: TDT model type. next_timestamp shape: {self.next_timestamp.shape}") hyps_equal &= self.next_timestamp[:, :, None] == self.next_timestamp[:, None, :] - # print(f"DEBUG: hyps_equal shape: {hyps_equal.shape}") - - # print(f"DEBUG: self.scores : {self.scores}") - scores_matrix = torch.where( hyps_equal, self.scores[:, None, :].expand(self.batch_size, self.beam_size, self.beam_size), self.INACTIVE_SCORE_TENSOR, ) - # print(f"DEBUG: scores_matrix : {scores_matrix}") - # print(f"DEBUG: scores_matrix shape: {scores_matrix.shape}") - # print(f"DEBUG: scores_matrix : {scores_matrix}") scores_argmax = scores_matrix.argmax(-1, keepdim=False) - # print(f"DEBUG: scores_argmax shape: {scores_argmax.shape}, min: {scores_argmax.min()}, max: {scores_argmax.max()}") scores_to_keep = ( torch.arange(self.beam_size, device=scores_argmax.device, dtype=torch.long)[None, :] == scores_argmax ) @@ -529,14 +496,8 @@ def recombine_hyps_(self): else: new_scores = torch.logsumexp(scores_matrix, dim=-1, keepdim=False) - # print(f"DEBUG: new_scores shape: {new_scores.shape}") - # print(f"DEBUG: scores_to_keep shape: {scores_to_keep.shape}") - # print(f"DEBUG: self.scores shape: {self.scores.shape}") - torch.where(scores_to_keep, new_scores.to(self.scores.dtype), self.INACTIVE_SCORE_TENSOR, out=self.scores) - # print("DEBUG: Exiting recombine_hyps_") - def remove_duplicates(self, labels: torch.Tensor, total_logps: torch.Tensor): """ Removes duplicate hypotheses that may arise after updating beam hypotheses with labels during the beam search process. @@ -691,86 +652,6 @@ def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]: ] return hypotheses - # def flatten_sort_(self, score_norm: bool = True): - # """ - # Sorts and flattens the tree structure of hypotheses in a batched beam search decoding process. - # This is a SERIALIZED version that processes each batch element sequentially to avoid - # issues with pointer chasing in the batched version. - - # Args: - # score_norm (bool, optional): If True, normalizes the scores by dividing - # them by the current lengths of the hypotheses plus one. Defaults to True. - # This method performs the following steps: - # 1. Normalizes the scores if `score_norm` is True. - # 2. Sorts the normalized scores in descending order and retrieves the corresponding indices. - # 3. Iteratively reconstructs the tokens and timestamps for each hypothesis in reverse order. - # 4. Updates the internal state of the object, including transcripts, timestamps, scores, - # lengths, labels, and other metadata, based on the sorted order. - # """ - - # # add one for consistency with non-batched decodings, that use SOS. - # normalized_scores = ( - # self.scores / (self.current_lengths_nb.to(self.scores.dtype) + 1) if score_norm else self.scores - # ) - # normalized_scores, indices = torch.sort(normalized_scores, dim=-1, descending=True) - - # # Create temporary buffers to hold the sorted results - # new_transcript_wb = self.transcript_wb.clone() - # new_timestamps = self.timestamps.clone() if (self.model_type == ASRModelTypeEnum.TDT or self.model_type == ASRModelTypeEnum.RNNT) else None - - # # Process each batch element sequentially - # for batch_idx in range(self.batch_size): - # max_idx = self.current_lengths_wb[batch_idx].max() - 1 - # if max_idx < 0: - # continue - - # batch_indices_local = indices[batch_idx] # [beam_size] - - # # For each beam in the sorted order, reconstruct the path by following pointers - # for beam_idx in range(self.beam_size): - # src_beam = batch_indices_local[beam_idx].item() - # ptr = src_beam - - # # Reconstruct the path from max_idx down to 0 - # for idx in range(max_idx, -1, -1): - # # Copy the token at this position - # new_transcript_wb[batch_idx, beam_idx, idx] = self.transcript_wb[batch_idx, ptr, idx] - - # # Copy timestamp if applicable - # if new_timestamps is not None: - # new_timestamps[batch_idx, beam_idx, idx] = self.timestamps[batch_idx, ptr, idx] - - # # Follow the pointer to the previous beam - # next_ptr = self.transcript_wb_prev_ptr[batch_idx, ptr, idx].item() - # if next_ptr != INIT_POINTER_VALUE: - # # Only update pointer if it's valid; otherwise keep current ptr - # ptr = next_ptr - - # # Copy reconstructed paths back to main buffers - # self.transcript_wb.copy_(new_transcript_wb) - # if new_timestamps is not None: - # self.timestamps.copy_(new_timestamps) - - # # Reset pointers to simple sequential structure - # max_idx = self.current_lengths_wb.max() - 1 - # if max_idx >= 0: - # self.transcript_wb_prev_ptr[..., : max_idx + 1].copy_(self.beam_indices.unsqueeze(0).unsqueeze(-1)) - - # # Sort all other state tensors according to indices - # self.scores.copy_(torch.gather(self.scores, dim=-1, index=indices)) - # self.current_lengths_nb.copy_(torch.gather(self.current_lengths_nb, dim=-1, index=indices)) - # self.current_lengths_wb.copy_(torch.gather(self.current_lengths_wb, dim=-1, index=indices)) - - # self.last_label.copy_(torch.gather(self.last_label, dim=-1, index=indices)) - - # if self.model_type == ASRModelTypeEnum.TDT or self.model_type == ASRModelTypeEnum.RNNT: - # self.next_timestamp.copy_(torch.gather(self.next_timestamp, dim=-1, index=indices)) - # self.last_timestamp_lasts.copy_(torch.gather(self.last_timestamp_lasts, dim=-1, index=indices)) - - # self.transcript_hash.copy_(torch.gather(self.transcript_hash, dim=-1, index=indices)) - # if self.store_prefix_hashes: - # self.transcript_prefix_hash.copy_(torch.gather(self.transcript_prefix_hash, dim=-1, index=indices)) - def flatten_sort_(self, score_norm: bool = True): """ Sorts and flattens the tree structure of hypotheses in a batched beam search decoding process. From 8adcfd27580de9ef798d5a5baef162a997782483 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 15 May 2026 16:54:40 +0400 Subject: [PATCH 13/13] clean up Signed-off-by: lilithgrigoryan --- nemo/collections/asr/parts/submodules/tdt_beam_decoding.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py index a0154b1f80e9..fea925dd219c 100644 --- a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py @@ -955,13 +955,10 @@ def forward( with torch.inference_mode(): # Apply optional preprocessing encoder_output = encoder_output.transpose(1, 2) # (B, T, D) - logitlen = encoded_lengths self.decoder.eval() self.joint.eval() - inseq = encoder_output # [B, T, D] - encoder_output_projected = self.joint.project_encoder(encoder_output) encoder_output_projected_len = encoded_lengths batched_beam_hyps, alignments, decoding_state = self._decoding_computer(x=encoder_output_projected, out_len=encoder_output_projected_len)