diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 25b9f0ec2fbe..7b474781a525 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -549,6 +549,8 @@ title: Ideogram 4 - local: api/pipelines/pix2pix title: InstructPix2Pix + - local: api/pipelines/joyai_echo + title: JoyAI Echo - local: api/pipelines/joyimage_edit title: JoyImage Edit - local: api/pipelines/kandinsky diff --git a/docs/source/en/api/pipelines/joyai_echo.md b/docs/source/en/api/pipelines/joyai_echo.md new file mode 100644 index 000000000000..c47de4e7356d --- /dev/null +++ b/docs/source/en/api/pipelines/joyai_echo.md @@ -0,0 +1,43 @@ +# JoyAI-Echo + +JoyAI-Echo is a text-to-audio-video generation pipeline for multi-shot video stories. It builds on the LTX-2 component +layout and adds the JoyAI-Echo few-step DMD denoising schedule plus a paired audio-video memory bank for cross-shot +consistency. + +The pipeline accepts one prompt per shot. When a list of prompts is passed, generated video and audio latents from +earlier shots are kept as memory tokens for later shots. + +```py +import torch +from diffusers import JoyAIEchoPipeline +from diffusers.utils import encode_video + +pipe = JoyAIEchoPipeline.from_pretrained("path/to/converted-joyai-echo", torch_dtype=torch.bfloat16) +pipe.enable_model_cpu_offload() + +output = pipe( + [ + "A cinematic opening shot of the protagonist entering a quiet train station.", + "The same protagonist speaks softly while the camera follows through the platform.", + ], + height=736, + width=1280, + num_frames=241, + frame_rate=25.0, +) + +for i, (frames, audio) in enumerate(zip(output.frames, output.audio)): + encode_video(frames[0], fps=25, audio=audio[0].float().cpu(), output_path=f"shot_{i:03d}.mp4") +``` + +## JoyAIEchoPipeline + +[[autodoc]] JoyAIEchoPipeline + +## JoyAIEchoPipelineOutput + +[[autodoc]] pipelines.joyai_echo.JoyAIEchoPipelineOutput + +## JoyAIEchoShotOutput + +[[autodoc]] pipelines.joyai_echo.JoyAIEchoShotOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0f4eb50a709a..09e0940c245b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -272,6 +272,7 @@ "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", "Ideogram4Transformer2DModel", + "JoyAIEchoTransformer3DModel", "JoyImageEditTransformer3DModel", "Kandinsky3UNet", "Kandinsky5Transformer3DModel", @@ -616,6 +617,10 @@ "IFPipeline", "IFSuperResolutionPipeline", "ImageTextPipelineOutput", + "JoyAIEchoOriginalCheckpointPipeline", + "JoyAIEchoPipeline", + "JoyAIEchoPipelineOutput", + "JoyAIEchoShotOutput", "JoyImageEditPipeline", "JoyImageEditPipelineOutput", "Kandinsky3Img2ImgPipeline", @@ -1126,6 +1131,7 @@ HunyuanVideoTransformer3DModel, I2VGenXLUNet, Ideogram4Transformer2DModel, + JoyAIEchoTransformer3DModel, JoyImageEditTransformer3DModel, Kandinsky3UNet, Kandinsky5Transformer3DModel, @@ -1445,6 +1451,10 @@ IFPipeline, IFSuperResolutionPipeline, ImageTextPipelineOutput, + JoyAIEchoOriginalCheckpointPipeline, + JoyAIEchoPipeline, + JoyAIEchoPipelineOutput, + JoyAIEchoShotOutput, JoyImageEditPipeline, JoyImageEditPipelineOutput, Kandinsky3Img2ImgPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d8342cb50fd5..289cef89d43f 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -119,6 +119,7 @@ _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_ideogram4"] = ["Ideogram4Transformer2DModel"] + _import_structure["transformers.transformer_joyai_echo"] = ["JoyAIEchoTransformer3DModel"] _import_structure["transformers.transformer_joyimage"] = ["JoyImageEditTransformer3DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"] @@ -250,6 +251,7 @@ HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, Ideogram4Transformer2DModel, + JoyAIEchoTransformer3DModel, JoyImageEditTransformer3DModel, Kandinsky5Transformer3DModel, LatteTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 6ee8ca55de33..b9d376e175e2 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -40,6 +40,7 @@ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_ideogram4 import Ideogram4Transformer2DModel + from .transformer_joyai_echo import JoyAIEchoTransformer3DModel from .transformer_joyimage import JoyImageEditTransformer3DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer diff --git a/src/diffusers/models/transformers/transformer_joyai_echo.py b/src/diffusers/models/transformers/transformer_joyai_echo.py new file mode 100644 index 000000000000..9b3ba1772436 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_joyai_echo.py @@ -0,0 +1,287 @@ +# Copyright 2025 JoyAI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...utils import apply_lora_scale, logging +from .transformer_ltx2 import AudioVisualModelOutput, LTX2VideoTransformer3DModel + + +logger = logging.get_logger(__name__) + + +class JoyAIEchoTransformer3DModel(LTX2VideoTransformer3DModel): + """ + JoyAI-Echo audiovisual transformer with memory mask support. + + Inherits all architecture and weights from LTX2VideoTransformer3DModel, adding support for + paired audio-video memory attention masks (audio_self_attention_mask, a2v_cross_attention_mask, + v2a_cross_attention_mask) that are required for multi-shot generation with memory. + """ + + @apply_lora_scale("attention_kwargs") + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + audio_timestep: torch.LongTensor | None = None, + sigma: torch.Tensor | None = None, + audio_sigma: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + audio_encoder_attention_mask: torch.Tensor | None = None, + num_frames: int | None = None, + height: int | None = None, + width: int | None = None, + fps: float = 24.0, + audio_num_frames: int | None = None, + video_coords: torch.Tensor | None = None, + audio_coords: torch.Tensor | None = None, + isolate_modalities: bool = False, + spatio_temporal_guidance_blocks: list[int] | None = None, + perturbation_mask: torch.Tensor | None = None, + use_cross_timestep: bool = False, + attention_kwargs: dict[str, Any] | None = None, + video_self_attention_mask: torch.Tensor | None = None, + audio_self_attention_mask: torch.Tensor | None = None, + a2v_cross_attention_mask: torch.Tensor | None = None, + v2a_cross_attention_mask: torch.Tensor | None = None, + return_dict: bool = True, + ) -> torch.Tensor: + """ + Forward pass with memory mask support for multi-shot generation. + + Additional args over LTX2VideoTransformer3DModel.forward: + audio_self_attention_mask (`torch.Tensor`, *optional*): + Multiplicative mask [B, T_a, T_a] for audio self-attention (0/1 float). + Used to block cross-attention between memory and target audio tokens. + a2v_cross_attention_mask (`torch.Tensor`, *optional*): + Bool mask [B, T_v, T_a] for audio-to-video cross attention. + True = attend (per-slot pairing for paired memory). + v2a_cross_attention_mask (`torch.Tensor`, *optional*): + Bool mask [B, T_a, T_v] for video-to-audio cross attention. + True = attend (per-slot pairing for paired memory). + """ + audio_timestep = audio_timestep if audio_timestep is not None else timestep + audio_sigma = audio_sigma if audio_sigma is not None else sigma + + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0 + audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + + if video_self_attention_mask is not None: + video_self_attention_mask = (1 - video_self_attention_mask.to(hidden_states.dtype)) * -10000.0 + + if audio_self_attention_mask is not None: + audio_self_attention_mask = (1 - audio_self_attention_mask.to(audio_hidden_states.dtype)) * -10000.0 + + batch_size = hidden_states.size(0) + + # 1. Prepare RoPE positional embeddings + if video_coords is None: + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, audio_hidden_states.device + ) + + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) + audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + + video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( + audio_coords[:, 0:1, :], device=audio_hidden_states.device + ) + + # 2. Patchify input projections + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # 3. Prepare timestep embeddings and modulation parameters + timestep_cross_attn_gate_scale_factor = ( + self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + ) + + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + temb_audio, audio_embedded_timestep = self.audio_time_embed( + audio_timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) + audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + + if self.prompt_modulation: + temb_prompt, _ = self.prompt_adaln( + sigma.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + temb_prompt_audio, _ = self.audio_prompt_adaln( + audio_sigma.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype + ) + temb_prompt = temb_prompt.view(batch_size, -1, temb_prompt.size(-1)) + temb_prompt_audio = temb_prompt_audio.view(batch_size, -1, temb_prompt_audio.size(-1)) + else: + temb_prompt = temb_prompt_audio = None + + # 3.2. Prepare global modality cross attention modulation parameters + video_ca_timestep = audio_sigma.flatten() if use_cross_timestep else timestep.flatten() + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( + video_ca_timestep, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( + video_ca_timestep * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) + + audio_ca_timestep = sigma.flatten() if use_cross_timestep else audio_timestep.flatten() + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( + audio_ca_timestep, + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( + audio_ca_timestep * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) + + # 4. Prepare prompt embeddings + if self.config.use_prompt_embeddings: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view( + batch_size, -1, audio_hidden_states.size(-1) + ) + + # 5. Run transformer blocks + spatio_temporal_guidance_blocks = spatio_temporal_guidance_blocks or [] + if len(spatio_temporal_guidance_blocks) > 0 and perturbation_mask is None: + perturbation_mask = torch.zeros((batch_size,)) + if perturbation_mask is not None and perturbation_mask.ndim == 1: + perturbation_mask = perturbation_mask[:, None, None] + all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False + stg_blocks = set(spatio_temporal_guidance_blocks) + + for block_idx, block in enumerate(self.transformer_blocks): + block_perturbation_mask = perturbation_mask if block_idx in stg_blocks else None + block_all_perturbed = all_perturbed if block_idx in stg_blocks else False + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + audio_hidden_states, + encoder_hidden_states, + audio_encoder_hidden_states, + temb, + temb_audio, + video_cross_attn_scale_shift, + audio_cross_attn_scale_shift, + video_cross_attn_a2v_gate, + audio_cross_attn_v2a_gate, + temb_prompt, + temb_prompt_audio, + video_rotary_emb, + audio_rotary_emb, + video_cross_attn_rotary_emb, + audio_cross_attn_rotary_emb, + encoder_attention_mask, + audio_encoder_attention_mask, + video_self_attention_mask, + audio_self_attention_mask, + a2v_cross_attention_mask, + v2a_cross_attention_mask, + not isolate_modalities, + not isolate_modalities, + block_perturbation_mask, + block_all_perturbed, + ) + else: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + temb_prompt=temb_prompt, + temb_prompt_audio=temb_prompt_audio, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + self_attention_mask=video_self_attention_mask, + audio_self_attention_mask=audio_self_attention_mask, + a2v_cross_attention_mask=a2v_cross_attention_mask, + v2a_cross_attention_mask=v2a_cross_attention_mask, + use_a2v_cross_attention=not isolate_modalities, + use_v2a_cross_attention=not isolate_modalities, + perturbation_mask=block_perturbation_mask, + all_perturbed=block_all_perturbed, + ) + + # 6. Output layers + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] + + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) + + if not return_dict: + return (output, audio_output) + return AudioVisualModelOutput(sample=output, audio_sample=audio_output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 79aa504818c6..bbb1a387e859 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -343,6 +343,12 @@ "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", ] + _import_structure["joyai_echo"] = [ + "JoyAIEchoOriginalCheckpointPipeline", + "JoyAIEchoPipeline", + "JoyAIEchoPipelineOutput", + "JoyAIEchoShotOutput", + ] _import_structure["joyimage"] = ["JoyImageEditPipeline", "JoyImageEditPipelineOutput"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] @@ -751,6 +757,12 @@ from .hunyuan_video1_5 import HunyuanVideo15ImageToVideoPipeline, HunyuanVideo15Pipeline from .hunyuandit import HunyuanDiTPipeline from .ideogram4 import Ideogram4Pipeline, Ideogram4PromptEnhancerHead + from .joyai_echo import ( + JoyAIEchoOriginalCheckpointPipeline, + JoyAIEchoPipeline, + JoyAIEchoPipelineOutput, + JoyAIEchoShotOutput, + ) from .joyimage import JoyImageEditPipeline, JoyImageEditPipelineOutput from .kandinsky import ( KandinskyCombinedPipeline, diff --git a/src/diffusers/pipelines/joyai_echo/__init__.py b/src/diffusers/pipelines/joyai_echo/__init__.py new file mode 100644 index 000000000000..1fd37aaf70e4 --- /dev/null +++ b/src/diffusers/pipelines/joyai_echo/__init__.py @@ -0,0 +1,52 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_joyai_echo_original_checkpoint"] = ["JoyAIEchoOriginalCheckpointPipeline"] + _import_structure["pipeline_joyai_echo"] = ["JoyAIEchoMemoryBank", "JoyAIEchoMemorySlot", "JoyAIEchoPipeline"] + _import_structure["pipeline_output"] = ["JoyAIEchoPipelineOutput", "JoyAIEchoShotOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_joyai_echo import JoyAIEchoMemoryBank, JoyAIEchoMemorySlot, JoyAIEchoPipeline + from .pipeline_joyai_echo_original_checkpoint import JoyAIEchoOriginalCheckpointPipeline + from .pipeline_output import JoyAIEchoPipelineOutput, JoyAIEchoShotOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/joyai_echo/pipeline_joyai_echo.py b/src/diffusers/pipelines/joyai_echo/pipeline_joyai_echo.py new file mode 100644 index 000000000000..02030319798e --- /dev/null +++ b/src/diffusers/pipelines/joyai_echo/pipeline_joyai_echo.py @@ -0,0 +1,634 @@ +# Copyright 2026 JD.com and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +import torch +from transformers import Gemma3ForConditionalGeneration, Gemma3Processor, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..ltx2.connectors import LTX2TextConnectors +from ..ltx2.pipeline_ltx2 import LTX2Pipeline +from ..ltx2.vocoder import LTX2Vocoder, LTX2VocoderWithBWE +from .pipeline_output import JoyAIEchoPipelineOutput, JoyAIEchoShotOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class JoyAIEchoMemorySlot: + r""" + A paired audio-video memory slot used by [`JoyAIEchoPipeline`]. + + Args: + latents (`torch.Tensor`): + Packed video latent tokens of shape `(batch_size, sequence_length, channels)`. + video_coords (`torch.Tensor`): + Video positional coordinates of shape `(batch_size, 3, sequence_length, 2)`. + audio_latents (`torch.Tensor`, *optional*): + Packed audio latent tokens of shape `(batch_size, sequence_length, channels)`. + audio_coords (`torch.Tensor`, *optional*): + Audio positional coordinates of shape `(batch_size, 1, sequence_length, 2)`. + """ + + latents: torch.Tensor + video_coords: torch.Tensor + audio_latents: torch.Tensor | None = None + audio_coords: torch.Tensor | None = None + + +class JoyAIEchoMemoryBank: + r""" + FIFO paired audio-video memory bank for JoyAI-Echo multi-shot generation. + + The official JoyAI-Echo inference script stores selected frames and audio windows. In diffusers we keep the already + packed latent tokens, which avoids an additional VAE encode pass between shots and matches the in-context token + interface used by existing LTX-2 pipelines. + """ + + def __init__(self, max_size: int = 7): + self.max_size = int(max_size) + self.slots: list[JoyAIEchoMemorySlot] = [] + + def __len__(self) -> int: + return len(self.slots) + + def append(self, slot: JoyAIEchoMemorySlot) -> None: + if self.max_size <= 0: + return + self.slots.append(slot) + if len(self.slots) > self.max_size: + self.slots = self.slots[-self.max_size :] + + def get_video_memory(self, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor] | None: + if len(self.slots) == 0: + return None + latents = torch.cat([slot.latents.to(device=device, dtype=dtype) for slot in self.slots], dim=1) + coords = torch.cat([slot.video_coords.to(device=device) for slot in self.slots], dim=2) + return latents, coords + + def get_audio_memory(self, device: torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor] | None: + audio_slots = [slot for slot in self.slots if slot.audio_latents is not None and slot.audio_coords is not None] + if len(audio_slots) == 0: + return None + latents = torch.cat([slot.audio_latents.to(device=device, dtype=dtype) for slot in audio_slots], dim=1) + coords = torch.cat([slot.audio_coords.to(device=device) for slot in audio_slots], dim=2) + return latents, coords + + +def _as_prompt_list(prompt: str | list[str]) -> list[str]: + if isinstance(prompt, str): + return [prompt] + return prompt + + +def _select_memory_video_tokens( + latents: torch.Tensor, + video_coords: torch.Tensor, + latent_num_frames: int, + frame_index: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + tokens_per_frame = latents.shape[1] // int(latent_num_frames) + if tokens_per_frame <= 0: + raise ValueError("Cannot select JoyAI-Echo memory video tokens from an empty latent sequence.") + frame_index = int(latent_num_frames) // 2 if frame_index is None else int(frame_index) + frame_index = max(0, min(frame_index, int(latent_num_frames) - 1)) + start = frame_index * tokens_per_frame + end = start + tokens_per_frame + return latents[:, start:end].contiguous(), video_coords[:, :, start:end].contiguous() + + +def _select_memory_audio_tokens( + audio_latents: torch.Tensor, + audio_coords: torch.Tensor, + window_size: int = 96, +) -> tuple[torch.Tensor, torch.Tensor]: + total_frames = audio_latents.shape[1] + window_len = min(int(total_frames), max(1, int(window_size))) + start = max((int(total_frames) - window_len) // 2, 0) + end = start + window_len + return audio_latents[:, start:end].contiguous(), audio_coords[:, :, start:end].contiguous() + + +class JoyAIEchoPipeline(LTX2Pipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): + r""" + Pipeline for JoyAI-Echo text-to-audio-video multi-shot generation. + + JoyAI-Echo extends LTX-2 audio-video generation with few-step DMD denoising and a paired cross-shot memory bank. + This pipeline keeps the diffusers LTX-2 component layout (`transformer`, `vae`, `audio_vae`, `vocoder`, + `text_encoder`, `connectors`) and adds the JoyAI-Echo inference loop on top. + + Args: + scheduler: + Scheduler registered for compatibility with LTX-2 checkpoints. JoyAI-Echo's distilled inference uses the + explicit `denoising_sigmas` schedule passed to `__call__`. + vae ([`AutoencoderKLLTX2Video`]): + Video VAE used to decode video latents. + audio_vae ([`AutoencoderKLLTX2Audio`]): + Audio VAE used to decode audio latents to mel spectrograms. + text_encoder ([`Gemma3ForConditionalGeneration`]): + Gemma text encoder. + tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]): + Gemma tokenizer. + connectors ([`LTX2TextConnectors`]): + Connector stack adapting Gemma hidden states to video and audio contexts. + transformer ([`LTX2VideoTransformer3DModel`]): + Distilled LTX-2 audio-video transformer. + vocoder ([`LTX2Vocoder`] or [`LTX2VocoderWithBWE`]): + Vocoder used to convert generated mel spectrograms to waveform. + processor ([`Gemma3Processor`], *optional*): + Optional Gemma processor. + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = ["processor", "scheduler"] + _callback_tensor_inputs = ["latents", "audio_latents", "prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder | LTX2VocoderWithBWE, + processor: Gemma3Processor | None = None, + scheduler=None, + ): + super().__init__( + scheduler=scheduler, + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + processor=processor, + ) + + @staticmethod + def _add_flow_noise(sample: torch.Tensor, noise: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + while sigma.ndim < sample.ndim: + sigma = sigma.unsqueeze(-1) + return (1 - sigma) * sample + sigma * noise + + @staticmethod + def _repeat_token_timestep(sigma: torch.Tensor, num_tokens: int) -> torch.Tensor: + if sigma.ndim == 0: + sigma = sigma[None] + if sigma.ndim == 1: + return sigma[:, None].expand(-1, num_tokens).clone() + return sigma + + @staticmethod + def _build_video_memory_attention_mask( + num_memory_tokens: int, + num_target_tokens: int, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + total_tokens = num_memory_tokens + num_target_tokens + attention_mask = torch.ones(batch_size, total_tokens, total_tokens, device=device, dtype=dtype) + attention_mask[:, :, :num_memory_tokens] = 0 + attention_mask[:, :num_memory_tokens, :] = 0 + attention_mask[:, :num_memory_tokens, :num_memory_tokens] = 1 + return attention_mask + + def _prepare_prompt_context( + self, + prompt: str, + device: torch.device, + max_sequence_length: int, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask, _, _ = self.encode_prompt( + prompt=prompt, + negative_prompt=None, + do_classifier_free_guidance=False, + num_videos_per_prompt=1, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=None, + max_sequence_length=max_sequence_length, + device=device, + ) + elif prompt_attention_mask is None: + prompt_attention_mask = torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + tokenizer_padding_side = "left" + if getattr(self, "tokenizer", None) is not None: + tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left") + + return self.connectors(prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side) + + def _get_execution_device(self) -> torch.device: + try: + return self._execution_device + except AttributeError: + pass + + for component in self.components.values(): + if not isinstance(component, torch.nn.Module): + continue + for tensor in component.parameters(recurse=True): + return tensor.device + for tensor in component.buffers(recurse=True): + return tensor.device + return torch.device("cpu") + + def _decode_latents( + self, + latents: torch.Tensor, + audio_latents: torch.Tensor, + latent_num_frames: int, + latent_height: int, + latent_width: int, + audio_num_frames: int, + latent_mel_bins: int, + output_type: str, + decode_timestep: float = 0.0, + decode_noise_scale: float | None = None, + generator: torch.Generator | None = None, + ) -> tuple[Any, Any]: + device = latents.device + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + return latents, audio_latents + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + timestep = torch.tensor([decode_timestep], device=device, dtype=latents.dtype) + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + video = self.vae.decode(latents.to(self.vae.dtype), timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + mel_spectrograms = self.audio_vae.decode(audio_latents.to(self.audio_vae.dtype), return_dict=False)[0] + audio = self.vocoder(mel_spectrograms) + return video, audio + + def _denoise_shot( + self, + latents: torch.Tensor, + audio_latents: torch.Tensor, + denoising_sigmas: torch.Tensor, + prompt_embeds: torch.Tensor, + audio_prompt_embeds: torch.Tensor, + prompt_attention_mask: torch.Tensor, + video_coords: torch.Tensor, + audio_coords: torch.Tensor, + latent_num_frames: int, + latent_height: int, + latent_width: int, + audio_num_frames: int, + frame_rate: float, + memory_bank: JoyAIEchoMemoryBank, + transformer_outputs_x0: bool, + generator: torch.Generator | None = None, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[Any, int, torch.Tensor, dict], dict] | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size = latents.shape[0] + device = latents.device + dtype = latents.dtype + + memory_video = memory_bank.get_video_memory(device=device, dtype=dtype) + memory_audio = memory_bank.get_audio_memory(device=device, dtype=dtype) + memory_video_tokens = 0 if memory_video is None else memory_video[0].shape[1] + memory_audio_tokens = 0 if memory_audio is None else memory_audio[0].shape[1] + + with self.progress_bar(total=max(len(denoising_sigmas) - 1, 0)) as progress_bar: + for step_idx, sigma in enumerate(denoising_sigmas[:-1]): + sigma = sigma.to(device=device, dtype=torch.float32) + next_sigma = denoising_sigmas[step_idx + 1].to(device=device, dtype=torch.float32) + + video_model_input = latents + audio_model_input = audio_latents + video_model_coords = video_coords + audio_model_coords = audio_coords + video_attention_mask = None + + if memory_video is not None: + video_memory_latents, video_memory_coords = memory_video + video_model_input = torch.cat([video_memory_latents, latents], dim=1) + video_model_coords = torch.cat([video_memory_coords, video_coords], dim=2) + + if memory_audio is not None: + audio_memory_latents, audio_memory_coords = memory_audio + audio_model_input = torch.cat([audio_memory_latents, audio_latents], dim=1) + audio_model_coords = torch.cat([audio_memory_coords, audio_coords], dim=2) + + video_timestep = self._repeat_token_timestep(sigma.expand(batch_size), video_model_input.shape[1]) + audio_timestep = self._repeat_token_timestep(sigma.expand(batch_size), audio_model_input.shape[1]) + if memory_video_tokens > 0: + video_timestep[:, :memory_video_tokens] = 0 + if memory_audio_tokens > 0: + audio_timestep[:, :memory_audio_tokens] = 0 + + pred_video, pred_audio = self.transformer( + hidden_states=video_model_input.to(dtype=prompt_embeds.dtype), + audio_hidden_states=audio_model_input.to(dtype=prompt_embeds.dtype), + encoder_hidden_states=prompt_embeds, + audio_encoder_hidden_states=audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=audio_timestep, + sigma=sigma.expand(batch_size), + audio_sigma=sigma.expand(batch_size), + encoder_attention_mask=prompt_attention_mask, + audio_encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_model_coords, + audio_coords=audio_model_coords, + isolate_modalities=False, + spatio_temporal_guidance_blocks=None, + perturbation_mask=None, + use_cross_timestep=False, + attention_kwargs=attention_kwargs, + video_self_attention_mask=video_attention_mask, + return_dict=False, + ) + pred_video = pred_video[:, memory_video_tokens:].float() + pred_audio = pred_audio[:, memory_audio_tokens:].float() + + if not transformer_outputs_x0: + pred_video = latents.float() - pred_video * sigma + pred_audio = audio_latents.float() - pred_audio * sigma + + if next_sigma > 0: + video_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + audio_noise = randn_tensor( + audio_latents.shape, generator=generator, device=device, dtype=audio_latents.dtype + ) + latents = self._add_flow_noise(pred_video, video_noise, next_sigma).to(dtype=dtype) + audio_latents = self._add_flow_noise(pred_audio, audio_noise, next_sigma).to(dtype=dtype) + else: + latents = pred_video.to(dtype=dtype) + audio_latents = pred_audio.to(dtype=dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for name in callback_on_step_end_tensor_inputs or []: + callback_kwargs[name] = locals()[name] + callback_outputs = callback_on_step_end(self, step_idx, sigma, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + audio_latents = callback_outputs.pop("audio_latents", audio_latents) + + progress_bar.update() + + return latents, audio_latents + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + height: int = 736, + width: int = 1280, + num_frames: int = 241, + frame_rate: float = 25.0, + denoising_sigmas: list[float] | torch.Tensor | None = None, + memory_max_size: int = 7, + generator: torch.Generator | None = None, + prompt_embeds: list[torch.Tensor] | torch.Tensor | None = None, + prompt_attention_mask: list[torch.Tensor] | torch.Tensor | None = None, + output_type: str = "pil", + return_latents: bool = False, + return_dict: bool = True, + decode_timestep: float = 0.0, + decode_noise_scale: float | None = None, + transformer_outputs_x0: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[Any, int, torch.Tensor, dict], dict] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents", "audio_latents"], + max_sequence_length: int = 1024, + ) -> JoyAIEchoPipelineOutput | tuple: + r""" + Generates one or more JoyAI-Echo shots. + + Args: + prompt (`str` or `list[str]`): + One prompt per shot. Passing a list enables cross-shot memory conditioning. + height (`int`, *optional*, defaults to `736`): + Generated video height. + width (`int`, *optional*, defaults to `1280`): + Generated video width. + num_frames (`int`, *optional*, defaults to `241`): + Number of video frames per shot. + frame_rate (`float`, *optional*, defaults to `25.0`): + Video frame rate. + denoising_sigmas (`list[float]` or `torch.Tensor`, *optional*): + JoyAI-Echo DMD sigma schedule. Defaults to the official inference schedule. + memory_max_size (`int`, *optional*, defaults to `7`): + Maximum number of previous shots kept as paired audio-video memory. + transformer_outputs_x0 (`bool`, *optional*, defaults to `True`): + Whether the transformer directly predicts `x0`. Set to `False` for velocity-prediction LTX-2 + transformers. + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + if denoising_sigmas is None: + denoising_sigmas = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] + + if isinstance(denoising_sigmas, torch.Tensor): + denoising_sigmas = denoising_sigmas.detach().float() + else: + denoising_sigmas = torch.tensor(denoising_sigmas, dtype=torch.float32) + if denoising_sigmas.ndim != 1 or denoising_sigmas.shape[0] < 2: + raise ValueError("`denoising_sigmas` must be a 1D sequence with at least two values.") + + prompts = _as_prompt_list(prompt) + device = self._get_execution_device() + batch_size = 1 + + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + num_channels_latents = self.transformer.config.in_channels + + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + num_mel_bins = self.audio_vae.config.mel_bins + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = self.audio_vae.config.latent_channels + + video_coords = self.transformer.rope.prepare_video_coords( + batch_size, latent_num_frames, latent_height, latent_width, device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords(batch_size, audio_num_frames, device) + + memory_bank = JoyAIEchoMemoryBank(max_size=memory_max_size) + shots: list[JoyAIEchoShotOutput] = [] + + for shot_idx, shot_prompt in enumerate(prompts): + current_prompt_embeds = None + current_prompt_attention_mask = None + if isinstance(prompt_embeds, list): + current_prompt_embeds = prompt_embeds[shot_idx] + elif isinstance(prompt_embeds, torch.Tensor): + current_prompt_embeds = prompt_embeds + if isinstance(prompt_attention_mask, list): + current_prompt_attention_mask = prompt_attention_mask[shot_idx] + elif isinstance(prompt_attention_mask, torch.Tensor): + current_prompt_attention_mask = prompt_attention_mask + + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = ( + self._prepare_prompt_context( + shot_prompt, + device=device, + max_sequence_length=max_sequence_length, + prompt_embeds=current_prompt_embeds, + prompt_attention_mask=current_prompt_attention_mask, + ) + ) + + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + noise_scale=0.0, + dtype=torch.float32, + device=device, + generator=generator, + ) + audio_latents = self.prepare_audio_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=0.0, + dtype=torch.float32, + device=device, + generator=generator, + ) + + latents, audio_latents = self._denoise_shot( + latents=latents, + audio_latents=audio_latents, + denoising_sigmas=denoising_sigmas.to(device), + prompt_embeds=connector_prompt_embeds, + audio_prompt_embeds=connector_audio_prompt_embeds, + prompt_attention_mask=connector_attention_mask, + video_coords=video_coords, + audio_coords=audio_coords, + latent_num_frames=latent_num_frames, + latent_height=latent_height, + latent_width=latent_width, + audio_num_frames=audio_num_frames, + frame_rate=frame_rate, + memory_bank=memory_bank, + transformer_outputs_x0=transformer_outputs_x0, + generator=generator, + attention_kwargs=attention_kwargs, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + memory_video_latents, memory_video_coords = _select_memory_video_tokens( + latents.detach().cpu(), + video_coords.detach().cpu(), + latent_num_frames=latent_num_frames, + ) + memory_audio_latents, memory_audio_coords = _select_memory_audio_tokens( + audio_latents.detach().cpu(), + audio_coords.detach().cpu(), + window_size=96, + ) + memory_bank.append( + JoyAIEchoMemorySlot( + latents=memory_video_latents, + video_coords=memory_video_coords, + audio_latents=memory_audio_latents, + audio_coords=memory_audio_coords, + ) + ) + + frames, audio = self._decode_latents( + latents, + audio_latents, + latent_num_frames=latent_num_frames, + latent_height=latent_height, + latent_width=latent_width, + audio_num_frames=audio_num_frames, + latent_mel_bins=latent_mel_bins, + output_type=output_type, + decode_timestep=decode_timestep, + decode_noise_scale=decode_noise_scale, + generator=generator, + ) + shots.append( + JoyAIEchoShotOutput( + frames=frames, + audio=audio, + latents=latents if return_latents else None, + audio_latents=audio_latents if return_latents else None, + ) + ) + + self.maybe_free_model_hooks() + + frames = [shot.frames for shot in shots] + audio = [shot.audio for shot in shots] + if not return_dict: + return frames, audio + return JoyAIEchoPipelineOutput(frames=frames, audio=audio, shots=shots) diff --git a/src/diffusers/pipelines/joyai_echo/pipeline_joyai_echo_original_checkpoint.py b/src/diffusers/pipelines/joyai_echo/pipeline_joyai_echo_original_checkpoint.py new file mode 100644 index 000000000000..ea9f337d3023 --- /dev/null +++ b/src/diffusers/pipelines/joyai_echo/pipeline_joyai_echo_original_checkpoint.py @@ -0,0 +1,411 @@ +# Copyright 2026 JD.com and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); + +from __future__ import annotations + +import gc +import json +import sys +import time +from pathlib import Path +from typing import Any + +import torch + +from ..pipeline_utils import DiffusionPipeline + + +class JoyAIEchoOriginalCheckpointPipeline(DiffusionPipeline): + r""" + Diffusers pipeline wrapper for running the original JoyAI-Echo release checkpoint. + + This class provides a diffusers entrypoint for the released JoyAI-Echo safetensors checkpoint while preserving the + official inference math: Gemma prompt encoding is separated from generator loading, the distilled DMD sigma schedule + predicts `x0`, and paired audio-video memory is chained across shots. + + Args: + checkpoint_path (`str`): + Path to the original JoyAI-Echo `.safetensors` checkpoint. + gemma_path (`str`): + Path to the Gemma text encoder directory. + original_repo (`str`): + Path to a JoyAI-Echo checkout containing `ltx-core`, `ltx-pipelines`, and `ltx-distillation`. + device (`str`, defaults to `"cuda"`): + Device used for inference. + torch_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Model dtype. + """ + + _optional_components = [] + + def __init__( + self, + checkpoint_path: str, + gemma_path: str, + original_repo: str, + device: str | torch.device = "cuda", + torch_dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.checkpoint_path = str(Path(checkpoint_path).expanduser().resolve()) + self.gemma_path = str(Path(gemma_path).expanduser().resolve()) + self.original_repo = str(Path(original_repo).expanduser().resolve()) + self._joyai_echo_device = torch.device(device) + self.torch_dtype = torch_dtype + self.register_to_config( + checkpoint_path=self.checkpoint_path, + gemma_path=self.gemma_path, + original_repo=self.original_repo, + ) + + self._ensure_original_modules() + self.generator = None + self.video_vae = None + self.audio_vae = None + self.base_pipeline = None + self.memory_pipeline = None + self.audio_sample_rate = None + + @classmethod + def from_original_checkpoint( + cls, + checkpoint_path: str, + gemma_path: str, + original_repo: str, + device: str | torch.device = "cuda", + torch_dtype: torch.dtype = torch.bfloat16, + ) -> "JoyAIEchoOriginalCheckpointPipeline": + return cls( + checkpoint_path=checkpoint_path, + gemma_path=gemma_path, + original_repo=original_repo, + device=device, + torch_dtype=torch_dtype, + ) + + def _ensure_original_modules(self) -> None: + repo = Path(self.original_repo) + for subpath in ["ltx-core/src", "ltx-pipelines/src", "ltx-distillation/src"]: + path = str(repo / subpath) + if path not in sys.path: + sys.path.insert(0, path) + + @staticmethod + def _empty_cuda(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.empty_cache() + + @staticmethod + def _move(module, target_device) -> None: + if module is not None: + module.to(target_device) + + def encode_prompts(self, prompts: list[str]) -> list[dict[str, Any]]: + self._ensure_original_modules() + from ltx_distillation.models.text_encoder_wrapper import create_text_encoder_wrapper + + text_encoder = create_text_encoder_wrapper( + checkpoint_path=self.checkpoint_path, + gemma_path=self.gemma_path, + device=self._joyai_echo_device, + dtype=self.torch_dtype, + ) + text_encoder.eval() + + cached = [] + for prompt in prompts: + cond = text_encoder([prompt]) + cached.append({k: (v.detach().cpu() if isinstance(v, torch.Tensor) else v) for k, v in cond.items()}) + del cond + + del text_encoder + gc.collect() + self._empty_cuda(self._joyai_echo_device) + return cached + + def load_generator( + self, + denoising_sigmas: list[float] | torch.Tensor, + video_height: int, + video_width: int, + memory_downscale_factor: int = 1, + ) -> None: + self._ensure_original_modules() + from ltx_distillation.inference.bidirectional_pipeline import BidirectionalAVInferencePipeline + from ltx_distillation.inference.memory_bidirectional_pipeline import BidirectionalMemoryAVInferencePipeline + from ltx_distillation.models.ltx_wrapper import create_ltx2_wrapper + from ltx_distillation.models.vae_wrapper import create_vae_wrappers + from ltx_distillation.utils import add_noise + + self.generator = create_ltx2_wrapper( + checkpoint_path=self.checkpoint_path, + gemma_path=self.gemma_path, + device=self._joyai_echo_device, + dtype=self.torch_dtype, + video_height=int(video_height), + video_width=int(video_width), + loras=(), + ) + self.generator.eval() + + self.video_vae, self.audio_vae = create_vae_wrappers( + checkpoint_path=self.checkpoint_path, + device=torch.device("cpu"), + dtype=self.torch_dtype, + with_video_encoder=True, + with_audio_encoder=True, + decoder_device=torch.device("cpu"), + ) + self.video_vae.eval() + self.audio_vae.eval() + + denoising_sigmas = torch.as_tensor(denoising_sigmas, device=self._joyai_echo_device, dtype=torch.float32) + self.base_pipeline = BidirectionalAVInferencePipeline( + generator=self.generator, + add_noise_fn=add_noise, + denoising_sigmas=denoising_sigmas, + ) + self.memory_pipeline = BidirectionalMemoryAVInferencePipeline( + generator=self.generator, + add_noise_fn=add_noise, + denoising_sigmas=denoising_sigmas, + memory_downscale_factor=int(memory_downscale_factor), + ) + self.audio_sample_rate = self.audio_vae.get_output_sample_rate() or 24000 + + def _stage_for_denoise(self) -> None: + self._move(self.video_vae.encoder, "cpu") + self._move(self.video_vae.decoder, "cpu") + self._move(self.audio_vae.encoder, "cpu") + self._move(self.audio_vae.decoder, "cpu") + self._move(self.audio_vae.vocoder, "cpu") + self._move(self.generator, self._joyai_echo_device) + self._empty_cuda(self._joyai_echo_device) + + def _stage_for_video_encode(self) -> None: + self._move(self.video_vae.encoder, self._joyai_echo_device) + + def _stage_after_video_encode(self) -> None: + self._move(self.video_vae.encoder, "cpu") + self._empty_cuda(self._joyai_echo_device) + + def _stage_for_decode(self) -> None: + self._move(self.generator, "cpu") + self._empty_cuda(self._joyai_echo_device) + self._move(self.video_vae.decoder, self._joyai_echo_device) + self._move(self.audio_vae.decoder, self._joyai_echo_device) + self._move(self.audio_vae.vocoder, self._joyai_echo_device) + + @torch.no_grad() + def __call__( + self, + prompts: list[str], + output_dir: str | Path, + cached_conds: list[dict[str, Any]] | None = None, + num_frames: int = 241, + height: int = 736, + width: int = 1280, + fps: int = 25, + seed: int = 12345, + memory_max_size: int = 7, + num_fix_frames: int = 3, + save_mode: str = "random_every_shot_frame", + enable_audio_memory: bool = True, + v2a_grad_scale: float = 2.0, + memory_position_mode: str = "reference", + audio_memory_window_size: int = 96, + audio_memory_window_selection_mode: str = "max_response", + video_memory_frame_selection_mode: str = "center", + video_memory_clip_num_frames: int = 9, + audio_memory_sample_rate: int = 16000, + audio_memory_mel_bins: int = 128, + audio_memory_mel_hop_length: int = 160, + audio_memory_n_fft: int = 1024, + audio_memory_downsample_factor: int = 4, + audio_memory_is_causal: bool = True, + ) -> dict[str, Any]: + self._ensure_original_modules() + import torchaudio + from ltx_distillation.inference.memory_multishot import ( + PairedAudioVideoMemoryBank, + audio_waveform_stats, + build_paired_audio_memory_kwargs, + video_uint8_to_pil_frames, + ) + from ltx_distillation.utils import ( + compute_latent_shapes, + concat_shot_audios, + concat_shot_videos, + decode_benchmark_sample, + encode_memory_frames_batch, + save_memory_bank_frames, + write_benchmark_media, + ) + + if self.generator is None: + raise RuntimeError("Call `load_generator(...)` before running the pipeline.") + if cached_conds is None: + cached_conds = self.encode_prompts(prompts) + if len(cached_conds) != len(prompts): + raise ValueError("`cached_conds` length must match `prompts` length.") + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + video_shape, audio_shape = compute_latent_shapes( + num_frames=int(num_frames), + video_height=int(height), + video_width=int(width), + batch_size=1, + video_fps=float(fps), + ) + memory_bank = PairedAudioVideoMemoryBank( + max_size=int(memory_max_size), + save_mode=str(save_mode), + num_fix_frames=int(num_fix_frames), + ) + + shot_paths: list[Path] = [] + shot_audios: list[torch.Tensor] = [] + metadata: dict[str, Any] = { + "checkpoint": self.checkpoint_path, + "gemma_path": self.gemma_path, + "num_prompts": len(prompts), + "shots": [], + } + + run_started = time.perf_counter() + for shot_idx, prompt in enumerate(prompts): + conditional_dict = { + k: (v.to(self._joyai_echo_device) if isinstance(v, torch.Tensor) else v) + for k, v in cached_conds[shot_idx].items() + } + prompt_seed = int(seed) + shot_idx + memory_size_before = len(memory_bank) + memory_video = None + memory_audio_kwargs: dict[str, Any] = {} + + self._stage_for_denoise() + with torch.random.fork_rng(devices=[self._joyai_echo_device]): + torch.manual_seed(prompt_seed) + if self._joyai_echo_device.type == "cuda": + torch.cuda.manual_seed(prompt_seed) + + if len(memory_bank) > 0: + self._stage_for_video_encode() + memory_video = encode_memory_frames_batch( + video_vae=self.video_vae, + batch_memory_frames=[memory_bank.get_memory_frames()], + target_h=int(height), + target_w=int(width), + device=self._joyai_echo_device, + dtype=self.torch_dtype, + ) + self._stage_after_video_encode() + + memory_audio_kwargs = build_paired_audio_memory_kwargs( + memory_bank, + enable_audio_memory=bool(enable_audio_memory), + v2a_grad_scale=float(v2a_grad_scale), + memory_position_mode=str(memory_position_mode), + ) + video_latent, audio_latent = self.memory_pipeline.generate( + video_shape=tuple(video_shape), + audio_shape=tuple(audio_shape), + conditional_dict=conditional_dict, + memory_video=memory_video, + seed=prompt_seed, + **memory_audio_kwargs, + ) + else: + video_latent, audio_latent = self.base_pipeline.generate( + video_shape=tuple(video_shape), + audio_shape=tuple(audio_shape), + conditional_dict=conditional_dict, + seed=prompt_seed, + ) + + del conditional_dict, memory_video, memory_audio_kwargs + + self._stage_for_decode() + audio_memory_latent = ( + audio_latent.detach().cpu().contiguous() + if (enable_audio_memory and audio_latent is not None) + else None + ) + video_uint8, audio_waveform = decode_benchmark_sample( + self.video_vae, self.audio_vae, video_latent, audio_latent + ) + memory_frames_for_bank = video_uint8_to_pil_frames(video_uint8) + + new_memory_metadata: dict[str, Any] = {} + if audio_memory_latent is not None: + new_memory_metadata = memory_bank.save_memory_slot( + memory_frames_for_bank, + audio_memory_latent, + audio_window_size=int(audio_memory_window_size), + video_clip_num_frames=int(video_memory_clip_num_frames), + audio_waveform=audio_waveform, + audio_sample_rate=int(audio_memory_sample_rate), + video_fps=float(fps), + audio_window_selection_mode=str(audio_memory_window_selection_mode), + video_frame_selection_mode=str(video_memory_frame_selection_mode), + audio_memory_mel_bins=int(audio_memory_mel_bins), + audio_memory_mel_hop_length=int(audio_memory_mel_hop_length), + audio_memory_n_fft=int(audio_memory_n_fft), + audio_memory_downsample_factor=int(audio_memory_downsample_factor), + audio_memory_is_causal=bool(audio_memory_is_causal), + ) + + save_memory_bank_frames( + memory_bank.get_memory_frames(), output_dir / "memory_bank" / f"shot_{shot_idx:03d}" + ) + + shot_path = output_dir / f"shot_{shot_idx:03d}.mp4" + write_result = write_benchmark_media( + output_path=shot_path, + video_uint8=video_uint8, + audio_waveform=audio_waveform, + fps=int(fps), + audio_sr=int(self.audio_sample_rate), + ) + shot_paths.append(shot_path) + if audio_waveform is not None: + shot_audios.append(audio_waveform.cpu()) + + metadata["shots"].append( + { + "shot_idx": int(shot_idx), + "prompt": prompt, + "output_path": str(shot_path), + "memory_size_before": int(memory_size_before), + "memory_size_after": int(len(memory_bank)), + "new_memory_entry": new_memory_metadata, + "audio_latent_shape": list(audio_latent.shape) if audio_latent is not None else None, + "wrote_audio_in_mp4": bool(write_result["wrote_audio_in_mp4"]), + "wrote_sidecar_wav": bool(write_result["wrote_sidecar_wav"]), + "audio_stats": write_result["audio_stats"], + "memory_entries": memory_bank.get_memory_metadata(), + } + ) + + del video_latent, audio_latent, video_uint8, audio_waveform, audio_memory_latent, memory_frames_for_bank + self._empty_cuda(self._joyai_echo_device) + + combined_path = output_dir / "combined_shots.mp4" + concat_shot_videos(shot_paths, combined_path) + combined_audio = concat_shot_audios(shot_audios) + combined_audio_path = None + if combined_audio is not None: + combined_audio_path = output_dir / "combined_shots.wav" + torchaudio.save(str(combined_audio_path), combined_audio, sample_rate=int(self.audio_sample_rate)) + + metadata["combined_path"] = str(combined_path) + metadata["combined_audio_path"] = str(combined_audio_path) if combined_audio_path else None + metadata["combined_audio_stats"] = audio_waveform_stats(combined_audio) + metadata["run_total_sec"] = round(time.perf_counter() - run_started, 3) + (output_dir / "run_metadata.json").write_text( + json.dumps(metadata, ensure_ascii=False, indent=2), encoding="utf-8" + ) + return metadata diff --git a/src/diffusers/pipelines/joyai_echo/pipeline_output.py b/src/diffusers/pipelines/joyai_echo/pipeline_output.py new file mode 100644 index 000000000000..4817178d0b69 --- /dev/null +++ b/src/diffusers/pipelines/joyai_echo/pipeline_output.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass + +import torch + +from ...utils import BaseOutput + + +@dataclass +class JoyAIEchoShotOutput(BaseOutput): + r""" + Output class for one JoyAI-Echo shot. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + Generated video frames for the shot. + audio (`torch.Tensor` or `np.ndarray`): + Generated waveform for the shot. + latents (`torch.Tensor`, *optional*): + Generated packed video latents before decoding. + audio_latents (`torch.Tensor`, *optional*): + Generated packed audio latents before decoding. + """ + + frames: torch.Tensor + audio: torch.Tensor + latents: torch.Tensor | None = None + audio_latents: torch.Tensor | None = None + + +@dataclass +class JoyAIEchoPipelineOutput(BaseOutput): + r""" + Output class for JoyAI-Echo multi-shot generation. + + Args: + frames (`list`): + Generated video frames for each shot. + audio (`list`): + Generated waveform for each shot. + shots (`list[JoyAIEchoShotOutput]`): + Per-shot structured outputs. + """ + + frames: list + audio: list + shots: list[JoyAIEchoShotOutput] diff --git a/tests/pipelines/joyai_echo/__init__.py b/tests/pipelines/joyai_echo/__init__.py new file mode 100644 index 000000000000..8ea577c73919 --- /dev/null +++ b/tests/pipelines/joyai_echo/__init__.py @@ -0,0 +1,4 @@ +from .test_joyai_echo import JoyAIEchoPipelineFastTests + + +__all__ = ["JoyAIEchoPipelineFastTests"] diff --git a/tests/pipelines/joyai_echo/test_joyai_echo.py b/tests/pipelines/joyai_echo/test_joyai_echo.py new file mode 100644 index 000000000000..18e8f0c7e15c --- /dev/null +++ b/tests/pipelines/joyai_echo/test_joyai_echo.py @@ -0,0 +1,170 @@ +# Copyright 2026 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from types import SimpleNamespace + +import torch + +from diffusers import FlowMatchEulerDiscreteScheduler, JoyAIEchoPipeline + + +class DummyCoords: + def __init__(self, dims): + self.dims = dims + + def prepare_video_coords(self, batch_size, num_frames, height, width, device, fps=24.0): + return torch.zeros(batch_size, 3, num_frames * height * width, 2, device=device) + + def prepare_audio_coords(self, batch_size, num_frames, device): + return torch.zeros(batch_size, 1, num_frames, 2, device=device) + + +class DummyTransformer(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(in_channels=3, patch_size=1, patch_size_t=1) + self.rope = DummyCoords(3) + self.audio_rope = DummyCoords(1) + self.seen_video_tokens = [] + self.seen_audio_tokens = [] + + @property + def dtype(self): + return torch.float32 + + def forward(self, hidden_states, audio_hidden_states, **kwargs): + self.seen_video_tokens.append(hidden_states.shape[1]) + self.seen_audio_tokens.append(audio_hidden_states.shape[1]) + return hidden_states, audio_hidden_states + + +class DummyConnectors(torch.nn.Module): + @property + def dtype(self): + return torch.float32 + + def forward(self, prompt_embeds, prompt_attention_mask, padding_side="left"): + return prompt_embeds, prompt_embeds, prompt_attention_mask + + +class DummyVideoVAE(torch.nn.Module): + spatial_compression_ratio = 1 + temporal_compression_ratio = 1 + + def __init__(self): + super().__init__() + self.config = SimpleNamespace(scaling_factor=1.0, timestep_conditioning=False) + self.latents_mean = torch.zeros(3) + self.latents_std = torch.ones(3) + + @property + def dtype(self): + return torch.float32 + + def decode(self, latents, timestep=None, return_dict=False): + return (latents,) + + +class DummyAudioVAE(torch.nn.Module): + mel_compression_ratio = 1 + temporal_compression_ratio = 1 + + def __init__(self): + super().__init__() + self.config = SimpleNamespace( + mel_bins=2, + latent_channels=1, + output_channels=1, + sample_rate=16000, + mel_hop_length=160, + ) + self.latents_mean = torch.zeros(2) + self.latents_std = torch.ones(2) + + @property + def dtype(self): + return torch.float32 + + def decode(self, latents, return_dict=False): + return (latents,) + + +class DummyVocoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(out_channels=1, output_sampling_rate=16000) + + @property + def dtype(self): + return torch.float32 + + def forward(self, mel_spectrograms): + return mel_spectrograms.flatten(2) + + +class JoyAIEchoPipelineFastTests(unittest.TestCase): + def get_dummy_components(self): + return { + "scheduler": FlowMatchEulerDiscreteScheduler(), + "vae": DummyVideoVAE(), + "audio_vae": DummyAudioVAE(), + "text_encoder": None, + "tokenizer": None, + "connectors": DummyConnectors(), + "transformer": DummyTransformer(), + "vocoder": DummyVocoder(), + "processor": None, + } + + def test_multishot_memory_prefix(self): + components = self.get_dummy_components() + pipe = JoyAIEchoPipeline(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt_embeds = [ + torch.zeros(1, 2, 4), + torch.zeros(1, 2, 4), + ] + prompt_attention_mask = [ + torch.ones(1, 2, dtype=torch.long), + torch.ones(1, 2, dtype=torch.long), + ] + + output = pipe( + ["first shot", "second shot"], + height=4, + width=4, + num_frames=1, + frame_rate=100.0, + denoising_sigmas=[1.0, 0.0], + memory_max_size=1, + generator=generator, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + output_type="pt", + return_latents=True, + ) + + self.assertEqual(len(output.shots), 2) + self.assertEqual(output.frames[0].shape, (1, 1, 3, 4, 4)) + self.assertEqual(output.audio[0].shape, (1, 1, 2)) + self.assertEqual(components["transformer"].seen_video_tokens, [16, 32]) + self.assertEqual(components["transformer"].seen_audio_tokens, [1, 2]) + + +if __name__ == "__main__": + unittest.main()