Skip to content

Commit 9101e11

Browse files
committed
Update
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 5361673 commit 9101e11

1 file changed

Lines changed: 18 additions & 14 deletions

File tree

modelopt/torch/export/unified_export_hf.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -940,21 +940,25 @@ def _export_diffusers_checkpoint(
940940
# Step 5: For pipelines, also save model_index.json
941941
if is_diffusers_pipe:
942942
model_index_path = export_dir / "model_index.json"
943-
source_path = getattr(pipe, "name_or_path", None) or getattr(
944-
getattr(pipe, "config", None), "_name_or_path", None
945-
)
943+
is_partial_export = components is not None
946944

947-
# Prefer preserving the original model_index.json when the source is local.
948-
if source_path:
949-
candidate_model_index = Path(source_path) / "model_index.json"
950-
if candidate_model_index.exists():
951-
with open(candidate_model_index) as file:
952-
model_index = json.load(file)
953-
with open(model_index_path, "w") as file:
954-
json.dump(model_index, file, indent=4)
955-
956-
# Fallback to Diffusers-native config serialization.
957-
if not model_index_path.exists() and hasattr(pipe, "save_config"):
945+
# For full export, preserve original model_index.json when possible.
946+
# For partial export, skip this to avoid listing non-exported components.
947+
if not is_partial_export:
948+
source_path = getattr(pipe, "name_or_path", None) or getattr(
949+
getattr(pipe, "config", None), "_name_or_path", None
950+
)
951+
if source_path:
952+
candidate_model_index = Path(source_path) / "model_index.json"
953+
if candidate_model_index.exists():
954+
with open(candidate_model_index) as file:
955+
model_index = json.load(file)
956+
with open(model_index_path, "w") as file:
957+
json.dump(model_index, file, indent=4)
958+
959+
# Full-export fallback to Diffusers-native config serialization.
960+
# Partial export skips this for the same reason as above.
961+
if not is_partial_export and not model_index_path.exists() and hasattr(pipe, "save_config"):
958962
pipe.save_config(export_dir)
959963

960964
# Last resort: synthesize a minimal model_index.json from exported components.

0 commit comments

Comments
 (0)