diff --git a/src/diffusers/modular_pipelines/mellon_node_utils.py b/src/diffusers/modular_pipelines/mellon_node_utils.py index f65459dfc990..0a6bb59c2f81 100644 --- a/src/diffusers/modular_pipelines/mellon_node_utils.py +++ b/src/diffusers/modular_pipelines/mellon_node_utils.py @@ -1058,6 +1058,7 @@ def from_custom_block( inputs = [] model_inputs = [] outputs = [] + required_inputs = [] # Process block inputs for input_param in block.inputs: @@ -1066,7 +1067,8 @@ def from_custom_block( if input_param.name in input_types: input_param = copy.copy(input_param) input_param.metadata = {"mellon": input_types[input_param.name]} - print(f" processing input: {input_param.name}, metadata: {input_param.metadata}") + if input_param.required: + required_inputs.append(input_param.name) inputs.append(input_param_to_mellon_param(input_param)) # Process block outputs @@ -1090,7 +1092,7 @@ def from_custom_block( "inputs": inputs, "model_inputs": model_inputs, "outputs": outputs, - "required_inputs": [], + "required_inputs": required_inputs, "required_model_inputs": [], "block_name": "custom", } diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py index 315e16d7b260..66a09e3b67fb 100644 --- a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -229,6 +229,17 @@ def test_custom_block_supported_components(self): assert len(pipe.components) == 1 assert pipe.component_names[0] == "transformer" + def test_custom_block_mellon_config_preserves_required_inputs(self): + from diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig + + custom_block = DummyCustomBlockSimple() + + mellon_config = MellonPipelineConfig.from_custom_block(custom_block) + custom_node = mellon_config.node_params["custom"] + + assert custom_node["required_inputs"] == ["prompt"] + assert custom_node["params"]["prompt"]["label"].endswith(" *") + def test_trust_remote_code_not_propagated_to_external_repo(self): """When a modular pipeline repo references a component from an external repo that has custom code (auto_map in config), calling load_components(trust_remote_code=True) should NOT