Skip to content

adapt ascend npu#1158

Open
Watebear wants to merge 7 commits into
mainfrom
ascend
Open

adapt ascend npu#1158
Watebear wants to merge 7 commits into
mainfrom
ascend

Conversation

@Watebear

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Ascend NPU platform across several models (Flux2, Hunyuan Video, LTX2, Qwen, and Wan), including new configuration files, shell scripts, and platform-specific adaptations such as NPU flash attention, float32 position embeddings, and BF16 replicate padding workarounds. Key feedback highlights a potential startup crash in wan_sf_runner.py due to module-level evaluation of AI_DEVICE, a shape mismatch issue in _npu_attention that can be resolved by unsqueezing tensors to 4D, and the presence of hardcoded absolute paths in the newly added configuration files.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +79 to +87
out_unpad = ATTN_WEIGHT_REGISTER["npu_flash_attn"]().apply(
q=q_unpad,
k=k_unpad,
v=v_unpad,
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
max_seqlen_q=int(lengths.max().item()),
max_seqlen_kv=int(lengths.max().item()),
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The NpuFlashAttnWeight.apply implementation expects a 4D tensor when batch > 1 or when padding is present, and it reshapes the output using bs * max_seqlen_q. When q_unpad is passed as a 3D tensor, apply sets bs = 1, which causes a RuntimeError due to shape mismatch (total_unpadded_tokens vs max_seqlen_q) during the final reshape.

To resolve this, we can unsqueeze the unpadded tensors to 4D of shape (total_unpadded_tokens, 1, heads, dim) and set max_seqlen_q=1. This correctly restores the 3D shape inside apply while ensuring the output is reshaped to (total_unpadded_tokens, heads * dim) without any shape mismatch.

Suggested change
out_unpad = ATTN_WEIGHT_REGISTER["npu_flash_attn"]().apply(
q=q_unpad,
k=k_unpad,
v=v_unpad,
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
max_seqlen_q=int(lengths.max().item()),
max_seqlen_kv=int(lengths.max().item()),
)
out_unpad = ATTN_WEIGHT_REGISTER["npu_flash_attn"]().apply(
q=q_unpad.unsqueeze(1),
k=k_unpad.unsqueeze(1),
v=v_unpad.unsqueeze(1),
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
max_seqlen_q=1,
max_seqlen_kv=1,
)

Comment on lines +17 to +19
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Importing AI_DEVICE directly from lightx2v_platform.base.global_var and evaluating getattr(torch, AI_DEVICE) at the module level (import time) will raise a TypeError because AI_DEVICE is initialized to None at import time.

To prevent this startup crash and support dynamic device updates, we should evaluate the device module lazily using a proxy object or a helper function.

from lightx2v_platform.base.global_var import AI_DEVICE

class TorchDeviceModuleProxy:
    def __getattr__(self, name):
        from lightx2v_platform.base.global_var import AI_DEVICE
        device = AI_DEVICE or "cuda"
        return getattr(getattr(torch, device), name)

torch_device_module = TorchDeviceModuleProxy()

"audio_mel_bins":16,
"double_precision_rope": false,
"use_tiling_vae": false,
"dit_original_ckpt": "/data/wushuo1/to5_models/LTX-2.3/ltx-2.3-22b-dev.safetensors",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The configuration contains a hardcoded absolute path /data/wushuo1/to5_models/LTX-2.3/ltx-2.3-22b-dev.safetensors which is specific to a single user's environment. This makes the configuration non-portable and will fail on other machines or in production. Consider using a relative path, an environment variable, or a placeholder.

Suggested change
"dit_original_ckpt": "/data/wushuo1/to5_models/LTX-2.3/ltx-2.3-22b-dev.safetensors",
"dit_original_ckpt": "checkpoints/ltx-2.3-22b-dev.safetensors",

"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": false,
"dit_original_ckpt": "/data/wushuo1/to5_models/Self-Forcing/checkpoints/self_forcing_dmd.pt",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The configuration contains a hardcoded absolute path /data/wushuo1/to5_models/Self-Forcing/checkpoints/self_forcing_dmd.pt which is specific to a single user's environment. This makes the configuration non-portable and will fail on other machines or in production. Consider using a relative path, an environment variable, or a placeholder.

Suggested change
"dit_original_ckpt": "/data/wushuo1/to5_models/Self-Forcing/checkpoints/self_forcing_dmd.pt",
"dit_original_ckpt": "checkpoints/self_forcing_dmd.pt",

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant