Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/diffusers/modular_pipelines/mellon_node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,7 @@ def from_custom_block(
inputs = []
model_inputs = []
outputs = []
required_inputs = []

# Process block inputs
for input_param in block.inputs:
Expand All @@ -1066,7 +1067,8 @@ def from_custom_block(
if input_param.name in input_types:
input_param = copy.copy(input_param)
input_param.metadata = {"mellon": input_types[input_param.name]}
print(f" processing input: {input_param.name}, metadata: {input_param.metadata}")
if input_param.required:
required_inputs.append(input_param.name)
inputs.append(input_param_to_mellon_param(input_param))

# Process block outputs
Expand All @@ -1090,7 +1092,7 @@ def from_custom_block(
"inputs": inputs,
"model_inputs": model_inputs,
"outputs": outputs,
"required_inputs": [],
"required_inputs": required_inputs,
"required_model_inputs": [],
"block_name": "custom",
}
Expand Down
11 changes: 11 additions & 0 deletions tests/modular_pipelines/test_modular_pipelines_custom_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,17 @@ def test_custom_block_supported_components(self):
assert len(pipe.components) == 1
assert pipe.component_names[0] == "transformer"

def test_custom_block_mellon_config_preserves_required_inputs(self):
from diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig

custom_block = DummyCustomBlockSimple()

mellon_config = MellonPipelineConfig.from_custom_block(custom_block)
custom_node = mellon_config.node_params["custom"]

assert custom_node["required_inputs"] == ["prompt"]
assert custom_node["params"]["prompt"]["label"].endswith(" *")

def test_trust_remote_code_not_propagated_to_external_repo(self):
"""When a modular pipeline repo references a component from an external repo that has custom
code (auto_map in config), calling load_components(trust_remote_code=True) should NOT
Expand Down
Loading