Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Ascend NPU platform across several models (Flux2, Hunyuan Video, LTX2, Qwen, and Wan), including new configuration files, shell scripts, and platform-specific adaptations such as NPU flash attention, float32 position embeddings, and BF16 replicate padding workarounds. Key feedback highlights a potential startup crash in wan_sf_runner.py due to module-level evaluation of AI_DEVICE, a shape mismatch issue in _npu_attention that can be resolved by unsqueezing tensors to 4D, and the presence of hardcoded absolute paths in the newly added configuration files.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| out_unpad = ATTN_WEIGHT_REGISTER["npu_flash_attn"]().apply( | ||
| q=q_unpad, | ||
| k=k_unpad, | ||
| v=v_unpad, | ||
| cu_seqlens_q=cu_seqlens, | ||
| cu_seqlens_kv=cu_seqlens, | ||
| max_seqlen_q=int(lengths.max().item()), | ||
| max_seqlen_kv=int(lengths.max().item()), | ||
| ) |
There was a problem hiding this comment.
The NpuFlashAttnWeight.apply implementation expects a 4D tensor when batch > 1 or when padding is present, and it reshapes the output using bs * max_seqlen_q. When q_unpad is passed as a 3D tensor, apply sets bs = 1, which causes a RuntimeError due to shape mismatch (total_unpadded_tokens vs max_seqlen_q) during the final reshape.
To resolve this, we can unsqueeze the unpadded tensors to 4D of shape (total_unpadded_tokens, 1, heads, dim) and set max_seqlen_q=1. This correctly restores the 3D shape inside apply while ensuring the output is reshaped to (total_unpadded_tokens, heads * dim) without any shape mismatch.
| out_unpad = ATTN_WEIGHT_REGISTER["npu_flash_attn"]().apply( | |
| q=q_unpad, | |
| k=k_unpad, | |
| v=v_unpad, | |
| cu_seqlens_q=cu_seqlens, | |
| cu_seqlens_kv=cu_seqlens, | |
| max_seqlen_q=int(lengths.max().item()), | |
| max_seqlen_kv=int(lengths.max().item()), | |
| ) | |
| out_unpad = ATTN_WEIGHT_REGISTER["npu_flash_attn"]().apply( | |
| q=q_unpad.unsqueeze(1), | |
| k=k_unpad.unsqueeze(1), | |
| v=v_unpad.unsqueeze(1), | |
| cu_seqlens_q=cu_seqlens, | |
| cu_seqlens_kv=cu_seqlens, | |
| max_seqlen_q=1, | |
| max_seqlen_kv=1, | |
| ) |
| from lightx2v_platform.base.global_var import AI_DEVICE | ||
|
|
||
| torch_device_module = getattr(torch, AI_DEVICE) |
There was a problem hiding this comment.
Importing AI_DEVICE directly from lightx2v_platform.base.global_var and evaluating getattr(torch, AI_DEVICE) at the module level (import time) will raise a TypeError because AI_DEVICE is initialized to None at import time.
To prevent this startup crash and support dynamic device updates, we should evaluate the device module lazily using a proxy object or a helper function.
from lightx2v_platform.base.global_var import AI_DEVICE
class TorchDeviceModuleProxy:
def __getattr__(self, name):
from lightx2v_platform.base.global_var import AI_DEVICE
device = AI_DEVICE or "cuda"
return getattr(getattr(torch, device), name)
torch_device_module = TorchDeviceModuleProxy()| "audio_mel_bins":16, | ||
| "double_precision_rope": false, | ||
| "use_tiling_vae": false, | ||
| "dit_original_ckpt": "/data/wushuo1/to5_models/LTX-2.3/ltx-2.3-22b-dev.safetensors", |
There was a problem hiding this comment.
The configuration contains a hardcoded absolute path /data/wushuo1/to5_models/LTX-2.3/ltx-2.3-22b-dev.safetensors which is specific to a single user's environment. This makes the configuration non-portable and will fail on other machines or in production. Consider using a relative path, an environment variable, or a placeholder.
| "dit_original_ckpt": "/data/wushuo1/to5_models/LTX-2.3/ltx-2.3-22b-dev.safetensors", | |
| "dit_original_ckpt": "checkpoints/ltx-2.3-22b-dev.safetensors", |
| "sample_shift": 5.0, | ||
| "enable_cfg": false, | ||
| "cpu_offload": false, | ||
| "dit_original_ckpt": "/data/wushuo1/to5_models/Self-Forcing/checkpoints/self_forcing_dmd.pt", |
There was a problem hiding this comment.
The configuration contains a hardcoded absolute path /data/wushuo1/to5_models/Self-Forcing/checkpoints/self_forcing_dmd.pt which is specific to a single user's environment. This makes the configuration non-portable and will fail on other machines or in production. Consider using a relative path, an environment variable, or a placeholder.
| "dit_original_ckpt": "/data/wushuo1/to5_models/Self-Forcing/checkpoints/self_forcing_dmd.pt", | |
| "dit_original_ckpt": "checkpoints/self_forcing_dmd.pt", |
No description provided.