Skip to content

Add Sound Encoder to Cosmos3#13911

Draft
MaciejBalaNV wants to merge 3 commits into
huggingface:mainfrom
MaciejBalaNV:cosmos3_sound_encoder
Draft

Add Sound Encoder to Cosmos3#13911
MaciejBalaNV wants to merge 3 commits into
huggingface:mainfrom
MaciejBalaNV:cosmos3_sound_encoder

Conversation

@MaciejBalaNV

Copy link
Copy Markdown
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: Maciej Bala <mbala@nvidia.com>
@github-actions github-actions Bot added models tests size/L PR with diff > 200 LOC labels Jun 10, 2026
Comment on lines +617 to +626
def _disable_encoder(self):
self.encoder = None
self._encoder_available = False
self.register_to_config(encoder_enabled=False)

def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None:
super()._fix_state_dict_keys_on_load(state_dict)
if self.encoder is not None and not any(key.startswith("encoder.") for key in state_dict):
self._disable_encoder()

@yiyixuxu yiyixuxu Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need these two methods?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an extra safety net for checkpoints that do not have the encoder weights. We will update the main checkpoint to have encoder weights, but I think it's still fine to keep this method in case of e.g. cached local checkpoints. We don't want them to break if people don't need the encoder weights.

@yiyixuxu yiyixuxu requested a review from dg845 June 10, 2026 20:37
return hidden_states


class Cosmos3AudioSnakeBeta(nn.Module):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the existing Snake1d module implements essentially the same logic as Cosmos3AudioSnakeBeta, could we use it as well for the encoder?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The math should be the same, but we'd need a reshape on load, since Cosmos3AudioSnakeBeta has 1D parameters instead of 3D. Let me think about it for a bit.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept the separate classes for native checkpoint loading, but shared a forward implementation

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we potentially reshape the encoder Snake alpha/beta weights to 3D in the scripts/convert_cosmos3_to_diffusers.py conversion script? I think this would allow us to reuse Snake1d for the encoder.

Comment thread src/diffusers/models/autoencoders/autoencoder_cosmos3_audio.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_cosmos3_audio.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_cosmos3_audio.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_cosmos3_audio.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_cosmos3_audio.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_cosmos3_audio.py Outdated
Comment thread src/diffusers/models/autoencoders/autoencoder_cosmos3_audio.py Outdated

@dg845 dg845 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left an initial design review :).

MaciejBalaNV and others added 2 commits June 12, 2026 10:18
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

alpha = self.alpha if not self.logscale else torch.exp(self.alpha)
beta = self.beta if not self.logscale else torch.exp(self.beta)
@staticmethod
def _forward(hidden_states, alpha, beta, logscale):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Snake1d is # Copied from the Oobleck audio VAE for the Cosmos 3 decoder, I don't think we should modify it here (and this breaks the CI check that the implementations are synced). I think it would be better (for example) to convert the encoder Snake weights from 1D to 3D when converting the checkpoint, as suggested in #13911 (comment).

from diffusers.models.autoencoders.autoencoder_oobleck import OobleckDiagonalGaussianDistribution


def _get_tiny_cosmos3_audio_tokenizer() -> Cosmos3AVAEAudioTokenizer:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to refactor the tests to use the standard diffusers model test config + test mixins? For reference, here is what the LTX-2 audio VAE tests do:

class TestAutoencoderKLLTX2Audio(AutoencoderKLLTX2AudioTesterConfig, ModelTesterMixin):
base_precision = 1e-2
def test_outputs_equivalence(self):
pytest.skip("Unsupported test.")
class TestAutoencoderKLLTX2AudioTraining(AutoencoderKLLTX2AudioTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLLTX2Audio."""

So for example we would move the _get_tiny_cosmos3_audio_tokenizer logic into the get_init_dict method of a new Cosmos3AVAEAudioTokenizerTesterConfig class. I think we would still keep the Cosmos 3-specific tests below.

@dg845 dg845 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! Left some follow up comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/L PR with diff > 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants