diff --git a/models/rf3/tests/test_af3_loss_symmetry.py b/models/rf3/tests/test_af3_loss_symmetry.py new file mode 100644 index 00000000..45977c64 --- /dev/null +++ b/models/rf3/tests/test_af3_loss_symmetry.py @@ -0,0 +1,136 @@ +"""Unit tests for the symmetry-resolution geometry in ``rf3.loss.af3_losses``. + +Both helpers below are the load-bearing machinery behind ``SubunitSymmetryResolution`` +and ``ResidueSymmetryResolution`` (used by ``rf3.symmetry.resolve`` / +``rf3.trainers.rf3``), which re-label the ground-truth coordinates to the symmetry +copy / automorphism that best matches the prediction before the loss is taken. + +- ``SubunitSymmetryResolution._rms_align`` is a batched Kabsch fit. Given predicted + coordinates ``X_fixed`` (``Nbatch x L x 3``) and candidate native copies ``X_moving`` + (``Nambig x L x 3``) it returns ``(u_moving, R, u_fixed)`` such that + ``(X_moving - u_moving) @ R + u_fixed`` lands the native on the prediction — the exact + transform ``_resolve_subunits`` then applies to the native centres of mass. The SVD is + sign-corrected so ``R`` is always a proper rotation (``det = +1``), never a reflection. +- ``ResidueSymmetryResolution._get_best`` picks, per model in the batch, the atom + automorphism (a permutation of a set of interchangeable atom indices) whose + intra-structure distance pattern best matches the prediction, then rewrites the native + coordinates / mask at those positions to that permutation. + +All tests run in float32 (production dtype); ``_rms_align`` builds its sign-correction +matrix with an un-typed ``torch.eye`` that only matches float32 inputs (see the roadmap). +""" + +import pytest +import torch +from rf3.loss.af3_losses import ResidueSymmetryResolution, SubunitSymmetryResolution + +# A non-degenerate, non-coplanar point cloud to align (L = 5). +_POINTS = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 3.0], + [1.0, 1.0, 1.0], + [-1.0, 2.0, -1.0], + ] +) + + +def _rotation_z(theta: float) -> torch.Tensor: + """Proper rotation (det +1) about the z-axis by ``theta`` radians.""" + c, s = torch.cos(torch.tensor(theta)), torch.sin(torch.tensor(theta)) + return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]]) + + +# --- SubunitSymmetryResolution._rms_align ----------------------------------- + + +def test_rms_align_recovers_a_rigid_transform(): + # X_moving is X_fixed pushed through a known rotation + translation; the returned + # transform must land it back on X_fixed. + fixed = _POINTS[None] # (Nbatch=1, L, 3) + moving = (_POINTS @ _rotation_z(0.7) + torch.tensor([2.0, -1.0, 3.0]))[None] + + u_moving, R, u_fixed = SubunitSymmetryResolution()._rms_align(fixed, moving) + + aligned = (moving[0] - u_moving[0, 0]) @ R[0, 0] + u_fixed[0, 0] + assert torch.allclose(aligned, _POINTS, atol=1e-4) + # Sign-corrected SVD → a proper rotation, not a reflection. + assert torch.linalg.det(R[0, 0]).item() == pytest.approx(1.0, abs=1e-4) + + +def test_rms_align_identity_when_moving_equals_fixed(): + fixed = _POINTS[None] + u_moving, R, u_fixed = SubunitSymmetryResolution()._rms_align(fixed, _POINTS[None]) + + assert torch.allclose(R[0, 0], torch.eye(3), atol=1e-4) + assert torch.allclose(u_moving[0, 0], u_fixed[0, 0], atol=1e-4) + aligned = (_POINTS - u_moving[0, 0]) @ R[0, 0] + u_fixed[0, 0] + assert torch.allclose(aligned, _POINTS, atol=1e-4) + + +def test_rms_align_corrects_reflection_to_proper_rotation(): + # A mirror image cannot be rotated onto the original; the sign correction must still + # return a proper rotation (det +1) rather than the optimal-but-improper reflection. + fixed = _POINTS[None] + reflected = (_POINTS @ torch.diag(torch.tensor([1.0, 1.0, -1.0])))[None] + + _, R, _ = SubunitSymmetryResolution()._rms_align(fixed, reflected) + + assert torch.linalg.det(R[0, 0]).item() == pytest.approx(1.0, abs=1e-4) + + +def test_rms_align_output_shapes_broadcast_over_ambig_and_batch(): + # u_moving is per-ambiguity, u_fixed is per-batch, R is the full cross product — the + # broadcast-ready shapes _resolve_subunits relies on. + fixed = _POINTS[None].repeat(3, 1, 1) # Nbatch = 3 + moving = _POINTS[None].repeat(2, 1, 1) # Nambig = 2 + + u_moving, R, u_fixed = SubunitSymmetryResolution()._rms_align(fixed, moving) + + assert u_moving.shape == (2, 1, 3) + assert R.shape == (2, 3, 3, 3) + assert u_fixed.shape == (1, 3, 3) + + +# --- ResidueSymmetryResolution._get_best ------------------------------------ + + +def _get_best_inputs(pred_sym, native_sym, context): + """Build (x_pred, x_native, x_native_mask, a_i) for a 1-model, 3-atom case. + + Atoms 0 and 1 are interchangeable (the automorphism set); atom 2 is fixed context + whose distances to atoms 0/1 break the tie. ``a_i`` offers the identity ordering + ``[0, 1]`` and the swap ``[1, 0]``. + """ + x_pred = torch.tensor([[*pred_sym, context]]) # (1, 3, 3) + x_native = torch.tensor([[*native_sym, context]]) + mask = torch.ones(1, 3, dtype=torch.bool) + a_i = torch.tensor([[0, 1], [1, 0]]) + return x_pred, x_native, mask, a_i + + +def test_get_best_selects_the_swap_that_matches_prediction(): + # Native has atoms 0/1 swapped relative to the prediction; _get_best must undo it so + # the native's interchangeable atoms line up with the prediction's arrangement. + pred_sym = [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]] + native_sym = [[2.0, 0.0, 0.0], [0.0, 0.0, 0.0]] # swapped + context = [0.0, 3.0, 0.0] + x_pred, x_native, mask, a_i = _get_best_inputs(pred_sym, native_sym, context) + + out_native, _ = ResidueSymmetryResolution()._get_best(x_pred, x_native, mask, a_i) + + assert torch.allclose(out_native[0, 0], torch.tensor([0.0, 0.0, 0.0])) + assert torch.allclose(out_native[0, 1], torch.tensor([2.0, 0.0, 0.0])) + + +def test_get_best_leaves_already_matching_native_unchanged(): + # Native already matches the prediction → the identity ordering wins → no rewrite. + sym = [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]] + context = [0.0, 3.0, 0.0] + x_pred, x_native, mask, a_i = _get_best_inputs(sym, sym, context) + before = x_native.clone() + + out_native, _ = ResidueSymmetryResolution()._get_best(x_pred, x_native, mask, a_i) + + assert torch.allclose(out_native, before) diff --git a/models/rf3/tests/test_attention.py b/models/rf3/tests/test_attention.py new file mode 100644 index 00000000..7328782d --- /dev/null +++ b/models/rf3/tests/test_attention.py @@ -0,0 +1,58 @@ +"""Unit tests for rf3.model.layers.attention triangle-update blocks (vanilla path). + +Both blocks map a pair representation ``[B, L, L, d_pair]`` to the same shape (for a +residual add) and, on CPU, run their vanilla PyTorch path (cuEquivariance is off). + +- ``TriangleAttention`` zero-initialises its output projection ``to_out`` so the block is + the identity (output 0) at the start of training; ``start_node`` switches the attention + axis (rows vs. transposed columns) but preserves the output shape either way. +- ``TriangleMultiplication`` validates its ``direction`` (``"outgoing"``/``"incoming"``) + and, when the cuEquivariance kernel is requested, requires ``d_pair == d_hidden``; the + vanilla path lifts that constraint. +""" + +import pytest +import torch +from rf3.model.layers.attention import TriangleAttention, TriangleMultiplication + +# --- TriangleAttention ------------------------------------------------------ + + +def test_triangle_attention_preserves_shape(): + pair = torch.randn(1, 6, 6, 8) + for start_node in (True, False): + layer = TriangleAttention(d_pair=8, n_head=2, d_hidden=4, start_node=start_node) + assert layer(pair).shape == pair.shape + + +def test_triangle_attention_zero_initialized(): + torch.manual_seed(0) + layer = TriangleAttention(d_pair=8, n_head=2, d_hidden=4) + # to_out is zero-initialised so the residual add starts as the identity. + assert bool((layer.to_out.weight == 0).all()) and bool( + (layer.to_out.bias == 0).all() + ) + assert bool((layer(torch.randn(1, 6, 6, 8)) == 0).all()) + + +# --- TriangleMultiplication ------------------------------------------------- + + +def test_triangle_multiplication_preserves_shape(): + pair = torch.randn(1, 6, 6, 8) + for direction in ("outgoing", "incoming"): + # Vanilla path lifts the d_pair == d_hidden cuEquivariance constraint. + layer = TriangleMultiplication( + d_pair=8, d_hidden=4, direction=direction, use_cuequivariance=False + ) + assert layer(pair).shape == pair.shape + + +def test_triangle_multiplication_rejects_invalid_direction(): + with pytest.raises(ValueError, match="direction must be 'outgoing' or 'incoming'"): + TriangleMultiplication(d_pair=8, direction="sideways", use_cuequivariance=False) + + +def test_triangle_multiplication_cuequivariance_requires_matching_dims(): + with pytest.raises(AssertionError, match="requires d_pair == d_hidden"): + TriangleMultiplication(d_pair=8, d_hidden=4, use_cuequivariance=True) diff --git a/models/rf3/tests/test_chiral.py b/models/rf3/tests/test_chiral.py index d4765d54..5d0b15bd 100644 --- a/models/rf3/tests/test_chiral.py +++ b/models/rf3/tests/test_chiral.py @@ -1,20 +1,30 @@ -"""Unit tests for rf3.metrics.chiral.calc_chiral_metrics_masked. - -The pure tensor core behind RF3's chirality metrics. The non-obvious contracts -pinned here: a chiral center is the dihedral of four atoms whose ideal angle is -the 5th column of ``chirals``; correctness is a *sign* match between the -predicted and ideal dihedral; a center counts only if all four of its atoms are -unmasked; duplicate rows that share a first-atom index (alternate orderings of -the same center) are collapsed to the first occurrence; and the empty cases -(no chirals, or a fully-masked structure) return ``{}``. +"""Unit tests for rf3.metrics.chiral. + +``calc_chiral_metrics_masked`` is the pure tensor core behind RF3's chirality +metrics. The non-obvious contracts pinned here: a chiral center is the dihedral +of four atoms whose ideal angle is the 5th column of ``chirals``; correctness is +a *sign* match between the predicted and ideal dihedral; a center counts only if +all four of its atoms are unmasked; duplicate rows that share a first-atom index +(alternate orderings of the same center) are collapsed to the first occurrence; +and the empty cases (no chirals, or a fully-masked structure) return ``{}``. + +``compute_chiral_metrics`` wraps that core: given predicted / ground-truth +``AtomArrayStack``s and (optionally) precomputed chiral features, it splits the +atoms into polymer vs non-polymer, drops ground-truth atoms with NaN +coordinates, and emits ``{category}_{n_chiral_centers,chiral_loss_mean, +percent_correct_chirality}`` keys only for categories that contain a scorable +center. Passing ``chiral_feats`` explicitly bypasses the rdkit/atomworks +feature generation, so the orchestration is exercised here on tiny fixtures. """ import math +import numpy as np import pytest import torch +from biotite.structure import AtomArrayStack from rf3.kinematics import get_dih -from rf3.metrics.chiral import calc_chiral_metrics_masked +from rf3.metrics.chiral import calc_chiral_metrics_masked, compute_chiral_metrics def _chiral_center_coords(theta: float) -> list[list[float]]: @@ -115,3 +125,58 @@ def test_batch_dimension_shapes(): assert result["chiral_loss_mean"].shape == (2,) assert result["percent_correct_chirality"].shape == (2,) + + +# --- compute_chiral_metrics ------------------------------------------------- + +# A single chiral center (atoms 0-3) whose predicted dihedral is +pi/2; ideal +1 -> a +# sign match, ideal -1 -> a mismatch. +_CENTER = [_chiral_center_coords(math.pi / 2)] # (1 model, 4 atoms, 3) +_FEATS_CORRECT = torch.tensor([[0.0, 1.0, 2.0, 3.0, 1.0]]) + + +def _stack(coords, is_polymer) -> AtomArrayStack: + """AtomArrayStack from (D, L, 3) coords + a per-atom is_polymer flag list.""" + coord = np.asarray(coords, dtype=np.float32) + stack = AtomArrayStack(depth=coord.shape[0], length=coord.shape[1]) + stack.coord = coord + stack.set_annotation("is_polymer", np.asarray(is_polymer, dtype=bool)) + return stack + + +def test_compute_chiral_metrics_polymer_center_correct(): + stack = _stack(_CENTER, [True] * 4) + out = compute_chiral_metrics(stack, stack, chiral_feats=_FEATS_CORRECT) + + assert out["polymer_n_chiral_centers"] == 1 + assert out["polymer_percent_correct_chirality"] == 1.0 + assert "polymer_chiral_loss_mean" in out + # No non-polymer atoms -> that category is absent entirely. + assert "non_polymer_n_chiral_centers" not in out + + +def test_compute_chiral_metrics_routes_to_non_polymer_category(): + stack = _stack(_CENTER, [False] * 4) + out = compute_chiral_metrics(stack, stack, chiral_feats=_FEATS_CORRECT) + + assert out["non_polymer_n_chiral_centers"] == 1 + assert "polymer_n_chiral_centers" not in out + + +def test_compute_chiral_metrics_wrong_sign_is_zero_percent(): + stack = _stack(_CENTER, [True] * 4) + feats_wrong = torch.tensor([[0.0, 1.0, 2.0, 3.0, -1.0]]) # ideal sign flipped + out = compute_chiral_metrics(stack, stack, chiral_feats=feats_wrong) + + assert out["polymer_percent_correct_chirality"] == 0.0 + + +def test_compute_chiral_metrics_nan_ground_truth_coord_drops_center(): + pred = _stack(_CENTER, [True] * 4) + gt = _stack(_CENTER, [True] * 4) + gt.coord[0, 3, :] = np.nan # one center atom unresolved in the ground truth + + out = compute_chiral_metrics(pred, gt, chiral_feats=_FEATS_CORRECT) + + # The center has an unresolved atom -> no scorable center in either category. + assert out == {} diff --git a/models/rf3/tests/test_inference_sampler.py b/models/rf3/tests/test_inference_sampler.py new file mode 100644 index 00000000..125b6ba2 --- /dev/null +++ b/models/rf3/tests/test_inference_sampler.py @@ -0,0 +1,107 @@ +"""Unit tests for rf3.diffusion_samplers.inference_sampler pure helpers. + +``SampleDiffusion`` runs the AF3 diffusion roll-out; the pure, network-free pieces +pinned here are the noise schedule and the initial-point-cloud sampler: + +- ``_construct_inference_noise_schedule`` builds the AF3 inference schedule + ``t_hat = sigma_data * (s_max**(1/p) + t*(s_min**(1/p) - s_max**(1/p)))**p`` over + ``num_timesteps`` points of ``t`` linearly spaced in ``[min_t, max_t]``. At ``t=0`` it + is ``sigma_data*s_max`` and at ``t=1`` it is ``sigma_data*s_min``, decreasing in between + (AF3 Supplement §3.7.1). +- ``SamplePartialDiffusion`` overrides the schedule to start the roll-out part-way + through, returning the full schedule's tail from index ``partial_t``. +- ``_get_initial_structure`` returns ``c0 * N(0,1) + coords`` — Gaussian noise scaled by + ``c0`` (derived from ``noise_schedule[0]``) added to the coords to be noised. +""" + +import pytest +import torch +from rf3.diffusion_samplers.inference_sampler import ( + SampleDiffusion, + SamplePartialDiffusion, +) + +# AF3 defaults (configs are the source of truth — no defaults in the constructor), +# with a short schedule so the tests stay tiny. +_KW = dict( + num_timesteps=8, + min_t=0, + max_t=1, + sigma_data=16, + s_min=4e-4, + s_max=160, + p=7, + gamma_0=0.8, + gamma_min=1.0, + noise_scale=1.003, + step_scale=1.5, + solver="af3", +) + +_CPU = torch.device("cpu") + + +# --- _construct_inference_noise_schedule ------------------------------------ + + +def test_noise_schedule_length_and_endpoints(): + sched = SampleDiffusion(**_KW)._construct_inference_noise_schedule(_CPU) + assert sched.shape == (8,) + # t=0 -> sigma_data*s_max; t=1 -> sigma_data*s_min (the (1/p)/**p cancel at the ends). + assert sched[0].item() == pytest.approx(16 * 160, rel=1e-4) + assert sched[-1].item() == pytest.approx(16 * 4e-4, rel=1e-4) + + +def test_noise_schedule_monotonically_decreasing(): + sched = SampleDiffusion(**_KW)._construct_inference_noise_schedule(_CPU) + assert bool((sched[1:] < sched[:-1]).all()) + + +def test_noise_schedule_matches_af3_formula(): + sched = SampleDiffusion(**_KW)._construct_inference_noise_schedule(_CPU) + t = torch.linspace(0, 1, 8) + expected = 16 * (160 ** (1 / 7) + t * (4e-4 ** (1 / 7) - 160 ** (1 / 7))) ** 7 + assert torch.allclose(sched, expected, atol=1e-6) + + +def test_noise_schedule_is_flat_when_s_min_equals_s_max(): + # With s_min == s_max the t-dependent term vanishes, so every step is sigma_data*s. + sched = SampleDiffusion( + **{**_KW, "s_min": 5.0, "s_max": 5.0} + )._construct_inference_noise_schedule(_CPU) + assert torch.allclose(sched, torch.full((8,), 16 * 5.0), atol=1e-4) + + +# --- SamplePartialDiffusion._construct_inference_noise_schedule -------------- + + +def test_partial_schedule_is_tail_of_full(): + full = SampleDiffusion(**_KW)._construct_inference_noise_schedule(_CPU) + partial = SamplePartialDiffusion( + partial_t=3, **_KW + )._construct_inference_noise_schedule(_CPU) + assert partial.shape == (8 - 3,) + assert torch.equal(partial, full[3:]) + + +def test_partial_schedule_rejects_t_at_or_above_num_timesteps(): + sampler = SamplePartialDiffusion(partial_t=8, **_KW) + with pytest.raises(AssertionError, match="must be less than num_timesteps"): + sampler._construct_inference_noise_schedule(_CPU) + + +# --- _get_initial_structure ------------------------------------------------- + + +def test_initial_structure_shape_and_zero_scale(): + coords = torch.randn(4, 6, 3) + sampler = SampleDiffusion(**_KW) + out = sampler._get_initial_structure( + torch.tensor(2.0), D=4, L=6, coord_atom_lvl_to_be_noised=coords + ) + assert out.shape == (4, 6, 3) + # c0 == 0 zeroes the noise term, so the result is exactly the coords to be noised. + zero_scale = sampler._get_initial_structure( + torch.tensor(0.0), D=4, L=6, coord_atom_lvl_to_be_noised=coords + ) + assert torch.equal(zero_scale, coords) diff --git a/models/rf3/tests/test_layer_utils.py b/models/rf3/tests/test_layer_utils.py new file mode 100644 index 00000000..b187b712 --- /dev/null +++ b/models/rf3/tests/test_layer_utils.py @@ -0,0 +1,148 @@ +"""Unit tests for rf3.model.layers.layer_utils shape helpers. + +These are the structural building blocks the diffusion stack composes: + +- ``MultiDimLinear`` is an ``nn.Linear`` whose output is reshaped from a flat + ``prod(out_shape)`` vector back into ``x.shape[:-1] + out_shape``; its weight is + re-initialised with Xavier-uniform (overriding ``nn.Linear``'s default). +- ``Transition`` is a SwiGLU feed-forward block: ``linear_3(silu(linear_1(LN(X))) * + linear_2(LN(X)))``, all projections bias-free, output width equal to the input. +- ``AdaLN`` is adaptive layer-norm — affine-free LayerNorm of the content ``Ai`` + modulated by a sigmoid gain and a bias, both linear projections of LayerNorm'd + conditioning ``Si``: ``sigmoid(W_g·LN(Si)) * LN_affine_free(Ai) + W_b·LN(Si)``. +- ``create_batch_dimension_if_not_present(n)`` decorates a function expecting an + ``n``-dim batched arg so it also accepts an ``(n-1)``-dim unbatched arg, inserting a + singleton batch dim before the call and stripping it from the result afterwards. +""" + +import math + +import pytest +import torch +import torch.nn.functional as F +from rf3.model.layers.layer_utils import ( + AdaLN, + MultiDimLinear, + Transition, + create_batch_dimension_if_not_present, +) + +# --- MultiDimLinear --------------------------------------------------------- + + +def test_multidim_linear_reshapes_output_to_out_shape(): + torch.manual_seed(0) + layer = MultiDimLinear(8, (3, 4)) + # Leading dims of the input are preserved; the feature dim becomes out_shape. + assert layer(torch.randn(2, 5, 8)).shape == (2, 5, 3, 4) + assert layer(torch.randn(7, 8)).shape == (7, 3, 4) + # Underlying Linear projects to the flattened width. + assert layer.out_features == 12 + assert layer.weight.shape == (12, 8) + + +def test_multidim_linear_is_flat_linear_then_reshape(): + torch.manual_seed(0) + layer = MultiDimLinear(8, (3, 4)) + x = torch.randn(2, 5, 8) + expected = F.linear(x, layer.weight, layer.bias).reshape(2, 5, 3, 4) + assert torch.allclose(layer(x), expected, atol=1e-6) + + +def test_multidim_linear_weight_is_xavier_bounded(): + torch.manual_seed(0) + layer = MultiDimLinear(8, (3, 4)) + # Xavier-uniform draws from [-bound, bound], bound = sqrt(6 / (fan_in + fan_out)). + bound = math.sqrt(6.0 / (8 + 12)) + assert layer.weight.abs().max().item() <= bound + + +# --- Transition ------------------------------------------------------------- + + +def test_transition_preserves_channel_width(): + torch.manual_seed(0) + block = Transition(n=2, c=6) + assert block(torch.randn(2, 7, 6)).shape == (2, 7, 6) + # Projections are bias-free and the hidden width is n*c. + assert block.linear_1.bias is None + assert block.linear_3.bias is None + assert block.linear_1.weight.shape == (12, 6) + assert block.linear_3.weight.shape == (6, 12) + + +def test_transition_matches_swiglu_gating(): + torch.manual_seed(0) + block = Transition(n=2, c=6) + x = torch.randn(2, 7, 6) + ln = block.layer_norm_1(x) + expected = block.linear_3(F.silu(block.linear_1(ln)) * block.linear_2(ln)) + assert torch.allclose(block(x), expected, atol=1e-6) + + +# --- AdaLN ------------------------------------------------------------------ + + +def test_adaln_output_shape_and_affine_free_content_norm(): + block = AdaLN(c_a=6, c_s=4) + out = block(torch.randn(2, 5, 6), torch.randn(2, 5, 4)) + assert out.shape == (2, 5, 6) + # Content LayerNorm is affine-free; the conditioning LayerNorm drops its bias. + assert block.ln_a.weight is None and block.ln_a.bias is None + assert block.ln_s.bias is None + + +def test_adaln_matches_gain_bias_modulation(): + torch.manual_seed(0) + block = AdaLN(c_a=6, c_s=4) + Ai, Si = torch.randn(2, 5, 6), torch.randn(2, 5, 4) + s = block.ln_s(Si) + gain = block.to_gain(s) + expected = gain * block.ln_a(Ai) + block.to_bias(s) + assert torch.allclose(block(Ai, Si), expected, atol=1e-6) + # The gain is a sigmoid, so modulation is bounded to (0, 1). + assert (gain > 0).all() and (gain < 1).all() + + +# --- create_batch_dimension_if_not_present ---------------------------------- + + +def test_batch_dim_inserted_and_stripped_for_unbatched_arg(): + seen = {} + + @create_batch_dimension_if_not_present(3) + def double(z): + seen["ndim"] = z.ndim + return z * 2 + + x = torch.randn(5, 8) + out = double(x) + # The wrapped function sees a 3-D arg, but the singleton batch dim is stripped back off. + assert seen["ndim"] == 3 + assert out.shape == (5, 8) + assert torch.equal(out, x * 2) + + +def test_batch_dim_passes_through_already_batched_arg(): + seen = {} + + @create_batch_dimension_if_not_present(3) + def double(z): + seen["ndim"] = z.ndim + return z * 2 + + x = torch.randn(2, 5, 8) + out = double(x) + assert seen["ndim"] == 3 + assert out.shape == (2, 5, 8) + assert torch.equal(out, x * 2) + + +def test_batch_dim_rejects_wrong_rank(): + @create_batch_dimension_if_not_present(3) + def identity(z): + return z + + # ndim 1 is neither the batched (3) nor unbatched (2) rank. + with pytest.raises(Exception, match="must have 2 or 3 dimensions"): + identity(torch.randn(8)) diff --git a/models/rf3/tests/test_loss_grads.py b/models/rf3/tests/test_loss_grads.py new file mode 100644 index 00000000..193c03c8 --- /dev/null +++ b/models/rf3/tests/test_loss_grads.py @@ -0,0 +1,110 @@ +"""Unit tests for the closed-form chiral/dihedral gradients in rf3.loss.loss. + +``calc_ddihedralmse_dxyz`` returns the analytic gradient of the summed dihedral +loss ``sum_i (dihedral_i - true_dih_i)**2`` with respect to the four atoms +``a, b, c, d`` of each dihedral — a hand-derived replacement for autograd. Inputs +are ``(leading, K, 3)`` with one ``true_dih`` entry per dihedral ``K``; the leading +dim is broadcast over. ``calc_chiral_grads_flat_impl`` evaluates that gradient for +a set of chiral centres (each four atom indices into ``xyz``) and scatters the +per-centre gradients back onto the full atom tensor with ``index_add_`` — so atoms +shared between centres accumulate, and the optional ``no_grad_on_chiral_center`` +flag drops the gradient on each centre's first atom. + +All tests use float32, the production coordinate dtype: the closed form builds an +``eye(3)`` without an explicit dtype and raises on float64 inputs. +""" + +import torch +from rf3.loss.loss import calc_chiral_grads_flat_impl, calc_ddihedralmse_dxyz + + +def _dihedral(a, b, c, d, eps=1e-6): + """The forward that the closed-form gradient differentiates (same eps as the source).""" + b0, b1, b2 = a - b, c - b, d - c + b1n = b1 / (b1.norm(dim=-1, keepdim=True) + eps) + v = b0 - (b0 * b1n).sum(-1, keepdim=True) * b1n + w = b2 - (b2 * b1n).sum(-1, keepdim=True) * b1n + x = (v * w).sum(-1) + y = (torch.cross(b1n, v, dim=-1) * w).sum(-1) + return torch.atan2(y + eps, x + eps) + + +# --- calc_ddihedralmse_dxyz ------------------------------------------------- + + +def test_ddihedralmse_matches_autograd(): + torch.manual_seed(0) + a, b, c, d = (torch.randn(1, 4, 3, requires_grad=True) for _ in range(4)) + true = torch.randn(4) + loss = ((_dihedral(a, b, c, d) - true) ** 2).sum() + ga, gb, gc, gd = torch.autograd.grad(loss, [a, b, c, d]) + expected = torch.stack([ga, gb, gc, gd], dim=-2) # (1, K, 4 atoms, 3 coords) + grads = calc_ddihedralmse_dxyz(a.detach(), b.detach(), c.detach(), d.detach(), true) + torch.testing.assert_close(grads, expected, atol=1e-3, rtol=1e-3) + + +def test_ddihedralmse_zero_gradient_at_truth(): + torch.manual_seed(1) + a, b, c, d = (torch.randn(1, 3, 3) for _ in range(4)) + true = _dihedral(a, b, c, d).reshape(3) # the truth is the actual dihedral + grads = calc_ddihedralmse_dxyz(a, b, c, d, true) + # dmse/ddih = 2*(dih - true) is exactly 0, so every coordinate gradient is 0. + assert torch.all(grads == 0.0) + + +def test_ddihedralmse_preserves_leading_shape(): + a, b, c, d = (torch.randn(2, 5, 3) for _ in range(4)) + grads = calc_ddihedralmse_dxyz(a, b, c, d, torch.randn(5)) + assert grads.shape == (2, 5, 4, 3) # leading dims + (4 atoms, 3 coords) + + +# --- calc_chiral_grads_flat_impl -------------------------------------------- + + +def test_chiral_grads_empty_centers_returns_zeros(): + xyz = torch.randn(1, 7, 3) + grads = calc_chiral_grads_flat_impl( + xyz, torch.zeros(0, 4, dtype=torch.long), torch.zeros(0), False + ) + assert grads.shape == xyz.shape + assert torch.all(grads == 0.0) + + +def test_chiral_grads_only_center_atoms_receive_gradient(): + torch.manual_seed(2) + xyz = torch.randn(1, 7, 3) + centers = torch.tensor([[1, 3, 4, 5]]) + grads = calc_chiral_grads_flat_impl(xyz, centers, torch.tensor([0.7]), False) + # Atoms outside the centre stay exactly zero; the four centre atoms get gradient. + assert torch.all(grads[:, [0, 2, 6]] == 0.0) + assert grads[:, [1, 3, 4, 5]].abs().sum() > 0 + + +def test_chiral_grads_accumulate_for_shared_atoms(): + torch.manual_seed(3) + xyz = torch.randn(1, 8, 3) + c1 = torch.tensor([[0, 1, 2, 3]]) + c2 = torch.tensor([[2, 4, 5, 6]]) # shares atom 2 with c1 + a1, a2 = torch.tensor([0.3]), torch.tensor([1.1]) + both = calc_chiral_grads_flat_impl( + xyz, torch.cat([c1, c2]), torch.cat([a1, a2]), False + ) + g1 = calc_chiral_grads_flat_impl(xyz, c1, a1, False) + g2 = calc_chiral_grads_flat_impl(xyz, c2, a2, False) + # index_add_ accumulates, so the combined gradient is the per-centre sum — + # including the shared atom 2, which gets a contribution from both centres. + torch.testing.assert_close(both, g1 + g2) + + +def test_chiral_grads_no_grad_on_center_zeroes_first_atom(): + torch.manual_seed(4) + xyz = torch.randn(1, 7, 3) + centers = torch.tensor([[1, 3, 4, 5]]) # atom 1 is the chiral centre + angles = torch.tensor([0.7]) + with_grad = calc_chiral_grads_flat_impl(xyz, centers, angles, False) + no_grad = calc_chiral_grads_flat_impl(xyz, centers, angles, True) + # The flag zeroes the gradient on the centre atom (the first of the four)... + assert with_grad[:, 1].abs().sum() > 0 + assert torch.all(no_grad[:, 1] == 0.0) + # ...and leaves the other three atoms untouched. + torch.testing.assert_close(no_grad[:, [3, 4, 5]], with_grad[:, [3, 4, 5]]) diff --git a/models/rf3/tests/test_mlff.py b/models/rf3/tests/test_mlff.py new file mode 100644 index 00000000..1d0faf97 --- /dev/null +++ b/models/rf3/tests/test_mlff.py @@ -0,0 +1,45 @@ +"""Unit tests for rf3.model.layers.mlff.ConformerEmbeddingWeightedAverage. + +The module compresses per-conformer atom-level embeddings ``[n_conformers, n_atom, d]`` +into a single per-atom embedding ``[n_atom, c_atom]``: a shared MLP downcasts each +conformer's features to ``c_atompair``, the conformers are flattened into one vector per +atom, and a final (bias-free, zero-initialised) linear projects to ``c_atom``. The +zero-init makes the block a no-op at the start of training (output ≈ 0, for a clean +residual add). The forward also pins two input contracts: the conformer count must match +``n_conformers`` exactly, and an over-wide feature dim is truncated to +``atom_level_embedding_dim`` while an under-wide one is rejected. +""" + +import pytest +import torch +from rf3.model.layers.mlff import ConformerEmbeddingWeightedAverage + + +def _layer(): + return ConformerEmbeddingWeightedAverage( + atom_level_embedding_dim=16, c_atompair=4, c_atom=8, n_conformers=3 + ) + + +def test_output_shape_and_zero_initialized(): + layer = _layer() + out = layer(torch.randn(3, 5, 16)) # [n_conformers, n_atom, d] + assert out.shape == (5, 8) # [n_atom, c_atom] + # The final projection is zero-initialised, so the block contributes nothing at init. + assert bool((layer.conformers_to_atom_single_embedding[0].weight == 0).all()) + assert bool((out == 0).all()) + + +def test_subsets_oversized_feature_dim(): + # A feature dim wider than atom_level_embedding_dim is truncated to it, not rejected. + assert _layer()(torch.randn(3, 5, 24)).shape == (5, 8) + + +def test_rejects_undersized_feature_dim(): + with pytest.raises(ValueError, match="is less than the expected dimension"): + _layer()(torch.randn(3, 5, 8)) + + +def test_rejects_wrong_conformer_count(): + with pytest.raises(AssertionError, match="Number of conformers must be consistent"): + _layer()(torch.randn(2, 5, 16)) diff --git a/models/rf3/tests/test_outer_product.py b/models/rf3/tests/test_outer_product.py new file mode 100644 index 00000000..5f4a59e9 --- /dev/null +++ b/models/rf3/tests/test_outer_product.py @@ -0,0 +1,53 @@ +"""Unit tests for rf3.model.layers.outer_product. + +``OuterProductMean`` / ``OuterProductMean_AF3`` turn an MSA embedding ``[B, N, L, c]`` +into a pair representation ``[B, L, L, c_out]``: LayerNorm the MSA, project to a small +hidden width left/right, take the outer product of the two projections over the hidden +dims and average it across the ``N`` sequence rows (the ``/N`` and the ``einsum`` sum +over ``s`` together form the mean), then project to ``c_out``. ``OuterProductMean`` +zero-initialises ``proj_out`` (so it is a no-op at init, the AF-style "start from zero" +trick); the AF3 variant does not. +""" + +import torch +from rf3.model.layers.outer_product import OuterProductMean, OuterProductMean_AF3 + + +def test_outer_product_mean_zero_initialized_output(): + torch.manual_seed(0) + layer = OuterProductMean(d_msa=8, d_pair=5, d_hidden=4) + out = layer(torch.randn(2, 3, 6, 8)) # [B, N, L, d_msa] + assert out.shape == (2, 6, 6, 5) # [B, L, L, d_pair] + # proj_out is zero-initialised, so the block contributes nothing until trained. + assert bool((out == 0).all()) + assert bool((layer.proj_left.bias == 0).all()) + assert bool((layer.proj_right.bias == 0).all()) + + +def test_outer_product_af3_output_shape(): + torch.manual_seed(0) + layer = OuterProductMean_AF3(c_msa_embed=8, c_outer_product=4, c_out=5) + assert layer(torch.randn(2, 3, 6, 8)).shape == (2, 6, 6, 5) + + +def test_outer_product_af3_matches_mean_einsum(): + torch.manual_seed(0) + layer = OuterProductMean_AF3(c_msa_embed=8, c_outer_product=4, c_out=5) + msa = torch.randn(2, 3, 6, 8) + B, N, L = msa.shape[:3] + normed = layer.norm(msa) + left = layer.proj_left(normed) + right = layer.proj_right(normed) / float(N) + expected = layer.proj_out( + torch.einsum("bsli,bsmj->blmij", left, right).reshape(B, L, L, -1) + ) + assert torch.allclose(layer(msa), expected, atol=1e-6) + + +def test_outer_product_mean_is_invariant_to_duplicate_rows(): + # The /N normalisation makes the block a true mean over sequence rows: N identical + # rows give the same output as a single copy of that row. + torch.manual_seed(0) + layer = OuterProductMean_AF3(c_msa_embed=8, c_outer_product=4, c_out=5) + row = torch.randn(2, 1, 6, 8) + assert torch.allclose(layer(row), layer(row.repeat(1, 4, 1, 1)), atol=1e-6) diff --git a/models/rf3/tests/test_predicted_error.py b/models/rf3/tests/test_predicted_error.py index 806e636b..2553741a 100644 --- a/models/rf3/tests/test_predicted_error.py +++ b/models/rf3/tests/test_predicted_error.py @@ -1,7 +1,9 @@ """Unit tests for rf3.metrics.predicted_error pure helpers. -Covers ``compute_ptm`` (the PAE -> predicted-TM reduction) and the thin -``ComputePTM`` Metric wrapper that reshapes the per-batch scores into a dict. +Covers ``compute_ptm`` (the PAE -> predicted-TM reduction), the thin +``ComputePTM`` Metric wrapper that reshapes the per-batch scores into a dict, +and ``ComputeIPTM`` which scores the *inter-chain* interfaces (overall, plus +protein-protein / protein-ligand / ligand-ligand sub-interfaces). ``compute_ptm`` takes a per-pair distribution over distance-error bins ``pae`` of shape ``[D, I, I, n_bins]``, softmaxes over the bins, weights each @@ -9,11 +11,13 @@ nearest bin), averages over the columns selected by ``to_calculate``, and returns the per-token maximum -> ``[D]``. So a distribution concentrated on the nearest bin scores highest, a uniform one scores the mean weight, and one on -the farthest bin scores lowest; all scores lie in ``(0, 1]``. +the farthest bin scores lowest; all scores lie in ``(0, 1]``. An empty +``to_calculate`` (no selected columns) averages nothing and scores 0. """ +import pytest import torch -from rf3.metrics.predicted_error import ComputePTM, compute_ptm +from rf3.metrics.predicted_error import ComputeIPTM, ComputePTM, compute_ptm D, I, N_BINS = 2, 5, 64 @@ -78,3 +82,62 @@ def test_compute_ptm_metric_returns_per_batch_dict(): assert set(out) == {f"ptm_{i}" for i in range(D)} assert all(0.0 < v <= 1.0 for v in out.values()) + + +# --- ComputeIPTM ------------------------------------------------------------ + +# Two chains of two tokens each; only the inter-chain pairs (asym 0 vs asym 1) are scored. +_TWO_CHAINS = torch.tensor([0, 0, 1, 1]) +_IPTM_FAMILIES = ( + "iptm", + "iptm_protein_protein", + "iptm_protein_ligand", + "iptm_ligand_ligand", +) + + +def _uniform_pae(n: int) -> torch.Tensor: + """Flat `[D, n, n, N_BINS]` logits -> uniform per-pair bin distribution.""" + return torch.zeros(D, n, n, N_BINS) + + +def test_iptm_keys_cover_every_interface_type_per_model(): + out = ComputeIPTM().compute( + pae=_uniform_pae(4), asym_id=_TWO_CHAINS, is_ligand=torch.tensor([0, 0, 1, 1]) + ) + + assert set(out) == {f"{fam}_{i}" for fam in _IPTM_FAMILIES for i in range(D)} + + +def test_iptm_all_protein_zeroes_ligand_interfaces(): + # No ligand tokens -> the protein-ligand and ligand-ligand masks are empty (score 0), + # and the protein-protein interface covers exactly the inter-chain pairs == overall iPTM. + out = ComputeIPTM().compute( + pae=_uniform_pae(4), + asym_id=_TWO_CHAINS, + is_ligand=torch.zeros(4, dtype=torch.long), + ) + + for i in range(D): + assert out[f"iptm_protein_ligand_{i}"] == 0.0 + assert out[f"iptm_ligand_ligand_{i}"] == 0.0 + assert out[f"iptm_protein_protein_{i}"] == pytest.approx(out[f"iptm_{i}"]) + + +def test_iptm_single_chain_scores_zero(): + # One chain -> no inter-chain pairs -> nothing selected -> every interface scores 0. + out = ComputeIPTM().compute( + pae=_uniform_pae(4), + asym_id=torch.zeros(4, dtype=torch.long), + is_ligand=torch.tensor([0, 0, 1, 1]), + ) + + assert all(v == 0.0 for v in out.values()) + + +def test_iptm_values_in_unit_range(): + out = ComputeIPTM().compute( + pae=_uniform_pae(4), asym_id=_TWO_CHAINS, is_ligand=torch.tensor([0, 1, 0, 1]) + ) + + assert all(0.0 <= v <= 1.0 for v in out.values()) diff --git a/models/rf3/tests/test_util_module.py b/models/rf3/tests/test_util_module.py new file mode 100644 index 00000000..5cd0c102 --- /dev/null +++ b/models/rf3/tests/test_util_module.py @@ -0,0 +1,63 @@ +"""Unit tests for rf3.util_module helpers. + +``rbf`` expands distances into a radial-basis (Gaussian) feature vector over +``D_count`` evenly spaced centres ``D_min .. D_min + (D_count-1)*D_sigma``: the +feature at the centre nearest a distance peaks at 1.0 and falls off as a Gaussian +of width ``D_sigma`` (far centres underflow to 0). ``init_lecun_normal`` replaces a +module's weight with a truncated-normal sample (clamped to ±2 before scaling) +whose post-truncation standard deviation is the Lecun value ``sqrt(scale / fan_in)``. +""" + +import math + +import pytest +import torch +from rf3.util_module import init_lecun_normal, rbf + +# Std of a standard normal truncated to [-2, 2]; the source divides by it so the +# scaled sample's std lands at sqrt(scale / fan_in) rather than below it. +_TRUNC_NORMAL_STD = 0.87962566103423978 + + +# --- rbf -------------------------------------------------------------------- + + +def test_rbf_appends_feature_dim(): + assert rbf(torch.rand(5)).shape == (5, 64) + assert rbf(torch.rand(2, 3), D_count=16).shape == (2, 3, 16) + + +def test_rbf_gaussian_values_at_known_distance(): + # D == D_min lands exactly on centre 0; later centres are D_sigma apart, so the + # features at D=0 are exp(-k**2) for k = 0, 1, 2, ... + vals = rbf(torch.tensor([0.0]), D_min=0.0, D_count=64, D_sigma=0.5)[0] + assert vals[0].item() == 1.0 + assert vals[1].item() == pytest.approx(math.exp(-1)) + assert vals[2].item() == pytest.approx(math.exp(-4)) + # The nearest centre is the argmax (D=0.5 -> centre 1), and values stay in [0, 1]. + assert int(rbf(torch.tensor([0.5]), D_sigma=0.5).argmax()) == 1 + full = rbf(torch.rand(20) * 30) + assert (full >= 0).all() and (full <= 1).all() + + +# --- init_lecun_normal ------------------------------------------------------ + + +def test_init_lecun_normal_returns_module_and_bounded_weight(): + torch.manual_seed(0) + linear = torch.nn.Linear(128, 64, bias=False) + result = init_lecun_normal(linear) + stddev = math.sqrt(1.0 / 128) / _TRUNC_NORMAL_STD # fan_in = 128 + assert result is linear + assert isinstance(linear.weight, torch.nn.Parameter) + assert linear.weight.shape == (64, 128) + # The truncated normal is clamped to ±2 before the stddev scaling. + assert linear.weight.abs().max().item() <= 2 * stddev + + +def test_init_lecun_normal_scales_to_lecun_stddev(): + torch.manual_seed(0) + linear = torch.nn.Linear(512, 512, bias=False) + init_lecun_normal(linear) + # Lecun-normal: post-truncation std ~ sqrt(scale / fan_in). + assert linear.weight.std().item() == pytest.approx(math.sqrt(1 / 512), rel=0.05)