diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index 88831aaed00c..36f0c8454098 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -75,9 +75,13 @@ from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel from nemo.collections.asr.parts.context_biasing.biasing_multi_model import BiasingRequestItemConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.collections.asr.parts.submodules.rnnt_maes_batched_computer import ModifiedAESBatchedRNNTComputer +from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer +from nemo.collections.asr.parts.submodules.tdt_malsd_batched_computer import ModifiedALSDBatchedTDTComputer from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import ( GreedyBatchedLabelLoopingComputerBase, ) +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import BatchedBeamHyps from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer from nemo.collections.asr.parts.utils.manifest_utils import filepath_to_absolute, read_manifest from nemo.collections.asr.parts.utils.rnnt_utils import BatchedHyps, batched_hyps_to_hypotheses @@ -232,9 +236,21 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: # Change Decoding Config with open_dict(cfg.decoding): - if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True: + if cfg.decoding.strategy == "greedy_batch": + if cfg.decoding.greedy.loop_labels is not True: + raise NotImplementedError( + "This script supports `greedy_batch` strategy only with Label-Looping algorithm" + ) + cfg.decoding.greedy.preserve_alignments = False + elif cfg.decoding.strategy == "malsd_batch": + # MALSD beam search is supported for streaming + pass + elif cfg.decoding.strategy == "maes_batch": + # MAES beam search is supported for streaming + pass + else: raise NotImplementedError( - "This script currently supports only `greedy_batch` strategy with Label-Looping algorithm" + "This script currently supports only `greedy_batch` with Label-Looping or `malsd_batch` strategy with MALSD" ) cfg.decoding.tdt_include_token_duration = cfg.timestamps cfg.decoding.greedy.preserve_alignments = False @@ -278,7 +294,16 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: asr_model.preprocessor.featurizer.pad_to = 0 asr_model.eval() - decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer + # Get decoding computer based on strategy. Beam-search strategies expose the + # underlying computer via the private ``_decoding_computer`` attribute. + if cfg.decoding.strategy == "greedy_batch": + decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer + elif cfg.decoding.strategy == "malsd_batch": + decoding_computer = asr_model.decoding.decoding._decoding_computer + elif cfg.decoding.strategy == "maes_batch": + decoding_computer: ModifiedAESBatchedRNNTComputer = asr_model.decoding.decoding._decoding_computer + else: + raise ValueError(f"Unsupported decoding strategy: {cfg.decoding.strategy}") audio_sample_rate = model_cfg.preprocessor['sample_rate'] @@ -394,7 +419,15 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: ) rest_audio_lengths = audio_batch_lengths.clone() - # iterate over audio samples + # Beam-search strategies (MALSD/MAES/TDT-MALSD) keep their batched_hyps inside + # the decoder state and reset chunk-local buffers per chunk; greedy returns a + # fresh BatchedHyps each call which we must merge externally. + is_beam_search = isinstance( + decoding_computer, + (ModifiedALSDBatchedRNNTComputer, ModifiedAESBatchedRNNTComputer, ModifiedALSDBatchedTDTComputer), + ) + + # iterate over audio samples (but only for decoding chunks) while left_sample < audio_batch.shape[1]: # add samples to buffer chunk_length = min(right_sample, audio_batch.shape[1]) - left_sample @@ -425,18 +458,47 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: encoder_output = encoder_output[:, encoder_context.left :] # decode only chunk frames - chunk_batched_hyps, _, state = decoding_computer( - x=encoder_output, - out_len=torch.where( - is_last_chunk_batch, - encoder_output_len - encoder_context_batch.left, - encoder_context_batch.chunk, - ), - prev_batched_state=state, - multi_biasing_ids=multi_biasing_ids, + out_len = torch.where( + is_last_chunk_batch, + encoder_output_len - encoder_context_batch.left, + encoder_context_batch.chunk, ) - # merge hyps with previous hyps - if current_batched_hyps is None: + if is_beam_search: + # Beam-search computers don't accept ``multi_biasing_ids`` yet. + chunk_batched_hyps, _, state = decoding_computer( + x=encoder_output, out_len=out_len, prev_batched_state=state + ) + else: + chunk_batched_hyps, _, state = decoding_computer( + x=encoder_output, + out_len=out_len, + prev_batched_state=state, + multi_biasing_ids=multi_biasing_ids, + ) + + # Accumulate hypotheses across chunks. + if is_beam_search: + # ``flatten_`` resolves the chunk-local prefix tree without sorting and + # returns ``root_ptrs``: for each chunk-end beam ``i``, the beam index + # at the chunk's start it descends from. The loop body permutes beams + # via top-K, so beam ``i`` at the end of chunk ``N`` is generally a + # different logical hypothesis from beam ``i`` at the end of chunk + # ``N-1`` - it descends from beam ``root_ptrs[i]``. Threading these + # pointers via ``boundary_prev_ptr`` encodes the redirection so the + # final ``flatten_sort_`` in ``to_hyps_list`` walks back through the + # right beam history. ``is_chunk_continuation=True`` makes ``merge_`` + # replace (not sum) the cumulative per-beam fields (``scores``, + # ``current_lengths_nb``, ...). + chunk_root_ptrs = state.batched_hyps.flatten_() + if current_batched_hyps is None: + current_batched_hyps = state.batched_hyps + else: + current_batched_hyps.merge_( + state.batched_hyps, + is_chunk_continuation=True, + boundary_prev_ptr=chunk_root_ptrs, + ) + elif current_batched_hyps is None: current_batched_hyps = chunk_batched_hyps else: current_batched_hyps.merge_(chunk_batched_hyps) @@ -446,13 +508,19 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: left_sample = right_sample right_sample = min(right_sample + context_samples.chunk, audio_batch.shape[1]) # add next chunk - # remove biasing requests from the decoder - if use_per_stream_biasing and audio_data.biasing_requests is not None: - for request in audio_data.biasing_requests: - if request is not None and request.multi_model_id is not None: - decoding_computer.biasing_multi_model.remove_model(request.multi_model_id) - request.multi_model_id = None - all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size)) + # Convert batched hypotheses to list + if isinstance(current_batched_hyps, BatchedBeamHyps): + # MALSD beam search returns BatchedBeamHyps + all_hyps.extend(current_batched_hyps.to_hyps_list(score_norm=True)) + else: + # Greedy batch returns BatchedHyps + # remove biasing requests from the decoder + if use_per_stream_biasing and audio_data.biasing_requests is not None: + for request in audio_data.biasing_requests: + if request is not None and request.multi_model_id is not None: + decoding_computer.biasing_multi_model.remove_model(request.multi_model_id) + request.multi_model_id = None + all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size)) timer.stop(device=map_location) # convert text diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index d0b02a73daaf..1909212c81f2 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -1690,7 +1690,7 @@ def forward( self.joint.eval() inseq = encoder_output # [B, T, D] - batched_beam_hyps = self._decoding_computer(x=inseq, out_len=logitlen) + batched_beam_hyps, alignments, decoding_state = self._decoding_computer(x=inseq, out_len=logitlen) batch_size = encoder_output.shape[0] if self.return_best_hypothesis: diff --git a/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py index 1af34f11ac8e..c326307b9d1f 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py @@ -15,7 +15,10 @@ from typing import Optional import torch +import torch.nn.functional as F +from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import BatchedLabelLoopingState +from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( INACTIVE_SCORE, @@ -33,6 +36,10 @@ class ModifiedAESBatchedRNNTComputer(ConfidenceMethodMixin): Based on https://ieeexplore.ieee.org/document/9250505 with the following modficiations: - does not support prediction network caching - supports prefix search with only longest prefix + + Note: RNN-T only. TDT is not supported by this computer (use ``ModifiedALSDBatchedTDTComputer`` + instead). Unlike :class:`ModifiedALSDBatchedRNNTComputer` this is a pure-PyTorch decoder; + CUDA graphs are not implemented. """ def __init__( @@ -44,7 +51,7 @@ def __init__( maes_num_steps: int, maes_expansion_beta: int, maes_expansion_gamma: int, - preserve_alignments=False, + preserve_alignments: bool = False, ngram_lm_model: Optional[str | Path] = None, ngram_lm_alpha: float = 0.0, blank_lm_score_mode: Optional[str | BlankLMScoreMode] = BlankLMScoreMode.NO_SCORE, @@ -75,37 +82,35 @@ def __init__( ngram_lm_alpha: weight for the n-gram LM scores blank_lm_score_mode: mode for scoring blank symbol with LM pruning_mode: mode for pruning hypotheses with LM - allow_cuda_graphs: whether to allow CUDA graphs + allow_cuda_graphs: accepted for API parity with :class:`ModifiedALSDBatchedRNNTComputer`, + but ignored - this computer does not implement CUDA graphs. """ super().__init__() self.decoder = decoder self.joint = joint self._blank_index = blank_index + self._SOS = self._blank_index self.beam_size = beam_size self.maes_num_steps = maes_num_steps self.maes_expansion_beta = maes_expansion_beta self.maes_expansion_gamma = maes_expansion_gamma - self.preserve_alignments = preserve_alignments - self._SOS = self._blank_index - self.pruning_mode = pruning_mode - self.blank_lm_score_mode = blank_lm_score_mode - self.maes_num_expansions = self.beam_size + self.maes_expansion_beta + self.preserve_alignments = preserve_alignments if self.preserve_alignments: raise NotImplementedError("Preserve alignments is not supported") if allow_cuda_graphs: - logging.info("CUDA Graphs are unsupported for `maes_batch`; preceeding pure pytorch decoding") + logging.info("`allow_cuda_graphs=True` is accepted for API parity, but `maes_batch` runs in pure PyTorch.") + # n-gram LM fusion setup if ngram_lm_model is not None: expected_blank_index = self.joint.num_classes_with_blank - self.joint.num_extra_outputs - 1 if self._blank_index != expected_blank_index: raise ValueError(f"Invalid blank index: expected {expected_blank_index}, got {self._blank_index}") self.ngram_lm_batch = ngram_lm_model - self.pruning_mode = PruningMode.EARLY if pruning_mode is None else PruningMode(pruning_mode) self.blank_lm_score_mode = ( BlankLMScoreMode.LM_WEIGHTED_FULL @@ -114,6 +119,7 @@ def __init__( ) else: self.ngram_lm_batch = None + self.pruning_mode = pruning_mode self.blank_lm_score_mode = None self.ngram_lm_alpha = ngram_lm_alpha @@ -121,66 +127,101 @@ def batched_modified_adaptive_expansion_search_torch( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: """ - Pure PyTorch implementation + Pure PyTorch implementation of batched modified adaptive expansion search for RNN-T. Args: - encoder_output: output from the encoder - encoder_output_length: lengths of the utterances in `encoder_output` + encoder_output: output from the encoder, shape [batch_size, max_time, encoder_dim]. + encoder_output_length: lengths of the utterances in ``encoder_output``, shape [batch_size]. + prev_batched_state: optional state from a previous chunk (streaming / chunked decoding). + When provided, ``predictor_states``, ``predictor_outputs`` and ``batched_hyps`` are + reused so that beam search continues across the chunk boundary. + + Returns: + tuple of (batched_hyps, None, decoding_state) where ``decoding_state`` is the state to + pass back as ``prev_batched_state`` on the next chunk. """ - batch_size, max_time, _unused = encoder_output.shape + batch_size, max_time, _ = encoder_output.shape device = encoder_output.device + # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) float_dtype = encoder_output_projected.dtype - # init empty batched hypotheses - batched_hyps = BatchedBeamHyps( - batch_size=batch_size, - beam_size=self.beam_size, - blank_index=self._blank_index, - init_length=max_time * (self.maes_num_steps + 1) if self.maes_num_steps is not None else max_time, - device=device, - float_dtype=float_dtype, - store_prefix_hashes=True, - ) - - last_labels_wb = torch.full( - [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long - ) - batch_indices = ( - torch.arange(batch_size, device=device)[:, None].expand(batch_size, self.beam_size).clone() + torch.arange(batch_size, dtype=torch.long, device=device)[:, None] + .expand(batch_size, self.beam_size) + .clone() ) # size: batch_size x beam_size beam_indices = ( - torch.arange(self.beam_size, device=device)[None, :].expand(batch_size, self.beam_size).clone() + torch.arange(self.beam_size, dtype=torch.long, device=device)[None, :] + .expand(batch_size, self.beam_size) + .clone() ) # size: batch_size x beam_size expansion_beam_indices = ( - torch.arange(self.beam_size, device=device)[None, :, None] + torch.arange(self.beam_size, dtype=torch.long, device=device)[None, :, None] .expand(batch_size, self.beam_size, self.maes_num_expansions) .clone() - ) # size: batch_size x beam_size x beam_size + maes_expansion_beta + ) # size: batch_size x beam_size x (beam_size + maes_expansion_beta) + + # On continuation, work on a fresh clone of the previous chunk's hypotheses so we + # don't mutate the state object the caller may also be using as its streaming + # accumulator (``current_batched_hyps`` in the streaming script aliases + # ``state.batched_hyps`` on the first chunk - mutating it here would corrupt the + # accumulator). Cross-chunk per-beam fields (``scores``, ``last_label``, + # ``transcript_hash``, ``current_lengths_nb``, ``last_timestamp_lasts``) are + # preserved by ``clone()``; the chunk-local prefix-tree cursor and buffers are + # reset by ``clear_chunk_local_()`` so this chunk's ``add_results_`` scatters + # start at offset zero (cf. ``_before_loop_continuation`` / + # ``_create_decoding_state`` clone in :class:`ModifiedALSDBatchedRNNTComputer`). + if prev_batched_state is not None and prev_batched_state.batched_hyps is not None: + batched_hyps = prev_batched_state.batched_hyps.clone() + batched_hyps.clear_chunk_local_() + else: + batched_hyps = BatchedBeamHyps( + batch_size=batch_size, + beam_size=self.beam_size, + blank_index=self._blank_index, + init_length=max_time * (self.maes_num_steps + 1) if self.maes_num_steps is not None else max_time, + device=device, + float_dtype=float_dtype, + store_prefix_hashes=True, + ) time_indices = torch.zeros_like(batch_indices) - safe_time_indices = torch.zeros_like(time_indices) - last_timesteps = (encoder_output_length - 1)[:, None].expand(batch_size, self.beam_size) + safe_time_indices = torch.zeros_like(time_indices) # time indices, guaranteed to be < out_len + last_timesteps = (encoder_output_length - 1)[:, None].expand_as(batch_indices) active_mask = time_indices <= last_timesteps # setup N-gram LM if available + # TODO: when ``prev_batched_state`` is provided, the n-gram LM state is currently + # re-seeded from BOS instead of being restored - see ``ModifiedALSDBatchedRNNTComputer`` + # for how fusion-model states are threaded through ``BatchedLabelLoopingState``. if self.ngram_lm_batch is not None: self.ngram_lm_batch.to(device) batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size * self.beam_size, bos=True) - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=batch_lm_states - ) # vocab_size_no_blank + lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=batch_lm_states) lm_scores = lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha - decoder_output, decoder_state, *_ = self.decoder.predict( - last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size - ) - # do not recalculate joint projection - decoder_output = self.joint.project_prednet(decoder_output) + if prev_batched_state is None: + last_labels_wb = torch.full( + [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long + ) + decoder_state = self.decoder.initialize_state( + torch.empty([batch_size * self.beam_size], dtype=float_dtype, device=device) + ) + decoder_output, state, *_ = self.decoder.predict( + last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size + ) + # do not recalculate joint projection + decoder_output = self.joint.project_prednet(decoder_output) # size: [(batch_size x beam_size), 1, Dim] + self.decoder.batch_replace_states_all(state, dst_states=decoder_state) + else: + # Continuing from previous chunk - batched_hyps already contains all state + decoder_output = prev_batched_state.predictor_outputs + decoder_state = prev_batched_state.predictor_states while active_mask.any(): # frames loop to_update = active_mask.clone() # mask for expansions loop @@ -194,14 +235,14 @@ def batched_modified_adaptive_expansion_search_torch( .squeeze(1) .squeeze(1) ) - logps = torch.log_softmax(logits, dim=-1).view(batch_size, self.beam_size, -1) + logps = F.log_softmax(logits, dim=-1).view(batch_size, self.beam_size, -1) # step 2: perform prefix search updated_logps = self.combine_scores(logps, lm_scores) if self.ngram_lm_batch is not None else logps batched_hyps.recombine_prefixes(updated_logps, active_mask) - expansion_steps = 0 # step 3: performs `maes_num_steps` non-blank expansions + expansion_steps = 0 while to_update.any() and expansion_steps < self.maes_num_steps: # expansions loop # step 3.1: get `maes_num_expansion` best expansions (in total beam x maes_num_expansion expansions) if self.ngram_lm_batch is None: @@ -238,7 +279,7 @@ def batched_modified_adaptive_expansion_search_torch( # step 3.3: update batched beam hypotheses structure batched_hyps.add_results_(hyp_indices, next_labels, next_hyps_probs) - # step 3.4: update + # step 3.4: update last labels (mask invalid with blank to avoid decoder errors) last_labels_wb = torch.where(next_labels >= 0, next_labels, self._blank_index) preserve_state = last_labels_wb == self._blank_index @@ -294,9 +335,7 @@ def batched_modified_adaptive_expansion_search_torch( ).squeeze(-1) batch_lm_states = torch.where(preserve_state, batch_lm_states_prev, batch_lm_states).view(-1) - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=batch_lm_states - ) # vocab_size_no_blank + lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=batch_lm_states) lm_scores = ( lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha ) @@ -306,12 +345,13 @@ def batched_modified_adaptive_expansion_search_torch( encoder_output_projected[batch_indices.flatten(), safe_time_indices.flatten()].unsqueeze(1), decoder_output, ) - logps = torch.log_softmax(logits, dim=-1).squeeze(1).squeeze(1).view(batch_size, self.beam_size, -1) + logps = F.log_softmax(logits, dim=-1).squeeze(1).squeeze(1).view(batch_size, self.beam_size, -1) to_update = torch.logical_and(to_update, last_labels_wb != self._blank_index) expansion_steps += 1 + + # step 4: force blank to active hypotheses still waiting for one this frame if to_update.any(): - # step 4: force blank to active hypotheses next_hyps_probs = torch.where(to_update, batched_hyps.scores + logps[..., -1], batched_hyps.scores) next_labels = torch.where(to_update, self._blank_index, -1) batched_hyps.add_results_(beam_indices, next_labels, next_hyps_probs) @@ -321,7 +361,53 @@ def batched_modified_adaptive_expansion_search_torch( active_mask = time_indices <= last_timesteps safe_time_indices = torch.where(active_mask, time_indices, last_timesteps) - return batched_hyps + return ( + batched_hyps, + None, + self._create_decoding_state( + batched_hyps=batched_hyps, + decoder_state=decoder_state, + decoder_output=decoder_output, + encoder_output_length=encoder_output_length, + prev_batched_state=prev_batched_state, + ), + ) + + def _create_decoding_state( + self, + batched_hyps: BatchedBeamHyps, + decoder_state, + decoder_output: torch.Tensor, + encoder_output_length: torch.Tensor, + prev_batched_state: Optional[BatchedLabelLoopingState], + ) -> BatchedLabelLoopingState: + """ + Build the :class:`BatchedLabelLoopingState` returned for the next chunk. + + Mirrors the trailing block of :meth:`ModifiedALSDBatchedRNNTComputer.modified_alsd_torch`: + accumulate ``decoded_lengths`` across chunks, fall back to previous-chunk labels for + beams that emitted nothing in this chunk, and reset the chunk-local ``next_timestamp`` + write cursor on ``batched_hyps``. + """ + last_labels = batched_hyps.get_last_labels(pad_id=self._SOS) + batched_hyps.next_timestamp.fill_(0) + + if prev_batched_state is not None: + # Beams that emitted nothing this chunk return SOS; carry over the previous label. + labels = torch.where(last_labels == self._SOS, prev_batched_state.labels, last_labels) + decoded_lengths = encoder_output_length + prev_batched_state.decoded_lengths + else: + labels = last_labels + decoded_lengths = encoder_output_length.clone() + + return BatchedLabelLoopingState( + predictor_states=decoder_state, + predictor_outputs=decoder_output, + labels=labels, + decoded_lengths=decoded_lengths, + time_jumps=None, + batched_hyps=batched_hyps, + ) def combine_scores(self, log_probs, lm_scores): """ @@ -438,5 +524,10 @@ def __call__( self, x: torch.Tensor, out_len: torch.Tensor, - ) -> BatchedBeamHyps: - return self.batched_modified_adaptive_expansion_search_torch(encoder_output=x, encoder_output_length=out_len) + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: + return self.batched_modified_adaptive_expansion_search_torch( + encoder_output=x, + encoder_output_length=out_len, + prev_batched_state=prev_batched_state, + ) diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index f1ef93c64f5d..be66c07df1f2 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -19,6 +19,8 @@ import torch.nn.functional as F from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel +from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import BatchedLabelLoopingState +from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( INACTIVE_SCORE, @@ -101,6 +103,10 @@ class MALSDState: init_fusion_states_candidates_list: Optional[List[torch.Tensor]] = None # list of initial fusion states candidates init_fusion_scores_list: Optional[List[torch.Tensor]] = None # list of initial fusion scores + # Streaming state fields + is_continuation: torch.Tensor # flag indicating if this is a continuation from previous chunk + is_first_chunk: torch.Tensor # complement of ``is_continuation``; both feed the captured graph's IF nodes + def __init__( self, batch_size: int, @@ -123,7 +129,6 @@ def __init__( float_dtype: default float dtype for tensors (should match projected encoder output) blank_index: index of the blank symbol """ - self.device = device self.float_dtype = float_dtype self.batch_size = batch_size @@ -184,6 +189,13 @@ def __init__( float_dtype=float_dtype, ) + # Streaming state fields. The captured FULL_GRAPH reads ``is_first_chunk`` and + # ``is_continuation`` to route to the first-chunk vs. continuation prologue at replay + # time. The two flags are kept in lockstep by ``modified_alsd_cuda_graphs`` before + # each replay (they are mutually exclusive). + self.is_continuation = torch.tensor(False, device=self.device, dtype=torch.bool) + self.is_first_chunk = torch.tensor(True, device=self.device, dtype=torch.bool) + def need_reinit(self, encoder_output_projected: torch.Tensor) -> bool: """Check if need to reinit state: larger batch_size/max_time, or new device""" return ( @@ -198,6 +210,7 @@ class SeparateGraphsMALSD: """Class to store Cuda graphs for decoding when separate graphs are used""" before_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + before_loop_continuation: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) loop_body: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) loop_update_decoder: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) @@ -350,7 +363,8 @@ def modified_alsd_torch( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: """ Pytorch implementation of the batched ALSD algorithm for RNN-T. Args: @@ -371,20 +385,6 @@ def modified_alsd_torch( encoder_output_projected = self.joint.project_encoder(encoder_output) float_dtype = encoder_output_projected.dtype - # init empty batched beam hypotheses - batched_hyps = BatchedBeamHyps( - batch_size=batch_size, - beam_size=self.beam_size, - blank_index=self._blank_index, - init_length=max_time * (self.max_symbols + 1) if self.max_symbols is not None else max_time, - device=device, - float_dtype=float_dtype, - ) - - last_labels_wb = torch.full( - [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long - ) - batch_beam_indices = ( torch.arange(batch_size, dtype=torch.long, device=device)[:, None] .expand(batch_size, self.beam_size) @@ -396,6 +396,19 @@ def modified_alsd_torch( .clone() ) # size: batch_size x beam_size x beam_size + # Reuse batched beam hypotheses from state if continuing, otherwise create new + if prev_batched_state is not None and prev_batched_state.batched_hyps is not None: + batched_hyps = prev_batched_state.batched_hyps + else: + batched_hyps = BatchedBeamHyps( + batch_size=batch_size, + beam_size=self.beam_size, + blank_index=self._blank_index, + init_length=max_time * (self.max_symbols + 1) if self.max_symbols is not None else max_time, + device=device, + float_dtype=float_dtype, + ) + time_indices = torch.zeros_like(batch_beam_indices) safe_time_indices = torch.zeros_like(time_indices) # time indices, guaranteed to be < out_len last_timesteps = (encoder_output_length - 1)[:, None].expand_as(batch_beam_indices) @@ -403,41 +416,68 @@ def modified_alsd_torch( # setup fusion models if available if self.fusion_models is not None: - fusion_states_list = [] - fusion_states_candidates_list = [] - fusion_scores_list = [] - for fusion_model_idx, fusion_model in enumerate(self.fusion_models): - fusion_model.to(device) - fusion_states = fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True) - fusion_scores, fusion_states_candidates = fusion_model.advance( - states=fusion_states - ) # vocab_size_no_blank + if prev_batched_state is None or prev_batched_state.fusion_states_list is None: + fusion_states_list = [] + fusion_states_candidates_list = [] + fusion_scores_list = [] + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(device) + fusion_states = fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True) + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=fusion_states + ) # vocab_size_no_blank - fusion_scores = ( - fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) - * self.fusion_models_alpha[fusion_model_idx] - ) - fusion_states_list.append(fusion_states) - fusion_states_candidates_list.append(fusion_states_candidates) - fusion_scores_list.append(fusion_scores) + fusion_scores = ( + fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_model_idx] + ) + fusion_states_list.append(fusion_states) + fusion_states_candidates_list.append(fusion_states_candidates) + fusion_scores_list.append(fusion_scores) + else: + fusion_states_list = prev_batched_state.fusion_states_list + fusion_states_candidates_list = [] + fusion_scores_list = [] + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(device) + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=fusion_states_list[fusion_model_idx] + ) + fusion_scores = ( + fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_model_idx] + ) + fusion_states_candidates_list.append(fusion_states_candidates) + fusion_scores_list.append(fusion_scores) + else: + fusion_states_list = None - decoder_state = self.decoder.initialize_state( - torch.empty( - [ - batch_size * self.beam_size, - ], - dtype=float_dtype, - device=device, + if prev_batched_state is None: + last_labels_wb = torch.full( + [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long + ) + decoder_state = self.decoder.initialize_state( + torch.empty( + [ + batch_size * self.beam_size, + ], + dtype=float_dtype, + device=device, + ) ) - ) - - decoder_output, state, *_ = self.decoder.predict( - last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size - ) - # do not recalculate joint projection - decoder_output = self.joint.project_prednet(decoder_output) # size: [(batch_size x beam_size), 1, Dim] - self.decoder.batch_replace_states_all(state, dst_states=decoder_state) + decoder_output, state, *_ = self.decoder.predict( + last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size + ) + # do not recalculate joint projection + decoder_output = self.joint.project_prednet(decoder_output) # size: [(batch_size x beam_size), 1, Dim] + self.decoder.batch_replace_states_all(state, dst_states=decoder_state) + else: + # Continuing from previous chunk - batched_hyps already contains all state + decoder_output = prev_batched_state.predictor_outputs + decoder_state = prev_batched_state.predictor_states + + step1=0 while active_mask.any(): # step 1: get joint output + fuse with fusion models (if present) logits = ( @@ -588,11 +628,36 @@ def modified_alsd_torch( fusion_scores_list[fusion_model_idx] = fusion_scores # step 6: update time indices + active mask - time_indices = batched_hyps.next_timestamp + time_indices = torch.gather(time_indices, dim=-1, index=hyps_indices) + (next_labels == self._blank_index) torch.minimum(time_indices, last_timesteps, out=safe_time_indices) active_mask = time_indices <= last_timesteps + + step1 += 1 + + # NB: last labels can not exist (nothing decoded on this step). + # return the last labels from the previous state in this case + last_labels = batched_hyps.get_last_labels(pad_id=self._SOS) + batched_hyps.next_timestamp.fill_(0) + decoding_state = BatchedLabelLoopingState( + predictor_states=decoder_state, + predictor_outputs=decoder_output, + labels=( + torch.where(last_labels == self._SOS, prev_batched_state.labels, last_labels) + if prev_batched_state is not None + else last_labels + ), + decoded_lengths=( + encoder_output_length.clone() + if prev_batched_state is None + else encoder_output_length + prev_batched_state.decoded_lengths + ), + fusion_states_list=fusion_states_list if self.fusion_models is not None else None, + time_jumps=None, + batched_hyps=batched_hyps, # Save batched_hyps object for next chunk + ) - return batched_hyps + # import pdb; pdb.set_trace() + return batched_hyps, None, decoding_state def topk_fusion_model(self, fusion_scores_list, log_probs, eps=1e-2): """ @@ -668,7 +733,8 @@ def modified_alsd_cuda_graphs( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: """ Cuda-Graphs implementation of the batched ALSD algorithm. Args: @@ -676,8 +742,9 @@ def modified_alsd_cuda_graphs( [batch_size, max_time, encoder_dim]. encoder_output_length (torch.Tensor): The lengths of the encoder outputs for each batch with shape [batch_size]. + prev_batched_state (Optional[BatchedLabelLoopingState]): The previous batched state. Returns: - BathcedBeamHyps: Batched beam hypotheses. + tuple: (BatchedBeamHyps, None, BatchedLabelLoopingState) """ assert self.cuda_graphs_mode is not None @@ -694,29 +761,55 @@ def modified_alsd_cuda_graphs( if self.state is None or self.state.need_reinit(encoder_output): self._graph_reinitialize(encoder_output, encoder_output_length) + # Set continuation flag and restore state from previous chunk if provided + is_continuation = prev_batched_state is not None + self.state.is_continuation.fill_(is_continuation) + # Mirror into the inverse flag so the captured graph's IF nodes can route to + # the right prologue. Both tensors are read by ``loop_conditional``-style + # condition kernels baked into the graph at capture time. + self.state.is_first_chunk.fill_(not is_continuation) + + if is_continuation: + # Restore state from previous chunk + self._restore_state_from_prev(prev_batched_state, current_batch_size) + # set length to zero for elements outside the current batch self.state.encoder_output_length.fill_(0) - # copy (projected) encoder output and lenghts - self.state.encoder_output_projected[:current_batch_size, :current_max_time, ...].copy_(encoder_output) + # copy (projected) encoder output and lengths + self.state.encoder_output_projected[:current_batch_size, :current_max_time, ...].copy_( + encoder_output[:current_batch_size, :current_max_time, ...] + ) self.state.encoder_output_length[:current_batch_size].copy_(encoder_output_length.unsqueeze(-1)) + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: + # Single graph dispatches between first-chunk and continuation prologues internally + # via captured IF nodes that read ``is_first_chunk`` / ``is_continuation``. self.full_graph.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: - self.separate_graphs.before_loop.replay() + # Use continuation before_loop graph if continuing from previous chunk + if is_continuation: + self.separate_graphs.before_loop_continuation.replay() + else: + self.separate_graphs.before_loop.replay() while self.state.active_mask_any.item(): self.separate_graphs.loop_body.replay() self.separate_graphs.loop_update_decoder.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: - # this mode is only for testing purposes # manual loop instead of using graphs - self._before_loop() + if is_continuation: + self._before_loop_continuation() + else: + self._before_loop() while self.state.active_mask_any.item(): self._loop_body() self._loop_update_decoder() else: raise NotImplementedError(f"Unknown graph mode: {self.cuda_graphs_mode}") - return self.state.batched_hyps + # Create and return decoding state for next chunk + decoding_state = self._create_decoding_state(encoder_output_length, prev_batched_state) + + return self.state.batched_hyps, None, decoding_state @classmethod def _create_loop_body_kernel(cls): @@ -838,6 +931,10 @@ def _graph_reinitialize( self.state.fusion_scores_list.append(self.state.init_fusion_scores_list[fusion_model_idx].clone()) self.state.fusion_states_prev_list.append(init_fusion_states.clone()) + # warmup before graph compilation + if self.cuda_graphs_mode is not self.CudaGraphsMode.NO_GRAPHS: + self._warmup_for_cuda_graphs() + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: try: self._full_graph_compile() @@ -858,12 +955,45 @@ def _graph_reinitialize( else: raise NotImplementedError + def _warmup_for_cuda_graphs(self): + """Warmup before compiling CUDA graphs. + + Runs a few eager iterations of both the first-chunk and continuation paths so that + cuBLAS / cuDNN handles and workspaces are allocated and stable before any graph + capture begins. Mirrors the warmup pattern used by the greedy label-looping decoder. + """ + is_ddp = torch.distributed.is_available() and torch.distributed.is_initialized() + # 11 warmup steps required in DDP mode + # see https://pytorch.org/docs/stable/notes/cuda.html#usage-with-distributeddataparallel + num_runs = 11 if is_ddp else 3 + self.state.encoder_output_projected.fill_(0.0) + self.state.encoder_output_length.fill_(1) + s = torch.cuda.Stream(self.state.device) + s.wait_stream(torch.cuda.current_stream(device=self.state.device)) + with torch.cuda.stream(s), torch.inference_mode(): + # Warm up the first-chunk path. + for _ in range(num_runs): + self._before_loop() + self._loop_body() + self._loop_update_decoder() + # Warm up the continuation path so its prologue and any kernels it touches + # are primed too. Both captures share a mempool, so any allocator activity + # they trigger needs to settle before either is captured. + for _ in range(num_runs): + self._before_loop_continuation() + self._loop_body() + self._loop_update_decoder() + torch.cuda.current_stream(device=self.state.device).wait_stream(s) + self.state.encoder_output_length.fill_(0) + def _partial_graphs_compile(self): """Compile decoding by parts""" # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) self.separate_graphs = SeparateGraphsMALSD() + + # Compile before_loop graph for first chunk with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), @@ -873,6 +1003,16 @@ def _partial_graphs_compile(self): ): self._before_loop() + # Compile before_loop_continuation graph for streaming + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph( + self.separate_graphs.before_loop_continuation, stream=stream_for_graph, capture_error_mode="thread_local" + ), + ): + self._before_loop_continuation() + with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), @@ -893,9 +1033,23 @@ def _partial_graphs_compile(self): @cuda_python_required def _full_graph_compile(self): - """Compile full graph for decoding""" + """Compile a single CUDA graph that handles both first-chunk and continuation paths. + + The graph contains three conditional sub-graphs in order: + 1. IF (``is_first_chunk``) → ``_before_loop()`` + 2. IF (``is_continuation``) → ``_before_loop_continuation()`` + 3. WHILE (``active_mask_any``) → ``_loop_body()`` + ``_loop_update_decoder()`` + + At replay time the caller toggles ``is_first_chunk`` / ``is_continuation`` so + exactly one prologue executes. This avoids needing two coexisting CUDAGraph + objects (which observed to cause cudaErrorIllegalAddress on replay due to + mempool interaction between the two captures). + """ # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) + # Drain any work pending on the default stream (e.g. the warmup that ran just above in + # ``_graph_reinitialize``) before we start capturing. + stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) self.full_graph = torch.cuda.CUDAGraph() with ( @@ -903,24 +1057,50 @@ def _full_graph_compile(self): torch.inference_mode(), torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): - self._before_loop() + # The condition-setter kernel (created lazily by ``_create_loop_body_kernel``) is + # signature-compatible with any 0-d bool*; we reuse it for all three conditional nodes. + cond_kernel = self._create_loop_body_kernel() + # NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements capture_status, _, graph, *_ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) ) - assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive - # capture: while self.active_mask_any: + # --- IF (is_first_chunk): run first-chunk prologue --- + (first_chunk_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) + is_first_chunk_ptr = np.array([self.state.is_first_chunk.data_ptr()], dtype=np.uint64) + first_chunk_args = np.array( + [first_chunk_handle.getPtr(), is_first_chunk_ptr.ctypes.data], + dtype=np.uint64, + ) + with with_conditional_node( + cond_kernel, first_chunk_args, first_chunk_handle, device=self.state.device, cond_type="if" + ): + self._before_loop() + + # --- IF (is_continuation): run continuation prologue --- + (continuation_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) + is_continuation_ptr = np.array([self.state.is_continuation.data_ptr()], dtype=np.uint64) + continuation_args = np.array( + [continuation_handle.getPtr(), is_continuation_ptr.ctypes.data], + dtype=np.uint64, + ) + with with_conditional_node( + cond_kernel, continuation_args, continuation_handle, device=self.state.device, cond_type="if" + ): + self._before_loop_continuation() + + # --- WHILE (active_mask_any): main decoding loop --- (loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) - loop_kernel = self._create_loop_body_kernel() active_mask_any_ptr = np.array([self.state.active_mask_any.data_ptr()], dtype=np.uint64) loop_args = np.array( [loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], dtype=np.uint64, ) - # loop while there are active utterances - with with_conditional_node(loop_kernel, loop_args, loop_conditional_handle, device=self.state.device): + with with_conditional_node( + cond_kernel, loop_args, loop_conditional_handle, device=self.state.device, cond_type="while" + ): self._loop_body() self._loop_update_decoder() @@ -941,13 +1121,42 @@ def _before_loop(self): self.state.fusion_scores_list[fusion_idx].copy_(self.state.init_fusion_scores_list[fusion_idx]) self.state.fusion_states_prev_list[fusion_idx].copy_(self.state.init_fusion_states_list[fusion_idx]) + # set decoder state and output to initial values + self.state.decoder_output.copy_(self.state.init_decoder_output) + self.state.decoder_state[0].copy_(self.state.init_decoder_state[0]) + self.state.decoder_state[1].copy_(self.state.init_decoder_state[1]) + # last found labels - initially () symbol self.state.last_labels_wb.fill_(self._SOS) + + self._before_loop_common() + + def _before_loop_continuation(self): + """ + Prologue for a continuation chunk: preserves cross-chunk per-beam state on + ``batched_hyps`` (scores, last_label, transcript_hash, current_lengths_nb, + last_timestamp_lasts) and resets only the chunk-local prefix-tree buffers + (transcript_wb / transcript_wb_prev_ptr / timestamps / current_lengths_wb) + that the captured loop body would otherwise overflow. + + Decoder and fusion states are already restored by ``_restore_state_from_prev``. + The caller is responsible for snapshotting / merging the chunk-local transcripts + before the next chunk (see :meth:`BatchedBeamHyps.merge_` with + ``is_chunk_continuation=True``). + """ + self.state.batched_hyps.clear_chunk_local_() + self._before_loop_common() + + def _before_loop_common(self): + """ + Common initialization for both first chunk and continuation. + Resets temporary variables and computes active mask. + """ self.state.next_scores.fill_(0.0) self.state.next_labels.fill_(0.0) self.state.next_idx.fill_(0.0) - # time indices + # time indices - reset for current chunk self.state.time_indices.fill_(0) self.state.safe_time_indices.fill_(0) # safe time indices: guaranteed to be < encoder_output_length @@ -960,11 +1169,6 @@ def _before_loop(self): # same as: self.active_mask_any = active_mask.any() torch.any(self.state.active_mask, out=self.state.active_mask_any) - # set decoder state and output to initial values - self.state.decoder_output.copy_(self.state.init_decoder_output) - self.state.decoder_state[0].copy_(self.state.init_decoder_state[0]) - self.state.decoder_state[1].copy_(self.state.init_decoder_state[1]) - # set previous decoder state and output to initial values self.state.prev_decoder_output.fill_(0) self.state.prev_decoder_state[0].fill_(0) @@ -1069,7 +1273,6 @@ def _loop_update_decoder(self): Updates the decoder state, decoder output, and optionally the fusion models state for the next iteration of the decoding loop in a batched RNNT (Recurrent Neural Network Transducer) setup. """ - # step 5: update decoder state + decoder output (+ fusion models state/scores) # step 5.1: mask invalid value labels with blank to avoid errors (refer to step 2.2) torch.where( @@ -1172,13 +1375,125 @@ def _loop_update_decoder(self): torch.less_equal(self.state.time_indices, self.state.last_timesteps, out=self.state.active_mask) torch.any(self.state.active_mask, out=self.state.active_mask_any) + def _restore_state_from_prev( + self, prev_batched_state: BatchedLabelLoopingState, current_batch_size: int + ): + """ + Restore decoder state, fusion states, and batched_hyps from previous chunk's state. + Used for streaming/chunked decoding. + + Args: + prev_batched_state: State from previous chunk + current_batch_size: Current batch size + """ + # Restore decoder output and state + if prev_batched_state.predictor_outputs is not None: + self.state.decoder_output[:current_batch_size * self.beam_size].copy_( + prev_batched_state.predictor_outputs.view(-1, 1, prev_batched_state.predictor_outputs.shape[-1]) + ) + + if prev_batched_state.predictor_states is not None: + # Copy decoder states (assuming tuple of tensors) + for i, state_tensor in enumerate(prev_batched_state.predictor_states): + if state_tensor is not None: + self.state.decoder_state[i][:, :current_batch_size * self.beam_size].copy_( + state_tensor[:, :current_batch_size * self.beam_size] + ) + + # Restore fusion states if present + if prev_batched_state.fusion_states_list is not None and self.fusion_models is not None: + for fusion_idx, fusion_state in enumerate(prev_batched_state.fusion_states_list): + if fusion_state is not None: + self.state.fusion_states_list[fusion_idx][:current_batch_size].copy_( + fusion_state[:current_batch_size] + ) + # Recompute fusion scores and candidates from restored states + fusion_scores, fusion_states_candidates = self.fusion_models[fusion_idx].advance( + states=self.state.fusion_states_list[fusion_idx][:current_batch_size].reshape(-1) + ) + self.state.fusion_scores_list[fusion_idx][:current_batch_size].copy_( + fusion_scores.to(dtype=self.state.float_dtype).view(current_batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_idx] + ) + self.state.fusion_states_candidates_list[fusion_idx][:current_batch_size].copy_( + fusion_states_candidates.view(current_batch_size, self.beam_size, -1) + ) + + # Restore batched_hyps from previous state + if prev_batched_state.batched_hyps is not None: + self.state.batched_hyps.copy_from_(prev_batched_state.batched_hyps) + + def _create_decoding_state( + self, + encoder_output_length: torch.Tensor, + prev_batched_state: Optional[BatchedLabelLoopingState], + ) -> BatchedLabelLoopingState: + """ + Create BatchedLabelLoopingState for the next chunk. + + Args: + encoder_output_length: Length of current encoder output + prev_batched_state: State from previous chunk (if any) + + Returns: + BatchedLabelLoopingState containing current decoding state + """ + current_batch_size = encoder_output_length.shape[0] + + # Get last labels from batched_hyps + last_labels = self.state.batched_hyps.get_last_labels(pad_id=self._SOS) + + # Reset next_timestamp for next chunk + self.state.batched_hyps.next_timestamp.fill_(0) + + # Calculate accumulated decoded lengths + if prev_batched_state is None: + decoded_lengths = encoder_output_length.clone() + else: + decoded_lengths = encoder_output_length + prev_batched_state.decoded_lengths[:current_batch_size] + + # Handle labels - if nothing decoded this chunk, use previous labels + if prev_batched_state is not None: + last_labels = torch.where( + last_labels == self._SOS, + prev_batched_state.labels, + last_labels + ) + + # Get fusion states if present + fusion_states_list = None + if self.fusion_models is not None and self.state.fusion_states_list is not None: + fusion_states_list = [ + state[:current_batch_size].clone() for state in self.state.fusion_states_list + ] + + return BatchedLabelLoopingState( + predictor_states=( + self.state.decoder_state[0][:, :current_batch_size * self.beam_size].clone(), + self.state.decoder_state[1][:, :current_batch_size * self.beam_size].clone(), + ), + predictor_outputs=self.state.decoder_output[:current_batch_size * self.beam_size].clone(), + labels=last_labels, + decoded_lengths=decoded_lengths, + fusion_states_list=fusion_states_list, + time_jumps=None, + batched_hyps=self.state.batched_hyps.clone(), + ) + def __call__( self, x: torch.Tensor, out_len: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: if self.cuda_graphs_mode is not None and x.device.type == "cuda": with torch.amp.autocast(device_type="cuda", enabled=False): - return self.modified_alsd_cuda_graphs(encoder_output=x, encoder_output_length=out_len) + return self.modified_alsd_cuda_graphs( + encoder_output=x, + encoder_output_length=out_len, + prev_batched_state=prev_batched_state, + ) - return self.modified_alsd_torch(encoder_output=x, encoder_output_length=out_len) + return self.modified_alsd_torch( + encoder_output=x, encoder_output_length=out_len, prev_batched_state=prev_batched_state + ) diff --git a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py index 65255b85849d..fea925dd219c 100644 --- a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py @@ -955,13 +955,13 @@ def forward( with torch.inference_mode(): # Apply optional preprocessing encoder_output = encoder_output.transpose(1, 2) # (B, T, D) - logitlen = encoded_lengths self.decoder.eval() self.joint.eval() - inseq = encoder_output # [B, T, D] - batched_beam_hyps = self._decoding_computer(x=inseq, out_len=logitlen) + encoder_output_projected = self.joint.project_encoder(encoder_output) + encoder_output_projected_len = encoded_lengths + batched_beam_hyps, alignments, decoding_state = self._decoding_computer(x=encoder_output_projected, out_len=encoder_output_projected_len) # Ensures the correct number of hypotheses (batch_size) for CUDA Graphs compatibility batch_size = encoder_output.shape[0] diff --git a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py index 45ecfed299a5..38352a1b64ae 100644 --- a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py @@ -19,6 +19,8 @@ import torch.nn.functional as F from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel +from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import BatchedLabelLoopingState +from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( INACTIVE_SCORE, @@ -92,6 +94,10 @@ class MALSDState: batched_hyps: BatchedBeamHyps # batched hypotheses - decoding result + # Streaming state fields + is_continuation: torch.Tensor # flag indicating if this is a continuation from previous chunk + is_first_chunk: torch.Tensor # complement of ``is_continuation``; both feed the captured graph's IF nodes + # fusion models related fields fusion_models: Optional[List[NGramGPULanguageModel]] = None fusion_models_alpha: Optional[List[float]] = None @@ -184,6 +190,13 @@ def __init__( self.blank_mask = torch.zeros_like(self.active_mask, dtype=torch.bool) self.active_mask_any = torch.tensor(True, device=self.device, dtype=torch.bool) + # Streaming state fields. The captured FULL_GRAPH reads ``is_first_chunk`` and + # ``is_continuation`` to route to the first-chunk vs. continuation prologue at replay + # time. The two flags are kept in lockstep by ``modified_alsd_cuda_graphs`` before + # each replay (they are mutually exclusive). + self.is_continuation = torch.tensor(False, device=self.device, dtype=torch.bool) + self.is_first_chunk = torch.tensor(True, device=self.device, dtype=torch.bool) + self.batched_hyps = BatchedBeamHyps( batch_size=batch_size, beam_size=self.beam_size, @@ -208,6 +221,7 @@ class SeparateGraphsMALSD: """Class to store Cuda graphs for decoding when separate graphs are used""" before_loop: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) + before_loop_continuation: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) loop_body: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) loop_update_decoder: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) @@ -363,7 +377,8 @@ def modified_alsd_torch( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: """ Pytorch implementation of the batched ALSD algorithm for TDT models. Args: @@ -371,8 +386,10 @@ def modified_alsd_torch( [batch_size, max_time, encoder_dim]. encoder_output_length (torch.Tensor): The lengths of the encoder outputs for each batch with shape [batch_size]. + prev_batched_state (Optional[BatchedLabelLoopingState]): The previous batched state for streaming. Returns: - BatchedBeamHyps: Batched beam hypotheses. + tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: + Batched beam hypotheses, alignments (None), and decoding state. """ batch_size, max_time, _ = encoder_output.shape @@ -385,21 +402,6 @@ def modified_alsd_torch( encoder_output_projected = self.joint.project_encoder(encoder_output) float_dtype = encoder_output_projected.dtype - # init empty batched beam hypotheses - batched_hyps = BatchedBeamHyps( - batch_size=batch_size, - beam_size=self.beam_size, - blank_index=self._blank_index, - init_length=max_time * (self.max_symbols + 1) if self.max_symbols is not None else max_time, - device=device, - float_dtype=float_dtype, - model_type='tdt', - ) - - last_labels_wb = torch.full( - [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long - ) - batch_beam_indices = ( torch.arange(batch_size, dtype=torch.long, device=device)[:, None] .expand(batch_size, self.beam_size) @@ -411,47 +413,99 @@ def modified_alsd_torch( .clone() ) # size: batch_size x beam_size x beam_size - time_indices = torch.zeros_like(batch_beam_indices) + # Reuse batched beam hypotheses from state if continuing, otherwise create new + if prev_batched_state is not None and prev_batched_state.batched_hyps is not None: + batched_hyps = prev_batched_state.batched_hyps + # For continuation, initialize time_indices from batched_hyps.next_timestamp + # This represents where we left off in the previous chunk + # Subtract decoded_lengths to convert from absolute to relative time + time_indices = batched_hyps.next_timestamp.clone() - prev_batched_state.decoded_lengths[:, None].expand_as(batched_hyps.next_timestamp) + assert time_indices.min() >= 0, "Time indices should be non-negative" + else: + batched_hyps = BatchedBeamHyps( + batch_size=batch_size, + beam_size=self.beam_size, + blank_index=self._blank_index, + init_length=max_time * (self.max_symbols + 1) if self.max_symbols is not None else max_time, + device=device, + float_dtype=float_dtype, + model_type='tdt', + ) + time_indices = torch.zeros_like(batch_beam_indices) + safe_time_indices = torch.zeros_like(time_indices) # time indices, guaranteed to be < out_len last_timesteps = (encoder_output_length - 1)[:, None].expand_as(batch_beam_indices) + # Clamp time_indices to valid range for continuation + torch.minimum(time_indices, last_timesteps, out=safe_time_indices) active_mask = time_indices <= last_timesteps # setup fusion models if available if self.fusion_models is not None: - fusion_states_list = [] - fusion_states_candidates_list = [] - fusion_scores_list = [] - for fusion_model_idx, fusion_model in enumerate(self.fusion_models): - fusion_model.to(device) - fusion_states = fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True) - fusion_scores, fusion_states_candidates = fusion_model.advance(states=fusion_states) + if prev_batched_state is None or prev_batched_state.fusion_states_list is None: + fusion_states_list = [] + fusion_states_candidates_list = [] + fusion_scores_list = [] + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(device) + fusion_states = fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True) + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=fusion_states + ) # vocab_size_no_blank - fusion_scores = ( - fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) - * self.fusion_models_alpha[fusion_model_idx] + fusion_scores = ( + fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_model_idx] + ) + fusion_states_list.append(fusion_states) + fusion_states_candidates_list.append(fusion_states_candidates) + fusion_scores_list.append(fusion_scores) + else: + fusion_states_list = prev_batched_state.fusion_states_list + fusion_states_candidates_list = [] + fusion_scores_list = [] + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(device) + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=fusion_states_list[fusion_model_idx] + ) + fusion_scores = ( + fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_model_idx] + ) + fusion_states_candidates_list.append(fusion_states_candidates) + fusion_scores_list.append(fusion_scores) + else: + fusion_states_list = None + + if prev_batched_state is None: + last_labels_wb = torch.full( + [batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long + ) + decoder_state = self.decoder.initialize_state( + torch.empty( + [ + batch_size * self.beam_size, + ], + dtype=float_dtype, + device=device, ) - fusion_states_list.append(fusion_states) - fusion_states_candidates_list.append(fusion_states_candidates) - fusion_scores_list.append(fusion_scores) + ) - decoder_state = self.decoder.initialize_state( - torch.empty( - [ - batch_size * self.beam_size, - ], - dtype=float_dtype, - device=device, + decoder_output, state, *_ = self.decoder.predict( + last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size ) - ) + # do not recalculate joint projection + decoder_output = self.joint.project_prednet(decoder_output) # size: [(batch_size x beam_size), 1, Dim] + self.decoder.batch_replace_states_all(state, dst_states=decoder_state) + else: + # Continuing from previous chunk - batched_hyps already contains all state + decoder_output = prev_batched_state.predictor_outputs + decoder_state = prev_batched_state.predictor_states - decoder_output, state, *_ = self.decoder.predict( - last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size - ) - # do not recalculate joint projection - decoder_output = self.joint.project_prednet(decoder_output) # size: [(batch_size x beam_size), 1, Dim] - self.decoder.batch_replace_states_all(state, dst_states=decoder_state) + # import pdb; pdb.set_trace() + step1 = 0 - while active_mask.any(): + while active_mask.any(): # step 1: get joint output + fuse with fusion models (if present) logits = ( self.joint.joint_after_projection( @@ -624,11 +678,40 @@ def modified_alsd_torch( fusion_scores_list[fusion_model_idx] = fusion_scores # step 6: update time indices + active mask - time_indices.copy_(batched_hyps.next_timestamp) + time_indices = torch.gather(time_indices, dim=-1, index=hyps_indices) + next_label_durations torch.minimum(time_indices, last_timesteps, out=safe_time_indices) torch.less_equal(time_indices, last_timesteps, out=active_mask) + + step1 += 1 + + # fix timestamps for iterative decoding + # Add offset to timestamps so they are cumulative across chunks + if prev_batched_state is not None: + batched_hyps.timestamps += prev_batched_state.decoded_lengths[:, None, None].expand_as(batched_hyps.timestamps) + # Also update next_timestamp for proper continuation + + # NB: last labels can not exist (nothing decoded on this step). + # return the last labels from the previous state in this case + last_labels = batched_hyps.get_last_labels(pad_id=self._SOS) + decoding_state = BatchedLabelLoopingState( + predictor_states=decoder_state, + predictor_outputs=decoder_output, + labels=( + torch.where(last_labels == self._SOS, prev_batched_state.labels, last_labels) + if prev_batched_state is not None + else last_labels + ), + decoded_lengths=( + encoder_output_length.clone() + if prev_batched_state is None + else encoder_output_length + prev_batched_state.decoded_lengths + ), + fusion_states_list=fusion_states_list if self.fusion_models is not None else None, + time_jumps=None, # Not needed for beam search since we use batched_hyps.next_timestamp + batched_hyps=batched_hyps, # Save batched_hyps object for next chunk + ) - return batched_hyps + return batched_hyps, None, decoding_state def topk_fusion_model(self, fusion_scores_list, log_probs, duration_log_probs, eps=1e-2): """ @@ -747,7 +830,8 @@ def modified_alsd_cuda_graphs( self, encoder_output: torch.Tensor, encoder_output_length: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: """ Cuda-Graphs implementation of the batched ALSD algorithm. Args: @@ -755,8 +839,9 @@ def modified_alsd_cuda_graphs( [batch_size, max_time, encoder_dim]. encoder_output_length (torch.Tensor): The lengths of the encoder outputs for each batch with shape [batch_size]. + prev_batched_state (Optional[BatchedLabelLoopingState]): The previous batched state. Returns: - BathedBeamHyps: Batched beam hypotheses. + tuple: (BatchedBeamHyps, None, BatchedLabelLoopingState) """ assert self.cuda_graphs_mode is not None @@ -773,29 +858,53 @@ def modified_alsd_cuda_graphs( if self.state is None or self.state.need_reinit(encoder_output): self._graph_reinitialize(encoder_output, encoder_output_length) + # Set continuation flag and restore state from previous chunk if provided + is_continuation = prev_batched_state is not None + self.state.is_continuation.fill_(is_continuation) + # Mirror into the inverse flag so the captured graph's IF nodes can route to + # the right prologue. Both tensors are read by ``loop_conditional``-style + # condition kernels baked into the graph at capture time. + self.state.is_first_chunk.fill_(not is_continuation) + + if is_continuation: + # Restore state from previous chunk + self._restore_state_from_prev(prev_batched_state, current_batch_size) + # set length to zero for elements outside the current batch self.state.encoder_output_length.fill_(0) - # copy (projected) encoder output and lenghts + # copy (projected) encoder output and lengths self.state.encoder_output_projected[:current_batch_size, :current_max_time, ...].copy_(encoder_output) self.state.encoder_output_length[:current_batch_size].copy_(encoder_output_length.unsqueeze(-1)) + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: + # Single graph dispatches between first-chunk and continuation prologues internally + # via captured IF nodes that read ``is_first_chunk`` / ``is_continuation``. self.full_graph.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_WHILE_LOOPS: - self.separate_graphs.before_loop.replay() + # Use continuation before_loop graph if continuing from previous chunk + if is_continuation: + self.separate_graphs.before_loop_continuation.replay() + else: + self.separate_graphs.before_loop.replay() while self.state.active_mask_any.item(): self.separate_graphs.loop_body.replay() self.separate_graphs.loop_update_decoder.replay() elif self.cuda_graphs_mode is self.CudaGraphsMode.NO_GRAPHS: - # this mode is only for testing purposes # manual loop instead of using graphs - self._before_loop() + if is_continuation: + self._before_loop_continuation() + else: + self._before_loop() while self.state.active_mask_any.item(): self._loop_body() self._loop_update_decoder() else: raise NotImplementedError(f"Unknown graph mode: {self.cuda_graphs_mode}") - return self.state.batched_hyps + # Create and return decoding state for next chunk + decoding_state = self._create_decoding_state(encoder_output_length, prev_batched_state) + + return self.state.batched_hyps, None, decoding_state @classmethod def _create_loop_body_kernel(cls): @@ -919,6 +1028,10 @@ def _graph_reinitialize( self.state.fusion_scores_list.append(self.state.init_fusion_scores_list[fusion_model_idx].clone()) self.state.fusion_states_prev_list.append(init_fusion_states.clone()) + # warmup before graph compilation + if self.cuda_graphs_mode is not self.CudaGraphsMode.NO_GRAPHS: + self._warmup_for_cuda_graphs() + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: try: self._full_graph_compile() @@ -939,6 +1052,37 @@ def _graph_reinitialize( else: raise NotImplementedError + def _warmup_for_cuda_graphs(self): + """Warmup before compiling CUDA graphs. + + Runs a few eager iterations of both the first-chunk and continuation paths so that + cuBLAS / cuDNN handles and workspaces are allocated and stable before any graph + capture begins. Mirrors the warmup pattern used by the greedy label-looping decoder. + """ + is_ddp = torch.distributed.is_available() and torch.distributed.is_initialized() + # 11 warmup steps required in DDP mode + # see https://pytorch.org/docs/stable/notes/cuda.html#usage-with-distributeddataparallel + num_runs = 11 if is_ddp else 3 + self.state.encoder_output_projected.fill_(0.0) + self.state.encoder_output_length.fill_(1) + s = torch.cuda.Stream(self.state.device) + s.wait_stream(torch.cuda.current_stream(device=self.state.device)) + with torch.cuda.stream(s), torch.inference_mode(): + # Warm up the first-chunk path. + for _ in range(num_runs): + self._before_loop() + self._loop_body() + self._loop_update_decoder() + # Warm up the continuation path so its prologue and any kernels it touches + # are primed too. Both captures share a mempool, so any allocator activity + # they trigger needs to settle before either is captured. + for _ in range(num_runs): + self._before_loop_continuation() + self._loop_body() + self._loop_update_decoder() + torch.cuda.current_stream(device=self.state.device).wait_stream(s) + self.state.encoder_output_length.fill_(0) + def _partial_graphs_compile(self): """Compile decoding by parts""" # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. @@ -954,6 +1098,17 @@ def _partial_graphs_compile(self): ): self._before_loop() + with ( + torch.cuda.stream(stream_for_graph), + torch.inference_mode(), + torch.cuda.graph( + self.separate_graphs.before_loop_continuation, + stream=stream_for_graph, + capture_error_mode="thread_local", + ), + ): + self._before_loop_continuation() + with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), @@ -974,9 +1129,23 @@ def _partial_graphs_compile(self): @cuda_python_required def _full_graph_compile(self): - """Compile full graph for decoding""" + """Compile a single CUDA graph that handles both first-chunk and continuation paths. + + The graph contains three conditional sub-graphs in order: + 1. IF (``is_first_chunk``) → ``_before_loop()`` + 2. IF (``is_continuation``) → ``_before_loop_continuation()`` + 3. WHILE (``active_mask_any``) → ``_loop_body()`` + ``_loop_update_decoder()`` + + At replay time the caller toggles ``is_first_chunk`` / ``is_continuation`` so + exactly one prologue executes. This avoids needing two coexisting CUDAGraph + objects (which observed to cause cudaErrorIllegalAddress on replay due to + mempool interaction between the two captures). + """ # Always create a new stream, because the per-thread default stream disallows stream capture to a graph. stream_for_graph = torch.cuda.Stream(self.state.device) + # Drain any work pending on the default stream (e.g. the warmup that ran just above in + # ``_graph_reinitialize``) before we start capturing. + stream_for_graph.wait_stream(torch.cuda.default_stream(self.state.device)) self.full_graph = torch.cuda.CUDAGraph() with ( @@ -984,24 +1153,50 @@ def _full_graph_compile(self): torch.inference_mode(), torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): - self._before_loop() + # The condition-setter kernel (created lazily by ``_create_loop_body_kernel``) is + # signature-compatible with any 0-d bool*; we reuse it for all three conditional nodes. + cond_kernel = self._create_loop_body_kernel() + # NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements capture_status, _, graph, *_ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.state.device).cuda_stream) ) - assert capture_status == cudart.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive - # capture: while self.active_mask_any: + # --- IF (is_first_chunk): run first-chunk prologue --- + (first_chunk_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) + is_first_chunk_ptr = np.array([self.state.is_first_chunk.data_ptr()], dtype=np.uint64) + first_chunk_args = np.array( + [first_chunk_handle.getPtr(), is_first_chunk_ptr.ctypes.data], + dtype=np.uint64, + ) + with with_conditional_node( + cond_kernel, first_chunk_args, first_chunk_handle, device=self.state.device, cond_type="if" + ): + self._before_loop() + + # --- IF (is_continuation): run continuation prologue --- + (continuation_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) + is_continuation_ptr = np.array([self.state.is_continuation.data_ptr()], dtype=np.uint64) + continuation_args = np.array( + [continuation_handle.getPtr(), is_continuation_ptr.ctypes.data], + dtype=np.uint64, + ) + with with_conditional_node( + cond_kernel, continuation_args, continuation_handle, device=self.state.device, cond_type="if" + ): + self._before_loop_continuation() + + # --- WHILE (active_mask_any): main decoding loop --- (loop_conditional_handle,) = cu_call(cudart.cudaGraphConditionalHandleCreate(graph, 0, 0)) - loop_kernel = self._create_loop_body_kernel() active_mask_any_ptr = np.array([self.state.active_mask_any.data_ptr()], dtype=np.uint64) loop_args = np.array( [loop_conditional_handle.getPtr(), active_mask_any_ptr.ctypes.data], dtype=np.uint64, ) - # loop while there are active utterances - with with_conditional_node(loop_kernel, loop_args, loop_conditional_handle, device=self.state.device): + with with_conditional_node( + cond_kernel, loop_args, loop_conditional_handle, device=self.state.device, cond_type="while" + ): self._loop_body() self._loop_update_decoder() @@ -1024,6 +1219,35 @@ def _before_loop(self): # last found labels - initially () symbol self.state.last_labels_wb.fill_(self._SOS) + + # set decoder state and output to initial values + self.state.decoder_output.copy_(self.state.init_decoder_output) + self.state.decoder_state[0].copy_(self.state.init_decoder_state[0]) + self.state.decoder_state[1].copy_(self.state.init_decoder_state[1]) + + self._before_loop_common() + + def _before_loop_continuation(self): + """ + Prologue for a continuation chunk: preserves cross-chunk per-beam state on + ``batched_hyps`` (scores, last_label, transcript_hash, current_lengths_nb, + last_timestamp_lasts) and resets only the chunk-local prefix-tree buffers + (transcript_wb / transcript_wb_prev_ptr / timestamps / current_lengths_wb) + that the captured loop body would otherwise overflow. + + Decoder and fusion states are already restored by ``_restore_state_from_prev``. + The caller is responsible for snapshotting / merging the chunk-local transcripts + before the next chunk (see :meth:`BatchedBeamHyps.merge_` with + ``is_chunk_continuation=True``). + """ + self.state.batched_hyps.clear_chunk_local_() + self._before_loop_common() + + def _before_loop_common(self): + """ + Common initialization for both first chunk and continuation. + Resets temporary variables and computes active mask. + """ self.state.next_scores.fill_(0.0) self.state.next_labels.fill_(0.0) self.state.next_idx.fill_(0.0) @@ -1042,10 +1266,6 @@ def _before_loop(self): # same as: self.active_mask_any = active_mask.any() torch.any(self.state.active_mask, out=self.state.active_mask_any) - self.state.decoder_output.copy_(self.state.init_decoder_output) - self.state.decoder_state[0].copy_(self.state.init_decoder_state[0]) - self.state.decoder_state[1].copy_(self.state.init_decoder_state[1]) - self.state.prev_decoder_output.fill_(0) self.state.prev_decoder_state[0].fill_(0) self.state.prev_decoder_state[1].fill_(0) @@ -1295,13 +1515,125 @@ def _loop_update_decoder(self): torch.less_equal(self.state.time_indices, self.state.last_timestamps, out=self.state.active_mask) torch.any(self.state.active_mask, out=self.state.active_mask_any) + def _restore_state_from_prev( + self, prev_batched_state: BatchedLabelLoopingState, current_batch_size: int + ): + """ + Restore decoder state, fusion states, and batched_hyps from previous chunk's state. + Used for streaming/chunked decoding. + + Args: + prev_batched_state: State from previous chunk + current_batch_size: Current batch size + """ + # Restore decoder output and state + if prev_batched_state.predictor_outputs is not None: + self.state.decoder_output[:current_batch_size * self.beam_size].copy_( + prev_batched_state.predictor_outputs.view(-1, 1, prev_batched_state.predictor_outputs.shape[-1]) + ) + + if prev_batched_state.predictor_states is not None: + # Copy decoder states (assuming tuple of tensors) + for i, state_tensor in enumerate(prev_batched_state.predictor_states): + if state_tensor is not None: + self.state.decoder_state[i][:, :current_batch_size * self.beam_size].copy_( + state_tensor[:, :current_batch_size * self.beam_size] + ) + + # Restore fusion states if present + if prev_batched_state.fusion_states_list is not None and self.fusion_models is not None: + for fusion_idx, fusion_state in enumerate(prev_batched_state.fusion_states_list): + if fusion_state is not None: + self.state.fusion_states_list[fusion_idx][:current_batch_size].copy_( + fusion_state[:current_batch_size] + ) + # Recompute fusion scores and candidates from restored states + fusion_scores, fusion_states_candidates = self.fusion_models[fusion_idx].advance( + states=self.state.fusion_states_list[fusion_idx][:current_batch_size].reshape(-1) + ) + self.state.fusion_scores_list[fusion_idx][:current_batch_size].copy_( + fusion_scores.to(dtype=self.state.float_dtype).view(current_batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_idx] + ) + self.state.fusion_states_candidates_list[fusion_idx][:current_batch_size].copy_( + fusion_states_candidates.view(current_batch_size, self.beam_size, -1) + ) + + # Restore batched_hyps from previous state + if prev_batched_state.batched_hyps is not None: + self.state.batched_hyps.copy_from_(prev_batched_state.batched_hyps) + + def _create_decoding_state( + self, + encoder_output_length: torch.Tensor, + prev_batched_state: Optional[BatchedLabelLoopingState], + ) -> BatchedLabelLoopingState: + """ + Create BatchedLabelLoopingState for the next chunk. + + Args: + encoder_output_length: Length of current encoder output + prev_batched_state: State from previous chunk (if any) + + Returns: + BatchedLabelLoopingState containing current decoding state + """ + current_batch_size = encoder_output_length.shape[0] + + # Get last labels from batched_hyps + last_labels = self.state.batched_hyps.get_last_labels(pad_id=self._SOS) + + # Reset next_timestamp for next chunk + self.state.batched_hyps.next_timestamp.fill_(0) + + # Calculate accumulated decoded lengths + if prev_batched_state is None: + decoded_lengths = encoder_output_length.clone() + else: + decoded_lengths = encoder_output_length + prev_batched_state.decoded_lengths[:current_batch_size] + + # Handle labels - if nothing decoded this chunk, use previous labels + if prev_batched_state is not None: + last_labels = torch.where( + last_labels == self._SOS, + prev_batched_state.labels, + last_labels + ) + + # Get fusion states if present + fusion_states_list = None + if self.fusion_models is not None and self.state.fusion_states_list is not None: + fusion_states_list = [ + state[:current_batch_size].clone() for state in self.state.fusion_states_list + ] + + return BatchedLabelLoopingState( + predictor_states=( + self.state.decoder_state[0][:, :current_batch_size * self.beam_size].clone(), + self.state.decoder_state[1][:, :current_batch_size * self.beam_size].clone(), + ), + predictor_outputs=self.state.decoder_output[:current_batch_size * self.beam_size].clone(), + labels=last_labels, + decoded_lengths=decoded_lengths, + fusion_states_list=fusion_states_list, + time_jumps=None, + batched_hyps=self.state.batched_hyps.clone(), + ) + def __call__( self, x: torch.Tensor, out_len: torch.Tensor, - ) -> BatchedBeamHyps: + prev_batched_state: Optional[BatchedLabelLoopingState] = None, + ) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]: if self.cuda_graphs_mode is not None and x.device.type == "cuda": with torch.amp.autocast(device_type="cuda", enabled=False): - return self.modified_alsd_cuda_graphs(encoder_output=x, encoder_output_length=out_len) + return self.modified_alsd_cuda_graphs( + encoder_output=x, + encoder_output_length=out_len, + prev_batched_state=prev_batched_state, + ) - return self.modified_alsd_torch(encoder_output=x, encoder_output_length=out_len) + return self.modified_alsd_torch( + encoder_output=x, encoder_output_length=out_len, prev_batched_state=prev_batched_state + ) diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py b/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py index 30a4b97366cd..b53531218698 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py @@ -47,6 +47,7 @@ class BatchedLabelLoopingState: decoded_lengths: torch.Tensor fusion_states_list: list[torch.Tensor] = field(default_factory=list) time_jumps: torch.Tensor | None = None + batched_hyps: Any = None # For beam search: BatchedBeamHyps object to continue across chunks @dataclass diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index e0af3b7623d7..29ff3c605bc0 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -199,6 +199,126 @@ def clear_(self): self.next_timestamp.fill_(0) self.last_timestamp_lasts.fill_(0) + def clear_chunk_local_(self): + """ + Reset only the chunk-local storage so the per-chunk transcript buffers can be + reused for the next chunk in streaming/chunked beam decoding under CUDA graphs. + + The chunk-local storage is the transcript prefix tree (``transcript_wb``, + ``transcript_wb_prev_ptr``), the per-step timestamps array (``timestamps``), the + write cursor into those buffers (``current_lengths_wb``) and, for transducer + models, the per-chunk ``next_timestamp`` counter. + + Cross-chunk per-beam state - ``scores``, ``last_label``, ``transcript_hash``, + ``current_lengths_nb`` and ``last_timestamp_lasts`` - is intentionally left + untouched so the beam-search loop can continue from the previous chunk's beam + state without re-seeding. + + This method does only in-place ``fill_`` / ``copy_`` writes into the existing + tensors, so it is safe to call from inside a captured CUDA-graph region (the + captured pointers remain valid). + """ + self.current_lengths_wb.fill_(0) + + self.transcript_wb.fill_(NON_EXISTENT_LABEL_VALUE) + self.transcript_wb_prev_ptr.fill_(INIT_POINTER_VALUE) + + if self.model_type == ASRModelTypeEnum.CTC: + self.timestamps.copy_(self._create_timestamps_tensor(self._max_length)) + else: + self.timestamps.fill_(0) + self.next_timestamp.fill_(0) + + def copy_from_(self, other: "BatchedBeamHyps"): + """ + Copy state from another BatchedBeamHyps object (in-place). + Used for streaming/chunked decoding to restore state from previous chunk. + + Args: + other: Source BatchedBeamHyps to copy from + """ + batch_size = min(self.batch_size, other.batch_size) + beam_size = min(self.beam_size, other.beam_size) + max_length = min(self._max_length, other._max_length) + + self.current_lengths_nb[:batch_size, :beam_size].copy_( + other.current_lengths_nb[:batch_size, :beam_size] + ) + self.current_lengths_wb[:batch_size, :beam_size].copy_( + other.current_lengths_wb[:batch_size, :beam_size] + ) + self.transcript_wb[:batch_size, :beam_size, :max_length].copy_( + other.transcript_wb[:batch_size, :beam_size, :max_length] + ) + self.transcript_wb_prev_ptr[:batch_size, :beam_size, :max_length].copy_( + other.transcript_wb_prev_ptr[:batch_size, :beam_size, :max_length] + ) + self.scores[:batch_size, :beam_size].copy_( + other.scores[:batch_size, :beam_size] + ) + self.last_label[:batch_size, :beam_size].copy_( + other.last_label[:batch_size, :beam_size] + ) + self.transcript_hash[:batch_size, :beam_size].copy_( + other.transcript_hash[:batch_size, :beam_size] + ) + + if self.store_prefix_hashes and other.store_prefix_hashes: + self.transcript_prefix_hash[:batch_size, :beam_size].copy_( + other.transcript_prefix_hash[:batch_size, :beam_size] + ) + + self.timestamps[:batch_size, :beam_size, :max_length].copy_( + other.timestamps[:batch_size, :beam_size, :max_length] + ) + + if self.model_type != ASRModelTypeEnum.CTC: + self.next_timestamp[:batch_size, :beam_size].copy_( + other.next_timestamp[:batch_size, :beam_size] + ) + self.last_timestamp_lasts[:batch_size, :beam_size].copy_( + other.last_timestamp_lasts[:batch_size, :beam_size] + ) + + def clone(self) -> "BatchedBeamHyps": + """ + Create a deep copy of this BatchedBeamHyps object. + Used for streaming/chunked decoding to save state for next chunk. + + Returns: + New BatchedBeamHyps with copied state + """ + new_hyps = BatchedBeamHyps( + batch_size=self.batch_size, + beam_size=self.beam_size, + init_length=self._max_length, + blank_index=self.blank_index, + device=self.device, + float_dtype=self.scores.dtype, + store_prefix_hashes=self.store_prefix_hashes, + model_type=self.model_type, + ) + new_hyps.copy_from_(self) + return new_hyps + + def get_last_labels(self, pad_id: int = -1) -> torch.Tensor: + """ + Get last labels for each hypothesis in the beam. + + Args: + pad_id: Value to use for padding (for hypotheses without labels). Defaults to -1. + + Returns: + Tensor of shape [batch_size, beam_size] with the last label for each hypothesis. + """ + # last_label already contains the last label for each beam + # Replace NON_EXISTENT_LABEL_VALUE with pad_id + return torch.where( + self.last_label != NON_EXISTENT_LABEL_VALUE, + self.last_label, + pad_id + ) + def _allocate_more(self): """ Dynamically allocates more memory for the internal buffers. @@ -275,6 +395,7 @@ def add_results_no_checks_( is_extended = next_labels >= 0 extended_with_blank = next_labels == self.blank_index extended_with_label = (is_extended) & (~extended_with_blank) + if self.model_type == ASRModelTypeEnum.CTC: # for CTC last non-blank and non-repeated label extended_with_label = (extended_with_label) & (next_labels != last_labels) # non-repeated non-blank label @@ -365,6 +486,7 @@ def recombine_hyps_(self): self.scores[:, None, :].expand(self.batch_size, self.beam_size, self.beam_size), self.INACTIVE_SCORE_TENSOR, ) + scores_argmax = scores_matrix.argmax(-1, keepdim=False) scores_to_keep = ( torch.arange(self.beam_size, device=scores_argmax.device, dtype=torch.long)[None, :] == scores_argmax @@ -373,6 +495,7 @@ def recombine_hyps_(self): new_scores = torch.max(scores_matrix, dim=-1, keepdim=False).values else: new_scores = torch.logsumexp(scores_matrix, dim=-1, keepdim=False) + torch.where(scores_to_keep, new_scores.to(self.scores.dtype), self.INACTIVE_SCORE_TENSOR, out=self.scores) def remove_duplicates(self, labels: torch.Tensor, total_logps: torch.Tensor): @@ -547,8 +670,49 @@ def flatten_sort_(self, score_norm: bool = True): normalized_scores = ( self.scores / (self.current_lengths_nb.to(self.scores.dtype) + 1) if score_norm else self.scores ) - normalized_scores, indices = torch.sort(normalized_scores, dim=-1, descending=True) + _, indices = torch.sort(normalized_scores, dim=-1, descending=True) + self._flatten_with_permutation_(indices) + + def flatten_(self) -> torch.Tensor: + """ + Flatten the tree structure of hypotheses without changing beam order. + + Like :meth:`flatten_sort_` but uses the identity permutation, so beam ``i`` keeps + its identity (its decoded prefix and its cross-chunk per-beam state stay aligned + with the corresponding beam in any other ``BatchedBeamHyps`` constructed under the + same decoding run). Required for inter-chunk :meth:`merge_` calls in streaming + beam decoding where beam indices must correspond across chunks. + + Returns: + ``root_ptrs`` of shape ``[batch_size, beam_size]``: the beam index at the + chunk's *start* (i.e. before the first ``add_results_*`` write) from which + each output beam ultimately descends. For chunked streaming beam search, this + tells the caller how to permute the previous chunks' accumulated per-beam + transcripts so they align with this chunk's beam ordering before merging. + + If the prefix tree is empty (``current_lengths_wb.max() == 0``) the identity + permutation is returned. + """ + identity = self.beam_indices.unsqueeze(0).expand(self.batch_size, self.beam_size).contiguous() + return self._flatten_with_permutation_(identity) + + def _flatten_with_permutation_(self, indices: torch.Tensor) -> torch.Tensor: + """ + In-place flatten of the prefix tree using ``indices`` as the new beam permutation. + + Walks ``transcript_wb_prev_ptr`` from the most recent step back to step 0, + gathering tokens and timestamps for each output beam from the source beam given + by ``indices``. Updates all per-beam metadata to match the new ordering. + Args: + indices: ``[batch_size, beam_size]`` long tensor giving the source beam index + for each output beam (e.g. ``arange(beam_size)`` for no permutation). + + Returns: + ``root_ptrs`` of shape ``[batch_size, beam_size]``: the beam index *before* + step 0 of the prefix tree from which each output beam descends. If the prefix + tree is empty (``max_idx < 0``) this equals ``indices``. + """ max_idx = self.current_lengths_wb.max() - 1 ptrs = indices @@ -573,6 +737,8 @@ def flatten_sort_(self, score_norm: bool = True): if self.store_prefix_hashes: self.transcript_prefix_hash.copy_(torch.gather(self.transcript_prefix_hash, dim=-1, index=indices)) + return ptrs + def _create_fold_consecutive_mask(self, transcript): """ Creates a mask to filter consecutive duplicates, blanks, and invalid tokens in a transcript. @@ -621,3 +787,131 @@ def _create_transcripts_mask(self, transcripts: torch.Tensor): return self._create_fold_consecutive_mask(transcripts) else: return (transcripts >= 0) & (transcripts != self.blank_index) + + def merge_( + self, + other: "BatchedBeamHyps", + is_chunk_continuation: bool = False, + boundary_prev_ptr: Optional[torch.Tensor] = None, + ) -> "BatchedBeamHyps": + """ + Merge two batched beam hypotheses structures by concatenating transcripts. + Used for streaming/chunked inference where results from multiple chunks need to be combined. + + Prerequisites: + - Both self and other should have been processed with flatten_sort_() before merging, + so that each beam contains an independent flattened hypothesis. + - Beam indices should correspond across chunks (beam i in self matches beam i in other). + + Notes: + - Timestamps in 'other' should already be cumulative (adjusted for time offset). + - The transcript_hash values are copied from 'other' and won't reflect the full + merged transcript. This means recombine_hyps_() should NOT be called on merged + results without recomputing hashes. This is acceptable for output-only use. + + Args: + other: BatchedBeamHyps from the next chunk to merge. + is_chunk_continuation: If True, treat ``other`` as a beam-search continuation + chunk in which the cross-chunk per-beam fields (``scores``, + ``current_lengths_nb``) already hold cumulative across-chunks values rather + than chunk-local deltas. In that case those fields are *replaced* with the + values from ``other`` instead of summed, to avoid double-counting. The + default (False) preserves the original "deltas" semantics used by greedy + streaming-style merges. + boundary_prev_ptr: Optional ``[batch_size, beam_size]`` long tensor. When + provided, written into ``transcript_wb_prev_ptr`` at the very first + position of the merged region (i.e. at ``self.current_lengths_wb`` before + the update). All other positions of the merged region still receive + ``beam_indices`` (identity) pointers. This is how chunked streaming beam + search threads the cross-chunk beam permutation (the "root ptrs" returned + by :meth:`flatten_` on ``other``) into the accumulator's prefix tree so + that the final :meth:`flatten_sort_` walk redirects from beam ``i`` in + ``other``'s region back to its source beam in ``self``'s region. + + Returns: + Self (modified in-place) + """ + max_other_len = other.current_lengths_wb.max().item() + + # Early return if other has nothing to merge + if max_other_len == 0: + return self + + # Check if we need more storage (using allocated buffer size, not current shape) + # Compute max needed length: current max + other max + max_needed = self.current_lengths_wb.max().item() + max_other_len + + # Expand storage if needed - use existing _allocate_more() method + while max_needed > self._max_length: + self._allocate_more() + + # Create a range tensor: [0, 1, 2, ..., max_other_len-1] + other_indices = torch.arange(max_other_len, device=self.device, dtype=torch.long) + + # Create shifted indices: current_lengths + [0, 1, 2, ...] + # Shape: [batch_size, beam_size, max_other_len] + shifted_indices = self.current_lengths_wb.unsqueeze(-1) + other_indices.unsqueeze(0).unsqueeze(0) + + # Scatter other's transcripts into self at shifted positions + self.transcript_wb.scatter_( + dim=-1, + index=shifted_indices, + src=other.transcript_wb[..., :max_other_len], + ) + + # Update pointers: in the merged region every position points to its own beam + # (identity), except the *first* merged position which optionally encodes the + # cross-chunk root permutation so the final flatten walk redirects from the new + # region back to the right beam in the old region. + identity_src = self.beam_indices.view(1, self.beam_size, 1).expand( + self.batch_size, -1, max_other_len + ) + if boundary_prev_ptr is not None: + ptr_src = identity_src.clone() + ptr_src[..., 0] = boundary_prev_ptr + else: + ptr_src = identity_src + self.transcript_wb_prev_ptr.scatter_( + dim=-1, + index=shifted_indices, + src=ptr_src, + ) + + # Scatter timestamps + self.timestamps.scatter_( + dim=-1, + index=shifted_indices, + src=other.timestamps[..., :max_other_len], + ) + + # Lengths in the chunk-local write cursor are always additive (``other`` always + # reports a chunk-local ``current_lengths_wb``). + self.current_lengths_wb += other.current_lengths_wb + + if is_chunk_continuation: + # Beam-search streaming: ``other`` carries cumulative cross-chunk state in + # these fields, so replace rather than accumulate. + self.current_lengths_nb.copy_(other.current_lengths_nb) + self.scores.copy_(other.scores) + else: + # Original ("deltas") semantics. + self.current_lengths_nb += other.current_lengths_nb + self.scores += other.scores + + # Update transcript hash by combining hashes + # The hash of the merged transcript should account for all non-blank labels + self.transcript_hash.copy_(other.transcript_hash) + + # Update prefix hashes if used + if self.store_prefix_hashes: + self.transcript_prefix_hash.copy_(other.transcript_prefix_hash) + + # Update tracking fields from other (they reflect the end state after other chunk) + self.last_label.copy_(other.last_label) + + # Only update timestamp tracking fields for transducer models + if self.model_type != ASRModelTypeEnum.CTC: + self.next_timestamp.copy_(other.next_timestamp) + self.last_timestamp_lasts.copy_(other.last_timestamp_lasts) + + return self diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index d657c56a67b6..b2ebb715bd11 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -2242,9 +2242,9 @@ def subsample(self, factor: int) -> "ContextSizeBatch": factor: subsampling factor """ return ContextSizeBatch( - left=torch.div(self.left, factor, rounding_mode="floor"), - chunk=torch.div(self.chunk, factor, rounding_mode="floor"), - right=torch.div(self.right, factor, rounding_mode="floor"), + left=torch.ceil(self.left / factor).to(dtype=torch.long), + chunk=torch.ceil(self.chunk / factor).to(dtype=torch.long), + right=torch.ceil(self.right / factor).to(dtype=torch.long), ) def add_frames_( diff --git a/nemo/core/utils/cuda_python_utils.py b/nemo/core/utils/cuda_python_utils.py index f67ac8b478d4..c40c567765c7 100644 --- a/nemo/core/utils/cuda_python_utils.py +++ b/nemo/core/utils/cuda_python_utils.py @@ -117,7 +117,9 @@ def cu_call(f_call_out): @contextlib.contextmanager @cuda_python_required -def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle, device): +def with_conditional_node( + while_loop_kernel, while_loop_args, while_loop_conditional_handle, device, cond_type="while" +): """ Even though we add a conditional node only once, we need to capture the kernel that calls cudaGraphSetConditional() both @@ -126,6 +128,7 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi to decide both whether to enter the loop, and also whether to execute the next iteration of the loop). """ + assert cond_type in ("while", "if"), f"cond_type must be 'while' or 'if', got {cond_type!r}" # NB: depending on cuda-python version, cudaStreamGetCaptureInfo can return either 5 or 6 elements capture_status, _, graph, *_ = cu_call( cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream) @@ -155,7 +158,10 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi driver_params = cuda.CUgraphNodeParams() driver_params.type = cuda.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL driver_params.conditional.handle = while_loop_conditional_handle - driver_params.conditional.type = cuda.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE + if cond_type == "while": + driver_params.conditional.type = cuda.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE + else: + driver_params.conditional.type = cuda.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF driver_params.conditional.size = 1 if Version(cuda_python_version) == Version("12.3.0"): # Work around for https://github.com/NVIDIA/cuda-python/issues/55 @@ -215,9 +221,13 @@ def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditi yield body_stream, body_graph - cuda.cuLaunchKernel( - while_loop_kernel, 1, 1, 1, 1, 1, 1, 0, body_stream.cuda_stream, while_loop_args.ctypes.data, 0 - ) + # For a WHILE node we re-launch the condition kernel at the end of the body so the + # graph can decide whether to execute another iteration. An IF node is one-shot so + # the body simply ends here. + if cond_type == "while": + cuda.cuLaunchKernel( + while_loop_kernel, 1, 1, 1, 1, 1, 1, 0, body_stream.cuda_stream, while_loop_args.ctypes.data, 0 + ) cudart.cudaStreamEndCapture(body_stream.cuda_stream)