Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions models/rf3/src/rf3/data/paired_msa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# mypy: ignore-errors
#
# This module does not type-check (and does not even import) against the installed
# atomworks: `MultiInputDatasetWrapper` below subclasses
# `atomworks.ml.datasets.StructuralDatasetWrapper`, which atomworks turned into a
# deprecated factory *function* — subclassing it raises `TypeError` at import time.
# Making it type-check requires a real refactor onto the `PandasDataset` API, validated
# on cluster data (see `.ai/roadmap.md`), not type annotations. The suppression lives
# here, in the file, rather than in `pyproject.toml` so it is visible to anyone reviving
# the module: when this file imports and type-checks cleanly again, delete this directive
# to restore mypy coverage (the module stays inside mypy's `files` scope).
import os
import socket
import time
Expand Down
9 changes: 6 additions & 3 deletions models/rf3/src/rf3/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from atomworks.io.utils.io_utils import to_cif_file
from atomworks.ml.utils.io import apply_sharding_pattern
from atomworks.ml.utils.misc import hash_sequence
from beartype.typing import Literal
from beartype.typing import Literal, cast
from biotite.structure import AtomArray, AtomArrayStack, stack

from foundry.utils.alignment import weighted_rigid_align
Expand Down Expand Up @@ -165,9 +165,12 @@ def dump_trajectories(
trajectory_list[0].device
)
for step in range(n_steps - 1):
# `trajectory_list` is typed `Tensor | np.ndarray`, but every path here
# (alignment + the `torch.stack` below) is tensor-only, so the diffusion
# trajectories are always tensors when alignment runs.
trajectory_list[step] = weighted_rigid_align(
X_L=trajectory_list[-1],
X_gt_L=trajectory_list[step],
X_L=cast(torch.Tensor, trajectory_list[-1]),
X_gt_L=cast(torch.Tensor, trajectory_list[step]),
X_exists_L=X_exists_L,
w_L=w_L,
)
Expand Down
136 changes: 136 additions & 0 deletions models/rf3/tests/test_af3_loss_symmetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Unit tests for the symmetry-resolution geometry in ``rf3.loss.af3_losses``.

Both helpers below are the load-bearing machinery behind ``SubunitSymmetryResolution``
and ``ResidueSymmetryResolution`` (used by ``rf3.symmetry.resolve`` /
``rf3.trainers.rf3``), which re-label the ground-truth coordinates to the symmetry
copy / automorphism that best matches the prediction before the loss is taken.

- ``SubunitSymmetryResolution._rms_align`` is a batched Kabsch fit. Given predicted
coordinates ``X_fixed`` (``Nbatch x L x 3``) and candidate native copies ``X_moving``
(``Nambig x L x 3``) it returns ``(u_moving, R, u_fixed)`` such that
``(X_moving - u_moving) @ R + u_fixed`` lands the native on the prediction — the exact
transform ``_resolve_subunits`` then applies to the native centres of mass. The SVD is
sign-corrected so ``R`` is always a proper rotation (``det = +1``), never a reflection.
- ``ResidueSymmetryResolution._get_best`` picks, per model in the batch, the atom
automorphism (a permutation of a set of interchangeable atom indices) whose
intra-structure distance pattern best matches the prediction, then rewrites the native
coordinates / mask at those positions to that permutation.

All tests run in float32 (production dtype); ``_rms_align`` builds its sign-correction
matrix with an un-typed ``torch.eye`` that only matches float32 inputs (see the roadmap).
"""

import pytest
import torch
from rf3.loss.af3_losses import ResidueSymmetryResolution, SubunitSymmetryResolution

# A non-degenerate, non-coplanar point cloud to align (L = 5).
_POINTS = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.0, 2.0, 0.0],
[0.0, 0.0, 3.0],
[1.0, 1.0, 1.0],
[-1.0, 2.0, -1.0],
]
)


def _rotation_z(theta: float) -> torch.Tensor:
"""Proper rotation (det +1) about the z-axis by ``theta`` radians."""
c, s = torch.cos(torch.tensor(theta)), torch.sin(torch.tensor(theta))
return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]])


# --- SubunitSymmetryResolution._rms_align -----------------------------------


def test_rms_align_recovers_a_rigid_transform():
# X_moving is X_fixed pushed through a known rotation + translation; the returned
# transform must land it back on X_fixed.
fixed = _POINTS[None] # (Nbatch=1, L, 3)
moving = (_POINTS @ _rotation_z(0.7) + torch.tensor([2.0, -1.0, 3.0]))[None]

u_moving, R, u_fixed = SubunitSymmetryResolution()._rms_align(fixed, moving)

aligned = (moving[0] - u_moving[0, 0]) @ R[0, 0] + u_fixed[0, 0]
assert torch.allclose(aligned, _POINTS, atol=1e-4)
# Sign-corrected SVD → a proper rotation, not a reflection.
assert torch.linalg.det(R[0, 0]).item() == pytest.approx(1.0, abs=1e-4)


def test_rms_align_identity_when_moving_equals_fixed():
fixed = _POINTS[None]
u_moving, R, u_fixed = SubunitSymmetryResolution()._rms_align(fixed, _POINTS[None])

assert torch.allclose(R[0, 0], torch.eye(3), atol=1e-4)
assert torch.allclose(u_moving[0, 0], u_fixed[0, 0], atol=1e-4)
aligned = (_POINTS - u_moving[0, 0]) @ R[0, 0] + u_fixed[0, 0]
assert torch.allclose(aligned, _POINTS, atol=1e-4)


def test_rms_align_corrects_reflection_to_proper_rotation():
# A mirror image cannot be rotated onto the original; the sign correction must still
# return a proper rotation (det +1) rather than the optimal-but-improper reflection.
fixed = _POINTS[None]
reflected = (_POINTS @ torch.diag(torch.tensor([1.0, 1.0, -1.0])))[None]

_, R, _ = SubunitSymmetryResolution()._rms_align(fixed, reflected)

assert torch.linalg.det(R[0, 0]).item() == pytest.approx(1.0, abs=1e-4)


def test_rms_align_output_shapes_broadcast_over_ambig_and_batch():
# u_moving is per-ambiguity, u_fixed is per-batch, R is the full cross product — the
# broadcast-ready shapes _resolve_subunits relies on.
fixed = _POINTS[None].repeat(3, 1, 1) # Nbatch = 3
moving = _POINTS[None].repeat(2, 1, 1) # Nambig = 2

u_moving, R, u_fixed = SubunitSymmetryResolution()._rms_align(fixed, moving)

assert u_moving.shape == (2, 1, 3)
assert R.shape == (2, 3, 3, 3)
assert u_fixed.shape == (1, 3, 3)


# --- ResidueSymmetryResolution._get_best ------------------------------------


def _get_best_inputs(pred_sym, native_sym, context):
"""Build (x_pred, x_native, x_native_mask, a_i) for a 1-model, 3-atom case.

Atoms 0 and 1 are interchangeable (the automorphism set); atom 2 is fixed context
whose distances to atoms 0/1 break the tie. ``a_i`` offers the identity ordering
``[0, 1]`` and the swap ``[1, 0]``.
"""
x_pred = torch.tensor([[*pred_sym, context]]) # (1, 3, 3)
x_native = torch.tensor([[*native_sym, context]])
mask = torch.ones(1, 3, dtype=torch.bool)
a_i = torch.tensor([[0, 1], [1, 0]])
return x_pred, x_native, mask, a_i


def test_get_best_selects_the_swap_that_matches_prediction():
# Native has atoms 0/1 swapped relative to the prediction; _get_best must undo it so
# the native's interchangeable atoms line up with the prediction's arrangement.
pred_sym = [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]]
native_sym = [[2.0, 0.0, 0.0], [0.0, 0.0, 0.0]] # swapped
context = [0.0, 3.0, 0.0]
x_pred, x_native, mask, a_i = _get_best_inputs(pred_sym, native_sym, context)

out_native, _ = ResidueSymmetryResolution()._get_best(x_pred, x_native, mask, a_i)

assert torch.allclose(out_native[0, 0], torch.tensor([0.0, 0.0, 0.0]))
assert torch.allclose(out_native[0, 1], torch.tensor([2.0, 0.0, 0.0]))


def test_get_best_leaves_already_matching_native_unchanged():
# Native already matches the prediction → the identity ordering wins → no rewrite.
sym = [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]]
context = [0.0, 3.0, 0.0]
x_pred, x_native, mask, a_i = _get_best_inputs(sym, sym, context)
before = x_native.clone()

out_native, _ = ResidueSymmetryResolution()._get_best(x_pred, x_native, mask, a_i)

assert torch.allclose(out_native, before)
58 changes: 58 additions & 0 deletions models/rf3/tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Unit tests for rf3.model.layers.attention triangle-update blocks (vanilla path).

Both blocks map a pair representation ``[B, L, L, d_pair]`` to the same shape (for a
residual add) and, on CPU, run their vanilla PyTorch path (cuEquivariance is off).

- ``TriangleAttention`` zero-initialises its output projection ``to_out`` so the block is
the identity (output 0) at the start of training; ``start_node`` switches the attention
axis (rows vs. transposed columns) but preserves the output shape either way.
- ``TriangleMultiplication`` validates its ``direction`` (``"outgoing"``/``"incoming"``)
and, when the cuEquivariance kernel is requested, requires ``d_pair == d_hidden``; the
vanilla path lifts that constraint.
"""

import pytest
import torch
from rf3.model.layers.attention import TriangleAttention, TriangleMultiplication

# --- TriangleAttention ------------------------------------------------------


def test_triangle_attention_preserves_shape():
pair = torch.randn(1, 6, 6, 8)
for start_node in (True, False):
layer = TriangleAttention(d_pair=8, n_head=2, d_hidden=4, start_node=start_node)
assert layer(pair).shape == pair.shape


def test_triangle_attention_zero_initialized():
torch.manual_seed(0)
layer = TriangleAttention(d_pair=8, n_head=2, d_hidden=4)
# to_out is zero-initialised so the residual add starts as the identity.
assert bool((layer.to_out.weight == 0).all()) and bool(
(layer.to_out.bias == 0).all()
)
assert bool((layer(torch.randn(1, 6, 6, 8)) == 0).all())


# --- TriangleMultiplication -------------------------------------------------


def test_triangle_multiplication_preserves_shape():
pair = torch.randn(1, 6, 6, 8)
for direction in ("outgoing", "incoming"):
# Vanilla path lifts the d_pair == d_hidden cuEquivariance constraint.
layer = TriangleMultiplication(
d_pair=8, d_hidden=4, direction=direction, use_cuequivariance=False
)
assert layer(pair).shape == pair.shape


def test_triangle_multiplication_rejects_invalid_direction():
with pytest.raises(ValueError, match="direction must be 'outgoing' or 'incoming'"):
TriangleMultiplication(d_pair=8, direction="sideways", use_cuequivariance=False)


def test_triangle_multiplication_cuequivariance_requires_matching_dims():
with pytest.raises(AssertionError, match="requires d_pair == d_hidden"):
TriangleMultiplication(d_pair=8, d_hidden=4, use_cuequivariance=True)
85 changes: 75 additions & 10 deletions models/rf3/tests/test_chiral.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
"""Unit tests for rf3.metrics.chiral.calc_chiral_metrics_masked.

The pure tensor core behind RF3's chirality metrics. The non-obvious contracts
pinned here: a chiral center is the dihedral of four atoms whose ideal angle is
the 5th column of ``chirals``; correctness is a *sign* match between the
predicted and ideal dihedral; a center counts only if all four of its atoms are
unmasked; duplicate rows that share a first-atom index (alternate orderings of
the same center) are collapsed to the first occurrence; and the empty cases
(no chirals, or a fully-masked structure) return ``{}``.
"""Unit tests for rf3.metrics.chiral.

``calc_chiral_metrics_masked`` is the pure tensor core behind RF3's chirality
metrics. The non-obvious contracts pinned here: a chiral center is the dihedral
of four atoms whose ideal angle is the 5th column of ``chirals``; correctness is
a *sign* match between the predicted and ideal dihedral; a center counts only if
all four of its atoms are unmasked; duplicate rows that share a first-atom index
(alternate orderings of the same center) are collapsed to the first occurrence;
and the empty cases (no chirals, or a fully-masked structure) return ``{}``.

``compute_chiral_metrics`` wraps that core: given predicted / ground-truth
``AtomArrayStack``s and (optionally) precomputed chiral features, it splits the
atoms into polymer vs non-polymer, drops ground-truth atoms with NaN
coordinates, and emits ``{category}_{n_chiral_centers,chiral_loss_mean,
percent_correct_chirality}`` keys only for categories that contain a scorable
center. Passing ``chiral_feats`` explicitly bypasses the rdkit/atomworks
feature generation, so the orchestration is exercised here on tiny fixtures.
"""

import math

import numpy as np
import pytest
import torch
from biotite.structure import AtomArrayStack
from rf3.kinematics import get_dih
from rf3.metrics.chiral import calc_chiral_metrics_masked
from rf3.metrics.chiral import calc_chiral_metrics_masked, compute_chiral_metrics


def _chiral_center_coords(theta: float) -> list[list[float]]:
Expand Down Expand Up @@ -115,3 +125,58 @@ def test_batch_dimension_shapes():

assert result["chiral_loss_mean"].shape == (2,)
assert result["percent_correct_chirality"].shape == (2,)


# --- compute_chiral_metrics -------------------------------------------------

# A single chiral center (atoms 0-3) whose predicted dihedral is +pi/2; ideal +1 -> a
# sign match, ideal -1 -> a mismatch.
_CENTER = [_chiral_center_coords(math.pi / 2)] # (1 model, 4 atoms, 3)
_FEATS_CORRECT = torch.tensor([[0.0, 1.0, 2.0, 3.0, 1.0]])


def _stack(coords, is_polymer) -> AtomArrayStack:
"""AtomArrayStack from (D, L, 3) coords + a per-atom is_polymer flag list."""
coord = np.asarray(coords, dtype=np.float32)
stack = AtomArrayStack(depth=coord.shape[0], length=coord.shape[1])
stack.coord = coord
stack.set_annotation("is_polymer", np.asarray(is_polymer, dtype=bool))
return stack


def test_compute_chiral_metrics_polymer_center_correct():
stack = _stack(_CENTER, [True] * 4)
out = compute_chiral_metrics(stack, stack, chiral_feats=_FEATS_CORRECT)

assert out["polymer_n_chiral_centers"] == 1
assert out["polymer_percent_correct_chirality"] == 1.0
assert "polymer_chiral_loss_mean" in out
# No non-polymer atoms -> that category is absent entirely.
assert "non_polymer_n_chiral_centers" not in out


def test_compute_chiral_metrics_routes_to_non_polymer_category():
stack = _stack(_CENTER, [False] * 4)
out = compute_chiral_metrics(stack, stack, chiral_feats=_FEATS_CORRECT)

assert out["non_polymer_n_chiral_centers"] == 1
assert "polymer_n_chiral_centers" not in out


def test_compute_chiral_metrics_wrong_sign_is_zero_percent():
stack = _stack(_CENTER, [True] * 4)
feats_wrong = torch.tensor([[0.0, 1.0, 2.0, 3.0, -1.0]]) # ideal sign flipped
out = compute_chiral_metrics(stack, stack, chiral_feats=feats_wrong)

assert out["polymer_percent_correct_chirality"] == 0.0


def test_compute_chiral_metrics_nan_ground_truth_coord_drops_center():
pred = _stack(_CENTER, [True] * 4)
gt = _stack(_CENTER, [True] * 4)
gt.coord[0, 3, :] = np.nan # one center atom unresolved in the ground truth

out = compute_chiral_metrics(pred, gt, chiral_feats=_FEATS_CORRECT)

# The center has an unresolved atom -> no scorable center in either category.
assert out == {}
Loading
Loading