Skip to content

Commit e6e4efd

Browse files
authored
[0.5/3] Diffusion ckpt export for NVFP4 & FP8 (#783)
See #781 This is the MR that only includes the refactoring of the llm export, please ignore the change on quantize.py from the diffusion example. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added `--hf-ckpt-dir` CLI option to save checkpoints in HuggingFace format * Enabled support for exporting Diffusers-based pipelines * Unified export system now handles both transformer and diffusion model architectures <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 849a350 commit e6e4efd

4 files changed

Lines changed: 323 additions & 128 deletions

File tree

examples/diffusers/quantization/quantize.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666

6767
import modelopt.torch.opt as mto
6868
import modelopt.torch.quantization as mtq
69+
from modelopt.torch.export import export_hf_checkpoint
6970

7071

7172
class ModelType(str, Enum):
@@ -348,6 +349,7 @@ class ExportConfig:
348349

349350
quantized_torch_ckpt_path: Path | None = None
350351
onnx_dir: Path | None = None
352+
hf_ckpt_dir: Path | None = None
351353
restore_from: Path | None = None
352354

353355
def validate(self) -> None:
@@ -363,6 +365,9 @@ def validate(self) -> None:
363365
if self.onnx_dir and not self.onnx_dir.exists():
364366
self.onnx_dir.mkdir(parents=True, exist_ok=True)
365367

368+
if self.hf_ckpt_dir and not self.hf_ckpt_dir.exists():
369+
self.hf_ckpt_dir.mkdir(parents=True, exist_ok=True)
370+
366371

367372
def setup_logging(verbose: bool = False) -> logging.Logger:
368373
"""
@@ -862,6 +867,20 @@ def restore_checkpoint(self, backbone: nn.Module) -> None:
862867
mto.restore(backbone, str(self.config.restore_from))
863868
self.logger.info("Model restored successfully")
864869

870+
def export_hf_ckpt(self, pipe: DiffusionPipeline) -> None:
871+
"""
872+
Export quantized model to HuggingFace checkpoint format.
873+
874+
Args:
875+
pipe: Diffusion pipeline containing the quantized model
876+
"""
877+
if not self.config.hf_ckpt_dir:
878+
return
879+
880+
self.logger.info(f"Exporting HuggingFace checkpoint to {self.config.hf_ckpt_dir}")
881+
export_hf_checkpoint(pipe, export_dir=self.config.hf_ckpt_dir)
882+
self.logger.info("HuggingFace checkpoint export completed successfully")
883+
865884

866885
def create_argument_parser() -> argparse.ArgumentParser:
867886
"""
@@ -994,6 +1013,11 @@ def create_argument_parser() -> argparse.ArgumentParser:
9941013
help="Path to save quantized PyTorch checkpoint",
9951014
)
9961015
export_group.add_argument("--onnx-dir", type=str, help="Directory for ONNX export")
1016+
export_group.add_argument(
1017+
"--hf-ckpt-dir",
1018+
type=str,
1019+
help="Directory for HuggingFace checkpoint export",
1020+
)
9971021
export_group.add_argument(
9981022
"--restore-from", type=str, help="Path to restore from previous checkpoint"
9991023
)
@@ -1070,6 +1094,7 @@ def main() -> None:
10701094
if args.quantized_torch_ckpt_save_path
10711095
else None,
10721096
onnx_dir=Path(args.onnx_dir) if args.onnx_dir else None,
1097+
hf_ckpt_dir=Path(args.hf_ckpt_dir) if args.hf_ckpt_dir else None,
10731098
restore_from=Path(args.restore_from) if args.restore_from else None,
10741099
)
10751100

@@ -1125,6 +1150,9 @@ def forward_loop(mod):
11251150
model_config.model_type,
11261151
quant_config.format,
11271152
)
1153+
1154+
export_manager.export_hf_ckpt(pipe)
1155+
11281156
logger.info(
11291157
f"Quantization process completed successfully! Time taken = {time.time() - s} seconds"
11301158
)

examples/llm_ptq/multinode_ptq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import modelopt.torch.quantization as mtq
3737
from modelopt.torch.export import get_model_type
3838
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
39-
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
39+
from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint
4040
from modelopt.torch.quantization.config import need_calibration
4141
from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes
4242
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets
@@ -243,7 +243,7 @@ def export_model(
243243
export_dir = Path(export_path)
244244
export_dir.mkdir(parents=True, exist_ok=True)
245245

246-
post_state_dict, hf_quant_config = _export_hf_checkpoint(
246+
post_state_dict, hf_quant_config = _export_transformers_checkpoint(
247247
model, torch.bfloat16, accelerator=accelerator
248248
)
249249

examples/llm_qat/export.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import modelopt.torch.opt as mto
2525
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
26-
from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint
26+
from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint
2727
from modelopt.torch.opt.conversion import restore_from_modelopt_state
2828
from modelopt.torch.quantization.utils import set_quantizer_state_dict
2929
from modelopt.torch.utils import print_rank_0
@@ -81,7 +81,9 @@ def main(args):
8181
base_model_dir = export_dir
8282

8383
try:
84-
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=is_qlora)
84+
post_state_dict, hf_quant_config = _export_transformers_checkpoint(
85+
model, is_modelopt_qlora=is_qlora
86+
)
8587

8688
with open(f"{base_model_dir}/hf_quant_config.json", "w") as file:
8789
json.dump(hf_quant_config, file, indent=4)

0 commit comments

Comments
 (0)