Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 115 additions & 27 deletions scripts/convert_cosmos3_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,49 @@
import re

import torch
from cosmos3.common.init import init_script


init_script()

from accelerate import init_empty_weights # noqa: E402
from cosmos3.args import _CHECKPOINTS # noqa: E402
from cosmos3.model import Cosmos3OmniModel # noqa: E402
from projects.cosmos3.vfm.models.omni_mot_model import OmniMoTModel # noqa: E402
from transformers import AutoTokenizer # noqa: E402

from diffusers import AutoencoderKLWan, UniPCMultistepScheduler # noqa: E402
from diffusers.models.autoencoders.autoencoder_cosmos3_audio import Cosmos3AVAEAudioTokenizer # noqa: E402
from diffusers.models.transformers.transformer_cosmos3 import Cosmos3OmniTransformer # noqa: E402
from diffusers.pipelines.cosmos.pipeline_cosmos3_omni import Cosmos3OmniPipeline # noqa: E402


DEFAULT_SOUND_TOKENIZER_CONFIG = {
"model_type": "autoencoder_v2",
"sampling_rate": 48000,
"stereo": True,
"use_wav_as_input": True,
"normalize_volume": True,
"hop_size": 1920,
"input_channels": 1,
"enc_type": "spec_convnext",
"enc_dim": 192,
"enc_intermediate_dim": 768,
"enc_num_layers": 12,
"enc_num_blocks": 2,
"enc_n_fft": 64,
"enc_hop_length": 16,
"enc_latent_dim": 128,
"enc_c_mults": [1, 2, 4],
"enc_strides": [4, 5, 6],
"enc_identity_init": False,
"enc_use_snake": True,
"dec_type": "oobleck",
"vocoder_input_dim": 64,
"dec_dim": 320,
"dec_c_mults": [1, 2, 4, 8, 16],
"dec_strides": [2, 4, 5, 6, 8],
"dec_use_snake": True,
"dec_final_tanh": False,
"dec_out_channels": 2,
"dec_anti_aliasing": False,
"dec_use_nearest_upsample": False,
"dec_use_tanh_at_final": False,
"bottleneck_type": "vae",
"bottleneck": {"type": "vae"},
"activation": "snakebeta",
"snake_logscale": True,
"anti_aliasing": False,
"use_cuda_kernel": False,
"causal": False,
"padding_mode": "zeros",
"latent_mean": None,
"latent_std": None,
}


Expand Down Expand Up @@ -114,8 +133,10 @@ def _sound_tokenizer_strip_per_key_prefixes(state_dict: dict[str, torch.Tensor])
return out


def _sound_tokenizer_filter_decoder(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
return {key: value for key, value in state_dict.items() if key.startswith("decoder.")}
def _sound_tokenizer_filter_supported_modules(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
return {
key: value for key, value in state_dict.items() if key.startswith("encoder.") or key.startswith("decoder.")
}


def _sound_tokenizer_infer_num_blocks(state_dict: dict[str, torch.Tensor]) -> int:
Expand Down Expand Up @@ -185,7 +206,7 @@ def _remap(key: str) -> str:
def _sound_tokenizer_reshape_snake_params(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
out: dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
if (key.endswith(".alpha") or key.endswith(".beta")) and value.ndim == 1:
if key.startswith("decoder.") and (key.endswith(".alpha") or key.endswith(".beta")) and value.ndim == 1:
value = value.unsqueeze(0).unsqueeze(-1).contiguous()
out[key] = value
return out
Expand All @@ -197,7 +218,11 @@ def _sound_tokenizer_reapply_weight_norm(state_dict: dict[str, torch.Tensor]) ->
candidate_keys = [
key
for key in state_dict
if key.endswith(".weight") and any(f".{layer}." in key for layer in ("conv1", "conv2", "conv_t1"))
if key.endswith(".weight")
and (
any(f".{layer}." in key for layer in ("conv1", "conv2", "conv_t1"))
or re.fullmatch(r"encoder\.layers\.\d+\.weight", key)
)
]
for key in candidate_keys:
stem = key[: -len(".weight")]
Expand All @@ -216,8 +241,10 @@ def _sound_tokenizer_reapply_weight_norm(state_dict: dict[str, torch.Tensor]) ->
def _remap_avae_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Convert a legacy AVAE state dict into the Cosmos3AVAEAudioTokenizer state dict."""
state_dict = _sound_tokenizer_strip_per_key_prefixes(state_dict)
state_dict = _sound_tokenizer_filter_decoder(state_dict)
state_dict = _sound_tokenizer_filter_supported_modules(state_dict)
if not state_dict:
raise RuntimeError("Sound tokenizer state dict has no `encoder.*` or `decoder.*` keys after prefix stripping.")
if not any(key.startswith("decoder.") for key in state_dict):
raise RuntimeError("Sound tokenizer state dict has no `decoder.*` keys after prefix stripping.")
state_dict = _sound_tokenizer_remap_flat_layout(state_dict)
state_dict = _sound_tokenizer_reshape_snake_params(state_dict)
Expand All @@ -230,20 +257,67 @@ def _remap_avae_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, tor
def _build_sound_tokenizer(
checkpoint_path: pathlib.Path,
config_path: pathlib.Path | None,
) -> Cosmos3AVAEAudioTokenizer:
):
from diffusers.models.autoencoders.autoencoder_cosmos3_audio import Cosmos3AVAEAudioTokenizer

config = _load_sound_tokenizer_config(config_path, fallback_config_path=pathlib.Path())
print(f"Loading AVAE sound tokenizer weights from {checkpoint_path} …")
raw_state_dict = _load_sound_tokenizer_state_dict(checkpoint_path)
state_dict = _remap_avae_state_dict(raw_state_dict)
print(f" Remapped {len(raw_state_dict)} → {len(state_dict)} decoder keys.")
has_encoder = any(key.startswith("encoder.") for key in state_dict)
print(
f" Remapped {len(raw_state_dict)} → {len(state_dict)} tokenizer keys "
f"({'encoder+decoder' if has_encoder else 'decoder-only'})."
)

sound_tokenizer = Cosmos3AVAEAudioTokenizer(
model_type=config.get("model_type", DEFAULT_SOUND_TOKENIZER_CONFIG["model_type"]),
sampling_rate=config.get("sampling_rate", DEFAULT_SOUND_TOKENIZER_CONFIG["sampling_rate"]),
stereo=config.get("stereo", DEFAULT_SOUND_TOKENIZER_CONFIG["stereo"]),
use_wav_as_input=config.get("use_wav_as_input", DEFAULT_SOUND_TOKENIZER_CONFIG["use_wav_as_input"]),
normalize_volume=config.get("normalize_volume", DEFAULT_SOUND_TOKENIZER_CONFIG["normalize_volume"]),
hop_size=config.get("hop_size", DEFAULT_SOUND_TOKENIZER_CONFIG["hop_size"]),
input_channels=config.get("input_channels", DEFAULT_SOUND_TOKENIZER_CONFIG["input_channels"]),
enc_type=config.get("enc_type", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_type"]),
enc_dim=config.get("enc_dim", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_dim"]),
enc_intermediate_dim=config.get(
"enc_intermediate_dim", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_intermediate_dim"]
),
enc_num_layers=config.get("enc_num_layers", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_num_layers"]),
enc_num_blocks=config.get("enc_num_blocks", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_num_blocks"]),
enc_n_fft=config.get("enc_n_fft", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_n_fft"]),
enc_hop_length=config.get("enc_hop_length", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_hop_length"]),
enc_latent_dim=config.get("enc_latent_dim", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_latent_dim"]),
enc_c_mults=tuple(config.get("enc_c_mults", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_c_mults"])),
enc_strides=tuple(config.get("enc_strides", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_strides"])),
enc_identity_init=config.get("enc_identity_init", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_identity_init"]),
enc_use_snake=config.get("enc_use_snake", DEFAULT_SOUND_TOKENIZER_CONFIG["enc_use_snake"]),
dec_type=config.get("dec_type", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_type"]),
vocoder_input_dim=config.get("vocoder_input_dim", DEFAULT_SOUND_TOKENIZER_CONFIG["vocoder_input_dim"]),
dec_dim=config.get("dec_dim", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_dim"]),
dec_c_mults=tuple(config.get("dec_c_mults", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_c_mults"])),
dec_strides=tuple(config.get("dec_strides", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_strides"])),
dec_use_snake=config.get("dec_use_snake", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_use_snake"]),
dec_final_tanh=config.get("dec_final_tanh", False),
dec_out_channels=config.get("dec_out_channels", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_out_channels"]),
dec_anti_aliasing=config.get("dec_anti_aliasing", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_anti_aliasing"]),
dec_use_nearest_upsample=config.get(
"dec_use_nearest_upsample", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_use_nearest_upsample"]
),
dec_use_tanh_at_final=config.get(
"dec_use_tanh_at_final", DEFAULT_SOUND_TOKENIZER_CONFIG["dec_use_tanh_at_final"]
),
bottleneck_type=config.get("bottleneck_type", DEFAULT_SOUND_TOKENIZER_CONFIG["bottleneck_type"]),
bottleneck=config.get("bottleneck", DEFAULT_SOUND_TOKENIZER_CONFIG["bottleneck"]),
activation=config.get("activation", DEFAULT_SOUND_TOKENIZER_CONFIG["activation"]),
snake_logscale=config.get("snake_logscale", DEFAULT_SOUND_TOKENIZER_CONFIG["snake_logscale"]),
anti_aliasing=config.get("anti_aliasing", DEFAULT_SOUND_TOKENIZER_CONFIG["anti_aliasing"]),
use_cuda_kernel=config.get("use_cuda_kernel", DEFAULT_SOUND_TOKENIZER_CONFIG["use_cuda_kernel"]),
causal=config.get("causal", DEFAULT_SOUND_TOKENIZER_CONFIG["causal"]),
padding_mode=config.get("padding_mode", DEFAULT_SOUND_TOKENIZER_CONFIG["padding_mode"]),
latent_mean=config.get("latent_mean", DEFAULT_SOUND_TOKENIZER_CONFIG["latent_mean"]),
latent_std=config.get("latent_std", DEFAULT_SOUND_TOKENIZER_CONFIG["latent_std"]),
encoder_enabled=has_encoder,
)
load_result = sound_tokenizer.load_state_dict(state_dict, strict=True)
if load_result.missing_keys or load_result.unexpected_keys:
Expand All @@ -255,8 +329,8 @@ def _build_sound_tokenizer(


@contextlib.contextmanager
def _skip_source_sound_tokenizer_load():
original_set_up_tokenizers = OmniMoTModel.set_up_tokenizers
def _skip_source_sound_tokenizer_load(omni_mot_model_cls):
original_set_up_tokenizers = omni_mot_model_cls.set_up_tokenizers

def set_up_tokenizers_without_sound(self):
if not getattr(self.config, "sound_gen", False):
Expand All @@ -269,14 +343,28 @@ def set_up_tokenizers_without_sound(self):
finally:
self.config.sound_gen = sound_gen

OmniMoTModel.set_up_tokenizers = set_up_tokenizers_without_sound
omni_mot_model_cls.set_up_tokenizers = set_up_tokenizers_without_sound
try:
yield
finally:
OmniMoTModel.set_up_tokenizers = original_set_up_tokenizers
omni_mot_model_cls.set_up_tokenizers = original_set_up_tokenizers


def main():
from cosmos3.common.init import init_script

init_script()

from accelerate import init_empty_weights
from cosmos3.args import _CHECKPOINTS
from cosmos3.model import Cosmos3OmniModel
from projects.cosmos3.vfm.models.omni_mot_model import OmniMoTModel
from transformers import AutoTokenizer

from diffusers import AutoencoderKLWan, UniPCMultistepScheduler
from diffusers.models.transformers.transformer_cosmos3 import Cosmos3OmniTransformer
from diffusers.pipelines.cosmos.pipeline_cosmos3_omni import Cosmos3OmniPipeline

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--checkpoint-path",
Expand Down Expand Up @@ -330,7 +418,7 @@ def main():

print("Instantiating model and loading weights from DCP checkpoint …")
print("Skipping source AVAE tokenizer instantiation during converter-only model load …")
with _skip_source_sound_tokenizer_load():
with _skip_source_sound_tokenizer_load(OmniMoTModel):
_tmp = Cosmos3OmniModel.from_pretrained_dcp(checkpoint_path).model

# Extract network components and architecture config from DCP model
Expand Down
Loading
Loading