Skip to content

Commit 4367b85

Browse files
committed
Bug fixed
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 3801923 commit 4367b85

1 file changed

Lines changed: 20 additions & 7 deletions

File tree

modelopt/torch/export/unified_export_hf.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -932,21 +932,34 @@ def _export_diffusers_checkpoint(
932932

933933
print(f" Saved to: {component_export_dir}")
934934

935-
# Step 5: For pipelines, also save the model_index.json
935+
# Step 5: For pipelines, also save model_index.json
936936
if is_diffusers_pipe:
937937
model_index_path = export_dir / "model_index.json"
938-
if hasattr(pipe, "config") and pipe.config is not None:
939-
# Save a simplified model_index.json that points to the exported components
938+
source_path = getattr(pipe, "name_or_path", None) or getattr(
939+
getattr(pipe, "config", None), "_name_or_path", None
940+
)
941+
942+
# Prefer preserving the original model_index.json when the source is local.
943+
if source_path:
944+
candidate_model_index = Path(source_path) / "model_index.json"
945+
if candidate_model_index.exists():
946+
with open(candidate_model_index) as file:
947+
model_index = json.load(file)
948+
with open(model_index_path, "w") as file:
949+
json.dump(model_index, file, indent=4)
950+
951+
# Fallback to Diffusers-native config serialization.
952+
if not model_index_path.exists() and hasattr(pipe, "save_config"):
953+
pipe.save_config(export_dir)
954+
955+
# Last resort: synthesize a minimal model_index.json from exported components.
956+
if not model_index_path.exists() and hasattr(pipe, "config") and pipe.config is not None:
940957
model_index = {
941958
"_class_name": type(pipe).__name__,
942959
"_diffusers_version": diffusers.__version__,
943960
}
944-
# Add component class names for all components
945-
# Use the base library name (e.g., "diffusers", "transformers") instead of
946-
# the full module path, as expected by diffusers pipeline loading
947961
for name, comp in all_components.items():
948962
module = type(comp).__module__
949-
# Extract base library name (first part of module path)
950963
library = module.split(".")[0]
951964
model_index[name] = [library, type(comp).__name__]
952965

0 commit comments

Comments
 (0)