Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions models/mpnn/src/mpnn/inference_engines/mpnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from atomworks.ml.utils.token import get_token_starts, spread_token_wise
from biotite.structure import AtomArray
from mpnn.collate.feature_collator import FeatureCollator
from mpnn.metrics.nll import MPNNConfidence, MPNNInterfaceConfidence
from mpnn.metrics.sequence_recovery import (
InterfaceSequenceRecovery,
SequenceRecovery,
Expand Down Expand Up @@ -193,11 +194,18 @@ def _build_metrics_manager(self) -> MetricManager:
# Construct metrics dict.
metrics: dict[str, Any] = {
"sequence_recovery": SequenceRecovery(return_per_example_metrics=True),
"overall_confidence": MPNNConfidence(
return_per_example_metrics=True,
return_per_residue_metrics=True,
),
}
if self.model_type == "ligand_mpnn":
metrics["interface_sequence_recovery"] = InterfaceSequenceRecovery(
return_per_example_metrics=True
)
metrics["ligand_confidence"] = MPNNInterfaceConfidence(
return_per_example_metrics=True
)

# Construct the MetricManager.
metric_manager = MetricManager.from_metrics(metrics, raise_errors=True)
Expand Down Expand Up @@ -387,6 +395,33 @@ def _run_batch(
else:
interface_sequence_recovery_per_design = None

# Per-design confidence (= exp(-NLL) of the *sampled* sequence), computed
# by the MPNNConfidence metrics so the NLL-to-confidence transform lives
# in the metric layer. Per-residue confidence is zeroed at non-designed
# positions and written into the output structure below.
overall_confidence_per_design = (
metrics_output["overall_confidence.confidence_per_example"]
.detach()
.cpu()
.numpy()
)
confidence_per_residue = (
metrics_output["overall_confidence.confidence_per_residue"]
.detach()
.cpu()
.numpy()
)

if self.model_type == "ligand_mpnn":
ligand_confidence_per_design = (
metrics_output["ligand_confidence.interface_confidence_per_example"]
.detach()
.cpu()
.numpy()
)
else:
ligand_confidence_per_design = None

# Grab the index to token mapping from the model.
idx_to_token = MPNN_TOKEN_ENCODING.idx_to_token

Expand Down Expand Up @@ -436,6 +471,18 @@ def _run_batch(
# Overwrite with designed residue names.
design_atom_array.set_annotation("res_name", full_resnames)

# Spread per-residue confidence (token-level) to atom level over the
# non-atomized subset and store it in a dedicated 'mpnn_confidence'
# annotation (so the original b-factor is preserved). Non-designed
# positions are set to 0.
design_confidence_atom = spread_token_wise(
design_non_atomized_array,
confidence_per_residue[design_idx],
)
full_confidence = np.zeros(len(design_atom_array), dtype=np.float32)
full_confidence[~design_atom_array.atomize] = design_confidence_atom
design_atom_array.set_annotation("mpnn_confidence", full_confidence)

# We need to remove any non-atomized residue atoms that no
# longer belong (i.e. old side chain atoms). We want to keep any
# atom that is atomized, any atom that is a backbone atom, and
Expand Down Expand Up @@ -478,6 +525,12 @@ def _run_batch(
"batch_idx": batch_idx,
"design_idx": design_idx,
"designed_sequence": one_letter_seq,
"overall_confidence": float(overall_confidence_per_design[design_idx]),
"ligand_confidence": (
float(ligand_confidence_per_design[design_idx])
if ligand_confidence_per_design is not None
else None
),
"sequence_recovery": sequence_recovery,
"ligand_interface_sequence_recovery": (
ligand_interface_sequence_recovery
Expand Down
167 changes: 167 additions & 0 deletions models/mpnn/src/mpnn/metrics/nll.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,170 @@ def compute(self, log_probs, S, mask_for_loss, atom_array, **kwargs):
interface_metrics[f"interface_{key}"] = value

return interface_metrics


class SampledNLL(NLL):
"""NLL of the *sampled* (designed) sequence.

Unlike :class:`NLL`, which scores the native sequence as a training metric,
this scores the sampled sequence under the model's un-temperatured,
un-biased log-probabilities (``log_softmax`` of the raw logits). See
:class:`MPNNConfidence` for the derived ``exp(-NLL)`` confidence.
"""

@property
def kwargs_to_compute_args(self):
# Score the *sampled* sequence using the raw logits (the parent's
# 'log_probs' kwarg is replaced with 'logits'). Copy first so the
# parent's mapping is never mutated.
mapping = dict(super().kwargs_to_compute_args)
mapping.pop("log_probs", None)
mapping["S"] = ("network_output", "decoder_features", "S_sampled")
mapping["logits"] = ("network_output", "decoder_features", "logits")
return mapping

def compute(self, logits, S, mask_for_loss, **kwargs):
"""Convert raw logits to log-probabilities and delegate to ``NLL``.

Args:
logits (torch.Tensor): [B, L, vocab_size] - raw model logits.
S (torch.Tensor): [B, L] - the sampled (designed) sequence.
mask_for_loss (torch.Tensor): [B, L] - mask for loss.
**kwargs: Additional arguments forwarded to ``NLL.compute``.

Returns:
dict: The NLL / confidence metrics from ``NLL.compute``.
"""
# Raw logits -> true log-probs at temperature 1.0 (no bias), matching
# LigandMPNN's confidence definition.
log_probs = torch.log_softmax(logits, dim=-1)
return super().compute(
log_probs=log_probs, S=S, mask_for_loss=mask_for_loss, **kwargs
)
Comment thread
daylight-00 marked this conversation as resolved.


class SampledInterfaceNLL(InterfaceNLL):
"""Interface NLL of the *sampled* (designed) sequence.

Like :class:`SampledNLL` but restricted to polymer-ligand interface
residues. See :class:`MPNNInterfaceConfidence` for the derived
``exp(-NLL)`` confidence.
"""

@property
def kwargs_to_compute_args(self):
# Score the *sampled* sequence using the raw logits (the parent's
# 'log_probs' kwarg is replaced with 'logits'). Copy first so the
# parent's mapping is never mutated.
mapping = dict(super().kwargs_to_compute_args)
mapping.pop("log_probs", None)
mapping["S"] = ("network_output", "decoder_features", "S_sampled")
mapping["logits"] = ("network_output", "decoder_features", "logits")
return mapping

def compute(self, logits, S, mask_for_loss, atom_array, **kwargs):
"""Convert raw logits to log-probabilities and delegate to ``InterfaceNLL``.

Args:
logits (torch.Tensor): [B, L, vocab_size] - raw model logits.
S (torch.Tensor): [B, L] - the sampled (designed) sequence.
mask_for_loss (torch.Tensor): [B, L] - mask for loss.
atom_array: Atom array(s) used to derive the polymer-ligand
interface mask.
**kwargs: Additional arguments forwarded to ``InterfaceNLL.compute``.

Returns:
dict: The interface NLL / confidence metrics (``interface_`` prefix).
"""
# Raw logits -> true log-probs at temperature 1.0 (no bias), matching
# LigandMPNN's confidence definition.
log_probs = torch.log_softmax(logits, dim=-1)
return super().compute(
log_probs=log_probs,
S=S,
mask_for_loss=mask_for_loss,
atom_array=atom_array,
**kwargs,
)
Comment thread
daylight-00 marked this conversation as resolved.
Comment thread
daylight-00 marked this conversation as resolved.


class MPNNConfidence(SampledNLL):
"""Per-design confidence of the sampled (designed) sequence.

Thin wrapper around :class:`SampledNLL` that additionally exposes the
LigandMPNN-style confidence (``confidence = exp(-NLL)``) so callers do not
need to apply the NLL-to-confidence transform themselves. Confidence lies
in ``(0, 1]``; higher means the model is more confident in the sequence.
"""

def compute(self, logits, S, mask_for_loss, **kwargs):
"""Compute the sampled-sequence NLL and derived confidence.

Args:
logits (torch.Tensor): [B, L, vocab_size] - raw model logits.
S (torch.Tensor): [B, L] - the sampled (designed) sequence.
mask_for_loss (torch.Tensor): [B, L] - mask for loss.
**kwargs: Additional arguments forwarded to ``SampledNLL.compute``.

Returns:
dict: The ``SampledNLL`` metrics plus, when the corresponding NLL
outputs are present, ``confidence_per_example`` (= exp(-NLL)) and
``confidence_per_residue`` (= exp(-NLL) per position, 0 at
non-designed positions).
"""
metrics = super().compute(
logits=logits, S=S, mask_for_loss=mask_for_loss, **kwargs
)
if "nll_per_example" in metrics:
metrics["confidence_per_example"] = torch.exp(-metrics["nll_per_example"])
if "nll_per_residue" in metrics:
confidence = torch.exp(-metrics["nll_per_residue"])
mask = metrics["per_residue_mask"].bool()
metrics["confidence_per_residue"] = torch.where(
mask, confidence, torch.zeros_like(confidence)
)
return metrics


class MPNNInterfaceConfidence(SampledInterfaceNLL):
"""Per-design confidence over polymer-ligand interface residues.

Interface counterpart of :class:`MPNNConfidence`; wraps
:class:`SampledInterfaceNLL` and exposes ``exp(-NLL)`` confidence computed
over the interface residues only (``interface_`` prefixed keys).
"""

def compute(self, logits, S, mask_for_loss, atom_array, **kwargs):
"""Compute the interface NLL and derived interface confidence.

Args:
logits (torch.Tensor): [B, L, vocab_size] - raw model logits.
S (torch.Tensor): [B, L] - the sampled (designed) sequence.
mask_for_loss (torch.Tensor): [B, L] - mask for loss.
atom_array: Atom array(s) used to derive the interface mask.
**kwargs: Additional arguments forwarded to
``SampledInterfaceNLL.compute``.

Returns:
dict: The ``SampledInterfaceNLL`` metrics plus, when present,
``interface_confidence_per_example`` and
``interface_confidence_per_residue``.
"""
metrics = super().compute(
logits=logits,
S=S,
mask_for_loss=mask_for_loss,
atom_array=atom_array,
**kwargs,
)
if "interface_nll_per_example" in metrics:
metrics["interface_confidence_per_example"] = torch.exp(
-metrics["interface_nll_per_example"]
)
if "interface_nll_per_residue" in metrics:
confidence = torch.exp(-metrics["interface_nll_per_residue"])
mask = metrics["interface_per_residue_mask"].bool()
metrics["interface_confidence_per_residue"] = torch.where(
mask, confidence, torch.zeros_like(confidence)
)
return metrics
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,13 @@ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
"input_features": [
"mask_for_loss",
],
"decoder_features": ["log_probs", "S_sampled", "S_argmax"],
# 'logits' are needed to compute un-temperatured confidence.
"decoder_features": [
"log_probs",
"S_sampled",
"S_argmax",
"logits",
],
}

# Save the scalar settings.
Expand Down
12 changes: 12 additions & 0 deletions models/mpnn/src/mpnn/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,8 @@ class MPNNInferenceOutput:
- 'batch_idx'
- 'design_idx'
- 'designed_sequence'
- 'overall_confidence'
- 'ligand_confidence'
- 'sequence_recovery'
- 'ligand_interface_sequence_recovery'
- 'model_type'
Expand Down Expand Up @@ -2292,6 +2294,7 @@ def write_structure(
"mpnn_temperature",
"mpnn_symmetry_equivalence_group",
"mpnn_symmetry_weight",
"mpnn_confidence",
]

# Limit to fields actually present in the atom array.
Expand Down Expand Up @@ -2367,6 +2370,15 @@ def write_fasta(
decorated_name = "_".join(name_fields)
header_fields.append(decorated_name)

# Construct the confidence fields for the header (exp(-NLL) of the
# sampled sequence; mirrors LigandMPNN's overall/ligand confidence).
overall_confidence = self.output_dict.get("overall_confidence")
ligand_confidence = self.output_dict.get("ligand_confidence")
if overall_confidence is not None:
header_fields.append(f"overall_confidence={float(overall_confidence):.4f}")
if ligand_confidence is not None:
header_fields.append(f"ligand_confidence={float(ligand_confidence):.4f}")

# Construct the recovery fields for the header.
if sequence_recovery is not None:
header_fields.append(f"sequence_recovery={float(sequence_recovery):.4f}")
Expand Down
Loading
Loading