Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions .github/workflows/test_integration.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: integration-tests

on:
push:
branches: [main, production]
pull_request:
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
rf3-integration:
name: pytest (rf3 integration)
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: "3.12"
cache: pip
cache-dependency-path: pyproject.toml

- name: Install package + dev extras
run: pip install -e ".[dev]"

- name: Cache RF3 checkpoint
id: cache-ckpt
uses: actions/cache@v4
with:
path: ~/.cache/rf3_checkpoints
key: rf3-ckpt-rf3_foundry_01_24_latest_remapped

- name: Download RF3 checkpoint
if: steps.cache-ckpt.outputs.cache-hit != 'true'
run: |
mkdir -p ~/.cache/rf3_checkpoints
wget -q \
-O ~/.cache/rf3_checkpoints/rf3_foundry_01_24_latest_remapped.ckpt \
http://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt

- name: Run integration tests
run: |
RF3_CKPT_PATH="$HOME/.cache/rf3_checkpoints/rf3_foundry_01_24_latest_remapped.ckpt" \
pytest models/rf3/tests/integration/ \
-m integration \
-v \
--tb=short
28 changes: 15 additions & 13 deletions models/rf3/src/rf3/inference_engines/rf3.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,22 +572,24 @@ def run(
# Handle early stopping
if network_output.get("early_stopped", False):
ranked_logger.warning(
f"Early stopping triggered for {input_spec.example_id} "
f"with mean pLDDT {network_output['mean_plddt']:.2f} < "
f"{self.early_stopping_plddt_threshold:.2f}!"
f"Early stopping triggered for {input_spec.example_id}: "
f"mean pLDDT {network_output['mean_plddt']:.2f} is below threshold "
f"{self.early_stopping_plddt_threshold:.2f}. "
f"No structure will be written for this input."
)

if out_dir:
# Save early stop info to disk
dict_to_save = {
k: v for k, v in network_output.items() if v is not None
}
df_to_save = pd.DataFrame([dict_to_save])
df_to_save.to_csv(example_out_dir / "score.csv", index=False)

df_to_save = pd.DataFrame([metrics_output])
df_to_save.to_csv(
example_out_dir / f"{input_spec.example_id}_metrics.csv",
# Write a ranking_scores.csv that records the early-stop outcome so
# downstream tooling (and skip_existing) always finds a consistent file.
pd.DataFrame(
[
{
"early_stopped": True,
"mean_plddt": network_output.get("mean_plddt"),
}
]
).to_csv(
example_out_dir / f"{input_spec.example_id}_ranking_scores.csv",
index=False,
)
else:
Expand Down
12 changes: 8 additions & 4 deletions models/rf3/src/rf3/model/RF3.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,14 @@ def forward(
return result | early_stop_data

# (We use `deque` with maxlen=1 to ensure that we only keep the last output in memory)
try:
recycling_outputs = deque(recycling_output_generator, maxlen=1).pop()
except IndexError:
# Handle the case where the generator is empty
remaining = deque(recycling_output_generator, maxlen=1)
if remaining:
recycling_outputs = remaining.pop()
elif should_early_stop_fn:
# n_recycles=1: the single recycle was already consumed by next() for
# the early-stop check; reuse it as the final recycling output.
recycling_outputs = first_recycle_outputs
else:
raise RuntimeError("Recycling generator produced no outputs")

# Predict the distogram from the pair representation
Expand Down
2 changes: 1 addition & 1 deletion models/rf3/src/rf3/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def example_exists(example_id: str) -> bool:
example_dir = get_sharded_output_path(
example_id, existing_outputs_dir, sharding_pattern
)
return (example_dir / f"{example_id}_metrics.csv").exists()
return (example_dir / f"{example_id}_ranking_scores.csv").exists()

inference_inputs = []

Expand Down
Loading
Loading