diff --git a/CLAUDE.md b/CLAUDE.md index 7a1cd849f9e7..eed65f2ba641 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -21,9 +21,9 @@ Requires Python 3.10+, PyTorch 2.6+. - **Line length: 119** (not default 88) — consistent across black, isort, flake8 - Black with `skip_string_normalization = true` - isort with `profile = black` -- Check: `python setup.py style --scope ` -- Fix: `python setup.py style --scope --fix` -- **Incremental reformatting**: most collections are excluded from black (see `extend-exclude` in pyproject.toml). The files are reformatted when somebody makes changes to avoid a single big reformatting PR. Do not reformat files outside your changes. +- Check: `isort --check && black --check ` or `isort --check . && black --check .` +- Fix: `isort && black ` or `isort . && black .` +- Jupyter Notebooks are excluded from automatic black reformatting (see `extend-exclude`), but can be still reformatted when passed directly. Do not reformat notebooks outside your changes. ## Testing diff --git a/examples/asr/asr_adapters/scoring_and_analysis.py b/examples/asr/asr_adapters/scoring_and_analysis.py index 2a5602a65658..6bce1bf5023c 100644 --- a/examples/asr/asr_adapters/scoring_and_analysis.py +++ b/examples/asr/asr_adapters/scoring_and_analysis.py @@ -202,7 +202,12 @@ def display_results(df_all: pd.DataFrame, category: str, best_config: pd.Series, def get_best_config( - df_exp: pd.DataFrame, dataset_type_col: str, key_info: dict, topk: int, show_analysis: bool, exp_type: str, + df_exp: pd.DataFrame, + dataset_type_col: str, + key_info: dict, + topk: int, + show_analysis: bool, + exp_type: str, ): """Get the best hyperparameter configuration for a given subset of experiments. diff --git a/examples/asr/export/transducer/infer_transducer_onnx.py b/examples/asr/export/transducer/infer_transducer_onnx.py index 2d39941c6282..e6e7dab3ee88 100644 --- a/examples/asr/export/transducer/infer_transducer_onnx.py +++ b/examples/asr/export/transducer/infer_transducer_onnx.py @@ -60,7 +60,11 @@ def parse_arguments(): parser = ArgumentParser() parser.add_argument( - "--nemo_model", type=str, default=None, required=False, help="Path to .nemo file", + "--nemo_model", + type=str, + default=None, + required=False, + help="Path to .nemo file", ) parser.add_argument( '--pretrained_model', type=str, default=None, required=False, help='Name of a pretrained NeMo file' diff --git a/examples/asr/export/transducer/infer_transducer_ts.py b/examples/asr/export/transducer/infer_transducer_ts.py index 8e7b71a1d7d2..6ee0c06869f7 100644 --- a/examples/asr/export/transducer/infer_transducer_ts.py +++ b/examples/asr/export/transducer/infer_transducer_ts.py @@ -63,7 +63,11 @@ def parse_arguments(): parser = ArgumentParser() parser.add_argument( - "--nemo_model", type=str, default=None, required=False, help="Path to .nemo file", + "--nemo_model", + type=str, + default=None, + required=False, + help="Path to .nemo file", ) parser.add_argument( '--pretrained_model', type=str, default=None, required=False, help='Name of a pretrained NeMo file' diff --git a/examples/asr/speech_classification/vad_infer.py b/examples/asr/speech_classification/vad_infer.py index 8ab040b34c79..08ebb78d05fc 100644 --- a/examples/asr/speech_classification/vad_infer.py +++ b/examples/asr/speech_classification/vad_infer.py @@ -91,7 +91,9 @@ def main(cfg): 'vad_stream': True, 'sample_rate': 16000, 'manifest_filepath': manifest_vad_input, - 'labels': ['infer',], + 'labels': [ + 'infer', + ], 'num_workers': cfg.num_workers, 'shuffle': False, 'window_length_in_sec': cfg.vad.parameters.window_length_in_sec, diff --git a/examples/tts/aligner_heteronym_disambiguation.py b/examples/tts/aligner_heteronym_disambiguation.py index c97d3db5f24f..a839f5f67b99 100644 --- a/examples/tts/aligner_heteronym_disambiguation.py +++ b/examples/tts/aligner_heteronym_disambiguation.py @@ -44,8 +44,7 @@ def get_args(): - """Retrieve arguments for disambiguation. - """ + """Retrieve arguments for disambiguation.""" parser = argparse.ArgumentParser("G2P disambiguation using Aligner input embedding distances.") # TODO(jocelynh): Make this required=False with default download from NGC once ckpt uploaded parser.add_argument('--model', required=True, type=str, help="Path to Aligner model checkpoint (.nemo file).") @@ -85,8 +84,7 @@ def get_args(): def load_and_prepare_audio(aligner, audio_path, target_sr, device): - """Loads and resamples audio to target sample rate (if necessary), and preprocesses for Aligner input. - """ + """Loads and resamples audio to target sample rate (if necessary), and preprocesses for Aligner input.""" # Load audio and get length for preprocessing audio_data, orig_sr = sf.read(audio_path) if orig_sr != target_sr: @@ -238,8 +236,7 @@ def disambiguate_candidates(aligner, text, spec, spec_len, confidence, device, h def disambiguate_dataset( aligner, manifest_path, out_path, sr, heteronyms, confidence, device, verbose, heteronyms_only=True ): - """Disambiguates the phonemes for all words with ambiguous pronunciations in the given manifest. - """ + """Disambiguates the phonemes for all words with ambiguous pronunciations in the given manifest.""" log_file = open('disambiguation_logs.txt', 'w') if verbose else None with open(out_path, 'w') as f_out: diff --git a/external/get_collections.py b/external/get_collections.py index d546ccb05709..6eaaf285ff99 100644 --- a/external/get_collections.py +++ b/external/get_collections.py @@ -25,8 +25,8 @@ def process_collection(id, col): - """ Helper function processing the collection. - + """Helper function processing the collection. + Args: id: (short) name of the collection. col: a collection (python module). @@ -41,7 +41,7 @@ def process_collection(id, col): def main(): - """ Main function generating a JSON file with list of NeMo collections. """ + """Main function generating a JSON file with list of NeMo collections.""" # Parse filename. parser = argparse.ArgumentParser() parser.add_argument('--filename', help='Name of the output JSON file', type=str, default="collections.json") diff --git a/external/get_modules.py b/external/get_modules.py index c080be9e8e5a..704f91d77d9d 100644 --- a/external/get_modules.py +++ b/external/get_modules.py @@ -26,8 +26,8 @@ def process_member(name, obj, module_list): - """ Helper function processing the passed object and, if ok, adding a record to the module list. - + """Helper function processing the passed object and, if ok, adding a record to the module list. + Args: name: name of the member obj: member (class/function etc.) @@ -74,7 +74,7 @@ def process_member(name, obj, module_list): def main(): - """ Main function analysing the indicated NeMo collection and generating a JSON file with module descriptions. """ + """Main function analysing the indicated NeMo collection and generating a JSON file with module descriptions.""" # Parse filename. parser = argparse.ArgumentParser() parser.add_argument('--collection', help='ID of the collection', type=str) diff --git a/nemo/collections/asr/data/audio_to_ctm_dataset.py b/nemo/collections/asr/data/audio_to_ctm_dataset.py index 54503053ae36..c2b8dffdedc6 100644 --- a/nemo/collections/asr/data/audio_to_ctm_dataset.py +++ b/nemo/collections/asr/data/audio_to_ctm_dataset.py @@ -24,8 +24,7 @@ @dataclass class FrameCtmUnit: - """A container class for one CTM unit with start and length countable in frames. - """ + """A container class for one CTM unit with start and length countable in frames.""" label: str start_frame: int diff --git a/nemo/collections/asr/data/audio_to_label_dataset.py b/nemo/collections/asr/data/audio_to_label_dataset.py index dcead6df94b8..c70843461bb1 100644 --- a/nemo/collections/asr/data/audio_to_label_dataset.py +++ b/nemo/collections/asr/data/audio_to_label_dataset.py @@ -131,7 +131,11 @@ def get_tarred_classification_label_dataset( def get_concat_tarred_speech_label_dataset( - featurizer, config: dict, shuffle_n: int, global_rank: int, world_size: int, + featurizer, + config: dict, + shuffle_n: int, + global_rank: int, + world_size: int, ): tarred_audio_filepaths = config['tarred_audio_filepaths'] manifest_filepaths = config['manifest_filepath'] @@ -143,7 +147,11 @@ def get_concat_tarred_speech_label_dataset( conf['manifest_filepath'] = manifest_filepath conf['tarred_audio_filepaths'] = tarred_audio_filepath dataset = get_tarred_speech_label_dataset( - config=conf, featurizer=featurizer, shuffle_n=shuffle_n, global_rank=global_rank, world_size=world_size, + config=conf, + featurizer=featurizer, + shuffle_n=shuffle_n, + global_rank=global_rank, + world_size=world_size, ) datasets.append(dataset) @@ -160,7 +168,11 @@ def get_concat_tarred_speech_label_dataset( def get_tarred_speech_label_dataset( - featurizer, config: dict, shuffle_n: int, global_rank: int, world_size: int, + featurizer, + config: dict, + shuffle_n: int, + global_rank: int, + world_size: int, ) -> audio_to_label.TarredAudioToSpeechLabelDataset: """ InInstantiates a Speech Label (e.g. VAD, speaker recognition) TarredAudioLabelDataset. diff --git a/nemo/collections/asr/data/feature_to_label.py b/nemo/collections/asr/data/feature_to_label.py index 058d0157fcbd..b635957a7b23 100644 --- a/nemo/collections/asr/data/feature_to_label.py +++ b/nemo/collections/asr/data/feature_to_label.py @@ -26,7 +26,7 @@ def _feature_collate_fn(batch): """collate batch of feat sig, feat len, labels, labels len, assuming all features have the same shape. Args: batch (FloatTensor, LongTensor, LongTensor, LongTensor): A tuple of tuples of feature, feature lengths, - encoded labels, and encoded labels length. + encoded labels, and encoded labels length. """ packed_batch = list(zip(*batch)) if len(packed_batch) == 5: @@ -61,7 +61,7 @@ def _audio_feature_collate_fn(batch, feat_pad_val, label_pad_id): Args: batch (Optional[FloatTensor], Optional[LongTensor], LongTensor, LongTensor): A tuple of tuples of feature, feature lengths, - labels, and label lengths. This collate func assumes the + labels, and label lengths. This collate func assumes the features are torch tensors of Log-Melspectrogram (i.e. [N_MEL, T]). """ packed_batch = list(zip(*batch)) @@ -178,8 +178,7 @@ class _FeatureSeqSpeakerLabelDataset(Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" # TODO output type for external features output_types = { 'external_feat': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), @@ -197,16 +196,26 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: ) else: output_types.update( - {'label': NeuralType(('B', 'T'), LabelsType()), 'label_length': NeuralType(tuple('B'), LengthsType()),} + { + 'label': NeuralType(('B', 'T'), LabelsType()), + 'label_length': NeuralType(tuple('B'), LengthsType()), + } ) return output_types def __init__( - self, *, manifest_filepath: str, labels: List[str], feature_loader, is_speaker_emb: bool = False, + self, + *, + manifest_filepath: str, + labels: List[str], + feature_loader, + is_speaker_emb: bool = False, ): super().__init__() - self.collection = collections.ASRFeatureSequenceLabel(manifests_files=manifest_filepath.split(','),) + self.collection = collections.ASRFeatureSequenceLabel( + manifests_files=manifest_filepath.split(','), + ) self.feature_loader = feature_loader self.labels = labels if labels else self.collection.uniq_labels @@ -259,12 +268,12 @@ def _collate_fn(self, batch): class FeatureToLabelDataset(Dataset): """ - Dataset that loads tensors via a json file containing paths to feature files and their labels. + Dataset that loads tensors via a json file containing paths to feature files and their labels. Each new line is a different sample. Example below: and their target labels. JSON files should be of the following format: {"feature_filepath": "/path/to/audio_feature.pt", "label": "1"} ... - {"feature_filepath": "/path/to/audio_feature.pt", "label": "0"} + {"feature_filepath": "/path/to/audio_feature.pt", "label": "0"} Args: manifest_filepath (str): Path to JSON containing data. labels (Optional[list]): List of unique labels collected from all samples. @@ -283,8 +292,7 @@ class FeatureToLabelDataset(Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" output_types = { 'audio_feat': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), 'feat_length': NeuralType(tuple('B'), LengthsType()), @@ -375,12 +383,12 @@ def _vad_segment_collate_fn(self, batch): class FeatureToMultiLabelDataset(Dataset): """ - Dataset that loads tensors via a json file containing paths to feature files and their labels. + Dataset that loads tensors via a json file containing paths to feature files and their labels. Each new line is a different sample. Example below: and their target labels. JSON files should be of the following format: {"feature_filepath": "/path/to/audio_feature.pt", "label": "1 1 0 0 1"} ... - {"feature_filepath": "/path/to/audio_feature.pt", "label": "0 1 0 0"} + {"feature_filepath": "/path/to/audio_feature.pt", "label": "0 1 0 0"} Args: manifest_filepath (str): Path to JSON containing data. labels (Optional[list]): List of unique labels collected from all samples. @@ -397,8 +405,7 @@ class FeatureToMultiLabelDataset(Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" output_types = { 'audio_feat': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), 'feat_length': NeuralType(tuple('B'), LengthsType()), diff --git a/nemo/collections/asr/data/feature_to_label_dataset.py b/nemo/collections/asr/data/feature_to_label_dataset.py index 08803f43ce8d..66ad047aac90 100644 --- a/nemo/collections/asr/data/feature_to_label_dataset.py +++ b/nemo/collections/asr/data/feature_to_label_dataset.py @@ -28,7 +28,9 @@ def get_feature_seq_speakerlabel_dataset( An instance of FeatureToSeqSpeakerLabelDataset. """ dataset = feature_to_label.FeatureToSeqSpeakerLabelDataset( - manifest_filepath=config['manifest_filepath'], labels=config['labels'], feature_loader=feature_loader, + manifest_filepath=config['manifest_filepath'], + labels=config['labels'], + feature_loader=feature_loader, ) return dataset diff --git a/nemo/collections/asr/data/huggingface/hf_audio_to_text_dataset.py b/nemo/collections/asr/data/huggingface/hf_audio_to_text_dataset.py index 0b36d58666f6..8e2e62cf7bf6 100644 --- a/nemo/collections/asr/data/huggingface/hf_audio_to_text_dataset.py +++ b/nemo/collections/asr/data/huggingface/hf_audio_to_text_dataset.py @@ -23,7 +23,11 @@ def get_hf_audio_to_text_bpe_dataset( - config: DictConfig, global_rank: int, world_size: int, tokenizer, augmentor=None, + config: DictConfig, + global_rank: int, + world_size: int, + tokenizer, + augmentor=None, ): if "streaming" in config and config["streaming"]: dataset = HFIterableAudioToBPEDataset( @@ -72,7 +76,10 @@ def get_hf_audio_to_text_bpe_dataset( def get_hf_audio_to_text_char_dataset( - config: DictConfig, global_rank: int, world_size: int, augmentor=None, + config: DictConfig, + global_rank: int, + world_size: int, + augmentor=None, ): if "streaming" in config and config["streaming"]: dataset = HFIterableAudioToCharDataset( diff --git a/nemo/collections/asr/losses/angularloss.py b/nemo/collections/asr/losses/angularloss.py index e2aee9bba6ea..57459b1f478f 100644 --- a/nemo/collections/asr/losses/angularloss.py +++ b/nemo/collections/asr/losses/angularloss.py @@ -27,13 +27,12 @@ class AngularSoftmaxLoss(Loss, Typing): reference: https://openaccess.thecvf.com/content_CVPR_2019/papers/Deng_ArcFace_Additive_Angular_Margin_Loss_for_Deep_Face_Recognition_CVPR_2019_paper.pdf args: scale: scale value for cosine angle - margin: margin value added to cosine angle + margin: margin value added to cosine angle """ @property def input_types(self): - """Input types definitions for AnguarLoss. - """ + """Input types definitions for AnguarLoss.""" return { "logits": NeuralType(('B', 'D'), LogitsType()), "labels": NeuralType(('B',), LabelsType()), diff --git a/nemo/collections/asr/losses/ctc.py b/nemo/collections/asr/losses/ctc.py index 8a1f72448893..d6373941408a 100644 --- a/nemo/collections/asr/losses/ctc.py +++ b/nemo/collections/asr/losses/ctc.py @@ -25,8 +25,7 @@ class CTCLoss(nn.CTCLoss, Serialization, Typing): @property def input_types(self): - """Input types definitions for CTCLoss. - """ + """Input types definitions for CTCLoss.""" return { "log_probs": NeuralType(('B', 'T', 'D'), LogprobsType()), "targets": NeuralType(('B', 'T'), LabelsType()), diff --git a/nemo/collections/asr/losses/rnnt_pytorch.py b/nemo/collections/asr/losses/rnnt_pytorch.py index c8eee90a2eb5..9d34148ae393 100644 --- a/nemo/collections/asr/losses/rnnt_pytorch.py +++ b/nemo/collections/asr/losses/rnnt_pytorch.py @@ -24,8 +24,7 @@ class RNNTLossPytorch(Loss): @property def input_types(self): - """Input types definitions for CTCLoss. - """ + """Input types definitions for CTCLoss.""" return { "acts": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), "labels": NeuralType(('B', 'T'), LabelsType()), @@ -126,8 +125,7 @@ class TDTLossPytorch(Loss): @property def input_types(self): - """Input types definitions for CTCLoss. - """ + """Input types definitions for CTCLoss.""" return { "acts": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), "labels": NeuralType(('B', 'T'), LabelsType()), @@ -256,8 +254,7 @@ class MultiblankRNNTLossPytorch(Loss): @property def input_types(self): - """Input types definitions for CTCLoss. - """ + """Input types definitions for CTCLoss.""" return { "acts": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), "labels": NeuralType(('B', 'T'), LabelsType()), diff --git a/nemo/collections/asr/losses/ssl_losses/ctc.py b/nemo/collections/asr/losses/ssl_losses/ctc.py index e71d60ac4956..0f3710482c32 100644 --- a/nemo/collections/asr/losses/ssl_losses/ctc.py +++ b/nemo/collections/asr/losses/ssl_losses/ctc.py @@ -22,8 +22,7 @@ class CTCLossForSSL(Loss): @property def input_types(self): - """Input types definitions for Contrastive. - """ + """Input types definitions for Contrastive.""" return { "spec_masks": NeuralType(("B", "D", "T"), SpectrogramType()), "decoder_outputs": NeuralType(("B", "T", "D"), VoidType()), diff --git a/nemo/collections/asr/losses/ssl_losses/rnnt.py b/nemo/collections/asr/losses/ssl_losses/rnnt.py index 0336063638f7..be68f00003d1 100644 --- a/nemo/collections/asr/losses/ssl_losses/rnnt.py +++ b/nemo/collections/asr/losses/ssl_losses/rnnt.py @@ -22,8 +22,7 @@ class RNNTLossForSSL(Loss): @property def input_types(self): - """Input types definitions for Contrastive. - """ + """Input types definitions for Contrastive.""" return { "spec_masks": NeuralType(("B", "D", "T"), SpectrogramType()), "decoder_outputs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), diff --git a/nemo/collections/asr/models/online_diarizer.py b/nemo/collections/asr/models/online_diarizer.py index 7074b92f4c04..b2b3dfd7b78f 100644 --- a/nemo/collections/asr/models/online_diarizer.py +++ b/nemo/collections/asr/models/online_diarizer.py @@ -445,7 +445,9 @@ def _extract_online_embeddings( @timeit def _perform_online_clustering( - self, uniq_embs_and_timestamps: Dict[str, torch.Tensor], cuda=False, + self, + uniq_embs_and_timestamps: Dict[str, torch.Tensor], + cuda=False, ) -> torch.Tensor: """ Launch online clustering for `uniq_embs_and_timestamps` input variable. @@ -476,7 +478,10 @@ def _perform_online_clustering( base_segment_indexes = torch.tensor(self.segment_indexes[self.base_scale_index]).to(curr_emb.device) merged_clus_labels = self.online_clus.forward_infer( - curr_emb=curr_emb, base_segment_indexes=base_segment_indexes, frame_index=self.frame_index, cuda=cuda, + curr_emb=curr_emb, + base_segment_indexes=base_segment_indexes, + frame_index=self.frame_index, + cuda=cuda, ) # Update history data for scale_idx, (window, shift) in self.multiscale_args_dict['scale_dict'].items(): @@ -572,7 +577,10 @@ def diarize_step(self, audio_buffer: torch.Tensor, vad_timestamps: torch.Tensor) ) # Step 3 - Clustering: Perform an online version of clustering algorithm - cluster_label_hyp = self._perform_online_clustering(embs_and_timestamps[self.uniq_id], cuda=self.cuda,) + cluster_label_hyp = self._perform_online_clustering( + embs_and_timestamps[self.uniq_id], + cuda=self.cuda, + ) # Step 4: Generate RTTM style diarization labels from segment ranges and cluster labels diar_hyp, _ = generate_cluster_labels(self.memory_segment_ranges[self.base_scale_index], cluster_label_hyp) diff --git a/nemo/collections/asr/modules/beam_search_decoder.py b/nemo/collections/asr/modules/beam_search_decoder.py index b39804ae8e65..ac6c7ac18870 100644 --- a/nemo/collections/asr/modules/beam_search_decoder.py +++ b/nemo/collections/asr/modules/beam_search_decoder.py @@ -43,8 +43,7 @@ class BeamSearchDecoderWithLM(NeuralModule): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "log_probs": NeuralType(('B', 'T', 'D'), LogprobsType()), "log_probs_length": NeuralType(tuple('B'), LengthsType()), @@ -52,8 +51,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": NeuralType(('B', 'T'), PredictionsType())} def __init__( diff --git a/nemo/collections/asr/modules/flashlight_decoder.py b/nemo/collections/asr/modules/flashlight_decoder.py index 05a111ebb7e7..419f7fe803e1 100644 --- a/nemo/collections/asr/modules/flashlight_decoder.py +++ b/nemo/collections/asr/modules/flashlight_decoder.py @@ -196,7 +196,14 @@ def __init__( ) self.decoder = LexiconDecoder( - self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], False, + self.decoder_opts, + self.trie, + self.lm, + self.silence, + self.blank, + self.unk_word, + [], + False, ) else: from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions diff --git a/nemo/collections/asr/modules/rnn_encoder.py b/nemo/collections/asr/modules/rnn_encoder.py index 0ebb89f545e2..bb87701ae22c 100644 --- a/nemo/collections/asr/modules/rnn_encoder.py +++ b/nemo/collections/asr/modules/rnn_encoder.py @@ -67,8 +67,7 @@ def input_example(self): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return OrderedDict( { "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), @@ -78,8 +77,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return OrderedDict( { "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), diff --git a/nemo/collections/asr/modules/transformer/bridge_encoders.py b/nemo/collections/asr/modules/transformer/bridge_encoders.py index 5c72d27b9ebf..901b6c655bfc 100644 --- a/nemo/collections/asr/modules/transformer/bridge_encoders.py +++ b/nemo/collections/asr/modules/transformer/bridge_encoders.py @@ -50,12 +50,17 @@ def __init__( if self.hidden_init_method not in self.supported_init_methods: raise ValueError( "Unknown hidden_init_method = {hidden_init_method}, supported methods are {supported_init_methods}".format( - hidden_init_method=self.hidden_init_method, supported_init_methods=self.supported_init_methods, + hidden_init_method=self.hidden_init_method, + supported_init_methods=self.supported_init_methods, ) ) # attention bridge - self.att_bridge = AttentionBridge(hidden_size=hidden_size, k=hidden_steps, bridge_size=inner_size,) + self.att_bridge = AttentionBridge( + hidden_size=hidden_size, + k=hidden_steps, + bridge_size=inner_size, + ) if self.hidden_init_method == "enc": self.init_hidden_enc = TransformerEncoder( diff --git a/nemo/collections/asr/modules/transformer/decoder_module.py b/nemo/collections/asr/modules/transformer/decoder_module.py index d1cb8ac9b1f0..c65d2db16af0 100644 --- a/nemo/collections/asr/modules/transformer/decoder_module.py +++ b/nemo/collections/asr/modules/transformer/decoder_module.py @@ -22,7 +22,7 @@ class DecoderModule(NeuralModule, ABC): - """ Base class for decoder neural module to be used in NLP models. """ + """Base class for decoder neural module to be used in NLP models.""" @property def input_types(self) -> Optional[Dict[str, NeuralType]]: diff --git a/nemo/collections/asr/modules/transformer/encoder_module.py b/nemo/collections/asr/modules/transformer/encoder_module.py index bd3912e0e693..991034be91bd 100644 --- a/nemo/collections/asr/modules/transformer/encoder_module.py +++ b/nemo/collections/asr/modules/transformer/encoder_module.py @@ -22,7 +22,7 @@ class EncoderModule(NeuralModule, ABC): - """ Base class for encoder neural module to be used in NLP models. """ + """Base class for encoder neural module to be used in NLP models.""" @property def input_types(self) -> Optional[Dict[str, NeuralType]]: diff --git a/nemo/collections/asr/modules/transformer/perceiver_encoders.py b/nemo/collections/asr/modules/transformer/perceiver_encoders.py index e836e20be7bc..04e01cff25e8 100644 --- a/nemo/collections/asr/modules/transformer/perceiver_encoders.py +++ b/nemo/collections/asr/modules/transformer/perceiver_encoders.py @@ -53,7 +53,8 @@ def __init__( if self.hidden_init_method not in self.supported_init_methods: raise ValueError( "Unknown hidden_init_method = {hidden_init_method}, supported methods are {supported_init_methods}".format( - hidden_init_method=self.hidden_init_method, supported_init_methods=self.supported_init_methods, + hidden_init_method=self.hidden_init_method, + supported_init_methods=self.supported_init_methods, ) ) @@ -77,7 +78,11 @@ def __init__( self.init_cross_att.diagonal = diagonal elif self.hidden_init_method == "bridge": # initialize latent with attention bridge - self.att_bridge = AttentionBridge(hidden_size=hidden_size, k=hidden_steps, bridge_size=inner_size,) + self.att_bridge = AttentionBridge( + hidden_size=hidden_size, + k=hidden_steps, + bridge_size=inner_size, + ) # cross-attention encoder layer = TransformerDecoder( @@ -150,7 +155,10 @@ def forward(self, encoder_states, encoder_mask): ) elif self._hidden_init_method == "bridge": # initialize latent with attention bridge - hidden_states = self.att_bridge(hidden=encoder_states, hidden_mask=encoder_mask,) + hidden_states = self.att_bridge( + hidden=encoder_states, + hidden_mask=encoder_mask, + ) # apply block (cross-attention, self-attention) multiple times # for block in range(self._hidden_blocks): @@ -166,7 +174,10 @@ def forward(self, encoder_states, encoder_mask): ) # self-attention over hidden - hidden_states = self_att(encoder_states=hidden_states, encoder_mask=hidden_mask,) + hidden_states = self_att( + encoder_states=hidden_states, + encoder_mask=hidden_mask, + ) # residual connection hidden_states += residual diff --git a/nemo/collections/asr/modules/transformer/reduction_encoders.py b/nemo/collections/asr/modules/transformer/reduction_encoders.py index 0c3355b0949f..90aec5e3e32d 100644 --- a/nemo/collections/asr/modules/transformer/reduction_encoders.py +++ b/nemo/collections/asr/modules/transformer/reduction_encoders.py @@ -57,7 +57,8 @@ def __init__( if self.hidden_init_method not in self.supported_init_methods: raise ValueError( "Unknown hidden_init_method = {hidden_init_method}, supported methods are {supported_init_methods}".format( - hidden_init_method=self.hidden_init_method, supported_init_methods=self.supported_init_methods, + hidden_init_method=self.hidden_init_method, + supported_init_methods=self.supported_init_methods, ) ) diff --git a/nemo/collections/asr/modules/transformer/text_generation.py b/nemo/collections/asr/modules/transformer/text_generation.py index a261e925691f..b0980f96a652 100644 --- a/nemo/collections/asr/modules/transformer/text_generation.py +++ b/nemo/collections/asr/modules/transformer/text_generation.py @@ -64,10 +64,10 @@ def generate( Args: inputs (Union[List[str], Tensor, List[dict]]): - Can be one of the 3 types: + Can be one of the 3 types: 1. List of strings. Each element of the list provides input prompt. The model will apply tokenizer on it. E.g [‘sentence’, ‘sentence2’ … ] - 2. Tuple of Pytorch Tensors (context_tokens, context_lengths). The `context_tokens` has shape (batch_size, seq_length), it's the batched sequences of tokens used as a prompst for the generation or as model inputs to the encoder. + 2. Tuple of Pytorch Tensors (context_tokens, context_lengths). The `context_tokens` has shape (batch_size, seq_length), it's the batched sequences of tokens used as a prompst for the generation or as model inputs to the encoder. The generative model will skip the tokenization and padding step. The `context_lengths` has shape (batch_size,), it indicates the length of the context tokens for each of the input sequences. E.g. ( torch.tensor([[23,5234,23,35,…], [223,323,23,23232,232,...] …]), torch.tensor([20, 30, …])) 3. List of python dict objects. Used for prompt/p-tuning inputs where a set of key-value pairs are converted into input token embeddings for the model. @@ -84,7 +84,7 @@ def generate( use_greedy: bool, Whether or not to use sampling ; use greedy decoding otherwise top_k: int, The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p: float, If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. - repetition_penalty: float, The parameter for repetition penalty. 1.0 means no penalty. + repetition_penalty: float, The parameter for repetition penalty. 1.0 means no penalty. add_BOS: bool, Whether add the bos token at the begining of the prompt all_probs: bool # whether return the log prob for all the tokens in vocab compute_logprob: bool # a flag used to compute logprob of all the input text, a very special case of running inference, default False diff --git a/nemo/collections/asr/modules/transformer/transformer_bottleneck.py b/nemo/collections/asr/modules/transformer/transformer_bottleneck.py index c463b4de1c70..758287107dd3 100644 --- a/nemo/collections/asr/modules/transformer/transformer_bottleneck.py +++ b/nemo/collections/asr/modules/transformer/transformer_bottleneck.py @@ -212,7 +212,9 @@ def _build_encoder(self, arch, **kwargs): def input_types(self) -> Optional[Dict[str, NeuralType]]: input_types = super().input_types input_types.update( - {"return_mask": NeuralType((), BoolType(), True),} + { + "return_mask": NeuralType((), BoolType(), True), + } ) return input_types @@ -221,7 +223,9 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: def output_types(self) -> Optional[Dict[str, NeuralType]]: output_types = super().output_types output_types.update( - {"hidden_mask": NeuralType(('B', 'T'), MaskType(), True),} + { + "hidden_mask": NeuralType(('B', 'T'), MaskType(), True), + } ) return output_types @@ -245,7 +249,8 @@ def forward(self, input_ids, encoder_mask, return_mask=None): encoder_hidden_mask = encoder_mask else: encoder_hidden_states, encoder_hidden_mask = self._encoder( - encoder_states=embeddings, encoder_mask=encoder_mask, + encoder_states=embeddings, + encoder_mask=encoder_mask, ) if return_mask: diff --git a/nemo/collections/asr/parts/k2/w_transducer.py b/nemo/collections/asr/parts/k2/w_transducer.py index b38a6c560fcd..c532aafffefb 100644 --- a/nemo/collections/asr/parts/k2/w_transducer.py +++ b/nemo/collections/asr/parts/k2/w_transducer.py @@ -289,9 +289,9 @@ def get_grid(self, units_tensor: torch.Tensor, num_frames: int, vocab_size: int) arcs[:-3, 1] = self.relabel_states(arcs[:-3, 1], text_length + 1, num_frames) if self.last_blank_mode == self.LastBlankMode.ALLOW_IGNORE: - arcs[ - num_forward_arcs_base + (num_frames - 1) : num_forward_arcs_base + (num_frames - 1) * 2, 1 - ] = num_grid_states + arcs[num_forward_arcs_base + (num_frames - 1) : num_forward_arcs_base + (num_frames - 1) * 2, 1] = ( + num_grid_states + ) # sort by start state - required in k2 # TODO: maybe it is more optimal to avoid sort, construct arcs in ascending order @@ -305,7 +305,11 @@ def get_grid(self, units_tensor: torch.Tensor, num_frames: int, vocab_size: int) return rnnt_graph def forward( - self, acts: torch.Tensor, labels: torch.Tensor, act_lens: torch.Tensor, label_lens: torch.Tensor, + self, + acts: torch.Tensor, + labels: torch.Tensor, + act_lens: torch.Tensor, + label_lens: torch.Tensor, ): """ Forward method is similar to RNN-T Graph-Transducer forward method, diff --git a/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py b/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py index fcf5d5cebfeb..2ed90dbc9ec6 100644 --- a/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py +++ b/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py @@ -203,8 +203,7 @@ class SpecAugmentNumba(nn.Module, Typing): @property def input_types(self): - """Returns definitions of module input types - """ + """Returns definitions of module input types""" return { "input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()), "length": NeuralType(tuple('B'), LengthsType()), @@ -212,12 +211,17 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output types - """ + """Returns definitions of module output types""" return {"augmented_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())} def __init__( - self, freq_masks=0, time_masks=0, freq_width=10, time_width=0.1, rng=None, mask_value=0.0, + self, + freq_masks=0, + time_masks=0, + freq_width=10, + time_width=0.1, + rng=None, + mask_value=0.0, ): super().__init__() # Message to mention that numba specaugment kernel will be available diff --git a/nemo/collections/asr/parts/submodules/causal_convs.py b/nemo/collections/asr/parts/submodules/causal_convs.py index 32f08a8d2feb..5946a839b6d8 100644 --- a/nemo/collections/asr/parts/submodules/causal_convs.py +++ b/nemo/collections/asr/parts/submodules/causal_convs.py @@ -62,7 +62,8 @@ def __init__( ) def forward( - self, x, + self, + x, ): x = F.pad(x, pad=(self._left_padding, self._right_padding, self._left_padding, self._right_padding)) x = super().forward(x) diff --git a/nemo/collections/asr/parts/submodules/classifier.py b/nemo/collections/asr/parts/submodules/classifier.py index 7d9e42593c1c..1b7250904824 100644 --- a/nemo/collections/asr/parts/submodules/classifier.py +++ b/nemo/collections/asr/parts/submodules/classifier.py @@ -37,7 +37,11 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: """ return {"hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())} - def __init__(self, hidden_size: int, dropout: float = 0.0,) -> None: + def __init__( + self, + hidden_size: int, + dropout: float = 0.0, + ) -> None: """ Initializes the Classifier base module. Args: diff --git a/nemo/collections/asr/parts/submodules/stateless_net.py b/nemo/collections/asr/parts/submodules/stateless_net.py index 7581fdc2834d..8d5345f5dfc4 100644 --- a/nemo/collections/asr/parts/submodules/stateless_net.py +++ b/nemo/collections/asr/parts/submodules/stateless_net.py @@ -65,7 +65,9 @@ def __init__(self, context_size, vocab_size, emb_dim, blank_idx, normalization_m self.blank_idx = blank_idx def forward( - self, y: Optional[torch.Tensor] = None, state: Optional[List[torch.Tensor]] = None, + self, + y: Optional[torch.Tensor] = None, + state: Optional[List[torch.Tensor]] = None, ): """ Although this is a *stateless* net, we use the "state" parameter to diff --git a/nemo/collections/asr/parts/utils/longform_clustering.py b/nemo/collections/asr/parts/utils/longform_clustering.py index 171c074d9e10..e000e414c0d1 100644 --- a/nemo/collections/asr/parts/utils/longform_clustering.py +++ b/nemo/collections/asr/parts/utils/longform_clustering.py @@ -28,11 +28,11 @@ class LongFormSpeakerClustering(torch.nn.Module): def __init__(self, cuda: bool = False): """ Initializes a speaker clustering class tailored for long-form audio, leveraging methods from the `SpeakerClustering` class. - The clustering algorithm for long-form content is executed via the `forward_infer` function (not shown here). Input embedding - vectors are divided into chunks, each of size `embeddings_per_chunk`. Within every chunk, the clustering algorithm aims - to identify `chunk_cluster_count` distinct clusters. The resulting clustering labels are then expanded to match the original + The clustering algorithm for long-form content is executed via the `forward_infer` function (not shown here). Input embedding + vectors are divided into chunks, each of size `embeddings_per_chunk`. Within every chunk, the clustering algorithm aims + to identify `chunk_cluster_count` distinct clusters. The resulting clustering labels are then expanded to match the original length of the input embeddings. - + NOTE: torch.jit.script currently does not support inherited methods with a `super()` call. Args: @@ -49,7 +49,7 @@ def __init__(self, cuda: bool = False): def check_input(self, embeddings_per_chunk: int, chunk_cluster_count: int, max_num_speakers: int) -> None: """ Checks the validity of the input parameters. - + Args: embeddings_per_chunk (int): The size of the windows in which the algorithm aims to identify `chunk_cluster_count` clusters. @@ -86,23 +86,23 @@ def unpack_labels( Unpack the labels from the aggregated labels to the original labels. Args: - Y_aggr (Tensor): + Y_aggr (Tensor): Aggregated label vector from the merged segments. - window_range_list (List[List[int]]): + window_range_list (List[List[int]]): List of window ranges for each of the merged segments. - absolute_merge_mapping (List[List[torch.Tensor]]): + absolute_merge_mapping (List[List[torch.Tensor]]): List of absolute mappings for each of the merged segments. Each list element contains two tensors: - The first tensor represents the absolute index of the bypassed segment (segments that remain unchanged). - The second tensor represents the absolute index of the merged segment (segments that have had their indexes changed). - org_len (int): + org_len (int): Original length of the labels. In most cases, this is a fairly large number (on the order of 10^5). Returns: - Y_unpack (Tensor): + Y_unpack (Tensor): Unpacked labels derived from the aggregated labels. """ Y_unpack = torch.zeros((org_len,)).long().to(Y_aggr.device) - for (win_rng, abs_mapping) in zip(window_range_list, absolute_merge_mapping): + for win_rng, abs_mapping in zip(window_range_list, absolute_merge_mapping): inferred_merged_embs = Y_aggr[win_rng[0] : win_rng[1]] if len(abs_mapping[1]) > 0: Y_unpack[abs_mapping[1]] = inferred_merged_embs[-1].clone() # Merged @@ -114,13 +114,16 @@ def unpack_labels( return Y_unpack def split_embs_to_windows( - self, index: int, emb: torch.Tensor, embeddings_per_chunk: int, + self, + index: int, + emb: torch.Tensor, + embeddings_per_chunk: int, ) -> Tuple[torch.Tensor, int]: """ Splits the embedding tensor into smaller window-sized tensors based on a given index. - + Args: - index (int): The index of the desired window. This determines the starting point + index (int): The index of the desired window. This determines the starting point of the window using the formula: start = embeddings_per_chunk * index emb (Tensor): The embedding tensor which needs to be split. @@ -128,9 +131,9 @@ def split_embs_to_windows( The size of the windows in which the algorithm aims to identify `chunk_cluster_count` clusters. Returns: - emb_part (Tensor): + emb_part (Tensor): The window-sized tensor, which is a portion of the `emb`. - offset_index (int): + offset_index (int): The starting position of the window in the `emb` tensor. """ if embeddings_per_chunk * (index + 1) > emb.shape[0]: @@ -146,7 +149,7 @@ def forward(self, param_dict: Dict[str, torch.Tensor]) -> torch.LongTensor: A function wrapper designed for performing inference using an exported script format. Note: - A dictionary is used to facilitate inference with the exported jit model in the Triton server. + A dictionary is used to facilitate inference with the exported jit model in the Triton server. This is done using an easy-to-understand naming convention. See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#special-conventions-for-pytorch-backend @@ -184,7 +187,7 @@ def forward(self, param_dict: Dict[str, torch.Tensor]) -> torch.LongTensor: def get_div_ceil_count(self, numer: int, denomin: int) -> int: """ Calculates the ceiling of the division of two integers. - + Args: numer (int): Numerator, the number of segments or clusters, for example. denomin (int): Denominator, the number of speakers or clusters, for example. @@ -328,7 +331,10 @@ def long_forward_infer( # `class_target_vol` is a list of cluster-indices from overclustering for spk_idx, merge_quantity in enumerate(list(class_target_vol)): merged_embs, merged_clus_labels, index_mapping = run_reducer( - pre_embs=emb_part, target_spk_idx=spk_idx, merge_quantity=merge_quantity, pre_clus_labels=Y_part, + pre_embs=emb_part, + target_spk_idx=spk_idx, + merge_quantity=merge_quantity, + pre_clus_labels=Y_part, ) total_emb.append(merged_embs) absolute_index_mapping = [x + offset_index for x in index_mapping] diff --git a/nemo/collections/asr/parts/utils/numba_utils.py b/nemo/collections/asr/parts/utils/numba_utils.py index 867ecf59521b..ec5c18ddc405 100644 --- a/nemo/collections/asr/parts/utils/numba_utils.py +++ b/nemo/collections/asr/parts/utils/numba_utils.py @@ -61,7 +61,7 @@ def _phase_vocoder_kernel(D, time_steps, phi_advance, d_stretch, phase_acc, scal """ two_pi = 2.0 * np.pi - for (t, step) in enumerate(time_steps): + for t, step in enumerate(time_steps): columns = D[:, int(step) : int(step + 2)] columns_0 = columns[:, 0] columns_1 = columns[:, 1] diff --git a/nemo/collections/asr/parts/utils/offline_clustering.py b/nemo/collections/asr/parts/utils/offline_clustering.py index 3f6c90d945ef..71291a665bcf 100644 --- a/nemo/collections/asr/parts/utils/offline_clustering.py +++ b/nemo/collections/asr/parts/utils/offline_clustering.py @@ -150,7 +150,13 @@ def kmeans_plusplus_torch( centers = torch.zeros(n_clusters, n_features, dtype=X.dtype) center_id = torch.randint(0, n_samples, (1,)).long() - indices = torch.full([n_clusters,], -1, dtype=torch.int) + indices = torch.full( + [ + n_clusters, + ], + -1, + dtype=torch.int, + ) centers[0] = X[center_id].squeeze(0) indices[0] = center_id.squeeze(0) @@ -511,7 +517,7 @@ def getMultiScaleCosAffinityMatrix( Returns: fused_sim_d (Tensor): - An affinity matrix that is obtained by calculating the weighted sum of + An affinity matrix that is obtained by calculating the weighted sum of the multiple affinity matrices from the different scales. """ multiscale_weights = torch.squeeze(multiscale_weights, dim=0).to(device) @@ -986,8 +992,12 @@ def forward(self) -> Tuple[torch.Tensor, torch.Tensor]: est_spk_n_dict: Dict[int, torch.Tensor] = {} self.p_value_list = self.getPvalueList() p_volume = self.p_value_list.shape[0] - eig_ratio_list = torch.zeros(p_volume,) - est_num_of_spk_list = torch.zeros(p_volume,) + eig_ratio_list = torch.zeros( + p_volume, + ) + est_num_of_spk_list = torch.zeros( + p_volume, + ) if self.parallelism: futures: List[torch.jit.Future[torch.Tensor]] = [] @@ -1176,10 +1186,10 @@ def forward_unit_infer( kmeans_random_trials: int = 1, ) -> torch.LongTensor: """ - This function takes a cosine similarity matrix `mat` and returns the speaker labels for the segments - in the given input embeddings. - - Args: + This function takes a cosine similarity matrix `mat` and returns the speaker labels for the segments + in the given input embeddings. + + Args: mat (Tensor): Cosine similarity matrix (affinity matrix) calculated from the provided speaker embeddings. oracle_num_speakers (int): @@ -1202,8 +1212,8 @@ def forward_unit_infer( This value should be optimized on a development set for best results. By default, it is set to -1.0, and the function performs NME-analysis to estimate the threshold. kmeans_random_trials (int): - The number of random trials for initializing k-means clustering. More trials can result in more stable clustering. The default is 1. - + The number of random trials for initializing k-means clustering. More trials can result in more stable clustering. The default is 1. + Returns: Y (LongTensor): Speaker labels (clustering output) in integer format for the segments in the given input embeddings. diff --git a/nemo/collections/asr/parts/utils/online_clustering.py b/nemo/collections/asr/parts/utils/online_clustering.py index 23ebe6c6dbbf..d61e63c6e9a6 100644 --- a/nemo/collections/asr/parts/utils/online_clustering.py +++ b/nemo/collections/asr/parts/utils/online_clustering.py @@ -180,9 +180,9 @@ def calculate_removable_counts(removable_counts_mat: torch.Tensor, remain_count: >>> removable_counts_mat = [5, 3, 1] >>> remain_count = 6 >>> num_clus = 3 - + Interim results: - >>> diff_counts + >>> diff_counts [1, 2, 2] >>> gradual_counts [3, 4, 2] @@ -190,7 +190,7 @@ def calculate_removable_counts(removable_counts_mat: torch.Tensor, remain_count: [3, 7, 9] Return: - >>> removable_counts_mat + >>> removable_counts_mat [2, 1, 0] Args: @@ -239,7 +239,9 @@ def calculate_removable_counts(removable_counts_mat: torch.Tensor, remain_count: def get_merge_quantity( - num_to_be_removed: int, pre_clus_labels: torch.Tensor, min_count_per_cluster: int, + num_to_be_removed: int, + pre_clus_labels: torch.Tensor, + min_count_per_cluster: int, ) -> torch.Tensor: """ Determine which embeddings we need to reduce or merge in history buffer. @@ -257,8 +259,8 @@ def get_merge_quantity( >>> pre_clus_labels = [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2] >>> min_count_per_cluster = 2 >>> get_merge_quantity(num_to_be_removed, pre_clus_labels, min_count_per_cluster) - Return: - torch.tensor([2, 1, 0]) + Return: + torch.tensor([2, 1, 0]) >>> # Sum should be equal to `num_to_be_removed` which is 3 Args: @@ -352,7 +354,7 @@ def get_closest_embeddings(affinity_mat: torch.Tensor, n_closest: int) -> Tuple[ >>> affinity_mat = [[1.0, 0.2, 0.8], [0.2, 1.0, 0.4], [0.8, 0.4, 1.0]] - >>> affinity_mat.sum(0) + >>> affinity_mat.sum(0) [2.0, 1.6, 2.2] # The closest two embedding vectors are at index 0 and 2. @@ -388,15 +390,18 @@ def get_closest_embeddings(affinity_mat: torch.Tensor, n_closest: int) -> Tuple[ def run_reducer( - pre_embs: torch.Tensor, target_spk_idx: int, merge_quantity: int, pre_clus_labels: torch.Tensor, + pre_embs: torch.Tensor, + target_spk_idx: int, + merge_quantity: int, + pre_clus_labels: torch.Tensor, ): """ Reduce the number of embedding vectors by merging the closest embedding vectors. - - This merging algorithm is based on the assumption that the closest embeddings + - This merging algorithm is based on the assumption that the closest embeddings are the most redundant embedding vectors. - - The closest embedding vectors are chosen by selecting the highest top-N sum of + - The closest embedding vectors are chosen by selecting the highest top-N sum of each column in a given affinity matrix. - - If merge_quantity is N, we choose (N+1) vectors into 1 embedding vector. + - If merge_quantity is N, we choose (N+1) vectors into 1 embedding vector. Thus, we reduce N embeddings in the original embedding vector set. Example: @@ -404,12 +409,12 @@ def run_reducer( >>> affinity_mat = [[1.0, 0.2, 0.8], [0.2, 1.0, 0.4], [0.8, 0.4, 1.0]] - >>> affinity_mat.sum(0) + >>> affinity_mat.sum(0) [2.0, 1.6, 2.2] The first and the third embedding vectors are merged into one embedding vector. >>> index_mapping # (bypassed indices, merged indices) - ([1], [0, 2]) + ([1], [0, 2]) Args: pre_embs (Tensor): @@ -424,7 +429,7 @@ def run_reducer( The original cluster (speaker) index Returns: - merged_embs (torch.Tensor): + merged_embs (torch.Tensor): The merged embedding vectors. merged_clus_labels (torch.Tensor): The cluster (speaker) indices for the merged embedding vectors. @@ -551,7 +556,7 @@ class OnlineSpeakerClustering(torch.nn.Module): temporal_label_major_vote_buffer_size (int): Buffer size for major-voting the num_spk_stat (list): - List of number of speakers for major voting. Number of speakers are estimated through + List of number of speakers for major voting. Number of speakers are estimated through majority voting of `self.num_spk_stat` list. p_value_hist (list): List of p_values for major voting. @@ -559,7 +564,7 @@ class OnlineSpeakerClustering(torch.nn.Module): saved to `self.p_value_hist`. Attributes for counters and buffers in streaming system: - + is_online (bool): - If self.is_online is False: FIFO queue does not push out any speaker embedding vector @@ -717,7 +722,7 @@ def limit_frames_per_speaker(self, frame_index: int, est_num_of_spk: int) -> int Unique index for each segment and embedding vector est_num_of_spk (int): Estimated number of speakers - + Returns: (int) Estimated number of speakers capped by `self.min_frame_per_spk` """ @@ -777,7 +782,7 @@ def prepare_embedding_update( hist_curr_boundary (int): The current boundary of between history buffer and current buffer. This is the new history-current buffer boundary while self.history_buffer_seg_end is the old one. - Thus, the new set of embedding vectors are collected from + Thus, the new set of embedding vectors are collected from `label_stt=self.hist_buffer_seg_end` to `label_end=hist_curr_boundary`. total_segments_processed_count (int): The number of segments that are processed so far in integer format. @@ -930,7 +935,7 @@ def update_speaker_history_buffer( Step (2) |-----------------------|ABCDEF--------------XY| - |---------emb_in-------| + |---------emb_in-------| The newly accepted embeddings go through a FIFO queue (first come, first merge) history buffer = 22 @@ -947,7 +952,7 @@ def update_speaker_history_buffer( Step (4) |======================|CDEF--------------XY| |-----hist_emb_buff----| - + After clustering, `self.Y_fullhist` is updated as: |0000000000011111111111|11110000110010010011| @@ -1181,7 +1186,10 @@ def forward_infer( if cuda and (curr_emb.device == torch.device("cpu") or base_segment_indexes.device == torch.device("cpu")): raise ValueError(f"CUDA is enabled but the input {curr_emb} or {base_segment_indexes} is not on the GPU.") - merged_embs, add_new = self.get_reduced_mat(emb_in=curr_emb, base_segment_indexes=base_segment_indexes,) + merged_embs, add_new = self.get_reduced_mat( + emb_in=curr_emb, + base_segment_indexes=base_segment_indexes, + ) # Perform clustering on the embedding matrix containing history and current FIFO buffer merged_embeddings if merged_embs.shape[0] == 1: Y = torch.zeros((1,), dtype=torch.int32) diff --git a/nemo/collections/asr/parts/utils/optimization_utils.py b/nemo/collections/asr/parts/utils/optimization_utils.py index f947007e59b4..56f9d667e169 100644 --- a/nemo/collections/asr/parts/utils/optimization_utils.py +++ b/nemo/collections/asr/parts/utils/optimization_utils.py @@ -52,26 +52,26 @@ def unravel_index(index: int, shape: torch.Tensor): @torch.jit.script class LinearSumAssignmentSolver(object): """ - A Solver class for the linear sum assignment (LSA) problem. - Designed for torch.jit.script compatibility in NeMo. - - The LSA problem is also referred to as bipartite matching problem. An LSA problem is described - by a matrix `cost_mat`, where each cost_mat[i,j] is the cost of matching vertex i of the first partite - set (e.g. a "worker") and vertex j of the second set (e.g. a "job"). - - Thus, the goal of LSA-solver is to find a complete assignment of column element to row element with - the minimal cost. Note that the solution may not be unique and there could be multiple solutions that + A Solver class for the linear sum assignment (LSA) problem. + Designed for torch.jit.script compatibility in NeMo. + + The LSA problem is also referred to as bipartite matching problem. An LSA problem is described + by a matrix `cost_mat`, where each cost_mat[i,j] is the cost of matching vertex i of the first partite + set (e.g. a "worker") and vertex j of the second set (e.g. a "job"). + + Thus, the goal of LSA-solver is to find a complete assignment of column element to row element with + the minimal cost. Note that the solution may not be unique and there could be multiple solutions that yield the same minimal cost. - LSA problem solver is needed for the following tasks in NeMo: + LSA problem solver is needed for the following tasks in NeMo: - Permutation Invariant Loss (PIL) for diarization model training - - Label permutation matching for online speaker diarzation - - Concatenated minimum-permutation Word Error Rate (cp-WER) calculation + - Label permutation matching for online speaker diarzation + - Concatenated minimum-permutation Word Error Rate (cp-WER) calculation - This implementation is based on the LAP solver from scipy: + This implementation is based on the LAP solver from scipy: https://github.com/scipy/scipy/blob/v0.18.1/scipy/optimize/_hungarian.py The scipy implementation comes with the following license: - + Copyright (c) 2008 Brian M. Clapper , Gael Varoquaux Author: Brian M. Clapper, Gael Varoquaux License: 3-clause BSD @@ -139,8 +139,8 @@ def _step2(self): Goal: Make sure assignment with cost sum 0 is feasible. Procedure: - - Find a zero in the resulting cost matrix. - - If there are no marked zeros in its row or column, mark the zero. + - Find a zero in the resulting cost matrix. + - If there are no marked zeros in its row or column, mark the zero. - Repeat for each element in the matrix. - Go to step 3. """ @@ -158,11 +158,11 @@ def _step2(self): def _step3(self) -> int: """ Step 3 - + Goal: All zeros in the matrix must be covered by marking with the least numbers of rows and columns. Procedure: - - Cover each column containing a marked zero. + - Cover each column containing a marked zero. - If n columns are covered, the marked zeros describe a complete set of unique assignments. In this case, Go to Step 0 (Done state) - Otherwise, Go to Step 4. @@ -181,10 +181,10 @@ def _step4(self, bypass: bool = False) -> int: Goal: Cover all columns containing a marked zero. Procedure: - - Find a non-covered zero and put a prime mark on it. + - Find a non-covered zero and put a prime mark on it. - If there is no marked zero in the row containing this primed zero, Go to Step 5. - - Otherwise, cover this row and uncover the column containing the marked zero. - - Continue in this manner until there are no uncovered zeros left. + - Otherwise, cover this row and uncover the column containing the marked zero. + - Continue in this manner until there are no uncovered zeros left. - Save the smallest uncovered value. - Go to Step 6. """ @@ -219,15 +219,15 @@ def _step5(self) -> int: Step 5 Goal: Construct a series of alternating primed and marked zeros as follows. - + Procedure: - Let Z0 represent the uncovered primed zero found in Step 4. - Let Z1 denote the marked zero in the column of Z0 (if any). - Let Z2 denote the primed zero in the row of Z1 (there will always be one). - - Continue until the series terminates at a primed zero that has no marked zero in its column. + - Continue until the series terminates at a primed zero that has no marked zero in its column. - Unmark each marked zero of the series. - Mark each primed zero of the series. - - Erase all primes and uncover every line in the matrix. + - Erase all primes and uncover every line in the matrix. - Return to Step 3 """ count = torch.tensor(0) diff --git a/nemo/collections/common/losses/bce_logits_loss.py b/nemo/collections/common/losses/bce_logits_loss.py index b65c41985afc..dff09e3e09f5 100644 --- a/nemo/collections/common/losses/bce_logits_loss.py +++ b/nemo/collections/common/losses/bce_logits_loss.py @@ -32,8 +32,7 @@ class BCEWithLogitsLoss(nn.BCEWithLogitsLoss, Serialization, Typing): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "logits": NeuralType(["B"] + ["ANY"] * (self._logits_dim - 1), LogitsType()), "labels": [NeuralType(["B"] + ["ANY"] * (self._logits_dim - 2), LabelsType())], @@ -42,8 +41,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"loss": NeuralType(elements_type=LossType())} def __init__( diff --git a/nemo/collections/common/losses/cross_entropy.py b/nemo/collections/common/losses/cross_entropy.py index 753cc089a981..820b2cc7450e 100644 --- a/nemo/collections/common/losses/cross_entropy.py +++ b/nemo/collections/common/losses/cross_entropy.py @@ -29,8 +29,7 @@ class CrossEntropyLoss(nn.CrossEntropyLoss, Serialization, Typing): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "logits": NeuralType(['B'] + ['ANY'] * (self._logits_dim - 1), LogitsType()), "labels": NeuralType(['B'] + ['ANY'] * (self._logits_dim - 2), LabelsType()), @@ -39,8 +38,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"loss": NeuralType(elements_type=LossType())} def __init__(self, logits_ndim=2, weight=None, reduction='mean', ignore_index=-100): @@ -88,8 +86,7 @@ class NLLLoss(nn.NLLLoss, Serialization, Typing): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "log_probs": NeuralType(("B", "T", "D"), LogprobsType()), "labels": NeuralType(("B", "T"), LabelsType()), @@ -98,8 +95,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"loss": NeuralType(elements_type=LossType())} def __init__(self, log_probs_ndim=2, weight=None, reduction='mean', ignore_index=-100): diff --git a/nemo/collections/common/losses/mse_loss.py b/nemo/collections/common/losses/mse_loss.py index 802e8ca49204..da1474216195 100644 --- a/nemo/collections/common/losses/mse_loss.py +++ b/nemo/collections/common/losses/mse_loss.py @@ -27,8 +27,7 @@ class MSELoss(nn.MSELoss, Serialization, Typing): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "preds": NeuralType(tuple('B'), RegressionValuesType()), "labels": NeuralType(tuple('B'), LabelsType()), @@ -36,8 +35,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"loss": NeuralType(elements_type=LossType())} def __init__(self, reduction: str = 'mean'): diff --git a/nemo/collections/common/losses/multi_similarity_loss.py b/nemo/collections/common/losses/multi_similarity_loss.py index 022f6d6de691..8bea184fdc64 100644 --- a/nemo/collections/common/losses/multi_similarity_loss.py +++ b/nemo/collections/common/losses/multi_similarity_loss.py @@ -27,14 +27,12 @@ class MultiSimilarityLoss(Loss): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return {"logits": NeuralType(('B', 'D'), LogitsType()), "labels": NeuralType(('B'), LabelsType())} @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"loss": NeuralType(elements_type=LossType())} def __init__( diff --git a/nemo/collections/common/losses/smoothed_cross_entropy.py b/nemo/collections/common/losses/smoothed_cross_entropy.py index 265251acc390..7de0220c8be5 100644 --- a/nemo/collections/common/losses/smoothed_cross_entropy.py +++ b/nemo/collections/common/losses/smoothed_cross_entropy.py @@ -46,8 +46,7 @@ class SmoothedCrossEntropyLoss(Loss): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "log_probs": NeuralType(("B", "T", "D"), LogprobsType()), "labels": NeuralType(("B", "T"), LabelsType()), @@ -56,8 +55,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"loss": NeuralType(elements_type=LossType())} def __init__( @@ -118,8 +116,7 @@ class SmoothedNLLLoss(NeuralModule, Exportable): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "log_probs": NeuralType(("B", "T", "D"), LogprobsType()), "labels": NeuralType(("B", "T"), LabelsType()), @@ -129,8 +126,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"loss": NeuralType(elements_type=LossType())} def __init__(self, reduction='mean', label_smoothing=0.0, eps=1e-8, **kwargs): diff --git a/nemo/collections/common/losses/spanning_loss.py b/nemo/collections/common/losses/spanning_loss.py index a12dab64afd0..b5d5c4f8f7b8 100644 --- a/nemo/collections/common/losses/spanning_loss.py +++ b/nemo/collections/common/losses/spanning_loss.py @@ -27,8 +27,7 @@ class SpanningLoss(Loss): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "logits": NeuralType(('B', 'T', 'D'), LogitsType()), "start_positions": NeuralType(tuple('B'), ChannelType()), @@ -37,15 +36,16 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "loss": NeuralType(elements_type=LossType()), "start_logits": NeuralType(('B', 'T'), LogitsType()), "end_logits": NeuralType(('B', 'T'), LogitsType()), } - def __init__(self,): + def __init__( + self, + ): super().__init__() @typecheck() diff --git a/nemo/collections/common/metrics/punct_er.py b/nemo/collections/common/metrics/punct_er.py index 933c1581f016..6bff496d594a 100644 --- a/nemo/collections/common/metrics/punct_er.py +++ b/nemo/collections/common/metrics/punct_er.py @@ -28,19 +28,21 @@ def punctuation_error_rate( - references: list[str], hypotheses: list[str], punctuation_marks: list[str], punctuation_mask: str = "[PUNCT]", + references: list[str], + hypotheses: list[str], + punctuation_marks: list[str], + punctuation_mask: str = "[PUNCT]", ) -> None: - """ Computes Punctuation Error Rate - + Args: references (list[str]) - list of references hypotheses (list[str]) - list of hypotheses punctuation_marks (list[str]) - list of punctuation marks for computing metrics punctuation_mask (str, by default "[PUNCT]") - mask token that will be applied to given punctuation marks while edit distance calculation - + Return: punct_er (float) - Punctuation Error Rate """ @@ -72,14 +74,14 @@ class OccurancePunctuationErrorRate: Args to init: punctuation_marks (list[str]) - list of punctuation marks for computing metrics punctuation_mask (str, by default "[PUNCT]") - mask token that will be applied to - given punctuation marks while edit distance calculation - + given punctuation marks while edit distance calculation + How to use: 1. Create object of OccurancePunctuationErrorRate class. Example: punctuation_marks = [".", ",", "!", "?"] oper_obj = OccurancePunctuationErrorRate(punctuation_marks) - + 2. To compute punctuation metrics, pass reference and hypothesis string to the "compute" method of created object. Example: @@ -94,16 +96,16 @@ class OccurancePunctuationErrorRate: ',': {'Correct': 0, 'Deletions': 1, 'Insertions': 0, 'Substitutions': 0}, '!': {'Correct': 1, 'Deletions': 0, 'Insertions': 0, 'Substitutions': 0}, '?': {'Correct': 0, 'Deletions': 0, 'Insertions': 1, 'Substitutions': 0}} - + 2. Dict of substitutions absolute amounts between given punctuation marks: Example: {'.': {'.': 0, ',': 0, '!': 1, '?': 0}, ',': {'.': 0, ',': 0, '!': 0, '?': 0}, '!': {'.': 0, ',': 0, '!': 0, '?': 0}, '?': {'.': 0, ',': 0, '!': 0, '?': 0}} - + 3. namedtuple "PunctuationRates" of punctuation operation rates (in range from 0 to 1): - 3.1. correct_rate - overall correct rate + 3.1. correct_rate - overall correct rate Example: correct_rate=0.25 3.2. deletions_rate - overall deletions rate Example: deletions_rate=0.25 @@ -114,14 +116,14 @@ class OccurancePunctuationErrorRate: 3.5. punct_er - Punctuation Error Rate Example: punct_er=0.75 3.6. operation_rates - dict of operations rates for each given punctuation mark - Example: + Example: operation_rates={ '.': {'Correct': 0.0, 'Deletions': 0.0, 'Insertions': 0.0, 'Substitutions': 1.0}, ',': {'Correct': 0.0, 'Deletions': 1.0, 'Insertions': 0.0, 'Substitutions': 0.0}, '!': {'Correct': 1.0, 'Deletions': 0.0, 'Insertions': 0.0, 'Substitutions': 0.0}, '?': {'Correct': 0.0, 'Deletions': 0.0, 'Insertions': 1.0, 'Substitutions': 0.0} } - + 3.7. substitution_rates - dict of substitution rates for each given punctuation mark Example: substitution_rates={ @@ -320,59 +322,59 @@ def compute(self, reference: str, hypothesis: str): class DatasetPunctuationErrorRate: """ - Class for computation the total puncutation-related absolute amounts of operations and their rates + Class for computation the total puncutation-related absolute amounts of operations and their rates in pairs of reference and hypothesis strins: - Absolute amounts of correct predictions, deletions, insertions and substitutions for each given punctuation mark - Rates of correct predictions, deletions, insertions - and substitutions for each given punctuation mark + and substitutions for each given punctuation mark - Total rates of correct predictions, deletions, insertions - and substiturions in pairs of reference and hypothesis strings + and substiturions in pairs of reference and hypothesis strings - Punctuation Error Rate - + Args to init: references (list[str]) - list of references hypotheses (list[str]) - list of hypotheses punctuation_marks (list[str]) - list of punctuation marks for computing metrics punctuation_mask (str, by default "[PUNCT]") - mask token that will be applied to given punctuation marks while edit distance calculation - + How to use: 1. Create object of DatasetPunctuationErrorRate class. Example: references = ["Hi, dear! Nice to see you. What's"] - hypotheses = ["Hi dear! Nice to see you! What's?"] + hypotheses = ["Hi dear! Nice to see you! What's?"] punctuation_marks = [".", ",", "!", "?"] - + dper_obj = DatasetPunctuationErrorRate(references, hypotheses, punctuation_marks) - + 2. To compute punctuation metrics, call the class method "compute()". Example: - dper_obj.compute() - + dper_obj.compute() + Result: The following atributes of class object will be updated with calculated metrics values. The values are available with calling the atributes: - - dper_obj.operation_rates - dict, rates of correctness and errors for each punctuation mark + + dper_obj.operation_rates - dict, rates of correctness and errors for each punctuation mark from `preset dper_obj.punctuation_marks` list. - + dper_obj.substitution_rates - dict, substitution rates between puncutation marks from `preset dper_obj.punctuation_marks` list. - - dper_obj.correct_rate - float, total rate of correctness between provided pairs of + + dper_obj.correct_rate - float, total rate of correctness between provided pairs of references and hypotheses. - - dper_obj.deletions_rate - float, total rate of deletions between provided pairs of + + dper_obj.deletions_rate - float, total rate of deletions between provided pairs of references and hypotheses. - - dper_obj.insertions_rate - float, total rate of insertions between provided pairs of + + dper_obj.insertions_rate - float, total rate of insertions between provided pairs of references and hypotheses. - - dper_obj.substitutions_rate - float, total rate of substitutions between provided pairs of + + dper_obj.substitutions_rate - float, total rate of substitutions between provided pairs of references and hypotheses. - - dper_obj.punct_er - float, total Punctuation Error Rate between provided pairs of + + dper_obj.punct_er - float, total Punctuation Error Rate between provided pairs of references and hypotheses. """ diff --git a/nemo/collections/common/parts/adapter_modules.py b/nemo/collections/common/parts/adapter_modules.py index 2084147f9cbc..0daa8ee083e7 100644 --- a/nemo/collections/common/parts/adapter_modules.py +++ b/nemo/collections/common/parts/adapter_modules.py @@ -61,7 +61,9 @@ def get_default_strategy_config(self) -> 'dataclass': """ return adapter_mixin_strategies.ResidualAddAdapterStrategyConfig() - def adapter_unfreeze(self,): + def adapter_unfreeze( + self, + ): """ Sets the requires grad for all parameters in the adapter to True. This method should be overridden for any custom unfreeze behavior that is required. @@ -72,7 +74,6 @@ def adapter_unfreeze(self,): class LinearAdapter(nn.Module, AdapterModuleUtil): - """ Simple Linear Feedforward Adapter module with LayerNorm and singe hidden layer with activation function. Note: The adapter explicitly initializes its final layer with all zeros in order to avoid affecting the diff --git a/nemo/collections/common/parts/mlm_scorer.py b/nemo/collections/common/parts/mlm_scorer.py index c38e4b25ed72..a6d8547affe8 100644 --- a/nemo/collections/common/parts/mlm_scorer.py +++ b/nemo/collections/common/parts/mlm_scorer.py @@ -86,7 +86,7 @@ def score_sentence(self, sentence: str): def __mask_text__(self, idx: int, tokens: List[str]): """ - replaces string at index idx in list `tokens` with a masked token and returns the modified list. + replaces string at index idx in list `tokens` with a masked token and returns the modified list. """ masked = tokens.copy() masked[idx] = self.MASK_LABEL diff --git a/nemo/collections/common/parts/preprocessing/parsers.py b/nemo/collections/common/parts/preprocessing/parsers.py index 10a3522ef241..54be78dc4ffa 100644 --- a/nemo/collections/common/parts/preprocessing/parsers.py +++ b/nemo/collections/common/parts/preprocessing/parsers.py @@ -207,7 +207,9 @@ def _normalize(self, text: str) -> Optional[str]: # noinspection PyBroadException try: text = cleaners.clean_text( - string=text, table=self._table, punctuation_to_replace=self.PUNCTUATION_TO_REPLACE, + string=text, + table=self._table, + punctuation_to_replace=self.PUNCTUATION_TO_REPLACE, ) except Exception: return None @@ -218,7 +220,11 @@ def _normalize(self, text: str) -> Optional[str]: NAME_TO_PARSER = {'base': CharParser, 'en': ENCharParser, 'ru': RUCharParser} -def make_parser(labels: Optional[List[str]] = None, name: str = 'base', **kwargs,) -> CharParser: +def make_parser( + labels: Optional[List[str]] = None, + name: str = 'base', + **kwargs, +) -> CharParser: """Creates parser from labels, set of arguments and concise parser name. Args: diff --git a/nemo/collections/common/tokenizers/word_tokenizer.py b/nemo/collections/common/tokenizers/word_tokenizer.py index f3431af9d734..562a8f19097e 100644 --- a/nemo/collections/common/tokenizers/word_tokenizer.py +++ b/nemo/collections/common/tokenizers/word_tokenizer.py @@ -37,7 +37,7 @@ def __init__( Args: vocab_file: path to file with vocabulary which consists of characters separated by \n - mask_token: mask token + mask_token: mask token bos_token: the beginning of sequence token eos_token: the end of sequence token. Usually equal to sep_token pad_token: token to use for padding diff --git a/nemo/collections/tts/g2p/data/ctc.py b/nemo/collections/tts/g2p/data/ctc.py index d96f4c0c718b..34b3102e9f06 100644 --- a/nemo/collections/tts/g2p/data/ctc.py +++ b/nemo/collections/tts/g2p/data/ctc.py @@ -109,7 +109,9 @@ def __init__( item[grapheme_field] = item[grapheme_field][:max_source_len] removed_source_max += 1 self.data.append( - {"graphemes": item[grapheme_field],} + { + "graphemes": item[grapheme_field], + } ) logging.info( @@ -123,7 +125,7 @@ def __getitem__(self, index): return self.data[index] def map(self, text: str) -> List[int]: - """ Creates a mapping from target labels to ids.""" + """Creates a mapping from target labels to ids.""" tokens = [] for word_id, word in enumerate(text.split()): tokens.append(self.labels_tkn2id[word]) diff --git a/nemo/collections/tts/g2p/data/t5.py b/nemo/collections/tts/g2p/data/t5.py index 0fc39f9ebaa1..5edaa18a302d 100644 --- a/nemo/collections/tts/g2p/data/t5.py +++ b/nemo/collections/tts/g2p/data/t5.py @@ -110,7 +110,11 @@ def _collate_fn(self, batch): # Encode inputs (graphemes) input_encoding = self.tokenizer( - graphemes_batch, padding='longest', max_length=self.max_source_len, truncation=True, return_tensors='pt', + graphemes_batch, + padding='longest', + max_length=self.max_source_len, + truncation=True, + return_tensors='pt', ) input_ids, attention_mask = input_encoding.input_ids, input_encoding.attention_mask output = (input_ids, attention_mask) @@ -120,7 +124,10 @@ def _collate_fn(self, batch): # Encode targets (phonemes) phonemes_batch = [entry["phonemes"] for entry in batch] target_encoding = self.tokenizer( - phonemes_batch, padding='longest', max_length=self.max_target_len, truncation=True, + phonemes_batch, + padding='longest', + max_length=self.max_target_len, + truncation=True, ) labels = target_encoding.input_ids diff --git a/nemo/collections/tts/losses/hifigan_losses.py b/nemo/collections/tts/losses/hifigan_losses.py index 649f075994d8..559860ddd7a5 100644 --- a/nemo/collections/tts/losses/hifigan_losses.py +++ b/nemo/collections/tts/losses/hifigan_losses.py @@ -96,7 +96,7 @@ def forward(self, disc_real_outputs, disc_generated_outputs): g_losses = [] for dr, dg in zip(disc_real_outputs, disc_generated_outputs): r_loss = torch.mean((1 - dr) ** 2) - g_loss = torch.mean(dg ** 2) + g_loss = torch.mean(dg**2) loss += r_loss + g_loss r_losses.append(r_loss.item()) g_losses.append(g_loss.item()) diff --git a/nemo/collections/tts/losses/stftlosses.py b/nemo/collections/tts/losses/stftlosses.py index 65320bf9dd65..971667beab5b 100644 --- a/nemo/collections/tts/losses/stftlosses.py +++ b/nemo/collections/tts/losses/stftlosses.py @@ -66,7 +66,7 @@ def stft(x, fft_size, hop_size, win_length, window): imag = x_stft[..., 1] # NOTE(kan-bayashi): clamp is needed to avoid nan or inf - return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) + return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1) class SpectralConvergenceLoss(Loss): diff --git a/nemo/collections/tts/modules/adapters.py b/nemo/collections/tts/modules/adapters.py index df5bdff84dc5..23486441ff8a 100644 --- a/nemo/collections/tts/modules/adapters.py +++ b/nemo/collections/tts/modules/adapters.py @@ -24,7 +24,7 @@ class FFTransformerDecoderAdapter(FFTransformerDecoder, adapter_mixins.AdapterModuleMixin): - """ Inherit from FFTransformerDecoder and add support for adapter""" + """Inherit from FFTransformerDecoder and add support for adapter""" def add_adapter(self, name: str, cfg: dict): cfg = self._update_adapter_cfg_input_dim(cfg) @@ -54,13 +54,13 @@ def _update_adapter_cfg_input_dim(self, cfg: DictConfig): class FFTransformerEncoderAdapter( FFTransformerDecoderAdapter, FFTransformerEncoder, adapter_mixins.AdapterModuleMixin ): - """ Inherit from FFTransformerEncoder and add support for adapter""" + """Inherit from FFTransformerEncoder and add support for adapter""" pass class AlignmentEncoderAdapter(AlignmentEncoder, adapter_mixins.AdapterModuleMixin): - """ Inherit from AlignmentEncoder and add support for adapter""" + """Inherit from AlignmentEncoder and add support for adapter""" def add_adapter(self, name: str, cfg: dict): @@ -106,7 +106,7 @@ def _update_adapter_cfg_input_dim(self, cfg: DictConfig, module_dim: int): class TemporalPredictorAdapter(TemporalPredictor, adapter_mixins.AdapterModuleMixin): - """ Inherit from TemporalPredictor and add support for adapter""" + """Inherit from TemporalPredictor and add support for adapter""" def add_adapter(self, name: str, cfg: dict): cfg = self._update_adapter_cfg_input_dim(cfg) diff --git a/nemo/collections/tts/modules/hifigan_modules.py b/nemo/collections/tts/modules/hifigan_modules.py index a600b1669d15..5c48c4ca77db 100644 --- a/nemo/collections/tts/modules/hifigan_modules.py +++ b/nemo/collections/tts/modules/hifigan_modules.py @@ -205,7 +205,7 @@ def __init__( self.ups.append( weight_norm( ConvTranspose1d( - upsample_initial_channel // (2 ** i), + upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, diff --git a/nemo/collections/tts/parts/mixins/fastpitch_adapter_mixins.py b/nemo/collections/tts/parts/mixins/fastpitch_adapter_mixins.py index 375cf1fe51ee..bfb8011e22ce 100644 --- a/nemo/collections/tts/parts/mixins/fastpitch_adapter_mixins.py +++ b/nemo/collections/tts/parts/mixins/fastpitch_adapter_mixins.py @@ -21,7 +21,7 @@ class FastPitchAdapterModelMixin(AdapterModelPTMixin): - """ FastPitch Adapter Mixin that can augment any Encoder module with Adapter module support. + """FastPitch Adapter Mixin that can augment any Encoder module with Adapter module support. This mixin class should be used only with a top level ModelPT subclass, that includes an `encoder` submodule. This mixin class adds several utility methods which are propagated to the `encoder`. An Adapter module is any Pytorch nn.Module that possess a few properties : diff --git a/nemo/collections/tts/parts/preprocessing/audio_trimming.py b/nemo/collections/tts/parts/preprocessing/audio_trimming.py index 71a10e4a5cc6..105e17902a4f 100644 --- a/nemo/collections/tts/parts/preprocessing/audio_trimming.py +++ b/nemo/collections/tts/parts/preprocessing/audio_trimming.py @@ -25,19 +25,18 @@ class AudioTrimmer(ABC): - """Interface for silence trimming implementations - """ + """Interface for silence trimming implementations""" @abstractmethod def trim_audio(self, audio: np.array, sample_rate: int, audio_id: str) -> Tuple[np.array, int, int]: """Trim starting and trailing silence from the input audio. - Args: - audio: Numpy array containing audio samples. Float [-1.0, 1.0] format. - sample_rate: Sample rate of input audio. - audio_id: String identifier (eg. file name) used for logging. + Args: + audio: Numpy array containing audio samples. Float [-1.0, 1.0] format. + sample_rate: Sample rate of input audio. + audio_id: String identifier (eg. file name) used for logging. - Returns numpy array with trimmed audio, and integer sample indices representing the start and end - of speech within the original audio array. + Returns numpy array with trimmed audio, and integer sample indices representing the start and end + of speech within the original audio array. """ raise NotImplementedError @@ -54,22 +53,22 @@ def __init__( volume_norm: bool = True, ) -> None: """Energy/power based silence trimming using Librosa backend. - Args: - db_threshold: Audio frames at least db_threshold decibels below ref_amplitude will be - considered silence. - ref_amplitude: Amplitude threshold for classifying speech versus silence. - speech_frame_threshold: Start and end of speech will be detected where there are at least - speech_frame_threshold consecutive audio frames classified as speech. Setting this value higher - is more robust to false-positives (silence detected as speech), but setting it too high may result - in very short speech segments being cut out from the audio. - trim_win_length: Length of audio frames to use when doing speech detection. This does not need to match - the win_length used any other part of the code or model. - trim_hop_length: Stride of audio frames to use when doing speech detection. This does not need to match - the hop_length used any other part of the code or model. - pad_seconds: Audio duration in seconds to keep before and after each speech segment. - Set this to at least 0.1 to avoid cutting off any speech audio, with larger values - being safer but increasing the average silence duration left afterwards. - volume_norm: Whether to normalize the volume of audio before doing speech detection. + Args: + db_threshold: Audio frames at least db_threshold decibels below ref_amplitude will be + considered silence. + ref_amplitude: Amplitude threshold for classifying speech versus silence. + speech_frame_threshold: Start and end of speech will be detected where there are at least + speech_frame_threshold consecutive audio frames classified as speech. Setting this value higher + is more robust to false-positives (silence detected as speech), but setting it too high may result + in very short speech segments being cut out from the audio. + trim_win_length: Length of audio frames to use when doing speech detection. This does not need to match + the win_length used any other part of the code or model. + trim_hop_length: Stride of audio frames to use when doing speech detection. This does not need to match + the hop_length used any other part of the code or model. + pad_seconds: Audio duration in seconds to keep before and after each speech segment. + Set this to at least 0.1 to avoid cutting off any speech audio, with larger values + being safer but increasing the average silence duration left afterwards. + volume_norm: Whether to normalize the volume of audio before doing speech detection. """ assert db_threshold >= 0 assert ref_amplitude >= 0 @@ -99,7 +98,9 @@ def trim_audio(self, audio: np.array, sample_rate: int, audio_id: str = "") -> T ) start_frame, end_frame = get_start_and_end_of_speech_frames( - is_speech=speech_frames, speech_frame_threshold=self.speech_frame_threshold, audio_id=audio_id, + is_speech=speech_frames, + speech_frame_threshold=self.speech_frame_threshold, + audio_id=audio_id, ) if not start_frame and not end_frame: return np.array([]), 0, 0 @@ -135,21 +136,21 @@ def __init__( ) -> None: """Voice activity detection (VAD) based silence trimming. - Args: - model_name: NeMo VAD model to load. Valid configurations can be found with - EncDecClassificationModel.list_available_models() - vad_sample_rate: Sample rate used for pretrained VAD model. - vad_threshold: Softmax probability [0, 1] of VAD output, above which audio frames will be classified - as speech. - device: Device "cpu" or "cuda" to use for running the VAD model. - trim_win_length: Length of audio frames to use when doing speech detection. This does not need to match - the win_length used any other part of the code or model. - trim_hop_length: Stride of audio frames to use when doing speech detection. This does not need to match - the hop_length used any other part of the code or model. - pad_seconds: Audio duration in seconds to keep before and after each speech segment. - Set this to at least 0.1 to avoid cutting off any speech audio, with larger values - being safer but increasing the average silence duration left afterwards. - volume_norm: Whether to normalize the volume of audio before doing speech detection. + Args: + model_name: NeMo VAD model to load. Valid configurations can be found with + EncDecClassificationModel.list_available_models() + vad_sample_rate: Sample rate used for pretrained VAD model. + vad_threshold: Softmax probability [0, 1] of VAD output, above which audio frames will be classified + as speech. + device: Device "cpu" or "cuda" to use for running the VAD model. + trim_win_length: Length of audio frames to use when doing speech detection. This does not need to match + the win_length used any other part of the code or model. + trim_hop_length: Stride of audio frames to use when doing speech detection. This does not need to match + the hop_length used any other part of the code or model. + pad_seconds: Audio duration in seconds to keep before and after each speech segment. + Set this to at least 0.1 to avoid cutting off any speech audio, with larger values + being safer but increasing the average silence duration left afterwards. + volume_norm: Whether to normalize the volume of audio before doing speech detection. """ assert vad_sample_rate > 0 assert vad_threshold >= 0 @@ -217,7 +218,9 @@ def trim_audio(self, audio: np.array, sample_rate: int, audio_id: str = "") -> T speech_frames = self._detect_speech(audio=vad_audio) start_frame, end_frame = get_start_and_end_of_speech_frames( - is_speech=speech_frames, speech_frame_threshold=self.speech_frame_threshold, audio_id=audio_id, + is_speech=speech_frames, + speech_frame_threshold=self.speech_frame_threshold, + audio_id=audio_id, ) if not start_frame and not end_frame: return np.array([]), 0, 0 @@ -258,12 +261,12 @@ def get_start_and_end_of_speech_frames( is_speech: np.array, speech_frame_threshold: int, audio_id: str = "" ) -> Tuple[int, int]: """Finds the speech frames corresponding to the start and end of speech for an utterance. - Args: - is_speech: [num_frames] boolean array with true entries labeling speech frames. - speech_frame_threshold: The number of consecutive speech frames required to classify the speech boundaries. - audio_id: String identifier (eg. file name) used for logging. + Args: + is_speech: [num_frames] boolean array with true entries labeling speech frames. + speech_frame_threshold: The number of consecutive speech frames required to classify the speech boundaries. + audio_id: String identifier (eg. file name) used for logging. - Returns integers representing the frame indices of the start (inclusive) and end (exclusive) of speech. + Returns integers representing the frame indices of the start (inclusive) and end (exclusive) of speech. """ num_frames = is_speech.shape[0] @@ -295,14 +298,14 @@ def pad_sample_indices( start_sample: int, end_sample: int, max_sample: int, sample_rate: int, pad_seconds: float ) -> Tuple[int, int]: """Shift the input sample indices by pad_seconds in front and back within [0, max_sample] - Args: - start_sample: Start sample index - end_sample: End sample index - max_sample: Maximum sample index - sample_rate: Sample rate of audio - pad_seconds: Amount to pad/shift the indices by. - - Returns the sample indices after padding by the input amount. + Args: + start_sample: Start sample index + end_sample: End sample index + max_sample: Maximum sample index + sample_rate: Sample rate of audio + pad_seconds: Amount to pad/shift the indices by. + + Returns the sample indices after padding by the input amount. """ pad_samples = int(pad_seconds * sample_rate) start_sample = start_sample - pad_samples diff --git a/nemo/core/config/pytorch_lightning.py b/nemo/core/config/pytorch_lightning.py index 98c1a1157f3f..de9c1a967e4f 100644 --- a/nemo/core/config/pytorch_lightning.py +++ b/nemo/core/config/pytorch_lightning.py @@ -78,5 +78,7 @@ class TrainerConfig: # Register the trainer config. cs.store( - group="trainer", name="trainer", node=TrainerConfig, + group="trainer", + name="trainer", + node=TrainerConfig, ) diff --git a/nemo/core/neural_types/axes.py b/nemo/core/neural_types/axes.py index 17589b74ce22..0d9c3eef3f34 100644 --- a/nemo/core/neural_types/axes.py +++ b/nemo/core/neural_types/axes.py @@ -83,11 +83,11 @@ def from_str(label): class AxisType(object): """This class represents axis semantics and (optionally) it's dimensionality - Args: - kind (AxisKindAbstract): what kind of axis it is? For example Batch, Height, etc. - size (int, optional): specify if the axis should have a fixed size. By default it is set to None and you - typically do not want to set it for Batch and Time - is_list (bool, default=False): whether this is a list or a tensor axis + Args: + kind (AxisKindAbstract): what kind of axis it is? For example Batch, Height, etc. + size (int, optional): specify if the axis should have a fixed size. By default it is set to None and you + typically do not want to set it for Batch and Time + is_list (bool, default=False): whether this is a list or a tensor axis """ def __init__(self, kind: AxisKindAbstract, size: Optional[int] = None, is_list=False): diff --git a/nemo/core/neural_types/neural_type.py b/nemo/core/neural_types/neural_type.py index d00ba72df043..3d234f2dca5a 100644 --- a/nemo/core/neural_types/neural_type.py +++ b/nemo/core/neural_types/neural_type.py @@ -131,7 +131,7 @@ def compare(self, second) -> NeuralTypeComparisonResult: return NeuralTypeComparisonResult.INCOMPATIBLE def compare_and_raise_error(self, parent_type_name, port_name, second_object): - """ Method compares definition of one type with another and raises an error if not compatible. """ + """Method compares definition of one type with another and raises an error if not compatible.""" if torch.jit.is_scripting(): # suppress for TorchScript return diff --git a/nemo/core/optim/adafactor.py b/nemo/core/optim/adafactor.py index 3954a0cdbcc4..925854df338f 100644 --- a/nemo/core/optim/adafactor.py +++ b/nemo/core/optim/adafactor.py @@ -180,7 +180,7 @@ def step(self, closure=None): group["lr"] = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) - update = (grad ** 2) + group["eps"][0] + update = (grad**2) + group["eps"][0] if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] diff --git a/nemo/core/optim/novograd.py b/nemo/core/optim/novograd.py index 8e39ab9719cf..529697eb997b 100644 --- a/nemo/core/optim/novograd.py +++ b/nemo/core/optim/novograd.py @@ -60,7 +60,12 @@ def __init__( ): _check_valid_opt_params(lr, eps, betas) defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, amsgrad=amsgrad, + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + amsgrad=amsgrad, ) self.luc = luc self.luc_trust = luc_trust diff --git a/nemo/core/utils/process_launcher/launcher.py b/nemo/core/utils/process_launcher/launcher.py index c121a33cb373..f1072af14278 100644 --- a/nemo/core/utils/process_launcher/launcher.py +++ b/nemo/core/utils/process_launcher/launcher.py @@ -163,7 +163,11 @@ def execute_job( return proc, res, (std_error_buffer, drainerthread) -def launch(launcher, job_overrides: Sequence[Sequence[str]], initial_job_idx: int,) -> Sequence[JobReturn]: +def launch( + launcher, + job_overrides: Sequence[Sequence[str]], + initial_job_idx: int, +) -> Sequence[JobReturn]: """ Args: launcher: Reference to the Launched subclass @@ -188,7 +192,8 @@ def launch(launcher, job_overrides: Sequence[Sequence[str]], initial_job_idx: in logging.info( "ProcessLauncher({}) is launching {} jobs".format( - ",".join([f"{k}={v}" for k, v in runner_cfg.items()]), len(job_overrides), + ",".join([f"{k}={v}" for k, v in runner_cfg.items()]), + len(job_overrides), ) ) logging.info("Launching jobs, sweep output dir : {}".format(sweep_dir)) @@ -348,7 +353,13 @@ def __init__(self, **kwargs: Any) -> None: self.runner = kwargs # type: ProcessLauncherConfig - def setup(self, *, hydra_context: HydraContext, task_function: TaskFunction, config: DictConfig,) -> None: + def setup( + self, + *, + hydra_context: HydraContext, + task_function: TaskFunction, + config: DictConfig, + ) -> None: self.config = config self.task_function = task_function self.hydra_context = hydra_context @@ -359,7 +370,10 @@ def launch(self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int) - ConfigStore.instance().store( - group="hydra/launcher", name="nemo_launcher", node=ProcessLauncherConfig, provider="nemo_process_launcher", + group="hydra/launcher", + name="nemo_launcher", + node=ProcessLauncherConfig, + provider="nemo_process_launcher", ) Plugins.instance().register(ProcessLauncher) diff --git a/nemo/utils/debug_hook.py b/nemo/utils/debug_hook.py index 4f07f74bc032..5cc393f91735 100644 --- a/nemo/utils/debug_hook.py +++ b/nemo/utils/debug_hook.py @@ -133,11 +133,11 @@ def backward_hook(module, inputs, outputs): def get_tensor_hook(module, name, trainer, rank, logger, dump_to_file=False): """ - A tensor hook to dump all of the tensor weight norms and grad norms at the end of each of the backward steps. - For more details about the tensor hook, check https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html + A tensor hook to dump all of the tensor weight norms and grad norms at the end of each of the backward steps. + For more details about the tensor hook, check https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html Args: - module: the model module + module: the model module name: tensor name trainer: PTL trainer rank: worker rank diff --git a/nemo/utils/exceptions.py b/nemo/utils/exceptions.py index 75a06b93c816..5eed8298bf7a 100644 --- a/nemo/utils/exceptions.py +++ b/nemo/utils/exceptions.py @@ -14,7 +14,7 @@ class NeMoBaseException(Exception): - """ NeMo Base Exception. All exceptions created in NeMo should inherit from this class""" + """NeMo Base Exception. All exceptions created in NeMo should inherit from this class""" class LightningNotInstalledException(NeMoBaseException): diff --git a/nemo/utils/metaclasses.py b/nemo/utils/metaclasses.py index 5fad7cb15013..b4cccb14a425 100644 --- a/nemo/utils/metaclasses.py +++ b/nemo/utils/metaclasses.py @@ -17,8 +17,8 @@ class Singleton(type): - """ Implementation of a generic, tread-safe singleton meta-class. - Can be used as meta-class, i.e. will create + """Implementation of a generic, tread-safe singleton meta-class. + Can be used as meta-class, i.e. will create """ # List of instances - one per class. @@ -27,7 +27,7 @@ class Singleton(type): __lock = threading.Lock() def __call__(cls, *args, **kwargs): - """ Returns singleton instance. A thread safe implementation. """ + """Returns singleton instance. A thread safe implementation.""" if cls not in cls.__instances: # Enter critical section. with cls.__lock: diff --git a/nemo/utils/timers.py b/nemo/utils/timers.py index a35c257652b9..7aed85ddd3e1 100644 --- a/nemo/utils/timers.py +++ b/nemo/utils/timers.py @@ -1,6 +1,7 @@ """ This module support timing of code blocks. """ + # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/pyproject.toml b/pyproject.toml index a84df104c516..54ac5b8b2724 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,32 +92,16 @@ skip_string_normalization = true # https://black.readthedocs.io/en/stable/the_black_code_style/index.html # `required_version` is necessary for consistency (other `black` versions will fail to reformat files) required_version = "24" -target-version = ['py310', 'py311', 'py312'] +target-version = ['py310', 'py311', 'py312', 'py313'] +# by default exclude Jupyter Notebooks (but can be reformated when passed directly) extend-exclude = ''' -# A regex preceded with ^/ will apply only to files and directories -# in the root of the project. -# include here only current collections, new collections should not be ignored -# exclude the collection once it is reformatted (due to changes in PRs) ( - ^\/docs\/ - | ^\/external\/ - | ^\/examples\/ - | ^\/nemo\/collections\/asr\/ - | ^\/nemo\/collections\/common\/ - | ^\/nemo\/collections\/multimodal\/ - | ^\/nemo\/collections\/nlp\/ - | ^\/nemo\/collections\/tts\/ - | ^\/nemo\/collections\/vision\/ - | ^\/nemo\/core\/ - | ^\/nemo\/utils\/ - | ^\/scripts\/ - | ^\/tests\/ - | ^\/tools\/ - | ^\/tutorials\/ - | ^\/setup.py + \.ipynb + | \.ipynb_checkpoints ) ''' + [tool.pytest.ini_options] # durations=0 will display all tests execution time, sorted in ascending order starting from from the slowest one. # -vv will also display tests with durration = 0.00s diff --git a/scripts/asr_language_modeling/ngram_lm/make_phone_lm.py b/scripts/asr_language_modeling/ngram_lm/make_phone_lm.py index e132b0cfcd6e..be59f6824f2f 100755 --- a/scripts/asr_language_modeling/ngram_lm/make_phone_lm.py +++ b/scripts/asr_language_modeling/ngram_lm/make_phone_lm.py @@ -214,7 +214,8 @@ def AddRawCountsFromStandardInput(self): lines_processed += 1 if lines_processed == 0 or args.verbose > 0: print( - "make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr, + "make_phone_lm.py: processed {0} lines of input".format(lines_processed), + file=sys.stderr, ) # This backs off the counts by subtracting 1 and assigning the subtracted @@ -271,7 +272,8 @@ def Print(self, info_string): if self.backoff_symbol in counts_for_hist.word_to_count: total_excluding_backoff -= counts_for_hist.word_to_count[self.backoff_symbol] print( - "total count = {0}, excluding backoff = {1}".format(total, total_excluding_backoff), file=sys.stderr, + "total count = {0}, excluding backoff = {1}".format(total, total_excluding_backoff), + file=sys.stderr, ) def GetHistToStateMap(self): @@ -297,7 +299,8 @@ def GetProb(self, hist, word): total_count = float(counts_for_hist.total_count) if not word in counts_for_hist.word_to_count: print( - "make_phone_lm.py: no prob for {0} -> {1} " "[no such count]".format(hist, word), file=sys.stderr, + "make_phone_lm.py: no prob for {0} -> {1} " "[no such count]".format(hist, word), + file=sys.stderr, ) return None prob = float(counts_for_hist.word_to_count[word]) / total_count @@ -415,7 +418,11 @@ def PrintAsFst(self, word_disambig_symbol): assert word == self.backoff_symbol backoff_fst_state = hist_to_state[hist[1 : len(hist)]] print( - this_fst_state, backoff_fst_state, word_disambig_symbol, 0, this_cost, + this_fst_state, + backoff_fst_state, + word_disambig_symbol, + 0, + this_cost, ) # This function returns a set of n-grams that cannot currently be pruned @@ -600,7 +607,8 @@ def GetLikeChangeFromPruningNgram(self, hist, word): backoff_count = self.GetProb(hist[1:], word) * backoff_total except Exception: print( - "problem getting backoff count: hist = {0}, word = {1}".format(hist, word), file=sys.stderr, + "problem getting backoff count: hist = {0}, word = {1}".format(hist, word), + file=sys.stderr, ) sys.exit(1) @@ -788,7 +796,12 @@ def PrintAsArpa(self): # print the number of n-grams. Add 1 for the 1-gram # section because of , we print -99 as the prob so we # have a place to put the backoff prob. - print("ngram {0}={1}".format(hist_len + 1, self.GetNumNgrams(hist_len) + (1 if hist_len == 0 else 0),)) + print( + "ngram {0}={1}".format( + hist_len + 1, + self.GetNumNgrams(hist_len) + (1 if hist_len == 0 else 0), + ) + ) print("") @@ -808,7 +821,8 @@ def PrintAsArpa(self): assert prob != None and prob > 0 backoff_prob = self.GetProb((hist) + (word,), self.backoff_symbol) line = "{0}\t{1}".format( - "%.5f" % math.log10(prob), " ".join(self.IntToString(x) for x in hist + (word,)), + "%.5f" % math.log10(prob), + " ".join(self.IntToString(x) for x in hist + (word,)), ) if backoff_prob != None: line += "\t{0}".format("%.5f" % math.log10(backoff_prob)) diff --git a/scripts/dataset_processing/fisher_audio_to_wav.py b/scripts/dataset_processing/fisher_audio_to_wav.py index 0f231b6b6073..c4ea1f5d0371 100644 --- a/scripts/dataset_processing/fisher_audio_to_wav.py +++ b/scripts/dataset_processing/fisher_audio_to_wav.py @@ -30,10 +30,18 @@ parser = argparse.ArgumentParser(description='Convert Fisher .sph to .wav') parser.add_argument( - "--data_root", default=None, type=str, required=True, help="The path to the root Fisher dataset folder.", + "--data_root", + default=None, + type=str, + required=True, + help="The path to the root Fisher dataset folder.", ) parser.add_argument( - "--dest_root", default=None, type=str, required=True, help="Path to the destination root directory.", + "--dest_root", + default=None, + type=str, + required=True, + help="Path to the destination root directory.", ) args = parser.parse_args() @@ -85,13 +93,27 @@ def main(): logging.info("\n\nConverting audio for Part 1") __process_set( - os.path.join(data_root, "LDC2004S13-Part1", "fisher_eng_tr_sp_d*", "audio", "*", "*.sph",), + os.path.join( + data_root, + "LDC2004S13-Part1", + "fisher_eng_tr_sp_d*", + "audio", + "*", + "*.sph", + ), os.path.join(dest_root, "LDC2004S13-Part1", "audio_wav"), ) logging.info("\n\nConverting audio for Part 2") __process_set( - os.path.join(data_root, "LDC2005S13-Part2", "fe_03_p2_sph*", "audio", "*", "*.sph",), + os.path.join( + data_root, + "LDC2005S13-Part2", + "fe_03_p2_sph*", + "audio", + "*", + "*.sph", + ), os.path.join(dest_root, "LDC2005S13-Part2", "audio_wav"), ) diff --git a/scripts/dataset_processing/g2p/syllabify.py b/scripts/dataset_processing/g2p/syllabify.py index 400e531de5c1..725f2a4daa04 100644 --- a/scripts/dataset_processing/g2p/syllabify.py +++ b/scripts/dataset_processing/g2p/syllabify.py @@ -230,7 +230,7 @@ def syllabify(pron, alaska_rule=True): nuclei = [] onsets = [] i = -1 - for (j, seg) in enumerate(mypron): + for j, seg in enumerate(mypron): if seg in VOWELS: nuclei.append([seg]) onsets.append(mypron[i + 1 : j]) # actually interludes, r.n. @@ -276,7 +276,7 @@ def destress(syllab): Generate a syllabification with nuclear stress information removed """ syls = [] - for (onset, nucleus, coda) in syllab: + for onset, nucleus, coda in syllab: nuke = [p[:-1] if p[-1] in {"0", "1", "2"} else p for p in nucleus] syls.append((onset, nuke, coda)) return syls diff --git a/scripts/dataset_processing/get_aishell_data.py b/scripts/dataset_processing/get_aishell_data.py index 0390dee95079..8147a5f2333a 100644 --- a/scripts/dataset_processing/get_aishell_data.py +++ b/scripts/dataset_processing/get_aishell_data.py @@ -136,7 +136,11 @@ def __process_data(data_folder: str, dst_folder: str): duration = float(duration) json_lines.append( json.dumps( - {"audio_filepath": os.path.abspath(audio_path), "duration": duration, "text": text,}, + { + "audio_filepath": os.path.abspath(audio_path), + "duration": duration, + "text": text, + }, ensure_ascii=False, ) ) diff --git a/scripts/dataset_processing/get_librispeech_data.py b/scripts/dataset_processing/get_librispeech_data.py index 48cddfd4bf63..75d1dd491a72 100644 --- a/scripts/dataset_processing/get_librispeech_data.py +++ b/scripts/dataset_processing/get_librispeech_data.py @@ -190,8 +190,15 @@ def main(): __extract_file(filepath, data_root) logging.info("Processing {0}".format(data_set)) __process_data( - os.path.join(os.path.join(data_root, "LibriSpeech"), data_set.replace("_", "-"),), - os.path.join(os.path.join(data_root, "LibriSpeech"), data_set.replace("_", "-"),) + "-processed", + os.path.join( + os.path.join(data_root, "LibriSpeech"), + data_set.replace("_", "-"), + ), + os.path.join( + os.path.join(data_root, "LibriSpeech"), + data_set.replace("_", "-"), + ) + + "-processed", os.path.join(data_root, data_set + ".json"), num_workers=num_workers, ) diff --git a/scripts/dataset_processing/get_openslr_rir_data.py b/scripts/dataset_processing/get_openslr_rir_data.py index 898555750469..adb7a78ef323 100644 --- a/scripts/dataset_processing/get_openslr_rir_data.py +++ b/scripts/dataset_processing/get_openslr_rir_data.py @@ -117,7 +117,8 @@ def __process_data(data_folder: str, dst_folder: str, manifest_file: str): else: for chan in range(1, n_chans + 1): chan_file_name = os.path.join( - real_rir_folder, os.path.splitext(os.path.basename(rir_f))[0] + "-" + str(chan) + ".wav", + real_rir_folder, + os.path.splitext(os.path.basename(rir_f))[0] + "-" + str(chan) + ".wav", ) _ = subprocess.check_output(f"sox {rir_f} {chan_file_name} remix {chan}", shell=True) diff --git a/scripts/dataset_processing/kaldi2json.py b/scripts/dataset_processing/kaldi2json.py index c2be1aa28fdc..a273723e8194 100644 --- a/scripts/dataset_processing/kaldi2json.py +++ b/scripts/dataset_processing/kaldi2json.py @@ -24,10 +24,16 @@ def main(): parser = argparse.ArgumentParser(description="Convert kaldi data folder to manifest.json") parser.add_argument( - "--data_dir", required=True, type=str, help="data in kaldi format", + "--data_dir", + required=True, + type=str, + help="data in kaldi format", ) parser.add_argument( - "--manifest", required=True, type=str, help="path to store the manifest file", + "--manifest", + required=True, + type=str, + help="path to store the manifest file", ) parser.add_argument( "--with_aux_data", @@ -74,12 +80,20 @@ def main(): # read text text = pd.read_csv( - required_data["text"], sep="^([^ ]+) ", engine="python", header=None, usecols=[1, 2], names=["label", "text"], + required_data["text"], + sep="^([^ ]+) ", + engine="python", + header=None, + usecols=[1, 2], + names=["label", "text"], ) # read segments segments = pd.read_csv( - required_data["duration"], sep=" ", header=None, names=["label", "wav_label", "offset", "end"], + required_data["duration"], + sep=" ", + header=None, + names=["label", "wav_label", "offset", "end"], ) # add offset if needed if len(segments.offset) > len(segments.offset[segments.offset == 0.0]): @@ -89,7 +103,10 @@ def main(): # merge data wav_segments_text = pd.merge( - pd.merge(segments, wavscp, how="inner", on="wav_label"), text, how="inner", on="label", + pd.merge(segments, wavscp, how="inner", on="wav_label"), + text, + how="inner", + on="label", ) if args.with_aux_data: diff --git a/scripts/dataset_processing/process_an4_data.py b/scripts/dataset_processing/process_an4_data.py index fb6e0bcc3ac6..3a0587c78364 100644 --- a/scripts/dataset_processing/process_an4_data.py +++ b/scripts/dataset_processing/process_an4_data.py @@ -37,7 +37,10 @@ def build_manifest(data_root, transcripts_path, manifest_path, wav_path): file_id = line[line.find('(') + 1 : -2] # e.g. "cen4-fash-b" audio_path = os.path.join( - data_root, wav_path, file_id[file_id.find('-') + 1 : file_id.rfind('-')], file_id + '.wav', + data_root, + wav_path, + file_id[file_id.find('-') + 1 : file_id.rfind('-')], + file_id + '.wav', ) duration = librosa.core.get_duration(filename=audio_path) diff --git a/scripts/dataset_processing/process_fisher_data.py b/scripts/dataset_processing/process_fisher_data.py index 96ddced060e1..a345571046ab 100644 --- a/scripts/dataset_processing/process_fisher_data.py +++ b/scripts/dataset_processing/process_fisher_data.py @@ -39,7 +39,11 @@ parser = argparse.ArgumentParser(description="Fisher Data Processing") parser.add_argument( - "--audio_root", default=None, type=str, required=True, help="The path to the root of the audio (wav) data folder.", + "--audio_root", + default=None, + type=str, + required=True, + help="The path to the root of the audio (wav) data folder.", ) parser.add_argument( "--transcript_root", @@ -49,21 +53,34 @@ help="The path to the root of the transcript data folder.", ) parser.add_argument( - "--dest_root", default=None, type=str, required=True, help="Path to the destination root directory.", + "--dest_root", + default=None, + type=str, + required=True, + help="Path to the destination root directory.", ) # Optional arguments parser.add_argument( - "--min_slice_duration", default=10.0, type=float, help="Minimum audio slice duration after processing.", + "--min_slice_duration", + default=10.0, + type=float, + help="Minimum audio slice duration after processing.", ) parser.add_argument( - "--keep_low_conf", action="store_true", help="Keep all utterances with low confidence transcripts", + "--keep_low_conf", + action="store_true", + help="Keep all utterances with low confidence transcripts", ) parser.add_argument( - "--remove_noises", action="store_true", help="Removes transcripted noises such as [laughter].", + "--remove_noises", + action="store_true", + help="Removes transcripted noises such as [laughter].", ) parser.add_argument( - "--noises_to_emoji", action="store_true", help="Converts transcripts for noises to an emoji character.", + "--noises_to_emoji", + action="store_true", + help="Converts transcripts for noises to an emoji character.", ) args = parser.parse_args() @@ -271,7 +288,10 @@ def __process_one_file( # Append utterance to buffer transcript_buffers[idx] += content audio_buffers[idx].append( - audio_data[floor(t_start * sample_rate) : ceil(t_end * sample_rate), idx,] + audio_data[ + floor(t_start * sample_rate) : ceil(t_end * sample_rate), + idx, + ] ) buffer_durations[idx] += duration @@ -310,7 +330,14 @@ def __partition_name(file_count): def __process_data( - audio_root, transcript_root, dst_root, min_slice_duration, file_count, keep_low_conf, rem_noises, emojify, + audio_root, + transcript_root, + dst_root, + min_slice_duration, + file_count, + keep_low_conf, + rem_noises, + emojify, ): """ Converts Fisher wav files to numpy arrays, segments audio and transcripts. diff --git a/scripts/dataset_processing/process_hub5_data.py b/scripts/dataset_processing/process_hub5_data.py index e67a0d04a01f..c7d8c960d2a4 100644 --- a/scripts/dataset_processing/process_hub5_data.py +++ b/scripts/dataset_processing/process_hub5_data.py @@ -32,7 +32,11 @@ parser = argparse.ArgumentParser(description="Prepare HUB5 data for training/eval") parser.add_argument( - "--data_root", default=None, type=str, required=True, help="The path to the root LDC HUB5 dataset directory.", + "--data_root", + default=None, + type=str, + required=True, + help="The path to the root LDC HUB5 dataset directory.", ) parser.add_argument( "--dest_root", @@ -44,13 +48,25 @@ # Optional arguments parser.add_argument( - "--min_slice_duration", default=10.0, type=float, help="Minimum audio slice duration after processing.", + "--min_slice_duration", + default=10.0, + type=float, + help="Minimum audio slice duration after processing.", ) args = parser.parse_args() StmUtterance = namedtuple( - 'StmUtterance', ['filename', 'channel', 'speaker_id', 'begin', 'end', 'label', 'transcript',], + 'StmUtterance', + [ + 'filename', + 'channel', + 'speaker_id', + 'begin', + 'end', + 'label', + 'transcript', + ], ) STM_LINE_FMT = re.compile(r"^(\w+)\s+(\w+)\s+(\w+)\s+([0-9.]+)\s+([0-9.]+)\s+(<.*>)?\s+(.+)$") @@ -62,7 +78,12 @@ def get_utt_id(segment): """ Gives utterance IDs in a form like: en_4156-a-36558-37113 """ - return "{}-{}-{}-{}".format(segment.filename, segment.channel, int(segment.begin * 100), int(segment.end * 100),) + return "{}-{}-{}-{}".format( + segment.filename, + segment.channel, + int(segment.begin * 100), + int(segment.end * 100), + ) def convert_utterances(sph_path, wav_path): @@ -92,7 +113,12 @@ def process_transcripts(dataset_root): """ Reads in transcripts for each audio segment and processes them. """ - stm_path = os.path.join(dataset_root, "2000_hub5_eng_eval_tr", "reference", "hub5e00.english.000405.stm",) + stm_path = os.path.join( + dataset_root, + "2000_hub5_eng_eval_tr", + "reference", + "hub5e00.english.000405.stm", + ) results = [] chars = set() @@ -198,7 +224,10 @@ def segment_audio(info_list, dest_root, min_slice_duration): transcript_buffer += info.transcript channel = 0 if info.channel.lower() == 'a' else 1 audio_buffer.append( - audio_data[floor(info.begin * sample_rate) : ceil(info.end * sample_rate), channel,] + audio_data[ + floor(info.begin * sample_rate) : ceil(info.end * sample_rate), + channel, + ] ) buffer_duration += info.end - info.begin diff --git a/scripts/dataset_processing/speaker_tasks/get_aishell_diarization_data.py b/scripts/dataset_processing/speaker_tasks/get_aishell_diarization_data.py index c48268921f91..7528c28291ad 100644 --- a/scripts/dataset_processing/speaker_tasks/get_aishell_diarization_data.py +++ b/scripts/dataset_processing/speaker_tasks/get_aishell_diarization_data.py @@ -64,7 +64,9 @@ def __process_data(dataset_url: str, dataset_path: Path, manifest_output_path: P with open(rttm_list, 'w') as f: f.write('\n'.join(rttm_files)) create_manifest( - str(audio_list), manifest_output_path, rttm_path=str(rttm_list), + str(audio_list), + manifest_output_path, + rttm_path=str(rttm_list), ) diff --git a/scripts/dataset_processing/speaker_tasks/get_ami_data.py b/scripts/dataset_processing/speaker_tasks/get_ami_data.py index e3fc64db6f37..59ab274ea919 100644 --- a/scripts/dataset_processing/speaker_tasks/get_ami_data.py +++ b/scripts/dataset_processing/speaker_tasks/get_ami_data.py @@ -40,7 +40,10 @@ default='AMI_test_manifest.json', ) parser.add_argument( - "--dev_manifest_filepath", help="path to output dev manifest file", type=str, default='AMI_dev_manifest.json', + "--dev_manifest_filepath", + help="path to output dev manifest file", + type=str, + default='AMI_dev_manifest.json', ) parser.add_argument( "--train_manifest_filepath", diff --git a/scripts/dataset_processing/speaker_tasks/get_voxconverse.py b/scripts/dataset_processing/speaker_tasks/get_voxconverse.py index b2c563374c41..7968e220ca96 100644 --- a/scripts/dataset_processing/speaker_tasks/get_voxconverse.py +++ b/scripts/dataset_processing/speaker_tasks/get_voxconverse.py @@ -45,7 +45,9 @@ def _generate_manifest(data_root: Path, audio_path: Path, rttm_path: Path, manif with open(rttm_list, 'w') as f: f.write('\n'.join([str(os.path.join(rttm_path, x)) for x in os.listdir(rttm_path)])) create_manifest( - audio_list, str(manifest_output_path), rttm_path=rttm_list, + audio_list, + str(manifest_output_path), + rttm_path=rttm_list, ) diff --git a/scripts/dataset_processing/tts/aishell3/get_data.py b/scripts/dataset_processing/tts/aishell3/get_data.py index 1b3043bbf0d3..f3a89cc43a5b 100755 --- a/scripts/dataset_processing/tts/aishell3/get_data.py +++ b/scripts/dataset_processing/tts/aishell3/get_data.py @@ -81,7 +81,10 @@ def __process_transcript(file_path: str): cc = OpenCC('t2s') # Create normalizer text_normalizer = Normalizer( - lang="zh", input_case="cased", overwrite_cache=True, cache_dir=str(file_path / "cache_dir"), + lang="zh", + input_case="cased", + overwrite_cache=True, + cache_dir=str(file_path / "cache_dir"), ) text_normalizer_call_kwargs = {"punct_pre_process": True, "punct_post_process": True} normalizer_call = lambda x: text_normalizer.normalize(x, **text_normalizer_call_kwargs) @@ -160,7 +163,11 @@ def main(): __extract_file(str(tarred_data_path), str(args.data_root)) __process_data( - args.data_root, args.val_size, args.test_size, args.seed_for_ds_split, args.manifests_path, + args.data_root, + args.val_size, + args.test_size, + args.seed_for_ds_split, + args.manifests_path, ) diff --git a/scripts/dataset_processing/tts/compute_feature_stats.py b/scripts/dataset_processing/tts/compute_feature_stats.py index 3c1b4a2ecf71..c3d5ac516c3c 100644 --- a/scripts/dataset_processing/tts/compute_feature_stats.py +++ b/scripts/dataset_processing/tts/compute_feature_stats.py @@ -71,16 +71,28 @@ def get_args(): parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute TTS feature statistics.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Compute TTS feature statistics.", ) parser.add_argument( - "--feature_config_path", required=True, type=Path, help="Path to feature config file.", + "--feature_config_path", + required=True, + type=Path, + help="Path to feature config file.", ) parser.add_argument( - "--manifest_path", required=True, type=Path, action="append", help="Path(s) to training manifest.", + "--manifest_path", + required=True, + type=Path, + action="append", + help="Path(s) to training manifest.", ) parser.add_argument( - "--audio_dir", required=True, type=Path, action="append", help="Path(s) to base directory with audio data.", + "--audio_dir", + required=True, + type=Path, + action="append", + help="Path(s) to base directory with audio data.", ) parser.add_argument( "--feature_dir", @@ -90,7 +102,10 @@ def get_args(): help="Path(s) to directory where feature data was stored.", ) parser.add_argument( - "--feature_names", default="pitch,energy", type=str, help="Comma separated list of features to process.", + "--feature_names", + default="pitch,energy", + type=str, + help="Comma separated list of features to process.", ) parser.add_argument( "--mask_field", @@ -141,7 +156,7 @@ def main(): f"{len(feature_dirs)}" ) - for (manifest_path, audio_dir, feature_dir) in zip(manifest_paths, audio_dirs, feature_dirs): + for manifest_path, audio_dir, feature_dir in zip(manifest_paths, audio_dirs, feature_dirs): if not manifest_path.exists(): raise ValueError(f"Manifest {manifest_path} does not exist.") @@ -172,7 +187,7 @@ def main(): # for that speaker feature_stats = {name: defaultdict(list) for name in feature_names} - for (manifest_path, audio_dir, feature_dir) in zip(manifest_paths, audio_dirs, feature_dirs): + for manifest_path, audio_dir, feature_dir in zip(manifest_paths, audio_dirs, feature_dirs): entries = read_manifest(manifest_path) for entry in tqdm(entries): diff --git a/scripts/dataset_processing/tts/compute_features.py b/scripts/dataset_processing/tts/compute_features.py index 0f9953ef63fd..0f49110e914e 100644 --- a/scripts/dataset_processing/tts/compute_features.py +++ b/scripts/dataset_processing/tts/compute_features.py @@ -38,19 +38,32 @@ def get_args(): parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute TTS features.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Compute TTS features.", ) parser.add_argument( - "--feature_config_path", required=True, type=Path, help="Path to feature config file.", + "--feature_config_path", + required=True, + type=Path, + help="Path to feature config file.", ) parser.add_argument( - "--manifest_path", required=True, type=Path, help="Path to training manifest.", + "--manifest_path", + required=True, + type=Path, + help="Path to training manifest.", ) parser.add_argument( - "--audio_dir", required=True, type=Path, help="Path to base directory with audio data.", + "--audio_dir", + required=True, + type=Path, + help="Path to base directory with audio data.", ) parser.add_argument( - "--feature_dir", required=True, type=Path, help="Path to directory where feature data will be stored.", + "--feature_dir", + required=True, + type=Path, + help="Path to directory where feature data will be stored.", ) parser.add_argument( "--dedupe_files", @@ -58,7 +71,9 @@ def get_args(): help="If given, will only process the first manifest entry found for each audio file.", ) parser.add_argument( - "--overwrite", action=argparse.BooleanOptionalAction, help="Whether to overwrite existing feature files.", + "--overwrite", + action=argparse.BooleanOptionalAction, + help="Whether to overwrite existing feature files.", ) parser.add_argument( "--num_workers", default=1, type=int, help="Number of parallel threads to use. If -1 all CPUs are used." diff --git a/scripts/dataset_processing/tts/compute_speaker_stats.py b/scripts/dataset_processing/tts/compute_speaker_stats.py index 5061edb216c9..dd4f9664bad1 100644 --- a/scripts/dataset_processing/tts/compute_speaker_stats.py +++ b/scripts/dataset_processing/tts/compute_speaker_stats.py @@ -42,13 +42,20 @@ def get_args(): parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute speaker level pitch statistics.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Compute speaker level pitch statistics.", ) parser.add_argument( - "--manifest_path", required=True, type=Path, help="Path to training manifest.", + "--manifest_path", + required=True, + type=Path, + help="Path to training manifest.", ) parser.add_argument( - "--sup_data_path", default=Path("sup_data"), type=Path, help="Path to base directory with supplementary data.", + "--sup_data_path", + default=Path("sup_data"), + type=Path, + help="Path to base directory with supplementary data.", ) parser.add_argument( "--pitch_stats_path", diff --git a/scripts/dataset_processing/tts/create_speaker_map.py b/scripts/dataset_processing/tts/create_speaker_map.py index ab8dd7b0828b..289415d9bef2 100644 --- a/scripts/dataset_processing/tts/create_speaker_map.py +++ b/scripts/dataset_processing/tts/create_speaker_map.py @@ -48,10 +48,17 @@ def get_args(): description="Create mapping from speaker names to numerical speaker indices.", ) parser.add_argument( - "--manifest_path", required=True, type=Path, action="append", help="Path to training manifest(s).", + "--manifest_path", + required=True, + type=Path, + action="append", + help="Path to training manifest(s).", ) parser.add_argument( - "--speaker_map_path", required=True, type=Path, help="Path for output speaker index JSON", + "--speaker_map_path", + required=True, + type=Path, + help="Path for output speaker index JSON", ) parser.add_argument( "--overwrite", diff --git a/scripts/dataset_processing/tts/generate_mels.py b/scripts/dataset_processing/tts/generate_mels.py index 24e09a7c50e4..9307df5ddcf0 100644 --- a/scripts/dataset_processing/tts/generate_mels.py +++ b/scripts/dataset_processing/tts/generate_mels.py @@ -118,7 +118,12 @@ def __generate_mels(entry, spec_model, device, use_beta_binomial_interpolator, m ) spectrogram = spec_model.forward( - text=text, input_lens=text_len, spec=spect, mel_lens=spect_len, attn_prior=attn_prior, speaker=speaker, + text=text, + input_lens=text_len, + spec=spect, + mel_lens=spect_len, + attn_prior=attn_prior, + speaker=speaker, )[0] save_path = mel_root / f"{Path(entry['audio_filepath']).stem}.npy" diff --git a/scripts/dataset_processing/tts/hui_acg/get_data.py b/scripts/dataset_processing/tts/hui_acg/get_data.py index 668d532f321a..6f7da30b4f97 100644 --- a/scripts/dataset_processing/tts/hui_acg/get_data.py +++ b/scripts/dataset_processing/tts/hui_acg/get_data.py @@ -127,7 +127,14 @@ def __save_json(json_file, dict_list): def __process_data( - dataset_path, stat_path_root, speaker_id, min_duration, max_duration, val_size, test_size, seed_for_ds_split, + dataset_path, + stat_path_root, + speaker_id, + min_duration, + max_duration, + val_size, + test_size, + seed_for_ds_split, ): logging.info(f"Preparing JSON split for speaker {speaker_id}.") # parse statistic.txt @@ -190,7 +197,10 @@ def __text_normalization(json_file, num_workers=-1): "punct_post_process": True, } text_normalizer = Normalizer( - lang="de", input_case="cased", overwrite_cache=True, cache_dir=str(json_file.parent / "cache_dir"), + lang="de", + input_case="cased", + overwrite_cache=True, + cache_dir=str(json_file.parent / "cache_dir"), ) def normalizer_call(x): diff --git a/scripts/dataset_processing/tts/ljspeech/get_data.py b/scripts/dataset_processing/tts/ljspeech/get_data.py index 8007b5a0f05a..912d20ce0274 100644 --- a/scripts/dataset_processing/tts/ljspeech/get_data.py +++ b/scripts/dataset_processing/tts/ljspeech/get_data.py @@ -62,7 +62,10 @@ def __extract_file(filepath, data_dir): def __process_data(data_root): text_normalizer = Normalizer( - lang="en", input_case="cased", overwrite_cache=True, cache_dir=data_root / "cache_dir", + lang="en", + input_case="cased", + overwrite_cache=True, + cache_dir=data_root / "cache_dir", ) text_normalizer_call_kwargs = {"punct_pre_process": True, "punct_post_process": True} normalizer_call = lambda x: text_normalizer.normalize(x, **text_normalizer_call_kwargs) diff --git a/scripts/dataset_processing/tts/preprocess_audio.py b/scripts/dataset_processing/tts/preprocess_audio.py index fd027269160c..d6097dfcff71 100644 --- a/scripts/dataset_processing/tts/preprocess_audio.py +++ b/scripts/dataset_processing/tts/preprocess_audio.py @@ -59,19 +59,32 @@ def get_args(): parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute speaker level pitch statistics.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Compute speaker level pitch statistics.", ) parser.add_argument( - "--input_manifest", required=True, type=Path, help="Path to input training manifest.", + "--input_manifest", + required=True, + type=Path, + help="Path to input training manifest.", ) parser.add_argument( - "--input_audio_dir", required=True, type=Path, help="Path to base directory with audio files.", + "--input_audio_dir", + required=True, + type=Path, + help="Path to base directory with audio files.", ) parser.add_argument( - "--output_manifest", required=True, type=Path, help="Path to output training manifest with processed audio.", + "--output_manifest", + required=True, + type=Path, + help="Path to output training manifest with processed audio.", ) parser.add_argument( - "--output_audio_dir", required=True, type=Path, help="Path to output directory for audio files.", + "--output_audio_dir", + required=True, + type=Path, + help="Path to output directory for audio files.", ) parser.add_argument( "--overwrite_audio", diff --git a/scripts/dataset_processing/tts/preprocess_text.py b/scripts/dataset_processing/tts/preprocess_text.py index 6afab42a1d6b..e4324bc7f92d 100644 --- a/scripts/dataset_processing/tts/preprocess_text.py +++ b/scripts/dataset_processing/tts/preprocess_text.py @@ -49,13 +49,20 @@ def get_args(): parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Process and normalize text data.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Process and normalize text data.", ) parser.add_argument( - "--input_manifest", required=True, type=Path, help="Path to input training manifest.", + "--input_manifest", + required=True, + type=Path, + help="Path to input training manifest.", ) parser.add_argument( - "--output_manifest", required=True, type=Path, help="Path to output training manifest with processed text.", + "--output_manifest", + required=True, + type=Path, + help="Path to output training manifest with processed text.", ) parser.add_argument( "--overwrite", @@ -63,13 +70,21 @@ def get_args(): help="Whether to overwrite the output manifest file if it exists.", ) parser.add_argument( - "--text_key", default="text", type=str, help="Input text field to normalize.", + "--text_key", + default="text", + type=str, + help="Input text field to normalize.", ) parser.add_argument( - "--normalized_text_key", default="normalized_text", type=str, help="Output field to save normalized text to.", + "--normalized_text_key", + default="normalized_text", + type=str, + help="Output field to save normalized text to.", ) parser.add_argument( - "--lower_case", action=argparse.BooleanOptionalAction, help="Whether to convert the final text to lower case.", + "--lower_case", + action=argparse.BooleanOptionalAction, + help="Whether to convert the final text to lower case.", ) parser.add_argument( "--normalizer_config_path", diff --git a/scripts/dataset_processing/tts/resynthesize_dataset.py b/scripts/dataset_processing/tts/resynthesize_dataset.py index 652fde299572..f6002ee8b1c1 100644 --- a/scripts/dataset_processing/tts/resynthesize_dataset.py +++ b/scripts/dataset_processing/tts/resynthesize_dataset.py @@ -218,13 +218,22 @@ def argument_parser() -> argparse.ArgumentParser: description="Resynthesize TTS dataset using a pretrained text-to-spectrogram model", ) parser.add_argument( - "--model-path", required=True, type=Path, help="Path to a checkpoint (either .nemo or .ckpt)", + "--model-path", + required=True, + type=Path, + help="Path to a checkpoint (either .nemo or .ckpt)", ) parser.add_argument( - "--input-json-manifest", required=True, type=Path, help="Path to the input JSON manifest", + "--input-json-manifest", + required=True, + type=Path, + help="Path to the input JSON manifest", ) parser.add_argument( - "--input-sup-data-path", required=True, type=Path, help="sup_data_path for the JSON manifest", + "--input-sup-data-path", + required=True, + type=Path, + help="sup_data_path for the JSON manifest", ) parser.add_argument( "--output-folder", diff --git a/scripts/dataset_processing/tts/sfbilingual/get_data.py b/scripts/dataset_processing/tts/sfbilingual/get_data.py index 806f9882a9f4..7db28d36dd40 100755 --- a/scripts/dataset_processing/tts/sfbilingual/get_data.py +++ b/scripts/dataset_processing/tts/sfbilingual/get_data.py @@ -64,7 +64,10 @@ def __process_transcript(file_path: str): cc = OpenCC('t2s') # Create normalizer text_normalizer = Normalizer( - lang="zh", input_case="cased", overwrite_cache=True, cache_dir=str(file_path / "cache_dir"), + lang="zh", + input_case="cased", + overwrite_cache=True, + cache_dir=str(file_path / "cache_dir"), ) text_normalizer_call_kwargs = {"punct_pre_process": True, "punct_post_process": True} normalizer_call = lambda x: text_normalizer.normalize(x, **text_normalizer_call_kwargs) @@ -122,7 +125,11 @@ def main(): dataset_root = args.data_root dataset_root.mkdir(parents=True, exist_ok=True) __process_data( - dataset_root, args.val_size, args.test_size, args.seed_for_ds_split, args.manifests_path, + dataset_root, + args.val_size, + args.test_size, + args.seed_for_ds_split, + args.manifests_path, ) diff --git a/scripts/dataset_processing/tts/thorsten_neutral/get_data.py b/scripts/dataset_processing/tts/thorsten_neutral/get_data.py index d49d362064fd..51104123d7c8 100644 --- a/scripts/dataset_processing/tts/thorsten_neutral/get_data.py +++ b/scripts/dataset_processing/tts/thorsten_neutral/get_data.py @@ -127,7 +127,10 @@ def __text_normalization(json_file, num_workers=-1): "punct_post_process": True, } text_normalizer = Normalizer( - lang="de", input_case="cased", overwrite_cache=True, cache_dir=str(json_file.parent / "cache_dir"), + lang="de", + input_case="cased", + overwrite_cache=True, + cache_dir=str(json_file.parent / "cache_dir"), ) def normalizer_call(x): diff --git a/scripts/installers/setup_os2s_decoders.py b/scripts/installers/setup_os2s_decoders.py index 9728829aaa44..c14f5018e3e2 100644 --- a/scripts/installers/setup_os2s_decoders.py +++ b/scripts/installers/setup_os2s_decoders.py @@ -124,7 +124,12 @@ def compile_test(header, library): name='_swig_decoders', sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'), language='c++', - include_dirs=['.', 'kenlm', 'openfst-1.6.3/src/include', 'ThreadPool',], + include_dirs=[ + '.', + 'kenlm', + 'openfst-1.6.3/src/include', + 'ThreadPool', + ], libraries=LIBS, extra_compile_args=ARGS, ) diff --git a/scripts/speaker_tasks/create_synth_vad_manifest.py b/scripts/speaker_tasks/create_synth_vad_manifest.py index 87828f2a001f..b7191a1c6a03 100644 --- a/scripts/speaker_tasks/create_synth_vad_manifest.py +++ b/scripts/speaker_tasks/create_synth_vad_manifest.py @@ -46,14 +46,14 @@ def generate_manifest_entry(inputs): """ - Generates a manifest entry for a single audio file. + Generates a manifest entry for a single audio file. This function is parallelized using multiprocessing.Pool. Args: inputs (tuple): Tuple containing audio file path and frame length in seconds. - inputs[0]: + inputs[0]: audio_filepath (str): Path to audio file. - inputs[1]: + inputs[1]: vad_frame_unit_secs (float): Duration in seconds for each frame label. Returns: diff --git a/scripts/speaker_tasks/eval_diar_with_asr.py b/scripts/speaker_tasks/eval_diar_with_asr.py index 9fc651e953cd..9e456afc54a9 100644 --- a/scripts/speaker_tasks/eval_diar_with_asr.py +++ b/scripts/speaker_tasks/eval_diar_with_asr.py @@ -80,8 +80,7 @@ def get_pyannote_objs_from_rttms(rttm_file_path_list): - """Generate PyAnnote objects from RTTM file list - """ + """Generate PyAnnote objects from RTTM file list""" pyannote_obj_list = [] for rttm_file in rttm_file_path_list: rttm_file = rttm_file.strip() @@ -94,8 +93,7 @@ def get_pyannote_objs_from_rttms(rttm_file_path_list): def make_meta_dict(hyp_rttm_list, ref_rttm_list): - """Create a temporary `audio_rttm_map_dict` for evaluation - """ + """Create a temporary `audio_rttm_map_dict` for evaluation""" meta_dict = {} for k, rttm_file in enumerate(ref_rttm_list): uniq_id = get_uniqname_from_filepath(rttm_file) @@ -107,8 +105,7 @@ def make_meta_dict(hyp_rttm_list, ref_rttm_list): def make_trans_info_dict(hyp_json_list_path): - """Create `trans_info_dict` from the `.json` files - """ + """Create `trans_info_dict` from the `.json` files""" trans_info_dict = {} for json_file in hyp_json_list_path: json_file = json_file.strip() @@ -120,8 +117,7 @@ def make_trans_info_dict(hyp_json_list_path): def read_file_path(list_path): - """Read file path and strip to remove line change symbol - """ + """Read file path and strip to remove line change symbol""" return sorted([x.strip() for x in read_file(list_path)]) diff --git a/scripts/speaker_tasks/filelist_to_manifest.py b/scripts/speaker_tasks/filelist_to_manifest.py index 063b2b1f6596..8aee010f63c3 100644 --- a/scripts/speaker_tasks/filelist_to_manifest.py +++ b/scripts/speaker_tasks/filelist_to_manifest.py @@ -237,5 +237,11 @@ def main(filelist, manifest, id, out, split=False, create_segments=False, min_co args = parser.parse_args() main( - args.filelist, args.manifest, args.id, args.out, args.split, args.create_segments, args.min_spkrs_count, + args.filelist, + args.manifest, + args.id, + args.out, + args.split, + args.create_segments, + args.min_spkrs_count, ) diff --git a/scripts/speaker_tasks/multispeaker_data_analysis.py b/scripts/speaker_tasks/multispeaker_data_analysis.py index bc33426ea574..8245066342d9 100644 --- a/scripts/speaker_tasks/multispeaker_data_analysis.py +++ b/scripts/speaker_tasks/multispeaker_data_analysis.py @@ -151,14 +151,25 @@ def run_multispeaker_data_analysis( queue = [] for rttm_file in tqdm(rttm_list): queue.append( - {"rttm_file": rttm_file, "session_dur": session_dur, "precise": precise,} + { + "rttm_file": rttm_file, + "session_dur": session_dur, + "precise": precise, + } ) if num_workers <= 1: results = [process_sample(sess_dict) for sess_dict in tqdm(queue)] else: with multiprocessing.Pool(processes=num_workers) as p: - results = list(tqdm(p.imap(process_sample, queue), total=len(queue), desc='Processing', leave=True,)) + results = list( + tqdm( + p.imap(process_sample, queue), + total=len(queue), + desc='Processing', + leave=True, + ) + ) for item in results: total_duration += item["session_dur"] diff --git a/scripts/speaker_tasks/pathfiles_to_diarize_manifest.py b/scripts/speaker_tasks/pathfiles_to_diarize_manifest.py index 6857120c0023..e15514d26510 100644 --- a/scripts/speaker_tasks/pathfiles_to_diarize_manifest.py +++ b/scripts/speaker_tasks/pathfiles_to_diarize_manifest.py @@ -52,7 +52,9 @@ def main( parser.add_argument("--paths2ctm_files", help="path to ctm files", type=str) parser.add_argument("--manifest_filepath", help="path to output manifest file", type=str, required=True) parser.add_argument( - "--add_duration", help="add duration of audio files to output manifest files.", action='store_true', + "--add_duration", + help="add duration of audio files to output manifest files.", + action='store_true', ) args = parser.parse_args() diff --git a/scripts/speech_recognition/code_switching/code_switching_audio_data_creation.py b/scripts/speech_recognition/code_switching/code_switching_audio_data_creation.py index c53b3eeaac36..1a3e70a33588 100644 --- a/scripts/speech_recognition/code_switching/code_switching_audio_data_creation.py +++ b/scripts/speech_recognition/code_switching/code_switching_audio_data_creation.py @@ -124,7 +124,6 @@ def create_cs_data( cs_data_sampling_rate: int, is_lid_manifest: bool, ): - """ Args: intermediate_cs_manifest_list: the intermediate cs manifest obtained from code_switching_manifest_creation.py as a list diff --git a/scripts/ssl_tts/make_supdata.py b/scripts/ssl_tts/make_supdata.py index 9f4c391b2f6e..5bbf8bad65af 100644 --- a/scripts/ssl_tts/make_supdata.py +++ b/scripts/ssl_tts/make_supdata.py @@ -69,7 +69,10 @@ def __len__(self): def _get_wav_from_filepath(self, audio_filepath): features = AudioSegment.segment_from_file( - audio_filepath, target_sr=self.sample_rate, n_segments=-1, trim=False, + audio_filepath, + target_sr=self.sample_rate, + n_segments=-1, + trim=False, ) audio_samples = features.samples audio, audio_length = torch.tensor(audio_samples), torch.tensor(audio_samples.shape[0]).long() @@ -181,7 +184,12 @@ def get_mel_spectrogram(fb, wav, stft_params): def load_wav(wav_path, sample_rate=22050, pad_multiple=1024): - wav = AudioSegment.segment_from_file(wav_path, target_sr=sample_rate, n_segments=-1, trim=False,).samples + wav = AudioSegment.segment_from_file( + wav_path, + target_sr=sample_rate, + n_segments=-1, + trim=False, + ).samples if wav.shape[0] % pad_multiple != 0: wav = np.concatenate([wav, np.zeros(pad_multiple - wav.shape[0] % pad_multiple)]) @@ -269,7 +277,9 @@ def _is_valid_pitch(pitch_mean, pitch_std): def main(): parser = argparse.ArgumentParser(description='Evaluate the model') parser.add_argument( - '--ssl_model_ckpt_path', type=str, required=True, + '--ssl_model_ckpt_path', + type=str, + required=True, ) parser.add_argument('--manifest_paths', type=str, required=True) parser.add_argument('--sup_data_dir', type=str, default=None) @@ -374,7 +384,9 @@ def main(): audio_seg_len = torch.tensor([len(segment) for segment in audio_segmented]).to(device).long() _, batch_speaker_embeddings, _, _, _ = ssl_model.forward_for_export( - input_signal=audio_segmented.to(device), input_signal_length=audio_seg_len, normalize_content=True, + input_signal=audio_segmented.to(device), + input_signal_length=audio_seg_len, + normalize_content=True, ) for idx in range(batch['audio'].shape[0]): diff --git a/scripts/ssl_tts/ssl_tts_vc.py b/scripts/ssl_tts/ssl_tts_vc.py index 9d9e1205da5a..66191de1122a 100644 --- a/scripts/ssl_tts/ssl_tts_vc.py +++ b/scripts/ssl_tts/ssl_tts_vc.py @@ -114,7 +114,7 @@ def get_ssl_features_disentangled( ssl_model, wav_featurizer, audio_path, emb_type="embedding_and_probs", use_unique_tokens=False, device="cpu" ): """ - Extracts content embedding, speaker embedding and duration tokens to be used as inputs for FastPitchModel_SSL + Extracts content embedding, speaker embedding and duration tokens to be used as inputs for FastPitchModel_SSL synthesizer. Content embedding and speaker embedding extracted using SSLDisentangler model. Args: ssl_model: SSLDisentangler model diff --git a/scripts/tokenizers/add_special_tokens_to_sentencepiece.py b/scripts/tokenizers/add_special_tokens_to_sentencepiece.py index 20e1cbc33e77..122c0bac908f 100644 --- a/scripts/tokenizers/add_special_tokens_to_sentencepiece.py +++ b/scripts/tokenizers/add_special_tokens_to_sentencepiece.py @@ -48,16 +48,28 @@ def edit_spt_model(): parser = ArgumentParser() parser.add_argument( - "--input_file", type=str, required=True, help="Path to sentencepiece model file", + "--input_file", + type=str, + required=True, + help="Path to sentencepiece model file", ) parser.add_argument( - "--output_file", type=str, required=True, help="Path to sentencepiece model file", + "--output_file", + type=str, + required=True, + help="Path to sentencepiece model file", ) parser.add_argument( - "--tokens", type=str, nargs='+', required=True, help="Special tokens to add to tokenizer", + "--tokens", + type=str, + nargs='+', + required=True, + help="Special tokens to add to tokenizer", ) parser.add_argument( - "--is_userdefined", action="store_true", help="When set, the new tokens are set as user_defined tokens", + "--is_userdefined", + action="store_true", + help="When set, the new tokens are set as user_defined tokens", ) args = parser.parse_args() diff --git a/scripts/voice_activity_detection/vad_tune_threshold.py b/scripts/voice_activity_detection/vad_tune_threshold.py index 2f41677fe3ae..6ffbbd7c92b7 100644 --- a/scripts/voice_activity_detection/vad_tune_threshold.py +++ b/scripts/voice_activity_detection/vad_tune_threshold.py @@ -68,7 +68,9 @@ required=True, ) parser.add_argument( - "--result_file", help="Filename of txt to store results", default="res", + "--result_file", + help="Filename of txt to store results", + default="res", ) parser.add_argument( "--vad_pred_method", @@ -82,7 +84,10 @@ default='DetER', ) parser.add_argument( - "--frame_length_in_sec", help="frame_length_in_sec ", type=float, default=0.01, + "--frame_length_in_sec", + help="frame_length_in_sec ", + type=float, + default=0.01, ) args = parser.parse_args() diff --git a/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py b/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py index cb5a9816e237..1cb354f3020c 100644 --- a/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py +++ b/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py @@ -98,7 +98,18 @@ def test_compute_alphas_kernel(self, dtype): # alpha kernel gpu_rnnt_kernel.compute_alphas_kernel[B, U, stream, 0]( - x_c, denom, alphas, llForward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + x_c, + denom, + alphas, + llForward, + input_lengths, + label_lengths, + labels_c, + B, + T, + U, + V, + blank_idx, ) # sync kernel @@ -109,12 +120,12 @@ def test_compute_alphas_kernel(self, dtype): diff = ground_alphas - alphas[0].cpu().numpy() assert np.abs(diff).mean() <= threshold - assert np.square(diff).mean() <= (threshold ** 2) + assert np.square(diff).mean() <= (threshold**2) ll_diff = ground_log_likelihood - llForward[0].cpu().numpy() assert np.abs(ll_diff).mean() <= threshold - assert np.square(ll_diff).mean() <= (threshold ** 2) + assert np.square(ll_diff).mean() <= (threshold**2) @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") @pytest.mark.unit @@ -168,7 +179,18 @@ def test_compute_betas_kernel(self, dtype): # beta kernel gpu_rnnt_kernel.compute_betas_kernel[B, U, stream, 0]( - x_c, denom, betas, llBackward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + x_c, + denom, + betas, + llBackward, + input_lengths, + label_lengths, + labels_c, + B, + T, + U, + V, + blank_idx, ) # sync kernel @@ -179,12 +201,12 @@ def test_compute_betas_kernel(self, dtype): diff = ground_alphas - betas[0].cpu().numpy() assert np.abs(diff).mean() <= threshold - assert np.square(diff).mean() <= (threshold ** 2) + assert np.square(diff).mean() <= (threshold**2) ll_diff = ground_log_likelihood - llBackward[0].cpu().numpy() assert np.abs(ll_diff).mean() <= threshold - assert np.square(ll_diff).mean() <= (threshold ** 2) + assert np.square(ll_diff).mean() <= (threshold**2) @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") @pytest.mark.unit @@ -258,12 +280,34 @@ def test_compute_grads_kernel(self, dtype): # alpha kernel gpu_rnnt_kernel.compute_alphas_kernel[B, U, stream, 0]( - x_c, denom, alphas, llForward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + x_c, + denom, + alphas, + llForward, + input_lengths, + label_lengths, + labels_c, + B, + T, + U, + V, + blank_idx, ) # beta kernel gpu_rnnt_kernel.compute_betas_kernel[B, U, stream, 0]( - x_c, denom, betas, llBackward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + x_c, + denom, + betas, + llBackward, + input_lengths, + label_lengths, + labels_c, + B, + T, + U, + V, + blank_idx, ) # gamma kernel @@ -296,7 +340,7 @@ def test_compute_grads_kernel(self, dtype): diff = true_grads - grads[0].cpu().numpy() assert np.abs(diff).mean() <= threshold - assert np.square(diff).mean() <= (threshold ** 2) * 5.0 + assert np.square(diff).mean() <= (threshold**2) * 5.0 @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") @pytest.mark.unit @@ -370,12 +414,34 @@ def test_compute_grads_kernel_fastemit(self, dtype): # alpha kernel gpu_rnnt_kernel.compute_alphas_kernel[B, U, stream, 0]( - x_c, denom, alphas, llForward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + x_c, + denom, + alphas, + llForward, + input_lengths, + label_lengths, + labels_c, + B, + T, + U, + V, + blank_idx, ) # beta kernel gpu_rnnt_kernel.compute_betas_kernel[B, U, stream, 0]( - x_c, denom, betas, llBackward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + x_c, + denom, + betas, + llBackward, + input_lengths, + label_lengths, + labels_c, + B, + T, + U, + V, + blank_idx, ) # gamma kernel @@ -408,7 +474,7 @@ def test_compute_grads_kernel_fastemit(self, dtype): diff = true_grads - grads[0].cpu().numpy() assert np.abs(diff).mean() <= threshold - assert np.square(diff).mean() <= (threshold ** 2) * 5 + assert np.square(diff).mean() <= (threshold**2) * 5 @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") @pytest.mark.unit @@ -482,12 +548,34 @@ def test_compute_grads_kernel_clamp(self, dtype): # alpha kernel gpu_rnnt_kernel.compute_alphas_kernel[B, U, stream, 0]( - x_c, denom, alphas, llForward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + x_c, + denom, + alphas, + llForward, + input_lengths, + label_lengths, + labels_c, + B, + T, + U, + V, + blank_idx, ) # beta kernel gpu_rnnt_kernel.compute_betas_kernel[B, U, stream, 0]( - x_c, denom, betas, llBackward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + x_c, + denom, + betas, + llBackward, + input_lengths, + label_lengths, + labels_c, + B, + T, + U, + V, + blank_idx, ) # gamma kernel @@ -520,7 +608,7 @@ def test_compute_grads_kernel_clamp(self, dtype): diff = true_grads - grads[0].cpu().numpy() assert np.abs(diff).mean() <= threshold - assert np.square(diff).mean() <= (threshold ** 2) * 5 + assert np.square(diff).mean() <= (threshold**2) * 5 class TestTDTCUDAKernels: diff --git a/tests/collections/common/loss_inputs.py b/tests/collections/common/loss_inputs.py index cc3c384efe4f..435d1fc99ee3 100644 --- a/tests/collections/common/loss_inputs.py +++ b/tests/collections/common/loss_inputs.py @@ -51,7 +51,8 @@ class LossInput: NO_ZERO_NUM_MEASUREMENTS = LossInput( - loss_sum_or_avg=torch.rand(NUM_BATCHES) * 2.0 - 1.0, num_measurements=torch.randint(1, 100, (NUM_BATCHES,)), + loss_sum_or_avg=torch.rand(NUM_BATCHES) * 2.0 - 1.0, + num_measurements=torch.randint(1, 100, (NUM_BATCHES,)), ) SOME_NUM_MEASUREMENTS_ARE_ZERO = LossInput( @@ -65,5 +66,6 @@ class LossInput: ) ALL_NUM_MEASUREMENTS_ARE_ZERO = LossInput( - loss_sum_or_avg=torch.rand(NUM_BATCHES) * 2.0 - 1.0, num_measurements=torch.zeros(NUM_BATCHES, dtype=torch.int32), + loss_sum_or_avg=torch.rand(NUM_BATCHES) * 2.0 - 1.0, + num_measurements=torch.zeros(NUM_BATCHES, dtype=torch.int32), ) diff --git a/tests/collections/common/mixins/test_adapter_common_model_mixin.py b/tests/collections/common/mixins/test_adapter_common_model_mixin.py index 22cd3fdb31bd..465702d5183c 100644 --- a/tests/collections/common/mixins/test_adapter_common_model_mixin.py +++ b/tests/collections/common/mixins/test_adapter_common_model_mixin.py @@ -37,7 +37,7 @@ class MockLinearAdapter2(LinearAdapter): class CommonModule(NeuralModule): - """ Define a default neural module (without adapter support)""" + """Define a default neural module (without adapter support)""" def __init__(self): super().__init__() @@ -60,7 +60,7 @@ def num_params(self): class CommonModuleAdapter(CommonModule, AdapterModuleMixin): - """ Subclass the DefaultModule, adding adapter module support""" + """Subclass the DefaultModule, adding adapter module support""" def forward(self, x): x = super().forward(x) @@ -73,7 +73,9 @@ def forward(self, x): return x - def get_accepted_adapter_types(self,) -> 'Set[type]': + def get_accepted_adapter_types( + self, + ) -> 'Set[type]': types = super().get_accepted_adapter_types() if len(types) == 0: diff --git a/tests/core/test_fileio.py b/tests/core/test_fileio.py index f737b23216d8..3f5f3c193920 100644 --- a/tests/core/test_fileio.py +++ b/tests/core/test_fileio.py @@ -107,8 +107,8 @@ def asr_model(): class TestFileIO: @pytest.mark.unit def test_to_from_config_file(self, asr_model): - """" Test makes sure that the second instance created with the same configuration (BUT NOT checkpoint) - has different weights. """ + """ " Test makes sure that the second instance created with the same configuration (BUT NOT checkpoint) + has different weights.""" with tempfile.NamedTemporaryFile() as fp: yaml_filename = fp.name @@ -127,8 +127,8 @@ def test_to_from_config_file(self, asr_model): @pytest.mark.unit def test_save_restore_from_nemo_file(self, asr_model): - """" Test makes sure that the second instance created from the same configuration AND checkpoint - has the same weights. """ + """ " Test makes sure that the second instance created from the same configuration AND checkpoint + has the same weights.""" with tempfile.NamedTemporaryFile() as fp: filename = fp.name @@ -152,7 +152,7 @@ def test_save_restore_from_nemo_file(self, asr_model): @requires_eff @pytest.mark.unit def test_eff_save_restore_from_nemo_file_encrypted(self, asr_model): - """" Test makes sure that after encrypted save-restore the model has the same weights. """ + """ " Test makes sure that after encrypted save-restore the model has the same weights.""" with tempfile.NamedTemporaryFile() as fp: filename = fp.name @@ -181,7 +181,7 @@ def test_eff_save_restore_from_nemo_file_encrypted(self, asr_model): @pytest.mark.unit def test_save_restore_from_nemo_file_with_override(self, asr_model, tmpdir): - """" Test makes sure that the second instance created from the same configuration AND checkpoint + """ " Test makes sure that the second instance created from the same configuration AND checkpoint has the same weights. Args: diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py index 42b1e22b711f..2bf8e8fded6a 100644 --- a/tests/core/test_serialization.py +++ b/tests/core/test_serialization.py @@ -42,7 +42,11 @@ def test_from_config_dict_with_cls(self): config = DictConfig( { 'cls': 'nemo.collections.asr.modules.SpectrogramAugmentation', - 'params': {'rect_freq': 50, 'rect_masks': 5, 'rect_time': 120,}, + 'params': { + 'rect_freq': 50, + 'rect_masks': 5, + 'rect_time': 120, + }, } ) obj = Serialization.from_config_dict(config=config) @@ -124,7 +128,11 @@ def test_config_updated(self): config = DictConfig( { 'cls': 'nemo.collections.asr.modules.SpectrogramAugmentation', - 'params': {'rect_freq': 50, 'rect_masks': 5, 'rect_time': 120,}, + 'params': { + 'rect_freq': 50, + 'rect_masks': 5, + 'rect_time': 120, + }, } ) obj = Serialization.from_config_dict(config=config) diff --git a/tests/core/test_torch_jit_script.py b/tests/core/test_torch_jit_script.py index a5c2be36a6ab..6ae7ce79ed37 100644 --- a/tests/core/test_torch_jit_script.py +++ b/tests/core/test_torch_jit_script.py @@ -154,7 +154,8 @@ def test_simple_linear_with_types(self): @pytest.mark.unit @pytest.mark.parametrize( - "neural_type", get_all_neural_types(), + "neural_type", + get_all_neural_types(), ) def test_element_compilable(self, neural_type: Type[nelements.ElementType]): """ diff --git a/tests/hydra/test_hydra_runner.py b/tests/hydra/test_hydra_runner.py index 1da0a914cfaf..11c3f3dc6601 100644 --- a/tests/hydra/test_hydra_runner.py +++ b/tests/hydra/test_hydra_runner.py @@ -22,8 +22,7 @@ class TestHydraRunner: @pytest.mark.integration def test_no_config(self): - """"Test app without config - fields missing causes error. - """ + """ "Test app without config - fields missing causes error.""" # Create system call. call = "python tests/hydra/my_app.py" @@ -33,8 +32,7 @@ def test_no_config(self): @pytest.mark.integration def test_config1(self): - """"Test injection of valid config1. - """ + """ "Test injection of valid config1.""" # Create system call. call = "python tests/hydra/my_app.py --config-name config1.yaml" @@ -48,8 +46,7 @@ def test_config1(self): @pytest.mark.integration def test_config1_invalid(self): - """"Test injection of invalid config1. - """ + """ "Test injection of invalid config1.""" # Create system call. call = "python tests/hydra/my_app.py --config-name config1_invalid.yaml" @@ -59,8 +56,7 @@ def test_config1_invalid(self): @pytest.mark.integration def test_config2(self): - """"Test injection of valid config2 from a different folder. - """ + """ "Test injection of valid config2 from a different folder.""" # Create system call. call = "python tests/hydra/my_app.py --config-path config_subdir --config-name config2.yaml" @@ -74,8 +70,7 @@ def test_config2(self): @pytest.mark.integration def test_config2_invalid(self): - """"Test injection of invalid config2 from a different folder. - """ + """ "Test injection of invalid config2 from a different folder.""" # Create system call. call = "python tests/hydra/my_app.py --config-path config_subdir --config-name config2_invalid.yaml" @@ -85,8 +80,7 @@ def test_config2_invalid(self): @pytest.mark.integration def test_config2_filepath_schema(self): - """"Test injection of valid config2 - using namepath with schema is prohibited. - """ + """ "Test injection of valid config2 - using namepath with schema is prohibited.""" # Create system call. call = "python tests/hydra/my_app.py --config-name config_subdir/config2_invalid.yaml" diff --git a/tools/ctc_segmentation/scripts/cut_audio.py b/tools/ctc_segmentation/scripts/cut_audio.py index 3dd934b5a6c2..cc62e2c55fc0 100644 --- a/tools/ctc_segmentation/scripts/cut_audio.py +++ b/tools/ctc_segmentation/scripts/cut_audio.py @@ -48,7 +48,7 @@ def process_alignment(alignment_file: str, manifest: str, clips_dir: str, args): - """ Cut original audio file into audio segments based on alignment_file + """Cut original audio file into audio segments based on alignment_file Args: alignment_file: path to the file with segmented text and corresponding time stamps. diff --git a/tools/ctc_segmentation/scripts/get_metrics_and_filter.py b/tools/ctc_segmentation/scripts/get_metrics_and_filter.py index 60b658e01af1..24c84d8b52e0 100644 --- a/tools/ctc_segmentation/scripts/get_metrics_and_filter.py +++ b/tools/ctc_segmentation/scripts/get_metrics_and_filter.py @@ -26,7 +26,9 @@ parser = argparse.ArgumentParser("Calculate metrics and filters out samples based on thresholds") parser.add_argument( - "--manifest", required=True, help="Path .json manifest file with ASR predictions saved at `pred_text` field.", + "--manifest", + required=True, + help="Path .json manifest file with ASR predictions saved at `pred_text` field.", ) parser.add_argument( "--edge_len", type=int, help="Number of characters to use for CER calculation at the edges", default=5 @@ -119,7 +121,7 @@ def _apply_filters( min_dur=1, original_duration=0, ): - """ Filters out samples that do not satisfy specified threshold values and saves remaining samples to manifest_out""" + """Filters out samples that do not satisfy specified threshold values and saves remaining samples to manifest_out""" remaining_duration = 0 segmented_duration = 0 with open(manifest, "r") as f, open(manifest_out, "w") as f_out: diff --git a/tools/ctc_segmentation/scripts/prepare_data.py b/tools/ctc_segmentation/scripts/prepare_data.py index 476e719eb51b..d54d31ee1779 100644 --- a/tools/ctc_segmentation/scripts/prepare_data.py +++ b/tools/ctc_segmentation/scripts/prepare_data.py @@ -53,7 +53,10 @@ help='Add target language based on the num2words list of supported languages', ) parser.add_argument( - "--cut_prefix", type=int, default=0, help="Number of seconds to cut from the beginning of the audio files.", + "--cut_prefix", + type=int, + default=0, + help="Number of seconds to cut from the beginning of the audio files.", ) parser.add_argument( "--model", type=str, default="QuartzNet15x5Base-En", help="Pre-trained model name or path to model checkpoint" @@ -75,7 +78,10 @@ help="Set to True to use NeMo Normalization tool to convert numbers from written to spoken format.", ) parser.add_argument( - "--batch_size", type=int, default=100, help="Batch size for NeMo Normalization tool.", + "--batch_size", + type=int, + default=100, + help="Batch size for NeMo Normalization tool.", ) diff --git a/tools/ctc_segmentation/scripts/run_ctc_segmentation.py b/tools/ctc_segmentation/scripts/run_ctc_segmentation.py index c9d9ed2d8731..e32474b7217a 100644 --- a/tools/ctc_segmentation/scripts/run_ctc_segmentation.py +++ b/tools/ctc_segmentation/scripts/run_ctc_segmentation.py @@ -43,7 +43,10 @@ parser.add_argument("--window_len", type=int, default=8000, help="Window size for ctc segmentation algorithm") parser.add_argument("--sample_rate", type=int, default=16000, help="Sampling rate, Hz") parser.add_argument( - "--model", type=str, default="QuartzNet15x5Base-En", help="Path to model checkpoint or pre-trained model name", + "--model", + type=str, + default="QuartzNet15x5Base-En", + help="Path to model checkpoint or pre-trained model name", ) parser.add_argument("--debug", action="store_true", help="Flag to enable debugging messages") parser.add_argument( diff --git a/tools/ctc_segmentation/scripts/utils.py b/tools/ctc_segmentation/scripts/utils.py index f6d4c0c4217b..dc9e3fd5f987 100644 --- a/tools/ctc_segmentation/scripts/utils.py +++ b/tools/ctc_segmentation/scripts/utils.py @@ -133,7 +133,7 @@ def get_segments( def _prepare_tokenized_text_for_bpe_model(text: List[str], tokenizer, vocabulary: List[str], blank_idx: int = 0): - """ Creates a transition matrix for BPE-based models""" + """Creates a transition matrix for BPE-based models""" space_idx = vocabulary.index("▁") ground_truth_mat = [[-1, -1]] utt_begin_indices = [] @@ -301,7 +301,9 @@ def write_output( def write_labels_for_audacity( - out_path: str, segments: List[Tuple[float]], text_no_preprocessing: str, + out_path: str, + segments: List[Tuple[float]], + text_no_preprocessing: str, ): """ Write the segmentation output to a file ready to be imported in Audacity with the unprocessed text as labels diff --git a/tools/customization_dataset_preparation/customization_dataset_preparation.py b/tools/customization_dataset_preparation/customization_dataset_preparation.py index 53582f5489f1..ca7586e767ed 100644 --- a/tools/customization_dataset_preparation/customization_dataset_preparation.py +++ b/tools/customization_dataset_preparation/customization_dataset_preparation.py @@ -92,7 +92,7 @@ def recommend_hyperparameters_human_readable(recommended_hyperparameters): def recommend_hyperparameters(df, model=None): """ Makes recommendations on the batch_size to use for training, based on the dataset size - + """ potential_batch_sizes = [2, 4, 8, 12, 16, 32, 64, 128] diff --git a/tools/customization_dataset_preparation/tests/test_customization_dataset_preparation.py b/tools/customization_dataset_preparation/tests/test_customization_dataset_preparation.py index 7bc9b701672e..4faa170ef139 100644 --- a/tools/customization_dataset_preparation/tests/test_customization_dataset_preparation.py +++ b/tools/customization_dataset_preparation/tests/test_customization_dataset_preparation.py @@ -155,7 +155,10 @@ def test_get_common_suffix(): def test_warn_missing_suffix(): df_no_common = pd.DataFrame( - {'prompt': [f'prompt{i}' for i in range(100)], 'completion': [f'completion{i}' for i in range(100)],} + { + 'prompt': [f'prompt{i}' for i in range(100)], + 'completion': [f'completion{i}' for i in range(100)], + } ) message = f"TODO: prompt does not have common suffix, please add one (e.g. \\n) at the end of prompt_template\n" message += ( @@ -164,7 +167,10 @@ def test_warn_missing_suffix(): assert warn_missing_suffix(df_no_common) == message df_common = pd.DataFrame( - {'prompt': [f'prompt{i} answer:' for i in range(100)], 'completion': [f'completion{i}\n' for i in range(100)],} + { + 'prompt': [f'prompt{i} answer:' for i in range(100)], + 'completion': [f'completion{i}\n' for i in range(100)], + } ) assert warn_missing_suffix(df_common) is None @@ -233,7 +239,11 @@ def test_drop_duplicated_rows(): def test_template_mapper(): - df = pd.DataFrame({'prompt': ['prompt sample'],}) + df = pd.DataFrame( + { + 'prompt': ['prompt sample'], + } + ) template = "{prompt}" field_names = ['prompt'] @@ -265,7 +275,12 @@ def test_convert_into_template(): with pytest.raises(ValueError): convert_into_template(df_non_existant_field_name, template) - df = pd.DataFrame({'question': ['question sample'], 'context': ['context sample'],}) + df = pd.DataFrame( + { + 'question': ['question sample'], + 'context': ['context sample'], + } + ) df_prompt = pd.DataFrame( { @@ -345,7 +360,10 @@ def test_get_prepared_filename(): prepared_val_filename = "tmp/sample_prepared_val.jsonl" assert get_prepared_filename(filename) == (prepared_filename, None) assert get_prepared_filename(filename, split_train_validation=True) == ( - [prepared_train_filename, prepared_val_filename,], + [ + prepared_train_filename, + prepared_val_filename, + ], None, ) csv_filename = "tmp/sample.csv" diff --git a/tools/nemo_forced_aligner/utils/make_output_manifest.py b/tools/nemo_forced_aligner/utils/make_output_manifest.py index 7ee3fc77f7ab..c9faec8531b5 100644 --- a/tools/nemo_forced_aligner/utils/make_output_manifest.py +++ b/tools/nemo_forced_aligner/utils/make_output_manifest.py @@ -16,7 +16,8 @@ def write_manifest_out_line( - f_manifest_out, utt_obj, + f_manifest_out, + utt_obj, ): data = {"audio_filepath": utt_obj.audio_filepath} diff --git a/tools/nmt_grpc_service/api/nmt_pb2.py b/tools/nmt_grpc_service/api/nmt_pb2.py index 14edb2e8bc10..5578b445ca0c 100644 --- a/tools/nmt_grpc_service/api/nmt_pb2.py +++ b/tools/nmt_grpc_service/api/nmt_pb2.py @@ -226,7 +226,7 @@ (_message.Message,), { 'DESCRIPTOR': _TRANSLATETEXTREQUEST, - '__module__': 'nmt_pb2' + '__module__': 'nmt_pb2', # @@protoc_insertion_point(class_scope:nvidia.riva.nmt.TranslateTextRequest) }, ) @@ -237,7 +237,7 @@ (_message.Message,), { 'DESCRIPTOR': _TRANSLATION, - '__module__': 'nmt_pb2' + '__module__': 'nmt_pb2', # @@protoc_insertion_point(class_scope:nvidia.riva.nmt.Translation) }, ) @@ -248,7 +248,7 @@ (_message.Message,), { 'DESCRIPTOR': _TRANSLATETEXTRESPONSE, - '__module__': 'nmt_pb2' + '__module__': 'nmt_pb2', # @@protoc_insertion_point(class_scope:nvidia.riva.nmt.TranslateTextResponse) }, ) diff --git a/tools/nmt_grpc_service/asr_nmt_client.py b/tools/nmt_grpc_service/asr_nmt_client.py index fbdc4ea5933b..91264088934f 100644 --- a/tools/nmt_grpc_service/asr_nmt_client.py +++ b/tools/nmt_grpc_service/asr_nmt_client.py @@ -103,6 +103,7 @@ def listen_print_loop(responses, nmt_stub, target_language, asr_only=False): output=True, ) + # read data def generator(w, s): d = w.readframes(CHUNK) diff --git a/tools/nmt_grpc_service/server.py b/tools/nmt_grpc_service/server.py index 0a35272a7441..c2beb1cbc2d0 100644 --- a/tools/nmt_grpc_service/server.py +++ b/tools/nmt_grpc_service/server.py @@ -28,7 +28,10 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--model_dir", required=True, type=str, help="Path to a folder containing .nemo translation model files.", + "--model_dir", + required=True, + type=str, + help="Path to a folder containing .nemo translation model files.", ) parser.add_argument( "--punctuation_model", @@ -56,7 +59,13 @@ class RivaTranslateServicer(nmtsrv.RivaTranslateServicer): """Provides methods that implement functionality of route guide server.""" def __init__( - self, model_dir, punctuation_model_path, beam_size=1, len_pen=0.6, max_delta_length=5, batch_size=256, + self, + model_dir, + punctuation_model_path, + beam_size=1, + len_pen=0.6, + max_delta_length=5, + batch_size=256, ): self._models = {} self._beam_size = beam_size