Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions models/mpnn/src/mpnn/inference_engines/mpnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from atomworks.ml.utils.token import get_token_starts, spread_token_wise
from biotite.structure import AtomArray
from mpnn.collate.feature_collator import FeatureCollator
from mpnn.metrics.nll import MPNNConfidence, MPNNLigandInterfaceConfidence
from mpnn.metrics.sequence_recovery import (
InterfaceSequenceRecovery,
SequenceRecovery,
Expand Down Expand Up @@ -193,11 +194,18 @@ def _build_metrics_manager(self) -> MetricManager:
# Construct metrics dict.
metrics: dict[str, Any] = {
"sequence_recovery": SequenceRecovery(return_per_example_metrics=True),
"confidence": MPNNConfidence(
return_per_example_metrics=True,
return_per_residue_metrics=True,
),
}
if self.model_type == "ligand_mpnn":
metrics["interface_sequence_recovery"] = InterfaceSequenceRecovery(
return_per_example_metrics=True
)
metrics["ligand_interface_confidence"] = MPNNLigandInterfaceConfidence(
return_per_example_metrics=True
)

# Construct the MetricManager.
metric_manager = MetricManager.from_metrics(metrics, raise_errors=True)
Expand Down Expand Up @@ -387,6 +395,33 @@ def _run_batch(
else:
interface_sequence_recovery_per_design = None

# Per-design confidence (= exp(-NLL) of the *sampled* sequence), computed
# by the MPNNConfidence metrics so the NLL-to-confidence transform lives
# in the metric layer. Per-residue confidence is zeroed at non-designed
# positions and written into the output structure below.
confidence_per_design = (
metrics_output["confidence.confidence_per_example"].detach().cpu().numpy()
)
confidence_per_residue = (
metrics_output["confidence.confidence_per_residue"].detach().cpu().numpy()
)

# Ligand-interface confidence is undefined (NaN) when there are no
# interface residues (e.g. ligand_mpnn run on a ligand-free input); such
# values are converted to None per design below so they are omitted from
# the outputs rather than written as NaN.
if self.model_type == "ligand_mpnn":
ligand_interface_confidence_per_design = (
metrics_output[
"ligand_interface_confidence.interface_confidence_per_example"
]
.detach()
.cpu()
.numpy()
)
else:
ligand_interface_confidence_per_design = None

# Grab the index to token mapping from the model.
idx_to_token = MPNN_TOKEN_ENCODING.idx_to_token

Expand Down Expand Up @@ -436,6 +471,18 @@ def _run_batch(
# Overwrite with designed residue names.
design_atom_array.set_annotation("res_name", full_resnames)

# Spread per-residue confidence (token-level) to atom level over the
# non-atomized subset and store it in a dedicated 'mpnn_confidence'
# annotation (so the original b-factor is preserved). Non-designed
# positions are set to 0.
design_confidence_atom = spread_token_wise(
design_non_atomized_array,
confidence_per_residue[design_idx],
)
full_confidence = np.zeros(len(design_atom_array), dtype=np.float32)
full_confidence[~design_atom_array.atomize] = design_confidence_atom
design_atom_array.set_annotation("mpnn_confidence", full_confidence)

# We need to remove any non-atomized residue atoms that no
# longer belong (i.e. old side chain atoms). We want to keep any
# atom that is atomized, any atom that is a backbone atom, and
Expand Down Expand Up @@ -478,6 +525,13 @@ def _run_batch(
"batch_idx": batch_idx,
"design_idx": design_idx,
"designed_sequence": one_letter_seq,
"confidence": float(confidence_per_design[design_idx]),
"ligand_interface_confidence": (
float(ligand_interface_confidence_per_design[design_idx])
if ligand_interface_confidence_per_design is not None
and not np.isnan(ligand_interface_confidence_per_design[design_idx])
else None
),
"sequence_recovery": sequence_recovery,
"ligand_interface_sequence_recovery": (
ligand_interface_sequence_recovery
Expand Down
168 changes: 168 additions & 0 deletions models/mpnn/src/mpnn/metrics/nll.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,171 @@ def compute(self, log_probs, S, mask_for_loss, atom_array, **kwargs):
interface_metrics[f"interface_{key}"] = value

return interface_metrics


class SampledNLL(NLL):
"""NLL of the *sampled* (designed) sequence.

Unlike :class:`NLL`, which scores the native sequence as a training metric,
this scores the sampled sequence under the model's un-temperatured,
un-biased log-probabilities (``log_softmax`` of the raw logits). See
:class:`MPNNConfidence` for the derived ``exp(-NLL)`` confidence.
"""

@property
def kwargs_to_compute_args(self):
# Build the mapping explicitly: score the *sampled* sequence
# (`S_sampled`) using the raw `logits`.
return {
"logits": ("network_output", "decoder_features", "logits"),
"S": ("network_output", "decoder_features", "S_sampled"),
"mask_for_loss": ("network_output", "input_features", "mask_for_loss"),
}

def compute(self, logits, S, mask_for_loss, **kwargs):
"""Convert raw logits to log-probabilities and delegate to ``NLL``.

Args:
logits (torch.Tensor): [B, L, vocab_size] - raw model logits.
S (torch.Tensor): [B, L] - the sampled (designed) sequence.
mask_for_loss (torch.Tensor): [B, L] - mask for loss.
**kwargs: Additional arguments forwarded to ``NLL.compute``.

Returns:
dict: The NLL / confidence metrics from ``NLL.compute``.
"""
# Raw logits -> true log-probs at temperature 1.0 (no bias), matching
# LigandMPNN's confidence definition.
log_probs = torch.log_softmax(logits, dim=-1)
return super().compute(
log_probs=log_probs, S=S, mask_for_loss=mask_for_loss, **kwargs
)
Comment thread
daylight-00 marked this conversation as resolved.


class SampledLigandInterfaceNLL(InterfaceNLL):
"""Interface NLL of the *sampled* (designed) sequence.

Like :class:`SampledNLL` but restricted to polymer-ligand interface
residues. See :class:`MPNNLigandInterfaceConfidence` for the derived
``exp(-NLL)`` confidence.
"""

@property
def kwargs_to_compute_args(self):
# Build the mapping explicitly: score the *sampled* sequence
# (`S_sampled`) using the raw `logits`; `atom_array` derives the
# polymer-ligand interface mask.
return {
"logits": ("network_output", "decoder_features", "logits"),
"S": ("network_output", "decoder_features", "S_sampled"),
"mask_for_loss": ("network_output", "input_features", "mask_for_loss"),
"atom_array": ("network_input", "atom_array"),
}

def compute(self, logits, S, mask_for_loss, atom_array, **kwargs):
"""Convert raw logits to log-probabilities and delegate to ``InterfaceNLL``.

Args:
logits (torch.Tensor): [B, L, vocab_size] - raw model logits.
S (torch.Tensor): [B, L] - the sampled (designed) sequence.
mask_for_loss (torch.Tensor): [B, L] - mask for loss.
atom_array: Atom array(s) used to derive the polymer-ligand
interface mask.
**kwargs: Additional arguments forwarded to ``InterfaceNLL.compute``.

Returns:
dict: The interface NLL / confidence metrics (``interface_`` prefix).
"""
# Raw logits -> true log-probs at temperature 1.0 (no bias), matching
# LigandMPNN's confidence definition.
log_probs = torch.log_softmax(logits, dim=-1)
return super().compute(
log_probs=log_probs,
S=S,
mask_for_loss=mask_for_loss,
atom_array=atom_array,
**kwargs,
)
Comment thread
daylight-00 marked this conversation as resolved.
Comment thread
daylight-00 marked this conversation as resolved.


class MPNNConfidence(SampledNLL):
"""Per-design confidence of the sampled (designed) sequence.

Thin wrapper around :class:`SampledNLL` that additionally exposes the
LigandMPNN-style confidence (``confidence = exp(-NLL)``) so callers do not
need to apply the NLL-to-confidence transform themselves. Confidence lies
in ``(0, 1]``; higher means the model is more confident in the sequence.
"""

def compute(self, logits, S, mask_for_loss, **kwargs):
"""Compute the sampled-sequence NLL and derived confidence.

Args:
logits (torch.Tensor): [B, L, vocab_size] - raw model logits.
S (torch.Tensor): [B, L] - the sampled (designed) sequence.
mask_for_loss (torch.Tensor): [B, L] - mask for loss.
**kwargs: Additional arguments forwarded to ``SampledNLL.compute``.

Returns:
dict: The ``SampledNLL`` metrics plus, when the corresponding NLL
outputs are present, ``confidence_per_example`` (= exp(-NLL)) and
``confidence_per_residue`` (= exp(-NLL) per position, 0 at
non-designed positions).
"""
metrics = super().compute(
logits=logits, S=S, mask_for_loss=mask_for_loss, **kwargs
)
if "nll_per_example" in metrics:
metrics["confidence_per_example"] = torch.exp(-metrics["nll_per_example"])
if "nll_per_residue" in metrics:
confidence = torch.exp(-metrics["nll_per_residue"])
mask = metrics["per_residue_mask"].bool()
metrics["confidence_per_residue"] = torch.where(
mask, confidence, torch.zeros_like(confidence)
)
return metrics


class MPNNLigandInterfaceConfidence(SampledLigandInterfaceNLL):
"""Per-design confidence over polymer-ligand interface residues.

Ligand-interface counterpart of :class:`MPNNConfidence`; wraps
:class:`SampledLigandInterfaceNLL` and exposes ``exp(-NLL)`` confidence
computed over the polymer-ligand interface residues only (``interface_``
prefixed keys).
"""

def compute(self, logits, S, mask_for_loss, atom_array, **kwargs):
"""Compute the interface NLL and derived interface confidence.

Args:
logits (torch.Tensor): [B, L, vocab_size] - raw model logits.
S (torch.Tensor): [B, L] - the sampled (designed) sequence.
mask_for_loss (torch.Tensor): [B, L] - mask for loss.
atom_array: Atom array(s) used to derive the interface mask.
**kwargs: Additional arguments forwarded to
``SampledLigandInterfaceNLL.compute``.

Returns:
dict: The ``SampledLigandInterfaceNLL`` metrics plus, when present,
``interface_confidence_per_example`` and
``interface_confidence_per_residue``.
"""
metrics = super().compute(
logits=logits,
S=S,
mask_for_loss=mask_for_loss,
atom_array=atom_array,
**kwargs,
)
if "interface_nll_per_example" in metrics:
metrics["interface_confidence_per_example"] = torch.exp(
-metrics["interface_nll_per_example"]
)
if "interface_nll_per_residue" in metrics:
confidence = torch.exp(-metrics["interface_nll_per_residue"])
mask = metrics["interface_per_residue_mask"].bool()
metrics["interface_confidence_per_residue"] = torch.where(
mask, confidence, torch.zeros_like(confidence)
)
return metrics
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,13 @@ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
"input_features": [
"mask_for_loss",
],
"decoder_features": ["log_probs", "S_sampled", "S_argmax"],
# 'logits' are needed to compute un-temperatured confidence.
"decoder_features": [
"log_probs",
"S_sampled",
"S_argmax",
"logits",
],
}

# Save the scalar settings.
Expand Down
16 changes: 16 additions & 0 deletions models/mpnn/src/mpnn/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,8 @@ class MPNNInferenceOutput:
- 'batch_idx'
- 'design_idx'
- 'designed_sequence'
- 'confidence'
- 'ligand_interface_confidence'
- 'sequence_recovery'
- 'ligand_interface_sequence_recovery'
- 'model_type'
Expand Down Expand Up @@ -2292,6 +2294,7 @@ def write_structure(
"mpnn_temperature",
"mpnn_symmetry_equivalence_group",
"mpnn_symmetry_weight",
"mpnn_confidence",
]

# Limit to fields actually present in the atom array.
Expand Down Expand Up @@ -2367,6 +2370,19 @@ def write_fasta(
decorated_name = "_".join(name_fields)
header_fields.append(decorated_name)

# Construct the confidence fields for the header (exp(-NLL) of the
# sampled sequence; mirrors LigandMPNN's overall/ligand confidence).
confidence = self.output_dict.get("confidence")
ligand_interface_confidence = self.output_dict.get(
"ligand_interface_confidence"
)
if confidence is not None:
header_fields.append(f"confidence={float(confidence):.4f}")
if ligand_interface_confidence is not None:
header_fields.append(
f"ligand_interface_confidence={float(ligand_interface_confidence):.4f}"
)

# Construct the recovery fields for the header.
if sequence_recovery is not None:
header_fields.append(f"sequence_recovery={float(sequence_recovery):.4f}")
Expand Down
Loading
Loading