-
Notifications
You must be signed in to change notification settings - Fork 218
feat(npu): Add 4 Ascend NPU custom operators for sparse video inference #1157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d4d0d37
3fb61d9
969b41a
525d729
06dec8d
bdd48bb
b6d2a63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,240 @@ | ||||||||||||||||||||||||||||||||||
| import math | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| import mindiesd | ||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||||||
| from einops import rearrange | ||||||||||||||||||||||||||||||||||
| from mindiesd.layers.flash_attn.attention_forward import attention_forward | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| class Rainfusion_blockwise(nn.Module): | ||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||
| grid_size: list, | ||||||||||||||||||||||||||||||||||
| pool_size: int = 128, | ||||||||||||||||||||||||||||||||||
| sparsity: float = 0.9, | ||||||||||||||||||||||||||||||||||
| skip_timesteps: int = 0, | ||||||||||||||||||||||||||||||||||
| txt_len: int = 0, | ||||||||||||||||||||||||||||||||||
| txt_first: bool = False, | ||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| 参数: | ||||||||||||||||||||||||||||||||||
| grid_size (list): latents的THW网格大小。 | ||||||||||||||||||||||||||||||||||
| sparsity (float, optional): 稀疏度, 取值范围[0, 1],默认为 0.50。 | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Rainfusion_param | ||||||||||||||||||||||||||||||||||
| self.grid_size = grid_size | ||||||||||||||||||||||||||||||||||
| self.frame_num = self.grid_size[0] | ||||||||||||||||||||||||||||||||||
| self.num_tokens_per_frame = self.grid_size[1] * self.grid_size[2] | ||||||||||||||||||||||||||||||||||
| self.first_frame_len = self.num_tokens_per_frame | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| self.pool_size = pool_size | ||||||||||||||||||||||||||||||||||
| self.sparsity = sparsity | ||||||||||||||||||||||||||||||||||
| self.skip_timesteps = skip_timesteps | ||||||||||||||||||||||||||||||||||
| self.text_len = txt_len | ||||||||||||||||||||||||||||||||||
| self.txt_first = txt_first | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||
| def get_grid_size(latent_size, patch_size): | ||||||||||||||||||||||||||||||||||
| t, h, w = latent_size[-3:] | ||||||||||||||||||||||||||||||||||
| return [t // patch_size[0], h // patch_size[1], w // patch_size[2]] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def avgpool(self, input_tensor, pool_size=128): # BSND in, BSND out | ||||||||||||||||||||||||||||||||||
| batch, seqlen, headnum, dim = input_tensor.shape | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| num_full_blocks = seqlen // pool_size | ||||||||||||||||||||||||||||||||||
| tail_size = seqlen % pool_size | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| pooled_tensors = [] | ||||||||||||||||||||||||||||||||||
| if num_full_blocks > 0: | ||||||||||||||||||||||||||||||||||
| full_blocks = input_tensor[:, : num_full_blocks * pool_size, :, :] | ||||||||||||||||||||||||||||||||||
| full_blocks_reshaped = full_blocks.view(batch, num_full_blocks, pool_size, headnum, dim) | ||||||||||||||||||||||||||||||||||
| pooled_tensors.append(full_blocks_reshaped.mean(dim=2)) | ||||||||||||||||||||||||||||||||||
| if tail_size > 0: | ||||||||||||||||||||||||||||||||||
| tail_block = input_tensor[:, num_full_blocks * pool_size :, :, :] | ||||||||||||||||||||||||||||||||||
| tail_reshaped = tail_block.view(batch, 1, tail_size, headnum, dim) | ||||||||||||||||||||||||||||||||||
| pooled_tensors.append(tail_reshaped.mean(dim=2)) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+51
to
+58
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return torch.cat(pooled_tensors, dim=1) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def get_mask_index(self, mask): | ||||||||||||||||||||||||||||||||||
| B, N, S, _ = mask.shape | ||||||||||||||||||||||||||||||||||
| device = mask.device | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # 1. 重塑维度 → (B*N)×S×S | ||||||||||||||||||||||||||||||||||
| mask_reshaped = mask.reshape(-1, S, S) | ||||||||||||||||||||||||||||||||||
| batch_size = mask_reshaped.shape[0] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # 2. 生成行索引 标记False位置为S(大于所有有效索引) | ||||||||||||||||||||||||||||||||||
| row_indices = torch.arange(S, device=device).view(1, 1, S).expand(batch_size, S, S) # (B*N, S, S) | ||||||||||||||||||||||||||||||||||
| sorted_vals = torch.where(mask_reshaped, row_indices, S) | ||||||||||||||||||||||||||||||||||
| sorted_vals, _ = torch.sort(sorted_vals, dim=-1) | ||||||||||||||||||||||||||||||||||
| valid_count = mask_reshaped.sum(dim=-1, keepdim=True) | ||||||||||||||||||||||||||||||||||
| keep_mask = row_indices < valid_count | ||||||||||||||||||||||||||||||||||
| result = torch.where(keep_mask, sorted_vals, -1) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| pos_matrix = result.reshape(B, N, S, S) | ||||||||||||||||||||||||||||||||||
| return pos_matrix | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def get_blockwise_mask(self, score_matrix, sparsity): | ||||||||||||||||||||||||||||||||||
| batch_size, num_heads, rows, cols = score_matrix.shape | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| keep_len = math.ceil(cols * (1 - sparsity)) | ||||||||||||||||||||||||||||||||||
| keep_len = max(1, min(keep_len, cols)) | ||||||||||||||||||||||||||||||||||
| topk_values, _ = torch.topk(score_matrix, k=keep_len, dim=-1) | ||||||||||||||||||||||||||||||||||
| thresholds = topk_values[..., -1:] | ||||||||||||||||||||||||||||||||||
| mask = score_matrix >= thresholds | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| total_len = self.frame_num * self.num_tokens_per_frame + self.text_len | ||||||||||||||||||||||||||||||||||
| protected_start_idx = total_len - self.first_frame_len - self.text_len | ||||||||||||||||||||||||||||||||||
| protected_start_block = protected_start_idx // self.pool_size | ||||||||||||||||||||||||||||||||||
| protect_len = cols - protected_start_block | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if protect_len > 0: | ||||||||||||||||||||||||||||||||||
| mask[:, :, -protect_len:, :] = True | ||||||||||||||||||||||||||||||||||
| mask[:, :, :, -protect_len:] = True | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| selectIdx = self.get_mask_index(mask) | ||||||||||||||||||||||||||||||||||
| selectIdx = selectIdx[0].transpose(0, 1) | ||||||||||||||||||||||||||||||||||
| selectNumIdx = mask[0].transpose(0, 1).sum(dim=-1) | ||||||||||||||||||||||||||||||||||
| return selectIdx, selectNumIdx | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def rearrange_with_remaining(self, tensor): # BSND in , BSND out | ||||||||||||||||||||||||||||||||||
| b, s, n, d = tensor.shape | ||||||||||||||||||||||||||||||||||
| h = self.grid_size[1] | ||||||||||||||||||||||||||||||||||
| w = self.grid_size[2] | ||||||||||||||||||||||||||||||||||
| h_res_len, w_res_len = 0, 0 | ||||||||||||||||||||||||||||||||||
| first_frame_num = self.first_frame_len // h // w | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| tensor_hwt = rearrange(tensor, "b (f h w) n d -> (b n) f h w d", f=self.frame_num - first_frame_num, h=h, w=w) | ||||||||||||||||||||||||||||||||||
| if h % 8 != 0: | ||||||||||||||||||||||||||||||||||
| tensor_hwt, tensor_h_r = torch.split(tensor_hwt, h - (h % 8), dim=2) | ||||||||||||||||||||||||||||||||||
| tensor_h_r = tensor_h_r.reshape(b * n, -1, d) | ||||||||||||||||||||||||||||||||||
| h_res_len = tensor_h_r.shape[1] | ||||||||||||||||||||||||||||||||||
| if w % 8 != 0: | ||||||||||||||||||||||||||||||||||
| tensor_hwt, tensor_w_r = torch.split(tensor_hwt, w - (w % 8), dim=3) | ||||||||||||||||||||||||||||||||||
| tensor_w_r = tensor_w_r.reshape(b * n, -1, d) | ||||||||||||||||||||||||||||||||||
| w_res_len = tensor_w_r.shape[1] | ||||||||||||||||||||||||||||||||||
| remaining_frames = self.frame_num - first_frame_num | ||||||||||||||||||||||||||||||||||
| if remaining_frames % 2 != 0: | ||||||||||||||||||||||||||||||||||
| raise ValueError(f"The number of remaining frames ({remaining_frames}) must be even to be rearranged with block size 2.") | ||||||||||||||||||||||||||||||||||
| tensor_hwt = rearrange(tensor_hwt, "bn (fn fb) (hn hb) (wn wb) d -> bn (fn hn wn fb hb wb) d", fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) | ||||||||||||||||||||||||||||||||||
| if h % 8 != 0: | ||||||||||||||||||||||||||||||||||
| tensor_hwt = torch.cat((tensor_hwt, tensor_h_r), dim=1) | ||||||||||||||||||||||||||||||||||
| if w % 8 != 0: | ||||||||||||||||||||||||||||||||||
| tensor_hwt = torch.cat((tensor_hwt, tensor_w_r), dim=1) | ||||||||||||||||||||||||||||||||||
| tensor_hwt = rearrange(tensor_hwt, "(b n) s d -> b s n d", b=b, n=n) | ||||||||||||||||||||||||||||||||||
| return tensor_hwt, h_res_len, w_res_len | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def inv_rearrange_with_remaining(self, tensor, h_res_len, w_res_len): # BSND in , BSND out | ||||||||||||||||||||||||||||||||||
| b, s, n, d = tensor.shape | ||||||||||||||||||||||||||||||||||
| h = self.grid_size[1] | ||||||||||||||||||||||||||||||||||
| w = self.grid_size[2] | ||||||||||||||||||||||||||||||||||
| h_sr, w_sr = h % 8, w % 8 | ||||||||||||||||||||||||||||||||||
| first_frame_num = self.first_frame_len // (h * w) | ||||||||||||||||||||||||||||||||||
| remaining_frames = self.frame_num - first_frame_num | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| tensor = rearrange(tensor, "b s n d->(b n) s d", b=b, n=n) | ||||||||||||||||||||||||||||||||||
| tensor_hwt, tensor_h, tensor_w = torch.split(tensor, [s - h_res_len - w_res_len, h_res_len, w_res_len], dim=1) | ||||||||||||||||||||||||||||||||||
| tensor_hwt = rearrange(tensor_hwt, "bn (fn hn wn fb hb wb) d -> bn (fn fb) (hn hb) (wn wb) d", fn=remaining_frames // 2, fb=2, hb=8, wb=8, hn=h // 8, wn=w // 8) | ||||||||||||||||||||||||||||||||||
| if w_res_len != 0: | ||||||||||||||||||||||||||||||||||
| tensor_w = tensor_w.reshape(b * n, remaining_frames, h - h_sr, w_sr, d) | ||||||||||||||||||||||||||||||||||
| tensor_hwt = torch.cat((tensor_hwt, tensor_w), dim=3) | ||||||||||||||||||||||||||||||||||
| if h_sr != 0: | ||||||||||||||||||||||||||||||||||
| tensor_h = tensor_h.reshape(b * n, remaining_frames, h_sr, w, d) | ||||||||||||||||||||||||||||||||||
| tensor_hwt = torch.cat((tensor_hwt, tensor_h), dim=2) | ||||||||||||||||||||||||||||||||||
| tensor_hwt = tensor_hwt.reshape(b * n, -1, d) | ||||||||||||||||||||||||||||||||||
| tensor_hwt = rearrange(tensor_hwt, "(b n) s d -> b s n d", b=b, n=n) | ||||||||||||||||||||||||||||||||||
| return tensor_hwt | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def do_tensor_rearrange_pooling(self, tensor): # BSND in , BSND out | ||||||||||||||||||||||||||||||||||
| b, s, n, d = tensor.shape | ||||||||||||||||||||||||||||||||||
| if s < self.text_len + self.first_frame_len: | ||||||||||||||||||||||||||||||||||
| raise ValueError(f"Sequence length {s} is too small for text_len {self.text_len} and first_frame_len {self.first_frame_len}") | ||||||||||||||||||||||||||||||||||
| if self.txt_first: | ||||||||||||||||||||||||||||||||||
| tensor_t, tensor_f, tensor_i = torch.split(tensor, [self.text_len, self.first_frame_len, s - self.text_len - self.first_frame_len], dim=1) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| tensor_f, tensor_i, tensor_t = torch.split(tensor, [self.first_frame_len, s - self.text_len - self.first_frame_len, self.text_len], dim=1) | ||||||||||||||||||||||||||||||||||
| tensor_i_2, h_res_len, w_res_len = self.rearrange_with_remaining(tensor_i) | ||||||||||||||||||||||||||||||||||
| tensor = torch.cat((tensor_i_2, tensor_f, tensor_t), dim=1) | ||||||||||||||||||||||||||||||||||
| tensor_pool = self.avgpool(tensor, pool_size=128) | ||||||||||||||||||||||||||||||||||
| return tensor, tensor_pool, h_res_len, w_res_len | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def do_tensor_inv_rearrange(self, tensor, h_res_len, w_res_len): | ||||||||||||||||||||||||||||||||||
| b, s, n, d = tensor.shape | ||||||||||||||||||||||||||||||||||
| tensor_i, tensor_f, tensor_t = torch.split(tensor, [s - self.text_len - self.first_frame_len, self.first_frame_len, self.text_len], dim=1) | ||||||||||||||||||||||||||||||||||
| tensor_i = self.inv_rearrange_with_remaining(tensor_i, h_res_len, w_res_len) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if self.txt_first: | ||||||||||||||||||||||||||||||||||
| tensor = torch.cat((tensor_t, tensor_f, tensor_i), dim=1) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| tensor = torch.cat((tensor_f, tensor_i, tensor_t), dim=1) | ||||||||||||||||||||||||||||||||||
| return tensor | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def forward( | ||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||
| q, # BSND | ||||||||||||||||||||||||||||||||||
| k, | ||||||||||||||||||||||||||||||||||
| v, | ||||||||||||||||||||||||||||||||||
| t_b_idx, | ||||||||||||||||||||||||||||||||||
| base_blockmask, | ||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||
| t_idx = t_b_idx[0] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if t_idx < self.skip_timesteps: | ||||||||||||||||||||||||||||||||||
| base_blockmask = None | ||||||||||||||||||||||||||||||||||
| x = attention_forward(q, k, v, opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD") | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| batch, qSeqlen, numHeads, headDim = q.shape | ||||||||||||||||||||||||||||||||||
| assert batch == 1, "Rainfusion_blockwise currently only supports batch size 1." | ||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||||||||||||||||||||||||||
| _, kvSeqlen, _, _ = k.shape | ||||||||||||||||||||||||||||||||||
| blockShapeX, blockShapeY = self.pool_size, self.pool_size | ||||||||||||||||||||||||||||||||||
| scale = headDim**-0.5 | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| blockShape = [blockShapeX, blockShapeY] | ||||||||||||||||||||||||||||||||||
| actualSeqLengthsHost = [qSeqlen for _ in range(batch)] | ||||||||||||||||||||||||||||||||||
| actualSeqLengthsKvHost = [kvSeqlen for _ in range(batch)] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| sparsity = self.sparsity | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| h_res_len, w_res_len = 0, 0 | ||||||||||||||||||||||||||||||||||
| qkv = torch.cat((q, k, v), dim=0) | ||||||||||||||||||||||||||||||||||
| qkv, qkv_pool, h_res_len, w_res_len = self.do_tensor_rearrange_pooling(qkv) | ||||||||||||||||||||||||||||||||||
| q, k, v = torch.chunk(qkv, 3, dim=0) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if base_blockmask is None: | ||||||||||||||||||||||||||||||||||
| query_pool, key_pool, value_pool = torch.chunk(qkv_pool, 3, dim=0) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| attn_scores_head = torch.einsum("blnd,bsnd->bnls", query_pool, key_pool) * scale | ||||||||||||||||||||||||||||||||||
| attn_scores_fake = torch.nn.functional.softmax(attn_scores_head, dim=-1) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| selectIdx, selectNumIdx = self.get_blockwise_mask(attn_scores_fake, sparsity) | ||||||||||||||||||||||||||||||||||
| base_blockmask = [selectIdx, selectNumIdx] | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| selectIdx = base_blockmask[0] | ||||||||||||||||||||||||||||||||||
| selectNumIdx = base_blockmask[1] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| q_bnsd = q.transpose(1, 2) | ||||||||||||||||||||||||||||||||||
| k_bnsd = k.transpose(1, 2) | ||||||||||||||||||||||||||||||||||
| v_bnsd = v.transpose(1, 2) | ||||||||||||||||||||||||||||||||||
| x = mindiesd.layers.flash_attn.sparse_flash_attn_rf_v2.rain_fusion_attention( | ||||||||||||||||||||||||||||||||||
| q_bnsd, | ||||||||||||||||||||||||||||||||||
| k_bnsd, | ||||||||||||||||||||||||||||||||||
| v_bnsd, | ||||||||||||||||||||||||||||||||||
| scale=scale, | ||||||||||||||||||||||||||||||||||
| head_num=numHeads, | ||||||||||||||||||||||||||||||||||
| input_layout="BNSD", | ||||||||||||||||||||||||||||||||||
| select_idx=selectIdx, | ||||||||||||||||||||||||||||||||||
| select_num_idx=selectNumIdx, | ||||||||||||||||||||||||||||||||||
| blockshape=blockShape, | ||||||||||||||||||||||||||||||||||
| actual_seq_lengths=actualSeqLengthsHost, | ||||||||||||||||||||||||||||||||||
| actual_seq_lengths_kv=actualSeqLengthsKvHost, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| x = x.transpose(1, 2).reshape(batch, qSeqlen, numHeads, headDim) | ||||||||||||||||||||||||||||||||||
|
liaopingbo marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| x = self.do_tensor_inv_rearrange(x, h_res_len, w_res_len) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| return x, base_blockmask | ||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from .npu_layer_norm import NpuLayerNormWeight | ||
| from .npu_rms_norm import NpuRmsNormWeight | ||
|
|
||
| __all__ = ["NpuRmsNormWeight", "NpuLayerNormWeight"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| import torch | ||
|
|
||
| from lightx2v_platform.ops.norm.norm_template import LayerNormWeightTemplate | ||
| from lightx2v_platform.registry_factory import PLATFORM_LAYERNORM_WEIGHT_REGISTER | ||
|
|
||
| try: | ||
| import torch_npu | ||
| except ImportError: | ||
| torch_npu = None | ||
|
|
||
|
|
||
| @PLATFORM_LAYERNORM_WEIGHT_REGISTER("npu_layer_norm") | ||
| class NpuLayerNormWeight(LayerNormWeightTemplate): | ||
| def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6, **kwargs): | ||
| super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps) | ||
|
|
||
| def apply(self, input_tensor): | ||
| if torch_npu is not None and hasattr(torch_npu, "npu_layer_norm") and self.weight is not None and self.bias is not None: | ||
| if self.sensitive_layer_dtype != self.infer_dtype: | ||
| out = torch_npu.npu_layer_norm( | ||
| input_tensor.to(self.sensitive_layer_dtype), | ||
| (input_tensor.shape[-1],), | ||
| self.weight.to(self.sensitive_layer_dtype), | ||
| self.bias.to(self.sensitive_layer_dtype), | ||
| self.eps, | ||
| ) | ||
| if isinstance(out, tuple): | ||
| out = out[0] | ||
| return out.to(self.infer_dtype) | ||
| out = torch_npu.npu_layer_norm( | ||
| input_tensor, | ||
| (input_tensor.shape[-1],), | ||
| self.weight, | ||
| self.bias, | ||
| self.eps, | ||
| ) | ||
| if isinstance(out, tuple): | ||
| out = out[0] | ||
| return out | ||
| output = torch.nn.functional.layer_norm( | ||
| input_tensor, | ||
| (input_tensor.shape[-1],), | ||
| self.weight, | ||
| self.bias, | ||
| self.eps, | ||
| ) | ||
| return output |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| import torch | ||
|
|
||
| from lightx2v_platform.ops.norm.norm_template import RMSWeightTemplate | ||
| from lightx2v_platform.registry_factory import PLATFORM_RMS_WEIGHT_REGISTER | ||
|
|
||
| try: | ||
| import torch_npu | ||
| except ImportError: | ||
| torch_npu = None | ||
|
|
||
|
|
||
| @PLATFORM_RMS_WEIGHT_REGISTER("npu_rms_norm") | ||
| class NpuRmsNormWeight(RMSWeightTemplate): | ||
| def _norm(self, x): | ||
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) | ||
|
|
||
| def apply(self, input_tensor): | ||
| weight = self.weight | ||
| if torch_npu is not None and hasattr(torch_npu, "npu_rms_norm") and weight is not None: | ||
| if self.sensitive_layer_dtype != self.infer_dtype: | ||
| output_tensor, _ = torch_npu.npu_rms_norm( | ||
| input_tensor.to(self.sensitive_layer_dtype), | ||
| weight.to(self.sensitive_layer_dtype), | ||
| self.eps, | ||
| ) | ||
| return output_tensor.to(self.infer_dtype) | ||
| output_tensor, _ = torch_npu.npu_rms_norm(input_tensor, weight, self.eps) | ||
| return output_tensor | ||
| x = self._norm(input_tensor.float()) | ||
| if weight is not None: | ||
| return (x.float() * weight.float()).to(self.infer_dtype) | ||
| return x.to(self.infer_dtype) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .npu_rope import NpuRope | ||
|
|
||
| __all__ = ["NpuRope"] |
Uh oh!
There was an error while loading. Please reload this page.