-
Notifications
You must be signed in to change notification settings - Fork 218
adapt ascend npu #1158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
adapt ascend npu #1158
Changes from all commits
f6a96bf
fae813f
93d2fe2
a989bf2
7074e5f
5c205f2
9b051de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| { | ||
| "model_cls": "flux2_dev", | ||
| "task": "t2i", | ||
| "infer_steps": 50, | ||
| "aspect_ratio": "16:9", | ||
| "sample_guide_scale": 4.0, | ||
| "vae_scale_factor": 16, | ||
| "feature_caching": "None", | ||
| "enable_cfg": false, | ||
| "patch_size": 2, | ||
| "tokenizer_max_length": 512, | ||
| "text_encoder_out_layers": [10, 20, 30], | ||
| "attn_type": "npu_flash_attn", | ||
| "cpu_offload": true, | ||
| "offload_granularity": "block", | ||
| "modulate_type": "torch", | ||
| "rope_type": "torch", | ||
| "layer_norm_type": "torch", | ||
| "rms_norm_type": "torch" | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| { | ||
| "infer_steps": 50, | ||
| "transformer_model_name": "480p_i2v", | ||
| "fps": 24, | ||
| "target_video_length": 121, | ||
| "vae_stride": [4, 16, 16], | ||
| "sample_shift": 5.0, | ||
| "sample_guide_scale": 6.0, | ||
| "aspect_ratio": "16:9", | ||
| "enable_cfg": true, | ||
| "attn_type": "npu_flash_attn", | ||
| "modulate_type": "torch", | ||
| "rope_type": "torch", | ||
| "layer_norm_type": "torch", | ||
| "rms_norm_type": "torch" | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| { | ||
| "infer_steps": 30, | ||
| "target_video_length": 121, | ||
| "target_height": 512, | ||
| "target_width": 768, | ||
| "attn_type": "npu_flash_attn", | ||
| "sample_guide_scale": 3.0, | ||
| "sample_shift": [2.05, 0.95], | ||
| "enable_cfg": true, | ||
| "cpu_offload": true, | ||
| "offload_granularity": "model", | ||
| "norm_modulate_backend": "torch", | ||
| "modulate_type": "torch", | ||
| "rope_type": "torch", | ||
| "layer_norm_type": "torch", | ||
| "rms_norm_type": "torch", | ||
| "num_channels_latents": 128, | ||
| "fps": 24, | ||
| "audio_fps": 24000, | ||
| "audio_mel_bins":16, | ||
| "double_precision_rope": false, | ||
| "use_tiling_vae": false, | ||
| "dit_original_ckpt": "/data/wushuo1/to5_models/LTX-2.3/ltx-2.3-22b-dev.safetensors", | ||
| "caption_proj_before_connector": true, | ||
| "cross_attention_adaln": true, | ||
| "apply_gated_attention": true, | ||
| "mm_guider": { | ||
| "enabled": true, | ||
| "video": { | ||
| "cfg_scale": 3.0, | ||
| "stg_scale": 1.0, | ||
| "stg_blocks": [28], | ||
| "modality_scale": 3.0, | ||
| "rescale_scale": 0.7, | ||
| "skip_step": 0 | ||
| }, | ||
| "audio": { | ||
| "cfg_scale": 7.0, | ||
| "stg_scale": 1.0, | ||
| "stg_blocks": [28], | ||
| "modality_scale": 3.0, | ||
| "rescale_scale": 0.7, | ||
| "skip_step": 0 | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,27 @@ | ||||||
| { | ||||||
| "infer_steps": 4, | ||||||
| "target_video_length": 81, | ||||||
| "text_len": 512, | ||||||
| "target_height": 480, | ||||||
| "target_width": 832, | ||||||
| "self_attn_1_type": "npu_flash_attn", | ||||||
| "cross_attn_1_type": "npu_flash_attn", | ||||||
| "cross_attn_2_type": "npu_flash_attn", | ||||||
| "sample_guide_scale": 1, | ||||||
| "sample_shift": 5.0, | ||||||
| "enable_cfg": false, | ||||||
| "cpu_offload": false, | ||||||
| "dit_original_ckpt": "/data/wushuo1/to5_models/Self-Forcing/checkpoints/self_forcing_dmd.pt", | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The configuration contains a hardcoded absolute path
Suggested change
|
||||||
| "causal_rope_type": "torch", | ||||||
| "modulate_type": "torch", | ||||||
| "rope_type": "torch", | ||||||
| "layer_norm_type": "torch", | ||||||
| "rms_norm_type": "torch", | ||||||
| "ar_config": { | ||||||
| "local_attn_size": -1, | ||||||
| "num_frame_per_chunk": 3, | ||||||
| "timesteps_index": [0, 179, 358, 679], | ||||||
| "kv_offload": false, | ||||||
| "async_vae_decode": false | ||||||
| } | ||||||
| } | ||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -7,9 +7,10 @@ | |||||||||||||||||||||||||||||||||||||
| from torch.nn import functional as F | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from lightx2v.utils.envs import * | ||||||||||||||||||||||||||||||||||||||
| from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER | ||||||||||||||||||||||||||||||||||||||
| from lightx2v_platform.base.global_var import AI_DEVICE, PLATFORM | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from .attn_no_pad import flash_attn_no_pad, flash_attn_no_pad_v3, sage_attn_no_pad_v2 | ||||||||||||||||||||||||||||||||||||||
| from .attn_no_pad import flash_attn_no_pad, flash_attn_no_pad_v3, flash_attn_varlen_qkvpacked_func, pad_input, sage_attn_no_pad_v2, unpad_input | ||||||||||||||||||||||||||||||||||||||
| from .module_io import HunyuanVideo15InferModuleOutput | ||||||||||||||||||||||||||||||||||||||
| from .posemb_layers import get_nd_rotary_pos_embed | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
@@ -21,6 +22,15 @@ | |||||||||||||||||||||||||||||||||||||
| TIMESTEP_EMBEDDING_CUDA_AVAILABLE = False | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| _TOKEN_REFINER_ATTN_TYPE_BY_PLATFORM = { | ||||||||||||||||||||||||||||||||||||||
| "ascend_npu": "npu_flash_attn", | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _get_token_refiner_attn_type(): | ||||||||||||||||||||||||||||||||||||||
| return _TOKEN_REFINER_ATTN_TYPE_BY_PLATFORM.get(PLATFORM, "flash_attn2") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def apply_gate(x, gate=None, tanh=False): | ||||||||||||||||||||||||||||||||||||||
| """AI is creating summary for apply_gate | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
@@ -40,6 +50,46 @@ def apply_gate(x, gate=None, tanh=False): | |||||||||||||||||||||||||||||||||||||
| return x * gate.unsqueeze(1) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _torch_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, causal: bool = False) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||
| q = q.transpose(1, 2) | ||||||||||||||||||||||||||||||||||||||
| k = k.transpose(1, 2) | ||||||||||||||||||||||||||||||||||||||
| v = v.transpose(1, 2) | ||||||||||||||||||||||||||||||||||||||
| if attn_mask is not None: | ||||||||||||||||||||||||||||||||||||||
| attn_mask = attn_mask[:, None, None, :].expand(-1, 1, q.shape[-2], -1) | ||||||||||||||||||||||||||||||||||||||
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=causal) | ||||||||||||||||||||||||||||||||||||||
| return out.transpose(1, 2) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _npu_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||
| batch, seqlen, heads, dim = q.shape | ||||||||||||||||||||||||||||||||||||||
| if attn_mask is None: | ||||||||||||||||||||||||||||||||||||||
| lengths = torch.full((batch,), seqlen, dtype=torch.int32, device=q.device) | ||||||||||||||||||||||||||||||||||||||
| indices = torch.arange(batch * seqlen, device=q.device) | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| attn_mask = attn_mask.bool() | ||||||||||||||||||||||||||||||||||||||
| lengths = attn_mask.sum(dim=1, dtype=torch.int32) | ||||||||||||||||||||||||||||||||||||||
| indices = attn_mask.reshape(-1).nonzero(as_tuple=False).squeeze(1) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| cu_seqlens = torch.zeros(batch + 1, dtype=torch.int32, device=q.device) | ||||||||||||||||||||||||||||||||||||||
| cu_seqlens[1:] = torch.cumsum(lengths, dim=0) | ||||||||||||||||||||||||||||||||||||||
| q_unpad = q.reshape(batch * seqlen, heads, dim)[indices] | ||||||||||||||||||||||||||||||||||||||
| k_unpad = k.reshape(batch * seqlen, heads, dim)[indices] | ||||||||||||||||||||||||||||||||||||||
| v_unpad = v.reshape(batch * seqlen, heads, dim)[indices] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| out_unpad = ATTN_WEIGHT_REGISTER["npu_flash_attn"]().apply( | ||||||||||||||||||||||||||||||||||||||
| q=q_unpad, | ||||||||||||||||||||||||||||||||||||||
| k=k_unpad, | ||||||||||||||||||||||||||||||||||||||
| v=v_unpad, | ||||||||||||||||||||||||||||||||||||||
| cu_seqlens_q=cu_seqlens, | ||||||||||||||||||||||||||||||||||||||
| cu_seqlens_kv=cu_seqlens, | ||||||||||||||||||||||||||||||||||||||
| max_seqlen_q=int(lengths.max().item()), | ||||||||||||||||||||||||||||||||||||||
| max_seqlen_kv=int(lengths.max().item()), | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+79
to
+87
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The To resolve this, we can unsqueeze the unpadded tensors to 4D of shape
Suggested change
|
||||||||||||||||||||||||||||||||||||||
| out = torch.zeros(batch * seqlen, heads * dim, dtype=out_unpad.dtype, device=out_unpad.device) | ||||||||||||||||||||||||||||||||||||||
| out[indices] = out_unpad | ||||||||||||||||||||||||||||||||||||||
| return out.reshape(batch, seqlen, heads, dim) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| @torch.compiler.disable | ||||||||||||||||||||||||||||||||||||||
| def attention( | ||||||||||||||||||||||||||||||||||||||
| q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, drop_rate: float = 0.0, attn_mask: Optional[torch.Tensor] = None, causal: bool = False, attn_type: str = "flash_attn2" | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -62,11 +112,18 @@ def attention( | |||||||||||||||||||||||||||||||||||||
| if attn_mask is not None and attn_mask.dtype != torch.bool: | ||||||||||||||||||||||||||||||||||||||
| attn_mask = attn_mask.bool() | ||||||||||||||||||||||||||||||||||||||
| if attn_type == "flash_attn2": | ||||||||||||||||||||||||||||||||||||||
| x = flash_attn_no_pad(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None) | ||||||||||||||||||||||||||||||||||||||
| if flash_attn_varlen_qkvpacked_func is None or pad_input is None or unpad_input is None: | ||||||||||||||||||||||||||||||||||||||
| x = _torch_attention(q, k, v, attn_mask=attn_mask, causal=causal) | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| x = flash_attn_no_pad(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None) | ||||||||||||||||||||||||||||||||||||||
| elif attn_type == "flash_attn3": | ||||||||||||||||||||||||||||||||||||||
| x = flash_attn_no_pad_v3(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None) | ||||||||||||||||||||||||||||||||||||||
| elif attn_type == "sage_attn2": | ||||||||||||||||||||||||||||||||||||||
| x = sage_attn_no_pad_v2(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None) | ||||||||||||||||||||||||||||||||||||||
| elif attn_type == "npu_flash_attn": | ||||||||||||||||||||||||||||||||||||||
| x = _npu_attention(q, k, v, attn_mask=attn_mask) | ||||||||||||||||||||||||||||||||||||||
| elif attn_type == "torch": | ||||||||||||||||||||||||||||||||||||||
| x = _torch_attention(q, k, v, attn_mask=attn_mask, causal=causal) | ||||||||||||||||||||||||||||||||||||||
| b, s, a, d = x.shape | ||||||||||||||||||||||||||||||||||||||
| out = x.reshape(b, s, -1) | ||||||||||||||||||||||||||||||||||||||
| return out | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -218,7 +275,7 @@ def run_individual_token_refiner(self, weights, out, mask, c): | |||||||||||||||||||||||||||||||||||||
| norm_x = block.norm1.apply(out.unsqueeze(0)).squeeze(0) | ||||||||||||||||||||||||||||||||||||||
| qkv = block.self_attn_qkv.apply(norm_x).unsqueeze(0) | ||||||||||||||||||||||||||||||||||||||
| q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | ||||||||||||||||||||||||||||||||||||||
| attn = attention(q, k, v, attn_mask=mask, attn_type="flash_attn2").squeeze(0) | ||||||||||||||||||||||||||||||||||||||
| attn = attention(q, k, v, attn_mask=mask, attn_type=_get_token_refiner_attn_type()).squeeze(0) | ||||||||||||||||||||||||||||||||||||||
| out = out + apply_gate(block.self_attn_proj.apply(attn).unsqueeze(0), gate_msa).squeeze(0) | ||||||||||||||||||||||||||||||||||||||
| tmp = block.mlp_fc1.apply(block.norm2.apply(out)) | ||||||||||||||||||||||||||||||||||||||
| tmp = torch.nn.functional.silu(tmp) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The configuration contains a hardcoded absolute path
/data/wushuo1/to5_models/LTX-2.3/ltx-2.3-22b-dev.safetensorswhich is specific to a single user's environment. This makes the configuration non-portable and will fail on other machines or in production. Consider using a relative path, an environment variable, or a placeholder.