From d4d0d3796e4052d2423ed060e5f8d4667679edd0 Mon Sep 17 00:00:00 2001 From: liaopingbo Date: Tue, 16 Jun 2026 14:38:07 +0800 Subject: [PATCH 1/7] feat(npu): Add 4 Ascend NPU custom operators for sparse video inference --- lightx2v_platform/ops/__init__.py | 2 + .../attn/ascend_npu/rainfusion_blockwise.py | 270 ++++++++++++++++++ .../ops/norm/ascend_npu/__init__.py | 4 + .../ops/norm/ascend_npu/npu_layer_norm.py | 36 +++ .../ops/norm/ascend_npu/npu_rms_norm.py | 28 ++ .../ops/rope/ascend_npu/__init__.py | 3 + .../ops/rope/ascend_npu/npu_rope.py | 80 ++++++ 7 files changed, 423 insertions(+) create mode 100644 lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py create mode 100644 lightx2v_platform/ops/norm/ascend_npu/__init__.py create mode 100644 lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py create mode 100644 lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py create mode 100644 lightx2v_platform/ops/rope/ascend_npu/__init__.py create mode 100644 lightx2v_platform/ops/rope/ascend_npu/npu_rope.py diff --git a/lightx2v_platform/ops/__init__.py b/lightx2v_platform/ops/__init__.py index b56ad794a..163b91baa 100755 --- a/lightx2v_platform/ops/__init__.py +++ b/lightx2v_platform/ops/__init__.py @@ -15,6 +15,8 @@ elif PLATFORM == "ascend_npu": from .attn.ascend_npu import * from .mm.ascend_npu import * + from .norm.ascend_npu import * + from .rope.ascend_npu import * elif PLATFORM == "metax_cuda": from .attn.metax_cuda import * elif PLATFORM == "enflame_gcu": diff --git a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py new file mode 100644 index 000000000..382404300 --- /dev/null +++ b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py @@ -0,0 +1,270 @@ +import os +import math +import torch +import torch_npu +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +import time +import mindiesd +from mindiesd.layers.flash_attn.attention_forward import attention_forward + + +class Rainfusion_blockwise(nn.Module): + def __init__( + self, + grid_size: list, + pool_size: int = 128, + sparsity: float = 0.9, + skip_timesteps: int = 0, + txt_len: int = 0, + txt_first: bool = False, + ) -> None: + """ + 参数: + grid_size (list): latents的THW网格大小。 + sparsity (float, optional): 稀疏度, 取值范围[0, 1],默认为 0.50。 + """ + super().__init__() + + # Rainfusion_param + self.grid_size = grid_size + self.frame_num = self.grid_size[0] + self.num_tokens_per_frame = self.grid_size[1] * self.grid_size[2] + self.first_frame_len = self.num_tokens_per_frame + + self.pool_size = pool_size + self.sparsity = sparsity + self.skip_timesteps = skip_timesteps + self.text_len = txt_len + self.txt_first = txt_first + + @staticmethod + def get_grid_size(latent_size, patch_size): + t, h, w = latent_size[-3:] + return [t // patch_size[0], h // patch_size[1], w // patch_size[2]] + + def avgpool(self, input_tensor, pool_size=128): # BSND in, BSND out + batch, seqlen, headnum, dim = input_tensor.shape + + num_full_blocks = seqlen // pool_size + tail_size = seqlen % pool_size + + if num_full_blocks > 0: + full_blocks = input_tensor[:, :num_full_blocks * pool_size, :, :] + full_blocks_reshaped = full_blocks.view(batch, num_full_blocks, pool_size, headnum, dim) + full_pooled = full_blocks_reshaped.mean(dim=2) + else: + full_pooled = torch.empty(0, device=input_tensor.device) + if tail_size > 0: + tail_block = input_tensor[:, num_full_blocks * pool_size:, :, :] + tail_reshaped = tail_block.view(batch, 1, tail_size, headnum, dim) + tail_pooled = tail_reshaped.mean(dim=2) + else: + tail_pooled = torch.empty(0, device=input_tensor.device) + + if num_full_blocks > 0 and tail_size > 0: + output_tensor = torch.cat([full_pooled, tail_pooled], dim=1) + elif num_full_blocks > 0: + output_tensor = full_pooled + else: + output_tensor = tail_pooled + + return output_tensor + + def get_mask_index(self, mask): + B, N, S, _ = mask.shape + device = mask.device + + # 1. 重塑维度 → (B*N)×S×S + mask_reshaped = mask.reshape(-1, S, S) + batch_size = mask_reshaped.shape[0] + + # 2. 生成行索引 标记False位置为S(大于所有有效索引) + row_indices = torch.arange(S, device=device).expand(batch_size, S, -1) # (B*N, S, S) + sorted_vals = torch.where(mask_reshaped, row_indices, 1e9).to(torch.float32) + sorted_vals, _ = torch.sort(sorted_vals, dim=-1) + valid_count = mask_reshaped.sum(dim=-1, keepdim=True) # 每行True的个数 + keep_mask = row_indices < valid_count # 前valid_count个位置保留索引,其余填-1 + result = torch.where(keep_mask, sorted_vals, -1) + + pos_matrix = result.reshape(B, N, S, S).to(torch.int64) + return pos_matrix + + def get_blockwise_mask(self, score_matrix, sparsity): + batch_size, num_heads, rows, cols = score_matrix.shape + + keep_len = math.ceil(cols * (1 - sparsity)) + topk_values, _ = torch.topk(score_matrix, k=keep_len, dim=-1) + thresholds = topk_values[..., -1:] + mask = score_matrix >= thresholds + + protect_len = (self.first_frame_len + self.text_len + self.pool_size - 1) // self.pool_size + + if protect_len > 0: + mask[:, :, -protect_len:, :] = True + mask[:, :, :, -protect_len:] = True + + selectIdx = self.get_mask_index(mask) + selectIdx = selectIdx[0].transpose(0, 1) + selectNumIdx = mask[0].transpose(0, 1).sum(dim=-1) + return selectIdx, selectNumIdx + + def rearrange_with_remaining(self, tensor): # BSND in , BSND out + b, s, n, d = tensor.shape + h = self.grid_size[1] + w = self.grid_size[2] + h_res_len, w_res_len = 0, 0 + first_frame_num = self.first_frame_len // h // w + + tensor_hwt = rearrange(tensor, 'b (f h w) n d -> (b n) f h w d', f=self.frame_num - first_frame_num, h=h, w=w) + if h % 8 != 0: + tensor_hwt, tensor_h_r = torch.split(tensor_hwt, h - (h % 8), dim=2) + tensor_h_r = tensor_h_r.reshape(b * n, -1, d) + h_res_len = tensor_h_r.shape[1] + if w % 8 != 0: + tensor_hwt, tensor_w_r = torch.split(tensor_hwt, w - (w % 8), dim=3) + tensor_w_r = tensor_w_r.reshape(b * n, -1, d) + w_res_len = tensor_w_r.shape[1] + tensor_hwt = rearrange(tensor_hwt, 'b (fn fb) (hn hb) (wn wb) d -> b (fn hn wn fb hb wb) d', + fn=self.frame_num // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) + if h % 8 != 0: + tensor_hwt = torch.cat((tensor_hwt, tensor_h_r), dim=1) + if w % 8 != 0: + tensor_hwt = torch.cat((tensor_hwt, tensor_w_r), dim=1) + tensor_hwt = rearrange(tensor_hwt, '(b n) s d -> b s n d', b=b, n=n) + return tensor_hwt, h_res_len, w_res_len + + def inv_rearrange_with_remaining(self, tensor, h_res_len, w_res_len): # BSND in , BSND out + b, s, n, d = tensor.shape + h = self.grid_size[1] + w = self.grid_size[2] + h_sr, w_sr = h % 8, w % 8 + + tensor = rearrange(tensor, 'b s n d->(b n) s d', b=b, n=n) + tensor_hwt, tensor_h, tensor_w = torch.split(tensor, [s - h_res_len - w_res_len, h_res_len, w_res_len], dim=1) + tensor_hwt = rearrange(tensor_hwt, 'b (fn hn wn fb hb wb) d -> b (fn fb) (hn hb) (wn wb) d', + fn=self.frame_num // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) + if w_res_len != 0: + tensor_w = tensor_w.reshape(b * n, self.frame_num - 1, h - h_sr, w_sr, d) + tensor_hwt = torch.cat((tensor_hwt, tensor_w), dim=3) + if h_sr != 0: + tensor_h = tensor_h.reshape(b * n, self.frame_num - 1, h_sr, w, d) + tensor_hwt = torch.cat((tensor_hwt, tensor_h), dim=2) + tensor_hwt = tensor_hwt.reshape(b * n, -1, d) + tensor_hwt = rearrange(tensor_hwt, '(b n) s d -> b s n d', b=b, n=n) + return tensor_hwt + + def do_tensor_rearrange_pooling(self, tensor): # BSND in , BSND out + b, s, n, d = tensor.shape + if self.txt_first: + tensor_t, tensor_f, tensor_i = torch.split(tensor, [self.text_len, self.first_frame_len, + s - self.text_len - self.first_frame_len], dim=1) + else: + tensor_f, tensor_i, tensor_t = torch.split(tensor, + [self.first_frame_len, s - self.text_len - self.first_frame_len, + self.text_len], dim=1) + tensor_i_2, h_res_len, w_res_len = self.rearrange_with_remaining(tensor_i) + tensor = torch.concat((tensor_i_2, tensor_f, tensor_t), dim=1) + tensor_pool = self.avgpool(tensor, pool_size=128) + return tensor, tensor_pool, h_res_len, w_res_len + + def do_tensor_inv_rearrange(self, tensor, h_res_len, w_res_len): + b, s, n, d = tensor.shape + tensor_i, tensor_f, tensor_t = torch.split(tensor, + [s - self.text_len - self.first_frame_len, self.first_frame_len, + self.text_len], dim=1) + tensor_i = self.inv_rearrange_with_remaining(tensor_i, h_res_len, w_res_len) + tensor = torch.concat((tensor_t, tensor_i), dim=1) + + if self.txt_first: + tensor = torch.concat((tensor_t, tensor_f, tensor_i), dim=1) + else: + tensor = torch.concat((tensor_f, tensor_i, tensor_t), dim=1) + return tensor + + def do_tensor_pooling(self, tensor): + tensor_t = tensor[:, :self.text_len, :, :] + tensor_i = tensor[:, self.text_len:, :, :] + + tensor_i_pool = self.avgpool(tensor_i, pool_size=128) + tensor_t_pool = self.avgpool(tensor_t, pool_size=128) + + tensor_pool = torch.concat((tensor_t_pool, tensor_i_pool), dim=1) + return tensor_pool + + def forward( + self, + q, # BSND + k, + v, + t_b_idx, + base_blockmask, + ): + t_idx = t_b_idx[0] + b_idx = t_b_idx[1] + device = q.device + + if t_idx < self.skip_timesteps: + base_blockmask = None + x = attention_forward(q, k, v, + opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD") + else: + batch, qSeqlen, numHeads, headDim = q.shape + _, kvSeqlen, _, _ = k.shape + blockShapeX, blockShapeY = self.pool_size, self.pool_size + scale = headDim ** -0.5 + + totalQTokens = batch * qSeqlen + totalKvTokens = batch * kvSeqlen + qBlockNum = math.ceil(qSeqlen / blockShapeX) + kvBlockNum = math.ceil(kvSeqlen / blockShapeY) + totalQBlocks = qBlockNum + maxKvBlockNum = kvBlockNum + blockShape = [blockShapeX, blockShapeY] + actualSeqLengthsHost = [qSeqlen for _ in range(batch)] + actualSeqLengthsKvHost = [kvSeqlen for _ in range(batch)] + + sparsity = self.sparsity + + h_res_len, w_res_len = 0, 0 + if base_blockmask is None: + qkv = torch.cat((q, k, v), dim=0) + qkv, qkv_pool, h_res_len, w_res_len = self.do_tensor_rearrange_pooling(qkv) + q, k, v = torch.chunk(qkv, 3, dim=0) + # qkv_pool = self.do_tensor_pooling(qkv) + query_pool, key_pool, value_pool = torch.chunk(qkv_pool, 3, dim=0) + + attn_scores_head = torch.einsum("blnd,bsnd->bnls", query_pool, key_pool) * scale + attn_scores_fake = torch.nn.functional.softmax(attn_scores_head, dim=-1) + + selectIdx, selectNumIdx = self.get_blockwise_mask(attn_scores_fake, sparsity) + base_blockmask = [selectIdx, selectNumIdx] + else: + selectIdx = base_blockmask[0] + selectNumIdx = base_blockmask[1] + qkv = torch.cat((q, k, v), dim=0) + qkv, qkv_pool, h_res_len, w_res_len = self.do_tensor_rearrange_pooling(qkv) + q, k, v = torch.chunk(qkv, 3, dim=0) + + q_bnsd = q.transpose(1, 2) + k_bnsd = k.transpose(1, 2) + v_bnsd = v.transpose(1, 2) + x = mindiesd.layers.flash_attn.sparse_flash_attn_rf_v2.rain_fusion_attention( + q_bnsd, k_bnsd, v_bnsd, + scale=scale, + head_num=numHeads, + input_layout="BNSD", + select_idx=selectIdx, + select_num_idx=selectNumIdx, + blockshape=blockShape, + actual_seq_lengths=actualSeqLengthsHost, + actual_seq_lengths_kv=actualSeqLengthsKvHost + ) + + x = x.transpose(1, 2).view(batch, qSeqlen, numHeads, headDim) + + x = self.do_tensor_inv_rearrange(x, h_res_len, w_res_len) + base_blockmask = None + + return x, base_blockmask \ No newline at end of file diff --git a/lightx2v_platform/ops/norm/ascend_npu/__init__.py b/lightx2v_platform/ops/norm/ascend_npu/__init__.py new file mode 100644 index 000000000..e522b460b --- /dev/null +++ b/lightx2v_platform/ops/norm/ascend_npu/__init__.py @@ -0,0 +1,4 @@ +from .npu_rms_norm import NpuRmsNormWeight +from .npu_layer_norm import NpuLayerNormWeight + +__all__ = ["NpuRmsNormWeight", "NpuLayerNormWeight"] \ No newline at end of file diff --git a/lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py b/lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py new file mode 100644 index 000000000..871c5d28f --- /dev/null +++ b/lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py @@ -0,0 +1,36 @@ +import torch + +from lightx2v_platform.ops.norm.norm_template import LayerNormWeightTemplate +from lightx2v_platform.registry_factory import PLATFORM_LAYERNORM_WEIGHT_REGISTER + +try: + import torch_npu +except ImportError: + torch_npu = None + + +@PLATFORM_LAYERNORM_WEIGHT_REGISTER("npu_layer_norm") +class NpuLayerNormWeight(LayerNormWeightTemplate): + def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6, **kwargs): + super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps) + + def apply(self, input_tensor): + if torch_npu is not None and hasattr(torch_npu, 'npu_layer_norm'): + out = torch_npu.npu_layer_norm( + input_tensor, + (input_tensor.shape[-1],), + self.weight, + self.bias, + self.eps, + ) + if isinstance(out, tuple): + out = out[0] + return out + output = torch.nn.functional.layer_norm( + input_tensor, + (input_tensor.shape[-1],), + self.weight, + self.bias, + self.eps, + ) + return output \ No newline at end of file diff --git a/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py b/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py new file mode 100644 index 000000000..bf01fefc7 --- /dev/null +++ b/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py @@ -0,0 +1,28 @@ +import torch + +from lightx2v_platform.ops.norm.norm_template import RMSWeightTemplate +from lightx2v_platform.registry_factory import PLATFORM_RMS_WEIGHT_REGISTER + +try: + import torch_npu +except ImportError: + torch_npu = None + + +@PLATFORM_RMS_WEIGHT_REGISTER("npu_rms_norm") +class NpuRmsNormWeight(RMSWeightTemplate): + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + def apply(self, input_tensor): + weight = self.weight + if torch_npu is not None and hasattr(torch_npu, "npu_rms_norm"): + if self.sensitive_layer_dtype != self.infer_dtype: + output_tensor, _ = torch_npu.npu_rms_norm( + input_tensor.float(), weight.float(), self.eps, + ) + return output_tensor.to(self.infer_dtype) + output_tensor, _ = torch_npu.npu_rms_norm(input_tensor, weight, self.eps) + return output_tensor + x = self._norm(input_tensor.float()) + return (x.float() * weight.float()).to(self.infer_dtype) \ No newline at end of file diff --git a/lightx2v_platform/ops/rope/ascend_npu/__init__.py b/lightx2v_platform/ops/rope/ascend_npu/__init__.py new file mode 100644 index 000000000..ba664c3de --- /dev/null +++ b/lightx2v_platform/ops/rope/ascend_npu/__init__.py @@ -0,0 +1,3 @@ +from .npu_rope import NpuRope + +__all__ = ["NpuRope"] diff --git a/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py b/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py new file mode 100644 index 000000000..2d205eb83 --- /dev/null +++ b/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py @@ -0,0 +1,80 @@ +import torch + +from lightx2v_platform.ops.rope.rope_template import GET_DTYPE, RopeTemplate +from lightx2v_platform.registry_factory import PLATFORM_ROPE_REGISTER + +try: + import torch_npu +except ImportError: + torch_npu = None + + +def GET_SENSITIVE_DTYPE(): + import os + DTYPE_MAP = { + "BF16": torch.bfloat16, + "FP16": torch.float16, + "FP32": torch.float32, + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + } + flag = os.getenv("SENSITIVE_LAYER_DTYPE", "None") + if flag == "None": + return GET_DTYPE() + return DTYPE_MAP[flag] + + +@PLATFORM_ROPE_REGISTER("npu_rope") +class NpuRope(RopeTemplate): + def __init__(self): + super().__init__() + self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() + + def _apply_rope_fp32(self, xq, xk, cos_sin_cache): + n = xq.size(1) + seq_len = cos_sin_cache.size(0) + xq_fp32 = torch.view_as_complex( + xq[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2) + ) + xk_fp32 = torch.view_as_complex( + xk[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2) + ) + xq_rot = torch.view_as_real(xq_fp32 * cos_sin_cache).flatten(2) + xk_rot = torch.view_as_real(xk_fp32 * cos_sin_cache).flatten(2) + if xq.size(0) > seq_len: + xq_rot = torch.cat([xq_rot, xq[seq_len:]], dim=0) + xk_rot = torch.cat([xk_rot, xk[seq_len:]], dim=0) + return xq_rot.to(self.infer_dtype), xk_rot.to(self.infer_dtype) + + def apply(self, xq: torch.Tensor, xk: torch.Tensor, cos_sin_cache: torch.Tensor): + s, n, d = xq.shape + seq_len = cos_sin_cache.size(0) + cos = cos_sin_cache.real + sin = cos_sin_cache.imag + cos = cos.repeat_interleave(2, dim=-1) + sin = sin.repeat_interleave(2, dim=-1) + xq_part = xq[:seq_len] + xk_part = xk[:seq_len] + + if torch_npu is not None and hasattr(torch_npu, "npu_rotary_mul"): + if self.sensitive_layer_dtype != self.infer_dtype: + xq_part = xq_part.float() + xk_part = xk_part.float() + cos = cos.float() if cos.dtype != torch.float32 else cos + sin = sin.float() if sin.dtype != torch.float32 else sin + if not xq_part.is_contiguous(): + xq_part = xq_part.contiguous() + if not xk_part.is_contiguous(): + xk_part = xk_part.contiguous() + xq_rotated = torch_npu.npu_rotary_mul(xq_part, cos, sin, "interleave") + xk_rotated = torch_npu.npu_rotary_mul(xk_part, cos, sin, "interleave") + if s > seq_len: + xq = torch.cat([xq_rotated, xq[seq_len:]], dim=0) + xk = torch.cat([xk_rotated, xk[seq_len:]], dim=0) + else: + xq = xq_rotated + xk = xk_rotated + return xq.to(self.infer_dtype), xk.to(self.infer_dtype) + + return self._apply_rope_fp32(xq, xk, cos_sin_cache) From 3fb61d9a12fe9931006b2dd04ca303ec9b54536f Mon Sep 17 00:00:00 2001 From: liaopingbo Date: Tue, 16 Jun 2026 15:58:18 +0800 Subject: [PATCH 2/7] feat(npu): Add 4 Ascend NPU custom operators for sparse video inference --- .../attn/ascend_npu/rainfusion_blockwise.py | 22 ++++++------------- .../ops/norm/ascend_npu/npu_layer_norm.py | 2 +- .../ops/norm/ascend_npu/npu_rms_norm.py | 2 +- .../ops/rope/ascend_npu/npu_rope.py | 3 +++ 4 files changed, 12 insertions(+), 17 deletions(-) diff --git a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py index 382404300..c8dd5d027 100644 --- a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py +++ b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py @@ -82,13 +82,13 @@ def get_mask_index(self, mask): # 2. 生成行索引 标记False位置为S(大于所有有效索引) row_indices = torch.arange(S, device=device).expand(batch_size, S, -1) # (B*N, S, S) - sorted_vals = torch.where(mask_reshaped, row_indices, 1e9).to(torch.float32) + sorted_vals = torch.where(mask_reshaped, row_indices, S) sorted_vals, _ = torch.sort(sorted_vals, dim=-1) - valid_count = mask_reshaped.sum(dim=-1, keepdim=True) # 每行True的个数 - keep_mask = row_indices < valid_count # 前valid_count个位置保留索引,其余填-1 + valid_count = mask_reshaped.sum(dim=-1, keepdim=True) + keep_mask = row_indices < valid_count result = torch.where(keep_mask, sorted_vals, -1) - pos_matrix = result.reshape(B, N, S, S).to(torch.int64) + pos_matrix = result.reshape(B, N, S, S) return pos_matrix def get_blockwise_mask(self, score_matrix, sparsity): @@ -127,7 +127,7 @@ def rearrange_with_remaining(self, tensor): # BSND in , BSND out tensor_w_r = tensor_w_r.reshape(b * n, -1, d) w_res_len = tensor_w_r.shape[1] tensor_hwt = rearrange(tensor_hwt, 'b (fn fb) (hn hb) (wn wb) d -> b (fn hn wn fb hb wb) d', - fn=self.frame_num // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) + fn=(self.frame_num - first_frame_num) // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) if h % 8 != 0: tensor_hwt = torch.cat((tensor_hwt, tensor_h_r), dim=1) if w % 8 != 0: @@ -144,7 +144,7 @@ def inv_rearrange_with_remaining(self, tensor, h_res_len, w_res_len): # BSND in tensor = rearrange(tensor, 'b s n d->(b n) s d', b=b, n=n) tensor_hwt, tensor_h, tensor_w = torch.split(tensor, [s - h_res_len - w_res_len, h_res_len, w_res_len], dim=1) tensor_hwt = rearrange(tensor_hwt, 'b (fn hn wn fb hb wb) d -> b (fn fb) (hn hb) (wn wb) d', - fn=self.frame_num // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) + fn=(self.frame_num - 1) // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) if w_res_len != 0: tensor_w = tensor_w.reshape(b * n, self.frame_num - 1, h - h_sr, w_sr, d) tensor_hwt = torch.cat((tensor_hwt, tensor_w), dim=3) @@ -175,7 +175,6 @@ def do_tensor_inv_rearrange(self, tensor, h_res_len, w_res_len): [s - self.text_len - self.first_frame_len, self.first_frame_len, self.text_len], dim=1) tensor_i = self.inv_rearrange_with_remaining(tensor_i, h_res_len, w_res_len) - tensor = torch.concat((tensor_t, tensor_i), dim=1) if self.txt_first: tensor = torch.concat((tensor_t, tensor_f, tensor_i), dim=1) @@ -202,8 +201,6 @@ def forward( base_blockmask, ): t_idx = t_b_idx[0] - b_idx = t_b_idx[1] - device = q.device if t_idx < self.skip_timesteps: base_blockmask = None @@ -211,16 +208,11 @@ def forward( opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD") else: batch, qSeqlen, numHeads, headDim = q.shape + assert batch == 1, "Rainfusion_blockwise currently only supports batch size 1." _, kvSeqlen, _, _ = k.shape blockShapeX, blockShapeY = self.pool_size, self.pool_size scale = headDim ** -0.5 - totalQTokens = batch * qSeqlen - totalKvTokens = batch * kvSeqlen - qBlockNum = math.ceil(qSeqlen / blockShapeX) - kvBlockNum = math.ceil(kvSeqlen / blockShapeY) - totalQBlocks = qBlockNum - maxKvBlockNum = kvBlockNum blockShape = [blockShapeX, blockShapeY] actualSeqLengthsHost = [qSeqlen for _ in range(batch)] actualSeqLengthsKvHost = [kvSeqlen for _ in range(batch)] diff --git a/lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py b/lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py index 871c5d28f..465424d61 100644 --- a/lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py +++ b/lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py @@ -15,7 +15,7 @@ def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, c super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps) def apply(self, input_tensor): - if torch_npu is not None and hasattr(torch_npu, 'npu_layer_norm'): + if torch_npu is not None and hasattr(torch_npu, 'npu_layer_norm') and self.weight is not None and self.bias is not None: out = torch_npu.npu_layer_norm( input_tensor, (input_tensor.shape[-1],), diff --git a/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py b/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py index bf01fefc7..2f94674f9 100644 --- a/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py +++ b/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py @@ -16,7 +16,7 @@ def _norm(self, x): def apply(self, input_tensor): weight = self.weight - if torch_npu is not None and hasattr(torch_npu, "npu_rms_norm"): + if torch_npu is not None and hasattr(torch_npu, "npu_rms_norm") and weight is not None: if self.sensitive_layer_dtype != self.infer_dtype: output_tensor, _ = torch_npu.npu_rms_norm( input_tensor.float(), weight.float(), self.eps, diff --git a/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py b/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py index 2d205eb83..a4188d4b3 100644 --- a/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py +++ b/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py @@ -18,6 +18,9 @@ def GET_SENSITIVE_DTYPE(): "bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32, + "torch.bfloat16": torch.bfloat16, + "torch.float16": torch.float16, + "torch.float32": torch.float32, } flag = os.getenv("SENSITIVE_LAYER_DTYPE", "None") if flag == "None": From 969b41a872fa505f300076d31eb397b801e7ffc1 Mon Sep 17 00:00:00 2001 From: liaopingbo Date: Tue, 16 Jun 2026 16:38:07 +0800 Subject: [PATCH 3/7] feat(npu): Add 4 Ascend NPU custom operators for sparse video inference --- .../ops/attn/ascend_npu/rainfusion_blockwise.py | 16 ++++++++-------- .../ops/norm/ascend_npu/npu_rms_norm.py | 4 +++- .../ops/rope/ascend_npu/npu_rope.py | 10 +++++----- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py index c8dd5d027..7f10b156e 100644 --- a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py +++ b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py @@ -126,8 +126,11 @@ def rearrange_with_remaining(self, tensor): # BSND in , BSND out tensor_hwt, tensor_w_r = torch.split(tensor_hwt, w - (w % 8), dim=3) tensor_w_r = tensor_w_r.reshape(b * n, -1, d) w_res_len = tensor_w_r.shape[1] + remaining_frames = self.frame_num - first_frame_num + if remaining_frames % 2 != 0: + raise ValueError(f"The number of remaining frames ({remaining_frames}) must be even to be rearranged with block size 2.") tensor_hwt = rearrange(tensor_hwt, 'b (fn fb) (hn hb) (wn wb) d -> b (fn hn wn fb hb wb) d', - fn=(self.frame_num - first_frame_num) // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) + fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) if h % 8 != 0: tensor_hwt = torch.cat((tensor_hwt, tensor_h_r), dim=1) if w % 8 != 0: @@ -220,11 +223,11 @@ def forward( sparsity = self.sparsity h_res_len, w_res_len = 0, 0 + qkv = torch.cat((q, k, v), dim=0) + qkv, qkv_pool, h_res_len, w_res_len = self.do_tensor_rearrange_pooling(qkv) + q, k, v = torch.chunk(qkv, 3, dim=0) + if base_blockmask is None: - qkv = torch.cat((q, k, v), dim=0) - qkv, qkv_pool, h_res_len, w_res_len = self.do_tensor_rearrange_pooling(qkv) - q, k, v = torch.chunk(qkv, 3, dim=0) - # qkv_pool = self.do_tensor_pooling(qkv) query_pool, key_pool, value_pool = torch.chunk(qkv_pool, 3, dim=0) attn_scores_head = torch.einsum("blnd,bsnd->bnls", query_pool, key_pool) * scale @@ -235,9 +238,6 @@ def forward( else: selectIdx = base_blockmask[0] selectNumIdx = base_blockmask[1] - qkv = torch.cat((q, k, v), dim=0) - qkv, qkv_pool, h_res_len, w_res_len = self.do_tensor_rearrange_pooling(qkv) - q, k, v = torch.chunk(qkv, 3, dim=0) q_bnsd = q.transpose(1, 2) k_bnsd = k.transpose(1, 2) diff --git a/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py b/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py index 2f94674f9..fcac7669e 100644 --- a/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py +++ b/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py @@ -25,4 +25,6 @@ def apply(self, input_tensor): output_tensor, _ = torch_npu.npu_rms_norm(input_tensor, weight, self.eps) return output_tensor x = self._norm(input_tensor.float()) - return (x.float() * weight.float()).to(self.infer_dtype) \ No newline at end of file + if weight is not None: + return (x.float() * weight.float()).to(self.infer_dtype) + return x.to(self.infer_dtype) \ No newline at end of file diff --git a/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py b/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py index a4188d4b3..ec01dc50e 100644 --- a/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py +++ b/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py @@ -73,11 +73,11 @@ def apply(self, xq: torch.Tensor, xk: torch.Tensor, cos_sin_cache: torch.Tensor) xq_rotated = torch_npu.npu_rotary_mul(xq_part, cos, sin, "interleave") xk_rotated = torch_npu.npu_rotary_mul(xk_part, cos, sin, "interleave") if s > seq_len: - xq = torch.cat([xq_rotated, xq[seq_len:]], dim=0) - xk = torch.cat([xk_rotated, xk[seq_len:]], dim=0) + xq = torch.cat([xq_rotated.to(self.infer_dtype), xq[seq_len:]], dim=0) + xk = torch.cat([xk_rotated.to(self.infer_dtype), xk[seq_len:]], dim=0) else: - xq = xq_rotated - xk = xk_rotated - return xq.to(self.infer_dtype), xk.to(self.infer_dtype) + xq = xq_rotated.to(self.infer_dtype) + xk = xk_rotated.to(self.infer_dtype) + return xq, xk return self._apply_rope_fp32(xq, xk, cos_sin_cache) From 525d729c05c7052e571f6fd90531a51a1b57179c Mon Sep 17 00:00:00 2001 From: liaopingbo Date: Tue, 16 Jun 2026 17:15:16 +0800 Subject: [PATCH 4/7] feat(npu): Add 4 Ascend NPU custom operators for sparse video inference --- .../attn/ascend_npu/rainfusion_blockwise.py | 22 ++++++++----------- .../ops/rope/ascend_npu/npu_rope.py | 4 ++++ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py index 7f10b156e..d0267f829 100644 --- a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py +++ b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py @@ -55,22 +55,15 @@ def avgpool(self, input_tensor, pool_size=128): # BSND in, BSND out full_blocks_reshaped = full_blocks.view(batch, num_full_blocks, pool_size, headnum, dim) full_pooled = full_blocks_reshaped.mean(dim=2) else: - full_pooled = torch.empty(0, device=input_tensor.device) + full_pooled = torch.empty(batch, 0, headnum, dim, device=input_tensor.device) if tail_size > 0: tail_block = input_tensor[:, num_full_blocks * pool_size:, :, :] tail_reshaped = tail_block.view(batch, 1, tail_size, headnum, dim) tail_pooled = tail_reshaped.mean(dim=2) else: - tail_pooled = torch.empty(0, device=input_tensor.device) + tail_pooled = torch.empty(batch, 0, headnum, dim, device=input_tensor.device) - if num_full_blocks > 0 and tail_size > 0: - output_tensor = torch.cat([full_pooled, tail_pooled], dim=1) - elif num_full_blocks > 0: - output_tensor = full_pooled - else: - output_tensor = tail_pooled - - return output_tensor + return torch.cat([full_pooled, tail_pooled], dim=1) def get_mask_index(self, mask): B, N, S, _ = mask.shape @@ -95,6 +88,7 @@ def get_blockwise_mask(self, score_matrix, sparsity): batch_size, num_heads, rows, cols = score_matrix.shape keep_len = math.ceil(cols * (1 - sparsity)) + keep_len = max(1, min(keep_len, cols)) topk_values, _ = torch.topk(score_matrix, k=keep_len, dim=-1) thresholds = topk_values[..., -1:] mask = score_matrix >= thresholds @@ -143,16 +137,18 @@ def inv_rearrange_with_remaining(self, tensor, h_res_len, w_res_len): # BSND in h = self.grid_size[1] w = self.grid_size[2] h_sr, w_sr = h % 8, w % 8 + first_frame_num = self.first_frame_len // (h * w) + remaining_frames = self.frame_num - first_frame_num tensor = rearrange(tensor, 'b s n d->(b n) s d', b=b, n=n) tensor_hwt, tensor_h, tensor_w = torch.split(tensor, [s - h_res_len - w_res_len, h_res_len, w_res_len], dim=1) tensor_hwt = rearrange(tensor_hwt, 'b (fn hn wn fb hb wb) d -> b (fn fb) (hn hb) (wn wb) d', - fn=(self.frame_num - 1) // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) + fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) if w_res_len != 0: - tensor_w = tensor_w.reshape(b * n, self.frame_num - 1, h - h_sr, w_sr, d) + tensor_w = tensor_w.reshape(b * n, remaining_frames, h - h_sr, w_sr, d) tensor_hwt = torch.cat((tensor_hwt, tensor_w), dim=3) if h_sr != 0: - tensor_h = tensor_h.reshape(b * n, self.frame_num - 1, h_sr, w, d) + tensor_h = tensor_h.reshape(b * n, remaining_frames, h_sr, w, d) tensor_hwt = torch.cat((tensor_hwt, tensor_h), dim=2) tensor_hwt = tensor_hwt.reshape(b * n, -1, d) tensor_hwt = rearrange(tensor_hwt, '(b n) s d -> b s n d', b=b, n=n) diff --git a/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py b/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py index ec01dc50e..56daaf4fc 100644 --- a/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py +++ b/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py @@ -1,4 +1,5 @@ import torch +from functools import lru_cache from lightx2v_platform.ops.rope.rope_template import GET_DTYPE, RopeTemplate from lightx2v_platform.registry_factory import PLATFORM_ROPE_REGISTER @@ -9,6 +10,7 @@ torch_npu = None +@lru_cache(maxsize=None) def GET_SENSITIVE_DTYPE(): import os DTYPE_MAP = { @@ -25,6 +27,8 @@ def GET_SENSITIVE_DTYPE(): flag = os.getenv("SENSITIVE_LAYER_DTYPE", "None") if flag == "None": return GET_DTYPE() + if flag not in DTYPE_MAP: + raise ValueError(f"Unsupported SENSITIVE_LAYER_DTYPE: {flag}. Expected one of {list(DTYPE_MAP.keys())}") return DTYPE_MAP[flag] From 06dec8d2cce0688cb73a09f13e50e94fb7e01c8f Mon Sep 17 00:00:00 2001 From: liaopingbo Date: Tue, 16 Jun 2026 17:53:23 +0800 Subject: [PATCH 5/7] feat(npu): Add 4 Ascend NPU custom operators for sparse video inference --- .../attn/ascend_npu/rainfusion_blockwise.py | 95 ++++++++----------- .../ops/norm/ascend_npu/__init__.py | 4 +- .../ops/norm/ascend_npu/npu_layer_norm.py | 15 ++- .../ops/norm/ascend_npu/npu_rms_norm.py | 6 +- .../ops/rope/ascend_npu/npu_rope.py | 20 ++-- 5 files changed, 70 insertions(+), 70 deletions(-) diff --git a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py index d0267f829..277132ffe 100644 --- a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py +++ b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py @@ -1,24 +1,21 @@ -import os import math + +import mindiesd import torch -import torch_npu import torch.nn as nn -import torch.nn.functional as F from einops import rearrange -import time -import mindiesd from mindiesd.layers.flash_attn.attention_forward import attention_forward class Rainfusion_blockwise(nn.Module): def __init__( - self, - grid_size: list, - pool_size: int = 128, - sparsity: float = 0.9, - skip_timesteps: int = 0, - txt_len: int = 0, - txt_first: bool = False, + self, + grid_size: list, + pool_size: int = 128, + sparsity: float = 0.9, + skip_timesteps: int = 0, + txt_len: int = 0, + txt_first: bool = False, ) -> None: """ 参数: @@ -50,20 +47,17 @@ def avgpool(self, input_tensor, pool_size=128): # BSND in, BSND out num_full_blocks = seqlen // pool_size tail_size = seqlen % pool_size + pooled_tensors = [] if num_full_blocks > 0: - full_blocks = input_tensor[:, :num_full_blocks * pool_size, :, :] + full_blocks = input_tensor[:, : num_full_blocks * pool_size, :, :] full_blocks_reshaped = full_blocks.view(batch, num_full_blocks, pool_size, headnum, dim) - full_pooled = full_blocks_reshaped.mean(dim=2) - else: - full_pooled = torch.empty(batch, 0, headnum, dim, device=input_tensor.device) + pooled_tensors.append(full_blocks_reshaped.mean(dim=2)) if tail_size > 0: - tail_block = input_tensor[:, num_full_blocks * pool_size:, :, :] + tail_block = input_tensor[:, num_full_blocks * pool_size :, :, :] tail_reshaped = tail_block.view(batch, 1, tail_size, headnum, dim) - tail_pooled = tail_reshaped.mean(dim=2) - else: - tail_pooled = torch.empty(batch, 0, headnum, dim, device=input_tensor.device) + pooled_tensors.append(tail_reshaped.mean(dim=2)) - return torch.cat([full_pooled, tail_pooled], dim=1) + return torch.cat(pooled_tensors, dim=1) def get_mask_index(self, mask): B, N, S, _ = mask.shape @@ -74,7 +68,7 @@ def get_mask_index(self, mask): batch_size = mask_reshaped.shape[0] # 2. 生成行索引 标记False位置为S(大于所有有效索引) - row_indices = torch.arange(S, device=device).expand(batch_size, S, -1) # (B*N, S, S) + row_indices = torch.arange(S, device=device).view(1, 1, S).expand(batch_size, S, S) # (B*N, S, S) sorted_vals = torch.where(mask_reshaped, row_indices, S) sorted_vals, _ = torch.sort(sorted_vals, dim=-1) valid_count = mask_reshaped.sum(dim=-1, keepdim=True) @@ -111,7 +105,7 @@ def rearrange_with_remaining(self, tensor): # BSND in , BSND out h_res_len, w_res_len = 0, 0 first_frame_num = self.first_frame_len // h // w - tensor_hwt = rearrange(tensor, 'b (f h w) n d -> (b n) f h w d', f=self.frame_num - first_frame_num, h=h, w=w) + tensor_hwt = rearrange(tensor, "b (f h w) n d -> (b n) f h w d", f=self.frame_num - first_frame_num, h=h, w=w) if h % 8 != 0: tensor_hwt, tensor_h_r = torch.split(tensor_hwt, h - (h % 8), dim=2) tensor_h_r = tensor_h_r.reshape(b * n, -1, d) @@ -123,13 +117,12 @@ def rearrange_with_remaining(self, tensor): # BSND in , BSND out remaining_frames = self.frame_num - first_frame_num if remaining_frames % 2 != 0: raise ValueError(f"The number of remaining frames ({remaining_frames}) must be even to be rearranged with block size 2.") - tensor_hwt = rearrange(tensor_hwt, 'b (fn fb) (hn hb) (wn wb) d -> b (fn hn wn fb hb wb) d', - fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) + tensor_hwt = rearrange(tensor_hwt, "bn (fn fb) (hn hb) (wn wb) d -> bn (fn hn wn fb hb wb) d", fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) if h % 8 != 0: tensor_hwt = torch.cat((tensor_hwt, tensor_h_r), dim=1) if w % 8 != 0: tensor_hwt = torch.cat((tensor_hwt, tensor_w_r), dim=1) - tensor_hwt = rearrange(tensor_hwt, '(b n) s d -> b s n d', b=b, n=n) + tensor_hwt = rearrange(tensor_hwt, "(b n) s d -> b s n d", b=b, n=n) return tensor_hwt, h_res_len, w_res_len def inv_rearrange_with_remaining(self, tensor, h_res_len, w_res_len): # BSND in , BSND out @@ -140,10 +133,9 @@ def inv_rearrange_with_remaining(self, tensor, h_res_len, w_res_len): # BSND in first_frame_num = self.first_frame_len // (h * w) remaining_frames = self.frame_num - first_frame_num - tensor = rearrange(tensor, 'b s n d->(b n) s d', b=b, n=n) + tensor = rearrange(tensor, "b s n d->(b n) s d", b=b, n=n) tensor_hwt, tensor_h, tensor_w = torch.split(tensor, [s - h_res_len - w_res_len, h_res_len, w_res_len], dim=1) - tensor_hwt = rearrange(tensor_hwt, 'b (fn hn wn fb hb wb) d -> b (fn fb) (hn hb) (wn wb) d', - fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) + tensor_hwt = rearrange(tensor_hwt, "bn (fn hn wn fb hb wb) d -> bn (fn fb) (hn hb) (wn wb) d", fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) if w_res_len != 0: tensor_w = tensor_w.reshape(b * n, remaining_frames, h - h_sr, w_sr, d) tensor_hwt = torch.cat((tensor_hwt, tensor_w), dim=3) @@ -151,18 +143,17 @@ def inv_rearrange_with_remaining(self, tensor, h_res_len, w_res_len): # BSND in tensor_h = tensor_h.reshape(b * n, remaining_frames, h_sr, w, d) tensor_hwt = torch.cat((tensor_hwt, tensor_h), dim=2) tensor_hwt = tensor_hwt.reshape(b * n, -1, d) - tensor_hwt = rearrange(tensor_hwt, '(b n) s d -> b s n d', b=b, n=n) + tensor_hwt = rearrange(tensor_hwt, "(b n) s d -> b s n d", b=b, n=n) return tensor_hwt def do_tensor_rearrange_pooling(self, tensor): # BSND in , BSND out b, s, n, d = tensor.shape + if s < self.text_len + self.first_frame_len: + raise ValueError(f"Sequence length {s} is too small for text_len {self.text_len} and first_frame_len {self.first_frame_len}") if self.txt_first: - tensor_t, tensor_f, tensor_i = torch.split(tensor, [self.text_len, self.first_frame_len, - s - self.text_len - self.first_frame_len], dim=1) + tensor_t, tensor_f, tensor_i = torch.split(tensor, [self.text_len, self.first_frame_len, s - self.text_len - self.first_frame_len], dim=1) else: - tensor_f, tensor_i, tensor_t = torch.split(tensor, - [self.first_frame_len, s - self.text_len - self.first_frame_len, - self.text_len], dim=1) + tensor_f, tensor_i, tensor_t = torch.split(tensor, [self.first_frame_len, s - self.text_len - self.first_frame_len, self.text_len], dim=1) tensor_i_2, h_res_len, w_res_len = self.rearrange_with_remaining(tensor_i) tensor = torch.concat((tensor_i_2, tensor_f, tensor_t), dim=1) tensor_pool = self.avgpool(tensor, pool_size=128) @@ -170,9 +161,7 @@ def do_tensor_rearrange_pooling(self, tensor): # BSND in , BSND out def do_tensor_inv_rearrange(self, tensor, h_res_len, w_res_len): b, s, n, d = tensor.shape - tensor_i, tensor_f, tensor_t = torch.split(tensor, - [s - self.text_len - self.first_frame_len, self.first_frame_len, - self.text_len], dim=1) + tensor_i, tensor_f, tensor_t = torch.split(tensor, [s - self.text_len - self.first_frame_len, self.first_frame_len, self.text_len], dim=1) tensor_i = self.inv_rearrange_with_remaining(tensor_i, h_res_len, w_res_len) if self.txt_first: @@ -182,8 +171,8 @@ def do_tensor_inv_rearrange(self, tensor, h_res_len, w_res_len): return tensor def do_tensor_pooling(self, tensor): - tensor_t = tensor[:, :self.text_len, :, :] - tensor_i = tensor[:, self.text_len:, :, :] + tensor_t = tensor[:, : self.text_len, :, :] + tensor_i = tensor[:, self.text_len :, :, :] tensor_i_pool = self.avgpool(tensor_i, pool_size=128) tensor_t_pool = self.avgpool(tensor_t, pool_size=128) @@ -192,25 +181,24 @@ def do_tensor_pooling(self, tensor): return tensor_pool def forward( - self, - q, # BSND - k, - v, - t_b_idx, - base_blockmask, + self, + q, # BSND + k, + v, + t_b_idx, + base_blockmask, ): t_idx = t_b_idx[0] if t_idx < self.skip_timesteps: base_blockmask = None - x = attention_forward(q, k, v, - opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD") + x = attention_forward(q, k, v, opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD") else: batch, qSeqlen, numHeads, headDim = q.shape assert batch == 1, "Rainfusion_blockwise currently only supports batch size 1." _, kvSeqlen, _, _ = k.shape blockShapeX, blockShapeY = self.pool_size, self.pool_size - scale = headDim ** -0.5 + scale = headDim**-0.5 blockShape = [blockShapeX, blockShapeY] actualSeqLengthsHost = [qSeqlen for _ in range(batch)] @@ -239,7 +227,9 @@ def forward( k_bnsd = k.transpose(1, 2) v_bnsd = v.transpose(1, 2) x = mindiesd.layers.flash_attn.sparse_flash_attn_rf_v2.rain_fusion_attention( - q_bnsd, k_bnsd, v_bnsd, + q_bnsd, + k_bnsd, + v_bnsd, scale=scale, head_num=numHeads, input_layout="BNSD", @@ -247,12 +237,11 @@ def forward( select_num_idx=selectNumIdx, blockshape=blockShape, actual_seq_lengths=actualSeqLengthsHost, - actual_seq_lengths_kv=actualSeqLengthsKvHost + actual_seq_lengths_kv=actualSeqLengthsKvHost, ) x = x.transpose(1, 2).view(batch, qSeqlen, numHeads, headDim) x = self.do_tensor_inv_rearrange(x, h_res_len, w_res_len) - base_blockmask = None - return x, base_blockmask \ No newline at end of file + return x, base_blockmask diff --git a/lightx2v_platform/ops/norm/ascend_npu/__init__.py b/lightx2v_platform/ops/norm/ascend_npu/__init__.py index e522b460b..91fa05c54 100644 --- a/lightx2v_platform/ops/norm/ascend_npu/__init__.py +++ b/lightx2v_platform/ops/norm/ascend_npu/__init__.py @@ -1,4 +1,4 @@ -from .npu_rms_norm import NpuRmsNormWeight from .npu_layer_norm import NpuLayerNormWeight +from .npu_rms_norm import NpuRmsNormWeight -__all__ = ["NpuRmsNormWeight", "NpuLayerNormWeight"] \ No newline at end of file +__all__ = ["NpuRmsNormWeight", "NpuLayerNormWeight"] diff --git a/lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py b/lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py index 465424d61..136c9a19e 100644 --- a/lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py +++ b/lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py @@ -15,7 +15,18 @@ def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, c super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps) def apply(self, input_tensor): - if torch_npu is not None and hasattr(torch_npu, 'npu_layer_norm') and self.weight is not None and self.bias is not None: + if torch_npu is not None and hasattr(torch_npu, "npu_layer_norm") and self.weight is not None and self.bias is not None: + if self.sensitive_layer_dtype != self.infer_dtype: + out = torch_npu.npu_layer_norm( + input_tensor.to(self.sensitive_layer_dtype), + (input_tensor.shape[-1],), + self.weight.to(self.sensitive_layer_dtype), + self.bias.to(self.sensitive_layer_dtype), + self.eps, + ) + if isinstance(out, tuple): + out = out[0] + return out.to(self.infer_dtype) out = torch_npu.npu_layer_norm( input_tensor, (input_tensor.shape[-1],), @@ -33,4 +44,4 @@ def apply(self, input_tensor): self.bias, self.eps, ) - return output \ No newline at end of file + return output diff --git a/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py b/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py index fcac7669e..d2e78a2fd 100644 --- a/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py +++ b/lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py @@ -19,7 +19,9 @@ def apply(self, input_tensor): if torch_npu is not None and hasattr(torch_npu, "npu_rms_norm") and weight is not None: if self.sensitive_layer_dtype != self.infer_dtype: output_tensor, _ = torch_npu.npu_rms_norm( - input_tensor.float(), weight.float(), self.eps, + input_tensor.to(self.sensitive_layer_dtype), + weight.to(self.sensitive_layer_dtype), + self.eps, ) return output_tensor.to(self.infer_dtype) output_tensor, _ = torch_npu.npu_rms_norm(input_tensor, weight, self.eps) @@ -27,4 +29,4 @@ def apply(self, input_tensor): x = self._norm(input_tensor.float()) if weight is not None: return (x.float() * weight.float()).to(self.infer_dtype) - return x.to(self.infer_dtype) \ No newline at end of file + return x.to(self.infer_dtype) diff --git a/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py b/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py index 56daaf4fc..b08116811 100644 --- a/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py +++ b/lightx2v_platform/ops/rope/ascend_npu/npu_rope.py @@ -1,6 +1,7 @@ -import torch from functools import lru_cache +import torch + from lightx2v_platform.ops.rope.rope_template import GET_DTYPE, RopeTemplate from lightx2v_platform.registry_factory import PLATFORM_ROPE_REGISTER @@ -13,6 +14,7 @@ @lru_cache(maxsize=None) def GET_SENSITIVE_DTYPE(): import os + DTYPE_MAP = { "BF16": torch.bfloat16, "FP16": torch.float16, @@ -41,12 +43,8 @@ def __init__(self): def _apply_rope_fp32(self, xq, xk, cos_sin_cache): n = xq.size(1) seq_len = cos_sin_cache.size(0) - xq_fp32 = torch.view_as_complex( - xq[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2) - ) - xk_fp32 = torch.view_as_complex( - xk[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2) - ) + xq_fp32 = torch.view_as_complex(xq[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2)) + xk_fp32 = torch.view_as_complex(xk[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2)) xq_rot = torch.view_as_real(xq_fp32 * cos_sin_cache).flatten(2) xk_rot = torch.view_as_real(xk_fp32 * cos_sin_cache).flatten(2) if xq.size(0) > seq_len: @@ -66,10 +64,10 @@ def apply(self, xq: torch.Tensor, xk: torch.Tensor, cos_sin_cache: torch.Tensor) if torch_npu is not None and hasattr(torch_npu, "npu_rotary_mul"): if self.sensitive_layer_dtype != self.infer_dtype: - xq_part = xq_part.float() - xk_part = xk_part.float() - cos = cos.float() if cos.dtype != torch.float32 else cos - sin = sin.float() if sin.dtype != torch.float32 else sin + xq_part = xq_part.to(self.sensitive_layer_dtype) + xk_part = xk_part.to(self.sensitive_layer_dtype) + cos = cos.to(self.sensitive_layer_dtype) + sin = sin.to(self.sensitive_layer_dtype) if not xq_part.is_contiguous(): xq_part = xq_part.contiguous() if not xk_part.is_contiguous(): From bdd48bb6845d5b3f897b9620e7971d87df28912a Mon Sep 17 00:00:00 2001 From: liaopingbo Date: Tue, 16 Jun 2026 18:04:37 +0800 Subject: [PATCH 6/7] Update lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py index 277132ffe..b9dfd619f 100644 --- a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py +++ b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py @@ -240,7 +240,7 @@ def forward( actual_seq_lengths_kv=actualSeqLengthsKvHost, ) - x = x.transpose(1, 2).view(batch, qSeqlen, numHeads, headDim) + x = x.transpose(1, 2).reshape(batch, qSeqlen, numHeads, headDim) x = self.do_tensor_inv_rearrange(x, h_res_len, w_res_len) From b6d2a638ca4e23fc20f87018e30ed40654e0cc0d Mon Sep 17 00:00:00 2001 From: liaopingbo Date: Tue, 16 Jun 2026 18:05:29 +0800 Subject: [PATCH 7/7] feat(npu): Add 4 Ascend NPU custom operators for sparse video inference --- .../attn/ascend_npu/rainfusion_blockwise.py | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py index b9dfd619f..ec9077942 100644 --- a/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py +++ b/lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py @@ -87,7 +87,10 @@ def get_blockwise_mask(self, score_matrix, sparsity): thresholds = topk_values[..., -1:] mask = score_matrix >= thresholds - protect_len = (self.first_frame_len + self.text_len + self.pool_size - 1) // self.pool_size + total_len = self.frame_num * self.num_tokens_per_frame + self.text_len + protected_start_idx = total_len - self.first_frame_len - self.text_len + protected_start_block = protected_start_idx // self.pool_size + protect_len = cols - protected_start_block if protect_len > 0: mask[:, :, -protect_len:, :] = True @@ -155,7 +158,7 @@ def do_tensor_rearrange_pooling(self, tensor): # BSND in , BSND out else: tensor_f, tensor_i, tensor_t = torch.split(tensor, [self.first_frame_len, s - self.text_len - self.first_frame_len, self.text_len], dim=1) tensor_i_2, h_res_len, w_res_len = self.rearrange_with_remaining(tensor_i) - tensor = torch.concat((tensor_i_2, tensor_f, tensor_t), dim=1) + tensor = torch.cat((tensor_i_2, tensor_f, tensor_t), dim=1) tensor_pool = self.avgpool(tensor, pool_size=128) return tensor, tensor_pool, h_res_len, w_res_len @@ -165,21 +168,11 @@ def do_tensor_inv_rearrange(self, tensor, h_res_len, w_res_len): tensor_i = self.inv_rearrange_with_remaining(tensor_i, h_res_len, w_res_len) if self.txt_first: - tensor = torch.concat((tensor_t, tensor_f, tensor_i), dim=1) + tensor = torch.cat((tensor_t, tensor_f, tensor_i), dim=1) else: - tensor = torch.concat((tensor_f, tensor_i, tensor_t), dim=1) + tensor = torch.cat((tensor_f, tensor_i, tensor_t), dim=1) return tensor - def do_tensor_pooling(self, tensor): - tensor_t = tensor[:, : self.text_len, :, :] - tensor_i = tensor[:, self.text_len :, :, :] - - tensor_i_pool = self.avgpool(tensor_i, pool_size=128) - tensor_t_pool = self.avgpool(tensor_t, pool_size=128) - - tensor_pool = torch.concat((tensor_t_pool, tensor_i_pool), dim=1) - return tensor_pool - def forward( self, q, # BSND