diff --git a/scripts/convert_cosmos3_to_diffusers.py b/scripts/convert_cosmos3_to_diffusers.py index 59fe57e2c07c..49a93ca1742e 100644 --- a/scripts/convert_cosmos3_to_diffusers.py +++ b/scripts/convert_cosmos3_to_diffusers.py @@ -18,30 +18,49 @@ import re import torch -from cosmos3.common.init import init_script - - -init_script() - -from accelerate import init_empty_weights # noqa: E402 -from cosmos3.args import _CHECKPOINTS # noqa: E402 -from cosmos3.model import Cosmos3OmniModel # noqa: E402 -from projects.cosmos3.vfm.models.omni_mot_model import OmniMoTModel # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 - -from diffusers import AutoencoderKLWan, UniPCMultistepScheduler # noqa: E402 -from diffusers.models.autoencoders.autoencoder_cosmos3_audio import Cosmos3AVAEAudioTokenizer # noqa: E402 -from diffusers.models.transformers.transformer_cosmos3 import Cosmos3OmniTransformer # noqa: E402 -from diffusers.pipelines.cosmos.pipeline_cosmos3_omni import Cosmos3OmniPipeline # noqa: E402 DEFAULT_SOUND_TOKENIZER_CONFIG = { + "model_type": "autoencoder_v2", "sampling_rate": 48000, + "stereo": True, + "use_wav_as_input": True, + "normalize_volume": True, + "hop_size": 1920, + "input_channels": 1, + "enc_type": "spec_convnext", + "enc_dim": 192, + "enc_intermediate_dim": 768, + "enc_num_layers": 12, + "enc_num_blocks": 2, + "enc_n_fft": 64, + "enc_hop_length": 16, + "enc_latent_dim": 128, + "enc_c_mults": [1, 2, 4], + "enc_strides": [4, 5, 6], + "enc_identity_init": False, + "enc_use_snake": True, + "dec_type": "oobleck", "vocoder_input_dim": 64, "dec_dim": 320, "dec_c_mults": [1, 2, 4, 8, 16], "dec_strides": [2, 4, 5, 6, 8], + "dec_use_snake": True, + "dec_final_tanh": False, "dec_out_channels": 2, + "dec_anti_aliasing": False, + "dec_use_nearest_upsample": False, + "dec_use_tanh_at_final": False, + "bottleneck_type": "vae", + "bottleneck": {"type": "vae"}, + "activation": "snakebeta", + "snake_logscale": True, + "anti_aliasing": False, + "use_cuda_kernel": False, + "causal": False, + "padding_mode": "zeros", + "latent_mean": None, + "latent_std": None, } @@ -114,8 +133,10 @@ def _sound_tokenizer_strip_per_key_prefixes(state_dict: dict[str, torch.Tensor]) return out -def _sound_tokenizer_filter_decoder(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - return {key: value for key, value in state_dict.items() if key.startswith("decoder.")} +def _sound_tokenizer_filter_supported_modules(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + return { + key: value for key, value in state_dict.items() if key.startswith("encoder.") or key.startswith("decoder.") + } def _sound_tokenizer_infer_num_blocks(state_dict: dict[str, torch.Tensor]) -> int: @@ -185,7 +206,7 @@ def _remap(key: str) -> str: def _sound_tokenizer_reshape_snake_params(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: out: dict[str, torch.Tensor] = {} for key, value in state_dict.items(): - if (key.endswith(".alpha") or key.endswith(".beta")) and value.ndim == 1: + if key.startswith("decoder.") and (key.endswith(".alpha") or key.endswith(".beta")) and value.ndim == 1: value = value.unsqueeze(0).unsqueeze(-1).contiguous() out[key] = value return out @@ -197,7 +218,11 @@ def _sound_tokenizer_reapply_weight_norm(state_dict: dict[str, torch.Tensor]) -> candidate_keys = [ key for key in state_dict - if key.endswith(".weight") and any(f".{layer}." in key for layer in ("conv1", "conv2", "conv_t1")) + if key.endswith(".weight") + and ( + any(f".{layer}." in key for layer in ("conv1", "conv2", "conv_t1")) + or re.fullmatch(r"encoder\.layers\.\d+\.weight", key) + ) ] for key in candidate_keys: stem = key[: -len(".weight")] @@ -216,8 +241,10 @@ def _sound_tokenizer_reapply_weight_norm(state_dict: dict[str, torch.Tensor]) -> def _remap_avae_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """Convert a legacy AVAE state dict into the Cosmos3AVAEAudioTokenizer state dict.""" state_dict = _sound_tokenizer_strip_per_key_prefixes(state_dict) - state_dict = _sound_tokenizer_filter_decoder(state_dict) + state_dict = _sound_tokenizer_filter_supported_modules(state_dict) if not state_dict: + raise RuntimeError("Sound tokenizer state dict has no `encoder.*` or `decoder.*` keys after prefix stripping.") + if not any(key.startswith("decoder.") for key in state_dict): raise RuntimeError("Sound tokenizer state dict has no `decoder.*` keys after prefix stripping.") state_dict = _sound_tokenizer_remap_flat_layout(state_dict) state_dict = _sound_tokenizer_reshape_snake_params(state_dict) @@ -230,20 +257,67 @@ def _remap_avae_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, tor def _build_sound_tokenizer( checkpoint_path: pathlib.Path, config_path: pathlib.Path | None, -) -> Cosmos3AVAEAudioTokenizer: +): + from diffusers.models.autoencoders.autoencoder_cosmos3_audio import Cosmos3AVAEAudioTokenizer + config = _load_sound_tokenizer_config(config_path, fallback_config_path=pathlib.Path()) print(f"Loading AVAE sound tokenizer weights from {checkpoint_path} …") raw_state_dict = _load_sound_tokenizer_state_dict(checkpoint_path) state_dict = _remap_avae_state_dict(raw_state_dict) - print(f" Remapped {len(raw_state_dict)} → {len(state_dict)} decoder keys.") + has_encoder = any(key.startswith("encoder.") for key in state_dict) + print( + f" Remapped {len(raw_state_dict)} → {len(state_dict)} tokenizer keys " + f"({'encoder+decoder' if has_encoder else 'decoder-only'})." + ) sound_tokenizer = Cosmos3AVAEAudioTokenizer( + model_type=config.get("model_type", DEFAULT_SOUND_TOKENIZER_CONFIG["model_type"]), sampling_rate=config.get("sampling_rate", DEFAULT_SOUND_TOKENIZER_CONFIG["sampling_rate"]), + stereo=config.get("stereo", DEFAULT_SOUND_TOKENIZER_CONFIG["stereo"]), + use_wav_as_input=config.get("use_wav_as_input", DEFAULT_SOUND_TOKENIZER_CONFIG["use_wav_as_input"]), + normalize_volume=config.get("normalize_volume", DEFAULT_SOUND_TOKENIZER_CONFIG["normalize_volume"]), + hop_size=config.get("hop_size", DEFAULT_SOUND_TOKENIZER_CONFIG["hop_size"]), + input_channels=config.get("input_channels", DEFAULT_SOUND_TOKENIZER_CONFIG["input_channels"]), + enc_type=config.get("enc_type", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_type"]), + enc_dim=config.get("enc_dim", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_dim"]), + enc_intermediate_dim=config.get( + "enc_intermediate_dim", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_intermediate_dim"] + ), + enc_num_layers=config.get("enc_num_layers", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_num_layers"]), + enc_num_blocks=config.get("enc_num_blocks", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_num_blocks"]), + enc_n_fft=config.get("enc_n_fft", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_n_fft"]), + enc_hop_length=config.get("enc_hop_length", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_hop_length"]), + enc_latent_dim=config.get("enc_latent_dim", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_latent_dim"]), + enc_c_mults=tuple(config.get("enc_c_mults", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_c_mults"])), + enc_strides=tuple(config.get("enc_strides", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_strides"])), + enc_identity_init=config.get("enc_identity_init", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_identity_init"]), + enc_use_snake=config.get("enc_use_snake", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_use_snake"]), + dec_type=config.get("dec_type", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_type"]), vocoder_input_dim=config.get("vocoder_input_dim", DEFAULT_SOUND_TOKENIZER_CONFIG["vocoder_input_dim"]), dec_dim=config.get("dec_dim", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_dim"]), dec_c_mults=tuple(config.get("dec_c_mults", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_c_mults"])), dec_strides=tuple(config.get("dec_strides", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_strides"])), + dec_use_snake=config.get("dec_use_snake", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_use_snake"]), + dec_final_tanh=config.get("dec_final_tanh", False), dec_out_channels=config.get("dec_out_channels", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_out_channels"]), + dec_anti_aliasing=config.get("dec_anti_aliasing", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_anti_aliasing"]), + dec_use_nearest_upsample=config.get( + "dec_use_nearest_upsample", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_use_nearest_upsample"] + ), + dec_use_tanh_at_final=config.get( + "dec_use_tanh_at_final", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_use_tanh_at_final"] + ), + bottleneck_type=config.get("bottleneck_type", DEFAULT_SOUND_TOKENIZER_CONFIG["bottleneck_type"]), + bottleneck=config.get("bottleneck", DEFAULT_SOUND_TOKENIZER_CONFIG["bottleneck"]), + activation=config.get("activation", DEFAULT_SOUND_TOKENIZER_CONFIG["activation"]), + snake_logscale=config.get("snake_logscale", DEFAULT_SOUND_TOKENIZER_CONFIG["snake_logscale"]), + anti_aliasing=config.get("anti_aliasing", DEFAULT_SOUND_TOKENIZER_CONFIG["anti_aliasing"]), + use_cuda_kernel=config.get("use_cuda_kernel", DEFAULT_SOUND_TOKENIZER_CONFIG["use_cuda_kernel"]), + causal=config.get("causal", DEFAULT_SOUND_TOKENIZER_CONFIG["causal"]), + padding_mode=config.get("padding_mode", DEFAULT_SOUND_TOKENIZER_CONFIG["padding_mode"]), + latent_mean=config.get("latent_mean", DEFAULT_SOUND_TOKENIZER_CONFIG["latent_mean"]), + latent_std=config.get("latent_std", DEFAULT_SOUND_TOKENIZER_CONFIG["latent_std"]), + encoder_enabled=has_encoder, ) load_result = sound_tokenizer.load_state_dict(state_dict, strict=True) if load_result.missing_keys or load_result.unexpected_keys: @@ -255,8 +329,8 @@ def _build_sound_tokenizer( @contextlib.contextmanager -def _skip_source_sound_tokenizer_load(): - original_set_up_tokenizers = OmniMoTModel.set_up_tokenizers +def _skip_source_sound_tokenizer_load(omni_mot_model_cls): + original_set_up_tokenizers = omni_mot_model_cls.set_up_tokenizers def set_up_tokenizers_without_sound(self): if not getattr(self.config, "sound_gen", False): @@ -269,14 +343,28 @@ def set_up_tokenizers_without_sound(self): finally: self.config.sound_gen = sound_gen - OmniMoTModel.set_up_tokenizers = set_up_tokenizers_without_sound + omni_mot_model_cls.set_up_tokenizers = set_up_tokenizers_without_sound try: yield finally: - OmniMoTModel.set_up_tokenizers = original_set_up_tokenizers + omni_mot_model_cls.set_up_tokenizers = original_set_up_tokenizers def main(): + from cosmos3.common.init import init_script + + init_script() + + from accelerate import init_empty_weights + from cosmos3.args import _CHECKPOINTS + from cosmos3.model import Cosmos3OmniModel + from projects.cosmos3.vfm.models.omni_mot_model import OmniMoTModel + from transformers import AutoTokenizer + + from diffusers import AutoencoderKLWan, UniPCMultistepScheduler + from diffusers.models.transformers.transformer_cosmos3 import Cosmos3OmniTransformer + from diffusers.pipelines.cosmos.pipeline_cosmos3_omni import Cosmos3OmniPipeline + parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--checkpoint-path", @@ -330,7 +418,7 @@ def main(): print("Instantiating model and loading weights from DCP checkpoint …") print("Skipping source AVAE tokenizer instantiation during converter-only model load …") - with _skip_source_sound_tokenizer_load(): + with _skip_source_sound_tokenizer_load(OmniMoTModel): _tmp = Cosmos3OmniModel.from_pretrained_dcp(checkpoint_path).model # Extract network components and architecture config from DCP model diff --git a/src/diffusers/models/autoencoders/autoencoder_cosmos3_audio.py b/src/diffusers/models/autoencoders/autoencoder_cosmos3_audio.py index d5d83d5f7076..f3356eebb80d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_cosmos3_audio.py +++ b/src/diffusers/models/autoencoders/autoencoder_cosmos3_audio.py @@ -13,22 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Cosmos3 AVAE Audio Tokenizer — decoder-only implementation. +"""Cosmos3 AVAE Audio Tokenizer. The decoder reuses the Oobleck architecture (Snake1d activations + weight-norm convs + residual units), inlined here -instead of imported so the audio module is self-contained. The corresponding encoder is intentionally not inlined: -upstream Cosmos3 uses a spec-convnext encoder whose tensor layout doesn't map onto Oobleck's encoder. +instead of imported so the audio module is self-contained. The encoder is the Cosmos3 SpecConvNeXt audio encoder used +by AVAE checkpoints; it is intentionally separate from Oobleck's waveform encoder because the tensor layouts and +bottleneck semantics are different. """ import math +from collections import OrderedDict +from dataclasses import dataclass import torch import torch.nn as nn +import torch.nn.functional as F from torch.nn.utils import weight_norm from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput from ...utils.accelerate_utils import apply_forward_hook -from ..modeling_utils import ModelMixin +from ..modeling_utils import ModelMixin, get_parameter_dtype +from ..normalization import FP32LayerNorm +from .autoencoder_oobleck import OobleckDiagonalGaussianDistribution # Copied from diffusers.models.autoencoders.autoencoder_oobleck.Snake1d @@ -47,17 +54,203 @@ def __init__(self, hidden_dim, logscale=True): self.logscale = logscale def forward(self, hidden_states): - shape = hidden_states.shape + return self._forward(hidden_states, self.alpha, self.beta, self.logscale) - alpha = self.alpha if not self.logscale else torch.exp(self.alpha) - beta = self.beta if not self.logscale else torch.exp(self.beta) + @staticmethod + def _forward(hidden_states, alpha, beta, logscale): + shape = hidden_states.shape + alpha = alpha if not logscale else torch.exp(alpha) + beta = beta if not logscale else torch.exp(beta) hidden_states = hidden_states.reshape(shape[0], shape[1], -1) hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2) hidden_states = hidden_states.reshape(shape) return hidden_states +class Cosmos3AudioSnakeBeta(nn.Module): + """SnakeBeta activation used by the Cosmos3 SpecConvNeXt encoder.""" + + def __init__(self, hidden_dim: int, logscale: bool = True): + super().__init__() + self.alpha = nn.Parameter(torch.zeros(hidden_dim)) + self.beta = nn.Parameter(torch.zeros(hidden_dim)) + self.logscale = logscale + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return Snake1d._forward(hidden_states, self.alpha[None, :, None], self.beta[None, :, None], self.logscale) + + +class Cosmos3AudioConvNeXtBlock(nn.Module): + """1D ConvNeXt block used by the Cosmos3 SpecConvNeXt encoder.""" + + def __init__( + self, + hidden_dim: int, + intermediate_dim: int, + identity_init: bool = False, + use_snake: bool = True, + causal: bool = False, + ): + super().__init__() + self.causal = causal + + if causal: + self.dwconv = nn.Sequential( + nn.ConstantPad1d((6, 0), 0), + nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, groups=hidden_dim), + ) + else: + self.dwconv = nn.Sequential( + nn.ConstantPad1d((3, 3), 0), + nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, groups=hidden_dim), + ) + + self.norm = FP32LayerNorm(hidden_dim, eps=1e-5, bias=False) + self.pwconv1 = nn.Conv1d(hidden_dim, intermediate_dim, kernel_size=1) + self.act = Cosmos3AudioSnakeBeta(intermediate_dim) if use_snake else nn.GELU() + self.pwconv2 = nn.Conv1d(intermediate_dim, hidden_dim, kernel_size=1) + if identity_init: + nn.init.zeros_(self.pwconv2.weight) + if self.pwconv2.bias is not None: + nn.init.zeros_(self.pwconv2.bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.dwconv(hidden_states) + hidden_states = self.norm(hidden_states.permute(0, 2, 1)).permute(0, 2, 1) + hidden_states = self.pwconv1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.pwconv2(hidden_states) + return residual + hidden_states + + +class Cosmos3AudioSpectrogramConvNeXtEncoder(nn.Module): + """Cosmos3 waveform-to-latent encoder using STFT features and ConvNeXt blocks.""" + + def __init__( + self, + input_channels: int, + stereo: bool, + channels: int, + latent_dim: int, + channel_multiples: tuple[int, ...], + strides: tuple[int, ...], + num_blocks: int, + n_fft: int, + hop_length: int, + identity_init: bool, + use_snake: bool, + causal: bool, + padding_mode: str, + ): + super().__init__() + + if causal: + raise NotImplementedError("Cosmos3 AVAE causal audio encoder is not supported yet.") + if len(channel_multiples) != len(strides): + raise ValueError( + "`enc_c_mults` and `enc_strides` must have the same length, got " + f"{len(channel_multiples)} and {len(strides)}." + ) + + self.input_channels = input_channels * (2 if stereo else 1) + self.channels = channels + self.latent_dim = latent_dim + self.channel_multiples = tuple(channel_multiples) + self.strides = tuple(strides) + self.num_blocks = num_blocks + self.n_fft = n_fft + self.hop_length = hop_length + self.causal = causal + + layers: list[nn.Module] = [ + weight_norm( + nn.Conv1d( + (n_fft + 2) * self.input_channels, + self.channel_multiples[0] * channels, + kernel_size=1, + bias=False, + ) + ) + ] + + for index, stride in enumerate(self.strides): + input_dim = self.channel_multiples[index] * channels + output_dim = ( + self.channel_multiples[index + 1] * channels + if index < len(self.channel_multiples) - 1 + else self.channel_multiples[-1] * channels + ) + + for _ in range(num_blocks): + layers.append( + Cosmos3AudioConvNeXtBlock( + hidden_dim=input_dim, + intermediate_dim=input_dim * 4, + identity_init=identity_init, + use_snake=use_snake, + causal=causal, + ) + ) + + layers.append( + weight_norm( + nn.Conv1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + padding_mode=padding_mode, + ) + ) + ) + + layers.append( + weight_norm(nn.Conv1d(self.channel_multiples[-1] * channels, latent_dim, kernel_size=1, bias=False)) + ) + self.layers = nn.Sequential(*layers) + + def _spectrogram(self, waveform: torch.Tensor) -> torch.Tensor: + pad_left = (self.n_fft - self.hop_length) // 2 + pad_right = (self.n_fft - self.hop_length) - pad_left + waveform = F.pad(waveform, (pad_left, pad_right)).float() + window = torch.hann_window(self.n_fft, device=waveform.device, dtype=waveform.dtype) + return torch.stft( + waveform, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.n_fft, + window=window, + center=False, + normalized=False, + onesided=True, + return_complex=True, + ) + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_samples = audio.shape + if num_channels != self.input_channels: + raise ValueError( + f"Cosmos3 AVAE encoder expected {self.input_channels} audio channels, got {num_channels}." + ) + + if num_channels > 1: + audio = audio.reshape(batch_size * num_channels, 1, num_samples) + + spectrogram = self._spectrogram(audio.squeeze(1)) + real, imaginary = torch.view_as_real(spectrogram).chunk(2, dim=-1) + spectrogram = torch.cat([real, imaginary], dim=1).squeeze(-1) + + spectrogram = spectrogram.to(audio.dtype) + if num_channels > 1: + spectrogram = spectrogram.reshape(batch_size, num_channels * spectrogram.shape[1], spectrogram.shape[2]) + + hidden_states = self.layers(spectrogram) + return hidden_states.transpose(1, 2) + + # Copied from diffusers.models.autoencoders.autoencoder_oobleck.OobleckResidualUnit with Oobleck->Cosmos3Audio class Cosmos3AudioResidualUnit(nn.Module): """ @@ -180,22 +373,81 @@ def forward(self, hidden_state): return hidden_state +@dataclass +class Cosmos3AudioEncoderOutput(BaseOutput): + """Output of `Cosmos3AVAEAudioTokenizer.encode`.""" + + latent_dist: OobleckDiagonalGaussianDistribution + + +@dataclass +class Cosmos3AudioDecoderOutput(BaseOutput): + """Output of `Cosmos3AVAEAudioTokenizer.forward`.""" + + sample: torch.Tensor + + class Cosmos3AVAEAudioTokenizer(ModelMixin, ConfigMixin): - """Decoder-only audio tokenizer for Cosmos3 sound generation. + """Audio tokenizer for Cosmos3 sound generation. - Wraps the Cosmos3Audio decoder (an inlined copy of Oobleck) used in the AVAE (Audio VAE) component of the Cosmos3 - omni model. Provides the interface expected by ``Cosmos3OmniPipeline`` when ``enable_sound=True``. + Wraps the Cosmos3 AVAE SpecConvNeXt encoder and Oobleck-style decoder used by the Cosmos3 omni model. The decoder + API stays tensor-returning because ``Cosmos3OmniPipeline`` calls it directly when ``enable_sound=True``. - For now encoder part of the Tokenizer is not supported. The encoder support will be added in the future. + Only the shipped AVAE configuration (``model_type="autoencoder_v2"``, waveform input, ``spec_convnext`` encoder, + ``vae`` bottleneck, ``oobleck`` decoder, log-scale SnakeBeta, no latent normalization) is supported; any other value + raises ``NotImplementedError``. Parameters: + model_type (`str`, defaults to `"autoencoder_v2"`): AVAE model variant; only `"autoencoder_v2"` is supported. sampling_rate (`int`, defaults to `48000`): Audio sample rate in Hz. vocoder_input_dim (`int`, defaults to `64`): Latent channel count fed into the decoder (``== transformer sound_dim``). dec_dim (`int`, defaults to `320`): Base decoder channel count. - dec_c_mults (`tuple[int, ...]`, defaults to `(1, 2, 4, 8, 16)`): Channel multipliers. - dec_strides (`tuple[int, ...]`, defaults to `(2, 4, 5, 6, 8)`): Upsampling strides. + dec_c_mults (`tuple[int, ...]`, defaults to `(1, 2, 4, 8, 16)`): Decoder channel multipliers. + dec_strides (`tuple[int, ...]`, defaults to `(2, 4, 5, 6, 8)`): Decoder upsampling strides. dec_out_channels (`int`, defaults to `2`): Output audio channels (2 = stereo). + stereo (`bool`, defaults to `True`): Whether the audio is stereo; doubles the encoder's effective channel count. + use_wav_as_input (`bool`, defaults to `True`): Whether the encoder consumes raw waveforms; only `True` is + supported. + normalize_volume (`bool`, defaults to `True`): Whether `encode` peak-normalizes the waveform before encoding. + hop_size (`int`, *optional*): Waveform→latent temporal compression factor used for `encode` padding. Defaults + to `prod(dec_strides)` when `None`. + input_channels (`int`, defaults to `1`): Per-channel encoder input count before the `stereo` doubling. + enc_type (`str`, defaults to `"spec_convnext"`): Encoder type; only `"spec_convnext"` is supported. + enc_dim (`int`, defaults to `192`): Base encoder channel count. + enc_intermediate_dim (`int`, defaults to `768`): Unused; kept for config fidelity (ConvNeXt blocks use + ``input_dim * 4``). + enc_num_layers (`int`, defaults to `12`): Unused; kept for config fidelity (depth derives from `enc_num_blocks`). + enc_num_blocks (`int`, defaults to `2`): ConvNeXt blocks per encoder downsampling stage. + enc_n_fft (`int`, defaults to `64`): STFT FFT size for the encoder spectrogram front-end. + enc_hop_length (`int`, defaults to `16`): STFT hop length for the encoder spectrogram front-end. + enc_latent_dim (`int`, defaults to `128`): Encoder output channels; split into mean/scale by the VAE bottleneck + (so ``enc_latent_dim == 2 * vocoder_input_dim``). + enc_c_mults (`tuple[int, ...]`, defaults to `(1, 2, 4)`): Encoder channel multipliers per stage. + enc_strides (`tuple[int, ...]`, defaults to `(4, 5, 6)`): Encoder downsampling strides per stage. + enc_identity_init (`bool`, defaults to `False`): Whether to zero-init the ConvNeXt residual 1x1 convs. + enc_use_snake (`bool`, defaults to `True`): Whether ConvNeXt blocks use SnakeBeta (else GELU). + dec_type (`str`, defaults to `"oobleck"`): Decoder type; only `"oobleck"` is supported. + dec_use_snake (`bool`, defaults to `True`): Whether the decoder uses SnakeBeta; only `True` is supported. + dec_final_tanh (`bool`, defaults to `False`): Vestigial decoder tanh flag; only `False` is supported. + dec_anti_aliasing (`bool`, defaults to `False`): Decoder anti-aliasing flag; only `False` is supported. + dec_use_nearest_upsample (`bool`, defaults to `False`): Decoder upsample mode flag; only `False` is supported. + dec_use_tanh_at_final (`bool`, defaults to `False`): Decoder final-tanh flag; only `False` is supported. + bottleneck_type (`str`, defaults to `"vae"`): Bottleneck type; only `"vae"` is supported. + bottleneck (`dict`, *optional*): Bottleneck config; if given, its `"type"` must be `"vae"`. + activation (`str`, defaults to `"snakebeta"`): Activation family; only `"snakebeta"` is supported. + snake_logscale (`bool`, defaults to `True`): Whether SnakeBeta parameters are log-scaled; only `True` is + supported. + anti_aliasing (`bool`, defaults to `False`): Global anti-aliasing flag; only `False` is supported. + use_cuda_kernel (`bool`, defaults to `False`): Whether to use fused CUDA kernels; only `False` is supported. + causal (`bool`, defaults to `False`): Whether convolutions are causal; only `False` is supported by the encoder. + padding_mode (`str`, defaults to `"zeros"`): Convolution padding mode. + latent_mean (`float` or `list[float]`, *optional*): Latent normalization mean; latent normalization is not + implemented, so a non-`None` value raises ``NotImplementedError``. + latent_std (`float` or `list[float]`, *optional*): Latent normalization std; latent normalization is not + implemented, so a non-`None` value raises ``NotImplementedError``. + encoder_enabled (`bool`, defaults to `True`): Whether to instantiate the encoder. Set to `False` (or + auto-disabled on load) for decoder-only checkpoints, which cannot `encode`. """ _supports_gradient_checkpointing = False @@ -204,15 +456,97 @@ class Cosmos3AVAEAudioTokenizer(ModelMixin, ConfigMixin): @register_to_config def __init__( self, + model_type: str = "autoencoder_v2", sampling_rate: int = 48000, vocoder_input_dim: int = 64, dec_dim: int = 320, dec_c_mults: tuple = (1, 2, 4, 8, 16), dec_strides: tuple = (2, 4, 5, 6, 8), dec_out_channels: int = 2, + stereo: bool = True, + use_wav_as_input: bool = True, + normalize_volume: bool = True, + hop_size: int | None = None, + input_channels: int = 1, + enc_type: str = "spec_convnext", + enc_dim: int = 192, + enc_intermediate_dim: int = 768, + enc_num_layers: int = 12, + enc_num_blocks: int = 2, + enc_n_fft: int = 64, + enc_hop_length: int = 16, + enc_latent_dim: int = 128, + enc_c_mults: tuple = (1, 2, 4), + enc_strides: tuple = (4, 5, 6), + enc_identity_init: bool = False, + enc_use_snake: bool = True, + dec_type: str = "oobleck", + dec_use_snake: bool = True, + dec_final_tanh: bool = False, + dec_anti_aliasing: bool = False, + dec_use_nearest_upsample: bool = False, + dec_use_tanh_at_final: bool = False, + bottleneck_type: str = "vae", + bottleneck: dict | None = None, + activation: str = "snakebeta", + snake_logscale: bool = True, + anti_aliasing: bool = False, + use_cuda_kernel: bool = False, + causal: bool = False, + padding_mode: str = "zeros", + latent_mean: float | list[float] | None = None, + latent_std: float | list[float] | None = None, + encoder_enabled: bool = True, ): super().__init__() + if model_type != "autoencoder_v2": + raise NotImplementedError(f"Cosmos3 AVAE model type {model_type!r} is not supported.") + if not use_wav_as_input: + raise NotImplementedError("Cosmos3 AVAE tokenizer only supports waveform input.") + if enc_type != "spec_convnext": + raise NotImplementedError(f"Cosmos3 AVAE encoder type {enc_type!r} is not supported.") + if bottleneck is not None and bottleneck.get("type", bottleneck_type) != "vae": + raise NotImplementedError("Cosmos3 AVAE tokenizer only supports the VAE bottleneck.") + if bottleneck_type != "vae": + raise NotImplementedError("Cosmos3 AVAE tokenizer only supports the VAE bottleneck.") + if dec_type != "oobleck": + raise NotImplementedError(f"Cosmos3 AVAE decoder type {dec_type!r} is not supported.") + if ( + not dec_use_snake + or dec_final_tanh + or dec_anti_aliasing + or dec_use_nearest_upsample + or dec_use_tanh_at_final + ): + raise NotImplementedError("Cosmos3 AVAE decoder only supports the shipped Oobleck decoder configuration.") + if activation != "snakebeta" or not snake_logscale or anti_aliasing or use_cuda_kernel: + raise NotImplementedError("Cosmos3 AVAE tokenizer only supports the shipped SnakeBeta configuration.") + if latent_mean is not None or latent_std is not None: + raise NotImplementedError( + "Cosmos3 AVAE tokenizer does not apply latent normalization; `latent_mean`/`latent_std` must be None." + ) + + self.encoder = None + self._encoder_available = False + if encoder_enabled: + self.encoder = Cosmos3AudioSpectrogramConvNeXtEncoder( + input_channels=input_channels, + stereo=stereo, + channels=enc_dim, + latent_dim=enc_latent_dim, + channel_multiples=tuple(enc_c_mults), + strides=tuple(enc_strides), + num_blocks=enc_num_blocks, + n_fft=enc_n_fft, + hop_length=enc_hop_length, + identity_init=enc_identity_init, + use_snake=enc_use_snake, + causal=causal, + padding_mode=padding_mode, + ) + self._encoder_available = True + self.decoder = Cosmos3AudioDecoder( channels=dec_dim, input_channels=vocoder_input_dim, @@ -221,7 +555,62 @@ def __init__( channel_multiples=list(dec_c_mults), ) - self._hop_size: int = math.prod(dec_strides) + self._hop_size: int = int(hop_size) if hop_size is not None else math.prod(dec_strides) + + def _disable_encoder(self): + self.encoder = None + self._encoder_available = False + self.register_to_config(encoder_enabled=False) + + def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None: + super()._fix_state_dict_keys_on_load(state_dict) + if self.encoder is not None and not any(key.startswith("encoder.") for key in state_dict): + self._disable_encoder() + + def _encode(self, sample: torch.Tensor) -> torch.Tensor: + return self.encoder(sample).transpose(1, 2) + + @apply_forward_hook + def encode( + self, + sample: torch.Tensor, + return_dict: bool = True, + force_pad: bool = False, + ) -> Cosmos3AudioEncoderOutput | tuple[OobleckDiagonalGaussianDistribution]: + """Encode a waveform into a VAE latent distribution. + + Args: + sample: Audio waveform tensor with shape ``[B, C, T]``. + return_dict: Whether to return a ``Cosmos3AudioEncoderOutput``. + force_pad: Whether to right-pad to ``hop_size`` even when the model is in training mode. + """ + if sample.ndim != 3: + raise ValueError(f"`sample` must have shape [B, C, T], got {tuple(sample.shape)}.") + + if self.encoder is None or not self._encoder_available: + raise ValueError( + "This Cosmos3 AVAE sound tokenizer was loaded from decoder-only weights and cannot encode audio. " + "Re-convert the AVAE checkpoint with encoder weights to use `encode()`." + ) + + hidden_states = sample + if self.config.normalize_volume: + hidden_states = hidden_states / (hidden_states.abs().max() + 1e-5) * 0.95 + + if force_pad or not self.training: + sample_length = hidden_states.shape[-1] + padding = (self._hop_size - (sample_length % self._hop_size)) % self._hop_size + if padding > 0: + hidden_states = F.pad(hidden_states, (0, padding), mode="constant", value=0) + + encoder_dtype = get_parameter_dtype(self.encoder) + moments = self._encode(hidden_states.to(dtype=encoder_dtype)) + posterior = OobleckDiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return Cosmos3AudioEncoderOutput(latent_dist=posterior) @apply_forward_hook def decode(self, latents: torch.Tensor) -> torch.Tensor: @@ -238,3 +627,23 @@ def decode(self, latents: torch.Tensor) -> torch.Tensor: latents = latents.unsqueeze(0) audio = self.decoder(latents).clamp(-1.0, 1.0) return audio.squeeze(0) if squeeze else audio + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + force_pad: bool = False, + ) -> Cosmos3AudioDecoderOutput | tuple[torch.Tensor]: + """Encode then decode a waveform; ``sample_posterior=False`` (default) decodes the distribution mode (mean), + whereas the upstream Cosmos3 AVAE always samples — pass ``sample_posterior=True`` for reference-equivalent + behavior.""" + posterior = self.encode(sample, force_pad=force_pad).latent_dist + latents = posterior.sample(generator=generator) if sample_posterior else posterior.mode() + decoded = self.decode(latents) + + if not return_dict: + return (decoded,) + + return Cosmos3AudioDecoderOutput(sample=decoded) diff --git a/tests/models/autoencoders/test_models_autoencoder_cosmos3_audio.py b/tests/models/autoencoders/test_models_autoencoder_cosmos3_audio.py new file mode 100644 index 000000000000..2d1eb05bf73a --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_cosmos3_audio.py @@ -0,0 +1,166 @@ +# coding=utf-8 +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +from pathlib import Path + +import pytest +import torch + +from diffusers.models.autoencoders.autoencoder_cosmos3_audio import ( + Cosmos3AudioSnakeBeta, + Cosmos3AVAEAudioTokenizer, + Snake1d, +) +from diffusers.models.autoencoders.autoencoder_oobleck import OobleckDiagonalGaussianDistribution + + +def _get_tiny_cosmos3_audio_tokenizer() -> Cosmos3AVAEAudioTokenizer: + return Cosmos3AVAEAudioTokenizer( + sampling_rate=16, + hop_size=4, + input_channels=1, + stereo=True, + normalize_volume=True, + enc_dim=4, + enc_num_blocks=1, + enc_n_fft=8, + enc_hop_length=2, + enc_latent_dim=8, + enc_c_mults=(1,), + enc_strides=(2,), + vocoder_input_dim=4, + dec_dim=4, + dec_c_mults=(1, 2), + dec_strides=(2, 2), + dec_out_channels=2, + ) + + +def test_cosmos3_audio_tokenizer_encode_decode_forward_shapes(): + torch.manual_seed(0) + model = _get_tiny_cosmos3_audio_tokenizer().eval() + state_dict = model.state_dict() + assert "encoder.layers.1.norm.weight" in state_dict + assert "encoder.layers.1.norm.bias" not in state_dict + + audio = torch.randn(2, 2, 15) + + encoded = model.encode(audio) + assert isinstance(encoded.latent_dist, OobleckDiagonalGaussianDistribution) + assert encoded.latent_dist.mean.shape == (2, 4, 4) + assert encoded.latent_dist.scale.shape == (2, 4, 4) + + latents = encoded.latent_dist.mode() + decoded = model.decode(latents) + assert decoded.shape == (2, 2, 16) + assert decoded.min() >= -1.0 + assert decoded.max() <= 1.0 + + forward_output = model(audio) + assert forward_output.sample.shape == (2, 2, 16) + + tuple_output = model(audio, return_dict=False) + assert tuple_output[0].shape == (2, 2, 16) + + +def test_cosmos3_audio_tokenizer_encode_tuple_and_seeded_sample(): + torch.manual_seed(0) + model = _get_tiny_cosmos3_audio_tokenizer().eval() + audio = torch.randn(1, 2, 16) + + posterior = model.encode(audio, return_dict=False)[0] + sample_a = posterior.sample(generator=torch.Generator("cpu").manual_seed(13)) + sample_b = posterior.sample(generator=torch.Generator("cpu").manual_seed(13)) + + assert torch.allclose(sample_a, sample_b) + assert sample_a.shape == (1, 4, 4) + assert posterior.kl().ndim == 0 + + +def test_cosmos3_audio_snake_beta_matches_snake1d_with_1d_state(): + torch.manual_seed(0) + snake = Snake1d(4) + snake_beta = Cosmos3AudioSnakeBeta(4) + snake_beta.alpha.data.copy_(snake.alpha.flatten()) + snake_beta.beta.data.copy_(snake.beta.flatten()) + hidden_states = torch.randn(2, 4, 8) + + assert snake_beta.state_dict()["alpha"].shape == (4,) + assert torch.allclose(snake_beta(hidden_states), snake(hidden_states)) + + +def test_cosmos3_audio_tokenizer_decoder_only_state_disables_encode(): + model = _get_tiny_cosmos3_audio_tokenizer() + decoder_only_state_dict = {key: value for key, value in model.state_dict().items() if key.startswith("decoder.")} + + decoder_only_model = _get_tiny_cosmos3_audio_tokenizer() + decoder_only_model._fix_state_dict_keys_on_load(decoder_only_state_dict) + decoder_only_model.load_state_dict(decoder_only_state_dict, strict=True) + + assert decoder_only_model.encoder is None + with pytest.raises(ValueError, match="decoder-only weights"): + decoder_only_model.encode(torch.randn(1, 2, 16)) + + +def _load_converter_module(): + repo_root = Path(__file__).resolve().parents[3] + script_path = repo_root / "scripts" / "convert_cosmos3_to_diffusers.py" + spec = importlib.util.spec_from_file_location("convert_cosmos3_to_diffusers", script_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_cosmos3_audio_converter_keeps_encoder_and_remaps_decoder(): + converter = _load_converter_module() + state_dict = { + "generator.encoder.layers.0.weight": torch.ones(4, 20, 1), + "generator.encoder.layers.1.act.alpha": torch.zeros(16), + "generator.encoder.layers.1.act.beta": torch.zeros(16), + "generator.decoder.layers.0.weight": torch.ones(8, 4, 7), + "generator.decoder.layers.1.layers.0.alpha": torch.zeros(8), + "generator.decoder.layers.1.layers.1.weight": torch.ones(8, 4, 4), + "generator.decoder.layers.1.layers.2.layers.0.alpha": torch.zeros(4), + "generator.decoder.layers.1.layers.2.layers.1.weight": torch.ones(4, 4, 7), + "generator.decoder.layers.2.alpha": torch.zeros(4), + "generator.decoder.layers.3.weight": torch.ones(2, 4, 7), + } + + remapped = converter._remap_avae_state_dict(state_dict) + + assert not any(key.startswith("decoder.layers.") for key in remapped) + assert "encoder.layers.0.weight" not in remapped + assert "encoder.layers.0.weight_g" in remapped + assert "encoder.layers.0.weight_v" in remapped + assert remapped["encoder.layers.1.act.alpha"].shape == (16,) + assert remapped["decoder.conv1.weight_g"].shape == (8, 1, 1) + assert remapped["decoder.block.0.snake1.alpha"].shape == (1, 8, 1) + assert remapped["decoder.block.0.res_unit1.snake1.alpha"].shape == (1, 4, 1) + assert remapped["decoder.snake1.alpha"].shape == (1, 4, 1) + + +def test_cosmos3_audio_converter_allows_decoder_only_state_dict(): + converter = _load_converter_module() + state_dict = { + "decoder.conv1.weight": torch.ones(8, 4, 7), + "decoder.snake1.alpha": torch.zeros(4), + } + + remapped = converter._remap_avae_state_dict(state_dict) + + assert not any(key.startswith("encoder.") for key in remapped) + assert "decoder.conv1.weight_g" in remapped + assert remapped["decoder.snake1.alpha"].shape == (1, 4, 1)