-
Notifications
You must be signed in to change notification settings - Fork 218
feat(npu): Add 4 Ascend NPU custom operators for sparse video inference #1152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
d29185f
273957e
b7fb73d
5f78da6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,262 @@ | ||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||
| import math | ||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||
| import torch_npu | ||||||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||||||
| import torch.nn.functional as F | ||||||||||||||||||||||||||||||||||
| from einops import rearrange | ||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||
| import mindiesd | ||||||||||||||||||||||||||||||||||
| from mindiesd.layers.flash_attn.attention_forward import attention_forward | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| class Rainfusion_blockwise(nn.Module): | ||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||
| grid_size: list, | ||||||||||||||||||||||||||||||||||
| pool_size: int = 128, | ||||||||||||||||||||||||||||||||||
| sparsity: float = 0.9, | ||||||||||||||||||||||||||||||||||
| skip_timesteps: int = 0, | ||||||||||||||||||||||||||||||||||
| txt_len: int = 0, | ||||||||||||||||||||||||||||||||||
| txt_first: bool = False, | ||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| 参数: | ||||||||||||||||||||||||||||||||||
| grid_size (list): latents的THW网格大小。 | ||||||||||||||||||||||||||||||||||
| sparsity (float, optional): 稀疏度, 取值范围[0, 1],默认为 0.50。 | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Rainfusion_param | ||||||||||||||||||||||||||||||||||
| self.grid_size = grid_size | ||||||||||||||||||||||||||||||||||
| self.frame_num = self.grid_size[0] | ||||||||||||||||||||||||||||||||||
| self.num_tokens_per_frame = self.grid_size[1] * self.grid_size[2] | ||||||||||||||||||||||||||||||||||
| self.first_frame_len = self.num_tokens_per_frame | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| self.pool_size = pool_size | ||||||||||||||||||||||||||||||||||
| self.sparsity = sparsity | ||||||||||||||||||||||||||||||||||
| self.skip_timesteps = skip_timesteps | ||||||||||||||||||||||||||||||||||
| self.text_len = txt_len | ||||||||||||||||||||||||||||||||||
| self.txt_first = txt_first | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||
| def get_grid_size(latent_size, patch_size): | ||||||||||||||||||||||||||||||||||
| t, h, w = latent_size[-3:] | ||||||||||||||||||||||||||||||||||
| return [t // patch_size[0], h // patch_size[1], w // patch_size[2]] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def avgpool(self, input_tensor, pool_size=128): # BSND in, BSND out | ||||||||||||||||||||||||||||||||||
| batch, seqlen, headnum, dim = input_tensor.shape | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| num_full_blocks = seqlen // pool_size | ||||||||||||||||||||||||||||||||||
| tail_size = seqlen % pool_size | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if num_full_blocks > 0: | ||||||||||||||||||||||||||||||||||
| full_blocks = input_tensor[:, :num_full_blocks * pool_size, :, :] | ||||||||||||||||||||||||||||||||||
| full_blocks_reshaped = full_blocks.view(batch, num_full_blocks, pool_size, headnum, dim) | ||||||||||||||||||||||||||||||||||
| full_pooled = full_blocks_reshaped.mean(dim=2) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| full_pooled = torch.empty(0, device=input_tensor.device) | ||||||||||||||||||||||||||||||||||
| if tail_size > 0: | ||||||||||||||||||||||||||||||||||
| tail_block = input_tensor[:, num_full_blocks * pool_size:, :, :] | ||||||||||||||||||||||||||||||||||
| tail_reshaped = tail_block.view(batch, 1, tail_size, headnum, dim) | ||||||||||||||||||||||||||||||||||
| tail_pooled = tail_reshaped.mean(dim=2) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| tail_pooled = torch.empty(0, device=input_tensor.device) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if num_full_blocks > 0 and tail_size > 0: | ||||||||||||||||||||||||||||||||||
| output_tensor = torch.cat([full_pooled, tail_pooled], dim=1) | ||||||||||||||||||||||||||||||||||
| elif num_full_blocks > 0: | ||||||||||||||||||||||||||||||||||
| output_tensor = full_pooled | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| output_tensor = tail_pooled | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return output_tensor | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def get_mask_index(self, mask): | ||||||||||||||||||||||||||||||||||
| B, N, S, _ = mask.shape | ||||||||||||||||||||||||||||||||||
| device = mask.device | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # 1. 重塑维度 → (B*N)×S×S | ||||||||||||||||||||||||||||||||||
| mask_reshaped = mask.reshape(-1, S, S) | ||||||||||||||||||||||||||||||||||
| batch_size = mask_reshaped.shape[0] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # 2. 生成行索引 标记False位置为S(大于所有有效索引) | ||||||||||||||||||||||||||||||||||
| row_indices = torch.arange(S, device=device).expand(batch_size, S, -1) # (B*N, S, S) | ||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||||||||||||||||||||||||||
| sorted_vals = torch.where(mask_reshaped, row_indices, S) | ||||||||||||||||||||||||||||||||||
| sorted_vals, _ = torch.sort(sorted_vals, dim=-1) | ||||||||||||||||||||||||||||||||||
| valid_count = mask_reshaped.sum(dim=-1, keepdim=True) | ||||||||||||||||||||||||||||||||||
| keep_mask = row_indices < valid_count | ||||||||||||||||||||||||||||||||||
| result = torch.where(keep_mask, sorted_vals, -1) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| pos_matrix = result.reshape(B, N, S, S) | ||||||||||||||||||||||||||||||||||
| return pos_matrix | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def get_blockwise_mask(self, score_matrix, sparsity): | ||||||||||||||||||||||||||||||||||
| batch_size, num_heads, rows, cols = score_matrix.shape | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| keep_len = math.ceil(cols * (1 - sparsity)) | ||||||||||||||||||||||||||||||||||
| topk_values, _ = torch.topk(score_matrix, k=keep_len, dim=-1) | ||||||||||||||||||||||||||||||||||
|
liaopingbo marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||
| thresholds = topk_values[..., -1:] | ||||||||||||||||||||||||||||||||||
| mask = score_matrix >= thresholds | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| protect_len = (self.first_frame_len + self.text_len + self.pool_size - 1) // self.pool_size | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if protect_len > 0: | ||||||||||||||||||||||||||||||||||
| mask[:, :, -protect_len:, :] = True | ||||||||||||||||||||||||||||||||||
| mask[:, :, :, -protect_len:] = True | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| selectIdx = self.get_mask_index(mask) | ||||||||||||||||||||||||||||||||||
| selectIdx = selectIdx[0].transpose(0, 1) | ||||||||||||||||||||||||||||||||||
| selectNumIdx = mask[0].transpose(0, 1).sum(dim=-1) | ||||||||||||||||||||||||||||||||||
| return selectIdx, selectNumIdx | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def rearrange_with_remaining(self, tensor): # BSND in , BSND out | ||||||||||||||||||||||||||||||||||
| b, s, n, d = tensor.shape | ||||||||||||||||||||||||||||||||||
| h = self.grid_size[1] | ||||||||||||||||||||||||||||||||||
| w = self.grid_size[2] | ||||||||||||||||||||||||||||||||||
| h_res_len, w_res_len = 0, 0 | ||||||||||||||||||||||||||||||||||
| first_frame_num = self.first_frame_len // h // w | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| tensor_hwt = rearrange(tensor, 'b (f h w) n d -> (b n) f h w d', f=self.frame_num - first_frame_num, h=h, w=w) | ||||||||||||||||||||||||||||||||||
| if h % 8 != 0: | ||||||||||||||||||||||||||||||||||
| tensor_hwt, tensor_h_r = torch.split(tensor_hwt, h - (h % 8), dim=2) | ||||||||||||||||||||||||||||||||||
| tensor_h_r = tensor_h_r.reshape(b * n, -1, d) | ||||||||||||||||||||||||||||||||||
| h_res_len = tensor_h_r.shape[1] | ||||||||||||||||||||||||||||||||||
| if w % 8 != 0: | ||||||||||||||||||||||||||||||||||
| tensor_hwt, tensor_w_r = torch.split(tensor_hwt, w - (w % 8), dim=3) | ||||||||||||||||||||||||||||||||||
| tensor_w_r = tensor_w_r.reshape(b * n, -1, d) | ||||||||||||||||||||||||||||||||||
| w_res_len = tensor_w_r.shape[1] | ||||||||||||||||||||||||||||||||||
| remaining_frames = self.frame_num - first_frame_num | ||||||||||||||||||||||||||||||||||
| if remaining_frames % 2 != 0: | ||||||||||||||||||||||||||||||||||
| raise ValueError(f"The number of remaining frames ({remaining_frames}) must be even to be rearranged with block size 2.") | ||||||||||||||||||||||||||||||||||
| tensor_hwt = rearrange(tensor_hwt, 'b (fn fb) (hn hb) (wn wb) d -> b (fn hn wn fb hb wb) d', | ||||||||||||||||||||||||||||||||||
| fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+126
to
+127
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||||||||||||||||||||||||||
| if h % 8 != 0: | ||||||||||||||||||||||||||||||||||
| tensor_hwt = torch.cat((tensor_hwt, tensor_h_r), dim=1) | ||||||||||||||||||||||||||||||||||
| if w % 8 != 0: | ||||||||||||||||||||||||||||||||||
| tensor_hwt = torch.cat((tensor_hwt, tensor_w_r), dim=1) | ||||||||||||||||||||||||||||||||||
| tensor_hwt = rearrange(tensor_hwt, '(b n) s d -> b s n d', b=b, n=n) | ||||||||||||||||||||||||||||||||||
| return tensor_hwt, h_res_len, w_res_len | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def inv_rearrange_with_remaining(self, tensor, h_res_len, w_res_len): # BSND in , BSND out | ||||||||||||||||||||||||||||||||||
| b, s, n, d = tensor.shape | ||||||||||||||||||||||||||||||||||
| h = self.grid_size[1] | ||||||||||||||||||||||||||||||||||
| w = self.grid_size[2] | ||||||||||||||||||||||||||||||||||
| h_sr, w_sr = h % 8, w % 8 | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| tensor = rearrange(tensor, 'b s n d->(b n) s d', b=b, n=n) | ||||||||||||||||||||||||||||||||||
| tensor_hwt, tensor_h, tensor_w = torch.split(tensor, [s - h_res_len - w_res_len, h_res_len, w_res_len], dim=1) | ||||||||||||||||||||||||||||||||||
| tensor_hwt = rearrange(tensor_hwt, 'b (fn hn wn fb hb wb) d -> b (fn fb) (hn hb) (wn wb) d', | ||||||||||||||||||||||||||||||||||
| fn=(self.frame_num - 1) // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) | ||||||||||||||||||||||||||||||||||
| if w_res_len != 0: | ||||||||||||||||||||||||||||||||||
| tensor_w = tensor_w.reshape(b * n, self.frame_num - 1, h - h_sr, w_sr, d) | ||||||||||||||||||||||||||||||||||
| tensor_hwt = torch.cat((tensor_hwt, tensor_w), dim=3) | ||||||||||||||||||||||||||||||||||
| if h_sr != 0: | ||||||||||||||||||||||||||||||||||
| tensor_h = tensor_h.reshape(b * n, self.frame_num - 1, h_sr, w, d) | ||||||||||||||||||||||||||||||||||
| tensor_hwt = torch.cat((tensor_hwt, tensor_h), dim=2) | ||||||||||||||||||||||||||||||||||
| tensor_hwt = tensor_hwt.reshape(b * n, -1, d) | ||||||||||||||||||||||||||||||||||
| tensor_hwt = rearrange(tensor_hwt, '(b n) s d -> b s n d', b=b, n=n) | ||||||||||||||||||||||||||||||||||
| return tensor_hwt | ||||||||||||||||||||||||||||||||||
|
liaopingbo marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def do_tensor_rearrange_pooling(self, tensor): # BSND in , BSND out | ||||||||||||||||||||||||||||||||||
| b, s, n, d = tensor.shape | ||||||||||||||||||||||||||||||||||
| if self.txt_first: | ||||||||||||||||||||||||||||||||||
| tensor_t, tensor_f, tensor_i = torch.split(tensor, [self.text_len, self.first_frame_len, | ||||||||||||||||||||||||||||||||||
| s - self.text_len - self.first_frame_len], dim=1) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| tensor_f, tensor_i, tensor_t = torch.split(tensor, | ||||||||||||||||||||||||||||||||||
| [self.first_frame_len, s - self.text_len - self.first_frame_len, | ||||||||||||||||||||||||||||||||||
| self.text_len], dim=1) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+159
to
+165
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the sequence length
Suggested change
|
||||||||||||||||||||||||||||||||||
| tensor_i_2, h_res_len, w_res_len = self.rearrange_with_remaining(tensor_i) | ||||||||||||||||||||||||||||||||||
| tensor = torch.concat((tensor_i_2, tensor_f, tensor_t), dim=1) | ||||||||||||||||||||||||||||||||||
| tensor_pool = self.avgpool(tensor, pool_size=128) | ||||||||||||||||||||||||||||||||||
| return tensor, tensor_pool, h_res_len, w_res_len | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def do_tensor_inv_rearrange(self, tensor, h_res_len, w_res_len): | ||||||||||||||||||||||||||||||||||
| b, s, n, d = tensor.shape | ||||||||||||||||||||||||||||||||||
| tensor_i, tensor_f, tensor_t = torch.split(tensor, | ||||||||||||||||||||||||||||||||||
| [s - self.text_len - self.first_frame_len, self.first_frame_len, | ||||||||||||||||||||||||||||||||||
| self.text_len], dim=1) | ||||||||||||||||||||||||||||||||||
| tensor_i = self.inv_rearrange_with_remaining(tensor_i, h_res_len, w_res_len) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if self.txt_first: | ||||||||||||||||||||||||||||||||||
| tensor = torch.concat((tensor_t, tensor_f, tensor_i), dim=1) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| tensor = torch.concat((tensor_f, tensor_i, tensor_t), dim=1) | ||||||||||||||||||||||||||||||||||
| return tensor | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def do_tensor_pooling(self, tensor): | ||||||||||||||||||||||||||||||||||
| tensor_t = tensor[:, :self.text_len, :, :] | ||||||||||||||||||||||||||||||||||
| tensor_i = tensor[:, self.text_len:, :, :] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| tensor_i_pool = self.avgpool(tensor_i, pool_size=128) | ||||||||||||||||||||||||||||||||||
| tensor_t_pool = self.avgpool(tensor_t, pool_size=128) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| tensor_pool = torch.concat((tensor_t_pool, tensor_i_pool), dim=1) | ||||||||||||||||||||||||||||||||||
| return tensor_pool | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def forward( | ||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||
| q, # BSND | ||||||||||||||||||||||||||||||||||
| k, | ||||||||||||||||||||||||||||||||||
| v, | ||||||||||||||||||||||||||||||||||
| t_b_idx, | ||||||||||||||||||||||||||||||||||
| base_blockmask, | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| t_idx = t_b_idx[0] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if t_idx < self.skip_timesteps: | ||||||||||||||||||||||||||||||||||
| base_blockmask = None | ||||||||||||||||||||||||||||||||||
| x = attention_forward(q, k, v, | ||||||||||||||||||||||||||||||||||
| opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD") | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| batch, qSeqlen, numHeads, headDim = q.shape | ||||||||||||||||||||||||||||||||||
| assert batch == 1, "Rainfusion_blockwise currently only supports batch size 1." | ||||||||||||||||||||||||||||||||||
| _, kvSeqlen, _, _ = k.shape | ||||||||||||||||||||||||||||||||||
|
liaopingbo marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||
| blockShapeX, blockShapeY = self.pool_size, self.pool_size | ||||||||||||||||||||||||||||||||||
| scale = headDim ** -0.5 | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| blockShape = [blockShapeX, blockShapeY] | ||||||||||||||||||||||||||||||||||
| actualSeqLengthsHost = [qSeqlen for _ in range(batch)] | ||||||||||||||||||||||||||||||||||
| actualSeqLengthsKvHost = [kvSeqlen for _ in range(batch)] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| sparsity = self.sparsity | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| h_res_len, w_res_len = 0, 0 | ||||||||||||||||||||||||||||||||||
| qkv = torch.cat((q, k, v), dim=0) | ||||||||||||||||||||||||||||||||||
| qkv, qkv_pool, h_res_len, w_res_len = self.do_tensor_rearrange_pooling(qkv) | ||||||||||||||||||||||||||||||||||
| q, k, v = torch.chunk(qkv, 3, dim=0) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if base_blockmask is None: | ||||||||||||||||||||||||||||||||||
| query_pool, key_pool, value_pool = torch.chunk(qkv_pool, 3, dim=0) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| attn_scores_head = torch.einsum("blnd,bsnd->bnls", query_pool, key_pool) * scale | ||||||||||||||||||||||||||||||||||
| attn_scores_fake = torch.nn.functional.softmax(attn_scores_head, dim=-1) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| selectIdx, selectNumIdx = self.get_blockwise_mask(attn_scores_fake, sparsity) | ||||||||||||||||||||||||||||||||||
| base_blockmask = [selectIdx, selectNumIdx] | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| selectIdx = base_blockmask[0] | ||||||||||||||||||||||||||||||||||
| selectNumIdx = base_blockmask[1] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| q_bnsd = q.transpose(1, 2) | ||||||||||||||||||||||||||||||||||
| k_bnsd = k.transpose(1, 2) | ||||||||||||||||||||||||||||||||||
| v_bnsd = v.transpose(1, 2) | ||||||||||||||||||||||||||||||||||
| x = mindiesd.layers.flash_attn.sparse_flash_attn_rf_v2.rain_fusion_attention( | ||||||||||||||||||||||||||||||||||
| q_bnsd, k_bnsd, v_bnsd, | ||||||||||||||||||||||||||||||||||
| scale=scale, | ||||||||||||||||||||||||||||||||||
| head_num=numHeads, | ||||||||||||||||||||||||||||||||||
| input_layout="BNSD", | ||||||||||||||||||||||||||||||||||
| select_idx=selectIdx, | ||||||||||||||||||||||||||||||||||
| select_num_idx=selectNumIdx, | ||||||||||||||||||||||||||||||||||
| blockshape=blockShape, | ||||||||||||||||||||||||||||||||||
| actual_seq_lengths=actualSeqLengthsHost, | ||||||||||||||||||||||||||||||||||
| actual_seq_lengths_kv=actualSeqLengthsKvHost | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| x = x.transpose(1, 2).view(batch, qSeqlen, numHeads, headDim) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| x = self.do_tensor_inv_rearrange(x, h_res_len, w_res_len) | ||||||||||||||||||||||||||||||||||
| base_blockmask = None | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+255
to
+256
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Setting
Suggested change
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return x, base_blockmask | ||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from .npu_rms_norm import NpuRmsNormWeight | ||
| from .npu_layer_norm import NpuLayerNormWeight | ||
|
|
||
| __all__ = ["NpuRmsNormWeight", "NpuLayerNormWeight"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| import torch | ||
|
|
||
| from lightx2v_platform.ops.norm.norm_template import LayerNormWeightTemplate | ||
| from lightx2v_platform.registry_factory import PLATFORM_LAYERNORM_WEIGHT_REGISTER | ||
|
|
||
| try: | ||
| import torch_npu | ||
| except ImportError: | ||
| torch_npu = None | ||
|
|
||
|
|
||
| @PLATFORM_LAYERNORM_WEIGHT_REGISTER("npu_layer_norm") | ||
| class NpuLayerNormWeight(LayerNormWeightTemplate): | ||
| def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6, **kwargs): | ||
| super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps) | ||
|
|
||
| def apply(self, input_tensor): | ||
| if torch_npu is not None and hasattr(torch_npu, 'npu_layer_norm') and self.weight is not None and self.bias is not None: | ||
| out = torch_npu.npu_layer_norm( | ||
| input_tensor, | ||
| (input_tensor.shape[-1],), | ||
| self.weight, | ||
| self.bias, | ||
| self.eps, | ||
| ) | ||
| if isinstance(out, tuple): | ||
| out = out[0] | ||
| return out | ||
|
Comment on lines
+18
to
+28
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
if torch_npu is not None and hasattr(torch_npu, 'npu_layer_norm') and self.weight is not None and self.bias is not None:
if self.sensitive_layer_dtype != self.infer_dtype:
out = torch_npu.npu_layer_norm(
input_tensor.to(self.sensitive_layer_dtype),
(input_tensor.shape[-1],),
self.weight.to(self.sensitive_layer_dtype),
self.bias.to(self.sensitive_layer_dtype),
self.eps,
)
if isinstance(out, tuple):
out = out[0]
return out.to(self.infer_dtype)
out = torch_npu.npu_layer_norm(
input_tensor,
(input_tensor.shape[-1],),
self.weight,
self.bias,
self.eps,
)
if isinstance(out, tuple):
out = out[0]
return out |
||
| output = torch.nn.functional.layer_norm( | ||
| input_tensor, | ||
| (input_tensor.shape[-1],), | ||
| self.weight, | ||
| self.bias, | ||
| self.eps, | ||
| ) | ||
| return output | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,30 @@ | ||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from lightx2v_platform.ops.norm.norm_template import RMSWeightTemplate | ||||||||||||||||||||||||||
| from lightx2v_platform.registry_factory import PLATFORM_RMS_WEIGHT_REGISTER | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| import torch_npu | ||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||
| torch_npu = None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @PLATFORM_RMS_WEIGHT_REGISTER("npu_rms_norm") | ||||||||||||||||||||||||||
| class NpuRmsNormWeight(RMSWeightTemplate): | ||||||||||||||||||||||||||
| def _norm(self, x): | ||||||||||||||||||||||||||
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def apply(self, input_tensor): | ||||||||||||||||||||||||||
| weight = self.weight | ||||||||||||||||||||||||||
| if torch_npu is not None and hasattr(torch_npu, "npu_rms_norm") and weight is not None: | ||||||||||||||||||||||||||
| if self.sensitive_layer_dtype != self.infer_dtype: | ||||||||||||||||||||||||||
| output_tensor, _ = torch_npu.npu_rms_norm( | ||||||||||||||||||||||||||
| input_tensor.float(), weight.float(), self.eps, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| return output_tensor.to(self.infer_dtype) | ||||||||||||||||||||||||||
|
Comment on lines
+20
to
+24
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When
Suggested change
|
||||||||||||||||||||||||||
| output_tensor, _ = torch_npu.npu_rms_norm(input_tensor, weight, self.eps) | ||||||||||||||||||||||||||
| return output_tensor | ||||||||||||||||||||||||||
| x = self._norm(input_tensor.float()) | ||||||||||||||||||||||||||
| if weight is not None: | ||||||||||||||||||||||||||
| return (x.float() * weight.float()).to(self.infer_dtype) | ||||||||||||||||||||||||||
| return x.to(self.infer_dtype) | ||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .npu_rope import NpuRope | ||
|
|
||
| __all__ = ["NpuRope"] |
Uh oh!
There was an error while loading. Please reload this page.