Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,13 @@
from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import BiasingRequestItemConfig
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig
from nemo.collections.asr.parts.submodules.rnnt_maes_batched_computer import ModifiedAESBatchedRNNTComputer
from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer
from nemo.collections.asr.parts.submodules.tdt_malsd_batched_computer import ModifiedALSDBatchedTDTComputer
from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import (
GreedyBatchedLabelLoopingComputerBase,
)
from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import BatchedBeamHyps
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.manifest_utils import filepath_to_absolute, read_manifest
from nemo.collections.asr.parts.utils.rnnt_utils import BatchedHyps, batched_hyps_to_hypotheses
Expand Down Expand Up @@ -232,9 +236,21 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:

# Change Decoding Config
with open_dict(cfg.decoding):
if cfg.decoding.strategy != "greedy_batch" or cfg.decoding.greedy.loop_labels is not True:
if cfg.decoding.strategy == "greedy_batch":
if cfg.decoding.greedy.loop_labels is not True:
raise NotImplementedError(
"This script supports `greedy_batch` strategy only with Label-Looping algorithm"
)
cfg.decoding.greedy.preserve_alignments = False
elif cfg.decoding.strategy == "malsd_batch":
# MALSD beam search is supported for streaming
pass
elif cfg.decoding.strategy == "maes_batch":
# MAES beam search is supported for streaming
pass
else:
raise NotImplementedError(
"This script currently supports only `greedy_batch` strategy with Label-Looping algorithm"
"This script currently supports only `greedy_batch` with Label-Looping or `malsd_batch` strategy with MALSD"
)
cfg.decoding.tdt_include_token_duration = cfg.timestamps
cfg.decoding.greedy.preserve_alignments = False
Expand Down Expand Up @@ -278,7 +294,16 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
asr_model.preprocessor.featurizer.pad_to = 0
asr_model.eval()

decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer
# Get decoding computer based on strategy. Beam-search strategies expose the
# underlying computer via the private ``_decoding_computer`` attribute.
if cfg.decoding.strategy == "greedy_batch":
decoding_computer: GreedyBatchedLabelLoopingComputerBase = asr_model.decoding.decoding.decoding_computer
elif cfg.decoding.strategy == "malsd_batch":
decoding_computer = asr_model.decoding.decoding._decoding_computer
elif cfg.decoding.strategy == "maes_batch":
decoding_computer: ModifiedAESBatchedRNNTComputer = asr_model.decoding.decoding._decoding_computer
else:
raise ValueError(f"Unsupported decoding strategy: {cfg.decoding.strategy}")

audio_sample_rate = model_cfg.preprocessor['sample_rate']

Expand Down Expand Up @@ -394,7 +419,15 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
)
rest_audio_lengths = audio_batch_lengths.clone()

# iterate over audio samples
# Beam-search strategies (MALSD/MAES/TDT-MALSD) keep their batched_hyps inside
# the decoder state and reset chunk-local buffers per chunk; greedy returns a
# fresh BatchedHyps each call which we must merge externally.
is_beam_search = isinstance(
decoding_computer,
(ModifiedALSDBatchedRNNTComputer, ModifiedAESBatchedRNNTComputer, ModifiedALSDBatchedTDTComputer),
)

# iterate over audio samples (but only for decoding chunks)
while left_sample < audio_batch.shape[1]:
# add samples to buffer
chunk_length = min(right_sample, audio_batch.shape[1]) - left_sample
Expand Down Expand Up @@ -425,18 +458,47 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
encoder_output = encoder_output[:, encoder_context.left :]

# decode only chunk frames
chunk_batched_hyps, _, state = decoding_computer(
x=encoder_output,
out_len=torch.where(
is_last_chunk_batch,
encoder_output_len - encoder_context_batch.left,
encoder_context_batch.chunk,
),
prev_batched_state=state,
multi_biasing_ids=multi_biasing_ids,
out_len = torch.where(
is_last_chunk_batch,
encoder_output_len - encoder_context_batch.left,
encoder_context_batch.chunk,
)
# merge hyps with previous hyps
if current_batched_hyps is None:
if is_beam_search:
# Beam-search computers don't accept ``multi_biasing_ids`` yet.
chunk_batched_hyps, _, state = decoding_computer(
x=encoder_output, out_len=out_len, prev_batched_state=state
)
else:
chunk_batched_hyps, _, state = decoding_computer(
x=encoder_output,
out_len=out_len,
prev_batched_state=state,
multi_biasing_ids=multi_biasing_ids,
)

# Accumulate hypotheses across chunks.
if is_beam_search:
# ``flatten_`` resolves the chunk-local prefix tree without sorting and
# returns ``root_ptrs``: for each chunk-end beam ``i``, the beam index
# at the chunk's start it descends from. The loop body permutes beams
# via top-K, so beam ``i`` at the end of chunk ``N`` is generally a
# different logical hypothesis from beam ``i`` at the end of chunk
# ``N-1`` - it descends from beam ``root_ptrs[i]``. Threading these
# pointers via ``boundary_prev_ptr`` encodes the redirection so the
# final ``flatten_sort_`` in ``to_hyps_list`` walks back through the
# right beam history. ``is_chunk_continuation=True`` makes ``merge_``
# replace (not sum) the cumulative per-beam fields (``scores``,
# ``current_lengths_nb``, ...).
chunk_root_ptrs = state.batched_hyps.flatten_()
if current_batched_hyps is None:
current_batched_hyps = state.batched_hyps
else:
current_batched_hyps.merge_(
state.batched_hyps,
is_chunk_continuation=True,
boundary_prev_ptr=chunk_root_ptrs,
)
elif current_batched_hyps is None:
current_batched_hyps = chunk_batched_hyps
else:
current_batched_hyps.merge_(chunk_batched_hyps)
Expand All @@ -446,13 +508,19 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
left_sample = right_sample
right_sample = min(right_sample + context_samples.chunk, audio_batch.shape[1]) # add next chunk

# remove biasing requests from the decoder
if use_per_stream_biasing and audio_data.biasing_requests is not None:
for request in audio_data.biasing_requests:
if request is not None and request.multi_model_id is not None:
decoding_computer.biasing_multi_model.remove_model(request.multi_model_id)
request.multi_model_id = None
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size))
# Convert batched hypotheses to list
if isinstance(current_batched_hyps, BatchedBeamHyps):
# MALSD beam search returns BatchedBeamHyps
all_hyps.extend(current_batched_hyps.to_hyps_list(score_norm=True))
else:
# Greedy batch returns BatchedHyps
# remove biasing requests from the decoder
if use_per_stream_biasing and audio_data.biasing_requests is not None:
for request in audio_data.biasing_requests:
if request is not None and request.multi_model_id is not None:
decoding_computer.biasing_multi_model.remove_model(request.multi_model_id)
request.multi_model_id = None
all_hyps.extend(batched_hyps_to_hypotheses(current_batched_hyps, None, batch_size=batch_size))
timer.stop(device=map_location)

# convert text
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1690,7 +1690,7 @@ def forward(
self.joint.eval()

inseq = encoder_output # [B, T, D]
batched_beam_hyps = self._decoding_computer(x=inseq, out_len=logitlen)
batched_beam_hyps, alignments, decoding_state = self._decoding_computer(x=inseq, out_len=logitlen)

batch_size = encoder_output.shape[0]
if self.return_best_hypothesis:
Expand Down
Loading
Loading