Skip to content

Commit 945ee02

Browse files
authored
[1/3] Add the fastvideo support (#804)
## What does this PR do? **Type of change:** new feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** FastVideo is a new diffusion-focused framework that we plan to integrate with. In this work, we added initial support for WAN 2.2 5B in FastVideo, targeting the text-to-video use case. For the Conv layer type, we currently use a straightforward direct convolution call. Implicit GEMM quantization is intentionally omitted in this first MR and will be addressed in a follow-up MR. - [x] [1/3] Added support for the WAN 2.2 DIT + VAE layer type. - [ ] [2/3] Added calibration support for them in the example script, add test cases and README, doc. - [ ] [3/3] Submitted an MR to fastvideo to enable quantization-aware training. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added FastVideo plugin support to the quantization framework. Users can now apply quantization to FastVideo-specific layers with specialized weight quantization handling, optimized input processing, and caching features for enhanced inference performance. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 668b8a1 commit 945ee02

2 files changed

Lines changed: 67 additions & 0 deletions

File tree

modelopt/torch/quantization/plugins/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,6 @@
7575

7676
with import_plugin("trl"):
7777
from .trl import *
78+
79+
with import_plugin("fastvideo"):
80+
from .fastvideo import *
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Support quantization for FastVideo layers."""
17+
18+
import torch
19+
import torch.nn.functional as F
20+
from fastvideo.layers.linear import ReplicatedLinear
21+
from fastvideo.models.vaes.wanvae import WanCausalConv3d
22+
23+
from ..nn import QuantLinearConvBase, QuantModuleRegistry
24+
from ..nn.modules.quant_conv import _QuantConv3d
25+
from ..nn.modules.quant_linear import _QuantLinear
26+
from ..utils import is_torch_export_mode
27+
28+
29+
@QuantModuleRegistry.register({WanCausalConv3d: "WanCausalConv3d"})
30+
class _QuantWanCausalConv3d(_QuantConv3d):
31+
@staticmethod
32+
def _get_quantized_weight(module: "QuantLinearConvBase", weight: torch.Tensor) -> torch.Tensor:
33+
"""Quantize weight in linear format for proper block-wise FP4 quantization."""
34+
if module._enable_weight_quantization or is_torch_export_mode():
35+
# Quantize in linear format (block-wise quantization works correctly here)
36+
return module.weight_quantizer(weight)
37+
38+
return weight
39+
40+
def forward(self, x, cache_x=None):
41+
from fastvideo.platforms import current_platform
42+
43+
with self.quantize_weight():
44+
padding = list(self._padding)
45+
if cache_x is not None and self._padding[4] > 0:
46+
cache_x = cache_x.to(x.device)
47+
x = torch.cat([cache_x, x], dim=2)
48+
padding[4] -= cache_x.shape[2]
49+
x = F.pad(x, padding)
50+
x = (
51+
x.to(self.weight.dtype) if current_platform.is_mps() else x
52+
) # casting needed for mps since amp isn't supported
53+
54+
input = self.input_quantizer(x)
55+
output = super(WanCausalConv3d, self).forward(input)
56+
57+
if isinstance(output, tuple):
58+
return (self.output_quantizer(output[0]), *output[1:])
59+
return self.output_quantizer(output)
60+
61+
62+
@QuantModuleRegistry.register({ReplicatedLinear: "ReplicatedLinear"})
63+
class _QuantReplicatedLinear(_QuantLinear):
64+
pass

0 commit comments

Comments
 (0)