Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,13 @@
from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import BiasingRequestItemConfig
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
from nemo.collections.asr.parts.submodules.rnnt_maes_batched_computer import ModifiedAESBatchedRNNTComputer
from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer
from nemo.collections.asr.parts.submodules.tdt_malsd_batched_computer import ModifiedALSDBatchedTDTComputer
from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import (
GreedyBatchedLabelLoopingComputerBase,
)
from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import BatchedBeamHyps
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.manifest_utils import filepath_to_absolute, read_manifest
from nemo.collections.asr.parts.utils.rnnt_utils import BatchedHyps, batched_hyps_to_hypotheses
Expand Down Expand Up @@ -232,9 +236,21 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:

# Change Decoding Config
with open_dict(cfg.decoding):
if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True:
if cfg.decoding.strategy == "greedy_batch":
if cfg.decoding.greedy.loop_labels is not True:
raise NotImplementedError(
"This script supports `greedy_batch` strategy only with Label-Looping algorithm"
)
cfg.decoding.greedy.preserve_alignments = False
elif cfg.decoding.strategy == "malsd_batch":
# MALSD beam search is supported for streaming
pass
elif cfg.decoding.strategy == "maes_batch":
# MAES beam search is supported for streaming
pass
else:
raise NotImplementedError(
"This script currently supports only `greedy_batch` strategy with Label-Looping algorithm"
"This script currently supports only `greedy_batch` with Label-Looping or `malsd_batch` strategy with MALSD"
)
cfg.decoding.tdt_include_token_duration = cfg.timestamps
cfg.decoding.greedy.preserve_alignments = False
Expand Down Expand Up @@ -278,7 +294,16 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.eval()

decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer
# Get decoding computer based on strategy. Beam-search strategies expose the
# underlying computer via the private ``_decoding_computer`` attribute.
if cfg.decoding.strategy == "greedy_batch":
decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer
elif cfg.decoding.strategy == "malsd_batch":
decoding_computer = asr_model.decoding.decoding._decoding_computer
elif cfg.decoding.strategy == "maes_batch":
decoding_computer: ModifiedAESBatchedRNNTComputer = asr_model.decoding.decoding._decoding_computer
else:
raise ValueError(f"Unsupported decoding strategy: {cfg.decoding.strategy}")

audio_sample_rate = model_cfg.preprocessor['sample_rate']

Expand Down Expand Up @@ -394,7 +419,15 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
)
rest_audio_lengths = audio_batch_lengths.clone()

# iterate over audio samples
# Beam-search strategies (MALSD/MAES/TDT-MALSD) keep their batched_hyps inside
# the decoder state and reset chunk-local buffers per chunk; greedy returns a
# fresh BatchedHyps each call which we must merge externally.
is_beam_search = isinstance(
decoding_computer,
(ModifiedALSDBatchedRNNTComputer, ModifiedAESBatchedRNNTComputer, ModifiedALSDBatchedTDTComputer),
)

# iterate over audio samples (but only for decoding chunks)
while left_sample < audio_batch.shape[1]:
# add samples to buffer
chunk_length = min(right_sample, audio_batch.shape[1]) - left_sample
Expand Down Expand Up @@ -425,18 +458,52 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
encoder_output = encoder_output[:, encoder_context.left :]

# decode only chunk frames
chunk_batched_hyps, _, state = decoding_computer(
x=encoder_output,
out_len=torch.where(
is_last_chunk_batch,
encoder_output_len - encoder_context_batch.left,
encoder_context_batch.chunk,
),
prev_batched_state=state,
multi_biasing_ids=multi_biasing_ids,
out_len = torch.where(
is_last_chunk_batch,
encoder_output_len - encoder_context_batch.left,
encoder_context_batch.chunk,
)
# merge hyps with previous hyps
if current_batched_hyps is None:
if is_beam_search:
# Beam-search computers don't accept ``multi_biasing_ids`` yet.
chunk_batched_hyps, _, state = decoding_computer(
x=encoder_output, out_len=out_len, prev_batched_state=state
)
else:
chunk_batched_hyps, _, state = decoding_computer(
x=encoder_output,
out_len=out_len,
prev_batched_state=state,
multi_biasing_ids=multi_biasing_ids,
)

# Accumulate hypotheses across chunks.
if is_beam_search:
# Beam-search reuses chunk-local transcript buffers inside
# ``decoding_computer.state.batched_hyps`` (reset at every continuation
# chunk), so we must snapshot before the next call overwrites them.
# ``flatten_`` resolves the chunk-local prefix tree without sorting and
# returns ``root_ptrs``: for each chunk-end beam ``i``, the beam index
# at the chunk's start that this hypothesis descends from. The loop
# body can permute beams via top-K, so beam ``i`` at the end of chunk
# ``N`` is generally a different logical hypothesis from beam ``i`` at
# the end of chunk ``N-1`` - it is the descendant of beam
# ``root_ptrs[i]``. Threading ``root_ptrs`` into the accumulator's
# ``transcript_wb_prev_ptr`` via ``boundary_prev_ptr`` encodes this
# redirection so the final ``flatten_sort_`` in ``to_hyps_list`` walks
# back through the right beam history. ``is_chunk_continuation=True``
# makes ``merge_`` replace (not sum) the cumulative per-beam fields
# (``scores``, ``current_lengths_nb``, ...).
chunk_snapshot = chunk_batched_hyps.clone()
chunk_root_ptrs = chunk_snapshot.flatten_()
if current_batched_hyps is None:
current_batched_hyps = chunk_snapshot
else:
current_batched_hyps.merge_(
chunk_snapshot,
is_chunk_continuation=True,
boundary_prev_ptr=chunk_root_ptrs,
)
elif current_batched_hyps is None:
current_batched_hyps = chunk_batched_hyps
else:
current_batched_hyps.merge_(chunk_batched_hyps)
Expand All @@ -446,13 +513,19 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
left_sample = right_sample
right_sample = min(right_sample + context_samples.chunk, audio_batch.shape[1]) # add next chunk

# remove biasing requests from the decoder
if use_per_stream_biasing and audio_data.biasing_requests is not None:
for request in audio_data.biasing_requests:
if request is not None and request.multi_model_id is not None:
decoding_computer.biasing_multi_model.remove_model(request.multi_model_id)
request.multi_model_id = None
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size))
# Convert batched hypotheses to list
if isinstance(current_batched_hyps, BatchedBeamHyps):
# MALSD beam search returns BatchedBeamHyps
all_hyps.extend(current_batched_hyps.to_hyps_list(score_norm=True))
else:
# Greedy batch returns BatchedHyps
# remove biasing requests from the decoder
if use_per_stream_biasing and audio_data.biasing_requests is not None:
for request in audio_data.biasing_requests:
if request is not None and request.multi_model_id is not None:
decoding_computer.biasing_multi_model.remove_model(request.multi_model_id)
request.multi_model_id = None
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size))
timer.stop(device=map_location)

# convert text
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1690,7 +1690,7 @@ def forward(
self.joint.eval()

inseq = encoder_output # [B, T, D]
batched_beam_hyps = self._decoding_computer(x=inseq, out_len=logitlen)
batched_beam_hyps, alignments, decoding_state = self._decoding_computer(x=inseq, out_len=logitlen)

batch_size = encoder_output.shape[0]
if self.return_best_hypothesis:
Expand Down
108 changes: 83 additions & 25 deletions nemo/collections/asr/parts/submodules/rnnt_maes_batched_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import torch

from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import BatchedLabelLoopingState
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin
from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import (
INACTIVE_SCORE,
Expand Down Expand Up @@ -121,6 +123,7 @@ def batched_modified_adaptive_expansion_search_torch(
self,
encoder_output: torch.Tensor,
encoder_output_length: torch.Tensor,
prev_batched_state: Optional[BatchedBeamHyps] = None,
) -> BatchedBeamHyps:
"""
Pure PyTorch implementation
Expand All @@ -134,22 +137,11 @@ def batched_modified_adaptive_expansion_search_torch(

encoder_output_projected = self.joint.project_encoder(encoder_output)
float_dtype = encoder_output_projected.dtype

# init empty batched hypotheses
batched_hyps = BatchedBeamHyps(
batch_size=batch_size,
beam_size=self.beam_size,
blank_index=self._blank_index,
init_length=max_time * (self.maes_num_steps + 1) if self.maes_num_steps is not None else max_time,
device=device,
float_dtype=float_dtype,
store_prefix_hashes=True,
)

last_labels_wb = torch.full(
[batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long
)


# import pdb; pdb.set_trace()
# encoder_output_projected = encoder_output
# float_dtype = encoder_output.dtype

batch_indices = (
torch.arange(batch_size, device=device)[:, None].expand(batch_size, self.beam_size).clone()
) # size: batch_size x beam_size
Expand All @@ -162,6 +154,28 @@ def batched_modified_adaptive_expansion_search_torch(
.clone()
) # size: batch_size x beam_size x beam_size + maes_expansion_beta

if prev_batched_state is not None and prev_batched_state.batched_hyps is not None:
batched_hyps = prev_batched_state.batched_hyps
time_indices = torch.zeros_like(beam_indices)
last_timesteps = (encoder_output_length - 1)[:, None].expand_as(beam_indices)
safe_time_indices = torch.minimum(time_indices, last_timesteps)
active_mask = time_indices <= last_timesteps
else:
# init empty batched hypotheses
batched_hyps = BatchedBeamHyps(
batch_size=batch_size,
beam_size=self.beam_size,
blank_index=self._blank_index,
init_length=max_time * (self.maes_num_steps + 1) if self.maes_num_steps is not None else max_time,
device=device,
float_dtype=float_dtype,
store_prefix_hashes=True,
)
time_indices = torch.zeros_like(beam_indices)
safe_time_indices = torch.zeros_like(time_indices) # time indices, guaranteed to be < out_len
last_timesteps = (encoder_output_length - 1)[:, None].expand_as(beam_indices)
active_mask = time_indices <= last_timesteps

time_indices = torch.zeros_like(batch_indices)
safe_time_indices = torch.zeros_like(time_indices)
last_timesteps = (encoder_output_length - 1)[:, None].expand(batch_size, self.beam_size)
Expand All @@ -176,14 +190,34 @@ def batched_modified_adaptive_expansion_search_torch(
) # vocab_size_no_blank
lm_scores = lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha

decoder_output, decoder_state, *_ = self.decoder.predict(
last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size
)
# do not recalculate joint projection
decoder_output = self.joint.project_prednet(decoder_output)
if prev_batched_state is None:
last_labels_wb = torch.full(
[batch_size, self.beam_size], fill_value=self._SOS, device=device, dtype=torch.long
)
decoder_state = self.decoder.initialize_state(
torch.empty(
[
batch_size * self.beam_size,
],
dtype=float_dtype,
device=device,
)
)

decoder_output, state, *_ = self.decoder.predict(
last_labels_wb.view(-1, 1), None, add_sos=False, batch_size=batch_size * self.beam_size
)
# do not recalculate joint projection
decoder_output = self.joint.project_prednet(decoder_output) # size: [(batch_size x beam_size), 1, Dim]
self.decoder.batch_replace_states_all(state, dst_states=decoder_state)
else:
# Continuing from previous chunk - batched_hyps already contains all state
decoder_output = prev_batched_state.predictor_outputs
decoder_state = prev_batched_state.predictor_states

while active_mask.any(): # frames loop
to_update = active_mask.clone() # mask for expansions loop
# import pdb; pdb.set_trace()

# step 1: get joint output
logits = (
Expand Down Expand Up @@ -235,6 +269,7 @@ def batched_modified_adaptive_expansion_search_torch(
next_labels = next_labels.view(batch_size, -1)[batch_indices, idx]
hyp_indices = expansion_beam_indices.view(batch_size, -1)[batch_indices, idx]

# import pdb; pdb.set_trace()
# step 3.3: update batched beam hypotheses structure
batched_hyps.add_results_(hyp_indices, next_labels, next_hyps_probs)

Expand Down Expand Up @@ -311,6 +346,7 @@ def batched_modified_adaptive_expansion_search_torch(

expansion_steps += 1
if to_update.any():
# import pdb; pdb.set_trace()
# step 4: force blank to active hypotheses
next_hyps_probs = torch.where(to_update, batched_hyps.scores + logps[..., -1], batched_hyps.scores)
next_labels = torch.where(to_update, self._blank_index, -1)
Expand All @@ -320,8 +356,29 @@ def batched_modified_adaptive_expansion_search_torch(
time_indices += 1
active_mask = time_indices <= last_timesteps
safe_time_indices = torch.where(active_mask, time_indices, last_timesteps)

return batched_hyps

# import pdb; pdb.set_trace()
last_labels = batched_hyps.get_last_labels(pad_id=self._SOS)
batched_hyps.next_timestamp.fill_(0)
decoding_state = BatchedLabelLoopingState(
predictor_states=decoder_state,
predictor_outputs=decoder_output,
labels=(
torch.where(last_labels == self._SOS, prev_batched_state.labels, last_labels)
if prev_batched_state is not None
else last_labels
),
decoded_lengths=(
encoder_output_length.clone()
if prev_batched_state is None
else encoder_output_length + prev_batched_state.decoded_lengths
),
# fusion_states_list=fusion_states_list if self.fusion_models is not None else None,
time_jumps=None,
batched_hyps=batched_hyps, # Save batched_hyps object for next chunk
)

return batched_hyps, None, decoding_state

def combine_scores(self, log_probs, lm_scores):
"""
Expand Down Expand Up @@ -438,5 +495,6 @@ def __call__(
self,
x: torch.Tensor,
out_len: torch.Tensor,
) -> BatchedBeamHyps:
return self.batched_modified_adaptive_expansion_search_torch(encoder_output=x, encoder_output_length=out_len)
prev_batched_state: Optional[BatchedBeamHyps] = None,
) -> tuple[BatchedBeamHyps, Optional[rnnt_utils.BatchedAlignments], BatchedLabelLoopingState]:
return self.batched_modified_adaptive_expansion_search_torch(encoder_output=x, encoder_output_length=out_len, prev_batched_state=prev_batched_state)
Loading
Loading