Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions configs/platforms/ascend_npu/flux2_dev.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"model_cls": "flux2_dev",
"task": "t2i",
"infer_steps": 50,
"aspect_ratio": "16:9",
"sample_guide_scale": 4.0,
"vae_scale_factor": 16,
"feature_caching": "None",
"enable_cfg": false,
"patch_size": 2,
"tokenizer_max_length": 512,
"text_encoder_out_layers": [10, 20, 30],
"attn_type": "npu_flash_attn",
"cpu_offload": true,
"offload_granularity": "block",
"modulate_type": "torch",
"rope_type": "torch",
"layer_norm_type": "torch",
"rms_norm_type": "torch"
}
16 changes: 16 additions & 0 deletions configs/platforms/ascend_npu/hunyuan_video_t2v_480p.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"infer_steps": 50,
"transformer_model_name": "480p_i2v",
"fps": 24,
"target_video_length": 121,
"vae_stride": [4, 16, 16],
"sample_shift": 5.0,
"sample_guide_scale": 6.0,
"aspect_ratio": "16:9",
"enable_cfg": true,
"attn_type": "npu_flash_attn",
"modulate_type": "torch",
"rope_type": "torch",
"layer_norm_type": "torch",
"rms_norm_type": "torch"
}
46 changes: 46 additions & 0 deletions configs/platforms/ascend_npu/ltx2_3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
{
"infer_steps": 30,
"target_video_length": 121,
"target_height": 512,
"target_width": 768,
"attn_type": "npu_flash_attn",
"sample_guide_scale": 3.0,
"sample_shift": [2.05, 0.95],
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"norm_modulate_backend": "torch",
"modulate_type": "torch",
"rope_type": "torch",
"layer_norm_type": "torch",
"rms_norm_type": "torch",
"num_channels_latents": 128,
"fps": 24,
"audio_fps": 24000,
"audio_mel_bins":16,
"double_precision_rope": false,
"use_tiling_vae": false,
"dit_original_ckpt": "/data/wushuo1/to5_models/LTX-2.3/ltx-2.3-22b-dev.safetensors",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The configuration contains a hardcoded absolute path /data/wushuo1/to5_models/LTX-2.3/ltx-2.3-22b-dev.safetensors which is specific to a single user's environment. This makes the configuration non-portable and will fail on other machines or in production. Consider using a relative path, an environment variable, or a placeholder.

Suggested change
"dit_original_ckpt": "/data/wushuo1/to5_models/LTX-2.3/ltx-2.3-22b-dev.safetensors",
"dit_original_ckpt": "checkpoints/ltx-2.3-22b-dev.safetensors",

"caption_proj_before_connector": true,
"cross_attention_adaln": true,
"apply_gated_attention": true,
"mm_guider": {
"enabled": true,
"video": {
"cfg_scale": 3.0,
"stg_scale": 1.0,
"stg_blocks": [28],
"modality_scale": 3.0,
"rescale_scale": 0.7,
"skip_step": 0
},
"audio": {
"cfg_scale": 7.0,
"stg_scale": 1.0,
"stg_blocks": [28],
"modality_scale": 3.0,
"rescale_scale": 0.7,
"skip_step": 0
}
}
}
2 changes: 1 addition & 1 deletion configs/platforms/ascend_npu/qwen_image_t2i_2512.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"aspect_ratio": "16:9",
"prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 34,
"attn_type": "flash_attn3",
"attn_type": "npu_flash_attn",
"enable_cfg": true,
"sample_guide_scale": 4.0,
"cpu_offload": true,
Expand Down
27 changes: 27 additions & 0 deletions configs/platforms/ascend_npu/wan_t2v_sf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"infer_steps": 4,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "npu_flash_attn",
"cross_attn_1_type": "npu_flash_attn",
"cross_attn_2_type": "npu_flash_attn",
"sample_guide_scale": 1,
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": false,
"dit_original_ckpt": "/data/wushuo1/to5_models/Self-Forcing/checkpoints/self_forcing_dmd.pt",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The configuration contains a hardcoded absolute path /data/wushuo1/to5_models/Self-Forcing/checkpoints/self_forcing_dmd.pt which is specific to a single user's environment. This makes the configuration non-portable and will fail on other machines or in production. Consider using a relative path, an environment variable, or a placeholder.

Suggested change
"dit_original_ckpt": "/data/wushuo1/to5_models/Self-Forcing/checkpoints/self_forcing_dmd.pt",
"dit_original_ckpt": "checkpoints/self_forcing_dmd.pt",

"causal_rope_type": "torch",
"modulate_type": "torch",
"rope_type": "torch",
"layer_norm_type": "torch",
"rms_norm_type": "torch",
"ar_config": {
"local_attn_size": -1,
"num_frame_per_chunk": 3,
"timesteps_index": [0, 179, 358, 679],
"kv_offload": false,
"async_vae_decode": false
}
}
63 changes: 60 additions & 3 deletions lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from torch.nn import functional as F

from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE, PLATFORM

from .attn_no_pad import flash_attn_no_pad, flash_attn_no_pad_v3, sage_attn_no_pad_v2
from .attn_no_pad import flash_attn_no_pad, flash_attn_no_pad_v3, flash_attn_varlen_qkvpacked_func, pad_input, sage_attn_no_pad_v2, unpad_input
from .module_io import HunyuanVideo15InferModuleOutput
from .posemb_layers import get_nd_rotary_pos_embed

Expand All @@ -21,6 +22,15 @@
TIMESTEP_EMBEDDING_CUDA_AVAILABLE = False


_TOKEN_REFINER_ATTN_TYPE_BY_PLATFORM = {
"ascend_npu": "npu_flash_attn",
}


def _get_token_refiner_attn_type():
return _TOKEN_REFINER_ATTN_TYPE_BY_PLATFORM.get(PLATFORM, "flash_attn2")


def apply_gate(x, gate=None, tanh=False):
"""AI is creating summary for apply_gate

Expand All @@ -40,6 +50,46 @@ def apply_gate(x, gate=None, tanh=False):
return x * gate.unsqueeze(1)


def _torch_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, causal: bool = False) -> torch.Tensor:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if attn_mask is not None:
attn_mask = attn_mask[:, None, None, :].expand(-1, 1, q.shape[-2], -1)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal)
return out.transpose(1, 2)


def _npu_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch, seqlen, heads, dim = q.shape
if attn_mask is None:
lengths = torch.full((batch,), seqlen, dtype=torch.int32, device=q.device)
indices = torch.arange(batch * seqlen, device=q.device)
else:
attn_mask = attn_mask.bool()
lengths = attn_mask.sum(dim=1, dtype=torch.int32)
indices = attn_mask.reshape(-1).nonzero(as_tuple=False).squeeze(1)

cu_seqlens = torch.zeros(batch + 1, dtype=torch.int32, device=q.device)
cu_seqlens[1:] = torch.cumsum(lengths, dim=0)
q_unpad = q.reshape(batch * seqlen, heads, dim)[indices]
k_unpad = k.reshape(batch * seqlen, heads, dim)[indices]
v_unpad = v.reshape(batch * seqlen, heads, dim)[indices]

out_unpad = ATTN_WEIGHT_REGISTER["npu_flash_attn"]().apply(
q=q_unpad,
k=k_unpad,
v=v_unpad,
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
max_seqlen_q=int(lengths.max().item()),
max_seqlen_kv=int(lengths.max().item()),
)
Comment on lines +79 to +87

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The NpuFlashAttnWeight.apply implementation expects a 4D tensor when batch > 1 or when padding is present, and it reshapes the output using bs * max_seqlen_q. When q_unpad is passed as a 3D tensor, apply sets bs = 1, which causes a RuntimeError due to shape mismatch (total_unpadded_tokens vs max_seqlen_q) during the final reshape.

To resolve this, we can unsqueeze the unpadded tensors to 4D of shape (total_unpadded_tokens, 1, heads, dim) and set max_seqlen_q=1. This correctly restores the 3D shape inside apply while ensuring the output is reshaped to (total_unpadded_tokens, heads * dim) without any shape mismatch.

Suggested change
out_unpad = ATTN_WEIGHT_REGISTER["npu_flash_attn"]().apply(
q=q_unpad,
k=k_unpad,
v=v_unpad,
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
max_seqlen_q=int(lengths.max().item()),
max_seqlen_kv=int(lengths.max().item()),
)
out_unpad = ATTN_WEIGHT_REGISTER["npu_flash_attn"]().apply(
q=q_unpad.unsqueeze(1),
k=k_unpad.unsqueeze(1),
v=v_unpad.unsqueeze(1),
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
max_seqlen_q=1,
max_seqlen_kv=1,
)

out = torch.zeros(batch * seqlen, heads * dim, dtype=out_unpad.dtype, device=out_unpad.device)
out[indices] = out_unpad
return out.reshape(batch, seqlen, heads, dim)


@torch.compiler.disable
def attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, drop_rate: float = 0.0, attn_mask: Optional[torch.Tensor] = None, causal: bool = False, attn_type: str = "flash_attn2"
Expand All @@ -62,11 +112,18 @@ def attention(
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.bool()
if attn_type == "flash_attn2":
x = flash_attn_no_pad(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
if flash_attn_varlen_qkvpacked_func is None or pad_input is None or unpad_input is None:
x = _torch_attention(q, k, v, attn_mask=attn_mask, causal=causal)
else:
x = flash_attn_no_pad(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
elif attn_type == "flash_attn3":
x = flash_attn_no_pad_v3(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
elif attn_type == "sage_attn2":
x = sage_attn_no_pad_v2(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
elif attn_type == "npu_flash_attn":
x = _npu_attention(q, k, v, attn_mask=attn_mask)
elif attn_type == "torch":
x = _torch_attention(q, k, v, attn_mask=attn_mask, causal=causal)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
Expand Down Expand Up @@ -218,7 +275,7 @@ def run_individual_token_refiner(self, weights, out, mask, c):
norm_x = block.norm1.apply(out.unsqueeze(0)).squeeze(0)
qkv = block.self_attn_qkv.apply(norm_x).unsqueeze(0)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
attn = attention(q, k, v, attn_mask=mask, attn_type="flash_attn2").squeeze(0)
attn = attention(q, k, v, attn_mask=mask, attn_type=_get_token_refiner_attn_type()).squeeze(0)
out = out + apply_gate(block.self_attn_proj.apply(attn).unsqueeze(0), gate_msa).squeeze(0)
tmp = block.mlp_fc1.apply(block.norm2.apply(out))
tmp = torch.nn.functional.silu(tmp)
Expand Down
35 changes: 29 additions & 6 deletions lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,47 @@
from lightx2v.models.networks.wan.infer.module_io import GridOutput
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE
from lightx2v_platform.base.global_var import AI_DEVICE, PLATFORM

_POSITION_FLOAT32_PLATFORMS = {
"ascend_npu",
"cambricon_mlu",
"metax_cuda",
}


def _position_math_dtype():
if PLATFORM in _POSITION_FLOAT32_PLATFORMS:
return torch.float32
return torch.float64


def _empty_device_cache():
device_module = getattr(torch, AI_DEVICE, None)
if device_module is not None and hasattr(device_module, "empty_cache"):
device_module.empty_cache()


def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
dtype = _position_math_dtype()
position = position.type(dtype)

# calculation
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half, device=position.device, dtype=dtype).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x


def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
dtype = _position_math_dtype()
freqs = torch.outer(
torch.arange(max_seq_len, dtype=dtype),
1.0 / torch.pow(theta, torch.arange(0, dim, 2, dtype=dtype).div(dim)),
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs

Expand Down Expand Up @@ -97,12 +120,12 @@ def infer(self, weights, inputs, kv_start=0, kv_end=0):
context = weights.text_embedding_2.apply(out)
if self.clean_cuda_cache:
del out
torch.cuda.empty_cache()
_empty_device_cache()

if self.clean_cuda_cache:
if self.config.get("use_image_encoder", True):
del context_clip
torch.cuda.empty_cache()
_empty_device_cache()

grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

_USE_FLASH_ATTN_V3 = True
except ImportError:
from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_func_v2
try:
from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_func_v2
except ImportError:
flash_attn_func_v2 = None

_USE_FLASH_ATTN_V3 = False
from ...comm.communication import _All2All
Expand Down
1 change: 1 addition & 0 deletions lightx2v/models/runners/flux2/flux2_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def run_pipeline(self, input_info):

latents, generator = self.run_dit()
images = self.run_vae_decoder(latents)
self.end_run()

if not input_info.return_result_tensor and is_main_process():
image = images[0]
Expand Down
16 changes: 16 additions & 0 deletions lightx2v/models/runners/ltx2/ltx2_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@ def _ltx2_parse_image_paths(image_path: str) -> list[str]:
return [p.strip() for p in image_path.split(",") if p.strip()]


def _ltx2_audio_to_stereo(audio: Audio) -> Audio:
waveform = audio.waveform
if waveform.dim() == 3:
if waveform.shape[1] == 1:
waveform = waveform.expand(waveform.shape[0], 2, waveform.shape[2]).contiguous()
elif waveform.shape[1] > 2:
waveform = waveform[:, :2, :].contiguous()
elif waveform.dim() == 2:
if waveform.shape[0] == 1:
waveform = waveform.expand(2, waveform.shape[1]).contiguous()
elif waveform.shape[0] > 2:
waveform = waveform[:2, :].contiguous()
return Audio(waveform=waveform, sampling_rate=audio.sampling_rate)


def _ltx2_normalize_image_strengths(image_strength, n: int) -> list[float]:
if not isinstance(image_strength, list):
return [float(image_strength)] * n
Expand Down Expand Up @@ -483,6 +498,7 @@ def _run_input_encoder_local_ltx2_s2v(self):
decoded = decode_audio_from_file(ap, enc_device, 0.0, max_duration)
if decoded is None:
raise ValueError(f"ltx2_s2v: failed to decode audio from {ap!r}.")
decoded = _ltx2_audio_to_stereo(decoded)

with torch.no_grad():
encoded = encode_audio(decoded, self.audio_vae.encoder)
Expand Down
Loading
Loading