From 23036c5ff8953c473a89b1ad500860597fdf106e Mon Sep 17 00:00:00 2001 From: Enas Albasiri Date: Wed, 6 May 2026 14:23:53 +0000 Subject: [PATCH 1/9] Add unified langID support for Hybrid and RNNT --- .../speech_to_text_rnnt_bpe_prompt.py | 95 +++ ...d_transducer_ctc_bpe_streaming_prompt.yaml | 426 +++++++++++ .../data/audio_to_text_lhotse_prompt_index.py | 147 ++++ nemo/collections/asr/models/__init__.py | 2 + .../hybrid_rnnt_ctc_bpe_models_prompt.py | 124 ++-- .../asr/models/rnnt_bpe_models_prompt.py | 681 ++++++++++++++++++ 6 files changed, 1415 insertions(+), 60 deletions(-) create mode 100644 examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py create mode 100644 examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming_prompt.yaml create mode 100644 nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py create mode 100644 nemo/collections/asr/models/rnnt_bpe_models_prompt.py diff --git a/examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py b/examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py new file mode 100644 index 000000000000..cd1a76d90e57 --- /dev/null +++ b/examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Preparing the Tokenizer for the dataset +Use the `process_asr_text_tokenizer.py` script under /scripts/tokenizers/ in order to prepare the tokenizer. + +# Manifest file example: +{"audio_filepath":"/data/audio.wav","duration":12.12,"text":"The transcript.","target_lang":"en-US"} + +```sh +python /scripts/tokenizers/process_asr_text_tokenizer.py \ + --manifest= + OR + --data_file= \ + --data_root="" \ + --vocab_size= \ + --tokenizer=<"spe" or "wpe"> \ + --no_lower_case \ + --spe_type=<"unigram", "bpe", "char" or "word"> \ + --spe_character_coverage=1.0 \ + --log +``` + +# Training the model +```sh +python speech_to_text_rnnt_bpe_prompt.py \ + # (Optional: --config-path= --config-name=) \ + model.train_ds.manifest_filepath= \ + model.validation_ds.manifest_filepath= \ + model.tokenizer.dir= \ + model.tokenizer.type= \ + trainer.devices=-1 \ + trainer.max_epochs=100 \ + model.optim.name="adamw" \ + model.optim.lr=0.001 \ + model.optim.betas=[0.9,0.999] \ + model.optim.weight_decay=0.0001 \ + model.optim.sched.warmup_steps=2000 + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name="" \ + exp_manager.wandb_logger_kwargs.project="" +``` + +# Fine-tune a model + +For documentation on fine-tuning this model, please visit - +https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations + +""" + +import lightning.pytorch as pl +from omegaconf import OmegaConf + +from nemo.collections.asr.models import EncDecRNNTBPEModelWithPrompt +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo.utils.trainer_utils import resolve_trainer_cfg + + +@hydra_runner( + config_path="../conf/fastconformer/hybrid_cache_aware_streaming/", + config_name="fastconformer_hybrid_transducer_ctc_bpe_streaming_prompt_600m.yaml", +) +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**resolve_trainer_cfg(cfg.trainer)) + exp_manager(trainer, cfg.get("exp_manager", None)) + asr_model = EncDecRNNTBPEModelWithPrompt(cfg=cfg.model, trainer=trainer) + + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + + trainer.fit(asr_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if asr_model.prepare_test(trainer): + trainer.test(asr_model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming_prompt.yaml b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming_prompt.yaml new file mode 100644 index 000000000000..4c933d9e40e8 --- /dev/null +++ b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming_prompt.yaml @@ -0,0 +1,426 @@ +# Cache-aware streaming FastConformer-Hybrid-Transducer-CTC ASR model with prompt support +# Combines cache-aware streaming encoder with prompt-based multilingual capability + +# You may find more detail: +# FastConformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#fast-conformer +# Hybrid ASR: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#hybrid-transducer-ctc +# Cache-aware Conformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#cache-aware-streaming-conformer + +name: "FastConformer-Hybrid-Transducer-CTC-BPE-Prompt-Streaming" + +model: + sample_rate: 16000 + compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. + log_prediction: true # enables logging sample predictions in the output during training + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + # Prompt configuration + initialize_prompt_feature: true + num_prompts: 128 + norm: None + # Dictionary mapping prompt identifiers to their corresponding embedding indices + prompt_dictionary: { + # Language prompts (0-99) + 'en-US': 0, + 'en': 0, + 'en-GB': 1, + 'enGB': 1, + 'es-ES': 2, + 'esES': 2, + 'es-US': 3, + 'es': 3, + 'zh-CN': 4, + 'zh-ZH': 4, + 'zh-TW': 5, + 'hi-IN': 6, + 'hi': 6, + 'hi-HI': 6, + 'ar-AR': 7, + 'ar': 7, + 'fr-FR': 8, + 'fr': 8, + 'de-DE': 9, + 'de': 9, + 'ja-JP': 10, + 'ja-JA': 10, + 'ru-RU': 11, + 'ru': 11, + 'pt-BR': 12, + 'pt-PT': 13, + 'pt': 13, + 'ko-KR': 14, + 'ko': 14, + 'ko-KO': 14, + 'it-IT': 15, + 'it': 15, + 'nl-NL': 16, + 'nl': 16, + 'pl-PL': 17, + 'pl': 17, + 'tr-TR': 18, + 'tr': 18, + 'uk-UA': 19, + 'uk': 19, + 'ro-RO': 20, + 'ro': 20, + 'el-GR': 21, + 'el': 21, + 'cs-CZ': 22, + 'cs': 22, + 'hu-HU': 23, + 'hu': 23, + 'sv-SE': 24, + 'sv': 24, + 'da-DK': 25, + 'da': 25, + 'fi-FI': 26, + 'fi': 26, + 'no-NO': 27, + 'no': 27, + 'nb-NO': 103, + 'nb': 103, + 'sk-SK': 28, + 'sk': 28, + 'hr-HR': 29, + 'hr': 29, + 'bg-BG': 30, + 'bg': 30, + 'lt-LT': 31, + 'lt': 31, + # Granary languages (60-62) + 'et-EE': 60, + 'et': 60, + 'lv-LV': 61, + 'lv': 61, + 'sl-SI': 62, + 'sl': 62, + 'th-TH': 32, + 'vi-VN': 33, + 'id-ID': 34, + 'ms-MY': 35, + 'bn-IN': 36, + 'ur-PK': 37, + 'fa-IR': 38, + 'ta-IN': 39, + 'te-IN': 40, + 'mr-IN': 41, + 'gu-IN': 42, + 'kn-IN': 43, + 'ml-IN': 44, + 'si-LK': 45, + 'ne-NP': 46, + 'km-KH': 47, + 'sw-KE': 48, + 'am-ET': 49, + 'ha-NG': 50, + 'zu-ZA': 51, + 'yo-NG': 52, + 'ig-NG': 53, + 'af-ZA': 54, + 'rw-RW': 55, + 'so-SO': 56, + 'ny-MW': 57, + 'ln-CD': 58, + 'or-KE': 59, + 'he-IL': 64, + 'ku-TR': 65, + 'az-AZ': 66, + 'ka-GE': 67, + 'hy-AM': 68, + 'uz-UZ': 69, + 'tg-TJ': 70, + 'ky-KG': 71, + 'qu-PE': 80, + 'ay-BO': 81, + 'gn-PY': 82, + 'nah-MX': 83, + 'mi-NZ': 96, + 'haw-US': 97, + 'sm-WS': 98, + 'to-TO': 99, + 'fr-CA': 100, + 'mt-MT': 102, + 'auto': 101 + } + + train_ds: + manifest_filepath: /config/asr_P0_P1_riva_training_data_v2_fixed_spacing_with_lang_tag.yaml + sample_rate: ${model.sample_rate} + use_lhotse: true + shard_manifests: true + batch_duration: 400 + quadratic_duration: 15 + num_buckets: 30 + shuffle: true + num_workers: 8 + pin_memory: true + max_duration: 39.99 # you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: true + tarred_audio_filepaths: null + shuffle_n: 2048 + slice_length: 100 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + # prompt configs + prompt_field: target_lang + prompt_dictionary: ${model.model_defaults.prompt_dictionary} + num_prompts: ${model.model_defaults.num_prompts} + subsampling_factor: ${model.encoder.subsampling_factor} + lang_field: target_lang + training_mode: true # 50% use auto (101), 50% use actual lang ID + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 2 + shuffle: false + use_start_end_token: false + num_workers: 2 + pin_memory: true + batch_duration: null + use_lhotse: true + use_bucketing: false + max_cuts: 8 + # prompt configurations for validation + prompt_field: target_lang + prompt_dictionary: ${model.model_defaults.prompt_dictionary} + num_prompts: ${model.model_defaults.num_prompts} + subsampling_factor: ${model.encoder.subsampling_factor} + training_mode: true # pass lang ID 50% + + test_ds: + manifest_filepath: '/manifests/fleurs/fleurs_emea_emea_lhotse.dev.ast_update_lang_id_cross_only_unique_.json' + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + use_start_end_token: false + num_workers: 8 + pin_memory: true + use_lhotse: true + use_bucketing: false + # prompt configurations for testing + prompt_field: target_lang + prompt_dictionary: ${model.model_defaults.prompt_dictionary} + num_prompts: ${model.model_defaults.num_prompts} + subsampling_factor: ${model.encoder.subsampling_factor} + training_mode: false # Always use actual lang ID during testing + + # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "NA" # No normalization for mel-spectogram makes streaming easier + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 42 + d_model: 1024 + use_bias: false + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: true # Required for streaming + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + # For multi-lookahead models, you may specify a list of context sizes. + # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + att_context_size: [70, 6] # look-ahead = 6*8*0.01 = 0.48s + att_context_probs: null + att_context_style: chunked_limited # regular or chunked_limited + xscaling: false + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'layer_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: causal # Required for streaming + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null # Currently only null is supported for export. + random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf + blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference. + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 2 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null # 'null' would set it automatically according to CPU/GPU device + preserve_memory: false # dramatically slows down training, but might preserve some memory + + # Fuses the computation of prediction net + joint net + loss + WER calculation + # to be run on sub-batches of size `fused_batch_size`. + fuse_loss_wer: true + fused_batch_size: 2 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + + decoding: + strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. + + # greedy strategy config + greedy: + max_symbols: 10 + + # beam strategy config + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 # for Time Synchronous Decoding + alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + + # The section which would contain the decoder and decoding configs of the auxiliary CTC decoder + aux_ctc: + ctc_loss_weight: 0.1 # the weight used to combine the CTC loss with the RNNT loss + use_cer: false + ctc_reduction: 'mean_batch' + decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: null + num_classes: -1 + vocabulary: [] + decoding: + strategy: "greedy" + + # config for InterCTC loss: https://arxiv.org/abs/2102.03216 + interctc: + loss_weights: [] + apply_at_layers: [] + + loss: + loss_name: "default" + + warprnnt_numba_kwargs: + # FastEmit regularization: https://arxiv.org/abs/2010.11148 + # You may enable FastEmit to reduce the latency of the model for streaming + fastemit_lambda: 5e-3 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. + + # Adds Gaussian noise to the gradients of the decoder to avoid overfitting + variational_noise: + start_step: 0 + std: 0.0 + + optim: + name: adamw + lr: 2.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 500000 # computed at runtime if not set + val_check_interval: 0.5 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + gradient_as_bucket_view: true + accumulate_grad_batches: 1 + gradient_clip_val: 0.5 + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 100 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + use_distributed_sampler: false + limit_train_batches: 1000 + + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + resume_if_exists: false + resume_ignore_no_checkpoint: false + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py new file mode 100644 index 000000000000..35bd9a090790 --- /dev/null +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py @@ -0,0 +1,147 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simplified Lhotse dataset that returns language ID indices instead of full prompt tensors. +The model creates the prompt tensor using the actual encoded length. +""" + +import random +from typing import Dict, Optional, Tuple + +import torch +import torch.utils.data +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_vectors + +from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType +from nemo.utils import logging + + +class LhotseSpeechToTextBpeDatasetWithPromptIndex(torch.utils.data.Dataset): + """ + Simplified dataset class for speech-to-text with prompt support. + + Instead of computing full prompt tensors, this dataset returns just the + language ID index per sample. The model creates the prompt tensor using + the actual encoder output length, guaranteeing no size mismatch. + + Returns: + audio_signal: Audio waveform [B, T] + audio_signal_length: Audio lengths [B] + transcripts: Token IDs [B, T] + transcript_length: Token lengths [B] + prompt_indices: Language ID indices [B] (NOT full tensors) + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), + 'audio_signal_length': NeuralType(tuple('B'), LengthsType()), + 'transcripts': NeuralType(('B', 'T'), LabelsType()), + 'transcript_length': NeuralType(tuple('B'), LengthsType()), + 'prompt_indices': NeuralType(tuple('B'), LabelsType()), # Just indices, not full tensors + } + + def __init__(self, tokenizer, cfg): + super().__init__() + self.tokenizer = TokenizerWrapper(tokenizer) + self.load_audio = AudioSamples(fault_tolerant=True) + self.cfg = cfg + + # Load prompt dictionary from config + self.prompt_dict = cfg.get('prompt_dictionary') + if not self.prompt_dict: + raise ValueError("prompt_dictionary is required in config") + + self.num_prompts = cfg.get('num_prompts', 128) + + # Field to use for prompt key (default to 'target_lang') + self.prompt_field = cfg.get('prompt_field', 'target_lang') + + # Training mode flag: when True, randomly use auto (101) 50% of the time + self.training_mode = cfg.get('training_mode', True) + + logging.info(f"LhotseSpeechToTextBpeDatasetWithPromptIndex: Returns indices only, model creates prompt tensor") + + def _get_prompt_index(self, prompt_key: str) -> int: + """Maps prompt keys to indices using the prompt dictionary.""" + if prompt_key not in self.prompt_dict: + available_keys = list(self.prompt_dict.keys()) + raise ValueError( + f"Unknown prompt key: '{prompt_key}'. Available: {available_keys[:10]}{'...' if len(available_keys) > 10 else ''}" + ) + return self.prompt_dict[prompt_key] + + def _get_prompt_index_for_cut(self, cut) -> int: + """ + Get prompt index for a cut, with training mode randomization. + During training: 50% chance to use auto (101), 50% actual language ID + During inference: always use the actual language ID + """ + if self.training_mode and random.random() < 0.5: + return 101 # Auto/language-agnostic + else: + return self._get_prompt_index(cut.supervisions[0].language) + + def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: + audio, audio_lens, cuts = self.load_audio(cuts) + tokens = [torch.as_tensor(self.tokenizer(c.supervisions[0].text, c.supervisions[0].language)) for c in cuts] + + # Get prompt indices (just the language ID per sample, NOT full tensors) + prompt_indices = torch.tensor( + [self._get_prompt_index_for_cut(c) for c in cuts], + dtype=torch.long + ) + + # Create final tensors + token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) + tokens = collate_vectors(tokens, padding_value=0) + + return ( + audio, # Audio signal [B, T] + audio_lens, # Audio lengths [B] + tokens, # Text tokens [B, T] + token_lens, # Token lengths [B] + prompt_indices, # Language ID indices [B] - model creates full tensor + ) + + +class TokenizerWrapper: + """Provide a unified interface for NeMo Tokenizer, AggregateTokenizer, and (char) Parser.""" + + def __init__(self, tokenizer): + self._tokenizer = tokenizer + if isinstance(tokenizer, AggregateTokenizer): + self._impl = self._call_agg_tokenizer + elif isinstance(tokenizer, TokenizerSpec): + self._impl = self._call_tokenizer + else: + self._impl = self._call_parser + + def __call__(self, text: str, lang: Optional[str] = None): + return self._impl(text, lang) + + def _call_agg_tokenizer(self, text: str, lang: Optional[str] = None): + assert lang is not None, "Expected 'lang' to be set for AggregateTokenizer." + return self._tokenizer.text_to_ids(text, lang) + + def _call_tokenizer(self, text: str, lang: Optional[str] = None): + return self._tokenizer.text_to_ids(text) + + def _call_parser(self, text: str, lang: Optional[str] = None): + return self._tokenizer(text) diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index cc9b3a74e1ea..67a99268175b 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -30,6 +30,7 @@ from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel # noqa: F401 from nemo.collections.asr.models.multitalker_asr_models import EncDecMultiTalkerRNNTBPEModel # noqa: F401 from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel # noqa: F401 +from nemo.collections.asr.models.rnnt_bpe_models_prompt import EncDecRNNTBPEModelWithPrompt # noqa: F401 from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel # noqa: F401 from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel # noqa: F401 from nemo.collections.asr.models.ssl_models import ( # noqa: F401 @@ -55,6 +56,7 @@ 'EncDecMultiTaskModel', 'EncDecMultiTalkerRNNTBPEModel', 'EncDecRNNTBPEModel', + 'EncDecRNNTBPEModelWithPrompt', 'EncDecRNNTModel', 'EncDecSpeakerLabelModel', 'EncDecTransfModelBPE', diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py index 992de84fc7a8..82a68e020cac 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py @@ -26,6 +26,7 @@ from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset, DALIOutputs from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.data.audio_to_text_lhotse_prompt import LhotseSpeechToTextBpeDatasetWithPrompt +from nemo.collections.asr.data.audio_to_text_lhotse_prompt_index import LhotseSpeechToTextBpeDatasetWithPromptIndex from nemo.collections.asr.metrics.bleu import BLEU from nemo.collections.asr.metrics.wer import WER from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel @@ -55,8 +56,9 @@ class HybridRNNTCTCPromptTranscribeConfig(TranscribeConfig): Configuration for Hybrid RNNT-CTC BPE Model with Prompt Transcription """ - target_lang: str = "en-US" - prompt_field: str = "lang" + target_lang: str = "auto" + prompt_field: str = "target_lang" + class EncDecHybridRNNTCTCBPEModelWithPrompt(EncDecHybridRNNTCTCBPEModel, ASRTranscriptionMixin): @@ -108,9 +110,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup prompt settings - default to 128 prompts if not specified cfg.num_prompts = cfg.model_defaults.get('num_prompts', 128) - # Make sure prompt_dictionary exists if 'prompt_dictionary' not in cfg.model_defaults: - raise ValueError("No prompt_dictionary found in config.") + logging.warning( + "No prompt_dictionary in config; using empty dict " + "(expected during checkpoint restoration)." + ) + cfg.model_defaults.prompt_dictionary = {} # Set subsampling_factor in a place accessible to the class self.subsampling_factor = cfg.get('subsampling_factor', 8) @@ -196,11 +201,15 @@ def initialize_prompt_feature(self): def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): if config.get('initialize_prompt_feature', True): - dataset = LhotseSpeechToTextBpeDatasetWithPrompt(tokenizer=self.tokenizer, cfg=config) - logging.info("Setting up Lhotse dataset with prompt support") + # Use index-based dataset - returns prompt indices instead of full tensors + # The model creates prompt tensors after encoding, guaranteeing no size mismatch + dataset_config = OmegaConf.to_container(config, resolve=True) if isinstance(config, DictConfig) else dict(config) + if hasattr(self, 'cfg') and 'encoder' in self.cfg: + dataset_config['encoder'] = OmegaConf.to_container(self.cfg.encoder, resolve=True) if isinstance(self.cfg.encoder, DictConfig) else dict(self.cfg.encoder) + dataset = LhotseSpeechToTextBpeDatasetWithPromptIndex(tokenizer=self.tokenizer, cfg=dataset_config) + logging.info("Setting up Lhotse dataset with prompt index support (model creates prompt tensors)") else: dataset = LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer) - logging.info("Setting up Lhotse dataset without prompt support") return get_lhotse_dataloader_from_config( config, global_rank=self.global_rank, @@ -318,24 +327,24 @@ def _transcribe_forward(self, batch: tuple[torch.Tensor, ...], trcfg: HybridRNNT The model's outputs that are processed by `_transcribe_output_processing()`. """ # Handling DataLoader batch - should be a tuple of tensors - # Expected structure: (audio, audio_lens, tokens, token_lens, prompt_targets) - # For transcription, we may only have (audio, audio_lens) or (audio, audio_lens, ..., prompt_targets) + # Expected structure: (audio, audio_lens, tokens, token_lens, prompt_indices) + # For transcription, we may only have (audio, audio_lens) or (audio, audio_lens, ..., prompt_indices) audio, audio_lens = batch[0], batch[1] if len(batch) >= 5: - # Prompt provided by the dataloader (one-hot vectors) - prompt = batch[4] # This should be the prompt_targets from dataset + # Prompt indices provided by the dataloader (language ID indices) + prompt_indices = batch[4] # This should be the prompt_indices from dataset else: # Prompt to be built dynamically. - prompt = None + prompt_indices = None batch_size = audio.shape[0] - if prompt is None: + if prompt_indices is None: # The dataloader provided only audio + audio_lens, so we need to construct - # the prompt as one-hot vectors dynamically using TranscribeConfig. + # the prompt indices dynamically using TranscribeConfig. target_lang = trcfg.target_lang - # Get prompt dictionary and num_prompts from model config + # Get prompt dictionary from model config prompt_dict = self.cfg.model_defaults.get('prompt_dictionary') num_prompts = self.cfg.model_defaults.get('num_prompts', 128) @@ -351,25 +360,13 @@ def _transcribe_forward(self, batch: tuple[torch.Tensor, ...], trcfg: HybridRNNT prompt_id = prompt_dict[target_lang] - # Preprocess audio to get the actual feature dimensions (like streaming does) - processed_signal, processed_signal_length = self.preprocessor(input_signal=audio, length=audio_lens) + # Create prompt index tensor for the batch - forward() will create the one-hot prompt + # from these indices using the actual encoder output length + prompt_indices = torch.full((batch_size,), prompt_id, dtype=torch.long, device=audio.device) - # Calculate exact hidden length using the same approach as streaming - time_length = processed_signal.shape[2] # Feature time dimension - subsampling_factor = self.cfg.get('subsampling_factor', 8) - hidden_length = math.ceil(time_length / subsampling_factor) - - # Create one-hot prompt tensor: (batch_size, time_steps, num_prompts) - prompt = torch.zeros(batch_size, hidden_length, num_prompts, dtype=torch.float32, device=audio.device) - prompt[:, :, prompt_id] = 1.0 # Set the target language prompt to 1 - - # Now call forward with preprocessed signal and prompt - encoded, encoded_len = self.forward( - processed_signal=processed_signal, processed_signal_length=processed_signal_length, prompt=prompt - ) - else: - # Prompt was provided, use normal forward path - encoded, encoded_len = self.forward(input_signal=audio, input_signal_length=audio_lens, prompt=prompt) + # Call forward with prompt_indices - the model creates prompt tensors after encoding + # This guarantees no size mismatch between encoded features and prompt tensors + encoded, encoded_len = self.forward(input_signal=audio, input_signal_length=audio_lens, prompt_indices=prompt_indices) # Prepare output dictionary based on decoder type if self.cur_decoder == "rnnt": @@ -383,6 +380,7 @@ def _transcribe_forward(self, batch: tuple[torch.Tensor, ...], trcfg: HybridRNNT return output + @torch.no_grad() def transcribe( self, @@ -458,7 +456,7 @@ def transcribe( # Create transcription config if not provided if override_config is None: # Extract target_lang from prompt or use default - target_lang = prompt.get('target_lang', 'en-US') + target_lang = prompt.get('target_lang', 'auto') prompt_field = prompt.get('prompt_field', 'target_lang') trcfg = HybridRNNTCTCPromptTranscribeConfig( @@ -507,7 +505,7 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), - "prompt": NeuralType(('B', 'T', 'D'), LabelsType()), + "prompt_indices": NeuralType(tuple('B'), LabelsType()), # Language ID indices per sample } @property @@ -524,20 +522,12 @@ def forward( input_signal_length=None, processed_signal=None, processed_signal_length=None, - prompt=None, + prompt_indices=None, ): """ Forward pass of the model. Note that for RNNT Models, the forward pass of the model is a 3 step process, and this method only performs the first step - forward of the acoustic model. - Please refer to the `training_step` in order to see the full `forward` step for training - which - performs the forward of the acoustic model, the prediction network and then the joint network. - Finally, it computes the loss and possibly compute the detokenized text via the `decoding` step. - - Please refer to the `validation_step` in order to see the full `forward` step for inference - which - performs the forward of the acoustic model, the prediction network and then the joint network. - Finally, it computes the decoded tokens via the `decoding` step and possibly compute the batch metrics. - Args: input_signal: Tensor that represents a batch of raw audio signals, of shape [B, T]. T here represents timesteps, with 1 second of audio represented as @@ -548,13 +538,13 @@ def forward( of shape (B, D, T) that has undergone processing via some DALI preprocessor. processed_signal_length: Vector of length B, that contains the individual lengths of the processed audio sequences. - prompt: Tensor that represents the prompt embeddings, - of shape (B, T, D) where D is the number of supported prompts. - Used for prompt-conditioned encoding via concatenation with acoustic features. + prompt_indices: Tensor of shape [B] containing language ID indices per sample. + The model creates the prompt tensor after encoding using the actual + encoder output length, guaranteeing no size mismatch. Returns: A tuple of 2 elements - - 1) The log probabilities tensor of shape [B, T, D]. + 1) The encoded tensor of shape [B, D, T]. 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. """ has_input_signal = input_signal is not None and input_signal_length is not None @@ -579,17 +569,26 @@ def forward( encoded = torch.transpose(encoded, 1, 2) # B * D * T -> B * T * D if self.concat: - if prompt.shape[1] > encoded.shape[1]: - prompt = prompt[:, : encoded.shape[1], :] - out_dtype = encoded.dtype # this is dtype, which the decoder previously got from encoder + if prompt_indices is None: + raise ValueError("prompt_indices must be provided when concat mode is enabled.") + + # Create prompt tensor from indices using actual encoded length + batch_size = encoded.shape[0] + time_steps = encoded.shape[1] + num_prompts = self.num_prompts + + # Create one-hot prompt tensor: (batch_size, time_steps, num_prompts) + prompt = torch.zeros(batch_size, time_steps, num_prompts, dtype=encoded.dtype, device=encoded.device) + # Vectorized scatter: set each sample's language column to 1.0 (no Python loop) + prompt.scatter_(2, prompt_indices.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) + + out_dtype = encoded.dtype # Concatenate encoded states with prompt concat_enc_states = torch.cat([encoded, prompt], dim=-1) # Apply joint projection - encoded = self.prompt_kernel(concat_enc_states).to( - out_dtype - ) # cast: unexpectedly without cast dtype is different from out_dtype + encoded = self.prompt_kernel(concat_enc_states).to(out_dtype) encoded = torch.transpose(encoded, 1, 2) # B * T * D -> B * D * T return encoded, encoded_len @@ -602,13 +601,15 @@ def training_step(self, batch, batch_nb): if self.is_interctc_enabled(): AccessMixin.set_access_enabled(access_enabled=True) - signal, signal_len, transcript, transcript_len, prompt = batch + signal, signal_len, transcript, transcript_len, prompt_indices = batch # forward() only performs encoder forward + # prompt_indices contains language ID indices [B], not full tensors + # The model creates prompt tensors after encoding using actual encoder output length if isinstance(batch, DALIOutputs) and batch.has_processed_signal: encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) else: - encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len, prompt=prompt) + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len, prompt_indices=prompt_indices) del signal # During training, loss must be computed, so decoder forward is necessary @@ -714,13 +715,14 @@ def training_step(self, batch, batch_nb): return {'loss': loss_value} def predict_step(self, batch, batch_idx, dataloader_idx=0): - signal, signal_len, transcript, transcript_len, prompt = batch + signal, signal_len, transcript, transcript_len, prompt_indices = batch # forward() only performs encoder forward + # prompt_indices contains language ID indices [B], not full tensors if isinstance(batch, DALIOutputs) and batch.has_processed_signal: encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) else: - encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len, prompt=prompt) + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len, prompt_indices=prompt_indices) del signal if self.cur_decoder == 'rnnt': @@ -744,13 +746,14 @@ def validation_pass(self, batch, batch_idx, dataloader_idx): if self.is_interctc_enabled(): AccessMixin.set_access_enabled(access_enabled=True) - signal, signal_len, transcript, transcript_len, prompt = batch + signal, signal_len, transcript, transcript_len, prompt_indices = batch # forward() only performs encoder forward + # prompt_indices contains language ID indices [B], not full tensors if isinstance(batch, DALIOutputs) and batch.has_processed_signal: encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) else: - encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len, prompt=prompt) + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len, prompt_indices=prompt_indices) del signal tensorboard_logs = {} @@ -1010,3 +1013,4 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: List of available pre-trained models. """ return None + diff --git a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py new file mode 100644 index 000000000000..0b51919ba5a6 --- /dev/null +++ b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py @@ -0,0 +1,681 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass +from math import ceil +from typing import Dict, List, Optional, Union + +import torch +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict +from pytorch_lightning import Trainer + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset, DALIOutputs +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset +from nemo.collections.asr.data.audio_to_text_lhotse_prompt import LhotseSpeechToTextBpeDatasetWithPrompt +from nemo.collections.asr.data.audio_to_text_lhotse_prompt_index import LhotseSpeechToTextBpeDatasetWithPromptIndex +from nemo.collections.asr.metrics.wer import WER +from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel +from nemo.collections.asr.parts.mixins import ASRTranscriptionMixin, TranscribeConfig +from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType +from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType +from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTBPEDecoding +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.classes.mixins import AccessMixin +from nemo.core.neural_types import ( + AcousticEncodedRepresentation, + AudioSignal, + LabelsType, + LengthsType, + NeuralType, + SpectrogramType, +) +from nemo.utils import logging, model_utils + + +@dataclass +class RNNTPromptTranscribeConfig(TranscribeConfig): + target_lang: str = "auto" + prompt_field: str = "target_lang" + + +class EncDecRNNTBPEModelWithPrompt(EncDecRNNTBPEModel, ASRTranscriptionMixin): + """Encoder-decoder RNNT model with subword tokenization and prompt conditioning. + + This is the RNNT-only variant (no auxiliary CTC head) of the prompt-aware + cache-aware streaming model. The prompt mechanism concatenates a language-ID + one-hot vector to the encoder output and projects back to the original + dimension, allowing the decoder to condition on the target language. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + cfg = model_utils.convert_model_config_to_dict_config(cfg) + cfg = model_utils.maybe_update_config_version(cfg) + + if 'tokenizer' not in cfg: + raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") + + if not isinstance(cfg, DictConfig): + cfg = OmegaConf.create(cfg) + + self._setup_tokenizer(cfg.tokenizer) + + vocabulary = self.tokenizer.tokenizer.get_vocab() + + with open_dict(cfg): + cfg.labels = ListConfig(list(vocabulary)) + + with open_dict(cfg.decoder): + cfg.decoder.vocab_size = len(vocabulary) + + with open_dict(cfg.joint): + cfg.joint.num_classes = len(vocabulary) + cfg.joint.vocabulary = ListConfig(list(vocabulary)) + cfg.joint.jointnet.encoder_hidden = cfg.model_defaults.enc_hidden + cfg.joint.jointnet.pred_hidden = cfg.model_defaults.pred_hidden + + with open_dict(cfg): + cfg.num_prompts = cfg.model_defaults.get('num_prompts', 128) + + if 'prompt_dictionary' not in cfg.model_defaults: + raise ValueError( + "No prompt_dictionary found in config. " + "Please make sure your config has a prompt_dictionary in model_defaults." + ) + + self.subsampling_factor = cfg.get('subsampling_factor', 8) + + super().__init__(cfg=cfg, trainer=trainer) + + self.concat = False + + if self.cfg.model_defaults.get('initialize_prompt_feature', False): + self.initialize_prompt_feature() + + @classmethod + def restore_from( + cls, + restore_path, + override_config_path=None, + map_location=None, + strict=True, + return_config=False, + save_restore_connector=None, + trainer=None, + validate_access_integrity=True, + ): + """Delegate to base EncDecRNNTBPEModel to avoid subclass substitution. + + NeMo's from_config_dict checks issubclass(cls, checkpoint_target_cls) + and, when True, replaces the checkpoint class with cls. Because this + class is a direct subclass of the checkpoint's target class + (EncDecRNNTBPEModel), the substitution would try to fully instantiate + EncDecRNNTBPEModelWithPrompt with the checkpoint config — which lacks + prompt_dictionary and hangs. Delegating to the parent class keeps + cls == EncDecRNNTBPEModel so the checkpoint is loaded with its own + class, matching the behaviour that naturally occurs for hybrid models. + """ + return EncDecRNNTBPEModel.restore_from( + restore_path=restore_path, + override_config_path=override_config_path, + map_location=map_location, + strict=strict, + return_config=return_config, + save_restore_connector=save_restore_connector, + trainer=trainer, + validate_access_integrity=validate_access_integrity, + ) + + def initialize_prompt_feature(self): + """Initialize model components for prompt feature via concatenation.""" + logging.info("Model with prompt feature has been initialized (RNNT-only)") + + self.concat = True + self.num_prompts = self.cfg.get('num_prompts', 128) + + proj_in_size = self.num_prompts + self._cfg.model_defaults.enc_hidden + proj_out_size = self._cfg.model_defaults.enc_hidden + + self.prompt_kernel = torch.nn.Sequential( + torch.nn.Linear(proj_in_size, proj_out_size * 2), + torch.nn.ReLU(), + torch.nn.Linear(proj_out_size * 2, proj_out_size), + ) + + self.decoding = RNNTBPEDecoding( + decoding_cfg=self.cfg.decoding, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, + ) + + self.wer = WER( + decoding=self.decoding, + batch_dim_index=0, + use_cer=self.cfg.get('use_cer', False), + log_prediction=self.cfg.get('log_prediction', True), + dist_sync_on_step=True, + ) + + if self.joint.fuse_loss_wer: + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + # ------------------------------------------------------------------ + # Data loading + # ------------------------------------------------------------------ + def _setup_dataloader_from_config(self, config: Optional[Dict]): + if config.get("use_lhotse"): + if config.get('initialize_prompt_feature', True): + dataset_config = ( + OmegaConf.to_container(config, resolve=True) + if isinstance(config, DictConfig) + else dict(config) + ) + if hasattr(self, 'cfg') and 'encoder' in self.cfg: + dataset_config['encoder'] = ( + OmegaConf.to_container(self.cfg.encoder, resolve=True) + if isinstance(self.cfg.encoder, DictConfig) + else dict(self.cfg.encoder) + ) + dataset = LhotseSpeechToTextBpeDatasetWithPromptIndex( + tokenizer=self.tokenizer, cfg=dataset_config + ) + logging.info( + "Setting up Lhotse dataset with prompt index support (RNNT-only model creates prompt tensors)" + ) + else: + dataset = LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer) + return get_lhotse_dataloader_from_config( + config, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=dataset, + tokenizer=self.tokenizer, + ) + + dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( + config=config, + local_rank=self.local_rank, + global_rank=self.global_rank, + world_size=self.world_size, + tokenizer=self.tokenizer, + preprocessor_cfg=self.cfg.get("preprocessor", None), + ) + + if dataset is None: + return None + + if isinstance(dataset, AudioToBPEDALIDataset): + return dataset + + shuffle = config['shuffle'] + if isinstance(dataset, torch.utils.data.IterableDataset): + shuffle = False + + if hasattr(dataset, 'collate_fn'): + collate_fn = dataset.collate_fn + elif hasattr(dataset.datasets[0], 'collate_fn'): + collate_fn = dataset.datasets[0].collate_fn + else: + collate_fn = dataset.datasets[0].datasets[0].collate_fn + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + if 'manifest_filepath' in config: + manifest_filepath = config['manifest_filepath'] + batch_size = config['batch_size'] + else: + manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json') + batch_size = min(config['batch_size'], len(config['paths2audio_files'])) + + target_lang = config.get('target_lang', 'en-US') + + dl_config = { + 'manifest_filepath': manifest_filepath, + 'sample_rate': self.preprocessor._sample_rate, + 'labels': self.joint.vocabulary, + 'batch_size': batch_size, + 'trim_silence': False, + 'shuffle': False, + 'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)), + 'pin_memory': True, + 'use_lhotse': config.get('use_lhotse', True), + 'use_bucketing': False, + 'drop_last': False, + 'prompt_field': config.get('prompt_field', 'target_lang'), + 'initialize_prompt_feature': True, + 'prompt_dictionary': self.cfg.model_defaults.get('prompt_dictionary'), + 'num_prompts': self.cfg.model_defaults.get('num_prompts', 128), + 'subsampling_factor': self.cfg.get('subsampling_factor', 8), + 'default_lang': target_lang, + 'window_stride': self.cfg.preprocessor.get('window_stride', 0.01), + } + + if config.get("augmentor"): + dl_config['augmentor'] = config.get("augmentor") + + return self._setup_dataloader_from_config(config=DictConfig(dl_config)) + + def setup_training_data(self, train_data_config: Optional[DictConfig]): + self._update_dataset_config(dataset_name='train', config=train_data_config) + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "training batches will be used. Please set the trainer and rebuild the dataset." + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + self._update_dataset_config(dataset_name='validation', config=val_data_config) + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + self._update_dataset_config(dataset_name='test', config=test_data_config) + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + input_signal_eltype = AudioSignal() + + return { + "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), + "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), + "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "prompt_indices": NeuralType(tuple('B'), LabelsType()), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + } + + @typecheck() + def forward( + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, + prompt_indices=None, + ): + has_input_signal = input_signal is not None and input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) is False: + raise ValueError( + f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + + if not has_processed_signal: + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, + length=input_signal_length, + ) + + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + encoded = torch.transpose(encoded, 1, 2) # B x D x T -> B x T x D + + if self.concat: + if prompt_indices is None: + raise ValueError("prompt_indices must be provided when concat mode is enabled.") + + batch_size = encoded.shape[0] + time_steps = encoded.shape[1] + num_prompts = self.num_prompts + + prompt = torch.zeros(batch_size, time_steps, num_prompts, dtype=encoded.dtype, device=encoded.device) + prompt.scatter_(2, prompt_indices.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) + + out_dtype = encoded.dtype + concat_enc_states = torch.cat([encoded, prompt], dim=-1) + encoded = self.prompt_kernel(concat_enc_states).to(out_dtype) + + encoded = torch.transpose(encoded, 1, 2) # B x T x D -> B x D x T + return encoded, encoded_len + + # ------------------------------------------------------------------ + # Training + # ------------------------------------------------------------------ + def training_step(self, batch, batch_nb): + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + signal, signal_len, transcript, transcript_len, prompt_indices = batch + + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) + else: + encoded, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len, prompt_indices=prompt_indices + ) + del signal + + decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) + + if hasattr(self, '_trainer') and self._trainer is not None: + log_every_n_steps = self._trainer.log_every_n_steps + sample_id = self._trainer.global_step + else: + log_every_n_steps = 1 + sample_id = batch_nb + + if not self.joint.fuse_loss_wer: + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + loss_value = self.add_auxiliary_losses(loss_value) + + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + tensorboard_logs = { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), + } + + if (sample_id + 1) % log_every_n_steps == 0: + self.wer.update( + predictions=encoded, + predictions_lengths=encoded_len, + targets=transcript, + targets_lengths=transcript_len, + ) + _, scores, words = self.wer.compute() + self.wer.reset() + tensorboard_logs.update({'training_batch_wer': scores.float() / words}) + + else: + if (sample_id + 1) % log_every_n_steps == 0: + compute_wer = True + else: + compute_wer = False + + loss_value, wer, _, _ = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoder, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=transcript_len, + compute_wer=compute_wer, + ) + loss_value = self.add_auxiliary_losses(loss_value) + + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + tensorboard_logs = { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), + } + + if compute_wer: + tensorboard_logs.update({'training_batch_wer': wer}) + + self.log_dict(tensorboard_logs) + + if self._optim_normalize_joint_txu: + self._optim_normalize_txu = [encoded_len.max(), transcript_len.max()] + + return {'loss': loss_value} + + # ------------------------------------------------------------------ + # Validation / Test + # ------------------------------------------------------------------ + def validation_pass(self, batch, batch_idx, dataloader_idx=0): + signal, signal_len, transcript, transcript_len, prompt_indices = batch + + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) + else: + encoded, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len, prompt_indices=prompt_indices + ) + del signal + + tensorboard_logs = {} + + if not self.joint.fuse_loss_wer: + if self.compute_eval_loss: + decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + tensorboard_logs['val_loss'] = loss_value + + self.wer.update( + predictions=encoded, + predictions_lengths=encoded_len, + targets=transcript, + targets_lengths=transcript_len, + ) + wer, wer_num, wer_denom = self.wer.compute() + self.wer.reset() + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + + else: + compute_wer = True + + if self.compute_eval_loss: + decoded, target_len, states = self.decoder(targets=transcript, target_length=transcript_len) + else: + decoded = None + target_len = transcript_len + + loss_value, wer, wer_num, wer_denom = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoded, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=target_len, + compute_wer=compute_wer, + ) + + if loss_value is not None: + tensorboard_logs['val_loss'] = loss_value + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return tensorboard_logs + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + tensorboard_logs = self.validation_pass(batch, batch_idx, dataloader_idx) + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(tensorboard_logs) + else: + self.validation_step_outputs.append(tensorboard_logs) + return tensorboard_logs + + def test_step(self, batch, batch_idx, dataloader_idx=0): + logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx) + test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()} + if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: + self.test_step_outputs[dataloader_idx].append(test_logs) + else: + self.test_step_outputs.append(test_logs) + return test_logs + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + signal, signal_len, transcript, transcript_len, prompt_indices = batch + + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) + else: + encoded, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len, prompt_indices=prompt_indices + ) + del signal + + best_hyp = self.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False + ) + + batch_size = signal_len.shape[0] + sample_id = torch.arange(batch_idx * batch_size, (batch_idx + 1) * batch_size).cpu().detach().numpy() + + return list(zip(sample_id, best_hyp)) + + # ------------------------------------------------------------------ + # Transcription + # ------------------------------------------------------------------ + def _transcribe_forward(self, batch, trcfg: RNNTPromptTranscribeConfig) -> dict: + audio, audio_lens = batch[0], batch[1] + if len(batch) >= 5: + prompt_indices = batch[4] + else: + prompt_indices = None + + batch_size = audio.shape[0] + + if prompt_indices is None: + target_lang = trcfg.target_lang + prompt_dict = self.cfg.model_defaults.get('prompt_dictionary') + num_prompts = self.cfg.model_defaults.get('num_prompts', 128) + + if not prompt_dict: + raise ValueError("Prompt dictionary is empty. Cannot create dynamic prompts.") + + if target_lang not in prompt_dict: + available_keys = list(prompt_dict.keys()) + raise ValueError( + f"Unknown target language: '{target_lang}'. " + f"Available languages: {available_keys[:10]}{'...' if len(available_keys) > 10 else ''}" + ) + + prompt_id = prompt_dict[target_lang] + prompt_indices = torch.full((batch_size,), prompt_id, dtype=torch.long, device=audio.device) + + encoded, encoded_len = self.forward( + input_signal=audio, input_signal_length=audio_lens, prompt_indices=prompt_indices + ) + + return dict(encoded=encoded, encoded_len=encoded_len) + + @torch.no_grad() + def transcribe( + self, + audio: List[str], + batch_size: int = 4, + return_hypotheses: bool = False, + partial_hypothesis: Optional[List['Hypothesis']] = None, + num_workers: int = 0, + channel_selector: Optional[ChannelSelectorType] = None, + augmentor: DictConfig = None, + verbose: bool = True, + timestamps: Optional[bool] = None, + override_config: Optional[RNNTPromptTranscribeConfig] = None, + **prompt, + ) -> TranscriptionReturnType: + if timestamps is not None: + decoding_cfg = self.cfg.decoding + if timestamps or (override_config is not None and override_config.timestamps): + return_hypotheses = True + with open_dict(decoding_cfg): + decoding_cfg.compute_timestamps = True + decoding_cfg.preserve_alignments = True + else: + with open_dict(decoding_cfg): + decoding_cfg.compute_timestamps = False + decoding_cfg.preserve_alignments = False + self.change_decoding_strategy(decoding_cfg, verbose=False) + + if override_config is None: + target_lang = prompt.get('target_lang', 'auto') + prompt_field = prompt.get('prompt_field', 'target_lang') + + trcfg = RNNTPromptTranscribeConfig( + batch_size=batch_size, + return_hypotheses=return_hypotheses, + num_workers=num_workers, + channel_selector=channel_selector, + augmentor=augmentor, + verbose=verbose, + timestamps=timestamps, + target_lang=target_lang, + prompt_field=prompt_field, + ) + else: + if not isinstance(override_config, RNNTPromptTranscribeConfig): + raise ValueError( + f"override_config must be of type {RNNTPromptTranscribeConfig}, " + f"but got {type(override_config)}" + ) + trcfg = override_config + + return super().transcribe( + audio=audio, + batch_size=batch_size, + return_hypotheses=return_hypotheses, + partial_hypothesis=partial_hypothesis, + num_workers=num_workers, + channel_selector=channel_selector, + augmentor=augmentor, + verbose=verbose, + timestamps=timestamps, + override_config=trcfg, + ) + + @classmethod + def get_transcribe_config(cls) -> RNNTPromptTranscribeConfig: + return RNNTPromptTranscribeConfig() + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + return None From abbdeb71b28ee2cbbc864a58fa48e47bda46da55 Mon Sep 17 00:00:00 2001 From: Enas Albasiri Date: Wed, 6 May 2026 14:23:53 +0000 Subject: [PATCH 2/9] Add training guide doc --- docs/TRAINING_GUIDE.md | 113 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 docs/TRAINING_GUIDE.md diff --git a/docs/TRAINING_GUIDE.md b/docs/TRAINING_GUIDE.md new file mode 100644 index 000000000000..d3a030291bd7 --- /dev/null +++ b/docs/TRAINING_GUIDE.md @@ -0,0 +1,113 @@ +# Unified Language-ID ASR Training + +This guide covers how to train cache-aware streaming ASR models with unified language-ID. + +- NeMo container: `nvcr.io/nvidia/nemo:25.11.01` +- NeMo branch: https://github.com/ealbasiri/NeMo/tree/unified_archetecture_langid + +## Overview + +There are two model variants: + +| Variant | Training Script | Model Class | Description | +|---------|----------------|-------------|-------------| +| **Hybrid RNNT+CTC** | `speech_to_text_hybrid_rnnt_ctc_bpe_prompt.py` | `EncDecHybridRNNTCTCBPEModelWithPrompt` | Trains with both RNNT and CTC losses (CTC weight = 0.1 recommended). | +| **RNNT-only** | `speech_to_text_rnnt_bpe_prompt.py` | `EncDecRNNTBPEModelWithPrompt` | Trains with RNNT loss only. | + +Both variants use the same config, tokenizer, and data format. + +## Data Format + +### Manifest files must include target_lang or lang/lang field + +```json +{"audio_filepath": "/data/audio/sample.wav", "duration": 5.2, "text": "The transcript of the audio.", "target_lang": "en-US"} +``` + +`target_lang` Language ID used as the prompt (e.g., `en-US`, `ar-AR`, `de-DE`). Mapped to a prompt index via the `prompt_dictionary` in the config. + +Multi-language example: + +```json +{"audio_filepath": "/data/en/audio_001.wav", "duration": 5.7, "text": "No, I don't think I need any further assistance.", "target_lang": "en-US"} +{"audio_filepath": "/data/ko/audio_002.wav", "duration": 5.9, "text": "선생님들 모두 어 우리 민서가 수학에 대한 흥미는", "target_lang": "ko-KR"} +{"audio_filepath": "/data/ar/audio_003.wav", "duration": 9.3, "text": "وللإشارة فإن هذا الاتفاق لا يعد اتفاقا منفصلا", "target_lang": "ar-AR"} +{"audio_filepath": "/data/fr/audio_004.wav", "duration": 4.1, "text": "Bonjour, comment allez-vous aujourd'hui?", "target_lang": "fr-FR"} +``` + + +## Config File + +The config is located in the repo at: + +``` +examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming_prompt.yaml +``` + +### Key config sections + +**Prompt dictionary** -- maps language tags to prompt indices. This is defined in `model.model_defaults.prompt_dictionary`: + +```yaml +model: + model_defaults: + num_prompts: 128 + prompt_dictionary: { + 'en-US': 0, 'en': 0, 'en-GB': 1, 'es-ES': 2, 'es': 3, + 'zh-CN': 4, 'hi-IN': 6, 'ar-AR': 7, 'fr-FR': 8, 'de-DE': 9, + 'ja-JP': 10, 'ru-RU': 11, 'pt-BR': 12, 'ko-KR': 14, 'it-IT': 15, + # ... up to 128 language/locale slots + 'auto': 101 + } +``` + + +## Training + +### Option 1: Hybrid RNNT+CTC + +This trains with both RNNT and auxiliary CTC loss. + +```bash +python3 /code/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe_prompt.py \ + --config-path=/code/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming \ + --config-name=${CONFIG_NAME} +``` + +### Option 2: RNNT-only + +Same as above but uses the RNNT-only training script: + +```bash +python3 /code/examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py \ + --config-path=/code/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming \ + --config-name=${CONFIG_NAME} +``` + + + +## Supported Languages + +The prompt dictionary supports 40+ languages. The full mapping is in the config YAML. Add new lang to the dictionary as needed. + +| Language | Tag | Prompt ID | +|----------|-----|-----------| +| English (US) | `en-US` | 0 | +| English (GB) | `en-GB` | 1 | +| Spanish (ES) | `es-ES` | 2 | +| Spanish (US) | `es` | 3 | +| Chinese | `zh-CN` | 4 | +| Hindi | `hi-IN` | 6 | +| Arabic | `ar-AR` | 7 | +| French | `fr-FR` | 8 | +| German | `de-DE` | 9 | +| Japanese | `ja-JP` | 10 | +| Russian | `ru-RU` | 11 | +| Portuguese (BR) | `pt-BR` | 12 | +| Korean | `ko-KR` | 14 | +| Italian | `it-IT` | 15 | +| Auto-detect | `auto` | 101 | + +See the full list in the config YAML under `model.model_defaults.prompt_dictionary`. + + From fea623a4d1a8a42d8de7df5f010650916a79d764 Mon Sep 17 00:00:00 2001 From: Enas Albasiri Date: Wed, 6 May 2026 14:23:53 +0000 Subject: [PATCH 3/9] Add unified prompt architecture for multilingual ASR - RNNT-only prompt model (EncDecRNNTBPEModelWithPrompt) and training script - RNNT-only streaming config (fastconformer_transducer_bpe_streaming_prompt.yaml) - Index-based dataset (LhotseSpeechToTextBpeDatasetWithPromptIndex) with per-dataset prompt_mode support (langID/auto/unified) via lhotse input_cfg tags - Backward-compatible hybrid model: accepts both old prompt tensors and new prompt_indices with auto-detection - Streaming inference: set_inference_prompt() + conformer_stream_step() override for both hybrid and RNNT-only models, with target_lang support in standard cache-aware streaming inference script - Config-driven strip_lang_tags in RNNT decoding to remove tags from output - Remove unused docs/TRAINING_GUIDE.md and hybrid streaming config --- docs/TRAINING_GUIDE.md | 113 ----- ...ech_to_text_cache_aware_streaming_infer.py | 10 + .../speech_to_text_rnnt_bpe_prompt.py | 4 +- ...ormer_transducer_bpe_streaming_prompt.yaml | 395 ++++++++++++++++ ...d_transducer_ctc_bpe_streaming_prompt.yaml | 426 ------------------ .../data/audio_to_text_lhotse_prompt_index.py | 60 ++- .../hybrid_rnnt_ctc_bpe_models_prompt.py | 278 +++++++++--- .../asr/models/rnnt_bpe_models_prompt.py | 140 +++++- .../asr/parts/submodules/rnnt_decoding.py | 11 + 9 files changed, 818 insertions(+), 619 deletions(-) delete mode 100644 docs/TRAINING_GUIDE.md create mode 100644 examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml delete mode 100644 examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming_prompt.yaml diff --git a/docs/TRAINING_GUIDE.md b/docs/TRAINING_GUIDE.md deleted file mode 100644 index d3a030291bd7..000000000000 --- a/docs/TRAINING_GUIDE.md +++ /dev/null @@ -1,113 +0,0 @@ -# Unified Language-ID ASR Training - -This guide covers how to train cache-aware streaming ASR models with unified language-ID. - -- NeMo container: `nvcr.io/nvidia/nemo:25.11.01` -- NeMo branch: https://github.com/ealbasiri/NeMo/tree/unified_archetecture_langid - -## Overview - -There are two model variants: - -| Variant | Training Script | Model Class | Description | -|---------|----------------|-------------|-------------| -| **Hybrid RNNT+CTC** | `speech_to_text_hybrid_rnnt_ctc_bpe_prompt.py` | `EncDecHybridRNNTCTCBPEModelWithPrompt` | Trains with both RNNT and CTC losses (CTC weight = 0.1 recommended). | -| **RNNT-only** | `speech_to_text_rnnt_bpe_prompt.py` | `EncDecRNNTBPEModelWithPrompt` | Trains with RNNT loss only. | - -Both variants use the same config, tokenizer, and data format. - -## Data Format - -### Manifest files must include target_lang or lang/lang field - -```json -{"audio_filepath": "/data/audio/sample.wav", "duration": 5.2, "text": "The transcript of the audio.", "target_lang": "en-US"} -``` - -`target_lang` Language ID used as the prompt (e.g., `en-US`, `ar-AR`, `de-DE`). Mapped to a prompt index via the `prompt_dictionary` in the config. - -Multi-language example: - -```json -{"audio_filepath": "/data/en/audio_001.wav", "duration": 5.7, "text": "No, I don't think I need any further assistance.", "target_lang": "en-US"} -{"audio_filepath": "/data/ko/audio_002.wav", "duration": 5.9, "text": "선생님들 모두 어 우리 민서가 수학에 대한 흥미는", "target_lang": "ko-KR"} -{"audio_filepath": "/data/ar/audio_003.wav", "duration": 9.3, "text": "وللإشارة فإن هذا الاتفاق لا يعد اتفاقا منفصلا", "target_lang": "ar-AR"} -{"audio_filepath": "/data/fr/audio_004.wav", "duration": 4.1, "text": "Bonjour, comment allez-vous aujourd'hui?", "target_lang": "fr-FR"} -``` - - -## Config File - -The config is located in the repo at: - -``` -examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming_prompt.yaml -``` - -### Key config sections - -**Prompt dictionary** -- maps language tags to prompt indices. This is defined in `model.model_defaults.prompt_dictionary`: - -```yaml -model: - model_defaults: - num_prompts: 128 - prompt_dictionary: { - 'en-US': 0, 'en': 0, 'en-GB': 1, 'es-ES': 2, 'es': 3, - 'zh-CN': 4, 'hi-IN': 6, 'ar-AR': 7, 'fr-FR': 8, 'de-DE': 9, - 'ja-JP': 10, 'ru-RU': 11, 'pt-BR': 12, 'ko-KR': 14, 'it-IT': 15, - # ... up to 128 language/locale slots - 'auto': 101 - } -``` - - -## Training - -### Option 1: Hybrid RNNT+CTC - -This trains with both RNNT and auxiliary CTC loss. - -```bash -python3 /code/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe_prompt.py \ - --config-path=/code/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming \ - --config-name=${CONFIG_NAME} -``` - -### Option 2: RNNT-only - -Same as above but uses the RNNT-only training script: - -```bash -python3 /code/examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py \ - --config-path=/code/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming \ - --config-name=${CONFIG_NAME} -``` - - - -## Supported Languages - -The prompt dictionary supports 40+ languages. The full mapping is in the config YAML. Add new lang to the dictionary as needed. - -| Language | Tag | Prompt ID | -|----------|-----|-----------| -| English (US) | `en-US` | 0 | -| English (GB) | `en-GB` | 1 | -| Spanish (ES) | `es-ES` | 2 | -| Spanish (US) | `es` | 3 | -| Chinese | `zh-CN` | 4 | -| Hindi | `hi-IN` | 6 | -| Arabic | `ar-AR` | 7 | -| French | `fr-FR` | 8 | -| German | `de-DE` | 9 | -| Japanese | `ja-JP` | 10 | -| Russian | `ru-RU` | 11 | -| Portuguese (BR) | `pt-BR` | 12 | -| Korean | `ko-KR` | 14 | -| Italian | `it-IT` | 15 | -| Auto-detect | `auto` | 101 | - -See the full list in the config YAML under `model.model_defaults.prompt_dictionary`. - - diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py index 858feda00fa0..96b3d307fcb8 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py @@ -186,6 +186,11 @@ class TranscriptionConfig: # use_cer: bool = False debug_mode: bool = False # Whether to print more detail in the output. + # Language-ID prompt for prompt-conditioned models (e.g. EncDecRNNTBPEModelWithPrompt). + # Set to a language key from the model's prompt_dictionary (e.g. "en-US", "auto"). + # Ignored for models without prompt support. + target_lang: Optional[str] = None + def extract_transcriptions(hyps): """ @@ -363,6 +368,11 @@ def main(cfg: TranscriptionConfig): else: asr_model.change_decoding_strategy(cfg.ctc_decoding) + # Set language-ID prompt for prompt-conditioned models + if hasattr(asr_model, 'set_inference_prompt'): + lang = cfg.target_lang if cfg.target_lang is not None else "auto" + asr_model.set_inference_prompt(lang) + asr_model = asr_model.to(device=device, dtype=compute_dtype) asr_model.eval() diff --git a/examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py b/examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py index cd1a76d90e57..ae6895b60412 100644 --- a/examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py +++ b/examples/asr/asr_transducer/speech_to_text_rnnt_bpe_prompt.py @@ -71,8 +71,8 @@ @hydra_runner( - config_path="../conf/fastconformer/hybrid_cache_aware_streaming/", - config_name="fastconformer_hybrid_transducer_ctc_bpe_streaming_prompt_600m.yaml", + config_path="../conf/fastconformer/cache_aware_streaming/", + config_name="fastconformer_transducer_bpe_streaming_prompt.yaml", ) def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml new file mode 100644 index 000000000000..7383eda7f40d --- /dev/null +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml @@ -0,0 +1,395 @@ +# Cache-aware streaming FastConformer-Transducer (RNNT-only) ASR model with prompt support +# Combines cache-aware streaming encoder with prompt-based multilingual capability +# This is the RNNT-only variant (no auxiliary CTC head). + +# You may find more detail: +# FastConformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#fast-conformer +# Cache-aware Conformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#cache-aware-streaming-conformer + +name: "FastConformer-Transducer-BPE-Prompt-Streaming" + +model: + sample_rate: 16000 + compute_eval_loss: false + log_prediction: true + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + initialize_prompt_feature: true + num_prompts: 128 + norm: None + prompt_dictionary: { + 'en-US': 0, + 'en': 0, + 'en-GB': 1, + 'enGB': 1, + 'es-ES': 2, + 'esES': 2, + 'es-US': 3, + 'es': 3, + 'zh-CN': 4, + 'zh-ZH': 4, + 'zh-TW': 5, + 'hi-IN': 6, + 'hi': 6, + 'hi-HI': 6, + 'ar-AR': 7, + 'ar': 7, + 'fr-FR': 8, + 'fr': 8, + 'de-DE': 9, + 'de': 9, + 'ja-JP': 10, + 'ja-JA': 10, + 'ru-RU': 11, + 'ru': 11, + 'pt-BR': 12, + 'pt-PT': 13, + 'pt': 13, + 'ko-KR': 14, + 'ko': 14, + 'ko-KO': 14, + 'it-IT': 15, + 'it': 15, + 'nl-NL': 16, + 'nl': 16, + 'pl-PL': 17, + 'pl': 17, + 'tr-TR': 18, + 'tr': 18, + 'uk-UA': 19, + 'uk': 19, + 'ro-RO': 20, + 'ro': 20, + 'el-GR': 21, + 'el': 21, + 'cs-CZ': 22, + 'cs': 22, + 'hu-HU': 23, + 'hu': 23, + 'sv-SE': 24, + 'sv': 24, + 'da-DK': 25, + 'da': 25, + 'fi-FI': 26, + 'fi': 26, + 'no-NO': 27, + 'no': 27, + 'nb-NO': 103, + 'nb': 103, + 'sk-SK': 28, + 'sk': 28, + 'hr-HR': 29, + 'hr': 29, + 'bg-BG': 30, + 'bg': 30, + 'lt-LT': 31, + 'lt': 31, + 'et-EE': 60, + 'et': 60, + 'lv-LV': 61, + 'lv': 61, + 'sl-SI': 62, + 'sl': 62, + 'th-TH': 32, + 'vi-VN': 33, + 'id-ID': 34, + 'ms-MY': 35, + 'bn-IN': 36, + 'ur-PK': 37, + 'fa-IR': 38, + 'ta-IN': 39, + 'te-IN': 40, + 'mr-IN': 41, + 'gu-IN': 42, + 'kn-IN': 43, + 'ml-IN': 44, + 'si-LK': 45, + 'ne-NP': 46, + 'km-KH': 47, + 'sw-KE': 48, + 'am-ET': 49, + 'ha-NG': 50, + 'zu-ZA': 51, + 'yo-NG': 52, + 'ig-NG': 53, + 'af-ZA': 54, + 'rw-RW': 55, + 'so-SO': 56, + 'ny-MW': 57, + 'ln-CD': 58, + 'or-KE': 59, + 'he-IL': 64, + 'ku-TR': 65, + 'az-AZ': 66, + 'ka-GE': 67, + 'hy-AM': 68, + 'uz-UZ': 69, + 'tg-TJ': 70, + 'ky-KG': 71, + 'qu-PE': 80, + 'ay-BO': 81, + 'gn-PY': 82, + 'nah-MX': 83, + 'mi-NZ': 96, + 'haw-US': 97, + 'sm-WS': 98, + 'to-TO': 99, + 'fr-CA': 100, + 'mt-MT': 102, + 'auto': 101 + } + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + use_lhotse: true + shard_manifests: true + batch_duration: 400 + quadratic_duration: 15 + num_buckets: 30 + shuffle: true + num_workers: 8 + pin_memory: true + max_duration: 39.99 + min_duration: 0.1 + is_tarred: true + tarred_audio_filepaths: null + shuffle_n: 2048 + slice_length: 100 + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + prompt_field: target_lang + prompt_dictionary: ${model.model_defaults.prompt_dictionary} + num_prompts: ${model.model_defaults.num_prompts} + subsampling_factor: ${model.encoder.subsampling_factor} + lang_field: target_lang + training_mode: true + + # Per-dataset prompt mode — controls how language prompts are selected during training. + # The mode is set per data source via lhotse input_cfg tags: + # tags: { prompt_mode: unified } + # + # Supported prompt_mode values: + # "langID" — always pass the real language ID (use for AST / language-forced tasks) + # "auto" — always pass auto/language-agnostic prompt (use for code-switching) + # "unified" — randomly choose auto vs lang ID (use for multilingual ASR, default) + # + # unified_auto_ratio controls the probability of selecting auto in "unified" mode. + prompt_mode_field: prompt_mode + default_prompt_mode: unified + unified_auto_ratio: 0.5 + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 2 + shuffle: false + use_start_end_token: false + num_workers: 2 + pin_memory: true + batch_duration: null + use_lhotse: true + use_bucketing: false + max_cuts: 8 + prompt_field: target_lang + prompt_dictionary: ${model.model_defaults.prompt_dictionary} + num_prompts: ${model.model_defaults.num_prompts} + subsampling_factor: ${model.encoder.subsampling_factor} + training_mode: true + + test_ds: + manifest_filepath: null + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + use_start_end_token: false + num_workers: 8 + pin_memory: true + use_lhotse: true + use_bucketing: false + prompt_field: target_lang + prompt_dictionary: ${model.model_defaults.prompt_dictionary} + num_prompts: ${model.model_defaults.num_prompts} + subsampling_factor: ${model.encoder.subsampling_factor} + training_mode: false + + tokenizer: + dir: ??? + type: bpe + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "NA" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 + n_layers: 42 + d_model: 1024 + use_bias: false + + subsampling: dw_striding + subsampling_factor: 8 + subsampling_conv_channels: 256 + causal_downsampling: true + + reduction: null + reduction_position: null + reduction_factor: 1 + + ff_expansion_factor: 4 + + self_attention_model: rel_pos + n_heads: 8 + att_context_size: [70, 6] + att_context_probs: null + att_context_style: chunked_limited + xscaling: false + untie_biases: true + pos_emb_max_len: 5000 + + conv_kernel_size: 9 + conv_norm_type: 'layer_norm' + conv_context_size: causal + + dropout: 0.1 + dropout_pre_encoder: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 + + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null + random_state_sampling: false + blank_as_pad: true + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 2 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null + preserve_memory: false + + fuse_loss_wer: true + fused_batch_size: 2 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + + decoding: + strategy: "greedy_batch" + # Strip language-ID tags (e.g. ) from decoded output during inference. + strip_lang_tags: true + + greedy: + max_symbols: 10 + + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 + alsd_max_target_len: 2.0 + + interctc: + loss_weights: [] + apply_at_layers: [] + + loss: + loss_name: "default" + + warprnnt_numba_kwargs: + fastemit_lambda: 5e-3 + clamp: -1.0 + + variational_noise: + start_step: 0 + std: 0.0 + + optim: + name: adamw + lr: 2.0 + betas: [0.9, 0.98] + weight_decay: 1e-3 + + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 + num_nodes: 1 + max_epochs: -1 + max_steps: 500000 + val_check_interval: 0.5 + accelerator: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + gradient_as_bucket_view: true + accumulate_grad_batches: 1 + gradient_clip_val: 0.5 + precision: bf16 + log_every_n_steps: 100 + enable_progress_bar: True + num_sanity_val_steps: 0 + sync_batchnorm: true + enable_checkpointing: False + logger: false + benchmark: false + use_distributed_sampler: false + limit_train_batches: 1000 + + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True + resume_if_exists: false + resume_ignore_no_checkpoint: false + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming_prompt.yaml b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming_prompt.yaml deleted file mode 100644 index 4c933d9e40e8..000000000000 --- a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming_prompt.yaml +++ /dev/null @@ -1,426 +0,0 @@ -# Cache-aware streaming FastConformer-Hybrid-Transducer-CTC ASR model with prompt support -# Combines cache-aware streaming encoder with prompt-based multilingual capability - -# You may find more detail: -# FastConformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#fast-conformer -# Hybrid ASR: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#hybrid-transducer-ctc -# Cache-aware Conformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#cache-aware-streaming-conformer - -name: "FastConformer-Hybrid-Transducer-CTC-BPE-Prompt-Streaming" - -model: - sample_rate: 16000 - compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. - log_prediction: true # enables logging sample predictions in the output during training - skip_nan_grad: false - - model_defaults: - enc_hidden: ${model.encoder.d_model} - pred_hidden: 640 - joint_hidden: 640 - # Prompt configuration - initialize_prompt_feature: true - num_prompts: 128 - norm: None - # Dictionary mapping prompt identifiers to their corresponding embedding indices - prompt_dictionary: { - # Language prompts (0-99) - 'en-US': 0, - 'en': 0, - 'en-GB': 1, - 'enGB': 1, - 'es-ES': 2, - 'esES': 2, - 'es-US': 3, - 'es': 3, - 'zh-CN': 4, - 'zh-ZH': 4, - 'zh-TW': 5, - 'hi-IN': 6, - 'hi': 6, - 'hi-HI': 6, - 'ar-AR': 7, - 'ar': 7, - 'fr-FR': 8, - 'fr': 8, - 'de-DE': 9, - 'de': 9, - 'ja-JP': 10, - 'ja-JA': 10, - 'ru-RU': 11, - 'ru': 11, - 'pt-BR': 12, - 'pt-PT': 13, - 'pt': 13, - 'ko-KR': 14, - 'ko': 14, - 'ko-KO': 14, - 'it-IT': 15, - 'it': 15, - 'nl-NL': 16, - 'nl': 16, - 'pl-PL': 17, - 'pl': 17, - 'tr-TR': 18, - 'tr': 18, - 'uk-UA': 19, - 'uk': 19, - 'ro-RO': 20, - 'ro': 20, - 'el-GR': 21, - 'el': 21, - 'cs-CZ': 22, - 'cs': 22, - 'hu-HU': 23, - 'hu': 23, - 'sv-SE': 24, - 'sv': 24, - 'da-DK': 25, - 'da': 25, - 'fi-FI': 26, - 'fi': 26, - 'no-NO': 27, - 'no': 27, - 'nb-NO': 103, - 'nb': 103, - 'sk-SK': 28, - 'sk': 28, - 'hr-HR': 29, - 'hr': 29, - 'bg-BG': 30, - 'bg': 30, - 'lt-LT': 31, - 'lt': 31, - # Granary languages (60-62) - 'et-EE': 60, - 'et': 60, - 'lv-LV': 61, - 'lv': 61, - 'sl-SI': 62, - 'sl': 62, - 'th-TH': 32, - 'vi-VN': 33, - 'id-ID': 34, - 'ms-MY': 35, - 'bn-IN': 36, - 'ur-PK': 37, - 'fa-IR': 38, - 'ta-IN': 39, - 'te-IN': 40, - 'mr-IN': 41, - 'gu-IN': 42, - 'kn-IN': 43, - 'ml-IN': 44, - 'si-LK': 45, - 'ne-NP': 46, - 'km-KH': 47, - 'sw-KE': 48, - 'am-ET': 49, - 'ha-NG': 50, - 'zu-ZA': 51, - 'yo-NG': 52, - 'ig-NG': 53, - 'af-ZA': 54, - 'rw-RW': 55, - 'so-SO': 56, - 'ny-MW': 57, - 'ln-CD': 58, - 'or-KE': 59, - 'he-IL': 64, - 'ku-TR': 65, - 'az-AZ': 66, - 'ka-GE': 67, - 'hy-AM': 68, - 'uz-UZ': 69, - 'tg-TJ': 70, - 'ky-KG': 71, - 'qu-PE': 80, - 'ay-BO': 81, - 'gn-PY': 82, - 'nah-MX': 83, - 'mi-NZ': 96, - 'haw-US': 97, - 'sm-WS': 98, - 'to-TO': 99, - 'fr-CA': 100, - 'mt-MT': 102, - 'auto': 101 - } - - train_ds: - manifest_filepath: /config/asr_P0_P1_riva_training_data_v2_fixed_spacing_with_lang_tag.yaml - sample_rate: ${model.sample_rate} - use_lhotse: true - shard_manifests: true - batch_duration: 400 - quadratic_duration: 15 - num_buckets: 30 - shuffle: true - num_workers: 8 - pin_memory: true - max_duration: 39.99 # you may need to update it for your dataset - min_duration: 0.1 - # tarred datasets - is_tarred: true - tarred_audio_filepaths: null - shuffle_n: 2048 - slice_length: 100 - # bucketing params - bucketing_strategy: "fully_randomized" - bucketing_batch_size: null - bucket_buffer_size: 10000 - shuffle_buffer_size: 10000 - # prompt configs - prompt_field: target_lang - prompt_dictionary: ${model.model_defaults.prompt_dictionary} - num_prompts: ${model.model_defaults.num_prompts} - subsampling_factor: ${model.encoder.subsampling_factor} - lang_field: target_lang - training_mode: true # 50% use auto (101), 50% use actual lang ID - - validation_ds: - manifest_filepath: ??? - sample_rate: ${model.sample_rate} - batch_size: 2 - shuffle: false - use_start_end_token: false - num_workers: 2 - pin_memory: true - batch_duration: null - use_lhotse: true - use_bucketing: false - max_cuts: 8 - # prompt configurations for validation - prompt_field: target_lang - prompt_dictionary: ${model.model_defaults.prompt_dictionary} - num_prompts: ${model.model_defaults.num_prompts} - subsampling_factor: ${model.encoder.subsampling_factor} - training_mode: true # pass lang ID 50% - - test_ds: - manifest_filepath: '/manifests/fleurs/fleurs_emea_emea_lhotse.dev.ast_update_lang_id_cross_only_unique_.json' - sample_rate: ${model.sample_rate} - batch_size: 16 - shuffle: false - use_start_end_token: false - num_workers: 8 - pin_memory: true - use_lhotse: true - use_bucketing: false - # prompt configurations for testing - prompt_field: target_lang - prompt_dictionary: ${model.model_defaults.prompt_dictionary} - num_prompts: ${model.model_defaults.num_prompts} - subsampling_factor: ${model.encoder.subsampling_factor} - training_mode: false # Always use actual lang ID during testing - - # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py - tokenizer: - dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) - type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) - - preprocessor: - _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor - sample_rate: ${model.sample_rate} - normalize: "NA" # No normalization for mel-spectogram makes streaming easier - window_size: 0.025 - window_stride: 0.01 - window: "hann" - features: 128 - n_fft: 512 - frame_splicing: 1 - dither: 0.00001 - pad_to: 0 - - spec_augment: - _target_: nemo.collections.asr.modules.SpectrogramAugmentation - freq_masks: 2 # set to zero to disable it - time_masks: 10 # set to zero to disable it - freq_width: 27 - time_width: 0.05 - - encoder: - _target_: nemo.collections.asr.modules.ConformerEncoder - feat_in: ${model.preprocessor.features} - feat_out: -1 # you may set it if you need different output size other than the default d_model - n_layers: 42 - d_model: 1024 - use_bias: false - - # Sub-sampling parameters - subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding - subsampling_factor: 8 # must be power of 2 for striding and vggnet - subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model - causal_downsampling: true # Required for streaming - - # Reduction parameters: Can be used to add another subsampling layer at a given position. - reduction: null # pooling, striding, or null - reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder - reduction_factor: 1 - - # Feed forward module's params - ff_expansion_factor: 4 - - # Multi-headed Attention Module's params - self_attention_model: rel_pos # rel_pos or abs_pos - n_heads: 8 # may need to be lower for smaller d_models - # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention - # For multi-lookahead models, you may specify a list of context sizes. - # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s - att_context_size: [70, 6] # look-ahead = 6*8*0.01 = 0.48s - att_context_probs: null - att_context_style: chunked_limited # regular or chunked_limited - xscaling: false - untie_biases: true # unties the biases of the TransformerXL layers - pos_emb_max_len: 5000 - - # Convolution module's params - conv_kernel_size: 9 - conv_norm_type: 'layer_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) - # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size - # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] - conv_context_size: causal # Required for streaming - - ### regularization - dropout: 0.1 # The dropout used in most of the Conformer Modules - dropout_pre_encoder: 0.1 # The dropout used before the encoder - dropout_emb: 0.0 # The dropout used for embeddings - dropout_att: 0.1 # The dropout for multi-headed attention modules - - # set to non-zero to enable stochastic depth - stochastic_depth_drop_prob: 0.0 - stochastic_depth_mode: linear # linear or uniform - stochastic_depth_start_layer: 1 - - decoder: - _target_: nemo.collections.asr.modules.RNNTDecoder - normalization_mode: null # Currently only null is supported for export. - random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf - blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference. - - prednet: - pred_hidden: ${model.model_defaults.pred_hidden} - pred_rnn_layers: 2 - t_max: null - dropout: 0.2 - - joint: - _target_: nemo.collections.asr.modules.RNNTJoint - log_softmax: null # 'null' would set it automatically according to CPU/GPU device - preserve_memory: false # dramatically slows down training, but might preserve some memory - - # Fuses the computation of prediction net + joint net + loss + WER calculation - # to be run on sub-batches of size `fused_batch_size`. - fuse_loss_wer: true - fused_batch_size: 2 - - jointnet: - joint_hidden: ${model.model_defaults.joint_hidden} - activation: "relu" - dropout: 0.2 - - decoding: - strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. - - # greedy strategy config - greedy: - max_symbols: 10 - - # beam strategy config - beam: - beam_size: 2 - return_best_hypothesis: False - score_norm: true - tsd_max_sym_exp: 50 # for Time Synchronous Decoding - alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding - - # The section which would contain the decoder and decoding configs of the auxiliary CTC decoder - aux_ctc: - ctc_loss_weight: 0.1 # the weight used to combine the CTC loss with the RNNT loss - use_cer: false - ctc_reduction: 'mean_batch' - decoder: - _target_: nemo.collections.asr.modules.ConvASRDecoder - feat_in: null - num_classes: -1 - vocabulary: [] - decoding: - strategy: "greedy" - - # config for InterCTC loss: https://arxiv.org/abs/2102.03216 - interctc: - loss_weights: [] - apply_at_layers: [] - - loss: - loss_name: "default" - - warprnnt_numba_kwargs: - # FastEmit regularization: https://arxiv.org/abs/2010.11148 - # You may enable FastEmit to reduce the latency of the model for streaming - fastemit_lambda: 5e-3 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. - clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. - - # Adds Gaussian noise to the gradients of the decoder to avoid overfitting - variational_noise: - start_step: 0 - std: 0.0 - - optim: - name: adamw - lr: 2.0 - # optimizer arguments - betas: [0.9, 0.98] - weight_decay: 1e-3 - - # scheduler setup - sched: - name: NoamAnnealing - d_model: ${model.encoder.d_model} - # scheduler config override - warmup_steps: 10000 - warmup_ratio: null - min_lr: 1e-6 - -trainer: - devices: -1 # number of GPUs, -1 would use all available GPUs - num_nodes: 1 - max_epochs: -1 - max_steps: 500000 # computed at runtime if not set - val_check_interval: 0.5 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations - accelerator: auto - strategy: - _target_: lightning.pytorch.strategies.DDPStrategy - gradient_as_bucket_view: true - accumulate_grad_batches: 1 - gradient_clip_val: 0.5 - precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. - log_every_n_steps: 100 # Interval of logging. - enable_progress_bar: True - num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it - sync_batchnorm: true - enable_checkpointing: False # Provided by exp_manager - logger: false # Provided by exp_manager - benchmark: false # needs to be false for models with variable-length speech input as it slows down training - use_distributed_sampler: false - limit_train_batches: 1000 - - -exp_manager: - exp_dir: null - name: ${name} - create_tensorboard_logger: true - create_checkpoint_callback: true - checkpoint_callback_params: - # in case of multiple validation sets, first one is used - monitor: "val_wer" - mode: "min" - save_top_k: 5 - always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints - resume_if_exists: false - resume_ignore_no_checkpoint: false - - create_wandb_logger: false - wandb_logger_kwargs: - name: null - project: null diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py index 35bd9a090790..f3c93bc3557a 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py @@ -73,10 +73,26 @@ def __init__(self, tokenizer, cfg): # Field to use for prompt key (default to 'target_lang') self.prompt_field = cfg.get('prompt_field', 'target_lang') - # Training mode flag: when True, randomly use auto (101) 50% of the time self.training_mode = cfg.get('training_mode', True) - logging.info(f"LhotseSpeechToTextBpeDatasetWithPromptIndex: Returns indices only, model creates prompt tensor") + # Per-dataset prompt mode is read from cut.custom["prompt_mode"] at runtime. + # Supported values: + # "langID" — always pass the real language ID + # "auto" — always pass auto (101) + # "unified" — randomize: auto with probability unified_auto_ratio, else lang ID + # Set via lhotse input_cfg tags, e.g. tags: { prompt_mode: langID } + self.prompt_mode_field = cfg.get('prompt_mode_field', 'prompt_mode') + self.default_prompt_mode = cfg.get('default_prompt_mode', 'unified') + self.unified_auto_ratio = cfg.get('unified_auto_ratio', 0.5) + + # Index used for the language-agnostic / auto prompt + self.auto_index = self.prompt_dict.get('auto', 101) + + logging.info( + f"LhotseSpeechToTextBpeDatasetWithPromptIndex: " + f"default_prompt_mode={self.default_prompt_mode}, " + f"unified_auto_ratio={self.unified_auto_ratio}" + ) def _get_prompt_index(self, prompt_key: str) -> int: """Maps prompt keys to indices using the prompt dictionary.""" @@ -87,15 +103,45 @@ def _get_prompt_index(self, prompt_key: str) -> int: ) return self.prompt_dict[prompt_key] + def _get_prompt_mode(self, cut) -> str: + """Resolve the prompt_mode for a cut from its custom tags.""" + if cut.custom is not None: + mode = cut.custom.get(self.prompt_mode_field) + if mode is not None: + return mode + return self.default_prompt_mode + def _get_prompt_index_for_cut(self, cut) -> int: """ - Get prompt index for a cut, with training mode randomization. - During training: 50% chance to use auto (101), 50% actual language ID - During inference: always use the actual language ID + Determine the prompt index for a cut based on its prompt_mode tag. + + During inference (training_mode=False): always returns the real lang ID + regardless of prompt_mode. + + During training, behaviour depends on prompt_mode (set per-dataset via + lhotse input_cfg tags): + "langID" — always return the real language ID + "auto" — always return auto index (language-agnostic) + "unified" — return auto with probability unified_auto_ratio, + otherwise the real language ID """ - if self.training_mode and random.random() < 0.5: - return 101 # Auto/language-agnostic + if not self.training_mode: + return self._get_prompt_index(cut.supervisions[0].language) + + mode = self._get_prompt_mode(cut) + + if mode == 'langID': + return self._get_prompt_index(cut.supervisions[0].language) + elif mode == 'auto': + return self.auto_index + elif mode == 'unified': + if random.random() < self.unified_auto_ratio: + return self.auto_index + return self._get_prompt_index(cut.supervisions[0].language) else: + logging.warning(f"Unknown prompt_mode '{mode}', falling back to unified") + if random.random() < self.unified_auto_ratio: + return self.auto_index return self._get_prompt_index(cut.supervisions[0].language) def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py index 82a68e020cac..83dc6b42933f 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import os from dataclasses import dataclass from math import ceil @@ -31,6 +30,7 @@ from nemo.collections.asr.metrics.wer import WER from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel from nemo.collections.asr.parts.mixins import ASRTranscriptionMixin, TranscribeConfig +from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.ctc_decoding import CTCBPEDecoding, CTCBPEDecodingConfig @@ -198,6 +198,168 @@ def initialize_prompt_feature(self): # setting the RNNT decoder as the default one self.cur_decoder = "rnnt" + # Streaming inference with language-ID prompt + + def set_inference_prompt(self, target_lang: str): + """ + Set the language prompt for streaming inference. + + Call this before ``conformer_stream_step`` to condition decoding on + a specific language, following the same pattern as + ``change_decoding_strategy``. + + Args: + target_lang: A key from the model's ``prompt_dictionary`` + (e.g. ``"en-US"``, ``"auto"``). + """ + prompt_dict = self.cfg.model_defaults.get('prompt_dictionary', {}) + if target_lang not in prompt_dict: + available = list(prompt_dict.keys()) + raise ValueError( + f"Unknown target language '{target_lang}'. " + f"Available: {available[:20]}{'...' if len(available) > 20 else ''}" + ) + self._inference_prompt_index = prompt_dict[target_lang] + logging.info(f"Inference prompt set to '{target_lang}' (index {self._inference_prompt_index})") + + def _apply_prompt_to_encoded(self, encoded: torch.Tensor) -> torch.Tensor: + """ + Inject the language-ID prompt into encoder output during streaming. + + ``encoded`` arrives as (B, D, T) from the encoder cache-aware step. + Returns the same shape after prompt concatenation + projection. + """ + if not self.concat or not hasattr(self, '_inference_prompt_index'): + return encoded + + encoded = encoded.transpose(1, 2) # (B, D, T) -> (B, T, D) + + batch_size, time_steps, _ = encoded.shape + prompt = torch.zeros( + batch_size, time_steps, self.num_prompts, + dtype=encoded.dtype, device=encoded.device, + ) + idx = torch.full( + (batch_size,), self._inference_prompt_index, + dtype=torch.long, device=encoded.device, + ) + prompt.scatter_(2, idx.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) + + out_dtype = encoded.dtype + encoded = self.prompt_kernel(torch.cat([encoded, prompt], dim=-1)).to(out_dtype) + return encoded.transpose(1, 2) # (B, T, D) -> (B, D, T) + + def conformer_stream_step( + self, + processed_signal, + processed_signal_length=None, + cache_last_channel=None, + cache_last_time=None, + cache_last_channel_len=None, + keep_all_outputs=True, + previous_hypotheses=None, + previous_pred_out=None, + drop_extra_pre_encoded=None, + return_transcription=True, + return_log_probs=False, + bypass_pre_encode=False, + ): + """Cache-aware streaming step with language-ID prompt injection. + + Identical to the base ``ASRModuleMixin.conformer_stream_step`` except + that after the encoder step, ``_apply_prompt_to_encoded`` concatenates + the one-hot language prompt and projects back to enc_hidden. + + Set the target language via ``set_inference_prompt(target_lang)`` + before calling this method. + """ + import nemo.collections.asr.models as asr_models + + if not isinstance(self.encoder, StreamingEncoder): + raise NotImplementedError("Encoder does not support streaming!") + + ( + encoded, + encoded_len, + cache_last_channel_next, + cache_last_time_next, + cache_last_channel_next_len, + ) = self.encoder.cache_aware_stream_step( + processed_signal=processed_signal, + processed_signal_length=processed_signal_length, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + keep_all_outputs=keep_all_outputs, + drop_extra_pre_encoded=drop_extra_pre_encoded, + bypass_pre_encode=bypass_pre_encode, + ) + + encoded = self._apply_prompt_to_encoded(encoded) + + if isinstance(self, asr_models.EncDecCTCModel) or ( + isinstance(self, asr_models.EncDecHybridRNNTCTCModel) and self.cur_decoder == "ctc" + ): + if hasattr(self, "ctc_decoder"): + decoding = self.ctc_decoding + decoder = self.ctc_decoder + else: + decoding = self.decoding + decoder = self.decoder + + log_probs = decoder(encoder_output=encoded) + predictions_tensor = log_probs.argmax(dim=-1, keepdim=False) + + greedy_predictions = [] + if return_transcription: + all_hyp_or_transcribed_texts = [] + else: + all_hyp_or_transcribed_texts = None + + for preds_idx, preds in enumerate(predictions_tensor): + if encoded_len is None: + preds_cur = predictions_tensor[preds_idx] + else: + preds_cur = predictions_tensor[preds_idx, : encoded_len[preds_idx]] + if previous_pred_out is not None: + greedy_predictions_concat = torch.cat((previous_pred_out[preds_idx], preds_cur), dim=-1) + encoded_len[preds_idx] += len(previous_pred_out[preds_idx]) + else: + greedy_predictions_concat = preds_cur + greedy_predictions.append(greedy_predictions_concat) + + if return_transcription: + decoded_out = decoding.ctc_decoder_predictions_tensor( + decoder_outputs=greedy_predictions_concat.unsqueeze(0), + decoder_lengths=encoded_len[preds_idx : preds_idx + 1], + return_hypotheses=False, + ) + all_hyp_or_transcribed_texts.append(decoded_out[0]) + best_hyp = None + else: + best_hyp = self.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoded, + encoded_lengths=encoded_len, + return_hypotheses=True, + partial_hypotheses=previous_hypotheses, + ) + greedy_predictions = [hyp.y_sequence for hyp in best_hyp] + all_hyp_or_transcribed_texts = best_hyp + + result = [ + greedy_predictions, + all_hyp_or_transcribed_texts, + cache_last_channel_next, + cache_last_time_next, + cache_last_channel_next_len, + best_hyp, + ] + if return_log_probs: + result.append(log_probs) + result.append(encoded_len) + + return tuple(result) + def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): if config.get('initialize_prompt_feature', True): @@ -326,32 +488,25 @@ def _transcribe_forward(self, batch: tuple[torch.Tensor, ...], trcfg: HybridRNNT Returns: The model's outputs that are processed by `_transcribe_output_processing()`. """ - # Handling DataLoader batch - should be a tuple of tensors - # Expected structure: (audio, audio_lens, tokens, token_lens, prompt_indices) - # For transcription, we may only have (audio, audio_lens) or (audio, audio_lens, ..., prompt_indices) audio, audio_lens = batch[0], batch[1] + prompt, prompt_indices = None, None + if len(batch) >= 5: - # Prompt indices provided by the dataloader (language ID indices) - prompt_indices = batch[4] # This should be the prompt_indices from dataset - else: - # Prompt to be built dynamically. - prompt_indices = None + prompt_or_indices = batch[4] + if prompt_or_indices.dim() == 1: + prompt_indices = prompt_or_indices + else: + prompt = prompt_or_indices batch_size = audio.shape[0] - if prompt_indices is None: - # The dataloader provided only audio + audio_lens, so we need to construct - # the prompt indices dynamically using TranscribeConfig. + if prompt is None and prompt_indices is None: target_lang = trcfg.target_lang - - # Get prompt dictionary from model config prompt_dict = self.cfg.model_defaults.get('prompt_dictionary') - num_prompts = self.cfg.model_defaults.get('num_prompts', 128) if not prompt_dict: raise ValueError("Prompt dictionary is empty. Cannot create dynamic prompts.") - # Get the prompt index for the target language if target_lang not in prompt_dict: available_keys = list(prompt_dict.keys()) raise ValueError( @@ -359,14 +514,11 @@ def _transcribe_forward(self, batch: tuple[torch.Tensor, ...], trcfg: HybridRNNT ) prompt_id = prompt_dict[target_lang] - - # Create prompt index tensor for the batch - forward() will create the one-hot prompt - # from these indices using the actual encoder output length prompt_indices = torch.full((batch_size,), prompt_id, dtype=torch.long, device=audio.device) - # Call forward with prompt_indices - the model creates prompt tensors after encoding - # This guarantees no size mismatch between encoded features and prompt tensors - encoded, encoded_len = self.forward(input_signal=audio, input_signal_length=audio_lens, prompt_indices=prompt_indices) + encoded, encoded_len = self.forward( + input_signal=audio, input_signal_length=audio_lens, prompt=prompt, prompt_indices=prompt_indices + ) # Prepare output dictionary based on decoder type if self.cur_decoder == "rnnt": @@ -505,7 +657,8 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), - "prompt_indices": NeuralType(tuple('B'), LabelsType()), # Language ID indices per sample + "prompt": NeuralType(('B', 'T', 'D'), LabelsType(), optional=True), + "prompt_indices": NeuralType(tuple('B'), LabelsType(), optional=True), } @property @@ -522,6 +675,7 @@ def forward( input_signal_length=None, processed_signal=None, processed_signal_length=None, + prompt=None, prompt_indices=None, ): """ @@ -538,9 +692,12 @@ def forward( of shape (B, D, T) that has undergone processing via some DALI preprocessor. processed_signal_length: Vector of length B, that contains the individual lengths of the processed audio sequences. + prompt: (backward-compatible) Pre-built one-hot prompt tensor of shape [B, T, D]. + If provided, used directly (trimmed to encoder length if needed). prompt_indices: Tensor of shape [B] containing language ID indices per sample. The model creates the prompt tensor after encoding using the actual encoder output length, guaranteeing no size mismatch. + Ignored if ``prompt`` is also provided. Returns: A tuple of 2 elements - @@ -569,25 +726,22 @@ def forward( encoded = torch.transpose(encoded, 1, 2) # B * D * T -> B * T * D if self.concat: - if prompt_indices is None: - raise ValueError("prompt_indices must be provided when concat mode is enabled.") - - # Create prompt tensor from indices using actual encoded length - batch_size = encoded.shape[0] - time_steps = encoded.shape[1] - num_prompts = self.num_prompts - - # Create one-hot prompt tensor: (batch_size, time_steps, num_prompts) - prompt = torch.zeros(batch_size, time_steps, num_prompts, dtype=encoded.dtype, device=encoded.device) - # Vectorized scatter: set each sample's language column to 1.0 (no Python loop) - prompt.scatter_(2, prompt_indices.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) - - out_dtype = encoded.dtype + if prompt is not None: + # Backward-compatible path: caller provided a pre-built [B, T, D] one-hot tensor + if prompt.shape[1] > encoded.shape[1]: + prompt = prompt[:, : encoded.shape[1], :] + elif prompt_indices is not None: + # New path: build one-hot from per-sample language ID indices + batch_size = encoded.shape[0] + time_steps = encoded.shape[1] + num_prompts = self.num_prompts + prompt = torch.zeros(batch_size, time_steps, num_prompts, dtype=encoded.dtype, device=encoded.device) + prompt.scatter_(2, prompt_indices.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) + else: + raise ValueError("Either prompt or prompt_indices must be provided when concat mode is enabled.") - # Concatenate encoded states with prompt + out_dtype = encoded.dtype concat_enc_states = torch.cat([encoded, prompt], dim=-1) - - # Apply joint projection encoded = self.prompt_kernel(concat_enc_states).to(out_dtype) encoded = torch.transpose(encoded, 1, 2) # B * T * D -> B * D * T @@ -601,15 +755,21 @@ def training_step(self, batch, batch_nb): if self.is_interctc_enabled(): AccessMixin.set_access_enabled(access_enabled=True) - signal, signal_len, transcript, transcript_len, prompt_indices = batch + signal, signal_len, transcript, transcript_len, prompt_or_indices = batch + + # Detect whether batch[4] is old-style prompt [B, T, D] or new-style indices [B] + prompt, prompt_indices = None, None + if prompt_or_indices.dim() == 1: + prompt_indices = prompt_or_indices + else: + prompt = prompt_or_indices - # forward() only performs encoder forward - # prompt_indices contains language ID indices [B], not full tensors - # The model creates prompt tensors after encoding using actual encoder output length if isinstance(batch, DALIOutputs) and batch.has_processed_signal: encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) else: - encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len, prompt_indices=prompt_indices) + encoded, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len, prompt=prompt, prompt_indices=prompt_indices + ) del signal # During training, loss must be computed, so decoder forward is necessary @@ -715,14 +875,20 @@ def training_step(self, batch, batch_nb): return {'loss': loss_value} def predict_step(self, batch, batch_idx, dataloader_idx=0): - signal, signal_len, transcript, transcript_len, prompt_indices = batch + signal, signal_len, transcript, transcript_len, prompt_or_indices = batch + + prompt, prompt_indices = None, None + if prompt_or_indices.dim() == 1: + prompt_indices = prompt_or_indices + else: + prompt = prompt_or_indices - # forward() only performs encoder forward - # prompt_indices contains language ID indices [B], not full tensors if isinstance(batch, DALIOutputs) and batch.has_processed_signal: encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) else: - encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len, prompt_indices=prompt_indices) + encoded, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len, prompt=prompt, prompt_indices=prompt_indices + ) del signal if self.cur_decoder == 'rnnt': @@ -746,14 +912,20 @@ def validation_pass(self, batch, batch_idx, dataloader_idx): if self.is_interctc_enabled(): AccessMixin.set_access_enabled(access_enabled=True) - signal, signal_len, transcript, transcript_len, prompt_indices = batch + signal, signal_len, transcript, transcript_len, prompt_or_indices = batch + + prompt, prompt_indices = None, None + if prompt_or_indices.dim() == 1: + prompt_indices = prompt_or_indices + else: + prompt = prompt_or_indices - # forward() only performs encoder forward - # prompt_indices contains language ID indices [B], not full tensors if isinstance(batch, DALIOutputs) and batch.has_processed_signal: encoded, encoded_len = self.forward(processed_signal=signal, processed_signal_length=signal_len) else: - encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len, prompt_indices=prompt_indices) + encoded, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len, prompt=prompt, prompt_indices=prompt_indices + ) del signal tensorboard_logs = {} diff --git a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py index 0b51919ba5a6..57bdc2b41cb8 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py +++ b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py @@ -24,11 +24,11 @@ from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset, DALIOutputs from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset -from nemo.collections.asr.data.audio_to_text_lhotse_prompt import LhotseSpeechToTextBpeDatasetWithPrompt from nemo.collections.asr.data.audio_to_text_lhotse_prompt_index import LhotseSpeechToTextBpeDatasetWithPromptIndex from nemo.collections.asr.metrics.wer import WER from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.parts.mixins import ASRTranscriptionMixin, TranscribeConfig +from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTBPEDecoding @@ -92,10 +92,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): cfg.num_prompts = cfg.model_defaults.get('num_prompts', 128) if 'prompt_dictionary' not in cfg.model_defaults: - raise ValueError( - "No prompt_dictionary found in config. " - "Please make sure your config has a prompt_dictionary in model_defaults." + logging.warning( + "No prompt_dictionary in config; using empty dict " + "(expected during checkpoint restoration)." ) + cfg.model_defaults.prompt_dictionary = {} self.subsampling_factor = cfg.get('subsampling_factor', 8) @@ -175,9 +176,124 @@ def initialize_prompt_feature(self): self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) - # ------------------------------------------------------------------ + # Streaming inference with language-ID prompt + def set_inference_prompt(self, target_lang: str): + """ + Set the language prompt for streaming inference. + + Call this before ``conformer_stream_step`` to condition decoding on + a specific language, following the same pattern as + ``change_decoding_strategy``. + + Args: + target_lang: A key from the model's ``prompt_dictionary`` + (e.g. ``"en-US"``, ``"auto"``). + """ + prompt_dict = self.cfg.model_defaults.get('prompt_dictionary', {}) + if target_lang not in prompt_dict: + available = list(prompt_dict.keys()) + raise ValueError( + f"Unknown target language '{target_lang}'. " + f"Available: {available[:20]}{'...' if len(available) > 20 else ''}" + ) + self._inference_prompt_index = prompt_dict[target_lang] + logging.info(f"Inference prompt set to '{target_lang}' (index {self._inference_prompt_index})") + + def _apply_prompt_to_encoded(self, encoded: torch.Tensor) -> torch.Tensor: + """ + Inject the language-ID prompt into encoder output during streaming. + + ``encoded`` arrives as (B, D, T) from the encoder cache-aware step. + Returns the same shape after prompt concatenation + projection. + """ + if not self.concat or not hasattr(self, '_inference_prompt_index'): + return encoded + + encoded = encoded.transpose(1, 2) # (B, D, T) -> (B, T, D) + + batch_size, time_steps, _ = encoded.shape + prompt = torch.zeros( + batch_size, time_steps, self.num_prompts, + dtype=encoded.dtype, device=encoded.device, + ) + idx = torch.full( + (batch_size,), self._inference_prompt_index, + dtype=torch.long, device=encoded.device, + ) + prompt.scatter_(2, idx.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) + + out_dtype = encoded.dtype + encoded = self.prompt_kernel(torch.cat([encoded, prompt], dim=-1)).to(out_dtype) + return encoded.transpose(1, 2) # (B, T, D) -> (B, D, T) + + def conformer_stream_step( + self, + processed_signal, + processed_signal_length=None, + cache_last_channel=None, + cache_last_time=None, + cache_last_channel_len=None, + keep_all_outputs=True, + previous_hypotheses=None, + previous_pred_out=None, + drop_extra_pre_encoded=None, + return_transcription=True, + return_log_probs=False, + bypass_pre_encode=False, + ): + """Cache-aware streaming step with language-ID prompt injection. + + Identical to the base ``ASRModuleMixin.conformer_stream_step`` except + that after the encoder step, ``_apply_prompt_to_encoded`` concatenates + the one-hot language prompt and projects back to enc_hidden. + + Set the target language via ``set_inference_prompt(target_lang)`` + before calling this method. + """ + import nemo.collections.asr.models as asr_models + + if not isinstance(self.encoder, StreamingEncoder): + raise NotImplementedError("Encoder does not support streaming!") + + ( + encoded, + encoded_len, + cache_last_channel_next, + cache_last_time_next, + cache_last_channel_next_len, + ) = self.encoder.cache_aware_stream_step( + processed_signal=processed_signal, + processed_signal_length=processed_signal_length, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + keep_all_outputs=keep_all_outputs, + drop_extra_pre_encoded=drop_extra_pre_encoded, + bypass_pre_encode=bypass_pre_encode, + ) + + encoded = self._apply_prompt_to_encoded(encoded) + + best_hyp = self.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoded, + encoded_lengths=encoded_len, + return_hypotheses=True, + partial_hypotheses=previous_hypotheses, + ) + greedy_predictions = [hyp.y_sequence for hyp in best_hyp] + all_hyp_or_transcribed_texts = best_hyp + + result = [ + greedy_predictions, + all_hyp_or_transcribed_texts, + cache_last_channel_next, + cache_last_time_next, + cache_last_channel_next_len, + best_hyp, + ] + return tuple(result) + # Data loading - # ------------------------------------------------------------------ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): if config.get('initialize_prompt_feature', True): @@ -308,9 +424,6 @@ def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): self._update_dataset_config(dataset_name='test', config=test_data_config) self._test_dl = self._setup_dataloader_from_config(config=test_data_config) - # ------------------------------------------------------------------ - # Forward - # ------------------------------------------------------------------ @property def input_types(self) -> Optional[Dict[str, NeuralType]]: if hasattr(self.preprocessor, '_sample_rate'): @@ -380,9 +493,6 @@ def forward( encoded = torch.transpose(encoded, 1, 2) # B x T x D -> B x D x T return encoded, encoded_len - # ------------------------------------------------------------------ - # Training - # ------------------------------------------------------------------ def training_step(self, batch, batch_nb): if AccessMixin.is_access_enabled(self.model_guid): AccessMixin.reset_registry(self) @@ -468,9 +578,6 @@ def training_step(self, batch, batch_nb): return {'loss': loss_value} - # ------------------------------------------------------------------ - # Validation / Test - # ------------------------------------------------------------------ def validation_pass(self, batch, batch_idx, dataloader_idx=0): signal, signal_len, transcript, transcript_len, prompt_indices = batch @@ -572,9 +679,6 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): return list(zip(sample_id, best_hyp)) - # ------------------------------------------------------------------ - # Transcription - # ------------------------------------------------------------------ def _transcribe_forward(self, batch, trcfg: RNNTPromptTranscribeConfig) -> dict: audio, audio_lens = batch[0], batch[1] if len(batch) >= 5: diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index c9a0989d1022..15b3139bcdae 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -329,6 +329,10 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu punct_pattern = '|'.join([re.escape(p) for p in self.supported_punctuation]) self.space_before_punct_pattern = re.compile(r'(\s)(' + punct_pattern + ')') + self.strip_lang_tags = self.cfg.get('strip_lang_tags', False) + if self.strip_lang_tags: + self.lang_tag_pattern = re.compile(r'\s*<[a-z]{2}-[A-Z]{2}>') + # initialize confidence-related fields self._init_confidence(self.cfg.get('confidence_cfg', None)) @@ -949,10 +953,13 @@ def decode_ids_to_str(self, tokens: List[int]) -> str: def decode_tokens_to_str_with_strip_punctuation(self, tokens: List[int]) -> str: """ Decodes a list of tokens to a string and removes a space before supported punctuation marks. + Optionally strips language-ID tags (e.g. ````) when ``strip_lang_tags`` is enabled. """ text = self.decode_ids_to_str(tokens) if self.supported_punctuation: text = self.space_before_punct_pattern.sub(r'\2', text) + if self.strip_lang_tags: + text = self.lang_tag_pattern.sub('', text).strip() return text def update_joint_fused_batch_size(self): @@ -1855,6 +1862,10 @@ class RNNTDecodingConfig: # config for multiblank decoding. big_blank_durations: Optional[List[int]] = field(default_factory=list) + # Strip language-ID tags (e.g. ) from decoded output. + # Enable for prompt-conditioned models that emit locale tags after punctuation. + strip_lang_tags: bool = True + @dataclass class RNNTBPEDecodingConfig(RNNTDecodingConfig): From 6585529e4cce95025342a4b7923af74297068e2f Mon Sep 17 00:00:00 2001 From: Enas Albasiri Date: Wed, 6 May 2026 14:44:30 +0000 Subject: [PATCH 4/9] fix code style black formatting, flake8 errors, and type hints Signed-off-by: Enas Albasiri --- ...ech_to_text_cache_aware_streaming_infer.py | 4 +-- .../data/audio_to_text_lhotse_prompt_index.py | 26 +++++++-------- .../hybrid_rnnt_ctc_bpe_models_prompt.py | 30 ++++++++++------- .../asr/models/rnnt_bpe_models_prompt.py | 32 +++++++++---------- .../asr/parts/submodules/rnnt_decoding.py | 4 +-- 5 files changed, 47 insertions(+), 49 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py index 96b3d307fcb8..6eecb05b5dfa 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py @@ -293,9 +293,7 @@ def perform_streaming( pred_out_offline_cat = torch.cat(pred_out_offline) if pred_out_stream_cat.size() == pred_out_offline_cat.size(): diff_num = torch.sum(pred_out_stream_cat != pred_out_offline_cat).cpu().numpy() - logging.info( - f"Found {diff_num} differences in the outputs of the model in streaming mode vs offline mode." - ) + logging.info(f"Found {diff_num} differences in the outputs of the model in streaming mode vs offline mode.") else: logging.info( f"The shape of the outputs of the model in streaming mode ({pred_out_stream_cat.size()}) is different from offline mode ({pred_out_offline_cat.size()})." diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py index f3c93bc3557a..6647ad3ae707 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py @@ -34,11 +34,11 @@ class LhotseSpeechToTextBpeDatasetWithPromptIndex(torch.utils.data.Dataset): """ Simplified dataset class for speech-to-text with prompt support. - + Instead of computing full prompt tensors, this dataset returns just the language ID index per sample. The model creates the prompt tensor using the actual encoder output length, guaranteeing no size mismatch. - + Returns: audio_signal: Audio waveform [B, T] audio_signal_length: Audio lengths [B] @@ -57,7 +57,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'prompt_indices': NeuralType(tuple('B'), LabelsType()), # Just indices, not full tensors } - def __init__(self, tokenizer, cfg): + def __init__(self, tokenizer: TokenizerSpec, cfg: Dict) -> None: super().__init__() self.tokenizer = TokenizerWrapper(tokenizer) self.load_audio = AudioSamples(fault_tolerant=True) @@ -67,7 +67,7 @@ def __init__(self, tokenizer, cfg): self.prompt_dict = cfg.get('prompt_dictionary') if not self.prompt_dict: raise ValueError("prompt_dictionary is required in config") - + self.num_prompts = cfg.get('num_prompts', 128) # Field to use for prompt key (default to 'target_lang') @@ -149,21 +149,18 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: tokens = [torch.as_tensor(self.tokenizer(c.supervisions[0].text, c.supervisions[0].language)) for c in cuts] # Get prompt indices (just the language ID per sample, NOT full tensors) - prompt_indices = torch.tensor( - [self._get_prompt_index_for_cut(c) for c in cuts], - dtype=torch.long - ) + prompt_indices = torch.tensor([self._get_prompt_index_for_cut(c) for c in cuts], dtype=torch.long) # Create final tensors token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) tokens = collate_vectors(tokens, padding_value=0) return ( - audio, # Audio signal [B, T] - audio_lens, # Audio lengths [B] - tokens, # Text tokens [B, T] - token_lens, # Token lengths [B] - prompt_indices, # Language ID indices [B] - model creates full tensor + audio, # Audio signal [B, T] + audio_lens, # Audio lengths [B] + tokens, # Text tokens [B, T] + token_lens, # Token lengths [B] + prompt_indices, # Language ID indices [B] - model creates full tensor ) @@ -183,7 +180,8 @@ def __call__(self, text: str, lang: Optional[str] = None): return self._impl(text, lang) def _call_agg_tokenizer(self, text: str, lang: Optional[str] = None): - assert lang is not None, "Expected 'lang' to be set for AggregateTokenizer." + if lang is None: + raise ValueError("Expected 'lang' to be set for AggregateTokenizer.") return self._tokenizer.text_to_ids(text, lang) def _call_tokenizer(self, text: str, lang: Optional[str] = None): diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py index 83dc6b42933f..12b8b5d3b465 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py @@ -24,7 +24,6 @@ from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset, DALIOutputs from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset -from nemo.collections.asr.data.audio_to_text_lhotse_prompt import LhotseSpeechToTextBpeDatasetWithPrompt from nemo.collections.asr.data.audio_to_text_lhotse_prompt_index import LhotseSpeechToTextBpeDatasetWithPromptIndex from nemo.collections.asr.metrics.bleu import BLEU from nemo.collections.asr.metrics.wer import WER @@ -60,7 +59,6 @@ class HybridRNNTCTCPromptTranscribeConfig(TranscribeConfig): prompt_field: str = "target_lang" - class EncDecHybridRNNTCTCBPEModelWithPrompt(EncDecHybridRNNTCTCBPEModel, ASRTranscriptionMixin): """Base class for encoder decoder RNNT-based models with auxiliary CTC decoder/loss, subword tokenization, and prompt conditioning.""" @@ -112,8 +110,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if 'prompt_dictionary' not in cfg.model_defaults: logging.warning( - "No prompt_dictionary in config; using empty dict " - "(expected during checkpoint restoration)." + "No prompt_dictionary in config; using empty dict " "(expected during checkpoint restoration)." ) cfg.model_defaults.prompt_dictionary = {} @@ -236,12 +233,17 @@ def _apply_prompt_to_encoded(self, encoded: torch.Tensor) -> torch.Tensor: batch_size, time_steps, _ = encoded.shape prompt = torch.zeros( - batch_size, time_steps, self.num_prompts, - dtype=encoded.dtype, device=encoded.device, + batch_size, + time_steps, + self.num_prompts, + dtype=encoded.dtype, + device=encoded.device, ) idx = torch.full( - (batch_size,), self._inference_prompt_index, - dtype=torch.long, device=encoded.device, + (batch_size,), + self._inference_prompt_index, + dtype=torch.long, + device=encoded.device, ) prompt.scatter_(2, idx.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) @@ -365,9 +367,15 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get('initialize_prompt_feature', True): # Use index-based dataset - returns prompt indices instead of full tensors # The model creates prompt tensors after encoding, guaranteeing no size mismatch - dataset_config = OmegaConf.to_container(config, resolve=True) if isinstance(config, DictConfig) else dict(config) + dataset_config = ( + OmegaConf.to_container(config, resolve=True) if isinstance(config, DictConfig) else dict(config) + ) if hasattr(self, 'cfg') and 'encoder' in self.cfg: - dataset_config['encoder'] = OmegaConf.to_container(self.cfg.encoder, resolve=True) if isinstance(self.cfg.encoder, DictConfig) else dict(self.cfg.encoder) + dataset_config['encoder'] = ( + OmegaConf.to_container(self.cfg.encoder, resolve=True) + if isinstance(self.cfg.encoder, DictConfig) + else dict(self.cfg.encoder) + ) dataset = LhotseSpeechToTextBpeDatasetWithPromptIndex(tokenizer=self.tokenizer, cfg=dataset_config) logging.info("Setting up Lhotse dataset with prompt index support (model creates prompt tensors)") else: @@ -532,7 +540,6 @@ def _transcribe_forward(self, batch: tuple[torch.Tensor, ...], trcfg: HybridRNNT return output - @torch.no_grad() def transcribe( self, @@ -1185,4 +1192,3 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: List of available pre-trained models. """ return None - diff --git a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py index 57bdc2b41cb8..334f2a374be6 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py +++ b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py @@ -49,6 +49,8 @@ @dataclass class RNNTPromptTranscribeConfig(TranscribeConfig): + """Transcription configuration for RNNT BPE Model with Prompt conditioning.""" + target_lang: str = "auto" prompt_field: str = "target_lang" @@ -93,8 +95,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if 'prompt_dictionary' not in cfg.model_defaults: logging.warning( - "No prompt_dictionary in config; using empty dict " - "(expected during checkpoint restoration)." + "No prompt_dictionary in config; using empty dict " "(expected during checkpoint restoration)." ) cfg.model_defaults.prompt_dictionary = {} @@ -213,12 +214,17 @@ def _apply_prompt_to_encoded(self, encoded: torch.Tensor) -> torch.Tensor: batch_size, time_steps, _ = encoded.shape prompt = torch.zeros( - batch_size, time_steps, self.num_prompts, - dtype=encoded.dtype, device=encoded.device, + batch_size, + time_steps, + self.num_prompts, + dtype=encoded.dtype, + device=encoded.device, ) idx = torch.full( - (batch_size,), self._inference_prompt_index, - dtype=torch.long, device=encoded.device, + (batch_size,), + self._inference_prompt_index, + dtype=torch.long, + device=encoded.device, ) prompt.scatter_(2, idx.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) @@ -250,8 +256,6 @@ def conformer_stream_step( Set the target language via ``set_inference_prompt(target_lang)`` before calling this method. """ - import nemo.collections.asr.models as asr_models - if not isinstance(self.encoder, StreamingEncoder): raise NotImplementedError("Encoder does not support streaming!") @@ -298,9 +302,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): if config.get('initialize_prompt_feature', True): dataset_config = ( - OmegaConf.to_container(config, resolve=True) - if isinstance(config, DictConfig) - else dict(config) + OmegaConf.to_container(config, resolve=True) if isinstance(config, DictConfig) else dict(config) ) if hasattr(self, 'cfg') and 'encoder' in self.cfg: dataset_config['encoder'] = ( @@ -308,9 +310,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if isinstance(self.cfg.encoder, DictConfig) else dict(self.cfg.encoder) ) - dataset = LhotseSpeechToTextBpeDatasetWithPromptIndex( - tokenizer=self.tokenizer, cfg=dataset_config - ) + dataset = LhotseSpeechToTextBpeDatasetWithPromptIndex(tokenizer=self.tokenizer, cfg=dataset_config) logging.info( "Setting up Lhotse dataset with prompt index support (RNNT-only model creates prompt tensors)" ) @@ -691,7 +691,6 @@ def _transcribe_forward(self, batch, trcfg: RNNTPromptTranscribeConfig) -> dict: if prompt_indices is None: target_lang = trcfg.target_lang prompt_dict = self.cfg.model_defaults.get('prompt_dictionary') - num_prompts = self.cfg.model_defaults.get('num_prompts', 128) if not prompt_dict: raise ValueError("Prompt dictionary is empty. Cannot create dynamic prompts.") @@ -758,8 +757,7 @@ def transcribe( else: if not isinstance(override_config, RNNTPromptTranscribeConfig): raise ValueError( - f"override_config must be of type {RNNTPromptTranscribeConfig}, " - f"but got {type(override_config)}" + f"override_config must be of type {RNNTPromptTranscribeConfig}, " f"but got {type(override_config)}" ) trcfg = override_config diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 15b3139bcdae..9fb6b353b65b 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -763,9 +763,7 @@ def rnnt_decoder_predictions_tensor( if return_hypotheses: # greedy decoding, can get high-level confidence scores - if self.preserve_frame_confidence and ( - self.preserve_word_confidence or self.preserve_token_confidence - ): + if self.preserve_frame_confidence and (self.preserve_word_confidence or self.preserve_token_confidence): hypotheses = self.compute_confidence(hypotheses) return hypotheses From 2ca55f4b8ec87b1c6c1d32cec46e33d8dd9f7d23 Mon Sep 17 00:00:00 2001 From: Enas Albasiri Date: Wed, 6 May 2026 15:07:02 +0000 Subject: [PATCH 5/9] fix code style black formatting Signed-off-by: Enas Albasiri --- .../speech_to_text_cache_aware_streaming_infer.py | 4 +++- nemo/collections/asr/models/rnnt_bpe_models_prompt.py | 3 ++- nemo/collections/asr/parts/submodules/rnnt_decoding.py | 4 +++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py index 6eecb05b5dfa..96b3d307fcb8 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py @@ -293,7 +293,9 @@ def perform_streaming( pred_out_offline_cat = torch.cat(pred_out_offline) if pred_out_stream_cat.size() == pred_out_offline_cat.size(): diff_num = torch.sum(pred_out_stream_cat != pred_out_offline_cat).cpu().numpy() - logging.info(f"Found {diff_num} differences in the outputs of the model in streaming mode vs offline mode.") + logging.info( + f"Found {diff_num} differences in the outputs of the model in streaming mode vs offline mode." + ) else: logging.info( f"The shape of the outputs of the model in streaming mode ({pred_out_stream_cat.size()}) is different from offline mode ({pred_out_offline_cat.size()})." diff --git a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py index 334f2a374be6..77b145c569c5 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py +++ b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py @@ -757,7 +757,8 @@ def transcribe( else: if not isinstance(override_config, RNNTPromptTranscribeConfig): raise ValueError( - f"override_config must be of type {RNNTPromptTranscribeConfig}, " f"but got {type(override_config)}" + f"override_config must be of type {RNNTPromptTranscribeConfig}, " + f"but got {type(override_config)}" ) trcfg = override_config diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 9fb6b353b65b..15b3139bcdae 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -763,7 +763,9 @@ def rnnt_decoder_predictions_tensor( if return_hypotheses: # greedy decoding, can get high-level confidence scores - if self.preserve_frame_confidence and (self.preserve_word_confidence or self.preserve_token_confidence): + if self.preserve_frame_confidence and ( + self.preserve_word_confidence or self.preserve_token_confidence + ): hypotheses = self.compute_confidence(hypotheses) return hypotheses From 299cf1e718e83cb08b94677c5d240e1d2c18e845 Mon Sep 17 00:00:00 2001 From: Jinhan Date: Tue, 12 May 2026 21:38:34 -0700 Subject: [PATCH 6/9] Resolve comments: 1. Set strip_lang_tags default False for inference script, manual enable by setting stip_lang_tags=True at inference(or training) time. 2. Import TokenizerWrapper in dataloader Signed-off-by: Jinhan --- ...ech_to_text_cache_aware_streaming_infer.py | 5 ++++ .../data/audio_to_text_lhotse_prompt_index.py | 29 +------------------ .../asr/parts/submodules/rnnt_decoding.py | 7 ++++- 3 files changed, 12 insertions(+), 29 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py index 96b3d307fcb8..8b777312e650 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py @@ -190,6 +190,9 @@ class TranscriptionConfig: # Set to a language key from the model's prompt_dictionary (e.g. "en-US", "auto"). # Ignored for models without prompt support. target_lang: Optional[str] = None + # whether to strip the language tags from the transcriptions + # Ignored for model without prompt support + strip_lang_tags: bool = False def extract_transcriptions(hyps): @@ -372,6 +375,8 @@ def main(cfg: TranscriptionConfig): if hasattr(asr_model, 'set_inference_prompt'): lang = cfg.target_lang if cfg.target_lang is not None else "auto" asr_model.set_inference_prompt(lang) + asr_model.decoding.strip_lang_tags = cfg.strip_lang_tags + asr_model.decoding.set_strip_lang_tags(cfg.strip_lang_tags) asr_model = asr_model.to(device=device, dtype=compute_dtype) asr_model.eval() diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py index 6647ad3ae707..249248a75f4c 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py @@ -25,7 +25,7 @@ from lhotse.dataset import AudioSamples from lhotse.dataset.collation import collate_vectors -from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer +from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType from nemo.utils import logging @@ -162,30 +162,3 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: token_lens, # Token lengths [B] prompt_indices, # Language ID indices [B] - model creates full tensor ) - - -class TokenizerWrapper: - """Provide a unified interface for NeMo Tokenizer, AggregateTokenizer, and (char) Parser.""" - - def __init__(self, tokenizer): - self._tokenizer = tokenizer - if isinstance(tokenizer, AggregateTokenizer): - self._impl = self._call_agg_tokenizer - elif isinstance(tokenizer, TokenizerSpec): - self._impl = self._call_tokenizer - else: - self._impl = self._call_parser - - def __call__(self, text: str, lang: Optional[str] = None): - return self._impl(text, lang) - - def _call_agg_tokenizer(self, text: str, lang: Optional[str] = None): - if lang is None: - raise ValueError("Expected 'lang' to be set for AggregateTokenizer.") - return self._tokenizer.text_to_ids(text, lang) - - def _call_tokenizer(self, text: str, lang: Optional[str] = None): - return self._tokenizer.text_to_ids(text) - - def _call_parser(self, text: str, lang: Optional[str] = None): - return self._tokenizer(text) diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 15b3139bcdae..be3ffcb25f41 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -685,6 +685,11 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu # Update the joint fused batch size or disable it entirely if needed. self.update_joint_fused_batch_size() + + def set_strip_lang_tags(self, strip_lang_tags: bool): + if strip_lang_tags: + logging.info("Setting strip_lang_tags to True and defined lang_tag_pattern to ") + self.lang_tag_pattern = re.compile(r'\s*<[a-z]{2}-[A-Z]{2}>') @abstractproperty def tokenizer_type(self): @@ -1864,7 +1869,7 @@ class RNNTDecodingConfig: # Strip language-ID tags (e.g. ) from decoded output. # Enable for prompt-conditioned models that emit locale tags after punctuation. - strip_lang_tags: bool = True + strip_lang_tags: bool = False @dataclass From 2d8fcad824680d21cdc316e0b0a7dce119190a15 Mon Sep 17 00:00:00 2001 From: Jinhan Date: Fri, 15 May 2026 14:19:18 -0700 Subject: [PATCH 7/9] Resolve critical structural comments and remove un-used config from rnnt prompt model yaml Signed-off-by: Jinhan --- ...ormer_transducer_bpe_streaming_prompt.yaml | 4 - .../hybrid_rnnt_ctc_bpe_models_prompt.py | 111 ------------------ .../asr/models/rnnt_bpe_models_prompt.py | 65 ---------- nemo/collections/asr/parts/mixins/mixins.py | 7 ++ 4 files changed, 7 insertions(+), 180 deletions(-) diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml index 7383eda7f40d..4fc45a047c2e 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml @@ -147,7 +147,6 @@ model: manifest_filepath: ??? sample_rate: ${model.sample_rate} use_lhotse: true - shard_manifests: true batch_duration: 400 quadratic_duration: 15 num_buckets: 30 @@ -158,9 +157,7 @@ model: min_duration: 0.1 is_tarred: true tarred_audio_filepaths: null - shuffle_n: 2048 slice_length: 100 - bucketing_strategy: "fully_randomized" bucketing_batch_size: null bucket_buffer_size: 10000 shuffle_buffer_size: 10000 @@ -196,7 +193,6 @@ model: batch_duration: null use_lhotse: true use_bucketing: false - max_cuts: 8 prompt_field: target_lang prompt_dictionary: ${model.model_defaults.prompt_dictionary} num_prompts: ${model.model_defaults.num_prompts} diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py index 12b8b5d3b465..07f2b96b95f3 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py @@ -251,117 +251,6 @@ def _apply_prompt_to_encoded(self, encoded: torch.Tensor) -> torch.Tensor: encoded = self.prompt_kernel(torch.cat([encoded, prompt], dim=-1)).to(out_dtype) return encoded.transpose(1, 2) # (B, T, D) -> (B, D, T) - def conformer_stream_step( - self, - processed_signal, - processed_signal_length=None, - cache_last_channel=None, - cache_last_time=None, - cache_last_channel_len=None, - keep_all_outputs=True, - previous_hypotheses=None, - previous_pred_out=None, - drop_extra_pre_encoded=None, - return_transcription=True, - return_log_probs=False, - bypass_pre_encode=False, - ): - """Cache-aware streaming step with language-ID prompt injection. - - Identical to the base ``ASRModuleMixin.conformer_stream_step`` except - that after the encoder step, ``_apply_prompt_to_encoded`` concatenates - the one-hot language prompt and projects back to enc_hidden. - - Set the target language via ``set_inference_prompt(target_lang)`` - before calling this method. - """ - import nemo.collections.asr.models as asr_models - - if not isinstance(self.encoder, StreamingEncoder): - raise NotImplementedError("Encoder does not support streaming!") - - ( - encoded, - encoded_len, - cache_last_channel_next, - cache_last_time_next, - cache_last_channel_next_len, - ) = self.encoder.cache_aware_stream_step( - processed_signal=processed_signal, - processed_signal_length=processed_signal_length, - cache_last_channel=cache_last_channel, - cache_last_time=cache_last_time, - cache_last_channel_len=cache_last_channel_len, - keep_all_outputs=keep_all_outputs, - drop_extra_pre_encoded=drop_extra_pre_encoded, - bypass_pre_encode=bypass_pre_encode, - ) - - encoded = self._apply_prompt_to_encoded(encoded) - - if isinstance(self, asr_models.EncDecCTCModel) or ( - isinstance(self, asr_models.EncDecHybridRNNTCTCModel) and self.cur_decoder == "ctc" - ): - if hasattr(self, "ctc_decoder"): - decoding = self.ctc_decoding - decoder = self.ctc_decoder - else: - decoding = self.decoding - decoder = self.decoder - - log_probs = decoder(encoder_output=encoded) - predictions_tensor = log_probs.argmax(dim=-1, keepdim=False) - - greedy_predictions = [] - if return_transcription: - all_hyp_or_transcribed_texts = [] - else: - all_hyp_or_transcribed_texts = None - - for preds_idx, preds in enumerate(predictions_tensor): - if encoded_len is None: - preds_cur = predictions_tensor[preds_idx] - else: - preds_cur = predictions_tensor[preds_idx, : encoded_len[preds_idx]] - if previous_pred_out is not None: - greedy_predictions_concat = torch.cat((previous_pred_out[preds_idx], preds_cur), dim=-1) - encoded_len[preds_idx] += len(previous_pred_out[preds_idx]) - else: - greedy_predictions_concat = preds_cur - greedy_predictions.append(greedy_predictions_concat) - - if return_transcription: - decoded_out = decoding.ctc_decoder_predictions_tensor( - decoder_outputs=greedy_predictions_concat.unsqueeze(0), - decoder_lengths=encoded_len[preds_idx : preds_idx + 1], - return_hypotheses=False, - ) - all_hyp_or_transcribed_texts.append(decoded_out[0]) - best_hyp = None - else: - best_hyp = self.decoding.rnnt_decoder_predictions_tensor( - encoder_output=encoded, - encoded_lengths=encoded_len, - return_hypotheses=True, - partial_hypotheses=previous_hypotheses, - ) - greedy_predictions = [hyp.y_sequence for hyp in best_hyp] - all_hyp_or_transcribed_texts = best_hyp - - result = [ - greedy_predictions, - all_hyp_or_transcribed_texts, - cache_last_channel_next, - cache_last_time_next, - cache_last_channel_next_len, - best_hyp, - ] - if return_log_probs: - result.append(log_probs) - result.append(encoded_len) - - return tuple(result) - def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): if config.get('initialize_prompt_feature', True): diff --git a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py index 77b145c569c5..237fe8ba30b2 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py +++ b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py @@ -232,71 +232,6 @@ def _apply_prompt_to_encoded(self, encoded: torch.Tensor) -> torch.Tensor: encoded = self.prompt_kernel(torch.cat([encoded, prompt], dim=-1)).to(out_dtype) return encoded.transpose(1, 2) # (B, T, D) -> (B, D, T) - def conformer_stream_step( - self, - processed_signal, - processed_signal_length=None, - cache_last_channel=None, - cache_last_time=None, - cache_last_channel_len=None, - keep_all_outputs=True, - previous_hypotheses=None, - previous_pred_out=None, - drop_extra_pre_encoded=None, - return_transcription=True, - return_log_probs=False, - bypass_pre_encode=False, - ): - """Cache-aware streaming step with language-ID prompt injection. - - Identical to the base ``ASRModuleMixin.conformer_stream_step`` except - that after the encoder step, ``_apply_prompt_to_encoded`` concatenates - the one-hot language prompt and projects back to enc_hidden. - - Set the target language via ``set_inference_prompt(target_lang)`` - before calling this method. - """ - if not isinstance(self.encoder, StreamingEncoder): - raise NotImplementedError("Encoder does not support streaming!") - - ( - encoded, - encoded_len, - cache_last_channel_next, - cache_last_time_next, - cache_last_channel_next_len, - ) = self.encoder.cache_aware_stream_step( - processed_signal=processed_signal, - processed_signal_length=processed_signal_length, - cache_last_channel=cache_last_channel, - cache_last_time=cache_last_time, - cache_last_channel_len=cache_last_channel_len, - keep_all_outputs=keep_all_outputs, - drop_extra_pre_encoded=drop_extra_pre_encoded, - bypass_pre_encode=bypass_pre_encode, - ) - - encoded = self._apply_prompt_to_encoded(encoded) - - best_hyp = self.decoding.rnnt_decoder_predictions_tensor( - encoder_output=encoded, - encoded_lengths=encoded_len, - return_hypotheses=True, - partial_hypotheses=previous_hypotheses, - ) - greedy_predictions = [hyp.y_sequence for hyp in best_hyp] - all_hyp_or_transcribed_texts = best_hyp - - result = [ - greedy_predictions, - all_hyp_or_transcribed_texts, - cache_last_channel_next, - cache_last_time_next, - cache_last_channel_next_len, - best_hyp, - ] - return tuple(result) - # Data loading def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index af973be3cc4c..a83467fdbb96 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -588,6 +588,11 @@ def change_subsampling_conv_chunking_factor( if update_config: with open_dict(self.cfg): self.cfg.encoder.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + def _apply_prompt_to_encoded(self, encoded: Tensor) -> Tensor: + """Hook for prompt-conditioned subclasses to inject a language prompt + into the encoder output. Default: no-op.""" + return encoded def conformer_stream_step( self, @@ -661,6 +666,8 @@ def conformer_stream_step( bypass_pre_encode=bypass_pre_encode, ) + encoded = self._apply_prompt_to_encoded(encoded) + if isinstance(self, asr_models.EncDecCTCModel) or ( isinstance(self, asr_models.EncDecHybridRNNTCTCModel) and self.cur_decoder == "ctc" ): From c76c347970811680f2225bfa1a237d2f58a4f779 Mon Sep 17 00:00:00 2001 From: Jinhan Date: Thu, 21 May 2026 12:27:36 -0700 Subject: [PATCH 8/9] Address comment: 1. Consolidate conformer stream step into mixin. 2. Consolidate streaming inference into mixin 3. Remove unused option in config.yaml 3. Consolidate strip_lang_tag into unified function. 4. Make lang_tag pattern customizable Signed-off-by: Jinhan --- ...ech_to_text_cache_aware_streaming_infer.py | 5 +- ...ormer_transducer_bpe_streaming_prompt.yaml | 4 +- .../data/audio_to_text_lhotse_prompt_index.py | 13 +--- .../hybrid_rnnt_ctc_bpe_models_prompt.py | 60 +--------------- .../asr/models/rnnt_bpe_models_prompt.py | 59 +--------------- nemo/collections/asr/parts/mixins/__init__.py | 2 + nemo/collections/asr/parts/mixins/mixins.py | 68 ++++++++++++++++++- .../asr/parts/submodules/rnnt_decoding.py | 31 +++++++-- 8 files changed, 104 insertions(+), 138 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py index 8b777312e650..67c18a366075 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py @@ -193,6 +193,8 @@ class TranscriptionConfig: # whether to strip the language tags from the transcriptions # Ignored for model without prompt support strip_lang_tags: bool = False + # Optional regex describing the language tag to strip. Defaults to "". (r'\s*<[a-z]{2}-[A-Z]{2}>') + lang_tag_pattern: Optional[str] = None def extract_transcriptions(hyps): @@ -375,8 +377,7 @@ def main(cfg: TranscriptionConfig): if hasattr(asr_model, 'set_inference_prompt'): lang = cfg.target_lang if cfg.target_lang is not None else "auto" asr_model.set_inference_prompt(lang) - asr_model.decoding.strip_lang_tags = cfg.strip_lang_tags - asr_model.decoding.set_strip_lang_tags(cfg.strip_lang_tags) + asr_model.decoding.set_strip_lang_tags(cfg.strip_lang_tags, lang_tag_pattern=cfg.lang_tag_pattern) asr_model = asr_model.to(device=device, dtype=compute_dtype) asr_model.eval() diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml index 4fc45a047c2e..39dfa96efb3d 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml @@ -166,7 +166,6 @@ model: num_prompts: ${model.model_defaults.num_prompts} subsampling_factor: ${model.encoder.subsampling_factor} lang_field: target_lang - training_mode: true # Per-dataset prompt mode — controls how language prompts are selected during training. # The mode is set per data source via lhotse input_cfg tags: @@ -197,7 +196,6 @@ model: prompt_dictionary: ${model.model_defaults.prompt_dictionary} num_prompts: ${model.model_defaults.num_prompts} subsampling_factor: ${model.encoder.subsampling_factor} - training_mode: true test_ds: manifest_filepath: null @@ -213,7 +211,7 @@ model: prompt_dictionary: ${model.model_defaults.prompt_dictionary} num_prompts: ${model.model_defaults.num_prompts} subsampling_factor: ${model.encoder.subsampling_factor} - training_mode: false + default_prompt_mode: langID tokenizer: dir: ??? diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py index 249248a75f4c..e99a31588e8e 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py @@ -73,8 +73,6 @@ def __init__(self, tokenizer: TokenizerSpec, cfg: Dict) -> None: # Field to use for prompt key (default to 'target_lang') self.prompt_field = cfg.get('prompt_field', 'target_lang') - self.training_mode = cfg.get('training_mode', True) - # Per-dataset prompt mode is read from cut.custom["prompt_mode"] at runtime. # Supported values: # "langID" — always pass the real language ID @@ -115,19 +113,14 @@ def _get_prompt_index_for_cut(self, cut) -> int: """ Determine the prompt index for a cut based on its prompt_mode tag. - During inference (training_mode=False): always returns the real lang ID - regardless of prompt_mode. - - During training, behaviour depends on prompt_mode (set per-dataset via - lhotse input_cfg tags): + Behaviour depends on prompt_mode (set per-dataset via lhotse + input_cfg tags, falling back to ``default_prompt_mode``): "langID" — always return the real language ID + (use for inference / language-forced tasks) "auto" — always return auto index (language-agnostic) "unified" — return auto with probability unified_auto_ratio, otherwise the real language ID """ - if not self.training_mode: - return self._get_prompt_index(cut.supervisions[0].language) - mode = self._get_prompt_mode(cut) if mode == 'langID': diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py index 07f2b96b95f3..e90851f3c8a4 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py @@ -28,7 +28,7 @@ from nemo.collections.asr.metrics.bleu import BLEU from nemo.collections.asr.metrics.wer import WER from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel -from nemo.collections.asr.parts.mixins import ASRTranscriptionMixin, TranscribeConfig +from nemo.collections.asr.parts.mixins import ASRTranscriptionMixin, PromptStreamingMixin, TranscribeConfig from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType @@ -59,7 +59,7 @@ class HybridRNNTCTCPromptTranscribeConfig(TranscribeConfig): prompt_field: str = "target_lang" -class EncDecHybridRNNTCTCBPEModelWithPrompt(EncDecHybridRNNTCTCBPEModel, ASRTranscriptionMixin): +class EncDecHybridRNNTCTCBPEModelWithPrompt(PromptStreamingMixin, EncDecHybridRNNTCTCBPEModel, ASRTranscriptionMixin): """Base class for encoder decoder RNNT-based models with auxiliary CTC decoder/loss, subword tokenization, and prompt conditioning.""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): @@ -195,62 +195,6 @@ def initialize_prompt_feature(self): # setting the RNNT decoder as the default one self.cur_decoder = "rnnt" - # Streaming inference with language-ID prompt - - def set_inference_prompt(self, target_lang: str): - """ - Set the language prompt for streaming inference. - - Call this before ``conformer_stream_step`` to condition decoding on - a specific language, following the same pattern as - ``change_decoding_strategy``. - - Args: - target_lang: A key from the model's ``prompt_dictionary`` - (e.g. ``"en-US"``, ``"auto"``). - """ - prompt_dict = self.cfg.model_defaults.get('prompt_dictionary', {}) - if target_lang not in prompt_dict: - available = list(prompt_dict.keys()) - raise ValueError( - f"Unknown target language '{target_lang}'. " - f"Available: {available[:20]}{'...' if len(available) > 20 else ''}" - ) - self._inference_prompt_index = prompt_dict[target_lang] - logging.info(f"Inference prompt set to '{target_lang}' (index {self._inference_prompt_index})") - - def _apply_prompt_to_encoded(self, encoded: torch.Tensor) -> torch.Tensor: - """ - Inject the language-ID prompt into encoder output during streaming. - - ``encoded`` arrives as (B, D, T) from the encoder cache-aware step. - Returns the same shape after prompt concatenation + projection. - """ - if not self.concat or not hasattr(self, '_inference_prompt_index'): - return encoded - - encoded = encoded.transpose(1, 2) # (B, D, T) -> (B, T, D) - - batch_size, time_steps, _ = encoded.shape - prompt = torch.zeros( - batch_size, - time_steps, - self.num_prompts, - dtype=encoded.dtype, - device=encoded.device, - ) - idx = torch.full( - (batch_size,), - self._inference_prompt_index, - dtype=torch.long, - device=encoded.device, - ) - prompt.scatter_(2, idx.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) - - out_dtype = encoded.dtype - encoded = self.prompt_kernel(torch.cat([encoded, prompt], dim=-1)).to(out_dtype) - return encoded.transpose(1, 2) # (B, T, D) -> (B, D, T) - def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): if config.get('initialize_prompt_feature', True): diff --git a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py index 237fe8ba30b2..00eb42d3ff51 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py +++ b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py @@ -27,7 +27,7 @@ from nemo.collections.asr.data.audio_to_text_lhotse_prompt_index import LhotseSpeechToTextBpeDatasetWithPromptIndex from nemo.collections.asr.metrics.wer import WER from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel -from nemo.collections.asr.parts.mixins import ASRTranscriptionMixin, TranscribeConfig +from nemo.collections.asr.parts.mixins import ASRTranscriptionMixin, PromptStreamingMixin, TranscribeConfig from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType @@ -55,7 +55,7 @@ class RNNTPromptTranscribeConfig(TranscribeConfig): prompt_field: str = "target_lang" -class EncDecRNNTBPEModelWithPrompt(EncDecRNNTBPEModel, ASRTranscriptionMixin): +class EncDecRNNTBPEModelWithPrompt(PromptStreamingMixin, EncDecRNNTBPEModel, ASRTranscriptionMixin): """Encoder-decoder RNNT model with subword tokenization and prompt conditioning. This is the RNNT-only variant (no auxiliary CTC head) of the prompt-aware @@ -177,61 +177,6 @@ def initialize_prompt_feature(self): self.joint.set_loss(self.loss) self.joint.set_wer(self.wer) - # Streaming inference with language-ID prompt - def set_inference_prompt(self, target_lang: str): - """ - Set the language prompt for streaming inference. - - Call this before ``conformer_stream_step`` to condition decoding on - a specific language, following the same pattern as - ``change_decoding_strategy``. - - Args: - target_lang: A key from the model's ``prompt_dictionary`` - (e.g. ``"en-US"``, ``"auto"``). - """ - prompt_dict = self.cfg.model_defaults.get('prompt_dictionary', {}) - if target_lang not in prompt_dict: - available = list(prompt_dict.keys()) - raise ValueError( - f"Unknown target language '{target_lang}'. " - f"Available: {available[:20]}{'...' if len(available) > 20 else ''}" - ) - self._inference_prompt_index = prompt_dict[target_lang] - logging.info(f"Inference prompt set to '{target_lang}' (index {self._inference_prompt_index})") - - def _apply_prompt_to_encoded(self, encoded: torch.Tensor) -> torch.Tensor: - """ - Inject the language-ID prompt into encoder output during streaming. - - ``encoded`` arrives as (B, D, T) from the encoder cache-aware step. - Returns the same shape after prompt concatenation + projection. - """ - if not self.concat or not hasattr(self, '_inference_prompt_index'): - return encoded - - encoded = encoded.transpose(1, 2) # (B, D, T) -> (B, T, D) - - batch_size, time_steps, _ = encoded.shape - prompt = torch.zeros( - batch_size, - time_steps, - self.num_prompts, - dtype=encoded.dtype, - device=encoded.device, - ) - idx = torch.full( - (batch_size,), - self._inference_prompt_index, - dtype=torch.long, - device=encoded.device, - ) - prompt.scatter_(2, idx.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) - - out_dtype = encoded.dtype - encoded = self.prompt_kernel(torch.cat([encoded, prompt], dim=-1)).to(out_dtype) - return encoded.transpose(1, 2) # (B, T, D) -> (B, D, T) - # Data loading def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): diff --git a/nemo/collections/asr/parts/mixins/__init__.py b/nemo/collections/asr/parts/mixins/__init__.py index c8bfd1503454..31b263a4e37a 100644 --- a/nemo/collections/asr/parts/mixins/__init__.py +++ b/nemo/collections/asr/parts/mixins/__init__.py @@ -19,6 +19,7 @@ ASRBPEMixin, ASRModuleMixin, DiarizationMixin, + PromptStreamingMixin, ) from nemo.collections.asr.parts.mixins.multitalker_asr_mixins import SpeakerKernelMixin from nemo.collections.asr.parts.mixins.transcription import ( @@ -35,6 +36,7 @@ 'ASRTranscriptionMixin', 'DiarizationMixin', 'InterCTCMixin', + 'PromptStreamingMixin', 'SpeakerKernelMixin', 'TranscribeConfig', 'TranscriptionMixin', diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index a83467fdbb96..51c03b1dfdce 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -591,7 +591,8 @@ def change_subsampling_conv_chunking_factor( def _apply_prompt_to_encoded(self, encoded: Tensor) -> Tensor: """Hook for prompt-conditioned subclasses to inject a language prompt - into the encoder output. Default: no-op.""" + into the encoder output. Default: no-op. See ``PromptStreamingMixin`` + for the prompt-aware override.""" return encoded def conformer_stream_step( @@ -904,3 +905,68 @@ def diarize(self, paths2audio_files: List[str], batch_size: int = 1) -> List[str Speaker labels """ pass + +class PromptStreamingMixin: + """Adds language-ID prompt conditioning to a cache-aware streaming ASR model. + + Overrides ``ASRModuleMixin._apply_prompt_to_encoded`` so that + ``conformer_stream_step`` injects a one-hot language prompt into the + encoder output. The host class is expected to set ``self.concat``, + ``self.num_prompts``, and ``self.prompt_kernel`` during its own + ``initialize_prompt_feature``. + """ + + def set_inference_prompt(self, target_lang: str): + """ + Set the language prompt for streaming inference. + + Call this before ``conformer_stream_step`` to condition decoding on + a specific language, following the same pattern as + ``change_decoding_strategy``. + + Args: + target_lang: A key from the model's ``prompt_dictionary`` + (e.g. ``"en-US"``, ``"auto"``). + """ + prompt_dict = self.cfg.model_defaults.get('prompt_dictionary', {}) + if target_lang not in prompt_dict: + available = list(prompt_dict.keys()) + raise ValueError( + f"Unknown target language '{target_lang}'. " + f"Available: {available[:20]}{'...' if len(available) > 20 else ''}" + ) + self._inference_prompt_index = prompt_dict[target_lang] + logging.info(f"Inference prompt set to '{target_lang}' (index {self._inference_prompt_index})") + + def _apply_prompt_to_encoded(self, encoded: Tensor) -> Tensor: + """ + Inject the language-ID prompt into encoder output during streaming. + + ``encoded`` arrives as (B, D, T) from the encoder cache-aware step. + Returns the same shape after prompt concatenation + projection. + """ + if not self.concat or not hasattr(self, '_inference_prompt_index'): + return encoded + + encoded = encoded.transpose(1, 2) # (B, D, T) -> (B, T, D) + + batch_size, time_steps, _ = encoded.shape + prompt = torch.zeros( + batch_size, + time_steps, + self.num_prompts, + dtype=encoded.dtype, + device=encoded.device, + ) + idx = torch.full( + (batch_size,), + self._inference_prompt_index, + dtype=torch.long, + device=encoded.device, + ) + prompt.scatter_(2, idx.view(batch_size, 1, 1).expand(-1, time_steps, -1), 1.0) + + out_dtype = encoded.dtype + encoded = self.prompt_kernel(torch.cat([encoded, prompt], dim=-1)).to(out_dtype) + return encoded.transpose(1, 2) # (B, T, D) -> (B, D, T) + diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index be3ffcb25f41..4160d00896fa 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -329,10 +329,11 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu punct_pattern = '|'.join([re.escape(p) for p in self.supported_punctuation]) self.space_before_punct_pattern = re.compile(r'(\s)(' + punct_pattern + ')') - self.strip_lang_tags = self.cfg.get('strip_lang_tags', False) - if self.strip_lang_tags: - self.lang_tag_pattern = re.compile(r'\s*<[a-z]{2}-[A-Z]{2}>') - + self.set_strip_lang_tags( + self.cfg.get('strip_lang_tags', False), + lang_tag_pattern=self.cfg.get('lang_tag_pattern', None), + ) + # initialize confidence-related fields self._init_confidence(self.cfg.get('confidence_cfg', None)) @@ -686,10 +687,22 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu # Update the joint fused batch size or disable it entirely if needed. self.update_joint_fused_batch_size() - def set_strip_lang_tags(self, strip_lang_tags: bool): + def set_strip_lang_tags(self, strip_lang_tags: bool, lang_tag_pattern: Optional[str] = None): + """ + Toggle language-tag stripping on decoded text. + + Args: + strip_lang_tags: Whether ``decode_tokens_to_str_with_strip_punctuation`` + should remove language tags from its output. + lang_tag_pattern: Optional regex (as a string) describing the tag to + strip. Defaults to ``\\s*<[a-z]{2}-[A-Z]{2}>`` (````). + Ignored when ``strip_lang_tags`` is False. + """ + self.strip_lang_tags = strip_lang_tags if strip_lang_tags: - logging.info("Setting strip_lang_tags to True and defined lang_tag_pattern to ") - self.lang_tag_pattern = re.compile(r'\s*<[a-z]{2}-[A-Z]{2}>') + pattern = lang_tag_pattern if lang_tag_pattern is not None else r'\s*<[a-z]{2}-[A-Z]{2}>' + logging.info(f"Setting strip_lang_tags to True with lang_tag_pattern={pattern!r}") + self.lang_tag_pattern = re.compile(pattern) @abstractproperty def tokenizer_type(self): @@ -1871,6 +1884,10 @@ class RNNTDecodingConfig: # Enable for prompt-conditioned models that emit locale tags after punctuation. strip_lang_tags: bool = False + # Optional regex (as a string) describing the language tag to strip. + # When None, defaults to ``DEFAULT_LANG_TAG_PATTERN`` (``\s*<[a-z]{2}-[A-Z]{2}>``). + lang_tag_pattern: Optional[str] = None + @dataclass class RNNTBPEDecodingConfig(RNNTDecodingConfig): From 92887d057ebd2ee9b2a0c4ed611ef56738833c7e Mon Sep 17 00:00:00 2001 From: Jinhan Date: Thu, 21 May 2026 13:50:12 -0700 Subject: [PATCH 9/9] Unified all lang_field to be target_lang Signed-off-by: Jinhan --- .../fastconformer_transducer_bpe_streaming_prompt.yaml | 5 ++--- ...fastconformer_hybrid_transducer_ctc_bpe_prompt.yaml | 10 +++++----- .../asr/data/audio_to_text_lhotse_prompt_index.py | 3 --- .../asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py | 5 ----- nemo/collections/asr/models/rnnt_bpe_models_prompt.py | 4 ---- 5 files changed, 7 insertions(+), 20 deletions(-) diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml index 39dfa96efb3d..ee58ad6ec9cf 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming_prompt.yaml @@ -161,7 +161,6 @@ model: bucketing_batch_size: null bucket_buffer_size: 10000 shuffle_buffer_size: 10000 - prompt_field: target_lang prompt_dictionary: ${model.model_defaults.prompt_dictionary} num_prompts: ${model.model_defaults.num_prompts} subsampling_factor: ${model.encoder.subsampling_factor} @@ -192,10 +191,10 @@ model: batch_duration: null use_lhotse: true use_bucketing: false - prompt_field: target_lang prompt_dictionary: ${model.model_defaults.prompt_dictionary} num_prompts: ${model.model_defaults.num_prompts} subsampling_factor: ${model.encoder.subsampling_factor} + lang_field: target_lang test_ds: manifest_filepath: null @@ -207,10 +206,10 @@ model: pin_memory: true use_lhotse: true use_bucketing: false - prompt_field: target_lang prompt_dictionary: ${model.model_defaults.prompt_dictionary} num_prompts: ${model.model_defaults.num_prompts} subsampling_factor: ${model.encoder.subsampling_factor} + lang_field: target_lang default_prompt_mode: langID tokenizer: diff --git a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe_prompt.yaml b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe_prompt.yaml index adca086e938f..022ff06eb21e 100644 --- a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe_prompt.yaml +++ b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe_prompt.yaml @@ -130,10 +130,10 @@ model: bucket_buffer_size: 10000 shuffle_buffer_size: 10000 #prompt configs - prompt_field: target_lang prompt_dictionary: ${model.model_defaults.prompt_dictionary} num_prompts: ${model.model_defaults.num_prompts} subsampling_factor: ${model.encoder.subsampling_factor} + lang_field: target_lang validation_ds: manifest_filepath: ??? @@ -149,11 +149,11 @@ model: max_cuts: 8 # prompt configurations for validation - prompt_field: target_lang prompt_dictionary: ${model.model_defaults.prompt_dictionary} num_prompts: ${model.model_defaults.num_prompts} subsampling_factor: ${model.encoder.subsampling_factor} - + lang_field: target_lang + test_ds: manifest_filepath: ??? sample_rate: ${model.sample_rate} @@ -165,11 +165,11 @@ model: use_lhotse: true use_bucketing: false # prompt configurations for testing - prompt_field: target_lang prompt_dictionary: ${model.model_defaults.prompt_dictionary} num_prompts: ${model.model_defaults.num_prompts} subsampling_factor: ${model.encoder.subsampling_factor} - + lang_field: target_lang + default_prompt_mode: langID # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py # We recommend to use vocab size of 1024 with SPE Unigram for most languages tokenizer: diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py index e99a31588e8e..d836f592cad8 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompt_index.py @@ -70,9 +70,6 @@ def __init__(self, tokenizer: TokenizerSpec, cfg: Dict) -> None: self.num_prompts = cfg.get('num_prompts', 128) - # Field to use for prompt key (default to 'target_lang') - self.prompt_field = cfg.get('prompt_field', 'target_lang') - # Per-dataset prompt mode is read from cut.custom["prompt_mode"] at runtime. # Supported values: # "langID" — always pass the real language ID diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py index e90851f3c8a4..e50bdfe436c2 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py @@ -56,7 +56,6 @@ class HybridRNNTCTCPromptTranscribeConfig(TranscribeConfig): """ target_lang: str = "auto" - prompt_field: str = "target_lang" class EncDecHybridRNNTCTCBPEModelWithPrompt(PromptStreamingMixin, EncDecHybridRNNTCTCBPEModel, ASRTranscriptionMixin): @@ -299,7 +298,6 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo 'use_lhotse': config.get('use_lhotse', True), 'use_bucketing': False, 'drop_last': False, - 'prompt_field': config.get('prompt_field', 'target_lang'), 'initialize_prompt_feature': True, 'prompt_dictionary': self.cfg.model_defaults.get('prompt_dictionary'), 'num_prompts': self.cfg.model_defaults.get('num_prompts', 128), @@ -411,7 +409,6 @@ def transcribe( override_config: (Optional[HybridRNNTCTCPromptTranscribeConfig]) override transcription config pre-defined by the user. **prompt: Optional input to construct the prompts for the model. Accepted formats include: target_lang: (str) target language ID for transcription (e.g., "en-US", "de-DE") - prompt_field: (str) field name to use for prompt extraction from manifest Additional prompt parameters can be passed and will be forwarded to the transcription config. Returns: @@ -449,7 +446,6 @@ def transcribe( if override_config is None: # Extract target_lang from prompt or use default target_lang = prompt.get('target_lang', 'auto') - prompt_field = prompt.get('prompt_field', 'target_lang') trcfg = HybridRNNTCTCPromptTranscribeConfig( batch_size=batch_size, @@ -460,7 +456,6 @@ def transcribe( verbose=verbose, timestamps=timestamps, target_lang=target_lang, - prompt_field=prompt_field, ) else: diff --git a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py index 00eb42d3ff51..b909dce398d8 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models_prompt.py +++ b/nemo/collections/asr/models/rnnt_bpe_models_prompt.py @@ -52,7 +52,6 @@ class RNNTPromptTranscribeConfig(TranscribeConfig): """Transcription configuration for RNNT BPE Model with Prompt conditioning.""" target_lang: str = "auto" - prompt_field: str = "target_lang" class EncDecRNNTBPEModelWithPrompt(PromptStreamingMixin, EncDecRNNTBPEModel, ASRTranscriptionMixin): @@ -262,7 +261,6 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo 'use_lhotse': config.get('use_lhotse', True), 'use_bucketing': False, 'drop_last': False, - 'prompt_field': config.get('prompt_field', 'target_lang'), 'initialize_prompt_feature': True, 'prompt_dictionary': self.cfg.model_defaults.get('prompt_dictionary'), 'num_prompts': self.cfg.model_defaults.get('num_prompts', 128), @@ -621,7 +619,6 @@ def transcribe( if override_config is None: target_lang = prompt.get('target_lang', 'auto') - prompt_field = prompt.get('prompt_field', 'target_lang') trcfg = RNNTPromptTranscribeConfig( batch_size=batch_size, @@ -632,7 +629,6 @@ def transcribe( verbose=verbose, timestamps=timestamps, target_lang=target_lang, - prompt_field=prompt_field, ) else: if not isinstance(override_config, RNNTPromptTranscribeConfig):