Skip to content

Commit 668b8a1

Browse files
authored
[1/3] Diffusion ckpt export for NVFP4 & FP8 (#781)
## What does this PR do? **Type of change:** New feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** This PR adds support for exporting quantized diffusers models (DiT, Flux, SD3, UNet, etc.) to HuggingFace checkpoint format, enabling deployment to inference frameworks like SGLang, vLLM, and TensorRT-LLM. **Changes** New file: `diffusers_utils.py` - Dummy input generation for various diffusion models - Pipeline component extraction helpers - QKV projection detection and grouping - `hide_quantizers_from_state_dict()` context manager for clean saves Refactored: `unified_export_hf.py` - New `_fuse_qkv_linears_diffusion()` for QKV amax fusion - `_export_diffusers_checkpoint()` to export full pipelines (models + tokenizers + schedulers etc.) Plans - [x] [1/3] Add the basic functionalities to support limited image models with NVFP4 + FP8, with some refactoring on the previous LLM code and the diffusers example. PIC: @jingyu-ml - [ ] [2/3] Add support to more video gen modelsPIC: @jingyu-ml - [ ] [3/3] Add test cases, refactor on the doc, and all related README. PIC: @jingyu-ml ## Usage <!-- You can potentially add a usage example below. --> ``` mtq.quantize(pipe, quant_config, forward_call) export_hf_checkpoint(pipe, export_dir=hf_ckpt_dir) ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**:No - **Did you add or update any necessary documentation?**:No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## New Features * Added HuggingFace checkpoint export support for quantized diffusion models with configurable output directory * Introduced new `--hf-ckpt-dir` CLI argument for specifying checkpoint export destination * Extended export functionality to support selective component exports from diffusion pipelines * Enhanced quantized model export with improved component handling and multi-stage checkpoint generation <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 563a1e0 commit 668b8a1

8 files changed

Lines changed: 1081 additions & 203 deletions

File tree

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from collections.abc import Callable
17+
from enum import Enum
18+
from typing import Any
19+
20+
from diffusers import (
21+
DiffusionPipeline,
22+
FluxPipeline,
23+
LTXConditionPipeline,
24+
StableDiffusion3Pipeline,
25+
WanPipeline,
26+
)
27+
from utils import (
28+
filter_func_default,
29+
filter_func_flux_dev,
30+
filter_func_ltx_video,
31+
filter_func_wan_video,
32+
)
33+
34+
35+
class ModelType(str, Enum):
36+
"""Supported model types."""
37+
38+
SDXL_BASE = "sdxl-1.0"
39+
SDXL_TURBO = "sdxl-turbo"
40+
SD3_MEDIUM = "sd3-medium"
41+
SD35_MEDIUM = "sd3.5-medium"
42+
FLUX_DEV = "flux-dev"
43+
FLUX_SCHNELL = "flux-schnell"
44+
LTX_VIDEO_DEV = "ltx-video-dev"
45+
WAN22_T2V_14b = "wan2.2-t2v-14b"
46+
WAN22_T2V_5b = "wan2.2-t2v-5b"
47+
48+
49+
def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
50+
"""
51+
Get the appropriate filter function for a given model type.
52+
53+
Args:
54+
model_type: The model type enum
55+
56+
Returns:
57+
A filter function appropriate for the model type
58+
"""
59+
filter_func_map = {
60+
ModelType.FLUX_DEV: filter_func_flux_dev,
61+
ModelType.FLUX_SCHNELL: filter_func_default,
62+
ModelType.SDXL_BASE: filter_func_default,
63+
ModelType.SDXL_TURBO: filter_func_default,
64+
ModelType.SD3_MEDIUM: filter_func_default,
65+
ModelType.SD35_MEDIUM: filter_func_default,
66+
ModelType.LTX_VIDEO_DEV: filter_func_ltx_video,
67+
ModelType.WAN22_T2V_14b: filter_func_wan_video,
68+
ModelType.WAN22_T2V_5b: filter_func_wan_video,
69+
}
70+
71+
return filter_func_map.get(model_type, filter_func_default)
72+
73+
74+
# Model registry with HuggingFace model IDs
75+
MODEL_REGISTRY: dict[ModelType, str] = {
76+
ModelType.SDXL_BASE: "stabilityai/stable-diffusion-xl-base-1.0",
77+
ModelType.SDXL_TURBO: "stabilityai/sdxl-turbo",
78+
ModelType.SD3_MEDIUM: "stabilityai/stable-diffusion-3-medium-diffusers",
79+
ModelType.SD35_MEDIUM: "stabilityai/stable-diffusion-3.5-medium",
80+
ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev",
81+
ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell",
82+
ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev",
83+
ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
84+
ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
85+
}
86+
87+
MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline]] = {
88+
ModelType.SDXL_BASE: DiffusionPipeline,
89+
ModelType.SDXL_TURBO: DiffusionPipeline,
90+
ModelType.SD3_MEDIUM: StableDiffusion3Pipeline,
91+
ModelType.SD35_MEDIUM: StableDiffusion3Pipeline,
92+
ModelType.FLUX_DEV: FluxPipeline,
93+
ModelType.FLUX_SCHNELL: FluxPipeline,
94+
ModelType.LTX_VIDEO_DEV: LTXConditionPipeline,
95+
ModelType.WAN22_T2V_14b: WanPipeline,
96+
ModelType.WAN22_T2V_5b: WanPipeline,
97+
}
98+
99+
# Shared dataset configurations
100+
_SD_PROMPTS_DATASET = {
101+
"name": "Gustavosta/Stable-Diffusion-Prompts",
102+
"split": "train",
103+
"column": "Prompt",
104+
}
105+
106+
_OPENVID_DATASET = {
107+
"name": "nkp37/OpenVid-1M",
108+
"split": "train",
109+
"column": "caption",
110+
}
111+
112+
# Model family base configurations
113+
_SDXL_BASE_CONFIG: dict[str, Any] = {
114+
"backbone": "unet",
115+
"dataset": _SD_PROMPTS_DATASET,
116+
}
117+
118+
_SD3_BASE_CONFIG: dict[str, Any] = {
119+
"backbone": "transformer",
120+
"dataset": _SD_PROMPTS_DATASET,
121+
}
122+
123+
_FLUX_BASE_CONFIG: dict[str, Any] = {
124+
"backbone": "transformer",
125+
"dataset": _SD_PROMPTS_DATASET,
126+
"inference_extra_args": {
127+
"height": 1024,
128+
"width": 1024,
129+
"guidance_scale": 3.5,
130+
"max_sequence_length": 512,
131+
},
132+
}
133+
134+
_WAN_BASE_CONFIG: dict[str, Any] = {
135+
"backbone": "transformer",
136+
"dataset": _OPENVID_DATASET,
137+
}
138+
139+
# Model-specific default arguments for calibration
140+
MODEL_DEFAULTS: dict[ModelType, dict[str, Any]] = {
141+
ModelType.SDXL_BASE: _SDXL_BASE_CONFIG,
142+
ModelType.SDXL_TURBO: _SDXL_BASE_CONFIG,
143+
ModelType.SD3_MEDIUM: _SD3_BASE_CONFIG,
144+
ModelType.SD35_MEDIUM: _SD3_BASE_CONFIG,
145+
ModelType.FLUX_DEV: _FLUX_BASE_CONFIG,
146+
ModelType.FLUX_SCHNELL: _FLUX_BASE_CONFIG,
147+
ModelType.LTX_VIDEO_DEV: {
148+
"backbone": "transformer",
149+
"dataset": _SD_PROMPTS_DATASET,
150+
"inference_extra_args": {
151+
"height": 512,
152+
"width": 704,
153+
"num_frames": 121,
154+
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
155+
},
156+
},
157+
ModelType.WAN22_T2V_14b: {
158+
**_WAN_BASE_CONFIG,
159+
"from_pretrained_extra_args": {
160+
"boundary_ratio": 0.875,
161+
},
162+
"inference_extra_args": {
163+
"height": 720,
164+
"width": 1280,
165+
"num_frames": 81,
166+
"fps": 16,
167+
"guidance_scale": 4.0,
168+
"guidance_scale_2": 3.0,
169+
"negative_prompt": (
170+
"vivid colors, overexposed, static, blurry details, subtitles, style, "
171+
"work of art, painting, picture, still, overall grayish, worst quality, "
172+
"low quality, JPEG artifacts, ugly, deformed, extra fingers, poorly drawn hands, "
173+
"poorly drawn face, deformed, disfigured, deformed limbs, fused fingers, "
174+
"static image, cluttered background, three legs, many people in the background, "
175+
"walking backwards"
176+
),
177+
},
178+
},
179+
ModelType.WAN22_T2V_5b: {
180+
**_WAN_BASE_CONFIG,
181+
"inference_extra_args": {
182+
"height": 512,
183+
"width": 768,
184+
"num_frames": 81,
185+
"fps": 16,
186+
"guidance_scale": 5.0,
187+
"negative_prompt": (
188+
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留" # noqa: RUF001
189+
",丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体," # noqa: RUF001
190+
"手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" # noqa: RUF001
191+
),
192+
},
193+
},
194+
}

0 commit comments

Comments
 (0)