Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions models/mpnn/src/mpnn/inference_engines/mpnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from atomworks.ml.utils.token import get_token_starts, spread_token_wise
from biotite.structure import AtomArray
from mpnn.collate.feature_collator import FeatureCollator
from mpnn.metrics.nll import SampledInterfaceNLL, SampledNLL
from mpnn.metrics.sequence_recovery import (
InterfaceSequenceRecovery,
SequenceRecovery,
Expand Down Expand Up @@ -193,11 +194,18 @@ def _build_metrics_manager(self) -> MetricManager:
# Construct metrics dict.
metrics: dict[str, Any] = {
"sequence_recovery": SequenceRecovery(return_per_example_metrics=True),
"overall_confidence": SampledNLL(
return_per_example_metrics=True,
return_per_residue_metrics=True,
),
}
if self.model_type == "ligand_mpnn":
metrics["interface_sequence_recovery"] = InterfaceSequenceRecovery(
return_per_example_metrics=True
)
metrics["ligand_confidence"] = SampledInterfaceNLL(
return_per_example_metrics=True
)

# Construct the MetricManager.
metric_manager = MetricManager.from_metrics(metrics, raise_errors=True)
Expand Down Expand Up @@ -387,6 +395,32 @@ def _run_batch(
else:
interface_sequence_recovery_per_design = None

# Per-design overall confidence = exp(-NLL) of the *sampled* sequence
# under the model's un-temperatured logits, mirroring LigandMPNN's
# overall/ligand confidence. Per-residue confidence (exp(-NLL) per
# position, zeroed at non-designed positions) is written into the
# output structure below.
overall_confidence_per_design = np.exp(
-metrics_output["overall_confidence.nll_per_example"].detach().cpu().numpy()
)
nll_per_residue = (
metrics_output["overall_confidence.nll_per_residue"].detach().cpu().numpy()
)
confidence_per_residue_mask = (
metrics_output["overall_confidence.per_residue_mask"].detach().cpu().numpy()
)
confidence_per_residue = np.exp(-nll_per_residue) * confidence_per_residue_mask

if self.model_type == "ligand_mpnn":
ligand_confidence_per_design = np.exp(
-metrics_output["ligand_confidence.interface_nll_per_example"]
.detach()
.cpu()
.numpy()
)
else:
ligand_confidence_per_design = None

# Grab the index to token mapping from the model.
idx_to_token = MPNN_TOKEN_ENCODING.idx_to_token

Expand Down Expand Up @@ -436,6 +470,19 @@ def _run_batch(
# Overwrite with designed residue names.
design_atom_array.set_annotation("res_name", full_resnames)

# Spread per-residue confidence (token-level) to atom level over the
# non-atomized subset and overwrite the b-factor with it, so it is
# written to the standard CIF '_atom_site.B_iso_or_equiv' column and
# picked up by viewers (analogous to LigandMPNN's per-residue
# confidence b-factors). Non-designed positions are set to 0.
design_confidence_atom = spread_token_wise(
design_non_atomized_array,
confidence_per_residue[design_idx],
)
full_confidence = np.zeros(len(design_atom_array), dtype=np.float32)
full_confidence[~design_atom_array.atomize] = design_confidence_atom
design_atom_array.set_annotation("b_factor", full_confidence)
Comment thread
daylight-00 marked this conversation as resolved.
Outdated

# We need to remove any non-atomized residue atoms that no
# longer belong (i.e. old side chain atoms). We want to keep any
# atom that is atomized, any atom that is a backbone atom, and
Expand Down Expand Up @@ -478,6 +525,12 @@ def _run_batch(
"batch_idx": batch_idx,
"design_idx": design_idx,
"designed_sequence": one_letter_seq,
"overall_confidence": float(overall_confidence_per_design[design_idx]),
"ligand_confidence": (
float(ligand_confidence_per_design[design_idx])
if ligand_confidence_per_design is not None
else None
),
"sequence_recovery": sequence_recovery,
"ligand_interface_sequence_recovery": (
ligand_interface_sequence_recovery
Expand Down
80 changes: 80 additions & 0 deletions models/mpnn/src/mpnn/metrics/nll.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,83 @@ def compute(self, log_probs, S, mask_for_loss, atom_array, **kwargs):
interface_metrics[f"interface_{key}"] = value

return interface_metrics


class SampledNLL(NLL):
"""NLL / confidence of the *sampled* (designed) sequence.

Unlike :class:`NLL` (which scores the native sequence as a training
metric), this scores the sampled sequence under the model's
un-temperatured, un-biased log-probabilities (``log_softmax`` of the raw
logits). This matches the "overall_confidence" reported by the original
LigandMPNN, where ``confidence = exp(-mean_nll)``.
"""

@property
def kwargs_to_compute_args(self):
mapping = super().kwargs_to_compute_args
mapping["S"] = ("network_output", "decoder_features", "S_sampled")
mapping["log_probs"] = ("network_output", "decoder_features", "logits")
return mapping
Comment thread
daylight-00 marked this conversation as resolved.
Outdated

def compute(self, log_probs, S, mask_for_loss, **kwargs):
Comment thread
daylight-00 marked this conversation as resolved.
Outdated
"""Convert raw logits to log-probabilities and delegate to ``NLL``.

Args:
log_probs (torch.Tensor): [B, L, vocab_size] - raw model logits
(mapped from ``decoder_features.logits``), not log-probs.
S (torch.Tensor): [B, L] - the sampled (designed) sequence.
mask_for_loss (torch.Tensor): [B, L] - mask for loss.
**kwargs: Additional arguments forwarded to ``NLL.compute``.

Returns:
dict: The NLL / confidence metrics from ``NLL.compute``.
"""
# Raw logits -> true log-probs at temperature 1.0 (no bias), matching
# LigandMPNN's confidence definition.
log_probs = torch.log_softmax(log_probs, dim=-1)
return super().compute(
log_probs=log_probs, S=S, mask_for_loss=mask_for_loss, **kwargs
)
Comment thread
daylight-00 marked this conversation as resolved.


class SampledInterfaceNLL(InterfaceNLL):
"""Interface NLL / confidence of the *sampled* (designed) sequence.

Mirrors LigandMPNN's "ligand_confidence" (polymer-ligand interface
residues only), computed on the un-temperatured logits of the sampled
sequence.
"""
Comment thread
daylight-00 marked this conversation as resolved.
Outdated

@property
def kwargs_to_compute_args(self):
mapping = super().kwargs_to_compute_args
mapping["S"] = ("network_output", "decoder_features", "S_sampled")
mapping["log_probs"] = ("network_output", "decoder_features", "logits")
return mapping
Comment thread
daylight-00 marked this conversation as resolved.
Outdated

def compute(self, log_probs, S, mask_for_loss, atom_array, **kwargs):
Comment thread
daylight-00 marked this conversation as resolved.
Outdated
Comment thread
daylight-00 marked this conversation as resolved.
Outdated
"""Convert raw logits to log-probabilities and delegate to ``InterfaceNLL``.

Args:
log_probs (torch.Tensor): [B, L, vocab_size] - raw model logits
(mapped from ``decoder_features.logits``), not log-probs.
S (torch.Tensor): [B, L] - the sampled (designed) sequence.
mask_for_loss (torch.Tensor): [B, L] - mask for loss.
atom_array: Atom array(s) used to derive the polymer-ligand
interface mask.
**kwargs: Additional arguments forwarded to ``InterfaceNLL.compute``.

Returns:
dict: The interface NLL / confidence metrics (``interface_`` prefix).
"""
# Raw logits -> true log-probs at temperature 1.0 (no bias), matching
# LigandMPNN's confidence definition.
log_probs = torch.log_softmax(log_probs, dim=-1)
return super().compute(
log_probs=log_probs,
S=S,
mask_for_loss=mask_for_loss,
atom_array=atom_array,
**kwargs,
)
Comment thread
daylight-00 marked this conversation as resolved.
Comment thread
daylight-00 marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,13 @@ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
"input_features": [
"mask_for_loss",
],
"decoder_features": ["log_probs", "S_sampled", "S_argmax"],
# 'logits' are needed to compute un-temperatured confidence.
"decoder_features": [
"log_probs",
"S_sampled",
"S_argmax",
"logits",
],
}

# Save the scalar settings.
Expand Down
11 changes: 11 additions & 0 deletions models/mpnn/src/mpnn/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,8 @@ class MPNNInferenceOutput:
- 'batch_idx'
- 'design_idx'
- 'designed_sequence'
- 'overall_confidence'
- 'ligand_confidence'
- 'sequence_recovery'
- 'ligand_interface_sequence_recovery'
- 'model_type'
Expand Down Expand Up @@ -2367,6 +2369,15 @@ def write_fasta(
decorated_name = "_".join(name_fields)
header_fields.append(decorated_name)

# Construct the confidence fields for the header (exp(-NLL) of the
# sampled sequence; mirrors LigandMPNN's overall/ligand confidence).
overall_confidence = self.output_dict.get("overall_confidence")
ligand_confidence = self.output_dict.get("ligand_confidence")
if overall_confidence is not None:
header_fields.append(f"overall_confidence={float(overall_confidence):.4f}")
if ligand_confidence is not None:
header_fields.append(f"ligand_confidence={float(ligand_confidence):.4f}")

# Construct the recovery fields for the header.
if sequence_recovery is not None:
header_fields.append(f"sequence_recovery={float(sequence_recovery):.4f}")
Expand Down
58 changes: 57 additions & 1 deletion models/mpnn/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
"""

import pytest
import torch
from atomworks.ml.utils.testing import cached_parse
from mpnn.metrics.nll import NLL, InterfaceNLL
from mpnn.metrics.nll import NLL, InterfaceNLL, SampledInterfaceNLL, SampledNLL
from mpnn.metrics.sequence_recovery import InterfaceSequenceRecovery, SequenceRecovery
from mpnn.pipelines.mpnn import build_mpnn_transform_pipeline
from test_utils import (
Expand Down Expand Up @@ -90,3 +91,58 @@ def test_metrics_comprehensive(self, pdb_id, model_type, is_inference):
return_per_example=True,
return_per_residue=True,
)

def test_sampled_confidence_metrics_read_sampled_logits(self):
"""SampledNLL/SampledInterfaceNLL must score the *sampled* sequence
using the raw model logits (not the native sequence or the
temperature-scaled log_probs)."""
for metric in (SampledNLL(), SampledInterfaceNLL()):
mapping = metric.kwargs_to_compute_args
assert mapping["S"] == ("network_output", "decoder_features", "S_sampled")
assert mapping["log_probs"] == (
"network_output",
"decoder_features",
"logits",
)
# The interface variant additionally needs the atom array for masking.
assert SampledInterfaceNLL().kwargs_to_compute_args["atom_array"] == (
"network_input",
"atom_array",
)

def test_sampled_nll_equals_log_softmax_of_logits_on_sampled_sequence(self):
"""SampledNLL.compute must equal the hand-computed NLL of the sampled
sequence under log_softmax(logits), and must ignore the native
sequence."""
batch, length, vocab = 1, 4, 21
torch.manual_seed(0)
logits = torch.randn(batch, length, vocab)
sampled = torch.tensor([[1, 5, 5, 10]])
native = torch.tensor([[0, 0, 0, 0]]) # deliberately != sampled
mask = torch.tensor([[True, True, True, False]])

network_output = {
"decoder_features": {"logits": logits, "S_sampled": sampled},
"input_features": {"mask_for_loss": mask, "S": native},
}

metric = SampledNLL(
return_per_example_metrics=True, return_per_residue_metrics=True
)
out = metric.compute_from_kwargs(network_output=network_output)

# Expected: mean over the (3) masked-in positions of -log_softmax(logits).
log_probs = torch.log_softmax(logits, dim=-1)
per_res = -log_probs[0, torch.arange(length), sampled[0]]
expected_nll = per_res[:3].sum() / 3.0

assert torch.allclose(out["nll_per_example"][0], expected_nll, atol=1e-6)
# Per-residue NLL is zeroed at the masked-out position.
assert torch.allclose(
out["nll_per_residue"][0],
torch.tensor([per_res[0], per_res[1], per_res[2], 0.0]),
atol=1e-6,
)
# Must NOT coincide with the NLL of the (different) native sequence.
native_nll = (-log_probs[0, torch.arange(length), native[0]])[:3].sum() / 3.0
assert not torch.allclose(out["nll_per_example"][0], native_nll, atol=1e-4)
Loading