From 3f976e126e59ef9000e5b244437396e6f1aaf932 Mon Sep 17 00:00:00 2001 From: daylight-00 Date: Fri, 5 Jun 2026 16:27:25 +0900 Subject: [PATCH 1/5] feat(mpnn): add overall/ligand confidence to inference outputs MPNN inference only reported sequence recovery, which requires a native sequence and is meaningless for de novo design. The original ProteinMPNN/LigandMPNN also report a per-sequence confidence (exp(-mean NLL)) used to rank designs. This adds that to parity. - Add SampledNLL / SampledInterfaceNLL metrics that score the *sampled* sequence (reusing the existing NLL math + interface mask) instead of the native sequence used by the training-time NLL metric. - Compute confidence from the un-temperatured, un-biased logits (log_softmax of raw logits), matching LigandMPNN's T=1.0 definition. Using decoder_features["log_probs"] directly would be wrong: it is temperature- and bias-scaled (default T=0.1), pinning confidence near 1.0 and making it useless for ranking. - Expose raw logits in the minimal-return decoder features. - Wire overall_confidence + ligand_confidence (ligand_mpnn) per design into MPNNInferenceEngine, plus per-residue confidence. - Write overall_confidence/ligand_confidence to FASTA headers and to the CIF mpnn_output category; write per-residue confidence into the standard b-factor column (_atom_site.B_iso_or_equiv), overwriting the inherited input b-factors as the original LigandMPNN does. overall_confidence = exp(-mean_over_designed_residues(log_probs)), range (0, 1], higher = more confident; matches LigandMPNN README. --- .../mpnn/src/mpnn/inference_engines/mpnn.py | 53 +++++++++++++++++++ models/mpnn/src/mpnn/metrics/nll.py | 52 ++++++++++++++++++ .../feature_aggregation/user_settings.py | 8 ++- models/mpnn/src/mpnn/utils/inference.py | 11 ++++ 4 files changed, 123 insertions(+), 1 deletion(-) diff --git a/models/mpnn/src/mpnn/inference_engines/mpnn.py b/models/mpnn/src/mpnn/inference_engines/mpnn.py index 9f53c1d5..c2e8d90c 100644 --- a/models/mpnn/src/mpnn/inference_engines/mpnn.py +++ b/models/mpnn/src/mpnn/inference_engines/mpnn.py @@ -12,6 +12,7 @@ from atomworks.ml.utils.token import get_token_starts, spread_token_wise from biotite.structure import AtomArray from mpnn.collate.feature_collator import FeatureCollator +from mpnn.metrics.nll import SampledInterfaceNLL, SampledNLL from mpnn.metrics.sequence_recovery import ( InterfaceSequenceRecovery, SequenceRecovery, @@ -193,11 +194,18 @@ def _build_metrics_manager(self) -> MetricManager: # Construct metrics dict. metrics: dict[str, Any] = { "sequence_recovery": SequenceRecovery(return_per_example_metrics=True), + "overall_confidence": SampledNLL( + return_per_example_metrics=True, + return_per_residue_metrics=True, + ), } if self.model_type == "ligand_mpnn": metrics["interface_sequence_recovery"] = InterfaceSequenceRecovery( return_per_example_metrics=True ) + metrics["ligand_confidence"] = SampledInterfaceNLL( + return_per_example_metrics=True + ) # Construct the MetricManager. metric_manager = MetricManager.from_metrics(metrics, raise_errors=True) @@ -387,6 +395,32 @@ def _run_batch( else: interface_sequence_recovery_per_design = None + # Per-design overall confidence = exp(-NLL) of the *sampled* sequence + # under the model's un-temperatured logits, mirroring LigandMPNN's + # overall/ligand confidence. Per-residue confidence (exp(-NLL) per + # position, zeroed at non-designed positions) is written into the + # output structure below. + overall_confidence_per_design = np.exp( + -metrics_output["overall_confidence.nll_per_example"].detach().cpu().numpy() + ) + nll_per_residue = ( + metrics_output["overall_confidence.nll_per_residue"].detach().cpu().numpy() + ) + confidence_per_residue_mask = ( + metrics_output["overall_confidence.per_residue_mask"].detach().cpu().numpy() + ) + confidence_per_residue = np.exp(-nll_per_residue) * confidence_per_residue_mask + + if self.model_type == "ligand_mpnn": + ligand_confidence_per_design = np.exp( + -metrics_output["ligand_confidence.interface_nll_per_example"] + .detach() + .cpu() + .numpy() + ) + else: + ligand_confidence_per_design = None + # Grab the index to token mapping from the model. idx_to_token = MPNN_TOKEN_ENCODING.idx_to_token @@ -436,6 +470,19 @@ def _run_batch( # Overwrite with designed residue names. design_atom_array.set_annotation("res_name", full_resnames) + # Spread per-residue confidence (token-level) to atom level over the + # non-atomized subset and overwrite the b-factor with it, so it is + # written to the standard CIF '_atom_site.B_iso_or_equiv' column and + # picked up by viewers (analogous to LigandMPNN's per-residue + # confidence b-factors). Non-designed positions are set to 0. + design_confidence_atom = spread_token_wise( + design_non_atomized_array, + confidence_per_residue[design_idx], + ) + full_confidence = np.zeros(len(design_atom_array), dtype=np.float32) + full_confidence[~design_atom_array.atomize] = design_confidence_atom + design_atom_array.set_annotation("b_factor", full_confidence) + # We need to remove any non-atomized residue atoms that no # longer belong (i.e. old side chain atoms). We want to keep any # atom that is atomized, any atom that is a backbone atom, and @@ -478,6 +525,12 @@ def _run_batch( "batch_idx": batch_idx, "design_idx": design_idx, "designed_sequence": one_letter_seq, + "overall_confidence": float(overall_confidence_per_design[design_idx]), + "ligand_confidence": ( + float(ligand_confidence_per_design[design_idx]) + if ligand_confidence_per_design is not None + else None + ), "sequence_recovery": sequence_recovery, "ligand_interface_sequence_recovery": ( ligand_interface_sequence_recovery diff --git a/models/mpnn/src/mpnn/metrics/nll.py b/models/mpnn/src/mpnn/metrics/nll.py index 1e28b7fb..7fe249e8 100644 --- a/models/mpnn/src/mpnn/metrics/nll.py +++ b/models/mpnn/src/mpnn/metrics/nll.py @@ -367,3 +367,55 @@ def compute(self, log_probs, S, mask_for_loss, atom_array, **kwargs): interface_metrics[f"interface_{key}"] = value return interface_metrics + + +class SampledNLL(NLL): + """NLL / confidence of the *sampled* (designed) sequence. + + Unlike :class:`NLL` (which scores the native sequence as a training + metric), this scores the sampled sequence under the model's + un-temperatured, un-biased log-probabilities (``log_softmax`` of the raw + logits). This matches the "overall_confidence" reported by the original + LigandMPNN, where ``confidence = exp(-mean_nll)``. + """ + + @property + def kwargs_to_compute_args(self): + mapping = super().kwargs_to_compute_args + mapping["S"] = ("network_output", "decoder_features", "S_sampled") + mapping["log_probs"] = ("network_output", "decoder_features", "logits") + return mapping + + def compute(self, log_probs, S, mask_for_loss, **kwargs): + # 'log_probs' is actually the raw logits; convert to true log-probs at + # temperature 1.0 (no bias) to match LigandMPNN's confidence definition. + log_probs = torch.log_softmax(log_probs, dim=-1) + return super().compute( + log_probs=log_probs, S=S, mask_for_loss=mask_for_loss, **kwargs + ) + + +class SampledInterfaceNLL(InterfaceNLL): + """Interface NLL / confidence of the *sampled* (designed) sequence. + + Mirrors LigandMPNN's "ligand_confidence" (polymer-ligand interface + residues only), computed on the un-temperatured logits of the sampled + sequence. + """ + + @property + def kwargs_to_compute_args(self): + mapping = super().kwargs_to_compute_args + mapping["S"] = ("network_output", "decoder_features", "S_sampled") + mapping["log_probs"] = ("network_output", "decoder_features", "logits") + return mapping + + def compute(self, log_probs, S, mask_for_loss, atom_array, **kwargs): + log_probs = torch.log_softmax(log_probs, dim=-1) + return super().compute( + log_probs=log_probs, + S=S, + mask_for_loss=mask_for_loss, + atom_array=atom_array, + **kwargs, + ) diff --git a/models/mpnn/src/mpnn/transforms/feature_aggregation/user_settings.py b/models/mpnn/src/mpnn/transforms/feature_aggregation/user_settings.py index ae756503..3738d3b8 100644 --- a/models/mpnn/src/mpnn/transforms/feature_aggregation/user_settings.py +++ b/models/mpnn/src/mpnn/transforms/feature_aggregation/user_settings.py @@ -236,7 +236,13 @@ def forward(self, data: dict[str, Any]) -> dict[str, Any]: "input_features": [ "mask_for_loss", ], - "decoder_features": ["log_probs", "S_sampled", "S_argmax"], + # 'logits' are needed to compute un-temperatured confidence. + "decoder_features": [ + "log_probs", + "S_sampled", + "S_argmax", + "logits", + ], } # Save the scalar settings. diff --git a/models/mpnn/src/mpnn/utils/inference.py b/models/mpnn/src/mpnn/utils/inference.py index a4fc50bc..52908343 100644 --- a/models/mpnn/src/mpnn/utils/inference.py +++ b/models/mpnn/src/mpnn/utils/inference.py @@ -2210,6 +2210,8 @@ class MPNNInferenceOutput: - 'batch_idx' - 'design_idx' - 'designed_sequence' + - 'overall_confidence' + - 'ligand_confidence' - 'sequence_recovery' - 'ligand_interface_sequence_recovery' - 'model_type' @@ -2367,6 +2369,15 @@ def write_fasta( decorated_name = "_".join(name_fields) header_fields.append(decorated_name) + # Construct the confidence fields for the header (exp(-NLL) of the + # sampled sequence; mirrors LigandMPNN's overall/ligand confidence). + overall_confidence = self.output_dict.get("overall_confidence") + ligand_confidence = self.output_dict.get("ligand_confidence") + if overall_confidence is not None: + header_fields.append(f"overall_confidence={float(overall_confidence):.4f}") + if ligand_confidence is not None: + header_fields.append(f"ligand_confidence={float(ligand_confidence):.4f}") + # Construct the recovery fields for the header. if sequence_recovery is not None: header_fields.append(f"sequence_recovery={float(sequence_recovery):.4f}") From f34d2212b2695263c241fc93768b9c359091630f Mon Sep 17 00:00:00 2001 From: daylight-00 Date: Fri, 5 Jun 2026 16:43:51 +0900 Subject: [PATCH 2/5] test(mpnn): add unit tests and docstrings for sampled confidence metrics Cover the SampledNLL / SampledInterfaceNLL metrics added in the previous commit: - Assert the kwargs remapping reads the sampled sequence (S_sampled) and the raw logits, not the native sequence or temperature-scaled log_probs. - Assert SampledNLL.compute equals the hand-computed NLL of the sampled sequence under log_softmax(logits), with masked positions zeroed. These are self-contained (no structure fixtures) so they run without the test-data assets the integration suite requires. Also add Google-format docstrings to the overridden compute() methods. --- models/mpnn/src/mpnn/metrics/nll.py | 32 +++++++++++++++- models/mpnn/tests/test_metrics.py | 58 ++++++++++++++++++++++++++++- 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/models/mpnn/src/mpnn/metrics/nll.py b/models/mpnn/src/mpnn/metrics/nll.py index 7fe249e8..0a3a0d29 100644 --- a/models/mpnn/src/mpnn/metrics/nll.py +++ b/models/mpnn/src/mpnn/metrics/nll.py @@ -387,8 +387,20 @@ def kwargs_to_compute_args(self): return mapping def compute(self, log_probs, S, mask_for_loss, **kwargs): - # 'log_probs' is actually the raw logits; convert to true log-probs at - # temperature 1.0 (no bias) to match LigandMPNN's confidence definition. + """Convert raw logits to log-probabilities and delegate to ``NLL``. + + Args: + log_probs (torch.Tensor): [B, L, vocab_size] - raw model logits + (mapped from ``decoder_features.logits``), not log-probs. + S (torch.Tensor): [B, L] - the sampled (designed) sequence. + mask_for_loss (torch.Tensor): [B, L] - mask for loss. + **kwargs: Additional arguments forwarded to ``NLL.compute``. + + Returns: + dict: The NLL / confidence metrics from ``NLL.compute``. + """ + # Raw logits -> true log-probs at temperature 1.0 (no bias), matching + # LigandMPNN's confidence definition. log_probs = torch.log_softmax(log_probs, dim=-1) return super().compute( log_probs=log_probs, S=S, mask_for_loss=mask_for_loss, **kwargs @@ -411,6 +423,22 @@ def kwargs_to_compute_args(self): return mapping def compute(self, log_probs, S, mask_for_loss, atom_array, **kwargs): + """Convert raw logits to log-probabilities and delegate to ``InterfaceNLL``. + + Args: + log_probs (torch.Tensor): [B, L, vocab_size] - raw model logits + (mapped from ``decoder_features.logits``), not log-probs. + S (torch.Tensor): [B, L] - the sampled (designed) sequence. + mask_for_loss (torch.Tensor): [B, L] - mask for loss. + atom_array: Atom array(s) used to derive the polymer-ligand + interface mask. + **kwargs: Additional arguments forwarded to ``InterfaceNLL.compute``. + + Returns: + dict: The interface NLL / confidence metrics (``interface_`` prefix). + """ + # Raw logits -> true log-probs at temperature 1.0 (no bias), matching + # LigandMPNN's confidence definition. log_probs = torch.log_softmax(log_probs, dim=-1) return super().compute( log_probs=log_probs, diff --git a/models/mpnn/tests/test_metrics.py b/models/mpnn/tests/test_metrics.py index d5a68398..67a43a05 100644 --- a/models/mpnn/tests/test_metrics.py +++ b/models/mpnn/tests/test_metrics.py @@ -6,8 +6,9 @@ """ import pytest +import torch from atomworks.ml.utils.testing import cached_parse -from mpnn.metrics.nll import NLL, InterfaceNLL +from mpnn.metrics.nll import NLL, InterfaceNLL, SampledInterfaceNLL, SampledNLL from mpnn.metrics.sequence_recovery import InterfaceSequenceRecovery, SequenceRecovery from mpnn.pipelines.mpnn import build_mpnn_transform_pipeline from test_utils import ( @@ -90,3 +91,58 @@ def test_metrics_comprehensive(self, pdb_id, model_type, is_inference): return_per_example=True, return_per_residue=True, ) + + def test_sampled_confidence_metrics_read_sampled_logits(self): + """SampledNLL/SampledInterfaceNLL must score the *sampled* sequence + using the raw model logits (not the native sequence or the + temperature-scaled log_probs).""" + for metric in (SampledNLL(), SampledInterfaceNLL()): + mapping = metric.kwargs_to_compute_args + assert mapping["S"] == ("network_output", "decoder_features", "S_sampled") + assert mapping["log_probs"] == ( + "network_output", + "decoder_features", + "logits", + ) + # The interface variant additionally needs the atom array for masking. + assert SampledInterfaceNLL().kwargs_to_compute_args["atom_array"] == ( + "network_input", + "atom_array", + ) + + def test_sampled_nll_equals_log_softmax_of_logits_on_sampled_sequence(self): + """SampledNLL.compute must equal the hand-computed NLL of the sampled + sequence under log_softmax(logits), and must ignore the native + sequence.""" + batch, length, vocab = 1, 4, 21 + torch.manual_seed(0) + logits = torch.randn(batch, length, vocab) + sampled = torch.tensor([[1, 5, 5, 10]]) + native = torch.tensor([[0, 0, 0, 0]]) # deliberately != sampled + mask = torch.tensor([[True, True, True, False]]) + + network_output = { + "decoder_features": {"logits": logits, "S_sampled": sampled}, + "input_features": {"mask_for_loss": mask, "S": native}, + } + + metric = SampledNLL( + return_per_example_metrics=True, return_per_residue_metrics=True + ) + out = metric.compute_from_kwargs(network_output=network_output) + + # Expected: mean over the (3) masked-in positions of -log_softmax(logits). + log_probs = torch.log_softmax(logits, dim=-1) + per_res = -log_probs[0, torch.arange(length), sampled[0]] + expected_nll = per_res[:3].sum() / 3.0 + + assert torch.allclose(out["nll_per_example"][0], expected_nll, atol=1e-6) + # Per-residue NLL is zeroed at the masked-out position. + assert torch.allclose( + out["nll_per_residue"][0], + torch.tensor([per_res[0], per_res[1], per_res[2], 0.0]), + atol=1e-6, + ) + # Must NOT coincide with the NLL of the (different) native sequence. + native_nll = (-log_probs[0, torch.arange(length), native[0]])[:3].sum() / 3.0 + assert not torch.allclose(out["nll_per_example"][0], native_nll, atol=1e-4) From ace5a79ad259a1ebdb0f9214783f49220c0fe676 Mon Sep 17 00:00:00 2001 From: daylight-00 Date: Fri, 5 Jun 2026 17:16:05 +0900 Subject: [PATCH 3/5] refactor(mpnn): rename confidence input to `logits`; add interface test Address PR review feedback: - Rename the `log_probs` compute parameter and its `kwargs_to_compute_args` key to `logits` in SampledNLL/SampledInterfaceNLL. The argument carries raw logits, so the name now matches the contents. - Add a unit test for SampledInterfaceNLL that injects a known interface mask and asserts the interface NLL is restricted to interface positions and scores the sampled sequence (not native S) under log_softmax(logits). --- models/mpnn/src/mpnn/metrics/nll.py | 24 +++++++------ models/mpnn/tests/test_metrics.py | 54 ++++++++++++++++++++++++++++- 2 files changed, 67 insertions(+), 11 deletions(-) diff --git a/models/mpnn/src/mpnn/metrics/nll.py b/models/mpnn/src/mpnn/metrics/nll.py index 0a3a0d29..ffd12d63 100644 --- a/models/mpnn/src/mpnn/metrics/nll.py +++ b/models/mpnn/src/mpnn/metrics/nll.py @@ -381,17 +381,19 @@ class SampledNLL(NLL): @property def kwargs_to_compute_args(self): + # Score the *sampled* sequence using the raw logits (the parent's + # 'log_probs' kwarg is dropped in favour of 'logits'). mapping = super().kwargs_to_compute_args + del mapping["log_probs"] mapping["S"] = ("network_output", "decoder_features", "S_sampled") - mapping["log_probs"] = ("network_output", "decoder_features", "logits") + mapping["logits"] = ("network_output", "decoder_features", "logits") return mapping - def compute(self, log_probs, S, mask_for_loss, **kwargs): + def compute(self, logits, S, mask_for_loss, **kwargs): """Convert raw logits to log-probabilities and delegate to ``NLL``. Args: - log_probs (torch.Tensor): [B, L, vocab_size] - raw model logits - (mapped from ``decoder_features.logits``), not log-probs. + logits (torch.Tensor): [B, L, vocab_size] - raw model logits. S (torch.Tensor): [B, L] - the sampled (designed) sequence. mask_for_loss (torch.Tensor): [B, L] - mask for loss. **kwargs: Additional arguments forwarded to ``NLL.compute``. @@ -401,7 +403,7 @@ def compute(self, log_probs, S, mask_for_loss, **kwargs): """ # Raw logits -> true log-probs at temperature 1.0 (no bias), matching # LigandMPNN's confidence definition. - log_probs = torch.log_softmax(log_probs, dim=-1) + log_probs = torch.log_softmax(logits, dim=-1) return super().compute( log_probs=log_probs, S=S, mask_for_loss=mask_for_loss, **kwargs ) @@ -417,17 +419,19 @@ class SampledInterfaceNLL(InterfaceNLL): @property def kwargs_to_compute_args(self): + # Score the *sampled* sequence using the raw logits (the parent's + # 'log_probs' kwarg is dropped in favour of 'logits'). mapping = super().kwargs_to_compute_args + del mapping["log_probs"] mapping["S"] = ("network_output", "decoder_features", "S_sampled") - mapping["log_probs"] = ("network_output", "decoder_features", "logits") + mapping["logits"] = ("network_output", "decoder_features", "logits") return mapping - def compute(self, log_probs, S, mask_for_loss, atom_array, **kwargs): + def compute(self, logits, S, mask_for_loss, atom_array, **kwargs): """Convert raw logits to log-probabilities and delegate to ``InterfaceNLL``. Args: - log_probs (torch.Tensor): [B, L, vocab_size] - raw model logits - (mapped from ``decoder_features.logits``), not log-probs. + logits (torch.Tensor): [B, L, vocab_size] - raw model logits. S (torch.Tensor): [B, L] - the sampled (designed) sequence. mask_for_loss (torch.Tensor): [B, L] - mask for loss. atom_array: Atom array(s) used to derive the polymer-ligand @@ -439,7 +443,7 @@ def compute(self, log_probs, S, mask_for_loss, atom_array, **kwargs): """ # Raw logits -> true log-probs at temperature 1.0 (no bias), matching # LigandMPNN's confidence definition. - log_probs = torch.log_softmax(log_probs, dim=-1) + log_probs = torch.log_softmax(logits, dim=-1) return super().compute( log_probs=log_probs, S=S, diff --git a/models/mpnn/tests/test_metrics.py b/models/mpnn/tests/test_metrics.py index 67a43a05..83508aaa 100644 --- a/models/mpnn/tests/test_metrics.py +++ b/models/mpnn/tests/test_metrics.py @@ -99,11 +99,13 @@ def test_sampled_confidence_metrics_read_sampled_logits(self): for metric in (SampledNLL(), SampledInterfaceNLL()): mapping = metric.kwargs_to_compute_args assert mapping["S"] == ("network_output", "decoder_features", "S_sampled") - assert mapping["log_probs"] == ( + assert mapping["logits"] == ( "network_output", "decoder_features", "logits", ) + # The parent's native-sequence log_probs input must not leak through. + assert "log_probs" not in mapping # The interface variant additionally needs the atom array for masking. assert SampledInterfaceNLL().kwargs_to_compute_args["atom_array"] == ( "network_input", @@ -146,3 +148,53 @@ def test_sampled_nll_equals_log_softmax_of_logits_on_sampled_sequence(self): # Must NOT coincide with the NLL of the (different) native sequence. native_nll = (-log_probs[0, torch.arange(length), native[0]])[:3].sum() / 3.0 assert not torch.allclose(out["nll_per_example"][0], native_nll, atol=1e-4) + + def test_sampled_interface_nll_restricts_to_interface_and_uses_sampled( + self, monkeypatch + ): + """SampledInterfaceNLL must restrict the NLL to the interface mask and + score the sampled sequence under log_softmax(logits). The interface-mask + derivation itself is inherited from InterfaceNLL (covered by the + integration test); here the structure-derived mask is injected so the + new sampled+logits+masking contract can be checked numerically.""" + batch, length, vocab = 1, 4, 21 + torch.manual_seed(1) + logits = torch.randn(batch, length, vocab) + sampled = torch.tensor([[2, 7, 3, 9]]) + native = torch.tensor([[0, 0, 0, 0]]) # deliberately != sampled + mask_for_loss = torch.ones(batch, length, dtype=torch.bool) + # Pretend only positions 1 and 2 are at the polymer-ligand interface. + interface_mask = torch.tensor([[False, True, True, False]]) + + metric = SampledInterfaceNLL( + return_per_example_metrics=True, return_per_residue_metrics=True + ) + # Bypass the structure-derived interface mask with a known one. + monkeypatch.setattr( + metric, "get_per_residue_mask", lambda mask_for_loss, **kw: interface_mask + ) + + network_output = { + "decoder_features": {"logits": logits, "S_sampled": sampled}, + "input_features": {"mask_for_loss": mask_for_loss, "S": native}, + } + network_input = {"atom_array": None} # unused; mask is injected + + out = metric.compute_from_kwargs( + network_input=network_input, network_output=network_output + ) + + log_probs = torch.log_softmax(logits, dim=-1) + per_res = -log_probs[0, torch.arange(length), sampled[0]] + # Only the two interface positions contribute. + expected = (per_res[1] + per_res[2]) / 2.0 + assert torch.allclose(out["interface_nll_per_example"][0], expected, atol=1e-6) + # Non-interface positions are excluded from the per-residue NLL. + assert out["interface_nll_per_residue"][0, 0] == 0.0 + assert out["interface_nll_per_residue"][0, 3] == 0.0 + # Must NOT coincide with the native sequence over the same mask. + native_res = -log_probs[0, torch.arange(length), native[0]] + native_expected = (native_res[1] + native_res[2]) / 2.0 + assert not torch.allclose( + out["interface_nll_per_example"][0], native_expected, atol=1e-4 + ) From 59fed1803113b3eaed57a49b50446afc2f201d7b Mon Sep 17 00:00:00 2001 From: daylight-00 Date: Sun, 7 Jun 2026 14:02:54 +0900 Subject: [PATCH 4/5] refactor(mpnn): wrap confidence metrics and store per-residue confidence in a dedicated field Address review feedback on #306: - Add MPNNConfidence / MPNNInterfaceConfidence wrappers around SampledNLL / SampledInterfaceNLL that expose confidence = exp(-NLL) directly, moving the NLL-to-confidence transform out of MPNNInferenceEngine into the metric layer (registered as overall_confidence / ligand_confidence). - Store per-residue confidence in a dedicated `mpnn_confidence` AtomArray annotation (serialized to CIF as _atom_site.mpnn_confidence via extra_fields) instead of overwriting b_factor, preserving the input structure's B-factors. - Scope SampledNLL/SampledInterfaceNLL docstrings to NLL only (confidence is documented on the wrappers); copy the parent kwargs mapping defensively (dict + pop) and mask per-residue confidence with torch.where. - Add unit tests for the confidence wrappers. --- .../mpnn/src/mpnn/inference_engines/mpnn.py | 46 +++---- models/mpnn/src/mpnn/metrics/nll.py | 115 +++++++++++++++--- models/mpnn/src/mpnn/utils/inference.py | 1 + models/mpnn/tests/test_metrics.py | 78 +++++++++++- 4 files changed, 200 insertions(+), 40 deletions(-) diff --git a/models/mpnn/src/mpnn/inference_engines/mpnn.py b/models/mpnn/src/mpnn/inference_engines/mpnn.py index c2e8d90c..965bad7e 100644 --- a/models/mpnn/src/mpnn/inference_engines/mpnn.py +++ b/models/mpnn/src/mpnn/inference_engines/mpnn.py @@ -12,7 +12,7 @@ from atomworks.ml.utils.token import get_token_starts, spread_token_wise from biotite.structure import AtomArray from mpnn.collate.feature_collator import FeatureCollator -from mpnn.metrics.nll import SampledInterfaceNLL, SampledNLL +from mpnn.metrics.nll import MPNNConfidence, MPNNInterfaceConfidence from mpnn.metrics.sequence_recovery import ( InterfaceSequenceRecovery, SequenceRecovery, @@ -194,7 +194,7 @@ def _build_metrics_manager(self) -> MetricManager: # Construct metrics dict. metrics: dict[str, Any] = { "sequence_recovery": SequenceRecovery(return_per_example_metrics=True), - "overall_confidence": SampledNLL( + "overall_confidence": MPNNConfidence( return_per_example_metrics=True, return_per_residue_metrics=True, ), @@ -203,7 +203,7 @@ def _build_metrics_manager(self) -> MetricManager: metrics["interface_sequence_recovery"] = InterfaceSequenceRecovery( return_per_example_metrics=True ) - metrics["ligand_confidence"] = SampledInterfaceNLL( + metrics["ligand_confidence"] = MPNNInterfaceConfidence( return_per_example_metrics=True ) @@ -395,25 +395,26 @@ def _run_batch( else: interface_sequence_recovery_per_design = None - # Per-design overall confidence = exp(-NLL) of the *sampled* sequence - # under the model's un-temperatured logits, mirroring LigandMPNN's - # overall/ligand confidence. Per-residue confidence (exp(-NLL) per - # position, zeroed at non-designed positions) is written into the - # output structure below. - overall_confidence_per_design = np.exp( - -metrics_output["overall_confidence.nll_per_example"].detach().cpu().numpy() - ) - nll_per_residue = ( - metrics_output["overall_confidence.nll_per_residue"].detach().cpu().numpy() + # Per-design confidence (= exp(-NLL) of the *sampled* sequence), computed + # by the MPNNConfidence metrics so the NLL-to-confidence transform lives + # in the metric layer. Per-residue confidence is zeroed at non-designed + # positions and written into the output structure below. + overall_confidence_per_design = ( + metrics_output["overall_confidence.confidence_per_example"] + .detach() + .cpu() + .numpy() ) - confidence_per_residue_mask = ( - metrics_output["overall_confidence.per_residue_mask"].detach().cpu().numpy() + confidence_per_residue = ( + metrics_output["overall_confidence.confidence_per_residue"] + .detach() + .cpu() + .numpy() ) - confidence_per_residue = np.exp(-nll_per_residue) * confidence_per_residue_mask if self.model_type == "ligand_mpnn": - ligand_confidence_per_design = np.exp( - -metrics_output["ligand_confidence.interface_nll_per_example"] + ligand_confidence_per_design = ( + metrics_output["ligand_confidence.interface_confidence_per_example"] .detach() .cpu() .numpy() @@ -471,17 +472,16 @@ def _run_batch( design_atom_array.set_annotation("res_name", full_resnames) # Spread per-residue confidence (token-level) to atom level over the - # non-atomized subset and overwrite the b-factor with it, so it is - # written to the standard CIF '_atom_site.B_iso_or_equiv' column and - # picked up by viewers (analogous to LigandMPNN's per-residue - # confidence b-factors). Non-designed positions are set to 0. + # non-atomized subset and store it in a dedicated 'mpnn_confidence' + # annotation (so the original b-factor is preserved). Non-designed + # positions are set to 0. design_confidence_atom = spread_token_wise( design_non_atomized_array, confidence_per_residue[design_idx], ) full_confidence = np.zeros(len(design_atom_array), dtype=np.float32) full_confidence[~design_atom_array.atomize] = design_confidence_atom - design_atom_array.set_annotation("b_factor", full_confidence) + design_atom_array.set_annotation("mpnn_confidence", full_confidence) # We need to remove any non-atomized residue atoms that no # longer belong (i.e. old side chain atoms). We want to keep any diff --git a/models/mpnn/src/mpnn/metrics/nll.py b/models/mpnn/src/mpnn/metrics/nll.py index ffd12d63..7dd82c20 100644 --- a/models/mpnn/src/mpnn/metrics/nll.py +++ b/models/mpnn/src/mpnn/metrics/nll.py @@ -370,21 +370,21 @@ def compute(self, log_probs, S, mask_for_loss, atom_array, **kwargs): class SampledNLL(NLL): - """NLL / confidence of the *sampled* (designed) sequence. + """NLL of the *sampled* (designed) sequence. - Unlike :class:`NLL` (which scores the native sequence as a training - metric), this scores the sampled sequence under the model's - un-temperatured, un-biased log-probabilities (``log_softmax`` of the raw - logits). This matches the "overall_confidence" reported by the original - LigandMPNN, where ``confidence = exp(-mean_nll)``. + Unlike :class:`NLL`, which scores the native sequence as a training metric, + this scores the sampled sequence under the model's un-temperatured, + un-biased log-probabilities (``log_softmax`` of the raw logits). See + :class:`MPNNConfidence` for the derived ``exp(-NLL)`` confidence. """ @property def kwargs_to_compute_args(self): # Score the *sampled* sequence using the raw logits (the parent's - # 'log_probs' kwarg is dropped in favour of 'logits'). - mapping = super().kwargs_to_compute_args - del mapping["log_probs"] + # 'log_probs' kwarg is replaced with 'logits'). Copy first so the + # parent's mapping is never mutated. + mapping = dict(super().kwargs_to_compute_args) + mapping.pop("log_probs", None) mapping["S"] = ("network_output", "decoder_features", "S_sampled") mapping["logits"] = ("network_output", "decoder_features", "logits") return mapping @@ -410,19 +410,20 @@ def compute(self, logits, S, mask_for_loss, **kwargs): class SampledInterfaceNLL(InterfaceNLL): - """Interface NLL / confidence of the *sampled* (designed) sequence. + """Interface NLL of the *sampled* (designed) sequence. - Mirrors LigandMPNN's "ligand_confidence" (polymer-ligand interface - residues only), computed on the un-temperatured logits of the sampled - sequence. + Like :class:`SampledNLL` but restricted to polymer-ligand interface + residues. See :class:`MPNNInterfaceConfidence` for the derived + ``exp(-NLL)`` confidence. """ @property def kwargs_to_compute_args(self): # Score the *sampled* sequence using the raw logits (the parent's - # 'log_probs' kwarg is dropped in favour of 'logits'). - mapping = super().kwargs_to_compute_args - del mapping["log_probs"] + # 'log_probs' kwarg is replaced with 'logits'). Copy first so the + # parent's mapping is never mutated. + mapping = dict(super().kwargs_to_compute_args) + mapping.pop("log_probs", None) mapping["S"] = ("network_output", "decoder_features", "S_sampled") mapping["logits"] = ("network_output", "decoder_features", "logits") return mapping @@ -451,3 +452,85 @@ def compute(self, logits, S, mask_for_loss, atom_array, **kwargs): atom_array=atom_array, **kwargs, ) + + +class MPNNConfidence(SampledNLL): + """Per-design confidence of the sampled (designed) sequence. + + Thin wrapper around :class:`SampledNLL` that additionally exposes the + LigandMPNN-style confidence (``confidence = exp(-NLL)``) so callers do not + need to apply the NLL-to-confidence transform themselves. Confidence lies + in ``(0, 1]``; higher means the model is more confident in the sequence. + """ + + def compute(self, logits, S, mask_for_loss, **kwargs): + """Compute the sampled-sequence NLL and derived confidence. + + Args: + logits (torch.Tensor): [B, L, vocab_size] - raw model logits. + S (torch.Tensor): [B, L] - the sampled (designed) sequence. + mask_for_loss (torch.Tensor): [B, L] - mask for loss. + **kwargs: Additional arguments forwarded to ``SampledNLL.compute``. + + Returns: + dict: The ``SampledNLL`` metrics plus, when the corresponding NLL + outputs are present, ``confidence_per_example`` (= exp(-NLL)) and + ``confidence_per_residue`` (= exp(-NLL) per position, 0 at + non-designed positions). + """ + metrics = super().compute( + logits=logits, S=S, mask_for_loss=mask_for_loss, **kwargs + ) + if "nll_per_example" in metrics: + metrics["confidence_per_example"] = torch.exp(-metrics["nll_per_example"]) + if "nll_per_residue" in metrics: + confidence = torch.exp(-metrics["nll_per_residue"]) + mask = metrics["per_residue_mask"].bool() + metrics["confidence_per_residue"] = torch.where( + mask, confidence, torch.zeros_like(confidence) + ) + return metrics + + +class MPNNInterfaceConfidence(SampledInterfaceNLL): + """Per-design confidence over polymer-ligand interface residues. + + Interface counterpart of :class:`MPNNConfidence`; wraps + :class:`SampledInterfaceNLL` and exposes ``exp(-NLL)`` confidence computed + over the interface residues only (``interface_`` prefixed keys). + """ + + def compute(self, logits, S, mask_for_loss, atom_array, **kwargs): + """Compute the interface NLL and derived interface confidence. + + Args: + logits (torch.Tensor): [B, L, vocab_size] - raw model logits. + S (torch.Tensor): [B, L] - the sampled (designed) sequence. + mask_for_loss (torch.Tensor): [B, L] - mask for loss. + atom_array: Atom array(s) used to derive the interface mask. + **kwargs: Additional arguments forwarded to + ``SampledInterfaceNLL.compute``. + + Returns: + dict: The ``SampledInterfaceNLL`` metrics plus, when present, + ``interface_confidence_per_example`` and + ``interface_confidence_per_residue``. + """ + metrics = super().compute( + logits=logits, + S=S, + mask_for_loss=mask_for_loss, + atom_array=atom_array, + **kwargs, + ) + if "interface_nll_per_example" in metrics: + metrics["interface_confidence_per_example"] = torch.exp( + -metrics["interface_nll_per_example"] + ) + if "interface_nll_per_residue" in metrics: + confidence = torch.exp(-metrics["interface_nll_per_residue"]) + mask = metrics["interface_per_residue_mask"].bool() + metrics["interface_confidence_per_residue"] = torch.where( + mask, confidence, torch.zeros_like(confidence) + ) + return metrics diff --git a/models/mpnn/src/mpnn/utils/inference.py b/models/mpnn/src/mpnn/utils/inference.py index 52908343..65aa95ab 100644 --- a/models/mpnn/src/mpnn/utils/inference.py +++ b/models/mpnn/src/mpnn/utils/inference.py @@ -2294,6 +2294,7 @@ def write_structure( "mpnn_temperature", "mpnn_symmetry_equivalence_group", "mpnn_symmetry_weight", + "mpnn_confidence", ] # Limit to fields actually present in the atom array. diff --git a/models/mpnn/tests/test_metrics.py b/models/mpnn/tests/test_metrics.py index 83508aaa..f636bc1c 100644 --- a/models/mpnn/tests/test_metrics.py +++ b/models/mpnn/tests/test_metrics.py @@ -8,7 +8,14 @@ import pytest import torch from atomworks.ml.utils.testing import cached_parse -from mpnn.metrics.nll import NLL, InterfaceNLL, SampledInterfaceNLL, SampledNLL +from mpnn.metrics.nll import ( + NLL, + InterfaceNLL, + MPNNConfidence, + MPNNInterfaceConfidence, + SampledInterfaceNLL, + SampledNLL, +) from mpnn.metrics.sequence_recovery import InterfaceSequenceRecovery, SequenceRecovery from mpnn.pipelines.mpnn import build_mpnn_transform_pipeline from test_utils import ( @@ -198,3 +205,72 @@ def test_sampled_interface_nll_restricts_to_interface_and_uses_sampled( assert not torch.allclose( out["interface_nll_per_example"][0], native_expected, atol=1e-4 ) + + def test_mpnn_confidence_exposes_exp_neg_nll(self): + """MPNNConfidence must expose confidence = exp(-NLL) directly, so the + engine no longer needs to apply the transform itself.""" + batch, length, vocab = 1, 4, 21 + torch.manual_seed(2) + logits = torch.randn(batch, length, vocab) + sampled = torch.tensor([[3, 1, 8, 4]]) + mask = torch.tensor([[True, True, False, True]]) + + network_output = { + "decoder_features": {"logits": logits, "S_sampled": sampled}, + "input_features": { + "mask_for_loss": mask, + "S": torch.zeros_like(sampled), + }, + } + metric = MPNNConfidence( + return_per_example_metrics=True, return_per_residue_metrics=True + ) + out = metric.compute_from_kwargs(network_output=network_output) + + # confidence_per_example is exactly exp(-nll_per_example). + assert torch.allclose( + out["confidence_per_example"], + torch.exp(-out["nll_per_example"]), + atol=1e-6, + ) + # Per-residue confidence is in (0, 1] on designed positions and 0 on the + # masked-out position. + conf_res = out["confidence_per_residue"][0] + assert conf_res[2] == 0.0 + designed = torch.tensor([0, 1, 3]) + assert torch.all(conf_res[designed] > 0.0) + assert torch.all(conf_res[designed] <= 1.0) + + def test_mpnn_interface_confidence_exposes_exp_neg_nll(self, monkeypatch): + """MPNNInterfaceConfidence must expose interface confidence = exp(-NLL) + over the interface residues only.""" + batch, length, vocab = 1, 4, 21 + torch.manual_seed(3) + logits = torch.randn(batch, length, vocab) + sampled = torch.tensor([[2, 7, 3, 9]]) + mask_for_loss = torch.ones(batch, length, dtype=torch.bool) + interface_mask = torch.tensor([[False, True, True, False]]) + + metric = MPNNInterfaceConfidence( + return_per_example_metrics=True, return_per_residue_metrics=True + ) + monkeypatch.setattr( + metric, "get_per_residue_mask", lambda mask_for_loss, **kw: interface_mask + ) + + network_output = { + "decoder_features": {"logits": logits, "S_sampled": sampled}, + "input_features": { + "mask_for_loss": mask_for_loss, + "S": torch.zeros_like(sampled), + }, + } + out = metric.compute_from_kwargs( + network_input={"atom_array": None}, network_output=network_output + ) + + assert torch.allclose( + out["interface_confidence_per_example"], + torch.exp(-out["interface_nll_per_example"]), + atol=1e-6, + ) From 3b51b82db05ebd2d1350302f7c95b429d47449b6 Mon Sep 17 00:00:00 2001 From: daylight-00 Date: Thu, 18 Jun 2026 18:57:52 +0900 Subject: [PATCH 5/5] =?UTF-8?q?refactor(mpnn):=20apply=20review=20feedback?= =?UTF-8?q?=20=E2=80=94=20naming,=20explicit=20kwargs,=20no-interface=20ed?= =?UTF-8?q?ge=20case?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address @AndrewKubaney's second review on #306: - Build kwargs_to_compute_args explicitly in SampledNLL / SampledLigandInterfaceNLL instead of copying the parent mapping and popping "log_probs" (I'd used copy+pop to avoid mutating the parent, but the explicit literal reads more clearly). - Rename MPNNInterfaceConfidence -> MPNNLigandInterfaceConfidence and SampledInterfaceNLL -> SampledLigandInterfaceNLL so the polymer-ligand interface scope is unambiguous (vs e.g. protein-protein interfaces). - Rename reported fields overall_confidence -> confidence and ligand_confidence -> ligand_interface_confidence, mirroring the existing sequence_recovery / ligand_interface_sequence_recovery keys and Foundry's un-prefixed output-key convention (the mpnn_confidence atom-array field is unchanged). - Handle the no-interface edge case: a ligand_mpnn run with no polymer-ligand interface residues yields an undefined (NaN) interface confidence; the engine now emits None (omitted) rather than NaN. Add a unit test. --- .../mpnn/src/mpnn/inference_engines/mpnn.py | 39 ++++++------- models/mpnn/src/mpnn/metrics/nll.py | 49 ++++++++-------- models/mpnn/src/mpnn/utils/inference.py | 20 ++++--- models/mpnn/tests/test_metrics.py | 56 ++++++++++++++++--- 4 files changed, 104 insertions(+), 60 deletions(-) diff --git a/models/mpnn/src/mpnn/inference_engines/mpnn.py b/models/mpnn/src/mpnn/inference_engines/mpnn.py index 965bad7e..475d25a2 100644 --- a/models/mpnn/src/mpnn/inference_engines/mpnn.py +++ b/models/mpnn/src/mpnn/inference_engines/mpnn.py @@ -12,7 +12,7 @@ from atomworks.ml.utils.token import get_token_starts, spread_token_wise from biotite.structure import AtomArray from mpnn.collate.feature_collator import FeatureCollator -from mpnn.metrics.nll import MPNNConfidence, MPNNInterfaceConfidence +from mpnn.metrics.nll import MPNNConfidence, MPNNLigandInterfaceConfidence from mpnn.metrics.sequence_recovery import ( InterfaceSequenceRecovery, SequenceRecovery, @@ -194,7 +194,7 @@ def _build_metrics_manager(self) -> MetricManager: # Construct metrics dict. metrics: dict[str, Any] = { "sequence_recovery": SequenceRecovery(return_per_example_metrics=True), - "overall_confidence": MPNNConfidence( + "confidence": MPNNConfidence( return_per_example_metrics=True, return_per_residue_metrics=True, ), @@ -203,7 +203,7 @@ def _build_metrics_manager(self) -> MetricManager: metrics["interface_sequence_recovery"] = InterfaceSequenceRecovery( return_per_example_metrics=True ) - metrics["ligand_confidence"] = MPNNInterfaceConfidence( + metrics["ligand_interface_confidence"] = MPNNLigandInterfaceConfidence( return_per_example_metrics=True ) @@ -399,28 +399,28 @@ def _run_batch( # by the MPNNConfidence metrics so the NLL-to-confidence transform lives # in the metric layer. Per-residue confidence is zeroed at non-designed # positions and written into the output structure below. - overall_confidence_per_design = ( - metrics_output["overall_confidence.confidence_per_example"] - .detach() - .cpu() - .numpy() + confidence_per_design = ( + metrics_output["confidence.confidence_per_example"].detach().cpu().numpy() ) confidence_per_residue = ( - metrics_output["overall_confidence.confidence_per_residue"] - .detach() - .cpu() - .numpy() + metrics_output["confidence.confidence_per_residue"].detach().cpu().numpy() ) + # Ligand-interface confidence is undefined (NaN) when there are no + # interface residues (e.g. ligand_mpnn run on a ligand-free input); such + # values are converted to None per design below so they are omitted from + # the outputs rather than written as NaN. if self.model_type == "ligand_mpnn": - ligand_confidence_per_design = ( - metrics_output["ligand_confidence.interface_confidence_per_example"] + ligand_interface_confidence_per_design = ( + metrics_output[ + "ligand_interface_confidence.interface_confidence_per_example" + ] .detach() .cpu() .numpy() ) else: - ligand_confidence_per_design = None + ligand_interface_confidence_per_design = None # Grab the index to token mapping from the model. idx_to_token = MPNN_TOKEN_ENCODING.idx_to_token @@ -525,10 +525,11 @@ def _run_batch( "batch_idx": batch_idx, "design_idx": design_idx, "designed_sequence": one_letter_seq, - "overall_confidence": float(overall_confidence_per_design[design_idx]), - "ligand_confidence": ( - float(ligand_confidence_per_design[design_idx]) - if ligand_confidence_per_design is not None + "confidence": float(confidence_per_design[design_idx]), + "ligand_interface_confidence": ( + float(ligand_interface_confidence_per_design[design_idx]) + if ligand_interface_confidence_per_design is not None + and not np.isnan(ligand_interface_confidence_per_design[design_idx]) else None ), "sequence_recovery": sequence_recovery, diff --git a/models/mpnn/src/mpnn/metrics/nll.py b/models/mpnn/src/mpnn/metrics/nll.py index 7dd82c20..f74e9e2c 100644 --- a/models/mpnn/src/mpnn/metrics/nll.py +++ b/models/mpnn/src/mpnn/metrics/nll.py @@ -380,14 +380,13 @@ class SampledNLL(NLL): @property def kwargs_to_compute_args(self): - # Score the *sampled* sequence using the raw logits (the parent's - # 'log_probs' kwarg is replaced with 'logits'). Copy first so the - # parent's mapping is never mutated. - mapping = dict(super().kwargs_to_compute_args) - mapping.pop("log_probs", None) - mapping["S"] = ("network_output", "decoder_features", "S_sampled") - mapping["logits"] = ("network_output", "decoder_features", "logits") - return mapping + # Build the mapping explicitly: score the *sampled* sequence + # (`S_sampled`) using the raw `logits`. + return { + "logits": ("network_output", "decoder_features", "logits"), + "S": ("network_output", "decoder_features", "S_sampled"), + "mask_for_loss": ("network_output", "input_features", "mask_for_loss"), + } def compute(self, logits, S, mask_for_loss, **kwargs): """Convert raw logits to log-probabilities and delegate to ``NLL``. @@ -409,24 +408,25 @@ def compute(self, logits, S, mask_for_loss, **kwargs): ) -class SampledInterfaceNLL(InterfaceNLL): +class SampledLigandInterfaceNLL(InterfaceNLL): """Interface NLL of the *sampled* (designed) sequence. Like :class:`SampledNLL` but restricted to polymer-ligand interface - residues. See :class:`MPNNInterfaceConfidence` for the derived + residues. See :class:`MPNNLigandInterfaceConfidence` for the derived ``exp(-NLL)`` confidence. """ @property def kwargs_to_compute_args(self): - # Score the *sampled* sequence using the raw logits (the parent's - # 'log_probs' kwarg is replaced with 'logits'). Copy first so the - # parent's mapping is never mutated. - mapping = dict(super().kwargs_to_compute_args) - mapping.pop("log_probs", None) - mapping["S"] = ("network_output", "decoder_features", "S_sampled") - mapping["logits"] = ("network_output", "decoder_features", "logits") - return mapping + # Build the mapping explicitly: score the *sampled* sequence + # (`S_sampled`) using the raw `logits`; `atom_array` derives the + # polymer-ligand interface mask. + return { + "logits": ("network_output", "decoder_features", "logits"), + "S": ("network_output", "decoder_features", "S_sampled"), + "mask_for_loss": ("network_output", "input_features", "mask_for_loss"), + "atom_array": ("network_input", "atom_array"), + } def compute(self, logits, S, mask_for_loss, atom_array, **kwargs): """Convert raw logits to log-probabilities and delegate to ``InterfaceNLL``. @@ -492,12 +492,13 @@ def compute(self, logits, S, mask_for_loss, **kwargs): return metrics -class MPNNInterfaceConfidence(SampledInterfaceNLL): +class MPNNLigandInterfaceConfidence(SampledLigandInterfaceNLL): """Per-design confidence over polymer-ligand interface residues. - Interface counterpart of :class:`MPNNConfidence`; wraps - :class:`SampledInterfaceNLL` and exposes ``exp(-NLL)`` confidence computed - over the interface residues only (``interface_`` prefixed keys). + Ligand-interface counterpart of :class:`MPNNConfidence`; wraps + :class:`SampledLigandInterfaceNLL` and exposes ``exp(-NLL)`` confidence + computed over the polymer-ligand interface residues only (``interface_`` + prefixed keys). """ def compute(self, logits, S, mask_for_loss, atom_array, **kwargs): @@ -509,10 +510,10 @@ def compute(self, logits, S, mask_for_loss, atom_array, **kwargs): mask_for_loss (torch.Tensor): [B, L] - mask for loss. atom_array: Atom array(s) used to derive the interface mask. **kwargs: Additional arguments forwarded to - ``SampledInterfaceNLL.compute``. + ``SampledLigandInterfaceNLL.compute``. Returns: - dict: The ``SampledInterfaceNLL`` metrics plus, when present, + dict: The ``SampledLigandInterfaceNLL`` metrics plus, when present, ``interface_confidence_per_example`` and ``interface_confidence_per_residue``. """ diff --git a/models/mpnn/src/mpnn/utils/inference.py b/models/mpnn/src/mpnn/utils/inference.py index 65aa95ab..8d5661fd 100644 --- a/models/mpnn/src/mpnn/utils/inference.py +++ b/models/mpnn/src/mpnn/utils/inference.py @@ -2210,8 +2210,8 @@ class MPNNInferenceOutput: - 'batch_idx' - 'design_idx' - 'designed_sequence' - - 'overall_confidence' - - 'ligand_confidence' + - 'confidence' + - 'ligand_interface_confidence' - 'sequence_recovery' - 'ligand_interface_sequence_recovery' - 'model_type' @@ -2372,12 +2372,16 @@ def write_fasta( # Construct the confidence fields for the header (exp(-NLL) of the # sampled sequence; mirrors LigandMPNN's overall/ligand confidence). - overall_confidence = self.output_dict.get("overall_confidence") - ligand_confidence = self.output_dict.get("ligand_confidence") - if overall_confidence is not None: - header_fields.append(f"overall_confidence={float(overall_confidence):.4f}") - if ligand_confidence is not None: - header_fields.append(f"ligand_confidence={float(ligand_confidence):.4f}") + confidence = self.output_dict.get("confidence") + ligand_interface_confidence = self.output_dict.get( + "ligand_interface_confidence" + ) + if confidence is not None: + header_fields.append(f"confidence={float(confidence):.4f}") + if ligand_interface_confidence is not None: + header_fields.append( + f"ligand_interface_confidence={float(ligand_interface_confidence):.4f}" + ) # Construct the recovery fields for the header. if sequence_recovery is not None: diff --git a/models/mpnn/tests/test_metrics.py b/models/mpnn/tests/test_metrics.py index f636bc1c..dc271f26 100644 --- a/models/mpnn/tests/test_metrics.py +++ b/models/mpnn/tests/test_metrics.py @@ -12,8 +12,8 @@ NLL, InterfaceNLL, MPNNConfidence, - MPNNInterfaceConfidence, - SampledInterfaceNLL, + MPNNLigandInterfaceConfidence, + SampledLigandInterfaceNLL, SampledNLL, ) from mpnn.metrics.sequence_recovery import InterfaceSequenceRecovery, SequenceRecovery @@ -100,10 +100,10 @@ def test_metrics_comprehensive(self, pdb_id, model_type, is_inference): ) def test_sampled_confidence_metrics_read_sampled_logits(self): - """SampledNLL/SampledInterfaceNLL must score the *sampled* sequence + """SampledNLL/SampledLigandInterfaceNLL must score the *sampled* sequence using the raw model logits (not the native sequence or the temperature-scaled log_probs).""" - for metric in (SampledNLL(), SampledInterfaceNLL()): + for metric in (SampledNLL(), SampledLigandInterfaceNLL()): mapping = metric.kwargs_to_compute_args assert mapping["S"] == ("network_output", "decoder_features", "S_sampled") assert mapping["logits"] == ( @@ -114,7 +114,7 @@ def test_sampled_confidence_metrics_read_sampled_logits(self): # The parent's native-sequence log_probs input must not leak through. assert "log_probs" not in mapping # The interface variant additionally needs the atom array for masking. - assert SampledInterfaceNLL().kwargs_to_compute_args["atom_array"] == ( + assert SampledLigandInterfaceNLL().kwargs_to_compute_args["atom_array"] == ( "network_input", "atom_array", ) @@ -159,7 +159,7 @@ def test_sampled_nll_equals_log_softmax_of_logits_on_sampled_sequence(self): def test_sampled_interface_nll_restricts_to_interface_and_uses_sampled( self, monkeypatch ): - """SampledInterfaceNLL must restrict the NLL to the interface mask and + """SampledLigandInterfaceNLL must restrict the NLL to the interface mask and score the sampled sequence under log_softmax(logits). The interface-mask derivation itself is inherited from InterfaceNLL (covered by the integration test); here the structure-derived mask is injected so the @@ -173,7 +173,7 @@ def test_sampled_interface_nll_restricts_to_interface_and_uses_sampled( # Pretend only positions 1 and 2 are at the polymer-ligand interface. interface_mask = torch.tensor([[False, True, True, False]]) - metric = SampledInterfaceNLL( + metric = SampledLigandInterfaceNLL( return_per_example_metrics=True, return_per_residue_metrics=True ) # Bypass the structure-derived interface mask with a known one. @@ -242,7 +242,7 @@ def test_mpnn_confidence_exposes_exp_neg_nll(self): assert torch.all(conf_res[designed] <= 1.0) def test_mpnn_interface_confidence_exposes_exp_neg_nll(self, monkeypatch): - """MPNNInterfaceConfidence must expose interface confidence = exp(-NLL) + """MPNNLigandInterfaceConfidence must expose interface confidence = exp(-NLL) over the interface residues only.""" batch, length, vocab = 1, 4, 21 torch.manual_seed(3) @@ -251,7 +251,7 @@ def test_mpnn_interface_confidence_exposes_exp_neg_nll(self, monkeypatch): mask_for_loss = torch.ones(batch, length, dtype=torch.bool) interface_mask = torch.tensor([[False, True, True, False]]) - metric = MPNNInterfaceConfidence( + metric = MPNNLigandInterfaceConfidence( return_per_example_metrics=True, return_per_residue_metrics=True ) monkeypatch.setattr( @@ -274,3 +274,41 @@ def test_mpnn_interface_confidence_exposes_exp_neg_nll(self, monkeypatch): torch.exp(-out["interface_nll_per_example"]), atol=1e-6, ) + + def test_mpnn_ligand_interface_confidence_no_interface_residues(self, monkeypatch): + """With no polymer-ligand interface residues (e.g. ligand_mpnn run on a + ligand-free input), the interface confidence must be cleanly undefined + (NaN, flagged by valid_examples_mask) rather than crashing. The + inference engine converts this NaN to None so it is omitted from the + outputs.""" + batch, length, vocab = 1, 4, 21 + torch.manual_seed(4) + logits = torch.randn(batch, length, vocab) + sampled = torch.tensor([[2, 7, 3, 9]]) + mask_for_loss = torch.ones(batch, length, dtype=torch.bool) + # No residues at the interface (empty mask). + empty_interface = torch.zeros(batch, length, dtype=torch.bool) + + metric = MPNNLigandInterfaceConfidence( + return_per_example_metrics=True, return_per_residue_metrics=True + ) + monkeypatch.setattr( + metric, "get_per_residue_mask", lambda mask_for_loss, **kw: empty_interface + ) + + network_output = { + "decoder_features": {"logits": logits, "S_sampled": sampled}, + "input_features": { + "mask_for_loss": mask_for_loss, + "S": torch.zeros_like(sampled), + }, + } + # Must not raise. + out = metric.compute_from_kwargs( + network_input={"atom_array": None}, network_output=network_output + ) + + assert bool(out["interface_valid_examples_mask"][0]) is False + assert torch.isnan(out["interface_confidence_per_example"][0]) + # Per-residue confidence is all zero (no interface positions). + assert torch.all(out["interface_confidence_per_residue"] == 0.0)