Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lightx2v_platform/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
elif PLATFORM == "ascend_npu":
from .attn.ascend_npu import *
from .mm.ascend_npu import *
from .norm.ascend_npu import *
from .rope.ascend_npu import *
elif PLATFORM == "metax_cuda":
from .attn.metax_cuda import *
elif PLATFORM == "enflame_gcu":
Expand Down
262 changes: 262 additions & 0 deletions lightx2v_platform/ops/attn/ascend_npu/rainfusion_blockwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
import os
import math
import torch
import torch_npu
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import time
import mindiesd
from mindiesd.layers.flash_attn.attention_forward import attention_forward


class Rainfusion_blockwise(nn.Module):
def __init__(
self,
grid_size: list,
pool_size: int = 128,
sparsity: float = 0.9,
skip_timesteps: int = 0,
txt_len: int = 0,
txt_first: bool = False,
) -> None:
"""
参数:
grid_size (list): latents的THW网格大小。
sparsity (float, optional): 稀疏度, 取值范围[0, 1],默认为 0.50。
"""
super().__init__()

# Rainfusion_param
self.grid_size = grid_size
self.frame_num = self.grid_size[0]
self.num_tokens_per_frame = self.grid_size[1] * self.grid_size[2]
self.first_frame_len = self.num_tokens_per_frame

self.pool_size = pool_size
self.sparsity = sparsity
self.skip_timesteps = skip_timesteps
self.text_len = txt_len
self.txt_first = txt_first

@staticmethod
def get_grid_size(latent_size, patch_size):
t, h, w = latent_size[-3:]
return [t // patch_size[0], h // patch_size[1], w // patch_size[2]]

def avgpool(self, input_tensor, pool_size=128): # BSND in, BSND out
batch, seqlen, headnum, dim = input_tensor.shape

num_full_blocks = seqlen // pool_size
tail_size = seqlen % pool_size

if num_full_blocks > 0:
full_blocks = input_tensor[:, :num_full_blocks * pool_size, :, :]
full_blocks_reshaped = full_blocks.view(batch, num_full_blocks, pool_size, headnum, dim)
full_pooled = full_blocks_reshaped.mean(dim=2)
else:
full_pooled = torch.empty(0, device=input_tensor.device)
if tail_size > 0:
tail_block = input_tensor[:, num_full_blocks * pool_size:, :, :]
tail_reshaped = tail_block.view(batch, 1, tail_size, headnum, dim)
tail_pooled = tail_reshaped.mean(dim=2)
else:
tail_pooled = torch.empty(0, device=input_tensor.device)

if num_full_blocks > 0 and tail_size > 0:
output_tensor = torch.cat([full_pooled, tail_pooled], dim=1)
elif num_full_blocks > 0:
output_tensor = full_pooled
else:
output_tensor = tail_pooled

return output_tensor
Comment thread
liaopingbo marked this conversation as resolved.
Outdated

def get_mask_index(self, mask):
B, N, S, _ = mask.shape
device = mask.device

# 1. 重塑维度 → (B*N)×S×S
mask_reshaped = mask.reshape(-1, S, S)
batch_size = mask_reshaped.shape[0]

# 2. 生成行索引 标记False位置为S(大于所有有效索引)
row_indices = torch.arange(S, device=device).expand(batch_size, S, -1) # (B*N, S, S)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using .expand(batch_size, S, -1) on a 1D tensor of shape (S,) relies on implicit dimension prepending and can be fragile across different PyTorch versions or backends. Explicitly reshaping with .view(1, 1, S) before expanding is much more robust and readable.

Suggested change
row_indices = torch.arange(S, device=device).expand(batch_size, S, -1) # (B*N, S, S)
row_indices = torch.arange(S, device=device).view(1, 1, S).expand(batch_size, S, S) # (B*N, S, S)

sorted_vals = torch.where(mask_reshaped, row_indices, S)
sorted_vals, _ = torch.sort(sorted_vals, dim=-1)
valid_count = mask_reshaped.sum(dim=-1, keepdim=True)
keep_mask = row_indices < valid_count
result = torch.where(keep_mask, sorted_vals, -1)

pos_matrix = result.reshape(B, N, S, S)
return pos_matrix

def get_blockwise_mask(self, score_matrix, sparsity):
batch_size, num_heads, rows, cols = score_matrix.shape

keep_len = math.ceil(cols * (1 - sparsity))
topk_values, _ = torch.topk(score_matrix, k=keep_len, dim=-1)
Comment thread
liaopingbo marked this conversation as resolved.
thresholds = topk_values[..., -1:]
mask = score_matrix >= thresholds

protect_len = (self.first_frame_len + self.text_len + self.pool_size - 1) // self.pool_size

if protect_len > 0:
mask[:, :, -protect_len:, :] = True
mask[:, :, :, -protect_len:] = True

selectIdx = self.get_mask_index(mask)
selectIdx = selectIdx[0].transpose(0, 1)
selectNumIdx = mask[0].transpose(0, 1).sum(dim=-1)
return selectIdx, selectNumIdx

def rearrange_with_remaining(self, tensor): # BSND in , BSND out
b, s, n, d = tensor.shape
h = self.grid_size[1]
w = self.grid_size[2]
h_res_len, w_res_len = 0, 0
first_frame_num = self.first_frame_len // h // w

tensor_hwt = rearrange(tensor, 'b (f h w) n d -> (b n) f h w d', f=self.frame_num - first_frame_num, h=h, w=w)
if h % 8 != 0:
tensor_hwt, tensor_h_r = torch.split(tensor_hwt, h - (h % 8), dim=2)
tensor_h_r = tensor_h_r.reshape(b * n, -1, d)
h_res_len = tensor_h_r.shape[1]
if w % 8 != 0:
tensor_hwt, tensor_w_r = torch.split(tensor_hwt, w - (w % 8), dim=3)
tensor_w_r = tensor_w_r.reshape(b * n, -1, d)
w_res_len = tensor_w_r.shape[1]
remaining_frames = self.frame_num - first_frame_num
if remaining_frames % 2 != 0:
raise ValueError(f"The number of remaining frames ({remaining_frames}) must be even to be rearranged with block size 2.")
tensor_hwt = rearrange(tensor_hwt, 'b (fn fb) (hn hb) (wn wb) d -> b (fn hn wn fb hb wb) d',
fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8)
Comment on lines +126 to +127

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using b as the first dimension name in the einops pattern is highly confusing because b in the outer scope represents the batch size, whereas the actual size of the first dimension here is b * n. Using a distinct name like bn avoids any confusion.

Suggested change
tensor_hwt = rearrange(tensor_hwt, 'b (fn fb) (hn hb) (wn wb) d -> b (fn hn wn fb hb wb) d',
fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8)
tensor_hwt = rearrange(tensor_hwt, 'bn (fn fb) (hn hb) (wn wb) d -> bn (fn hn wn fb hb wb) d',
fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8)

if h % 8 != 0:
tensor_hwt = torch.cat((tensor_hwt, tensor_h_r), dim=1)
if w % 8 != 0:
tensor_hwt = torch.cat((tensor_hwt, tensor_w_r), dim=1)
tensor_hwt = rearrange(tensor_hwt, '(b n) s d -> b s n d', b=b, n=n)
return tensor_hwt, h_res_len, w_res_len

def inv_rearrange_with_remaining(self, tensor, h_res_len, w_res_len): # BSND in , BSND out
b, s, n, d = tensor.shape
h = self.grid_size[1]
w = self.grid_size[2]
h_sr, w_sr = h % 8, w % 8

tensor = rearrange(tensor, 'b s n d->(b n) s d', b=b, n=n)
tensor_hwt, tensor_h, tensor_w = torch.split(tensor, [s - h_res_len - w_res_len, h_res_len, w_res_len], dim=1)
tensor_hwt = rearrange(tensor_hwt, 'b (fn hn wn fb hb wb) d -> b (fn fb) (hn hb) (wn wb) d',
fn=(self.frame_num - 1) // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8)
if w_res_len != 0:
tensor_w = tensor_w.reshape(b * n, self.frame_num - 1, h - h_sr, w_sr, d)
tensor_hwt = torch.cat((tensor_hwt, tensor_w), dim=3)
if h_sr != 0:
tensor_h = tensor_h.reshape(b * n, self.frame_num - 1, h_sr, w, d)
tensor_hwt = torch.cat((tensor_hwt, tensor_h), dim=2)
tensor_hwt = tensor_hwt.reshape(b * n, -1, d)
tensor_hwt = rearrange(tensor_hwt, '(b n) s d -> b s n d', b=b, n=n)
return tensor_hwt
Comment thread
liaopingbo marked this conversation as resolved.

def do_tensor_rearrange_pooling(self, tensor): # BSND in , BSND out
b, s, n, d = tensor.shape
if self.txt_first:
tensor_t, tensor_f, tensor_i = torch.split(tensor, [self.text_len, self.first_frame_len,
s - self.text_len - self.first_frame_len], dim=1)
else:
tensor_f, tensor_i, tensor_t = torch.split(tensor,
[self.first_frame_len, s - self.text_len - self.first_frame_len,
self.text_len], dim=1)
Comment on lines +159 to +165

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If the sequence length s is smaller than self.text_len + self.first_frame_len, torch.split will receive a negative split size and raise a cryptic PyTorch runtime error. Adding a defensive check prevents this.

Suggested change
if self.txt_first:
tensor_t, tensor_f, tensor_i = torch.split(tensor, [self.text_len, self.first_frame_len,
s - self.text_len - self.first_frame_len], dim=1)
else:
tensor_f, tensor_i, tensor_t = torch.split(tensor,
[self.first_frame_len, s - self.text_len - self.first_frame_len,
self.text_len], dim=1)
if s < self.text_len + self.first_frame_len:
raise ValueError(f"Sequence length {s} is too small for text_len {self.text_len} and first_frame_len {self.first_frame_len}")
if self.txt_first:
tensor_t, tensor_f, tensor_i = torch.split(tensor, [self.text_len, self.first_frame_len,
s - self.text_len - self.first_frame_len], dim=1)
else:
tensor_f, tensor_i, tensor_t = torch.split(tensor,
[self.first_frame_len, s - self.text_len - self.first_frame_len,
self.text_len], dim=1)

tensor_i_2, h_res_len, w_res_len = self.rearrange_with_remaining(tensor_i)
tensor = torch.concat((tensor_i_2, tensor_f, tensor_t), dim=1)
tensor_pool = self.avgpool(tensor, pool_size=128)
return tensor, tensor_pool, h_res_len, w_res_len

def do_tensor_inv_rearrange(self, tensor, h_res_len, w_res_len):
b, s, n, d = tensor.shape
tensor_i, tensor_f, tensor_t = torch.split(tensor,
[s - self.text_len - self.first_frame_len, self.first_frame_len,
self.text_len], dim=1)
tensor_i = self.inv_rearrange_with_remaining(tensor_i, h_res_len, w_res_len)

if self.txt_first:
tensor = torch.concat((tensor_t, tensor_f, tensor_i), dim=1)
else:
tensor = torch.concat((tensor_f, tensor_i, tensor_t), dim=1)
return tensor

def do_tensor_pooling(self, tensor):
tensor_t = tensor[:, :self.text_len, :, :]
tensor_i = tensor[:, self.text_len:, :, :]

tensor_i_pool = self.avgpool(tensor_i, pool_size=128)
tensor_t_pool = self.avgpool(tensor_t, pool_size=128)

tensor_pool = torch.concat((tensor_t_pool, tensor_i_pool), dim=1)
return tensor_pool

def forward(
self,
q, # BSND
k,
v,
t_b_idx,
base_blockmask,
):
t_idx = t_b_idx[0]

if t_idx < self.skip_timesteps:
base_blockmask = None
x = attention_forward(q, k, v,
opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD")
else:
batch, qSeqlen, numHeads, headDim = q.shape
assert batch == 1, "Rainfusion_blockwise currently only supports batch size 1."
_, kvSeqlen, _, _ = k.shape
Comment thread
liaopingbo marked this conversation as resolved.
blockShapeX, blockShapeY = self.pool_size, self.pool_size
scale = headDim ** -0.5

blockShape = [blockShapeX, blockShapeY]
actualSeqLengthsHost = [qSeqlen for _ in range(batch)]
actualSeqLengthsKvHost = [kvSeqlen for _ in range(batch)]

sparsity = self.sparsity

h_res_len, w_res_len = 0, 0
qkv = torch.cat((q, k, v), dim=0)
qkv, qkv_pool, h_res_len, w_res_len = self.do_tensor_rearrange_pooling(qkv)
q, k, v = torch.chunk(qkv, 3, dim=0)

if base_blockmask is None:
query_pool, key_pool, value_pool = torch.chunk(qkv_pool, 3, dim=0)

attn_scores_head = torch.einsum("blnd,bsnd->bnls", query_pool, key_pool) * scale
attn_scores_fake = torch.nn.functional.softmax(attn_scores_head, dim=-1)

selectIdx, selectNumIdx = self.get_blockwise_mask(attn_scores_fake, sparsity)
base_blockmask = [selectIdx, selectNumIdx]
else:
selectIdx = base_blockmask[0]
selectNumIdx = base_blockmask[1]

q_bnsd = q.transpose(1, 2)
k_bnsd = k.transpose(1, 2)
v_bnsd = v.transpose(1, 2)
x = mindiesd.layers.flash_attn.sparse_flash_attn_rf_v2.rain_fusion_attention(
q_bnsd, k_bnsd, v_bnsd,
scale=scale,
head_num=numHeads,
input_layout="BNSD",
select_idx=selectIdx,
select_num_idx=selectNumIdx,
blockshape=blockShape,
actual_seq_lengths=actualSeqLengthsHost,
actual_seq_lengths_kv=actualSeqLengthsKvHost
)

x = x.transpose(1, 2).view(batch, qSeqlen, numHeads, headDim)

x = self.do_tensor_inv_rearrange(x, h_res_len, w_res_len)
base_blockmask = None
Comment on lines +255 to +256

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting base_blockmask = None right before returning completely overwrites the computed block mask, meaning the caller will always receive None and can never cache or reuse the mask. This defeats the caching mechanism and causes redundant mask computations on every step.

Suggested change
x = self.do_tensor_inv_rearrange(x, h_res_len, w_res_len)
base_blockmask = None
x = self.do_tensor_inv_rearrange(x, h_res_len, w_res_len)


return x, base_blockmask
4 changes: 4 additions & 0 deletions lightx2v_platform/ops/norm/ascend_npu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .npu_rms_norm import NpuRmsNormWeight
from .npu_layer_norm import NpuLayerNormWeight

__all__ = ["NpuRmsNormWeight", "NpuLayerNormWeight"]
36 changes: 36 additions & 0 deletions lightx2v_platform/ops/norm/ascend_npu/npu_layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch

from lightx2v_platform.ops.norm.norm_template import LayerNormWeightTemplate
from lightx2v_platform.registry_factory import PLATFORM_LAYERNORM_WEIGHT_REGISTER

try:
import torch_npu
except ImportError:
torch_npu = None


@PLATFORM_LAYERNORM_WEIGHT_REGISTER("npu_layer_norm")
class NpuLayerNormWeight(LayerNormWeightTemplate):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6, **kwargs):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)

def apply(self, input_tensor):
if torch_npu is not None and hasattr(torch_npu, 'npu_layer_norm') and self.weight is not None and self.bias is not None:
out = torch_npu.npu_layer_norm(
input_tensor,
(input_tensor.shape[-1],),
self.weight,
self.bias,
self.eps,
)
if isinstance(out, tuple):
out = out[0]
return out
Comment on lines +18 to +28

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

NpuLayerNormWeight does not respect self.sensitive_layer_dtype when it differs from self.infer_dtype, unlike NpuRmsNormWeight and NpuRope. It should be updated to cast inputs, weight, and bias to self.sensitive_layer_dtype for consistency and correctness.

        if torch_npu is not None and hasattr(torch_npu, 'npu_layer_norm') and self.weight is not None and self.bias is not None:
            if self.sensitive_layer_dtype != self.infer_dtype:
                out = torch_npu.npu_layer_norm(
                    input_tensor.to(self.sensitive_layer_dtype),
                    (input_tensor.shape[-1],),
                    self.weight.to(self.sensitive_layer_dtype),
                    self.bias.to(self.sensitive_layer_dtype),
                    self.eps,
                )
                if isinstance(out, tuple):
                    out = out[0]
                return out.to(self.infer_dtype)
            out = torch_npu.npu_layer_norm(
                input_tensor,
                (input_tensor.shape[-1],),
                self.weight,
                self.bias,
                self.eps,
            )
            if isinstance(out, tuple):
                out = out[0]
            return out

output = torch.nn.functional.layer_norm(
input_tensor,
(input_tensor.shape[-1],),
self.weight,
self.bias,
self.eps,
)
return output
30 changes: 30 additions & 0 deletions lightx2v_platform/ops/norm/ascend_npu/npu_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch

from lightx2v_platform.ops.norm.norm_template import RMSWeightTemplate
from lightx2v_platform.registry_factory import PLATFORM_RMS_WEIGHT_REGISTER

try:
import torch_npu
except ImportError:
torch_npu = None


@PLATFORM_RMS_WEIGHT_REGISTER("npu_rms_norm")
class NpuRmsNormWeight(RMSWeightTemplate):
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)

def apply(self, input_tensor):
weight = self.weight
if torch_npu is not None and hasattr(torch_npu, "npu_rms_norm") and weight is not None:
if self.sensitive_layer_dtype != self.infer_dtype:
output_tensor, _ = torch_npu.npu_rms_norm(
input_tensor.float(), weight.float(), self.eps,
)
return output_tensor.to(self.infer_dtype)
Comment on lines +20 to +24

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When self.sensitive_layer_dtype != self.infer_dtype, the code casts the inputs to .float() (which is float32) instead of self.sensitive_layer_dtype. To respect the configured sensitive layer dtype (which could be float16 or bfloat16), cast to self.sensitive_layer_dtype instead.

Suggested change
if self.sensitive_layer_dtype != self.infer_dtype:
output_tensor, _ = torch_npu.npu_rms_norm(
input_tensor.float(), weight.float(), self.eps,
)
return output_tensor.to(self.infer_dtype)
if self.sensitive_layer_dtype != self.infer_dtype:
output_tensor, _ = torch_npu.npu_rms_norm(
input_tensor.to(self.sensitive_layer_dtype),
weight.to(self.sensitive_layer_dtype),
self.eps,
)
return output_tensor.to(self.infer_dtype)

output_tensor, _ = torch_npu.npu_rms_norm(input_tensor, weight, self.eps)
return output_tensor
x = self._norm(input_tensor.float())
if weight is not None:
return (x.float() * weight.float()).to(self.infer_dtype)
return x.to(self.infer_dtype)
3 changes: 3 additions & 0 deletions lightx2v_platform/ops/rope/ascend_npu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .npu_rope import NpuRope

__all__ = ["NpuRope"]
Loading
Loading