From f51f881e132b6116073c20aee60f4c7cb06b9da0 Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Tue, 9 Jun 2026 03:25:34 +0000 Subject: [PATCH 01/10] test(rf3): add CPU unit tests for closed-form loss gradients and util_module Cover the pure numeric helpers in rf3.loss.loss and rf3.util_module: - calc_ddihedralmse_dxyz: the hand-derived closed-form dihedral-loss gradient, pinned against torch.autograd of a mirrored forward, plus zero-at-truth and leading-dim shape preservation. - calc_chiral_grads_flat_impl: empty-centres, scatter routing onto only the centre atoms, index_add_ accumulation for shared atoms, and the no_grad_on_chiral_center flag. - rbf: Gaussian distance-encoding (feature dim, exact values, [0,1]). - init_lecun_normal: module returned, Parameter weight, +/-2*stddev bound, and Lecun std sqrt(scale/fan_in). Test-only; both modules are already mypy-clean. Tests run in float32 (the production coordinate dtype). Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- models/rf3/tests/test_loss_grads.py | 110 +++++++++++++++++++++++++++ models/rf3/tests/test_util_module.py | 63 +++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 models/rf3/tests/test_loss_grads.py create mode 100644 models/rf3/tests/test_util_module.py diff --git a/models/rf3/tests/test_loss_grads.py b/models/rf3/tests/test_loss_grads.py new file mode 100644 index 00000000..193c03c8 --- /dev/null +++ b/models/rf3/tests/test_loss_grads.py @@ -0,0 +1,110 @@ +"""Unit tests for the closed-form chiral/dihedral gradients in rf3.loss.loss. + +``calc_ddihedralmse_dxyz`` returns the analytic gradient of the summed dihedral +loss ``sum_i (dihedral_i - true_dih_i)**2`` with respect to the four atoms +``a, b, c, d`` of each dihedral — a hand-derived replacement for autograd. Inputs +are ``(leading, K, 3)`` with one ``true_dih`` entry per dihedral ``K``; the leading +dim is broadcast over. ``calc_chiral_grads_flat_impl`` evaluates that gradient for +a set of chiral centres (each four atom indices into ``xyz``) and scatters the +per-centre gradients back onto the full atom tensor with ``index_add_`` — so atoms +shared between centres accumulate, and the optional ``no_grad_on_chiral_center`` +flag drops the gradient on each centre's first atom. + +All tests use float32, the production coordinate dtype: the closed form builds an +``eye(3)`` without an explicit dtype and raises on float64 inputs. +""" + +import torch +from rf3.loss.loss import calc_chiral_grads_flat_impl, calc_ddihedralmse_dxyz + + +def _dihedral(a, b, c, d, eps=1e-6): + """The forward that the closed-form gradient differentiates (same eps as the source).""" + b0, b1, b2 = a - b, c - b, d - c + b1n = b1 / (b1.norm(dim=-1, keepdim=True) + eps) + v = b0 - (b0 * b1n).sum(-1, keepdim=True) * b1n + w = b2 - (b2 * b1n).sum(-1, keepdim=True) * b1n + x = (v * w).sum(-1) + y = (torch.cross(b1n, v, dim=-1) * w).sum(-1) + return torch.atan2(y + eps, x + eps) + + +# --- calc_ddihedralmse_dxyz ------------------------------------------------- + + +def test_ddihedralmse_matches_autograd(): + torch.manual_seed(0) + a, b, c, d = (torch.randn(1, 4, 3, requires_grad=True) for _ in range(4)) + true = torch.randn(4) + loss = ((_dihedral(a, b, c, d) - true) ** 2).sum() + ga, gb, gc, gd = torch.autograd.grad(loss, [a, b, c, d]) + expected = torch.stack([ga, gb, gc, gd], dim=-2) # (1, K, 4 atoms, 3 coords) + grads = calc_ddihedralmse_dxyz(a.detach(), b.detach(), c.detach(), d.detach(), true) + torch.testing.assert_close(grads, expected, atol=1e-3, rtol=1e-3) + + +def test_ddihedralmse_zero_gradient_at_truth(): + torch.manual_seed(1) + a, b, c, d = (torch.randn(1, 3, 3) for _ in range(4)) + true = _dihedral(a, b, c, d).reshape(3) # the truth is the actual dihedral + grads = calc_ddihedralmse_dxyz(a, b, c, d, true) + # dmse/ddih = 2*(dih - true) is exactly 0, so every coordinate gradient is 0. + assert torch.all(grads == 0.0) + + +def test_ddihedralmse_preserves_leading_shape(): + a, b, c, d = (torch.randn(2, 5, 3) for _ in range(4)) + grads = calc_ddihedralmse_dxyz(a, b, c, d, torch.randn(5)) + assert grads.shape == (2, 5, 4, 3) # leading dims + (4 atoms, 3 coords) + + +# --- calc_chiral_grads_flat_impl -------------------------------------------- + + +def test_chiral_grads_empty_centers_returns_zeros(): + xyz = torch.randn(1, 7, 3) + grads = calc_chiral_grads_flat_impl( + xyz, torch.zeros(0, 4, dtype=torch.long), torch.zeros(0), False + ) + assert grads.shape == xyz.shape + assert torch.all(grads == 0.0) + + +def test_chiral_grads_only_center_atoms_receive_gradient(): + torch.manual_seed(2) + xyz = torch.randn(1, 7, 3) + centers = torch.tensor([[1, 3, 4, 5]]) + grads = calc_chiral_grads_flat_impl(xyz, centers, torch.tensor([0.7]), False) + # Atoms outside the centre stay exactly zero; the four centre atoms get gradient. + assert torch.all(grads[:, [0, 2, 6]] == 0.0) + assert grads[:, [1, 3, 4, 5]].abs().sum() > 0 + + +def test_chiral_grads_accumulate_for_shared_atoms(): + torch.manual_seed(3) + xyz = torch.randn(1, 8, 3) + c1 = torch.tensor([[0, 1, 2, 3]]) + c2 = torch.tensor([[2, 4, 5, 6]]) # shares atom 2 with c1 + a1, a2 = torch.tensor([0.3]), torch.tensor([1.1]) + both = calc_chiral_grads_flat_impl( + xyz, torch.cat([c1, c2]), torch.cat([a1, a2]), False + ) + g1 = calc_chiral_grads_flat_impl(xyz, c1, a1, False) + g2 = calc_chiral_grads_flat_impl(xyz, c2, a2, False) + # index_add_ accumulates, so the combined gradient is the per-centre sum — + # including the shared atom 2, which gets a contribution from both centres. + torch.testing.assert_close(both, g1 + g2) + + +def test_chiral_grads_no_grad_on_center_zeroes_first_atom(): + torch.manual_seed(4) + xyz = torch.randn(1, 7, 3) + centers = torch.tensor([[1, 3, 4, 5]]) # atom 1 is the chiral centre + angles = torch.tensor([0.7]) + with_grad = calc_chiral_grads_flat_impl(xyz, centers, angles, False) + no_grad = calc_chiral_grads_flat_impl(xyz, centers, angles, True) + # The flag zeroes the gradient on the centre atom (the first of the four)... + assert with_grad[:, 1].abs().sum() > 0 + assert torch.all(no_grad[:, 1] == 0.0) + # ...and leaves the other three atoms untouched. + torch.testing.assert_close(no_grad[:, [3, 4, 5]], with_grad[:, [3, 4, 5]]) diff --git a/models/rf3/tests/test_util_module.py b/models/rf3/tests/test_util_module.py new file mode 100644 index 00000000..5cd0c102 --- /dev/null +++ b/models/rf3/tests/test_util_module.py @@ -0,0 +1,63 @@ +"""Unit tests for rf3.util_module helpers. + +``rbf`` expands distances into a radial-basis (Gaussian) feature vector over +``D_count`` evenly spaced centres ``D_min .. D_min + (D_count-1)*D_sigma``: the +feature at the centre nearest a distance peaks at 1.0 and falls off as a Gaussian +of width ``D_sigma`` (far centres underflow to 0). ``init_lecun_normal`` replaces a +module's weight with a truncated-normal sample (clamped to ±2 before scaling) +whose post-truncation standard deviation is the Lecun value ``sqrt(scale / fan_in)``. +""" + +import math + +import pytest +import torch +from rf3.util_module import init_lecun_normal, rbf + +# Std of a standard normal truncated to [-2, 2]; the source divides by it so the +# scaled sample's std lands at sqrt(scale / fan_in) rather than below it. +_TRUNC_NORMAL_STD = 0.87962566103423978 + + +# --- rbf -------------------------------------------------------------------- + + +def test_rbf_appends_feature_dim(): + assert rbf(torch.rand(5)).shape == (5, 64) + assert rbf(torch.rand(2, 3), D_count=16).shape == (2, 3, 16) + + +def test_rbf_gaussian_values_at_known_distance(): + # D == D_min lands exactly on centre 0; later centres are D_sigma apart, so the + # features at D=0 are exp(-k**2) for k = 0, 1, 2, ... + vals = rbf(torch.tensor([0.0]), D_min=0.0, D_count=64, D_sigma=0.5)[0] + assert vals[0].item() == 1.0 + assert vals[1].item() == pytest.approx(math.exp(-1)) + assert vals[2].item() == pytest.approx(math.exp(-4)) + # The nearest centre is the argmax (D=0.5 -> centre 1), and values stay in [0, 1]. + assert int(rbf(torch.tensor([0.5]), D_sigma=0.5).argmax()) == 1 + full = rbf(torch.rand(20) * 30) + assert (full >= 0).all() and (full <= 1).all() + + +# --- init_lecun_normal ------------------------------------------------------ + + +def test_init_lecun_normal_returns_module_and_bounded_weight(): + torch.manual_seed(0) + linear = torch.nn.Linear(128, 64, bias=False) + result = init_lecun_normal(linear) + stddev = math.sqrt(1.0 / 128) / _TRUNC_NORMAL_STD # fan_in = 128 + assert result is linear + assert isinstance(linear.weight, torch.nn.Parameter) + assert linear.weight.shape == (64, 128) + # The truncated normal is clamped to ±2 before the stddev scaling. + assert linear.weight.abs().max().item() <= 2 * stddev + + +def test_init_lecun_normal_scales_to_lecun_stddev(): + torch.manual_seed(0) + linear = torch.nn.Linear(512, 512, bias=False) + init_lecun_normal(linear) + # Lecun-normal: post-truncation std ~ sqrt(scale / fan_in). + assert linear.weight.std().item() == pytest.approx(math.sqrt(1 / 512), rel=0.05) From 43c9319d8f1a558c0b1b9a2d2e4e692ab4ca6745 Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Tue, 9 Jun 2026 18:23:03 +0000 Subject: [PATCH 02/10] test(rf3): add CPU unit tests for layer_utils shape helpers Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- models/rf3/tests/test_layer_utils.py | 148 +++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 models/rf3/tests/test_layer_utils.py diff --git a/models/rf3/tests/test_layer_utils.py b/models/rf3/tests/test_layer_utils.py new file mode 100644 index 00000000..b187b712 --- /dev/null +++ b/models/rf3/tests/test_layer_utils.py @@ -0,0 +1,148 @@ +"""Unit tests for rf3.model.layers.layer_utils shape helpers. + +These are the structural building blocks the diffusion stack composes: + +- ``MultiDimLinear`` is an ``nn.Linear`` whose output is reshaped from a flat + ``prod(out_shape)`` vector back into ``x.shape[:-1] + out_shape``; its weight is + re-initialised with Xavier-uniform (overriding ``nn.Linear``'s default). +- ``Transition`` is a SwiGLU feed-forward block: ``linear_3(silu(linear_1(LN(X))) * + linear_2(LN(X)))``, all projections bias-free, output width equal to the input. +- ``AdaLN`` is adaptive layer-norm — affine-free LayerNorm of the content ``Ai`` + modulated by a sigmoid gain and a bias, both linear projections of LayerNorm'd + conditioning ``Si``: ``sigmoid(W_g·LN(Si)) * LN_affine_free(Ai) + W_b·LN(Si)``. +- ``create_batch_dimension_if_not_present(n)`` decorates a function expecting an + ``n``-dim batched arg so it also accepts an ``(n-1)``-dim unbatched arg, inserting a + singleton batch dim before the call and stripping it from the result afterwards. +""" + +import math + +import pytest +import torch +import torch.nn.functional as F +from rf3.model.layers.layer_utils import ( + AdaLN, + MultiDimLinear, + Transition, + create_batch_dimension_if_not_present, +) + +# --- MultiDimLinear --------------------------------------------------------- + + +def test_multidim_linear_reshapes_output_to_out_shape(): + torch.manual_seed(0) + layer = MultiDimLinear(8, (3, 4)) + # Leading dims of the input are preserved; the feature dim becomes out_shape. + assert layer(torch.randn(2, 5, 8)).shape == (2, 5, 3, 4) + assert layer(torch.randn(7, 8)).shape == (7, 3, 4) + # Underlying Linear projects to the flattened width. + assert layer.out_features == 12 + assert layer.weight.shape == (12, 8) + + +def test_multidim_linear_is_flat_linear_then_reshape(): + torch.manual_seed(0) + layer = MultiDimLinear(8, (3, 4)) + x = torch.randn(2, 5, 8) + expected = F.linear(x, layer.weight, layer.bias).reshape(2, 5, 3, 4) + assert torch.allclose(layer(x), expected, atol=1e-6) + + +def test_multidim_linear_weight_is_xavier_bounded(): + torch.manual_seed(0) + layer = MultiDimLinear(8, (3, 4)) + # Xavier-uniform draws from [-bound, bound], bound = sqrt(6 / (fan_in + fan_out)). + bound = math.sqrt(6.0 / (8 + 12)) + assert layer.weight.abs().max().item() <= bound + + +# --- Transition ------------------------------------------------------------- + + +def test_transition_preserves_channel_width(): + torch.manual_seed(0) + block = Transition(n=2, c=6) + assert block(torch.randn(2, 7, 6)).shape == (2, 7, 6) + # Projections are bias-free and the hidden width is n*c. + assert block.linear_1.bias is None + assert block.linear_3.bias is None + assert block.linear_1.weight.shape == (12, 6) + assert block.linear_3.weight.shape == (6, 12) + + +def test_transition_matches_swiglu_gating(): + torch.manual_seed(0) + block = Transition(n=2, c=6) + x = torch.randn(2, 7, 6) + ln = block.layer_norm_1(x) + expected = block.linear_3(F.silu(block.linear_1(ln)) * block.linear_2(ln)) + assert torch.allclose(block(x), expected, atol=1e-6) + + +# --- AdaLN ------------------------------------------------------------------ + + +def test_adaln_output_shape_and_affine_free_content_norm(): + block = AdaLN(c_a=6, c_s=4) + out = block(torch.randn(2, 5, 6), torch.randn(2, 5, 4)) + assert out.shape == (2, 5, 6) + # Content LayerNorm is affine-free; the conditioning LayerNorm drops its bias. + assert block.ln_a.weight is None and block.ln_a.bias is None + assert block.ln_s.bias is None + + +def test_adaln_matches_gain_bias_modulation(): + torch.manual_seed(0) + block = AdaLN(c_a=6, c_s=4) + Ai, Si = torch.randn(2, 5, 6), torch.randn(2, 5, 4) + s = block.ln_s(Si) + gain = block.to_gain(s) + expected = gain * block.ln_a(Ai) + block.to_bias(s) + assert torch.allclose(block(Ai, Si), expected, atol=1e-6) + # The gain is a sigmoid, so modulation is bounded to (0, 1). + assert (gain > 0).all() and (gain < 1).all() + + +# --- create_batch_dimension_if_not_present ---------------------------------- + + +def test_batch_dim_inserted_and_stripped_for_unbatched_arg(): + seen = {} + + @create_batch_dimension_if_not_present(3) + def double(z): + seen["ndim"] = z.ndim + return z * 2 + + x = torch.randn(5, 8) + out = double(x) + # The wrapped function sees a 3-D arg, but the singleton batch dim is stripped back off. + assert seen["ndim"] == 3 + assert out.shape == (5, 8) + assert torch.equal(out, x * 2) + + +def test_batch_dim_passes_through_already_batched_arg(): + seen = {} + + @create_batch_dimension_if_not_present(3) + def double(z): + seen["ndim"] = z.ndim + return z * 2 + + x = torch.randn(2, 5, 8) + out = double(x) + assert seen["ndim"] == 3 + assert out.shape == (2, 5, 8) + assert torch.equal(out, x * 2) + + +def test_batch_dim_rejects_wrong_rank(): + @create_batch_dimension_if_not_present(3) + def identity(z): + return z + + # ndim 1 is neither the batched (3) nor unbatched (2) rank. + with pytest.raises(Exception, match="must have 2 or 3 dimensions"): + identity(torch.randn(8)) From bd36cfaf296d52b67142d1b43c6aab29e6a7244d Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Tue, 9 Jun 2026 19:17:12 +0000 Subject: [PATCH 03/10] test(rf3): add CPU unit tests for the diffusion sampler and outer-product layer Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- models/rf3/tests/test_inference_sampler.py | 107 +++++++++++++++++++++ models/rf3/tests/test_outer_product.py | 53 ++++++++++ 2 files changed, 160 insertions(+) create mode 100644 models/rf3/tests/test_inference_sampler.py create mode 100644 models/rf3/tests/test_outer_product.py diff --git a/models/rf3/tests/test_inference_sampler.py b/models/rf3/tests/test_inference_sampler.py new file mode 100644 index 00000000..125b6ba2 --- /dev/null +++ b/models/rf3/tests/test_inference_sampler.py @@ -0,0 +1,107 @@ +"""Unit tests for rf3.diffusion_samplers.inference_sampler pure helpers. + +``SampleDiffusion`` runs the AF3 diffusion roll-out; the pure, network-free pieces +pinned here are the noise schedule and the initial-point-cloud sampler: + +- ``_construct_inference_noise_schedule`` builds the AF3 inference schedule + ``t_hat = sigma_data * (s_max**(1/p) + t*(s_min**(1/p) - s_max**(1/p)))**p`` over + ``num_timesteps`` points of ``t`` linearly spaced in ``[min_t, max_t]``. At ``t=0`` it + is ``sigma_data*s_max`` and at ``t=1`` it is ``sigma_data*s_min``, decreasing in between + (AF3 Supplement §3.7.1). +- ``SamplePartialDiffusion`` overrides the schedule to start the roll-out part-way + through, returning the full schedule's tail from index ``partial_t``. +- ``_get_initial_structure`` returns ``c0 * N(0,1) + coords`` — Gaussian noise scaled by + ``c0`` (derived from ``noise_schedule[0]``) added to the coords to be noised. +""" + +import pytest +import torch +from rf3.diffusion_samplers.inference_sampler import ( + SampleDiffusion, + SamplePartialDiffusion, +) + +# AF3 defaults (configs are the source of truth — no defaults in the constructor), +# with a short schedule so the tests stay tiny. +_KW = dict( + num_timesteps=8, + min_t=0, + max_t=1, + sigma_data=16, + s_min=4e-4, + s_max=160, + p=7, + gamma_0=0.8, + gamma_min=1.0, + noise_scale=1.003, + step_scale=1.5, + solver="af3", +) + +_CPU = torch.device("cpu") + + +# --- _construct_inference_noise_schedule ------------------------------------ + + +def test_noise_schedule_length_and_endpoints(): + sched = SampleDiffusion(**_KW)._construct_inference_noise_schedule(_CPU) + assert sched.shape == (8,) + # t=0 -> sigma_data*s_max; t=1 -> sigma_data*s_min (the (1/p)/**p cancel at the ends). + assert sched[0].item() == pytest.approx(16 * 160, rel=1e-4) + assert sched[-1].item() == pytest.approx(16 * 4e-4, rel=1e-4) + + +def test_noise_schedule_monotonically_decreasing(): + sched = SampleDiffusion(**_KW)._construct_inference_noise_schedule(_CPU) + assert bool((sched[1:] < sched[:-1]).all()) + + +def test_noise_schedule_matches_af3_formula(): + sched = SampleDiffusion(**_KW)._construct_inference_noise_schedule(_CPU) + t = torch.linspace(0, 1, 8) + expected = 16 * (160 ** (1 / 7) + t * (4e-4 ** (1 / 7) - 160 ** (1 / 7))) ** 7 + assert torch.allclose(sched, expected, atol=1e-6) + + +def test_noise_schedule_is_flat_when_s_min_equals_s_max(): + # With s_min == s_max the t-dependent term vanishes, so every step is sigma_data*s. + sched = SampleDiffusion( + **{**_KW, "s_min": 5.0, "s_max": 5.0} + )._construct_inference_noise_schedule(_CPU) + assert torch.allclose(sched, torch.full((8,), 16 * 5.0), atol=1e-4) + + +# --- SamplePartialDiffusion._construct_inference_noise_schedule -------------- + + +def test_partial_schedule_is_tail_of_full(): + full = SampleDiffusion(**_KW)._construct_inference_noise_schedule(_CPU) + partial = SamplePartialDiffusion( + partial_t=3, **_KW + )._construct_inference_noise_schedule(_CPU) + assert partial.shape == (8 - 3,) + assert torch.equal(partial, full[3:]) + + +def test_partial_schedule_rejects_t_at_or_above_num_timesteps(): + sampler = SamplePartialDiffusion(partial_t=8, **_KW) + with pytest.raises(AssertionError, match="must be less than num_timesteps"): + sampler._construct_inference_noise_schedule(_CPU) + + +# --- _get_initial_structure ------------------------------------------------- + + +def test_initial_structure_shape_and_zero_scale(): + coords = torch.randn(4, 6, 3) + sampler = SampleDiffusion(**_KW) + out = sampler._get_initial_structure( + torch.tensor(2.0), D=4, L=6, coord_atom_lvl_to_be_noised=coords + ) + assert out.shape == (4, 6, 3) + # c0 == 0 zeroes the noise term, so the result is exactly the coords to be noised. + zero_scale = sampler._get_initial_structure( + torch.tensor(0.0), D=4, L=6, coord_atom_lvl_to_be_noised=coords + ) + assert torch.equal(zero_scale, coords) diff --git a/models/rf3/tests/test_outer_product.py b/models/rf3/tests/test_outer_product.py new file mode 100644 index 00000000..5f4a59e9 --- /dev/null +++ b/models/rf3/tests/test_outer_product.py @@ -0,0 +1,53 @@ +"""Unit tests for rf3.model.layers.outer_product. + +``OuterProductMean`` / ``OuterProductMean_AF3`` turn an MSA embedding ``[B, N, L, c]`` +into a pair representation ``[B, L, L, c_out]``: LayerNorm the MSA, project to a small +hidden width left/right, take the outer product of the two projections over the hidden +dims and average it across the ``N`` sequence rows (the ``/N`` and the ``einsum`` sum +over ``s`` together form the mean), then project to ``c_out``. ``OuterProductMean`` +zero-initialises ``proj_out`` (so it is a no-op at init, the AF-style "start from zero" +trick); the AF3 variant does not. +""" + +import torch +from rf3.model.layers.outer_product import OuterProductMean, OuterProductMean_AF3 + + +def test_outer_product_mean_zero_initialized_output(): + torch.manual_seed(0) + layer = OuterProductMean(d_msa=8, d_pair=5, d_hidden=4) + out = layer(torch.randn(2, 3, 6, 8)) # [B, N, L, d_msa] + assert out.shape == (2, 6, 6, 5) # [B, L, L, d_pair] + # proj_out is zero-initialised, so the block contributes nothing until trained. + assert bool((out == 0).all()) + assert bool((layer.proj_left.bias == 0).all()) + assert bool((layer.proj_right.bias == 0).all()) + + +def test_outer_product_af3_output_shape(): + torch.manual_seed(0) + layer = OuterProductMean_AF3(c_msa_embed=8, c_outer_product=4, c_out=5) + assert layer(torch.randn(2, 3, 6, 8)).shape == (2, 6, 6, 5) + + +def test_outer_product_af3_matches_mean_einsum(): + torch.manual_seed(0) + layer = OuterProductMean_AF3(c_msa_embed=8, c_outer_product=4, c_out=5) + msa = torch.randn(2, 3, 6, 8) + B, N, L = msa.shape[:3] + normed = layer.norm(msa) + left = layer.proj_left(normed) + right = layer.proj_right(normed) / float(N) + expected = layer.proj_out( + torch.einsum("bsli,bsmj->blmij", left, right).reshape(B, L, L, -1) + ) + assert torch.allclose(layer(msa), expected, atol=1e-6) + + +def test_outer_product_mean_is_invariant_to_duplicate_rows(): + # The /N normalisation makes the block a true mean over sequence rows: N identical + # rows give the same output as a single copy of that row. + torch.manual_seed(0) + layer = OuterProductMean_AF3(c_msa_embed=8, c_outer_product=4, c_out=5) + row = torch.randn(2, 1, 6, 8) + assert torch.allclose(layer(row), layer(row.repeat(1, 4, 1, 1)), atol=1e-6) From 6b1fc18fdf8b065302290a726d82dcedf47d83bf Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Tue, 9 Jun 2026 20:01:00 +0000 Subject: [PATCH 04/10] test(rf3): add CPU unit tests for the mlff and triangle-attention layers Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- models/rf3/tests/test_attention.py | 58 ++++++++++++++++++++++++++++++ models/rf3/tests/test_mlff.py | 45 +++++++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 models/rf3/tests/test_attention.py create mode 100644 models/rf3/tests/test_mlff.py diff --git a/models/rf3/tests/test_attention.py b/models/rf3/tests/test_attention.py new file mode 100644 index 00000000..7328782d --- /dev/null +++ b/models/rf3/tests/test_attention.py @@ -0,0 +1,58 @@ +"""Unit tests for rf3.model.layers.attention triangle-update blocks (vanilla path). + +Both blocks map a pair representation ``[B, L, L, d_pair]`` to the same shape (for a +residual add) and, on CPU, run their vanilla PyTorch path (cuEquivariance is off). + +- ``TriangleAttention`` zero-initialises its output projection ``to_out`` so the block is + the identity (output 0) at the start of training; ``start_node`` switches the attention + axis (rows vs. transposed columns) but preserves the output shape either way. +- ``TriangleMultiplication`` validates its ``direction`` (``"outgoing"``/``"incoming"``) + and, when the cuEquivariance kernel is requested, requires ``d_pair == d_hidden``; the + vanilla path lifts that constraint. +""" + +import pytest +import torch +from rf3.model.layers.attention import TriangleAttention, TriangleMultiplication + +# --- TriangleAttention ------------------------------------------------------ + + +def test_triangle_attention_preserves_shape(): + pair = torch.randn(1, 6, 6, 8) + for start_node in (True, False): + layer = TriangleAttention(d_pair=8, n_head=2, d_hidden=4, start_node=start_node) + assert layer(pair).shape == pair.shape + + +def test_triangle_attention_zero_initialized(): + torch.manual_seed(0) + layer = TriangleAttention(d_pair=8, n_head=2, d_hidden=4) + # to_out is zero-initialised so the residual add starts as the identity. + assert bool((layer.to_out.weight == 0).all()) and bool( + (layer.to_out.bias == 0).all() + ) + assert bool((layer(torch.randn(1, 6, 6, 8)) == 0).all()) + + +# --- TriangleMultiplication ------------------------------------------------- + + +def test_triangle_multiplication_preserves_shape(): + pair = torch.randn(1, 6, 6, 8) + for direction in ("outgoing", "incoming"): + # Vanilla path lifts the d_pair == d_hidden cuEquivariance constraint. + layer = TriangleMultiplication( + d_pair=8, d_hidden=4, direction=direction, use_cuequivariance=False + ) + assert layer(pair).shape == pair.shape + + +def test_triangle_multiplication_rejects_invalid_direction(): + with pytest.raises(ValueError, match="direction must be 'outgoing' or 'incoming'"): + TriangleMultiplication(d_pair=8, direction="sideways", use_cuequivariance=False) + + +def test_triangle_multiplication_cuequivariance_requires_matching_dims(): + with pytest.raises(AssertionError, match="requires d_pair == d_hidden"): + TriangleMultiplication(d_pair=8, d_hidden=4, use_cuequivariance=True) diff --git a/models/rf3/tests/test_mlff.py b/models/rf3/tests/test_mlff.py new file mode 100644 index 00000000..1d0faf97 --- /dev/null +++ b/models/rf3/tests/test_mlff.py @@ -0,0 +1,45 @@ +"""Unit tests for rf3.model.layers.mlff.ConformerEmbeddingWeightedAverage. + +The module compresses per-conformer atom-level embeddings ``[n_conformers, n_atom, d]`` +into a single per-atom embedding ``[n_atom, c_atom]``: a shared MLP downcasts each +conformer's features to ``c_atompair``, the conformers are flattened into one vector per +atom, and a final (bias-free, zero-initialised) linear projects to ``c_atom``. The +zero-init makes the block a no-op at the start of training (output ≈ 0, for a clean +residual add). The forward also pins two input contracts: the conformer count must match +``n_conformers`` exactly, and an over-wide feature dim is truncated to +``atom_level_embedding_dim`` while an under-wide one is rejected. +""" + +import pytest +import torch +from rf3.model.layers.mlff import ConformerEmbeddingWeightedAverage + + +def _layer(): + return ConformerEmbeddingWeightedAverage( + atom_level_embedding_dim=16, c_atompair=4, c_atom=8, n_conformers=3 + ) + + +def test_output_shape_and_zero_initialized(): + layer = _layer() + out = layer(torch.randn(3, 5, 16)) # [n_conformers, n_atom, d] + assert out.shape == (5, 8) # [n_atom, c_atom] + # The final projection is zero-initialised, so the block contributes nothing at init. + assert bool((layer.conformers_to_atom_single_embedding[0].weight == 0).all()) + assert bool((out == 0).all()) + + +def test_subsets_oversized_feature_dim(): + # A feature dim wider than atom_level_embedding_dim is truncated to it, not rejected. + assert _layer()(torch.randn(3, 5, 24)).shape == (5, 8) + + +def test_rejects_undersized_feature_dim(): + with pytest.raises(ValueError, match="is less than the expected dimension"): + _layer()(torch.randn(3, 5, 8)) + + +def test_rejects_wrong_conformer_count(): + with pytest.raises(AssertionError, match="Number of conformers must be consistent"): + _layer()(torch.randn(2, 5, 16)) From 4bc0dbb808a67e2806c9aee1d9f60758922d4f9d Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Tue, 9 Jun 2026 20:44:07 +0000 Subject: [PATCH 05/10] test(rf3): add CPU unit tests for the af3_losses symmetry-resolution helpers Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- models/rf3/tests/test_af3_loss_symmetry.py | 136 +++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 models/rf3/tests/test_af3_loss_symmetry.py diff --git a/models/rf3/tests/test_af3_loss_symmetry.py b/models/rf3/tests/test_af3_loss_symmetry.py new file mode 100644 index 00000000..45977c64 --- /dev/null +++ b/models/rf3/tests/test_af3_loss_symmetry.py @@ -0,0 +1,136 @@ +"""Unit tests for the symmetry-resolution geometry in ``rf3.loss.af3_losses``. + +Both helpers below are the load-bearing machinery behind ``SubunitSymmetryResolution`` +and ``ResidueSymmetryResolution`` (used by ``rf3.symmetry.resolve`` / +``rf3.trainers.rf3``), which re-label the ground-truth coordinates to the symmetry +copy / automorphism that best matches the prediction before the loss is taken. + +- ``SubunitSymmetryResolution._rms_align`` is a batched Kabsch fit. Given predicted + coordinates ``X_fixed`` (``Nbatch x L x 3``) and candidate native copies ``X_moving`` + (``Nambig x L x 3``) it returns ``(u_moving, R, u_fixed)`` such that + ``(X_moving - u_moving) @ R + u_fixed`` lands the native on the prediction — the exact + transform ``_resolve_subunits`` then applies to the native centres of mass. The SVD is + sign-corrected so ``R`` is always a proper rotation (``det = +1``), never a reflection. +- ``ResidueSymmetryResolution._get_best`` picks, per model in the batch, the atom + automorphism (a permutation of a set of interchangeable atom indices) whose + intra-structure distance pattern best matches the prediction, then rewrites the native + coordinates / mask at those positions to that permutation. + +All tests run in float32 (production dtype); ``_rms_align`` builds its sign-correction +matrix with an un-typed ``torch.eye`` that only matches float32 inputs (see the roadmap). +""" + +import pytest +import torch +from rf3.loss.af3_losses import ResidueSymmetryResolution, SubunitSymmetryResolution + +# A non-degenerate, non-coplanar point cloud to align (L = 5). +_POINTS = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, 3.0], + [1.0, 1.0, 1.0], + [-1.0, 2.0, -1.0], + ] +) + + +def _rotation_z(theta: float) -> torch.Tensor: + """Proper rotation (det +1) about the z-axis by ``theta`` radians.""" + c, s = torch.cos(torch.tensor(theta)), torch.sin(torch.tensor(theta)) + return torch.tensor([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]]) + + +# --- SubunitSymmetryResolution._rms_align ----------------------------------- + + +def test_rms_align_recovers_a_rigid_transform(): + # X_moving is X_fixed pushed through a known rotation + translation; the returned + # transform must land it back on X_fixed. + fixed = _POINTS[None] # (Nbatch=1, L, 3) + moving = (_POINTS @ _rotation_z(0.7) + torch.tensor([2.0, -1.0, 3.0]))[None] + + u_moving, R, u_fixed = SubunitSymmetryResolution()._rms_align(fixed, moving) + + aligned = (moving[0] - u_moving[0, 0]) @ R[0, 0] + u_fixed[0, 0] + assert torch.allclose(aligned, _POINTS, atol=1e-4) + # Sign-corrected SVD → a proper rotation, not a reflection. + assert torch.linalg.det(R[0, 0]).item() == pytest.approx(1.0, abs=1e-4) + + +def test_rms_align_identity_when_moving_equals_fixed(): + fixed = _POINTS[None] + u_moving, R, u_fixed = SubunitSymmetryResolution()._rms_align(fixed, _POINTS[None]) + + assert torch.allclose(R[0, 0], torch.eye(3), atol=1e-4) + assert torch.allclose(u_moving[0, 0], u_fixed[0, 0], atol=1e-4) + aligned = (_POINTS - u_moving[0, 0]) @ R[0, 0] + u_fixed[0, 0] + assert torch.allclose(aligned, _POINTS, atol=1e-4) + + +def test_rms_align_corrects_reflection_to_proper_rotation(): + # A mirror image cannot be rotated onto the original; the sign correction must still + # return a proper rotation (det +1) rather than the optimal-but-improper reflection. + fixed = _POINTS[None] + reflected = (_POINTS @ torch.diag(torch.tensor([1.0, 1.0, -1.0])))[None] + + _, R, _ = SubunitSymmetryResolution()._rms_align(fixed, reflected) + + assert torch.linalg.det(R[0, 0]).item() == pytest.approx(1.0, abs=1e-4) + + +def test_rms_align_output_shapes_broadcast_over_ambig_and_batch(): + # u_moving is per-ambiguity, u_fixed is per-batch, R is the full cross product — the + # broadcast-ready shapes _resolve_subunits relies on. + fixed = _POINTS[None].repeat(3, 1, 1) # Nbatch = 3 + moving = _POINTS[None].repeat(2, 1, 1) # Nambig = 2 + + u_moving, R, u_fixed = SubunitSymmetryResolution()._rms_align(fixed, moving) + + assert u_moving.shape == (2, 1, 3) + assert R.shape == (2, 3, 3, 3) + assert u_fixed.shape == (1, 3, 3) + + +# --- ResidueSymmetryResolution._get_best ------------------------------------ + + +def _get_best_inputs(pred_sym, native_sym, context): + """Build (x_pred, x_native, x_native_mask, a_i) for a 1-model, 3-atom case. + + Atoms 0 and 1 are interchangeable (the automorphism set); atom 2 is fixed context + whose distances to atoms 0/1 break the tie. ``a_i`` offers the identity ordering + ``[0, 1]`` and the swap ``[1, 0]``. + """ + x_pred = torch.tensor([[*pred_sym, context]]) # (1, 3, 3) + x_native = torch.tensor([[*native_sym, context]]) + mask = torch.ones(1, 3, dtype=torch.bool) + a_i = torch.tensor([[0, 1], [1, 0]]) + return x_pred, x_native, mask, a_i + + +def test_get_best_selects_the_swap_that_matches_prediction(): + # Native has atoms 0/1 swapped relative to the prediction; _get_best must undo it so + # the native's interchangeable atoms line up with the prediction's arrangement. + pred_sym = [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]] + native_sym = [[2.0, 0.0, 0.0], [0.0, 0.0, 0.0]] # swapped + context = [0.0, 3.0, 0.0] + x_pred, x_native, mask, a_i = _get_best_inputs(pred_sym, native_sym, context) + + out_native, _ = ResidueSymmetryResolution()._get_best(x_pred, x_native, mask, a_i) + + assert torch.allclose(out_native[0, 0], torch.tensor([0.0, 0.0, 0.0])) + assert torch.allclose(out_native[0, 1], torch.tensor([2.0, 0.0, 0.0])) + + +def test_get_best_leaves_already_matching_native_unchanged(): + # Native already matches the prediction → the identity ordering wins → no rewrite. + sym = [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]] + context = [0.0, 3.0, 0.0] + x_pred, x_native, mask, a_i = _get_best_inputs(sym, sym, context) + before = x_native.clone() + + out_native, _ = ResidueSymmetryResolution()._get_best(x_pred, x_native, mask, a_i) + + assert torch.allclose(out_native, before) From 9cdee60f652ac1ff28be2d2b5f4c8062e3b1d9a7 Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Tue, 9 Jun 2026 21:42:55 +0000 Subject: [PATCH 06/10] test(rf3): add CPU unit tests for chiral and iPTM metric orchestration Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- models/rf3/tests/test_chiral.py | 85 +++++++++++++++++++++--- models/rf3/tests/test_predicted_error.py | 71 ++++++++++++++++++-- 2 files changed, 142 insertions(+), 14 deletions(-) diff --git a/models/rf3/tests/test_chiral.py b/models/rf3/tests/test_chiral.py index d4765d54..5d0b15bd 100644 --- a/models/rf3/tests/test_chiral.py +++ b/models/rf3/tests/test_chiral.py @@ -1,20 +1,30 @@ -"""Unit tests for rf3.metrics.chiral.calc_chiral_metrics_masked. - -The pure tensor core behind RF3's chirality metrics. The non-obvious contracts -pinned here: a chiral center is the dihedral of four atoms whose ideal angle is -the 5th column of ``chirals``; correctness is a *sign* match between the -predicted and ideal dihedral; a center counts only if all four of its atoms are -unmasked; duplicate rows that share a first-atom index (alternate orderings of -the same center) are collapsed to the first occurrence; and the empty cases -(no chirals, or a fully-masked structure) return ``{}``. +"""Unit tests for rf3.metrics.chiral. + +``calc_chiral_metrics_masked`` is the pure tensor core behind RF3's chirality +metrics. The non-obvious contracts pinned here: a chiral center is the dihedral +of four atoms whose ideal angle is the 5th column of ``chirals``; correctness is +a *sign* match between the predicted and ideal dihedral; a center counts only if +all four of its atoms are unmasked; duplicate rows that share a first-atom index +(alternate orderings of the same center) are collapsed to the first occurrence; +and the empty cases (no chirals, or a fully-masked structure) return ``{}``. + +``compute_chiral_metrics`` wraps that core: given predicted / ground-truth +``AtomArrayStack``s and (optionally) precomputed chiral features, it splits the +atoms into polymer vs non-polymer, drops ground-truth atoms with NaN +coordinates, and emits ``{category}_{n_chiral_centers,chiral_loss_mean, +percent_correct_chirality}`` keys only for categories that contain a scorable +center. Passing ``chiral_feats`` explicitly bypasses the rdkit/atomworks +feature generation, so the orchestration is exercised here on tiny fixtures. """ import math +import numpy as np import pytest import torch +from biotite.structure import AtomArrayStack from rf3.kinematics import get_dih -from rf3.metrics.chiral import calc_chiral_metrics_masked +from rf3.metrics.chiral import calc_chiral_metrics_masked, compute_chiral_metrics def _chiral_center_coords(theta: float) -> list[list[float]]: @@ -115,3 +125,58 @@ def test_batch_dimension_shapes(): assert result["chiral_loss_mean"].shape == (2,) assert result["percent_correct_chirality"].shape == (2,) + + +# --- compute_chiral_metrics ------------------------------------------------- + +# A single chiral center (atoms 0-3) whose predicted dihedral is +pi/2; ideal +1 -> a +# sign match, ideal -1 -> a mismatch. +_CENTER = [_chiral_center_coords(math.pi / 2)] # (1 model, 4 atoms, 3) +_FEATS_CORRECT = torch.tensor([[0.0, 1.0, 2.0, 3.0, 1.0]]) + + +def _stack(coords, is_polymer) -> AtomArrayStack: + """AtomArrayStack from (D, L, 3) coords + a per-atom is_polymer flag list.""" + coord = np.asarray(coords, dtype=np.float32) + stack = AtomArrayStack(depth=coord.shape[0], length=coord.shape[1]) + stack.coord = coord + stack.set_annotation("is_polymer", np.asarray(is_polymer, dtype=bool)) + return stack + + +def test_compute_chiral_metrics_polymer_center_correct(): + stack = _stack(_CENTER, [True] * 4) + out = compute_chiral_metrics(stack, stack, chiral_feats=_FEATS_CORRECT) + + assert out["polymer_n_chiral_centers"] == 1 + assert out["polymer_percent_correct_chirality"] == 1.0 + assert "polymer_chiral_loss_mean" in out + # No non-polymer atoms -> that category is absent entirely. + assert "non_polymer_n_chiral_centers" not in out + + +def test_compute_chiral_metrics_routes_to_non_polymer_category(): + stack = _stack(_CENTER, [False] * 4) + out = compute_chiral_metrics(stack, stack, chiral_feats=_FEATS_CORRECT) + + assert out["non_polymer_n_chiral_centers"] == 1 + assert "polymer_n_chiral_centers" not in out + + +def test_compute_chiral_metrics_wrong_sign_is_zero_percent(): + stack = _stack(_CENTER, [True] * 4) + feats_wrong = torch.tensor([[0.0, 1.0, 2.0, 3.0, -1.0]]) # ideal sign flipped + out = compute_chiral_metrics(stack, stack, chiral_feats=feats_wrong) + + assert out["polymer_percent_correct_chirality"] == 0.0 + + +def test_compute_chiral_metrics_nan_ground_truth_coord_drops_center(): + pred = _stack(_CENTER, [True] * 4) + gt = _stack(_CENTER, [True] * 4) + gt.coord[0, 3, :] = np.nan # one center atom unresolved in the ground truth + + out = compute_chiral_metrics(pred, gt, chiral_feats=_FEATS_CORRECT) + + # The center has an unresolved atom -> no scorable center in either category. + assert out == {} diff --git a/models/rf3/tests/test_predicted_error.py b/models/rf3/tests/test_predicted_error.py index 806e636b..2553741a 100644 --- a/models/rf3/tests/test_predicted_error.py +++ b/models/rf3/tests/test_predicted_error.py @@ -1,7 +1,9 @@ """Unit tests for rf3.metrics.predicted_error pure helpers. -Covers ``compute_ptm`` (the PAE -> predicted-TM reduction) and the thin -``ComputePTM`` Metric wrapper that reshapes the per-batch scores into a dict. +Covers ``compute_ptm`` (the PAE -> predicted-TM reduction), the thin +``ComputePTM`` Metric wrapper that reshapes the per-batch scores into a dict, +and ``ComputeIPTM`` which scores the *inter-chain* interfaces (overall, plus +protein-protein / protein-ligand / ligand-ligand sub-interfaces). ``compute_ptm`` takes a per-pair distribution over distance-error bins ``pae`` of shape ``[D, I, I, n_bins]``, softmaxes over the bins, weights each @@ -9,11 +11,13 @@ nearest bin), averages over the columns selected by ``to_calculate``, and returns the per-token maximum -> ``[D]``. So a distribution concentrated on the nearest bin scores highest, a uniform one scores the mean weight, and one on -the farthest bin scores lowest; all scores lie in ``(0, 1]``. +the farthest bin scores lowest; all scores lie in ``(0, 1]``. An empty +``to_calculate`` (no selected columns) averages nothing and scores 0. """ +import pytest import torch -from rf3.metrics.predicted_error import ComputePTM, compute_ptm +from rf3.metrics.predicted_error import ComputeIPTM, ComputePTM, compute_ptm D, I, N_BINS = 2, 5, 64 @@ -78,3 +82,62 @@ def test_compute_ptm_metric_returns_per_batch_dict(): assert set(out) == {f"ptm_{i}" for i in range(D)} assert all(0.0 < v <= 1.0 for v in out.values()) + + +# --- ComputeIPTM ------------------------------------------------------------ + +# Two chains of two tokens each; only the inter-chain pairs (asym 0 vs asym 1) are scored. +_TWO_CHAINS = torch.tensor([0, 0, 1, 1]) +_IPTM_FAMILIES = ( + "iptm", + "iptm_protein_protein", + "iptm_protein_ligand", + "iptm_ligand_ligand", +) + + +def _uniform_pae(n: int) -> torch.Tensor: + """Flat `[D, n, n, N_BINS]` logits -> uniform per-pair bin distribution.""" + return torch.zeros(D, n, n, N_BINS) + + +def test_iptm_keys_cover_every_interface_type_per_model(): + out = ComputeIPTM().compute( + pae=_uniform_pae(4), asym_id=_TWO_CHAINS, is_ligand=torch.tensor([0, 0, 1, 1]) + ) + + assert set(out) == {f"{fam}_{i}" for fam in _IPTM_FAMILIES for i in range(D)} + + +def test_iptm_all_protein_zeroes_ligand_interfaces(): + # No ligand tokens -> the protein-ligand and ligand-ligand masks are empty (score 0), + # and the protein-protein interface covers exactly the inter-chain pairs == overall iPTM. + out = ComputeIPTM().compute( + pae=_uniform_pae(4), + asym_id=_TWO_CHAINS, + is_ligand=torch.zeros(4, dtype=torch.long), + ) + + for i in range(D): + assert out[f"iptm_protein_ligand_{i}"] == 0.0 + assert out[f"iptm_ligand_ligand_{i}"] == 0.0 + assert out[f"iptm_protein_protein_{i}"] == pytest.approx(out[f"iptm_{i}"]) + + +def test_iptm_single_chain_scores_zero(): + # One chain -> no inter-chain pairs -> nothing selected -> every interface scores 0. + out = ComputeIPTM().compute( + pae=_uniform_pae(4), + asym_id=torch.zeros(4, dtype=torch.long), + is_ligand=torch.tensor([0, 0, 1, 1]), + ) + + assert all(v == 0.0 for v in out.values()) + + +def test_iptm_values_in_unit_range(): + out = ComputeIPTM().compute( + pae=_uniform_pae(4), asym_id=_TWO_CHAINS, is_ligand=torch.tensor([0, 1, 0, 1]) + ) + + assert all(0.0 <= v <= 1.0 for v in out.values()) From 30fedf2a23a8d0648b10e4fcb075f5bb95fd1e2d Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Tue, 9 Jun 2026 23:23:25 +0000 Subject: [PATCH 07/10] chore(rf3): move paired_msa mypy suppression into the file Replace the central pyproject ignore_errors override for rf3.data.paired_msa with a file-level '# mypy: ignore-errors' directive in the module itself. The module is broken against the installed atomworks (subclasses a now-function) and needs a PandasDataset-API refactor to clear honestly; keeping it in mypy's files scope with an in-file directive makes the suppression visible at the point of breakage and re-enables checking the moment the directive is removed. Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- models/rf3/src/rf3/data/paired_msa.py | 11 +++++++++++ pyproject.toml | 21 +++++---------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/models/rf3/src/rf3/data/paired_msa.py b/models/rf3/src/rf3/data/paired_msa.py index a368a880..8e7b81c8 100644 --- a/models/rf3/src/rf3/data/paired_msa.py +++ b/models/rf3/src/rf3/data/paired_msa.py @@ -1,3 +1,14 @@ +# mypy: ignore-errors +# +# This module does not type-check (and does not even import) against the installed +# atomworks: `MultiInputDatasetWrapper` below subclasses +# `atomworks.ml.datasets.StructuralDatasetWrapper`, which atomworks turned into a +# deprecated factory *function* — subclassing it raises `TypeError` at import time. +# Making it type-check requires a real refactor onto the `PandasDataset` API, validated +# on cluster data (see `.ai/roadmap.md`), not type annotations. The suppression lives +# here, in the file, rather than in `pyproject.toml` so it is visible to anyone reviving +# the module: when this file imports and type-checks cleanly again, delete this directive +# to restore mypy coverage (the module stays inside mypy's `files` scope). import os import socket import time diff --git a/pyproject.toml b/pyproject.toml index 3becc089..d0f9535e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -248,22 +248,11 @@ module = [ ] ignore_errors = true -# rf3 enablement ratchet. `models/rf3` was brought into mypy's scope in 0014; -# these modules had pre-existing type errors at that point and are exempted until -# annotated. Fix the errors and remove the entry to enable type-checking for that -# module (same playbook as the rfd3 ratchet above). Do NOT add modules. -# -# `rf3.data.paired_msa` is NOT a type-only clear: it subclasses the atomworks -# `StructuralDatasetWrapper`, which the installed atomworks turned into a deprecated -# factory *function*, so the module fails to import (subclassing a function), and it -# also calls an unimported `save_failed_example_to_disk`. Clearing it needs a real -# refactor to the `PandasDataset` API (cluster-data-validated), not annotations. -# Tracked in `.ai/roadmap.md`. -[[tool.mypy.overrides]] -module = [ - "rf3.data.paired_msa", -] -ignore_errors = true +# NOTE: the rf3 enablement ratchet (0014) is fully cleared — there is no rf3 module-level +# mypy exemption here. The one module that cannot type-check, `rf3.data.paired_msa` +# (broken against the installed atomworks; needs a `PandasDataset`-API refactor), carries +# a file-level `# mypy: ignore-errors` directive in the module itself, so the suppression +# is visible where the code is and the module stays inside mypy's `files` scope. # Testing ---------------------------------------------------------------------------- [tool.pytest.ini_options] From 3dfbda37c9846ae02dfe6d1f2376fb1dbbec7ed3 Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Wed, 10 Jun 2026 00:54:15 +0000 Subject: [PATCH 08/10] chore(foundry): annotate rigid.py for strict mypy and add geometry unit tests Bring src/foundry/utils/rigid.py under a per-module disallow_untyped_defs + check_untyped_defs override (the first Track 1 direction-(b) strictness slice) by annotating its ~47 previously-untyped Rotation/Rigid methods and quaternion/matrix helpers. Annotation-only, no behaviour change; self-references use string forward-refs and the identity helpers' shape params are widened to Tuple[int, ...]. Add tests/test_rigid.py: 33 fixture-backed CPU unit tests pinning the pure geometry (quat<->matrix round-trips, compose/invert, apply, the 4x4/7-vector encodings, and from_3_points) against independent references. Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- pyproject.toml | 8 + src/foundry/utils/rigid.py | 109 ++++++------ tests/test_rigid.py | 340 +++++++++++++++++++++++++++++++++++++ 3 files changed, 406 insertions(+), 51 deletions(-) create mode 100644 tests/test_rigid.py diff --git a/pyproject.toml b/pyproject.toml index d0f9535e..0cb1f652 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -254,6 +254,14 @@ ignore_errors = true # a file-level `# mypy: ignore-errors` directive in the module itself, so the suppression # is visible where the code is and the module stays inside mypy's `files` scope. +# Per-module strictness ratchet (direction (b)). The global baseline above leaves +# disallow_untyped_defs / check_untyped_defs off; fully-annotated modules opt into strict +# checking here one at a time, so any new untyped def in them fails the gate. +[[tool.mypy.overrides]] +module = ["foundry.utils.rigid"] +disallow_untyped_defs = true +check_untyped_defs = true + # Testing ---------------------------------------------------------------------------- [tool.pytest.ini_options] # Shared-layer tests live in the top-level `tests/`; model-specific tests live in each diff --git a/src/foundry/utils/rigid.py b/src/foundry/utils/rigid.py index f8f3bf26..77406ec3 100644 --- a/src/foundry/utils/rigid.py +++ b/src/foundry/utils/rigid.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Tuple, cast +from typing import Any, Callable, Optional, Tuple, cast import numpy as np import torch @@ -101,7 +101,7 @@ def rot_vec_mul(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor: def identity_rot_mats( - batch_dims: Tuple[int], + batch_dims: Tuple[int, ...], dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = True, @@ -126,7 +126,7 @@ def identity_trans( def identity_quats( - batch_dims: Tuple[int], + batch_dims: Tuple[int, ...], dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = True, @@ -146,7 +146,7 @@ def identity_quats( _qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)} -def _to_mat(pairs): +def _to_mat(pairs: list[tuple[str, int]]) -> np.ndarray: mat = np.zeros((4, 4)) for pair in pairs: key, value = pair @@ -193,7 +193,7 @@ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor: def rot_to_quat( rot: torch.Tensor, -): +) -> torch.Tensor: if rot.shape[-2:] != (3, 3): raise ValueError("Input rotation is incorrectly shaped") @@ -245,7 +245,7 @@ def rot_to_quat( _QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] -def quat_multiply(quat1, quat2): +def quat_multiply(quat1: torch.Tensor, quat2: torch.Tensor) -> torch.Tensor: """Multiply a quaternion by another quaternion.""" mat = quat1.new_tensor(_QUAT_MULTIPLY) reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape) @@ -255,7 +255,7 @@ def quat_multiply(quat1, quat2): ) -def quat_multiply_by_vec(quat, vec): +def quat_multiply_by_vec(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: """Multiply a quaternion by a pure-vector quaternion.""" mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC) reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape) @@ -264,11 +264,11 @@ def quat_multiply_by_vec(quat, vec): ) -def invert_rot_mat(rot_mat: torch.Tensor): +def invert_rot_mat(rot_mat: torch.Tensor) -> torch.Tensor: return rot_mat.transpose(-1, -2) -def invert_quat(quat: torch.Tensor): +def invert_quat(quat: torch.Tensor) -> torch.Tensor: quat_prime = quat.clone() quat_prime[..., 1:] *= -1 inv = quat_prime / torch.sum(quat**2, dim=-1, keepdim=True) @@ -327,12 +327,12 @@ def __init__( @staticmethod def identity( - shape, + shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = True, fmt: str = "quat", - ): + ) -> "Rotation": """ Returns an identity Rotation. @@ -369,7 +369,7 @@ def identity( # Magic methods - def __getitem__(self, index: Any): + def __getitem__(self, index: Any) -> "Rotation": """ Allows torch-style indexing over the virtual shape of the rotation object. See documentation for the shape property. @@ -392,7 +392,7 @@ def __getitem__(self, index: Any): else: raise ValueError("Both rotations are None") - def __setitem__(self, index: Any, new: Any): + def __setitem__(self, index: Any, new: Any) -> None: if not isinstance(index, tuple): index = (index,) @@ -406,7 +406,7 @@ def __setitem__(self, index: Any, new: Any): def __mul__( self, right: torch.Tensor, - ): + ) -> "Rotation": """ Pointwise left multiplication of the rotation with a tensor. Can be used to e.g. mask the Rotation. @@ -432,7 +432,7 @@ def __mul__( def __rmul__( self, left: torch.Tensor, - ): + ) -> "Rotation": """ Reverse pointwise multiplication of the rotation with a tensor. @@ -559,7 +559,7 @@ def get_cur_rot(self) -> torch.Tensor: else: raise ValueError("Both rotations are None") - def get_rotvec(self, eps=1e-4) -> torch.Tensor: + def get_rotvec(self, eps: float = 1e-4) -> torch.Tensor: """ Return the underlying axis-angle rotation vector. @@ -594,7 +594,7 @@ def compose_q_update_vec( q_update_vec: torch.Tensor, normalize_quats: bool = True, update_mask: torch.Tensor | None = None, - ): + ) -> "Rotation": """ Returns a new quaternion Rotation after updating the current object's underlying rotation with a quaternion update, formatted @@ -621,7 +621,7 @@ def compose_q_update_vec( normalize_quats=normalize_quats, ) - def compose_r(self, r): + def compose_r(self, r: "Rotation") -> "Rotation": """ Compose the rotation matrices of the current Rotation object with those of another. @@ -637,7 +637,7 @@ def compose_r(self, r): new_rot_mats = rot_matmul(r1, r2) return Rotation(rot_mats=new_rot_mats, quats=None) - def compose_q(self, r, normalize_quats: bool = True): + def compose_q(self, r: "Rotation", normalize_quats: bool = True) -> "Rotation": """ Compose the quaternions of the current Rotation object with those of another. @@ -684,7 +684,7 @@ def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: inv_rot_mats = invert_rot_mat(rot_mats) return rot_vec_mul(inv_rot_mats, pts) - def invert(self): + def invert(self) -> "Rotation": """ Returns the inverse of the current Rotation. @@ -707,7 +707,7 @@ def invert(self): def unsqueeze( self, dim: int, - ): + ) -> "Rotation": """ Analogous to torch.unsqueeze. The dimension is relative to the shape of the Rotation object. @@ -731,9 +731,9 @@ def unsqueeze( @staticmethod def cat( - rs, + rs: list["Rotation"], dim: int, - ): + ) -> "Rotation": """ Concatenates rotations along one of the batch dimensions. Analogous to torch.cat(). @@ -755,7 +755,7 @@ def cat( return Rotation(rot_mats=cat_rot_mats, quats=None) - def map_tensor_fn(self, fn): + def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> "Rotation": """ Apply a Tensor -> Tensor function to underlying rotation tensors, mapping over the rotation dimension(s). Can be used e.g. to sum out @@ -782,7 +782,7 @@ def map_tensor_fn(self, fn): else: raise ValueError("Both rotations are None") - def cuda(self): + def cuda(self) -> "Rotation": """ Analogous to the cuda() method of torch Tensors @@ -798,7 +798,9 @@ def cuda(self): else: raise ValueError("Both rotations are None") - def to(self, device: Optional[torch.device], dtype: Optional[torch.dtype]): + def to( + self, device: Optional[torch.device], dtype: Optional[torch.dtype] + ) -> "Rotation": """ Analogous to the to() method of torch Tensors @@ -824,7 +826,7 @@ def to(self, device: Optional[torch.device], dtype: Optional[torch.dtype]): else: raise ValueError("Both rotations are None") - def detach(self): + def detach(self) -> "Rotation": """ Returns a copy of the Rotation whose underlying Tensor has been detached from its torch graph. @@ -905,12 +907,12 @@ def __init__( @staticmethod def identity( - shape: Tuple[int], + shape: Tuple[int, ...], dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, requires_grad: bool = True, fmt: str = "quat", - ): + ) -> "Rigid": """ Constructs an identity transformation. @@ -934,7 +936,7 @@ def identity( def __getitem__( self, index: Any, - ): + ) -> "Rigid": """ Indexes the affine transformation with PyTorch-style indices. The index is applied to the shared dimensions of both the rotation @@ -966,7 +968,7 @@ def __getitem__( def __mul__( self, right: torch.Tensor, - ): + ) -> "Rigid": """ Pointwise left multiplication of the transformation with a tensor. Can be used to e.g. mask the Rigid. @@ -988,7 +990,7 @@ def __mul__( def __rmul__( self, left: torch.Tensor, - ): + ) -> "Rigid": """ Reverse pointwise multiplication of the transformation with a tensor. @@ -1045,7 +1047,7 @@ def compose_q_update_vec( self, q_update_vec: torch.Tensor, update_mask: torch.Tensor | None = None, - ): + ) -> "Rigid": """ Composes the transformation with a quaternion update vector of shape [*, 6], where the final 6 columns represent the x, y, and @@ -1071,7 +1073,7 @@ def compose_tran_update_vec( self, t_vec: torch.Tensor, update_mask: torch.Tensor | None = None, - ): + ) -> "Rigid": """ Composes the transformation with a quaternion update vector of shape [*, 3], where columns represent a 3D translation. @@ -1090,8 +1092,8 @@ def compose_tran_update_vec( def compose( self, - r, - ): + r: "Rigid", + ) -> "Rigid": """ Composes the current rigid object with another. @@ -1105,7 +1107,7 @@ def compose( new_trans = self._rots.apply(r._trans) + self._trans return Rigid(new_rot, new_trans) - def compose_r(self, rot, order="right"): + def compose_r(self, rot: Rotation, order: str = "right") -> "Rigid": """ Composes the current rigid object with another. @@ -1152,7 +1154,7 @@ def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: pts = pts - self._trans return self._rots.invert_apply(pts) - def invert(self): + def invert(self) -> "Rigid": """ Inverts the transformation. @@ -1164,7 +1166,7 @@ def invert(self): return Rigid(rot_inv, -1 * trn_inv) - def map_tensor_fn(self, fn): + def map_tensor_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> "Rigid": """ Apply a Tensor -> Tensor function to underlying translation and rotation tensors, mapping over the translation/rotation dimensions @@ -1197,7 +1199,7 @@ def to_tensor_4x4(self) -> torch.Tensor: return tensor @staticmethod - def from_tensor_4x4(t: torch.Tensor): + def from_tensor_4x4(t: torch.Tensor) -> "Rigid": """ Constructs a transformation from a homogenous transformation tensor. @@ -1233,7 +1235,7 @@ def to_tensor_7(self) -> torch.Tensor: def from_tensor_7( t: torch.Tensor, normalize_quats: bool = False, - ): + ) -> "Rigid": if t.shape[-1] != 7: raise ValueError("Incorrectly shaped input tensor") @@ -1249,7 +1251,7 @@ def from_3_points( origin: torch.Tensor, p_xy_plane: torch.Tensor, eps: float = 1e-4, - ): + ) -> "Rigid": """ Implements algorithm 21. Constructs transformations from sets of 3 points using the Gram-Schmidt algorithm. @@ -1291,7 +1293,7 @@ def from_3_points( def unsqueeze( self, dim: int, - ): + ) -> "Rigid": """ Analogous to torch.unsqueeze. The dimension is relative to the shared dimensions of the rotation/translation. @@ -1310,9 +1312,9 @@ def unsqueeze( @staticmethod def cat( - ts, + ts: list["Rigid"], dim: int, - ): + ) -> "Rigid": """ Concatenates transformations along a new dimension. @@ -1330,7 +1332,7 @@ def cat( return Rigid(rots, trans) - def apply_rot_fn(self, fn): + def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> "Rigid": """ Applies a Rotation -> Rotation function to the stored rotation object. @@ -1342,7 +1344,7 @@ def apply_rot_fn(self, fn): """ return Rigid(fn(self._rots), self._trans) - def apply_trans_fn(self, fn): + def apply_trans_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> "Rigid": """ Applies a Tensor -> Tensor function to the stored translation. @@ -1355,7 +1357,7 @@ def apply_trans_fn(self, fn): """ return Rigid(self._rots, fn(self._trans)) - def scale_translation(self, trans_scale_factor: float): + def scale_translation(self, trans_scale_factor: float) -> "Rigid": """ Scales the translation by a constant factor. @@ -1368,7 +1370,7 @@ def scale_translation(self, trans_scale_factor: float): fn = lambda t: t * trans_scale_factor # noqa: E731 return self.apply_trans_fn(fn) - def stop_rot_gradient(self): + def stop_rot_gradient(self) -> "Rigid": """ Detaches the underlying rotation object @@ -1379,7 +1381,12 @@ def stop_rot_gradient(self): return self.apply_rot_fn(fn) @staticmethod - def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): + def make_transform_from_reference( + n_xyz: torch.Tensor, + ca_xyz: torch.Tensor, + c_xyz: torch.Tensor, + eps: float = 1e-20, + ) -> "Rigid": """ Returns a transformation object from reference coordinates. @@ -1449,7 +1456,7 @@ def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): return Rigid(rot_obj, translation) - def cuda(self): + def cuda(self) -> "Rigid": """ Moves the transformation object to GPU memory diff --git a/tests/test_rigid.py b/tests/test_rigid.py new file mode 100644 index 00000000..435cef6a --- /dev/null +++ b/tests/test_rigid.py @@ -0,0 +1,340 @@ +"""Unit tests for foundry.utils.rigid. + +`Rotation` and `Rigid` are the OpenFold-derived SE(3) frame primitives used across +the models for backbone/atom frames and IPA-style structure updates. Their contracts +(quaternion <-> rotation-matrix round-trips, composition order, inversion, and the +homogeneous/7-vector tensor encodings) are not obvious from the signatures, so the +tests below pin them on small CPU inputs. + +Both classes force float32 internally, so reference values are computed in float32 and +compared with a loose tolerance (the quaternion path goes through torch.linalg.eigh). +""" + +import pytest +import torch + +from foundry.utils.rigid import ( + Rigid, + Rotation, + identity_quats, + identity_rot_mats, + identity_trans, + invert_quat, + invert_rot_mat, + quat_multiply, + quat_to_rot, + rot_matmul, + rot_to_quat, + rot_vec_mul, +) + +ATOL = 1e-4 + + +def _rodrigues(axis: torch.Tensor, angle: float) -> torch.Tensor: + """Reference proper rotation about `axis` by `angle` (radians), independent of rigid.py.""" + u = axis / torch.linalg.norm(axis) + a = torch.tensor(angle, dtype=torch.float32) + K = torch.tensor( + [ + [0.0, -u[2], u[1]], + [u[2], 0.0, -u[0]], + [-u[1], u[0], 0.0], + ] + ) + eye = torch.eye(3) + return eye + torch.sin(a) * K + (1 - torch.cos(a)) * (K @ K) + + +def _axis_angle_quat(axis: torch.Tensor, angle: float) -> torch.Tensor: + """Unit quaternion (w, x, y, z) for a rotation about `axis` by `angle` (radians).""" + u = axis / torch.linalg.norm(axis) + half = torch.tensor(angle / 2.0, dtype=torch.float32) + return torch.cat([torch.cos(half).reshape(1), torch.sin(half) * u]) + + +def _random_rotations(n: int, seed: int) -> torch.Tensor: + """`n` random proper rotation matrices [n, 3, 3] via QR with a determinant fix-up.""" + torch.manual_seed(seed) + q, _ = torch.linalg.qr(torch.randn(n, 3, 3)) + # Multiplying one column by det(q) = +-1 forces det -> +1 (a proper rotation). + det = torch.linalg.det(q) + q[..., -1] = q[..., -1] * det.unsqueeze(-1) + return q + + +# --- module-level helpers -------------------------------------------------------------- + + +def test_rot_matmul_matches_matrix_product(): + a = _random_rotations(4, 0) + b = _random_rotations(4, 1) + assert torch.allclose(rot_matmul(a, b), a @ b, atol=ATOL) + + +def test_rot_vec_mul_matches_matvec(): + r = _random_rotations(5, 2) + torch.manual_seed(3) + v = torch.randn(5, 3) + expected = torch.einsum("...ij,...j->...i", r, v) + assert torch.allclose(rot_vec_mul(r, v), expected, atol=ATOL) + + +def test_quat_to_rot_identity_quat_is_identity_matrix(): + quat = torch.tensor([1.0, 0.0, 0.0, 0.0]) + assert torch.allclose(quat_to_rot(quat), torch.eye(3), atol=ATOL) + + +def test_quat_to_rot_matches_axis_angle(): + axis = torch.tensor([0.3, -0.7, 0.5]) + angle = 1.1 + quat = _axis_angle_quat(axis, angle) + assert torch.allclose(quat_to_rot(quat), _rodrigues(axis, angle), atol=ATOL) + + +def test_rot_to_quat_roundtrips_through_quat_to_rot(): + r = _random_rotations(6, 4) + recovered = quat_to_rot(rot_to_quat(r)) + assert torch.allclose(recovered, r, atol=ATOL) + + +def test_rot_to_quat_rejects_bad_shape(): + with pytest.raises(ValueError): + rot_to_quat(torch.randn(3, 2)) + + +def test_quat_multiply_composes_rotations(): + q1 = _axis_angle_quat(torch.tensor([0.0, 0.0, 1.0]), 0.6) + q2 = _axis_angle_quat(torch.tensor([1.0, 0.0, 0.0]), -0.9) + composed = quat_to_rot(quat_multiply(q1, q2)) + assert torch.allclose(composed, quat_to_rot(q1) @ quat_to_rot(q2), atol=ATOL) + + +def test_invert_rot_mat_is_transpose_and_true_inverse(): + r = _random_rotations(3, 5) + assert torch.allclose(invert_rot_mat(r), r.transpose(-1, -2), atol=ATOL) + prod = rot_matmul(r, invert_rot_mat(r)) + assert torch.allclose(prod, torch.eye(3).expand(3, 3, 3), atol=ATOL) + + +def test_invert_quat_yields_identity_quaternion(): + q = _axis_angle_quat(torch.tensor([0.2, 0.5, -0.4]), 0.8) + prod = quat_multiply(q, invert_quat(q)) + assert torch.allclose(prod, torch.tensor([1.0, 0.0, 0.0, 0.0]), atol=ATOL) + + +def test_identity_helpers(): + assert torch.allclose(identity_rot_mats((2,)), torch.eye(3).expand(2, 3, 3)) + quats = identity_quats((2,)) + assert torch.allclose(quats, torch.tensor([1.0, 0.0, 0.0, 0.0]).expand(2, 4)) + assert torch.count_nonzero(identity_trans((2,))) == 0 + + +# --- Rotation -------------------------------------------------------------------------- + + +def test_rotation_quat_and_matrix_formats_agree(): + quat = _axis_angle_quat(torch.tensor([0.1, 0.2, 0.9]), 1.3) + from_quat = Rotation(quats=quat) + from_mat = Rotation(rot_mats=quat_to_rot(quat)) + # get_rot_mats() of the quat-format object matches the matrix it encodes... + assert torch.allclose(from_quat.get_rot_mats(), from_mat.get_rot_mats(), atol=ATOL) + # ...and get_quats() of the matrix-format object round-trips back to the same rotation. + assert torch.allclose( + quat_to_rot(from_mat.get_quats()), from_mat.get_rot_mats(), atol=ATOL + ) + + +def test_rotation_requires_exactly_one_input(): + with pytest.raises(ValueError): + Rotation() + with pytest.raises(ValueError): + Rotation(rot_mats=torch.eye(3), quats=torch.tensor([1.0, 0.0, 0.0, 0.0])) + + +def test_rotation_apply_and_invert_apply_roundtrip(): + rot = Rotation(rot_mats=_random_rotations(4, 6)) + torch.manual_seed(7) + pts = torch.randn(4, 3) + assert torch.allclose( + rot.apply(pts), rot_vec_mul(rot.get_rot_mats(), pts), atol=ATOL + ) + assert torch.allclose(rot.invert_apply(rot.apply(pts)), pts, atol=ATOL) + + +def test_rotation_compose_r_matches_sequential_apply(): + r1 = Rotation(rot_mats=_random_rotations(3, 8)) + r2 = Rotation(rot_mats=_random_rotations(3, 9)) + torch.manual_seed(10) + pts = torch.randn(3, 3) + composed = r1.compose_r(r2).apply(pts) + assert torch.allclose(composed, r1.apply(r2.apply(pts)), atol=ATOL) + + +def test_rotation_compose_q_matches_compose_r(): + q1 = Rotation(quats=_axis_angle_quat(torch.tensor([0.0, 1.0, 0.0]), 0.7)) + q2 = Rotation(quats=_axis_angle_quat(torch.tensor([1.0, 0.0, 1.0]), -0.5)) + via_q = q1.compose_q(q2).get_rot_mats() + via_r = q1.compose_r(q2).get_rot_mats() + assert torch.allclose(via_q, via_r, atol=ATOL) + + +def test_rotation_invert_matrix_and_quat_formats(): + torch.manual_seed(11) + pts = torch.randn(2, 3) + rot_mat = Rotation(rot_mats=_random_rotations(2, 12)) + assert torch.allclose(rot_mat.invert().apply(rot_mat.apply(pts)), pts, atol=ATOL) + rot_quat = Rotation(quats=_axis_angle_quat(torch.tensor([0.4, 0.4, 0.8]), 1.0)) + single = torch.randn(3) + assert torch.allclose( + rot_quat.invert().apply(rot_quat.apply(single)), single, atol=ATOL + ) + + +@pytest.mark.parametrize("fmt", ["quat", "rot_mat"]) +def test_rotation_identity_apply_is_noop(fmt): + rot = Rotation.identity((5,), fmt=fmt) + torch.manual_seed(13) + pts = torch.randn(5, 3) + assert torch.allclose(rot.apply(pts), pts, atol=ATOL) + + +def test_rotation_getitem_slices_virtual_shape(): + mats = _random_rotations(4, 14) + rot = Rotation(rot_mats=mats) + sliced = rot[1:3] + assert sliced.shape == (2,) + assert torch.allclose(sliced.get_rot_mats(), mats[1:3], atol=ATOL) + + +def test_rotation_mul_by_mask_zeroes_entries(): + mats = _random_rotations(3, 15) + rot = Rotation(rot_mats=mats) + mask = torch.tensor([1.0, 0.0, 1.0]) + masked = (rot * mask).get_rot_mats() + assert torch.allclose(masked, mats * mask[..., None, None], atol=ATOL) + + +def test_get_rotvec_returns_axis_times_angle(): + axis = torch.tensor([0.0, 0.0, 1.0]) + angle = 1.2 + rot = Rotation(quats=_axis_angle_quat(axis, angle)) + assert torch.allclose(rot.get_rotvec(), axis * angle, atol=1e-3) + identity = Rotation.identity((), fmt="quat") + assert torch.allclose(identity.get_rotvec(), torch.zeros(3), atol=ATOL) + + +# --- Rigid ----------------------------------------------------------------------------- + + +def test_rigid_apply_is_rotate_then_translate(): + rot = Rotation(rot_mats=_random_rotations(4, 16)) + trans = torch.randn(4, 3) + rigid = Rigid(rot, trans) + torch.manual_seed(17) + pts = torch.randn(4, 3) + expected = rot.apply(pts) + trans + assert torch.allclose(rigid.apply(pts), expected, atol=ATOL) + + +def test_rigid_invert_apply_roundtrip(): + rigid = Rigid(Rotation(rot_mats=_random_rotations(3, 18)), torch.randn(3, 3)) + torch.manual_seed(19) + pts = torch.randn(3, 3) + assert torch.allclose(rigid.invert_apply(rigid.apply(pts)), pts, atol=ATOL) + + +def test_rigid_compose_matches_sequential_apply(): + t1 = Rigid(Rotation(rot_mats=_random_rotations(2, 20)), torch.randn(2, 3)) + t2 = Rigid(Rotation(rot_mats=_random_rotations(2, 21)), torch.randn(2, 3)) + torch.manual_seed(22) + pts = torch.randn(2, 3) + assert torch.allclose(t1.compose(t2).apply(pts), t1.apply(t2.apply(pts)), atol=ATOL) + + +def test_rigid_invert_inverts_apply(): + rigid = Rigid(Rotation(rot_mats=_random_rotations(2, 23)), torch.randn(2, 3)) + torch.manual_seed(24) + pts = torch.randn(2, 3) + assert torch.allclose(rigid.invert().apply(rigid.apply(pts)), pts, atol=ATOL) + + +def test_rigid_identity_apply_is_noop(): + rigid = Rigid.identity((4,)) + torch.manual_seed(25) + pts = torch.randn(4, 3) + assert torch.allclose(rigid.apply(pts), pts, atol=ATOL) + + +def test_rigid_to_tensor_4x4_structure_and_roundtrip(): + rot = Rotation(rot_mats=_random_rotations(3, 26)) + trans = torch.randn(3, 3) + rigid = Rigid(rot, trans) + t = rigid.to_tensor_4x4() + assert t.shape == (3, 4, 4) + assert torch.allclose(t[..., :3, :3], rot.get_rot_mats(), atol=ATOL) + assert torch.allclose(t[..., :3, 3], trans, atol=ATOL) + assert torch.allclose(t[..., 3, 3], torch.ones(3), atol=ATOL) + torch.manual_seed(27) + pts = torch.randn(3, 3) + rebuilt = Rigid.from_tensor_4x4(t) + assert torch.allclose(rebuilt.apply(pts), rigid.apply(pts), atol=ATOL) + + +def test_rigid_from_tensor_4x4_rejects_bad_shape(): + with pytest.raises(ValueError): + Rigid.from_tensor_4x4(torch.randn(3, 3)) + + +def test_rigid_to_from_tensor_7_roundtrip(): + rot = Rotation(quats=_axis_angle_quat(torch.tensor([0.3, 0.6, 0.2]), 0.9)) + trans = torch.randn(3) + rigid = Rigid(rot, trans) + t = rigid.to_tensor_7() + assert t.shape == (7,) + rebuilt = Rigid.from_tensor_7(t) + torch.manual_seed(28) + pts = torch.randn(3) + assert torch.allclose(rebuilt.apply(pts), rigid.apply(pts), atol=ATOL) + + +def test_rigid_from_tensor_7_rejects_bad_shape(): + with pytest.raises(ValueError): + Rigid.from_tensor_7(torch.randn(6)) + + +def test_rigid_compose_q_update_vec_zero_update_is_noop(): + rigid = Rigid( + Rotation(quats=_axis_angle_quat(torch.tensor([0.1, 0.2, 0.3]), 0.5)), + torch.randn(3), + ) + updated = rigid.compose_q_update_vec(torch.zeros(6)) + torch.manual_seed(29) + pts = torch.randn(3) + assert torch.allclose(updated.apply(pts), rigid.apply(pts), atol=ATOL) + + +def test_rigid_from_3_points_builds_orthonormal_frame(): + torch.manual_seed(30) + p_neg_x = torch.randn(3) + origin = torch.randn(3) + p_xy = torch.randn(3) + rigid = Rigid.from_3_points(p_neg_x, origin, p_xy) + rot = rigid.get_rots().get_rot_mats() + # Proper orthonormal rotation. + assert torch.allclose(rot @ rot.transpose(-1, -2), torch.eye(3), atol=ATOL) + assert torch.allclose(torch.linalg.det(rot), torch.tensor(1.0), atol=ATOL) + # The origin maps to the frame origin, and p_neg_x lies on the frame's negative x-axis. + assert torch.allclose(rigid.invert_apply(origin), torch.zeros(3), atol=ATOL) + local_neg_x = rigid.invert_apply(p_neg_x) + assert local_neg_x[0] < 0 + assert torch.allclose(local_neg_x[1:], torch.zeros(2), atol=ATOL) + + +def test_rigid_cat_and_unsqueeze_shapes(): + a = Rigid(Rotation(rot_mats=_random_rotations(3, 31)), torch.randn(3, 3)) + b = Rigid(Rotation(rot_mats=_random_rotations(3, 32)), torch.randn(3, 3)) + cat = Rigid.cat([a, b], dim=0) + assert cat.shape == (6,) + assert torch.allclose(cat[:3].get_trans(), a.get_trans(), atol=ATOL) + assert a.unsqueeze(0).shape == (1, 3) From faa515635d3da8f51ad0f60b15c8ea90a443969c Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Wed, 10 Jun 2026 01:36:39 +0000 Subject: [PATCH 09/10] chore(foundry): annotate components.py for strict mypy and add parsing tests Bring src/foundry/utils/components.py under the per-module disallow_untyped_defs + check_untyped_defs override by annotating the contig parsers and mask getters. Annotation-only apart from two behaviour-preserving refactors (collapse a str->int variable reuse in split_contig; store fixed_parts as tuples to avoid object-typed heterogeneous lists). Also widen get_name_mask's query_names to str | list[str] to match the isinstance branch and docstring. Add tests/test_components.py: 25 unit tests pinning the contig/component parsing grammar (split_contig, extract_pn_unit_info, get_motif_components_and_breaks, get_design_pattern_with_constraints) and the get_name_mask atom selector. Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- pyproject.toml | 2 +- src/foundry/utils/components.py | 41 +++---- tests/test_components.py | 186 ++++++++++++++++++++++++++++++++ 3 files changed, 210 insertions(+), 19 deletions(-) create mode 100644 tests/test_components.py diff --git a/pyproject.toml b/pyproject.toml index 0cb1f652..0372f650 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -258,7 +258,7 @@ ignore_errors = true # disallow_untyped_defs / check_untyped_defs off; fully-annotated modules opt into strict # checking here one at a time, so any new untyped def in them fails the gate. [[tool.mypy.overrides]] -module = ["foundry.utils.rigid"] +module = ["foundry.utils.rigid", "foundry.utils.components"] disallow_untyped_defs = true check_untyped_defs = true diff --git a/src/foundry/utils/components.py b/src/foundry/utils/components.py index 5a258a0e..64932dad 100644 --- a/src/foundry/utils/components.py +++ b/src/foundry/utils/components.py @@ -41,15 +41,14 @@ def __init__( class ComponentStr(str): """Component identifier, e.g. "A1" for residues, "B12", etc. Previously named `contig_string`""" - def split_component(v): + def split_component(v) -> list: return split_contig(v) -def split_contig(x): +def split_contig(x: str) -> list: try: chain = str(x[0]) - idx = x[1:] - idx = int(idx) + idx = int(x[1:]) if idx < 0: raise ComponentValidationError( "Residue index must be a non-negative integer.", component=str(x) @@ -62,7 +61,7 @@ def split_contig(x): return [chain, idx] -def extract_pn_unit_info(contig): +def extract_pn_unit_info(contig: str) -> tuple[str, int, int]: """ Convert substring like A20-21 or A20 to separate terms: A, 20, 21. """ @@ -81,7 +80,9 @@ def extract_pn_unit_info(contig): ) -def get_design_pattern_with_constraints(contig, length=None): +def get_design_pattern_with_constraints( + contig: str, length: str | None = None +) -> list[str]: """ Convert the contig string to separate modules. e.g. '1-5,A20-21,1-5,A25-25,1-5,A30-30,/0,1-5' with length = 10-10 may be converted to [2, A20, A21, 2, A25, 3, A30, /0, 3] @@ -91,8 +92,8 @@ def get_design_pattern_with_constraints(contig, length=None): contig_parts = contig.split(",") # Separate fixed segments (e.g., "A1051-1051") and variable ranges (e.g., "0-40") - variable_ranges = [] - fixed_parts = [] + variable_ranges: list[list[int]] = [] + fixed_parts: list[tuple[str, int, int]] = [] pos_to_put_motif = [] suff = [] # suffixes for diffused regions P(optional),R,D @@ -111,7 +112,7 @@ def get_design_pattern_with_constraints(contig, length=None): elif any(c.isalpha() for c in part): # Detect parts containing letters as fixed pn_unit_id, pn_unit_start, pn_unit_end = extract_pn_unit_info(part) - fixed_parts.append([pn_unit_id, pn_unit_start, pn_unit_end]) + fixed_parts.append((pn_unit_id, pn_unit_start, pn_unit_end)) pos_to_put_motif.append(1) elif part == "/0": pos_to_put_motif.append(2) @@ -141,7 +142,7 @@ def get_design_pattern_with_constraints(contig, length=None): remaining_length_min = length_min remaining_length_max = length_max - num_free_atoms = [] + num_free_atoms: list[int] = [] for range_limits in variable_ranges: min_value = range_limits[0] max_value = range_limits[1] @@ -187,7 +188,9 @@ def get_design_pattern_with_constraints(contig, length=None): return atoms_with_motif -def get_motif_components_and_breaks(unindexed_contig, index_all=False): +def get_motif_components_and_breaks( + unindexed_contig: str, index_all: bool = False +) -> tuple[list[str], list[bool | None]]: """ Convert a contig string into its components and breaks in motif This way you can specify in your contigs where the breaks in the motif should be, so that, @@ -203,8 +206,8 @@ def get_motif_components_and_breaks(unindexed_contig, index_all=False): index_all: No breaks are used, allows for full indexing of concatenated tokens Can use cleanup if this is the desired way to provide motif tokens. """ - components = [] - breaks = [] + components: list[str] = [] + breaks: list[bool | None] = [] contig_parts = unindexed_contig.split(",") for part in contig_parts: @@ -249,8 +252,10 @@ def get_motif_components_and_breaks(unindexed_contig, index_all=False): def get_name_mask( - source_names: np.ndarray, query_names: str, source_resname: str | None = None -): + source_names: np.ndarray, + query_names: str | list[str], + source_resname: str | None = None, +) -> np.ndarray: """ Args: source_names: list of all names to match in current token @@ -344,7 +349,7 @@ def get_name_mask( return mask -def fetch_mask_from_idx(contig_str, *, atom_array): +def fetch_mask_from_idx(contig_str: str, *, atom_array: AtomArray) -> np.ndarray: """ contig_str: A11 returns: @@ -360,7 +365,7 @@ def fetch_mask_from_idx(contig_str, *, atom_array): return mask -def fetch_mask_from_name(name, *, atom_array): +def fetch_mask_from_name(name: str, *, atom_array: AtomArray) -> np.ndarray: """ name: LIG_NAME returns: @@ -379,7 +384,7 @@ def fetch_mask_from_name(name, *, atom_array): return mask -def fetch_mask_from_component(component, *, atom_array): +def fetch_mask_from_component(component: str, *, atom_array: AtomArray) -> np.ndarray: """ Catch-all function for fetching a component by non-protein name or contig component: A11 or LIG_NAME diff --git a/tests/test_components.py b/tests/test_components.py new file mode 100644 index 00000000..d5647376 --- /dev/null +++ b/tests/test_components.py @@ -0,0 +1,186 @@ +"""Unit tests for foundry.utils.components contig/component parsing. + +These string parsers turn user-facing contig specifications (e.g. "A14-15,A16", +"5-10,A20-21") into the component/break/free-residue structures the rfd3/rfd3na +inference paths consume. Their grammar is non-obvious from the signatures — leading +chain letter, optional ranges, the "/0" chain-break token, and the R/D/P diffusion +suffixes — so the tests below pin the documented behaviour and the validation errors +on small inputs. `get_name_mask` is the pure atom-name selector (ALL/BKBN/explicit). +""" + +import random + +import numpy as np +import pytest + +from foundry.utils.components import ( + ComponentValidationError, + extract_pn_unit_info, + get_design_pattern_with_constraints, + get_motif_components_and_breaks, + get_name_mask, + split_contig, +) + +# --- split_contig ---------------------------------------------------------------------- + + +def test_split_contig_parses_chain_and_index(): + assert split_contig("A20") == ["A", 20] + assert split_contig("B0") == ["B", 0] + + +def test_split_contig_rejects_negative_index(): + with pytest.raises(ComponentValidationError): + split_contig("A-5") + + +def test_split_contig_rejects_malformed(): + with pytest.raises(ComponentValidationError): + split_contig("AB") + + +# --- extract_pn_unit_info -------------------------------------------------------------- + + +def test_extract_pn_unit_info_range(): + assert extract_pn_unit_info("A20-21") == ("A", 20, 21) + + +def test_extract_pn_unit_info_single_residue_duplicates_bound(): + assert extract_pn_unit_info("Z5") == ("Z", 5, 5) + + +def test_extract_pn_unit_info_rejects_missing_chain(): + with pytest.raises(ComponentValidationError): + extract_pn_unit_info("123") + + +# --- get_motif_components_and_breaks --------------------------------------------------- + + +def test_motif_breaks_all_single_residues_break_between_each(): + # Documented example: each comma-separated residue is its own component, all broken. + components, breaks = get_motif_components_and_breaks("A14,A15,A16") + assert components == ["A14", "A15", "A16"] + assert breaks == [True, True, True] + + +def test_motif_breaks_range_keeps_interior_glued(): + # Documented example: a range stays glued (break only at its start). + components, breaks = get_motif_components_and_breaks("A14-15,A16") + assert components == ["A14", "A15", "A16"] + assert breaks == [True, False, True] + + +def test_motif_breaks_index_all_drops_internal_breaks(): + components, breaks = get_motif_components_and_breaks("A14,A15,A16", index_all=True) + assert components == ["A14", "A15", "A16"] + assert breaks == [False, False, False] + + +def test_motif_breaks_chain_break_token_has_none_break(): + components, breaks = get_motif_components_and_breaks("A14,/0") + assert components == ["A14", "/0"] + assert breaks == [True, None] + + +def test_motif_breaks_rejects_partial_unindexing_range(): + with pytest.raises(ComponentValidationError): + get_motif_components_and_breaks("A14,5-6") + + +# --- get_design_pattern_with_constraints ----------------------------------------------- + + +def test_design_pattern_free_then_fixed_motif(): + random.seed(0) + # "5-5" is a fixed-width free segment (-> "5P"); "A20-21" expands to two fixed residues. + assert get_design_pattern_with_constraints("5-5,A20-21") == ["5P", "A20", "A21"] + + +def test_design_pattern_rna_suffix_preserved(): + random.seed(0) + # A trailing "R"/"D" marks a non-fixed RNA/DNA segment and the suffix is carried through. + assert get_design_pattern_with_constraints("3-3R,A5") == ["3R", "A5"] + + +def test_design_pattern_chain_break_token_passthrough(): + random.seed(0) + assert get_design_pattern_with_constraints("A1,/0,A2") == ["A1", "/0", "A2"] + + +def test_design_pattern_threads_total_length(): + random.seed(0) + # length="6" minus the 1 motif residue leaves exactly 5 free -> the "5-5" segment. + assert get_design_pattern_with_constraints("5-5,A20", length="6") == ["5P", "A20"] + + +def test_design_pattern_raises_when_length_infeasible(): + random.seed(0) + # 12 total - 1 motif = 11 free required, but the segment caps at 10. + with pytest.raises(ComponentValidationError): + get_design_pattern_with_constraints("5-10,A20", length="12") + + +# --- get_name_mask --------------------------------------------------------------------- + + +def _atom_names(*names): + return np.array(names) + + +def test_get_name_mask_all(): + names = _atom_names("N", "CA", "C", "O", "CB") + assert get_name_mask(names, "ALL").tolist() == [True] * 5 + + +def test_get_name_mask_backbone_excludes_cb(): + names = _atom_names("N", "CA", "C", "O", "CB") + assert get_name_mask(names, "BKBN").tolist() == [True, True, True, True, False] + + +def test_get_name_mask_explicit_comma_string(): + names = _atom_names("N", "CA", "C", "O", "CB") + assert get_name_mask(names, "N,CA").tolist() == [True, True, False, False, False] + + +def test_get_name_mask_accepts_list_of_names(): + names = _atom_names("N", "CA", "C", "O", "CB") + assert get_name_mask(names, ["N", "CA"]).tolist() == [ + True, + True, + False, + False, + False, + ] + + +def test_get_name_mask_empty_string_selects_nothing(): + names = _atom_names("N", "CA", "C") + assert get_name_mask(names, "").tolist() == [False, False, False] + + +def test_get_name_mask_rejects_duplicate_names(): + names = _atom_names("N", "CA", "C") + with pytest.raises(ComponentValidationError): + get_name_mask(names, "N,N") + + +def test_get_name_mask_rejects_missing_names(): + names = _atom_names("N", "CA", "C") + with pytest.raises(ComponentValidationError): + get_name_mask(names, "XYZ") + + +def test_get_name_mask_tip_requires_resname(): + names = _atom_names("N", "CA", "C") + with pytest.raises(ComponentValidationError): + get_name_mask(names, "TIP", source_resname=None) + + +def test_get_name_mask_rejects_non_multiple_atom_count(): + # Two N's but one CA: the match count (3) is not a multiple of the requested names (2). + names = _atom_names("N", "N", "CA") + with pytest.raises(ComponentValidationError): + get_name_mask(names, "N,CA") From 6632ba55c95ad312b1bf2aff8a64c566a065147e Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Wed, 10 Jun 2026 02:52:55 +0000 Subject: [PATCH 10/10] chore(foundry): annotate torch utils for strict mypy and add unit tests Bring src/foundry/utils/torch.py under the per-module disallow_untyped_defs + check_untyped_defs override by annotating map_to's kwargs, the tracer-warning contextmanager, assert_shape, and the Timer/Timers methods. Annotation-only; the two warnings.filters mutations carry a documented type: ignore[attr-defined] (typeshed types it immutable, but it is a list at runtime). Extend tests/test_torch_utils.py with 12 tests for the untested pure helpers: scatter_mean, assert_shape/assert_same_shape, device_of, and Timer/Timers. Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- pyproject.toml | 2 +- src/foundry/utils/torch.py | 30 ++++++----- tests/test_torch_utils.py | 103 ++++++++++++++++++++++++++++++++++++- 3 files changed, 119 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0372f650..42f57087 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -258,7 +258,7 @@ ignore_errors = true # disallow_untyped_defs / check_untyped_defs off; fully-annotated modules opt into strict # checking here one at a time, so any new untyped def in them fails the gate. [[tool.mypy.overrides]] -module = ["foundry.utils.rigid", "foundry.utils.components"] +module = ["foundry.utils.rigid", "foundry.utils.components", "foundry.utils.torch"] disallow_untyped_defs = true check_untyped_defs = true diff --git a/src/foundry/utils/torch.py b/src/foundry/utils/torch.py index 34cacd64..43ad41a4 100755 --- a/src/foundry/utils/torch.py +++ b/src/foundry/utils/torch.py @@ -14,7 +14,7 @@ import numpy as np import torch -from beartype.typing import Any, Sequence +from beartype.typing import Any, Iterator, Sequence from toolz import valmap from torch import Tensor from torch._prims_common import DeviceLikeType @@ -30,7 +30,7 @@ def map_to( device: DeviceLikeType | None = None, dtype: _dtype | None = None, non_blocking: bool = False, - **to_kwargs, + **to_kwargs: Any, ) -> Any: """ Recursively applies the `.to()` method to all tensors in a nested structure. @@ -135,7 +135,7 @@ def _assert_no_nans(x: Any, *, msg: str = "", fail_if_not_tensor: bool = False) @contextmanager -def _suppress_tracer_warnings(): +def _suppress_tracer_warnings() -> Iterator[None]: """ Context manager to temporarily suppress known warnings in torch.jit.trace(). Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 @@ -144,12 +144,14 @@ def _suppress_tracer_warnings(): - https://github.com/NVlabs/edm2/blob/main/torch_utils/misc.py """ tracer_warning_filter = ("ignore", None, torch.jit.TracerWarning, None, 0) - warnings.filters.insert(0, tracer_warning_filter) + # warnings.filters is typed as an immutable Sequence in typeshed but is a mutable + # list at runtime, so insert/remove are valid. + warnings.filters.insert(0, tracer_warning_filter) # type: ignore[attr-defined] yield - warnings.filters.remove(tracer_warning_filter) + warnings.filters.remove(tracer_warning_filter) # type: ignore[attr-defined] -def assert_shape(tensor: Tensor, ref_shape: Sequence[int | None]): +def assert_shape(tensor: Tensor, ref_shape: Sequence[int | None]) -> None: """ Assert that the shape of a tensor matches the given list of integers. None indicates that the size of a dimension is allowed to vary. @@ -276,7 +278,7 @@ class Timer: use_barrier (bool, optional): Whether to use synchronization barriers. Defaults to True. """ - def __init__(self, name, use_barrier: bool = True): + def __init__(self, name: str, use_barrier: bool = True): self.name_ = name self.elapsed_ = 0.0 self.started_ = False @@ -336,30 +338,30 @@ class Timers: timers (dict): A dictionary of Timer objects, keyed by their names. """ - def __init__(self): - self.timers = {} + def __init__(self) -> None: + self.timers: dict[str, Timer] = {} - def __call__(self, name, use_barrier: bool = True) -> Timer: + def __call__(self, name: str, use_barrier: bool = True) -> Timer: """Get or create a Timer object.""" if name not in self.timers: self.timers[name] = Timer(name, use_barrier=use_barrier) return self.timers[name] - def start(self, *names) -> None: + def start(self, *names: str) -> None: """Start the specified timers.""" for name in names: self(name).start() - def stop(self, *names) -> None: + def stop(self, *names: str) -> None: """Stop the specified timers.""" for name in names: self.timers[name].stop() - def reset(self, *names) -> None: + def reset(self, *names: str) -> None: """Reset the specified timers.""" for name in names: self.timers[name].reset() - def elapsed(self, *names, reset: bool = True) -> dict[str, float]: + def elapsed(self, *names: str, reset: bool = True) -> dict[str, float]: """Get the elapsed time for the specified timers.""" return {name: self.timers[name].elapsed(reset=reset) for name in names} diff --git a/tests/test_torch_utils.py b/tests/test_torch_utils.py index fb5c26e2..2e4d5741 100644 --- a/tests/test_torch_utils.py +++ b/tests/test_torch_utils.py @@ -4,7 +4,16 @@ import torch os.environ["NAN_CHECKING"] = "True" -from foundry.utils.torch import assert_no_nans, map_to +from foundry.utils.torch import ( + Timer, + Timers, + assert_no_nans, + assert_same_shape, + assert_shape, + device_of, + map_to, + scatter_mean, +) def test_map_to(): @@ -128,5 +137,97 @@ def test_assert_no_nans(): assert_no_nans({"a": torch.tensor([1.0, float("nan")])}, msg="custom") +def test_scatter_mean_averages_by_index(): + # Rows 0 and 1 of source map to output row 0 (averaged); row 2 maps to output row 2; + # output row 1 receives nothing and stays at its (zero) self value (include_self=False). + zeros = torch.zeros(3, 2) + index = torch.tensor([0, 0, 2]) + source = torch.tensor([[1.0, 1.0], [3.0, 3.0], [5.0, 5.0]]) + out = scatter_mean(zeros, 0, index, source) + expected = torch.tensor([[2.0, 2.0], [0.0, 0.0], [5.0, 5.0]]) + assert torch.allclose(out, expected) + + +def test_scatter_mean_matches_index_reduce(): + torch.manual_seed(0) + zeros = torch.zeros(4, 3) + index = torch.tensor([0, 1, 1, 3, 3, 3]) + source = torch.randn(6, 3) + out = scatter_mean(zeros, 0, index, source) + expected = zeros.index_reduce(0, index, source, "mean", include_self=False) + assert torch.allclose(out, expected) + + +def test_scatter_mean_does_not_mutate_input(): + zeros = torch.zeros(2, 2) + index = torch.tensor([0, 1]) + source = torch.tensor([[1.0, 1.0], [2.0, 2.0]]) + scatter_mean(zeros, 0, index, source) + assert torch.all(zeros == 0) + + +def test_assert_shape_matches_exact_and_wildcard(): + t = torch.zeros(2, 3, 4) + assert_shape(t, [2, 3, 4]) # exact match, no raise + assert_shape(t, [2, None, 4]) # None leaves that dimension free + assert_shape(t, [None, None, None]) + + +def test_assert_shape_wrong_ndim_raises(): + with pytest.raises(AssertionError): + assert_shape(torch.zeros(2, 3), [2, 3, 4]) + + +def test_assert_shape_wrong_size_raises(): + with pytest.raises(AssertionError): + assert_shape(torch.zeros(2, 3), [2, 4]) + + +def test_assert_same_shape(): + assert_same_shape(torch.zeros(2, 3), torch.ones(2, 3)) # no raise + with pytest.raises(AssertionError): + assert_same_shape(torch.zeros(2, 3), torch.zeros(2, 4)) + + +def test_device_of_tensor_and_module(): + assert device_of(torch.zeros(3)).type == "cpu" + assert device_of(torch.nn.Linear(2, 2)).type == "cpu" + + +def test_device_of_unsupported_raises(): + with pytest.raises(ValueError): + device_of(42) + + +def test_timer_start_stop_guards(): + timer = Timer("t", use_barrier=False) + timer.start() + with pytest.raises(AssertionError): + timer.start() # already started + timer.stop() + with pytest.raises(AssertionError): + timer.stop() # not started + + +def test_timer_reset_zeroes_elapsed(): + timer = Timer("t", use_barrier=False) + timer.start() + timer.stop() + assert timer.elapsed(reset=False) >= 0.0 + timer.reset() + assert timer.elapsed() == 0.0 + + +def test_timers_dispatch_and_elapsed_dict(): + timers = Timers() + a = timers("a", use_barrier=False) + assert timers("a") is a # same name returns the same Timer + timers.start("a") + timers.stop("a") + result = timers.elapsed("a") + assert set(result.keys()) == {"a"} + assert result["a"] >= 0.0 + + if __name__ == "__main__": pytest.main(["-v", __file__])