Add Sound Encoder to Cosmos3#13911
Conversation
Signed-off-by: Maciej Bala <mbala@nvidia.com>
| def _disable_encoder(self): | ||
| self.encoder = None | ||
| self._encoder_available = False | ||
| self.register_to_config(encoder_enabled=False) | ||
|
|
||
| def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None: | ||
| super()._fix_state_dict_keys_on_load(state_dict) | ||
| if self.encoder is not None and not any(key.startswith("encoder.") for key in state_dict): | ||
| self._disable_encoder() | ||
|
|
There was a problem hiding this comment.
why do we need these two methods?
There was a problem hiding this comment.
It's an extra safety net for checkpoints that do not have the encoder weights. We will update the main checkpoint to have encoder weights, but I think it's still fine to keep this method in case of e.g. cached local checkpoints. We don't want them to break if people don't need the encoder weights.
| return hidden_states | ||
|
|
||
|
|
||
| class Cosmos3AudioSnakeBeta(nn.Module): |
There was a problem hiding this comment.
It looks like the existing Snake1d module implements essentially the same logic as Cosmos3AudioSnakeBeta, could we use it as well for the encoder?
There was a problem hiding this comment.
The math should be the same, but we'd need a reshape on load, since Cosmos3AudioSnakeBeta has 1D parameters instead of 3D. Let me think about it for a bit.
There was a problem hiding this comment.
I kept the separate classes for native checkpoint loading, but shared a forward implementation
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! Left an initial design review :).
Signed-off-by: Maciej Bala <mbala@nvidia.com>
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.