-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Adding Conformer encoder I/O-styled Transformer encoder #15703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tango4j
wants to merge
5
commits into
main
Choose a base branch
from
add_tf_encoder_asr
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,135
−101
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
f86e346
Adding Conformer encoder style Transformer encoder
tango4j ba9d40d
Adding final touch up
tango4j 5398604
Fixing Black issue
tango4j c19ca03
Adding relative position encoding and transformer-ctc yaml
tango4j 6725930
Apply black formatting
tango4j File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,216 @@ | ||
| # It contains the default values for training a Transformer-CTC ASR model with CTC loss and sub-word encoding. | ||
| # | ||
| # This config is the Transformer counterpart of ``fast-conformer_ctc_bpe.yaml``: same | ||
| # preprocessor / spec-augment / decoder / optimiser / trainer / exp_manager sections, but the | ||
| # encoder is the FlexAttention-based ``TransformerEncoder`` defined in | ||
| # ``nemo/collections/asr/modules/transformer_encoder.py``. By default it uses | ||
| # ``self_attention_model: rel_pos`` (Transformer-XL relative positional encoding wired into | ||
| # FlexAttention via a ``score_mod`` closure and a ``Q + pos_bias_u`` query rewrite). | ||
| # | ||
| # Use trainer.precision=bf16 on GPUs that support it; the FlexAttention kernel is compiled | ||
| # with ``torch.compile(dynamic=True)`` and works on CUDA out of the box. On CPU it falls back | ||
| # to the un-fused FlexAttention path. | ||
|
|
||
| name: "Transformer-CTC-BPE" | ||
|
|
||
| model: | ||
| sample_rate: 16000 | ||
| log_prediction: true # enables logging sample predictions in the output during training | ||
| ctc_reduction: 'mean_volume' | ||
| skip_nan_grad: false | ||
|
|
||
| train_ds: | ||
| manifest_filepath: ??? | ||
| sample_rate: ${model.sample_rate} | ||
| batch_size: 16 # you may increase batch_size if your memory allows | ||
| shuffle: true | ||
| num_workers: 8 | ||
| pin_memory: true | ||
| max_duration: 16.7 # it is set for LibriSpeech, you may need to update it for your dataset | ||
| min_duration: 0.1 | ||
| # tarred datasets | ||
| is_tarred: false | ||
| tarred_audio_filepaths: null | ||
| shuffle_n: 2048 | ||
| # bucketing params | ||
| bucketing_strategy: "fully_randomized" | ||
| bucketing_batch_size: null | ||
|
|
||
| validation_ds: | ||
| manifest_filepath: ??? | ||
| sample_rate: ${model.sample_rate} | ||
| batch_size: 16 | ||
| shuffle: false | ||
| use_start_end_token: false | ||
| num_workers: 8 | ||
| pin_memory: true | ||
|
|
||
| test_ds: | ||
| manifest_filepath: null | ||
| sample_rate: ${model.sample_rate} | ||
| batch_size: 16 | ||
| shuffle: false | ||
| use_start_end_token: false | ||
| num_workers: 8 | ||
| pin_memory: true | ||
|
|
||
| # recommend vocab size of 128 or 256 when training on ~1k hr datasets and 1k vocab size on 10+k hr datasets | ||
| # you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py | ||
| tokenizer: | ||
| dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe) | ||
| type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) | ||
|
|
||
| preprocessor: | ||
| _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor | ||
| sample_rate: ${model.sample_rate} | ||
| normalize: "per_feature" | ||
| window_size: 0.025 | ||
| window_stride: 0.01 | ||
| window: "hann" | ||
| features: 80 | ||
| n_fft: 512 | ||
| log: true | ||
| frame_splicing: 1 | ||
| dither: 0.00001 | ||
| pad_to: 0 | ||
| pad_value: 0.0 | ||
|
|
||
| spec_augment: | ||
| _target_: nemo.collections.asr.modules.SpectrogramAugmentation | ||
| freq_masks: 2 # set to zero to disable it | ||
| # you may use lower time_masks for smaller models to have a faster convergence | ||
| time_masks: 10 # set to zero to disable it | ||
| freq_width: 27 | ||
| time_width: 0.05 | ||
|
|
||
| encoder: | ||
| _target_: nemo.collections.asr.modules.TransformerEncoder | ||
| feat_in: ${model.preprocessor.features} | ||
| feat_out: -1 # you may set it if you need different output size other than the default d_model | ||
| # n_layers=31 chosen so the encoder has ~108.1M params, matching the FastConformer baseline | ||
| # (109.5M at d=512, L=17, conv_kernel=9, ff_x4) to within ~1.4%. A Conformer layer carries | ||
| # an extra convolution module and a sandwich-pair of FFNs, so the post-norm Transformer | ||
| # needs more layers (not more heads — heads only partition d_model and add no parameters) | ||
| # to reach the same capacity. | ||
| n_layers: 31 | ||
| d_model: 512 | ||
| n_heads: 8 | ||
|
|
||
| # Sub-sampling params (Conformer-style options are supported; ``feature_stacking`` is the | ||
| # Transformer-native default in the module itself, but for parity with the Canary-flash | ||
| # baseline we keep dw_striding x8 here.) | ||
| subsampling: dw_striding # feature_stacking, stacking, stacking_norm, vggnet, striding, dw_striding, striding_conv1d, dw_striding_conv1d | ||
| subsampling_factor: 8 # must be power of 2 for striding / vggnet variants | ||
| subsampling_conv_channels: 256 # -1 sets it to d_model | ||
| subsampling_conv_chunking_factor: 1 # 1 = auto-chunking, -1 = no chunking, otherwise power-of-2 | ||
| causal_downsampling: false | ||
|
|
||
| # Feed-forward module's params | ||
| ff_expansion: 4.0 # FFN hidden = ff_expansion * d_model | ||
|
|
||
| # Self-attention / positional encoding | ||
| # - ``rel_pos`` (default): Transformer-XL relative positional encoding, wired into | ||
| # FlexAttention via a ``score_mod`` closure plus a ``Q + pos_bias_u`` query rewrite. | ||
| # - ``abs_pos``: sinusoidal absolute positional encoding added before the first block. | ||
| # - ``no_pos`` (or ``null``): no positional encoding; pre-encoder output flows directly | ||
| # into ``embed_norm`` and the Transformer blocks. | ||
| self_attention_model: rel_pos | ||
| pos_emb_max_len: 5000 | ||
| xscaling: false # scale embeddings by sqrt(d_model); mostly a no-op when pre_block_norm=true | ||
|
|
||
| # Attention/FFN block options | ||
| qkv_bias: false # add a learnable bias to the fused Q/K/V projection (Whisper-style: false) | ||
| qk_norm: false # per-head LayerNorm on Q and K before the dot product (OLMo 2 / Gemma 3 style) | ||
| pre_block_norm: true # BERT/ViT-style: LayerNorm on embeddings before the first block | ||
| attn_mode: full # currently only "full" (bidirectional) is supported | ||
|
|
||
| # Regularization | ||
| drop_rate: 0.1 # dropout inside attention/FFN sublayers (corresponds to conformer's ``dropout``) | ||
| dropout_pre_encoder: 0.1 # dropout applied after positional encoding (unused when self_attention_model=no_pos) | ||
| dropout_emb: 0.0 # dropout for the positional embeddings (unused when self_attention_model=no_pos) | ||
|
|
||
| # Set to non-zero to enable stochastic depth | ||
| stochastic_depth_drop_prob: 0.0 | ||
| stochastic_depth_mode: linear # linear or uniform | ||
| stochastic_depth_start_layer: 1 | ||
|
|
||
| # When true, sync max-audio-length across distributed ranks before extending positional buffers | ||
| sync_max_audio_length: true | ||
|
|
||
| decoder: | ||
| _target_: nemo.collections.asr.modules.ConvASRDecoder | ||
| feat_in: null | ||
| num_classes: -1 | ||
| vocabulary: [] | ||
|
|
||
| # config for InterCTC loss: https://arxiv.org/abs/2102.03216 | ||
| # specify loss weights and which layers to use for InterCTC | ||
| # e.g., to reproduce the paper results, set loss_weights: [0.3] | ||
| # and apply_at_layers: [8] (assuming 18 layers). Note that final | ||
| # layer loss coefficient is automatically adjusted (to 0.7 in above example) | ||
| interctc: | ||
| loss_weights: [] | ||
| apply_at_layers: [] | ||
|
|
||
| optim: | ||
| name: adamw | ||
| lr: 1e-3 | ||
| # optimizer arguments | ||
| betas: [0.9, 0.98] | ||
| # less necessity for weight_decay as we already have large augmentations with SpecAug | ||
| # you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used | ||
| # weight decay of 0.0 with lr of 2.0 also works fine | ||
| weight_decay: 1e-3 | ||
|
|
||
| # scheduler setup | ||
| sched: | ||
| name: CosineAnnealing | ||
| # scheduler config override | ||
| warmup_steps: 15000 | ||
| warmup_ratio: null | ||
| min_lr: 1e-4 | ||
|
|
||
| trainer: | ||
| devices: -1 # number of GPUs, -1 would use all available GPUs | ||
| num_nodes: 1 | ||
| max_epochs: 1000 | ||
| max_steps: -1 # computed at runtime if not set | ||
| val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
| accelerator: auto | ||
| strategy: | ||
| _target_: lightning.pytorch.strategies.DDPStrategy | ||
| gradient_as_bucket_view: true | ||
| accumulate_grad_batches: 1 | ||
| gradient_clip_val: 0.0 | ||
| precision: 32 # 16, 32, or bf16 | ||
| log_every_n_steps: 10 # Interval of logging. | ||
| enable_progress_bar: True | ||
| num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
| check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
| sync_batchnorm: true | ||
| enable_checkpointing: False # Provided by exp_manager | ||
| logger: false # Provided by exp_manager | ||
| benchmark: false # needs to be false for models with variable-length speech input as it slows down training | ||
|
|
||
| exp_manager: | ||
| exp_dir: null | ||
| name: ${name} | ||
| create_tensorboard_logger: true | ||
| create_checkpoint_callback: true | ||
| checkpoint_callback_params: | ||
| # in case of multiple validation sets, first one is used | ||
| monitor: "val_wer" | ||
| mode: "min" | ||
| save_top_k: 5 | ||
| always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints | ||
|
|
||
| resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. | ||
| # you need to set these two to True to continue the training | ||
| resume_if_exists: false | ||
| resume_ignore_no_checkpoint: false | ||
|
|
||
| # You may use this section to create a W&B logger | ||
| create_wandb_logger: false | ||
| wandb_logger_kwargs: | ||
| name: null | ||
| project: null |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you also move
https://github.com/NVIDIA-NeMo/NeMo/blob/main/examples/asr/conf/fastconformer/transformer_stacking_tdt_bpe.yamlto this folder.examples/asr/conf/transformer/