Skip to content
2 changes: 2 additions & 0 deletions lightx2v_platform/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
elif PLATFORM == "ascend_npu":
from .attn.ascend_npu import *
from .mm.ascend_npu import *
from .norm.ascend_npu import *
from .rope.ascend_npu import *
elif PLATFORM == "metax_cuda":
from .attn.metax_cuda import *
elif PLATFORM == "enflame_gcu":
Expand Down
240 changes: 240 additions & 0 deletions lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import math

import mindiesd
import torch
import torch.nn as nn
from einops import rearrange
from mindiesd.layers.flash_attn.attention_forward import attention_forward


class Rainfusion_blockwise(nn.Module):
def __init__(
self,
grid_size: list,
pool_size: int = 128,
sparsity: float = 0.9,
skip_timesteps: int = 0,
txt_len: int = 0,
txt_first: bool = False,
) -> None:
"""
参数:
grid_size (list): latents的THW网格大小。
sparsity (float, optional): 稀疏度, 取值范围[0, 1],默认为 0.50。
"""
super().__init__()

# Rainfusion_param
self.grid_size = grid_size
self.frame_num = self.grid_size[0]
self.num_tokens_per_frame = self.grid_size[1] * self.grid_size[2]
self.first_frame_len = self.num_tokens_per_frame

self.pool_size = pool_size
self.sparsity = sparsity
self.skip_timesteps = skip_timesteps
self.text_len = txt_len
self.txt_first = txt_first

@staticmethod
def get_grid_size(latent_size, patch_size):
t, h, w = latent_size[-3:]
return [t // patch_size[0], h // patch_size[1], w // patch_size[2]]

def avgpool(self, input_tensor, pool_size=128): # BSND in, BSND out
batch, seqlen, headnum, dim = input_tensor.shape

num_full_blocks = seqlen // pool_size
tail_size = seqlen % pool_size

pooled_tensors = []
if num_full_blocks > 0:
full_blocks = input_tensor[:, : num_full_blocks * pool_size, :, :]
full_blocks_reshaped = full_blocks.view(batch, num_full_blocks, pool_size, headnum, dim)
pooled_tensors.append(full_blocks_reshaped.mean(dim=2))
if tail_size > 0:
tail_block = input_tensor[:, num_full_blocks * pool_size :, :, :]
tail_reshaped = tail_block.view(batch, 1, tail_size, headnum, dim)
pooled_tensors.append(tail_reshaped.mean(dim=2))
Comment thread
liaopingbo marked this conversation as resolved.
Comment on lines +51 to +58

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using .view() on sliced tensors (like full_blocks and tail_block) can cause runtime crashes if the tensor is not contiguous in memory. Slicing along the sequence dimension often results in non-contiguous memory layouts. It is safer to use .reshape() instead of .view(), as .reshape() automatically handles non-contiguous tensors by copying them if necessary.

Suggested change
if num_full_blocks > 0:
full_blocks = input_tensor[:, : num_full_blocks * pool_size, :, :]
full_blocks_reshaped = full_blocks.view(batch, num_full_blocks, pool_size, headnum, dim)
pooled_tensors.append(full_blocks_reshaped.mean(dim=2))
if tail_size > 0:
tail_block = input_tensor[:, num_full_blocks * pool_size :, :, :]
tail_reshaped = tail_block.view(batch, 1, tail_size, headnum, dim)
pooled_tensors.append(tail_reshaped.mean(dim=2))
if num_full_blocks > 0:
full_blocks = input_tensor[:, : num_full_blocks * pool_size, :, :]
full_blocks_reshaped = full_blocks.reshape(batch, num_full_blocks, pool_size, headnum, dim)
pooled_tensors.append(full_blocks_reshaped.mean(dim=2))
if tail_size > 0:
tail_block = input_tensor[:, num_full_blocks * pool_size :, :, :]
tail_reshaped = tail_block.reshape(batch, 1, tail_size, headnum, dim)
pooled_tensors.append(tail_reshaped.mean(dim=2))


return torch.cat(pooled_tensors, dim=1)

def get_mask_index(self, mask):
B, N, S, _ = mask.shape
device = mask.device

# 1. 重塑维度 → (B*N)×S×S
mask_reshaped = mask.reshape(-1, S, S)
batch_size = mask_reshaped.shape[0]

# 2. 生成行索引 标记False位置为S(大于所有有效索引)
row_indices = torch.arange(S, device=device).view(1, 1, S).expand(batch_size, S, S) # (B*N, S, S)
sorted_vals = torch.where(mask_reshaped, row_indices, S)
sorted_vals, _ = torch.sort(sorted_vals, dim=-1)
valid_count = mask_reshaped.sum(dim=-1, keepdim=True)
keep_mask = row_indices < valid_count
result = torch.where(keep_mask, sorted_vals, -1)

pos_matrix = result.reshape(B, N, S, S)
return pos_matrix

def get_blockwise_mask(self, score_matrix, sparsity):
batch_size, num_heads, rows, cols = score_matrix.shape

keep_len = math.ceil(cols * (1 - sparsity))
keep_len = max(1, min(keep_len, cols))
topk_values, _ = torch.topk(score_matrix, k=keep_len, dim=-1)
thresholds = topk_values[..., -1:]
mask = score_matrix >= thresholds

total_len = self.frame_num * self.num_tokens_per_frame + self.text_len
protected_start_idx = total_len - self.first_frame_len - self.text_len
protected_start_block = protected_start_idx // self.pool_size
protect_len = cols - protected_start_block

if protect_len > 0:
mask[:, :, -protect_len:, :] = True
mask[:, :, :, -protect_len:] = True

selectIdx = self.get_mask_index(mask)
selectIdx = selectIdx[0].transpose(0, 1)
selectNumIdx = mask[0].transpose(0, 1).sum(dim=-1)
return selectIdx, selectNumIdx

def rearrange_with_remaining(self, tensor): # BSND in , BSND out
b, s, n, d = tensor.shape
h = self.grid_size[1]
w = self.grid_size[2]
h_res_len, w_res_len = 0, 0
first_frame_num = self.first_frame_len // h // w

tensor_hwt = rearrange(tensor, "b (f h w) n d -> (b n) f h w d", f=self.frame_num - first_frame_num, h=h, w=w)
if h % 8 != 0:
tensor_hwt, tensor_h_r = torch.split(tensor_hwt, h - (h % 8), dim=2)
tensor_h_r = tensor_h_r.reshape(b * n, -1, d)
h_res_len = tensor_h_r.shape[1]
if w % 8 != 0:
tensor_hwt, tensor_w_r = torch.split(tensor_hwt, w - (w % 8), dim=3)
tensor_w_r = tensor_w_r.reshape(b * n, -1, d)
w_res_len = tensor_w_r.shape[1]
remaining_frames = self.frame_num - first_frame_num
if remaining_frames % 2 != 0:
raise ValueError(f"The number of remaining frames ({remaining_frames}) must be even to be rearranged with block size 2.")
tensor_hwt = rearrange(tensor_hwt, "bn (fn fb) (hn hb) (wn wb) d -> bn (fn hn wn fb hb wb) d", fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8)
if h % 8 != 0:
tensor_hwt = torch.cat((tensor_hwt, tensor_h_r), dim=1)
if w % 8 != 0:
tensor_hwt = torch.cat((tensor_hwt, tensor_w_r), dim=1)
tensor_hwt = rearrange(tensor_hwt, "(b n) s d -> b s n d", b=b, n=n)
return tensor_hwt, h_res_len, w_res_len

def inv_rearrange_with_remaining(self, tensor, h_res_len, w_res_len): # BSND in , BSND out
b, s, n, d = tensor.shape
h = self.grid_size[1]
w = self.grid_size[2]
h_sr, w_sr = h % 8, w % 8
first_frame_num = self.first_frame_len // (h * w)
remaining_frames = self.frame_num - first_frame_num

tensor = rearrange(tensor, "b s n d->(b n) s d", b=b, n=n)
tensor_hwt, tensor_h, tensor_w = torch.split(tensor, [s - h_res_len - w_res_len, h_res_len, w_res_len], dim=1)
tensor_hwt = rearrange(tensor_hwt, "bn (fn hn wn fb hb wb) d -> bn (fn fb) (hn hb) (wn wb) d", fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8)
if w_res_len != 0:
tensor_w = tensor_w.reshape(b * n, remaining_frames, h - h_sr, w_sr, d)
tensor_hwt = torch.cat((tensor_hwt, tensor_w), dim=3)
if h_sr != 0:
tensor_h = tensor_h.reshape(b * n, remaining_frames, h_sr, w, d)
tensor_hwt = torch.cat((tensor_hwt, tensor_h), dim=2)
tensor_hwt = tensor_hwt.reshape(b * n, -1, d)
tensor_hwt = rearrange(tensor_hwt, "(b n) s d -> b s n d", b=b, n=n)
return tensor_hwt

def do_tensor_rearrange_pooling(self, tensor): # BSND in , BSND out
b, s, n, d = tensor.shape
if s < self.text_len + self.first_frame_len:
raise ValueError(f"Sequence length {s} is too small for text_len {self.text_len} and first_frame_len {self.first_frame_len}")
if self.txt_first:
tensor_t, tensor_f, tensor_i = torch.split(tensor, [self.text_len, self.first_frame_len, s - self.text_len - self.first_frame_len], dim=1)
else:
tensor_f, tensor_i, tensor_t = torch.split(tensor, [self.first_frame_len, s - self.text_len - self.first_frame_len, self.text_len], dim=1)
tensor_i_2, h_res_len, w_res_len = self.rearrange_with_remaining(tensor_i)
tensor = torch.cat((tensor_i_2, tensor_f, tensor_t), dim=1)
tensor_pool = self.avgpool(tensor, pool_size=128)
return tensor, tensor_pool, h_res_len, w_res_len

def do_tensor_inv_rearrange(self, tensor, h_res_len, w_res_len):
b, s, n, d = tensor.shape
tensor_i, tensor_f, tensor_t = torch.split(tensor, [s - self.text_len - self.first_frame_len, self.first_frame_len, self.text_len], dim=1)
tensor_i = self.inv_rearrange_with_remaining(tensor_i, h_res_len, w_res_len)

if self.txt_first:
tensor = torch.cat((tensor_t, tensor_f, tensor_i), dim=1)
else:
tensor = torch.cat((tensor_f, tensor_i, tensor_t), dim=1)
return tensor

def forward(
self,
q, # BSND
k,
v,
t_b_idx,
base_blockmask,
):
t_idx = t_b_idx[0]

if t_idx < self.skip_timesteps:
base_blockmask = None
x = attention_forward(q, k, v, opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD")
else:
batch, qSeqlen, numHeads, headDim = q.shape
assert batch == 1, "Rainfusion_blockwise currently only supports batch size 1."

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using assert statements for runtime input validation is discouraged because they can be optimized away when Python is run with the -O (optimize) flag. This would bypass the batch size check and could lead to unexpected behavior or crashes later in the execution. Please replace this with an explicit if check and raise a ValueError.

Suggested change
assert batch == 1, "Rainfusion_blockwise currently only supports batch size 1."
if batch != 1:
raise ValueError("Rainfusion_blockwise currently only supports batch size 1.")

_, kvSeqlen, _, _ = k.shape
blockShapeX, blockShapeY = self.pool_size, self.pool_size
scale = headDim**-0.5

blockShape = [blockShapeX, blockShapeY]
actualSeqLengthsHost = [qSeqlen for _ in range(batch)]
actualSeqLengthsKvHost = [kvSeqlen for _ in range(batch)]

sparsity = self.sparsity

h_res_len, w_res_len = 0, 0
qkv = torch.cat((q, k, v), dim=0)
qkv, qkv_pool, h_res_len, w_res_len = self.do_tensor_rearrange_pooling(qkv)
q, k, v = torch.chunk(qkv, 3, dim=0)

if base_blockmask is None:
query_pool, key_pool, value_pool = torch.chunk(qkv_pool, 3, dim=0)

attn_scores_head = torch.einsum("blnd,bsnd->bnls", query_pool, key_pool) * scale
attn_scores_fake = torch.nn.functional.softmax(attn_scores_head, dim=-1)

selectIdx, selectNumIdx = self.get_blockwise_mask(attn_scores_fake, sparsity)
base_blockmask = [selectIdx, selectNumIdx]
else:
selectIdx = base_blockmask[0]
selectNumIdx = base_blockmask[1]

q_bnsd = q.transpose(1, 2)
k_bnsd = k.transpose(1, 2)
v_bnsd = v.transpose(1, 2)
x = mindiesd.layers.flash_attn.sparse_flash_attn_rf_v2.rain_fusion_attention(
q_bnsd,
k_bnsd,
v_bnsd,
scale=scale,
head_num=numHeads,
input_layout="BNSD",
select_idx=selectIdx,
select_num_idx=selectNumIdx,
blockshape=blockShape,
actual_seq_lengths=actualSeqLengthsHost,
actual_seq_lengths_kv=actualSeqLengthsKvHost,
)

x = x.transpose(1, 2).reshape(batch, qSeqlen, numHeads, headDim)
Comment thread
liaopingbo marked this conversation as resolved.

x = self.do_tensor_inv_rearrange(x, h_res_len, w_res_len)

return x, base_blockmask
4 changes: 4 additions & 0 deletions lightx2v_platform/ops/norm/ascend_npu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .npu_layer_norm import NpuLayerNormWeight
from .npu_rms_norm import NpuRmsNormWeight

__all__ = ["NpuRmsNormWeight", "NpuLayerNormWeight"]
47 changes: 47 additions & 0 deletions lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch

from lightx2v_platform.ops.norm.norm_template import LayerNormWeightTemplate
from lightx2v_platform.registry_factory import PLATFORM_LAYERNORM_WEIGHT_REGISTER

try:
import torch_npu
except ImportError:
torch_npu = None


@PLATFORM_LAYERNORM_WEIGHT_REGISTER("npu_layer_norm")
class NpuLayerNormWeight(LayerNormWeightTemplate):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6, **kwargs):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)

def apply(self, input_tensor):
if torch_npu is not None and hasattr(torch_npu, "npu_layer_norm") and self.weight is not None and self.bias is not None:
if self.sensitive_layer_dtype != self.infer_dtype:
out = torch_npu.npu_layer_norm(
input_tensor.to(self.sensitive_layer_dtype),
(input_tensor.shape[-1],),
self.weight.to(self.sensitive_layer_dtype),
self.bias.to(self.sensitive_layer_dtype),
self.eps,
)
if isinstance(out, tuple):
out = out[0]
return out.to(self.infer_dtype)
out = torch_npu.npu_layer_norm(
input_tensor,
(input_tensor.shape[-1],),
self.weight,
self.bias,
self.eps,
)
if isinstance(out, tuple):
out = out[0]
return out
output = torch.nn.functional.layer_norm(
input_tensor,
(input_tensor.shape[-1],),
self.weight,
self.bias,
self.eps,
)
return output
32 changes: 32 additions & 0 deletions lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch

from lightx2v_platform.ops.norm.norm_template import RMSWeightTemplate
from lightx2v_platform.registry_factory import PLATFORM_RMS_WEIGHT_REGISTER

try:
import torch_npu
except ImportError:
torch_npu = None


@PLATFORM_RMS_WEIGHT_REGISTER("npu_rms_norm")
class NpuRmsNormWeight(RMSWeightTemplate):
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)

def apply(self, input_tensor):
weight = self.weight
if torch_npu is not None and hasattr(torch_npu, "npu_rms_norm") and weight is not None:
if self.sensitive_layer_dtype != self.infer_dtype:
output_tensor, _ = torch_npu.npu_rms_norm(
input_tensor.to(self.sensitive_layer_dtype),
weight.to(self.sensitive_layer_dtype),
self.eps,
)
return output_tensor.to(self.infer_dtype)
output_tensor, _ = torch_npu.npu_rms_norm(input_tensor, weight, self.eps)
return output_tensor
x = self._norm(input_tensor.float())
if weight is not None:
return (x.float() * weight.float()).to(self.infer_dtype)
return x.to(self.infer_dtype)
3 changes: 3 additions & 0 deletions lightx2v_platform/ops/rope/ascend_npu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .npu_rope import NpuRope

__all__ = ["NpuRope"]
Loading
Loading