feat(mpnn): report per-design confidence in inference outputs#306
feat(mpnn): report per-design confidence in inference outputs#306daylight-00 wants to merge 5 commits into
Conversation
MPNN inference only reported sequence recovery, which requires a native sequence and is meaningless for de novo design. The original ProteinMPNN/LigandMPNN also report a per-sequence confidence (exp(-mean NLL)) used to rank designs. This adds that to parity. - Add SampledNLL / SampledInterfaceNLL metrics that score the *sampled* sequence (reusing the existing NLL math + interface mask) instead of the native sequence used by the training-time NLL metric. - Compute confidence from the un-temperatured, un-biased logits (log_softmax of raw logits), matching LigandMPNN's T=1.0 definition. Using decoder_features["log_probs"] directly would be wrong: it is temperature- and bias-scaled (default T=0.1), pinning confidence near 1.0 and making it useless for ranking. - Expose raw logits in the minimal-return decoder features. - Wire overall_confidence + ligand_confidence (ligand_mpnn) per design into MPNNInferenceEngine, plus per-residue confidence. - Write overall_confidence/ligand_confidence to FASTA headers and to the CIF mpnn_output category; write per-residue confidence into the standard b-factor column (_atom_site.B_iso_or_equiv), overwriting the inherited input b-factors as the original LigandMPNN does. overall_confidence = exp(-mean_over_designed_residues(log_probs)), range (0, 1], higher = more confident; matches LigandMPNN README.
Cover the SampledNLL / SampledInterfaceNLL metrics added in the previous commit: - Assert the kwargs remapping reads the sampled sequence (S_sampled) and the raw logits, not the native sequence or temperature-scaled log_probs. - Assert SampledNLL.compute equals the hand-computed NLL of the sampled sequence under log_softmax(logits), with masked positions zeroed. These are self-contained (no structure fixtures) so they run without the test-data assets the integration suite requires. Also add Google-format docstrings to the overridden compute() methods.
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds “confidence” metrics that score the sampled (designed) sequence using raw logits, and propagates those confidences through inference outputs (FASTA headers and structure B-factors) to mirror LigandMPNN’s reporting.
Changes:
- Introduces
SampledNLL/SampledInterfaceNLLto compute NLL (and derived confidence) on the sampled sequence usinglog_softmax(logits). - Updates inference to emit
overall_confidence/ligand_confidence, write them to FASTA headers, and store per-residue confidence in structureb_factor. - Ensures logits are preserved through feature aggregation and adds unit tests for the new sampled-confidence wiring.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
| models/mpnn/tests/test_metrics.py | Adds tests to ensure sampled-confidence metrics read sampled sequence and raw logits. |
| models/mpnn/src/mpnn/utils/inference.py | Adds confidence fields to output docs and FASTA header rendering. |
| models/mpnn/src/mpnn/transforms/feature_aggregation/user_settings.py | Ensures logits are kept in aggregated decoder features for confidence computation. |
| models/mpnn/src/mpnn/metrics/nll.py | Implements SampledNLL and SampledInterfaceNLL based on raw logits + sampled sequence. |
| models/mpnn/src/mpnn/inference_engines/mpnn.py | Computes confidences from new metrics, writes per-residue confidence into b_factor, and adds output fields. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Address PR review feedback: - Rename the `log_probs` compute parameter and its `kwargs_to_compute_args` key to `logits` in SampledNLL/SampledInterfaceNLL. The argument carries raw logits, so the name now matches the contents. - Add a unit test for SampledInterfaceNLL that injects a known interface mask and asserts the interface NLL is restricted to interface positions and scores the sampled sequence (not native S) under log_softmax(logits).
|
Resolves #239 |
There was a problem hiding this comment.
Thanks for implementing this! It will be useful to have the ProteinMPNN/LigandMPNN confidences included in the output.
I have two changes/questions:
-
I would recommend adding
MPNNConfidenceandMPNNInterfaceConfidencemetrics, corresponding tooverall_confidenceandligand_confidencehere. I think those names would make the metrics clearer and help distinguish them from other confidence metrics. These could be wrappers around the sampled NLL metrics you added, which would also move the NLL-to-confidence calculation out of the top-levelMPNNInferenceEngine. -
I agree that it would be useful to store the per-residue confidences in the atom array. I have some changes in another branch that could enable writing these per-residue confidences to the output CIF. My only concern is that overwriting the b-factor may be undesirable if users want to retain the original b-factor annotation. Could we instead create a separate atom array field for the MPNN confidence? If the goal is to write this into the CIF output, we can also discuss how to merge this with the changes in my
fix/mpnn-fasta-outputbranch.
Happy to discuss further!
-Andrew
…nce in a dedicated field Address review feedback on RosettaCommons#306: - Add MPNNConfidence / MPNNInterfaceConfidence wrappers around SampledNLL / SampledInterfaceNLL that expose confidence = exp(-NLL) directly, moving the NLL-to-confidence transform out of MPNNInferenceEngine into the metric layer (registered as overall_confidence / ligand_confidence). - Store per-residue confidence in a dedicated `mpnn_confidence` AtomArray annotation (serialized to CIF as _atom_site.mpnn_confidence via extra_fields) instead of overwriting b_factor, preserving the input structure's B-factors. - Scope SampledNLL/SampledInterfaceNLL docstrings to NLL only (confidence is documented on the wrappers); copy the parent kwargs mapping defensively (dict + pop) and mask per-residue confidence with torch.where. - Add unit tests for the confidence wrappers.
|
Thanks for the review, Andrew. I've addressed both points.
For CIF output, I added For the FASTA output overlap, I also checked the refactored writer in your branch. Since the confidence values are now in David |
AndrewKubaney
left a comment
There was a problem hiding this comment.
Apologies for the delay in my review! This looks great; thank you for implementing this! I would suggest a few minor changes:
-
For the
kwargs_to_compute_argsfunction in your new metrics classes, I think it would be more straightforward to construct the mapping dictionary from scratch, rather than relying onsuper. This will avoid the.poplogic, which might be hard to understand. -
I think
MPNNInterfaceConfidencecould be renamed toMPNNLigandInterfaceConfidence, so that people don't misinterpret this as applying to non protein-ligand interfaces (for instance, protein-protein interfaces). -
I think
mpnn_confidenceis a good name for the atom array feature; what do you think about changing the final confidence names (the ones that ends up in the FASTA) fromoverall_confidencetoconfidenceormpnn_confidence(and similarlyligand_confidencetoligand_interface_confidenceormpnn_ligand_interface_confidence)? I think those might be more descriptive, but they do deviate from the original ProteinMPNN/LigandMPNN norm. -
One edge case involves attempting to compute the ligand interface confidence when no ligand is present (if someone runs LigandMPNN on a protein monomer) or there are no valid ligand-interface residues. Could you add a test case to make sure this is handled cleanly?
Thanks again for your work on this! After addressing these, I think we should merge. The other branch can be merged at a later time.
- Andrew
…nterface edge case Address @AndrewKubaney's second review on RosettaCommons#306: - Build kwargs_to_compute_args explicitly in SampledNLL / SampledLigandInterfaceNLL instead of copying the parent mapping and popping "log_probs" (I'd used copy+pop to avoid mutating the parent, but the explicit literal reads more clearly). - Rename MPNNInterfaceConfidence -> MPNNLigandInterfaceConfidence and SampledInterfaceNLL -> SampledLigandInterfaceNLL so the polymer-ligand interface scope is unambiguous (vs e.g. protein-protein interfaces). - Rename reported fields overall_confidence -> confidence and ligand_confidence -> ligand_interface_confidence, mirroring the existing sequence_recovery / ligand_interface_sequence_recovery keys and Foundry's un-prefixed output-key convention (the mpnn_confidence atom-array field is unchanged). - Handle the no-interface edge case: a ligand_mpnn run with no polymer-ligand interface residues yields an undefined (NaN) interface confidence; the engine now emits None (omitted) rather than NaN. Add a unit test.
|
Thanks, Andrew — all four addressed.
One small naming question: the sibling metric classes in mpnn/metrics are unprefixed ( Sounds good on merging this and leaving the fix/mpnn-fasta-output integration for later. Thanks for the thorough review! David |
Summary
MPNN inference currently reports only sequence recovery, which needs a native sequence and is meaningless for de novo design. The original ProteinMPNN/LigandMPNN additionally report a per-sequence confidence (exp(-mean NLL)), the standard metric for ranking/filtering designs. This PR brings that to the foundry re-implementation.
What's added
For each design, inference now reports:
mpnn_confidenceatom-array field.Output format
>name_b{b}_d{d}, confidence=..., ligand_interface_confidence=..., sequence_recovery=..., ligand_interface_sequence_recovery=..._mpnn_output.confidence/_mpnn_output.ligand_interface_confidence(per-design scalars)._atom_site.mpnn_confidencecolumn (via the existingextra_fieldspath, likempnn_temperature); the inputb_factorannotation is preserved.Implementation
metrics/nll.py:SampledNLL/SampledLigandInterfaceNLL: score the sampled sequence (S_sampled) underlog_softmax(logits), reusing the existing NLL math and interface-mask machinery. Note:decoder_features["log_probs"]is temperature/bias-scaled (defaultT=0.1), which would pin confidence near 1.0, so the rawlogitsare used to match LigandMPNN'sT=1.0definition.MPNNConfidence/MPNNLigandInterfaceConfidence: thin wrappers that expose the derived confidence (exp(-NLL)), so the NLL-to-confidence transform lives in the metric layer rather than the engine. Registered asconfidence/ligand_interface_confidence.user_settings.py: exposelogitsin the minimal-return decoder features.inference_engines/mpnn.py: register the confidence metrics, read the per-design / per-residue confidence directly, and store per-residue confidence in a dedicatedmpnn_confidenceatom-array annotation (leavingb_factoruntouched). When there are no polymer-ligand interface residues the interface confidence is omitted (None) rather than written as NaN.utils/inference.py: emit the confidence header fields inwrite_fastaand addmpnn_confidenceto the CIFextra_fields.Tests
tests/test_metrics.py: self-contained unit tests for the new metrics (no structure fixtures, so they run without the test-data assets the integration suite needs). They pin: the metrics read the sampled sequence and the raw logits (not the native sequence / temperature-scaled log_probs); the confidence equalsexp(-NLL)of the sampled sequence underlog_softmax(logits); the interface confidence is restricted to interface residues; and the no-interface case is cleanly undefined (not a crash).Verification
Ran
ligand_mpnn(legacy weightsligandmpnn_v_32_010_25.pt) on a dsDNA binder backbone:In the output CIF, the input
_atom_site.B_iso_or_equivis preserved and per-residue confidence is written to a separate_atom_site.mpnn_confidencecolumn (range ~0.14-0.93 on designed positions, 0.0 on ligand/DNA/fixed atoms).Notes / scope
mpnn_confidencefield rather thanb_factor, so the input structure's B-factors are preserved. (The original LigandMPNN overwrites PDB b-factors; Foundry keeps them.)ligand_interface_confidenceis omitted when the input has no polymer-ligand interface residues (e.g. ligand_mpnn on a monomer).seedis set.fix/mpnn-fasta-outputbranch; confidence values live inoutput_dict, so they thread into that branch's writer by adding the two fields to its header allowlist.