diff --git a/models/rf3/configs/datasets/pdb_and_distillation.yaml b/models/rf3/configs/datasets/pdb_and_distillation.yaml index 505513cc..d85c46e3 100644 --- a/models/rf3/configs/datasets/pdb_and_distillation.yaml +++ b/models/rf3/configs/datasets/pdb_and_distillation.yaml @@ -9,7 +9,6 @@ defaults: - monomer_distillation - na_complex_distillation - disorder_distillation - # - domain_distillation # - rna_monomer_distillation - val/af3_validation@val.af3_validation - val/af3_validation@val.quick_af3_validation_with_templating @@ -28,8 +27,6 @@ train: probability: 0.02 disorder_distillation: probability: 0.02 - # multidomain_distillation: - # probability: 0.06 # rna_monomer_distillation: # probability: 0.04 diff --git a/models/rf3/configs/datasets/train/domain_distillation.yaml b/models/rf3/configs/datasets/train/domain_distillation.yaml deleted file mode 100644 index 73924aa6..00000000 --- a/models/rf3/configs/datasets/train/domain_distillation.yaml +++ /dev/null @@ -1,50 +0,0 @@ -# TODO: Inherit from common config with default Transform pipeline - -multidomain_distillation: - dataset: - _target_: rf3.data.paired_msa.MultiInputDatasetWrapper - save_failed_examples_to_dir: null - - # cif parser - cif_parser_args: - #assume_residues_all_resolved: true - cache_dir: null - load_from_cache: false - save_to_cache: false - - # metadata parser - dataset_parser: - _target_: rf3.data.paired_msa.MultidomainDFParser - - # metadata dataset - dataset: - _target_: atomworks.ml.datasets.PandasDataset - name: multidomain_distillation - id_column: example_id - data: /projects/ml/datahub/dfs/domain_domain/domain_domain_dataset.DIGS.parquet - columns_to_load: - - example_id - - pdb_path - - msa_path - transform: - _target_: ${datasets.pipeline_target} - is_inference: False - input_contains_explicit_msa: True - protein_msa_dirs: [] - rna_msa_dirs: [] - n_recycles: ${datasets.n_recycles_train} - crop_size: ${datasets.crop_size} - n_msa: ${datasets.n_msa} - diffusion_batch_size: ${datasets.diffusion_batch_size_train} - max_atoms_in_crop: ${datasets.max_atoms_in_crop} - crop_contiguous_probability: 0.25 - crop_spatial_probability: 0.75 - run_confidence_head: ${datasets.run_confidence_head} - take_first_chiral_subordering: ${datasets.take_first_chiral_subordering} - use_element_for_atom_names_of_atomized_tokens: ${datasets.use_element_for_atom_names_of_atomized_tokens} - mirror_prob: 0.0 - atomization_prob: ${datasets.atomization_prob} - ligand_dropout_prob: 0.0 - p_unconditional: ${datasets.p_unconditional} - p_dropout_atom_level_embeddings: ${datasets.p_dropout_atom_level_embeddings} - add_residue_is_paired_feature: ${datasets.add_residue_is_paired_feature} diff --git a/models/rf3/src/rf3/data/paired_msa.py b/models/rf3/src/rf3/data/paired_msa.py deleted file mode 100644 index 8e7b81c8..00000000 --- a/models/rf3/src/rf3/data/paired_msa.py +++ /dev/null @@ -1,217 +0,0 @@ -# mypy: ignore-errors -# -# This module does not type-check (and does not even import) against the installed -# atomworks: `MultiInputDatasetWrapper` below subclasses -# `atomworks.ml.datasets.StructuralDatasetWrapper`, which atomworks turned into a -# deprecated factory *function* — subclassing it raises `TypeError` at import time. -# Making it type-check requires a real refactor onto the `PandasDataset` API, validated -# on cluster data (see `.ai/roadmap.md`), not type annotations. The suppression lives -# here, in the file, rather than in `pyproject.toml` so it is visible to anyone reviving -# the module: when this file imports and type-checks cleanly again, delete this directive -# to restore mypy coverage (the module stays inside mypy's `files` scope). -import os -import socket -import time -from pathlib import Path -from typing import Any - -import numpy as np -from atomworks.common import exists -from atomworks.enums import ChainType -from atomworks.ml.datasets import StructuralDatasetWrapper, logger -from atomworks.ml.datasets.parsers import ( - MetadataRowParser, - load_example_from_metadata_row, -) -from atomworks.ml.transforms._checks import ( - check_contains_keys, - check_is_instance, - check_nonzero_length, -) -from atomworks.ml.transforms.base import Transform, TransformedDict -from atomworks.ml.transforms.msa._msa_loading_utils import load_msa_data_from_path -from atomworks.ml.utils.rng import capture_rng_states -from biotite.structure import AtomArray, concatenate - - -# input data wrapper that allows multiple input files separated by ':' -# data is loaded as concatentation of all inputs -class MultiInputDatasetWrapper(StructuralDatasetWrapper): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def __getitem__(self, idx: int) -> Any: - # Capture example ID & current rng state (for reproducibility & debugging) - if hasattr(self, "idx_to_id"): - # ...if the dataset has a custom idx_to_id method, use it (e.g., for a PandasDataset) - example_id = self.idx_to_id(idx) - else: - # ...otherwise, fallback to a the `id_column` or a string representation of the index - example_id = ( - self.dataset[idx][self.id_column] if self.id_column else f"row_{idx}" - ) - - # Get process id and hostname (for debugging) - logger.debug( - f"({socket.gethostname()}:{os.getpid()}) Processing example ID: {example_id}" - ) - - # Load the row, using the __getitem__ method of the dataset - row = self.dataset[idx] - pdb_path = row["pdb_path"].split(":") - - # Process the row into a transform-ready dictionary with the given CIF and dataset parsers - # We require the "data" dictionary output from `load_example_from_metadata_row` to contain, at a minimum: - # (a) An "id" key, which uniquely identifies the example within the dataframe; and, - # (b) The "path" key, which is the path to the CIF file - _start_parse_time = time.time() - data = None - assert len(pdb_path) <= 2 - - for pdb_i in pdb_path: - row_i = {"example_id": row["example_id"], "path": pdb_i} - data_i = load_example_from_metadata_row( - row_i, self.dataset_parser, cif_parser_args=self.cif_parser_args - ) - - if data is None: - data = data_i - else: - data_i["atom_array"].pn_unit_id = np.full( - len(data_i["atom_array"]), "B_1" - ) # unique pn unit id - data_i["atom_array"].pn_unit_iid = np.full( - len(data_i["atom_array"]), "B_1" - ) # unique pn unit iid - data_i["atom_array"].chain_id = np.full( - len(data_i["atom_array"]), "B" - ) # unique chain id - data_i["atom_array"].chain_iid = np.full( - len(data_i["atom_array"]), "B" - ) # unique chain iid - data["atom_array"] = concatenate( - [data["atom_array"], data_i["atom_array"]] - ) - data["atom_array_stack"] = concatenate( - [data["atom_array_stack"], data_i["atom_array_stack"]] - ) - data["chain_info"]["B"] = data_i["chain_info"]["A"] - - # 'example_id', 'path', 'assembly_id', 'query_pn_unit_iids', - data["path"] = row["pdb_path"] - data["msa_path"] = Path(row["msa_path"]) # save msa - _stop_parse_time = time.time() - - # Manually add timing for cif-parsing - data = TransformedDict(data) - data.__transform_history__.append( - dict( - name="load_example_from_metadata_row", - instance=hex(id(load_example_from_metadata_row)), - start_time=_start_parse_time, - end_time=_stop_parse_time, - processing_time=_stop_parse_time - _start_parse_time, - ) - ) - - # Apply the transformation pipeline to the data - if exists(self.transform): - try: - rng_state_dict = capture_rng_states(include_cuda=False) - data = self.transform(data) - except KeyboardInterrupt as e: - raise e - except Exception as e: - # Log the error and save the failed example to disk (optional) - logger.info(f"Error processing row {idx} ({example_id}): {e}") - - if exists(self.save_failed_examples_to_dir): - save_failed_example_to_disk( - example_id=example_id, - error_msg=e, - rng_state_dict=rng_state_dict, - data={}, # We do not save the data, since it may be large. - fail_dir=self.save_failed_examples_to_dir, - ) - raise e - - return data - - -class MultidomainDFParser(MetadataRowParser): - """Parser for Qian's multidomain data""" - - def __init__( - self, - example_id_colname: str = "example_id", - path_colname: str = "path", - ): - self.example_id_colname = example_id_colname - self.path_colname = path_colname - - def _parse(self, row: dict) -> dict[str, Any]: - query_pn_unit_iids = None - assembly_id = "1" - - return { - "example_id": row[self.example_id_colname], - "path": Path(row[self.path_colname]), - "assembly_id": assembly_id, - "query_pn_unit_iids": query_pn_unit_iids, - "extra_info": row, - } - - -class LoadPairedMSAs(Transform): - """ - LoadPairedMSAs adds paired MSAs from disk, overwriting previously paired MSA data. - """ - - def check_input(self, data: dict[str, Any]): - check_contains_keys(data, ["atom_array", "msa_path"]) - check_is_instance(data, "atom_array", AtomArray) - check_nonzero_length(data, "atom_array") - - def forward(self, data: dict[str, Any]) -> dict[str, Any]: - atom_array = data["atom_array"] - msa_file_path = data["msa_path"] - chain_type = data["chain_info"]["A"]["chain_type"] - max_msa_sequences = 10000 - - msa_data = load_msa_data_from_path( - msa_file_path=msa_file_path, - chain_type=chain_type, - max_msa_sequences=max_msa_sequences, - ) - - # split into chains - start_idx = 0 - allpolymerchains = np.unique( - atom_array.chain_id[ - np.isin(atom_array.chain_type, ChainType.get_polymers()) - ] - ) - - data["polymer_msas_by_chain_id"] = {} # nuke old version - for chain_id in allpolymerchains: - sequence = data["chain_info"][chain_id][ - "processed_entity_non_canonical_sequence" - ] - stop_idx = start_idx + len(sequence) - - data["polymer_msas_by_chain_id"][chain_id] = {} - - # trim all msa info to this chain only - for mkey in msa_data.keys(): - data["polymer_msas_by_chain_id"][chain_id][mkey] = msa_data[mkey][ - ..., start_idx:stop_idx - ] - - # mock msa_is_padded_mask (all 0s) - data["polymer_msas_by_chain_id"][chain_id]["msa_is_padded_mask"] = np.zeros( - data["polymer_msas_by_chain_id"][chain_id]["msa"].shape, dtype=bool - ) - - start_idx = stop_idx - - return data diff --git a/pyproject.toml b/pyproject.toml index e9ff2c44..072c846a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -258,10 +258,7 @@ module = [ ignore_errors = true # NOTE: the rf3 enablement ratchet (0014) is fully cleared — there is no rf3 module-level -# mypy exemption here. The one module that cannot type-check, `rf3.data.paired_msa` -# (broken against the installed atomworks; needs a `PandasDataset`-API refactor), carries -# a file-level `# mypy: ignore-errors` directive in the module itself, so the suppression -# is visible where the code is and the module stays inside mypy's `files` scope. +# mypy exemption here, and every rf3 module type-checks with no in-file suppressions. # Per-module strictness ratchet (direction (b)). The global baseline above leaves # disallow_untyped_defs / check_untyped_defs off; fully-annotated modules opt into strict @@ -274,6 +271,33 @@ module = [ "foundry.utils.alignment", "foundry.utils.weights", "foundry.utils.rotation_augmentation", + "foundry.utils.instantiators", + "foundry.utils.ddp", + "foundry.utils.squashfs", + "foundry.utils.logging", + "foundry.utils.datasets", + "foundry.training.schedulers", + "foundry.training.EMA", + "foundry.training.checkpoint", + "foundry.common", + "foundry.constants", + "foundry.metrics.metric", + "foundry.metrics.losses", + "foundry.callbacks.callback", + "foundry.callbacks.timing_logging", + "foundry.callbacks.metrics_logging", + "foundry.callbacks.train_logging", + "foundry.callbacks.health_logging", + "foundry.inference_engines.base", + "foundry.inference_engines.checkpoint_registry", + "foundry.hydra.resolvers", + "foundry.model.layers.blocks", + "foundry.utils.xpu.single_xpu_strategy", + "foundry.utils.xpu.xpu_accelerator", + "foundry.utils.xpu.xpu_precision", + "foundry.testing.fixtures", + "foundry.testing.pytest_hooks", + "foundry_cli.download_checkpoints", ] disallow_untyped_defs = true check_untyped_defs = true diff --git a/src/foundry/callbacks/callback.py b/src/foundry/callbacks/callback.py index ded81735..e2923f29 100755 --- a/src/foundry/callbacks/callback.py +++ b/src/foundry/callbacks/callback.py @@ -25,42 +25,46 @@ class BaseCallback(ABC): """ # Epoch loops - def on_fit_start(self, trainer: Any): + def on_fit_start(self, trainer: Any) -> None: """Called at the start of the training""" pass - def on_fit_end(self, trainer: Any): + def on_fit_end(self, trainer: Any) -> None: """Called at the end of the training""" pass # Training loop - def on_train_epoch_start(self, trainer: Any): + def on_train_epoch_start(self, trainer: Any) -> None: """Called at the start of each training epoch""" pass - def on_after_train_loader_iter(self, trainer: Any, **kwargs): + def on_after_train_loader_iter(self, trainer: Any, **kwargs: Any) -> None: """Called after 'iter(train_loader)' is called, but before the first batch is yielded""" pass - def on_before_train_loader_next(self, trainer: Any, **kwargs): + def on_before_train_loader_next(self, trainer: Any, **kwargs: Any) -> None: """Called after each batch is yielded from the train_loader 'next(train_iter)' call""" pass - def on_train_batch_start(self, trainer: Any, batch: Any, batch_idx: int): + def on_train_batch_start(self, trainer: Any, batch: Any, batch_idx: int) -> None: """Called at the start of each training batch""" pass def on_train_batch_end( self, trainer: Any, outputs: Any, batch: Any, batch_idx: int - ): + ) -> None: """Called after each training batch, but before the optimizer.step""" pass - def on_before_optimizer_step(self, trainer: Any, optimizer: _FabricOptimizer): + def on_before_optimizer_step( + self, trainer: Any, optimizer: _FabricOptimizer + ) -> None: """Called before each optimizer.step""" pass - def on_after_optimizer_step(self, optimizer: _FabricOptimizer, **kwargs): + def on_after_optimizer_step( + self, optimizer: _FabricOptimizer, **kwargs: Any + ) -> None: """Called after each optimizer.step. NOTE: This hook is called internally by Fabric when optimizer.step() executes. @@ -68,18 +72,18 @@ def on_after_optimizer_step(self, optimizer: _FabricOptimizer, **kwargs): """ pass - def optimizer_step(self, trainer: Any, optimizer: _FabricOptimizer): + def optimizer_step(self, trainer: Any, optimizer: _FabricOptimizer) -> None: """Called after optimizer.step completes. Unlike on_after_optimizer_step, this hook is called explicitly by the trainer and receives trainer access. """ pass - def on_train_epoch_end(self, trainer: Any): + def on_train_epoch_end(self, trainer: Any) -> None: """Called at the end of each training epoch""" pass # Validation loop - def on_validation_epoch_start(self, trainer: Any): + def on_validation_epoch_start(self, trainer: Any) -> None: """Called at the start of each validation epoch""" pass @@ -90,7 +94,7 @@ def on_validation_batch_start( batch_idx: int, num_batches: int, dataset_name: str | None = None, - ): + ) -> None: """Called at the start of each validation batch""" pass @@ -102,15 +106,15 @@ def on_validation_batch_end( batch_idx: int, num_batches: int, dataset_name: str | None = None, - ): + ) -> None: """Called after each validation batch""" pass - def on_validation_epoch_end(self, trainer: Any): + def on_validation_epoch_end(self, trainer: Any) -> None: """Called at the end of each validation epoch""" pass # Saving and Loading - def on_save_checkpoint(self, trainer: Any, state: dict[str, Any]): + def on_save_checkpoint(self, trainer: Any, state: dict[str, Any]) -> None: """Called when saving a checkpoint""" pass diff --git a/src/foundry/callbacks/health_logging.py b/src/foundry/callbacks/health_logging.py index d9881599..36ef5575 100644 --- a/src/foundry/callbacks/health_logging.py +++ b/src/foundry/callbacks/health_logging.py @@ -93,7 +93,7 @@ def __init__( self.log_histograms = {} @rank_zero_only - def on_fit_start(self, trainer): + def on_fit_start(self, trainer: Any) -> None: """Initialize the callback and register activation hooks.""" # Check that we either have loggers attached or keep_cache is True, otherwise the # data will be computed but not logged. @@ -104,7 +104,7 @@ def on_fit_start(self, trainer): ) @rank_zero_only - def on_train_batch_start(self, trainer, batch: Any, batch_idx: int): + def on_train_batch_start(self, trainer: Any, batch: Any, batch_idx: int) -> None: step = trainer.state["global_step"] model = trainer.state["model"] if (self.log_activations or "activations" in self.log_histograms) and ( @@ -113,7 +113,9 @@ def on_train_batch_start(self, trainer, batch: Any, batch_idx: int): self._register_activation_hooks(model, step) @rank_zero_only - def on_before_optimizer_step(self, trainer, optimizer: _FabricOptimizer, **kwargs): + def on_before_optimizer_step( + self, trainer: Any, optimizer: _FabricOptimizer, **kwargs: Any + ) -> None: """Log gradients, weights, and activations before optimizer step.""" step = trainer.state["global_step"] @@ -143,25 +145,27 @@ def on_before_optimizer_step(self, trainer, optimizer: _FabricOptimizer, **kwarg if key.endswith("hist"): self._cache[key].append(value) - def on_train_batch_end(self, trainer, **kwargs): + # Fabric dispatches hooks by keyword, so absorbing the base's positional + # outputs/batch/batch_idx via **kwargs is intentional; see TimingCallback. + def on_train_batch_end(self, trainer: Any, **kwargs: Any) -> None: # type: ignore[override] """Called at the end of a training batch - clear temporary cache.""" self._temp_cache["scalars"].clear() self._temp_cache["histograms"].clear() self._remove_activation_hooks() - def on_fit_end(self, trainer, **kwargs): + def on_fit_end(self, trainer: Any, **kwargs: Any) -> None: """Clean up activation hooks at the end of training.""" self._remove_activation_hooks() - def on_validation_epoch_start(self, trainer): + def on_validation_epoch_start(self, trainer: Any) -> None: # Temporarily remove any hooks for validation self._remove_activation_hooks() @rank_zero_only - def on_save_checkpoint(self, trainer, state: dict[str, Any]): + def on_save_checkpoint(self, trainer: Any, state: dict[str, Any]) -> None: self._remove_activation_hooks() - def _collect_parameter_stats(self, model, step: int): + def _collect_parameter_stats(self, model: nn.Module, step: int) -> None: """Collect gradient and weight statistics in a single parameter iteration.""" cache = self._temp_cache # alias @@ -213,12 +217,12 @@ def _should_track_activation(self, name: str, module_type: type[nn.Module]) -> b return True return self.filter_activations(name, module_type) - def _register_activation_hooks(self, model, step: int): + def _register_activation_hooks(self, model: nn.Module, step: int) -> None: """Register forward hooks to accumulate activations.""" cache = self._temp_cache # alias - def create_activation_hook(name): - def hook(module, input, output): + def create_activation_hook(name: str) -> Callable[..., None]: + def hook(module: nn.Module, input: Any, output: Any) -> None: if isinstance(output, torch.Tensor) and (step % self.log_freq == 0): output = output.detach() for stat_name, stat_fn in self.log_activations.items(): @@ -238,13 +242,13 @@ def hook(module, input, output): hook = module.register_forward_hook(create_activation_hook(name)) self._hooks.append(hook) - def _remove_activation_hooks(self): + def _remove_activation_hooks(self) -> None: """Remove activation hooks.""" for hook in self._hooks: hook.remove() self._hooks.clear() - def __del__(self): + def __del__(self) -> None: self._remove_activation_hooks() del self._temp_cache del self._cache @@ -326,7 +330,7 @@ def plot_tensor_stats( norm: Float[Tensor, "N"] | None = None, name: str = "", height_ratios: tuple[float, float] = (5, 1), -): +) -> plt.Figure: """ Plot comprehensive statistics with mean/std/min/max in top panel and norm in bottom panel. diff --git a/src/foundry/callbacks/metrics_logging.py b/src/foundry/callbacks/metrics_logging.py index 63a6b66b..94ad7759 100644 --- a/src/foundry/callbacks/metrics_logging.py +++ b/src/foundry/callbacks/metrics_logging.py @@ -24,7 +24,7 @@ def __init__( self.save_dir = Path(save_dir) self.metrics_to_save = metrics_to_save - def _save_dataframe_for_rank(self, rank: int, epoch: int): + def _save_dataframe_for_rank(self, rank: int, epoch: int) -> None: """Saves per-GPU output dataframe of metrics to a rank-specific CSV.""" self.save_dir.mkdir(parents=True, exist_ok=True) file_path = self.save_dir / f"validation_output_rank_{rank}_epoch_{epoch}.csv" @@ -39,18 +39,18 @@ def _save_dataframe_for_rank(self, rank: int, epoch: int): f"Saved validation outputs to {file_path} for rank {rank}, epoch {epoch}" ) - def on_validation_epoch_start(self, trainer): + def on_validation_epoch_start(self, trainer: Any) -> None: self.per_gpu_outputs_df = pd.DataFrame() def on_validation_batch_end( self, - trainer, + trainer: Any, outputs: dict, batch: Any, batch_idx: int, num_batches: int, dataset_name: str | None = None, - ): + ) -> None: """Build a flattened DataFrame from the metrics output and accumulate with the prior batches""" assert "metrics_output" in outputs, "Validation outputs must contain metrics." metrics_output = deepcopy(outputs["metrics_output"]) @@ -71,7 +71,7 @@ def on_validation_batch_end( def _build_row_from_flattened_dict( dict_to_flatten: dict, prefix: str, example_id: str - ): + ) -> dict[str, Any]: """Helper function to build a DataFrame row""" flattened_dict = nested_dict.flatten(dict_to_flatten, fuse_keys=".") row_data = {"example_id": example_id} @@ -121,7 +121,7 @@ def _build_row_from_flattened_dict( f"Validation Progress: {100 * (batch_idx + 1) / num_batches:.0f}% for {dataset_name}" ) - def on_validation_epoch_end(self, trainer): + def on_validation_epoch_end(self, trainer: Any) -> None: """Aggregate and log the validation metrics at the end of the epoch. Each rank writes out its partial CSV. Then rank 0 aggregates them, logs grouped metrics by dataset, @@ -201,7 +201,7 @@ def _load_and_concatenate_csvs(self, epoch: int) -> pd.DataFrame: # Concatenate dataframes, filling missing columns with NaN return pd.concat(final_dataframes, axis=0, ignore_index=True, sort=False) - def _cleanup_temp_files(self): + def _cleanup_temp_files(self) -> None: """Remove temporary files used to store individual rank outputs.""" all_files = list(self.save_dir.rglob("validation_output_rank_*_epoch_*.csv")) for file in all_files: diff --git a/src/foundry/callbacks/timing_logging.py b/src/foundry/callbacks/timing_logging.py index 62f7b363..a329f7a7 100644 --- a/src/foundry/callbacks/timing_logging.py +++ b/src/foundry/callbacks/timing_logging.py @@ -1,5 +1,9 @@ import pandas as pd +from beartype.typing import Any from lightning.fabric.utilities.rank_zero import rank_zero_only +from lightning.fabric.wrappers import ( + _FabricOptimizer, +) from foundry.callbacks.callback import BaseCallback from foundry.utils.logging import print_df_as_table @@ -7,46 +11,57 @@ class TimingCallback(BaseCallback): - """Fabric callback to print timing metrics.""" + """Fabric callback to print timing metrics. - def __init__(self, log_every_n: int = 100): + The hooks that the base declares with explicit positional params + (``on_train_batch_start``/``on_train_batch_end``/``on_before_optimizer_step``) + are overridden here with ``**kwargs`` because Fabric always dispatches hooks by + keyword (``fabric.call(name, trainer=..., batch=..., ...)``), so the unused + arguments are simply absorbed. mypy flags the narrower signature as an + incompatible override; the ``# type: ignore[override]`` documents that this is + intentional and safe given the keyword-only dispatch. + """ + + def __init__(self, log_every_n: int = 100) -> None: super().__init__() self.log_every_n = log_every_n self.timers = Timers() self.n_steps_since_last_log = 0 @rank_zero_only - def on_train_epoch_start(self, trainer, **kwargs): + def on_train_epoch_start(self, trainer: Any, **kwargs: Any) -> None: self.timers.start("train_loader_iter") @rank_zero_only - def on_after_train_loader_iter(self, trainer, **kwargs): + def on_after_train_loader_iter(self, trainer: Any, **kwargs: Any) -> None: self.timers.stop("train_loader_iter") @rank_zero_only - def on_before_train_loader_next(self, trainer, **kwargs): + def on_before_train_loader_next(self, trainer: Any, **kwargs: Any) -> None: self.timers.start("train_step", "train_loader_next") @rank_zero_only - def on_train_batch_start(self, trainer, **kwargs): + def on_train_batch_start(self, trainer: Any, **kwargs: Any) -> None: # type: ignore[override] self.timers.start("forward_loss_backward") self.timers.stop("train_loader_next") @rank_zero_only - def on_train_batch_end(self, trainer, **kwargs): + def on_train_batch_end(self, trainer: Any, **kwargs: Any) -> None: # type: ignore[override] self.timers.stop("forward_loss_backward") self.timers.stop("train_step") @rank_zero_only - def on_before_optimizer_step(self, trainer, **kwargs): + def on_before_optimizer_step(self, trainer: Any, **kwargs: Any) -> None: # type: ignore[override] self.timers.start("optimizer_step") @rank_zero_only - def on_after_optimizer_step(self, optimizer, **kwargs): + def on_after_optimizer_step( + self, optimizer: _FabricOptimizer, **kwargs: Any + ) -> None: self.timers.stop("optimizer_step") @rank_zero_only - def optimizer_step(self, trainer, optimizer): + def optimizer_step(self, trainer: Any, optimizer: _FabricOptimizer) -> None: step = trainer.state["global_step"] self.n_steps_since_last_log += 1 if step % self.log_every_n == 0: @@ -59,7 +74,7 @@ def optimizer_step(self, trainer, optimizer): if trainer.fabric.is_global_zero: self._print_timings(timings) - def _print_timings(self, timings: dict[str, float]): + def _print_timings(self, timings: dict[str, float]) -> None: df = pd.DataFrame(timings.items(), columns=["Step", "Time (s)"]) print_df_as_table( df, title=f"Timing stats (over {self.n_steps_since_last_log} steps)" diff --git a/src/foundry/callbacks/train_logging.py b/src/foundry/callbacks/train_logging.py index 3fbe0128..f4efee47 100755 --- a/src/foundry/callbacks/train_logging.py +++ b/src/foundry/callbacks/train_logging.py @@ -26,7 +26,7 @@ class LogModelParametersCallback(BaseCallback): """Print a table of the total and trainable parameters of the model at the start of training.""" - def on_fit_start(self, trainer): + def on_fit_start(self, trainer: Any) -> None: print_model_parameters(trainer.state["model"]) @@ -39,7 +39,7 @@ class PrintExampleIDBeforeForwardPassCallback(BaseCallback): def __init__(self, rank_zero_only: bool = True): self.logger = RankedLogger(__name__, rank_zero_only=rank_zero_only) - def on_train_batch_start(self, trainer, batch: Any, batch_idx: int): + def on_train_batch_start(self, trainer: Any, batch: Any, batch_idx: int) -> None: example_id = batch[0]["example_id"] # Prepare the formatted strings with colors @@ -58,17 +58,17 @@ def on_train_batch_start(self, trainer, batch: Any, batch_idx: int): class LogDatasetSamplingRatiosCallback(BaseCallback): """Monitor the sampling ratios of the datasets and log after each epoch.""" - def on_fit_start(self, trainer): - self.dataset_sampling_counts = defaultdict(int) + def on_fit_start(self, trainer: Any) -> None: + self.dataset_sampling_counts: defaultdict[str, int] = defaultdict(int) - def on_train_batch_start(self, trainer, batch, batch_idx): + def on_train_batch_start(self, trainer: Any, batch: Any, batch_idx: int) -> None: example_id = batch[0]["example_id"] if trainer.fabric.is_global_zero: dataset_string = "/".join(parse_example_id(example_id)["datasets"]) self.dataset_sampling_counts[dataset_string] += 1 - def on_train_epoch_end(self, trainer): + def on_train_epoch_end(self, trainer: Any) -> None: if trainer.fabric.is_global_zero: total_samples = sum(self.dataset_sampling_counts.values()) @@ -100,7 +100,7 @@ class LogLearningRateCallback(BaseCallback): def __init__(self, log_every_n: int): self.log_every_n = log_every_n - def optimizer_step(self, trainer, optimizer: _FabricOptimizer): + def optimizer_step(self, trainer: Any, optimizer: _FabricOptimizer) -> None: # Get the current global step current_step = trainer.state["global_step"] @@ -140,17 +140,19 @@ def __init__( self.log_every_n = log_every_n self.log_full_batch_losses = log_full_batch_losses - self.start_time = None + self.start_time: float | None = None self.logger = RankedLogger(__name__, rank_zero_only=True) # This dict will store key -> MeanMetric() for each loss self.loss_trackers: dict[str, MeanMetric] = {} - def on_train_epoch_start(self, trainer): + def on_train_epoch_start(self, trainer: Any) -> None: # Record the start time of the epoch self.start_time = time.time() - def on_train_batch_end(self, trainer, outputs: Any, batch: Any, batch_idx: int): + def on_train_batch_end( + self, trainer: Any, outputs: Any, batch: Any, batch_idx: int + ) -> None: mean_loss_dict = {} if "loss_dict" in outputs: mean_loss_dict.update(mean_losses(outputs["loss_dict"])) @@ -230,12 +232,15 @@ def on_train_batch_end(self, trainer, outputs: Any, batch: Any, batch_idx: int): safe_print(combined_content) - def on_train_epoch_end(self, trainer): + def on_train_epoch_end(self, trainer: Any) -> None: # Gather final epoch means (must be run on all ranks) final_means = { k: tracker.compute().item() for k, tracker in self.loss_trackers.items() } + # on_train_epoch_start always runs first in the training loop, so start_time is set. + assert self.start_time is not None + # Calculate elapsed time and number of batches (from the total_loss tracker, if available) elapsed_time = time.time() - self.start_time num_batches = ( diff --git a/src/foundry/hydra/resolvers.py b/src/foundry/hydra/resolvers.py index 25bfd63e..095d1462 100644 --- a/src/foundry/hydra/resolvers.py +++ b/src/foundry/hydra/resolvers.py @@ -7,7 +7,7 @@ import importlib from atomworks.enums import ChainType, ChainTypeInfo -from beartype.typing import Any +from beartype.typing import Any, Callable from omegaconf import OmegaConf from ..common import run_once @@ -15,8 +15,8 @@ # (Custom resolvers) @run_once -def register_resolvers(): - resolvers = { +def register_resolvers() -> None: + resolvers: dict[str, Callable[..., Any]] = { "resolve_import": resolve_import, "chain_type_info_to_regex": chain_type_info_to_regex, } @@ -48,7 +48,7 @@ def resolve_import(module_path: str, attribute_path: str | None = None) -> Any: return module -def chain_type_info_to_regex(*args) -> Any: +def chain_type_info_to_regex(*args: str) -> Any: """Convert a combination of ChainType or ChainTypeInfo attributes to a regex string. Primarily used for filtering a dataset by chain type prior to training/validation. diff --git a/src/foundry/inference_engines/base.py b/src/foundry/inference_engines/base.py index 81ef213e..574cc837 100644 --- a/src/foundry/inference_engines/base.py +++ b/src/foundry/inference_engines/base.py @@ -2,6 +2,7 @@ import os from os import PathLike from pathlib import Path +from types import TracebackType from typing import Any, Dict import hydra @@ -25,7 +26,7 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True) -def merge(cfg, overrides: dict): +def merge(cfg: Any, overrides: dict) -> Any: return OmegaConf.merge(cfg, OmegaConf.create(overrides)) @@ -41,9 +42,9 @@ def __init__( num_nodes: int = 1, devices_per_node: int = 1, # Config overrides - transform_overrides={}, - inference_sampler_overrides={}, - trainer_overrides={}, + transform_overrides: dict[str, Any] = {}, + inference_sampler_overrides: dict[str, Any] = {}, + trainer_overrides: dict[str, Any] = {}, # Debug verbose: bool = False, seed: int | None = None, @@ -125,7 +126,7 @@ def __init__( # Required subclasss methods ################################################################################### - def initialize(self): + def initialize(self) -> Any: if self.initialized_: return getattr(self, "cfg", None) @@ -149,7 +150,7 @@ def run( inputs: ( Dict[str, dict] | AtomArray | list[AtomArray] | PathLike | list[PathLike] ), - *_, + *_: Any, ) -> dict[str, dict] | None: self.initialize() raise NotImplementedError( @@ -160,12 +161,12 @@ def run( # Util methods ################################################################################### - def _override_checkpoint_config(self, cfg): + def _override_checkpoint_config(self, cfg: Any) -> Any: cfg = merge(cfg, self.overrides) cfg = set_accelerator_based_on_availability(cfg) return cfg - def _construct_trainer(self, cfg, checkpoint=None): + def _construct_trainer(self, cfg: Any, checkpoint: Any = None) -> None: """ Sets attr self.trainer """ @@ -209,7 +210,7 @@ def _assign_override(self, dotted_key: str, value: Any) -> None: target = target[key] target[keys[-1]] = value - def _construct_pipeline(self, cfg): + def _construct_pipeline(self, cfg: Any) -> None: """ Sets attr self.pipeline """ @@ -232,17 +233,22 @@ def _construct_pipeline(self, cfg): self.pipeline = hydra.utils.instantiate(transform) # aliases for run - def forward(self, *args, **kwargs): + def forward(self, *args: Any, **kwargs: Any) -> dict[str, dict] | None: return self.run(*args, **kwargs) - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> dict[str, dict] | None: return self.run(*args, **kwargs) # for use as a context manager: e.g. `with BaseInferenceEngine(...) as engine:` to automatically cleanup - def __enter__(self): + def __enter__(self) -> "BaseInferenceEngine": return self - def __exit__(self, exc_type, exc, tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: self.trainer = None self.pipeline = None self.initialized_ = False diff --git a/src/foundry/inference_engines/checkpoint_registry.py b/src/foundry/inference_engines/checkpoint_registry.py index 3f1af3db..ed2cea21 100644 --- a/src/foundry/inference_engines/checkpoint_registry.py +++ b/src/foundry/inference_engines/checkpoint_registry.py @@ -68,7 +68,7 @@ class RegisteredCheckpoint: description: str sha256: None = None # Optional: add checksum for verification - def get_default_path(self): + def get_default_path(self) -> Path: checkpoint_dirs = get_default_checkpoint_dirs() for checkpoint_dir in checkpoint_dirs: candidate = checkpoint_dir / self.filename diff --git a/src/foundry/metrics/losses.py b/src/foundry/metrics/losses.py index 02cc4cc1..dfa5f255 100644 --- a/src/foundry/metrics/losses.py +++ b/src/foundry/metrics/losses.py @@ -1,10 +1,12 @@ import hydra +import torch import torch.nn as nn +from beartype.typing import Any, cast from omegaconf import DictConfig class Loss(nn.Module): - def __init__(self, **losses): + def __init__(self, **losses: Any) -> None: super().__init__() self.to_compute = [] for loss_name, loss in losses.items(): @@ -16,15 +18,20 @@ def __init__(self, **losses): def forward( self, - network_input, - network_output, - loss_input, - ): - loss_dict = {} + network_input: dict[str, Any], + network_output: dict[str, Any], + loss_input: dict[str, Any], + ) -> tuple[torch.Tensor, dict[str, Any]]: + loss_dict: dict[str, Any] = {} + # Start the accumulator as the int 0 (not a 0-d tensor): the first `+=` + # then adopts the device/dtype of the child losses via scalar promotion. + # A `torch.zeros(())` here would sit on the CPU and break GPU training on + # a device mismatch. After the (always non-empty) loop `loss` is a Tensor. loss = 0 for loss_fn in self.to_compute: loss_, loss_dict_ = loss_fn(network_input, network_output, loss_input) loss += loss_ loss_dict.update(loss_dict_) - loss_dict["total_loss"] = loss.detach() - return loss, loss_dict + total_loss = cast(torch.Tensor, loss) + loss_dict["total_loss"] = total_loss.detach() + return total_loss, loss_dict diff --git a/src/foundry/model/layers/blocks.py b/src/foundry/model/layers/blocks.py index a266430c..002bbac0 100644 --- a/src/foundry/model/layers/blocks.py +++ b/src/foundry/model/layers/blocks.py @@ -10,7 +10,7 @@ class FourierEmbedding(nn.Module): w: torch.Tensor b: torch.Tensor - def __init__(self, c): + def __init__(self, c: int) -> None: super().__init__() self.c = c self.register_buffer("w", torch.zeros(c, dtype=torch.float32)) @@ -24,14 +24,14 @@ def reset_parameters(self) -> None: def forward( self, - t, # [D] - ): + t: torch.Tensor, # [D] + ) -> torch.Tensor: return torch.cos(2 * pi * (t[..., None] * self.w + self.b)) class Dropout(nn.Module): # Dropout entire row or column - def __init__(self, broadcast_dim=None, p_drop=0.15): + def __init__(self, broadcast_dim: int | None = None, p_drop: float = 0.15) -> None: super(Dropout, self).__init__() # give ones with probability of 1-p_drop / zeros with p_drop self.sampler = torch.distributions.bernoulli.Bernoulli( @@ -40,7 +40,7 @@ def __init__(self, broadcast_dim=None, p_drop=0.15): self.broadcast_dim = broadcast_dim self.p_drop = p_drop - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.training: # no drophead during evaluation mode return x shape = list(x.shape) diff --git a/src/foundry/testing/fixtures.py b/src/foundry/testing/fixtures.py index 9e67f282..f8974f2b 100644 --- a/src/foundry/testing/fixtures.py +++ b/src/foundry/testing/fixtures.py @@ -7,7 +7,7 @@ @pytest.fixture(scope="session") -def gpu(): +def gpu() -> bool: """Fixture to check GPU availability for tests that require CUDA or MPS.""" if not torch.cuda.is_available() and not torch.backends.mps.is_available(): pytest.skip("GPU not available") diff --git a/src/foundry/testing/pytest_hooks.py b/src/foundry/testing/pytest_hooks.py index e05854bc..ec4b348c 100644 --- a/src/foundry/testing/pytest_hooks.py +++ b/src/foundry/testing/pytest_hooks.py @@ -1,10 +1,11 @@ """Shared pytest configuration hooks for foundry tests.""" +import pytest import rootutils from dotenv import load_dotenv -def configure_pytest(config, conftest_file: str) -> None: +def configure_pytest(config: pytest.Config, conftest_file: str) -> None: """Configure pytest for foundry tests. Sets up project root and environment variables. """ diff --git a/src/foundry/training/EMA.py b/src/foundry/training/EMA.py index f6fffd06..efcf7936 100644 --- a/src/foundry/training/EMA.py +++ b/src/foundry/training/EMA.py @@ -1,5 +1,6 @@ from collections import OrderedDict from copy import deepcopy +from typing import Any import torch import torch.nn as nn @@ -28,7 +29,7 @@ def __init__(self, model: nn.Module, decay: float): param.detach_() @torch.no_grad() - def update(self): + def update(self) -> None: """Update the shadow model using the weight of the original model and the decay rate.""" if not self.training: raise RuntimeError("EMA update should only be called during training") @@ -59,7 +60,7 @@ def update(self): # ... copy the buffers from the model to the shadow shadow_buffers[name].copy_(buffer) - def forward(self, *args, **kwargs): + def forward(self, *args: Any, **kwargs: Any) -> Any: """Dynamic dispatch to the correct model (model or shadow).""" if self.training: return self.model(*args, **kwargs) diff --git a/src/foundry/training/checkpoint.py b/src/foundry/training/checkpoint.py index 79458c3d..6b8332c5 100644 --- a/src/foundry/training/checkpoint.py +++ b/src/foundry/training/checkpoint.py @@ -6,11 +6,15 @@ .. _PyTorch Checkpoint Documentation: https://pytorch.org/docs/stable/checkpoint.html """ +from typing import Any, Callable + import torch from torch.utils.checkpoint import checkpoint -def create_custom_forward(module, **kwargs): +def create_custom_forward( + module: Callable[..., Any], **kwargs: Any +) -> Callable[..., Any]: """Create a custom forward function for gradient checkpointing with fixed kwargs. Enables passing keyword arguments to a module when using PyTorch's checkpoint function, @@ -25,13 +29,13 @@ def create_custom_forward(module, **kwargs): with the fixed kwargs to the original module. """ - def custom_forward(*inputs): + def custom_forward(*inputs: Any) -> Any: return module(*inputs, **kwargs) return custom_forward -def activation_checkpointing(function): +def activation_checkpointing(function: Callable[..., Any]) -> Callable[..., Any]: """Decorator to enable gradient checkpointing for a function during training. Args: @@ -51,7 +55,7 @@ def forward(self, x, mask=None): Uses ``use_reentrant=False`` for compatibility with recent PyTorch versions. """ - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: if torch.is_grad_enabled(): return checkpoint( create_custom_forward(function, **kwargs), *args, use_reentrant=False diff --git a/src/foundry/utils/datasets.py b/src/foundry/utils/datasets.py index 444a772b..7ebd82d5 100755 --- a/src/foundry/utils/datasets.py +++ b/src/foundry/utils/datasets.py @@ -61,10 +61,12 @@ def wrap_dataset_and_sampler_with_fallbacks( Returns: tuple[Dataset, Sampler]: The wrapped dataset and sampler with fallbacks. """ - # Instantiate a new fallback sampler to avoid scaling issues - if "weights" in sampler_to_fallback_to: - # `.weights` lives on weighted samplers only; the base `Sampler` type doesn't expose it. - fallback_weights = sampler_to_fallback_to.weights # type: ignore[attr-defined] + # Instantiate a new fallback sampler to avoid scaling issues. + # `hasattr`, not `"weights" in sampler`: `in` on a Sampler has no `__contains__`, so it + # iterates the sampler's indices (ints) and never matches the string "weights" — the + # weighted branch was dead and the membership test needlessly drew samples. + if hasattr(sampler_to_fallback_to, "weights"): + fallback_weights = sampler_to_fallback_to.weights else: # torch types `Dataset` without `__len__`, but map-style datasets provide it. fallback_weights = torch.ones(len(dataset_to_fallback_to)) # type: ignore[arg-type] diff --git a/src/foundry/utils/ddp.py b/src/foundry/utils/ddp.py index fef47bd3..0dcbe16f 100644 --- a/src/foundry/utils/ddp.py +++ b/src/foundry/utils/ddp.py @@ -79,7 +79,7 @@ def __init__( self.rank_zero_only = rank_zero_only def log( # type: ignore[override] # deliberately extends LoggerAdapter.log with a `rank` parameter - self, level: int, msg: str, rank: int | None = None, *args, **kwargs + self, level: int, msg: str, rank: int | None = None, *args: Any, **kwargs: Any ) -> None: """ Delegate a log call to the underlying logger, after prefixing its message with the rank diff --git a/src/foundry/utils/logging.py b/src/foundry/utils/logging.py index 95b580e3..e6a8b3bc 100755 --- a/src/foundry/utils/logging.py +++ b/src/foundry/utils/logging.py @@ -3,7 +3,7 @@ from contextlib import contextmanager import pandas as pd -from beartype.typing import Any +from beartype.typing import Any, Iterator from lightning.fabric.utilities import rank_zero_only from omegaconf import DictConfig, OmegaConf from rich.console import Console @@ -20,14 +20,14 @@ class CachedDataFilter(logging.Filter): """Filter to suppress atomworks cached data logging messages.""" - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: # Filter out "Cached data not found" messages if "Cached data not found" in record.getMessage(): return False return True -def silence_warnings(): +def silence_warnings() -> None: """Silence common warnings that appear during foundry execution.""" warnings.filterwarnings( "ignore", message="All-NaN slice encountered", category=RuntimeWarning @@ -67,7 +67,7 @@ def silence_warnings(): @contextmanager -def suppress_warnings(is_inference: bool = False): +def suppress_warnings(is_inference: bool = False) -> Iterator[None]: """Context manager to suppress specific warnings within its scope. Args: @@ -178,7 +178,7 @@ def print_model_parameters(model: nn.Module, name: str = "") -> None: def log_hyperparameters_with_all_loggers( trainer: Any, cfg: dict | DictConfig, model: Any -): +) -> None: """Logs hyperparameters using all loggers in the trainer. Args: @@ -260,7 +260,7 @@ def table_from_df(df: pd.DataFrame, title: str) -> Table: return table -def safe_print(obj: Any, console_width=100, logger: Any | None = None) -> None: +def safe_print(obj: Any, console_width: int = 100, logger: Any | None = None) -> None: """Print a Rich object in a console- and logger-safe manner.""" console = Console(force_terminal=False, color_system=None, width=console_width) diff --git a/src/foundry_cli/download_checkpoints.py b/src/foundry_cli/download_checkpoints.py index 05821c1c..b82a83a6 100644 --- a/src/foundry_cli/download_checkpoints.py +++ b/src/foundry_cli/download_checkpoints.py @@ -156,7 +156,7 @@ def install( force: bool = typer.Option( False, "--force", "-f", help="Overwrite existing checkpoints" ), -): +) -> None: """Install model checkpoints for foundry. Examples: foundry install all @@ -186,7 +186,7 @@ def install( @app.command(name="list-available") -def list_available(): +def list_available() -> None: """List available model checkpoints.""" console.print("[bold]Available models:[/bold]\n") for name, info in REGISTERED_CHECKPOINTS.items(): @@ -194,7 +194,7 @@ def list_available(): @app.command(name="list-installed") -def list_installed(): +def list_installed() -> None: """List installed checkpoints and their sizes.""" checkpoint_dirs = _resolve_checkpoint_dirs(None) @@ -214,7 +214,7 @@ def list_installed(): raise typer.Exit(0) console.print("[bold]Installed checkpoints:[/bold]\n") - total_size = 0 + total_size = 0.0 for ckpt, size in sorted(checkpoint_files, key=lambda item: str(item[0])): total_size += size console.print(f" {ckpt} {size:8.2f} GB") @@ -227,7 +227,7 @@ def clean( confirm: bool = typer.Option( True, "--confirm/--no-confirm", help="Ask for confirmation before deleting" ), -): +) -> None: """Remove all downloaded checkpoints.""" checkpoint_dirs = _resolve_checkpoint_dirs(None) diff --git a/tests/test_blocks.py b/tests/test_blocks.py new file mode 100644 index 00000000..6c34d21e --- /dev/null +++ b/tests/test_blocks.py @@ -0,0 +1,62 @@ +"""Unit tests for foundry.model.layers.blocks. + +``FourierEmbedding`` and ``Dropout`` are small CPU nn.Modules. Their +deterministic, reachable behaviours are pinned here: the Fourier features are +cosines (so bounded to [-1, 1]) of the right shape, and ``Dropout`` is an +identity in eval mode, scales surviving entries by ``1/(1-p)``, and (with a +``broadcast_dim``) applies one mask value across that whole dimension — i.e. +drops entire rows/columns rather than individual entries. +""" + +import torch + +from foundry.model.layers.blocks import Dropout, FourierEmbedding + + +def test_fourier_embedding_shape_and_cosine_range(): + embed = FourierEmbedding(c=8) + t = torch.arange(5, dtype=torch.float32) + + out = embed(t) + + assert out.shape == (5, 8) + assert torch.all(out <= 1.0) and torch.all(out >= -1.0) + + +def test_dropout_is_identity_in_eval_mode(): + dropout = Dropout(p_drop=0.5) + dropout.eval() + x = torch.randn(4, 6) + + assert dropout(x) is x + + +def test_dropout_scales_survivors_by_keep_probability(): + # p_drop=0 -> Bernoulli(1.0) always keeps, so output equals input exactly. + dropout = Dropout(p_drop=0.0) + dropout.train() + x = torch.randn(3, 5) + + assert torch.allclose(dropout(x), x) + + +def test_dropout_broadcasts_one_mask_value_across_the_dimension(): + torch.manual_seed(0) + dropout = Dropout(broadcast_dim=1, p_drop=0.5) + dropout.train() + x = torch.ones(2, 3, 4) + + out = dropout(x) + + # Each surviving entry is scaled by 1/(1-0.5) = 2; dropped entries are 0. + assert torch.all((out == 0.0) | (out == 2.0)) + # broadcast_dim=1 means the same mask is applied across that dim: every + # slice along dim 1 is identical. + assert torch.equal(out[:, 0, :], out[:, 1, :]) + assert torch.equal(out[:, 1, :], out[:, 2, :]) + + +if __name__ == "__main__": + import pytest + + pytest.main(["-v", __file__]) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py new file mode 100644 index 00000000..acb94593 --- /dev/null +++ b/tests/test_callbacks.py @@ -0,0 +1,108 @@ +"""Unit tests for the pure logic in foundry.callbacks. + +Most of the callbacks package is side-effecting training/validation glue (Rich +console output, fabric logging, forward-hook registration, matplotlib plotting) +and is intentionally not unit-tested. The one piece with a non-obvious, +environment-independent contract is pinned here: + +``StoreValidationMetricsInDFCallback._load_and_concatenate_csvs`` merges the +per-rank validation CSVs written for one epoch, de-duplicating rows by the +``example_id``/``dataset`` pair (the same example may be validated on more than +one rank) and skipping empty rank files. +""" + +import pandas as pd +import pytest + +from foundry.callbacks.metrics_logging import StoreValidationMetricsInDFCallback + + +def _write_rank_csv(save_dir, rank: int, epoch: int, rows: list[dict]) -> None: + path = save_dir / f"validation_output_rank_{rank}_epoch_{epoch}.csv" + pd.DataFrame(rows, columns=["example_id", "dataset", "lddt"]).to_csv( + path, index=False + ) + + +def _callback(tmp_path) -> StoreValidationMetricsInDFCallback: + return StoreValidationMetricsInDFCallback(save_dir=tmp_path) + + +def test_single_rank_returns_all_rows_without_temp_key(tmp_path): + _write_rank_csv( + tmp_path, + rank=0, + epoch=3, + rows=[ + {"example_id": "e1", "dataset": "d1", "lddt": 0.9}, + {"example_id": "e2", "dataset": "d1", "lddt": 0.8}, + ], + ) + + merged = _callback(tmp_path)._load_and_concatenate_csvs(epoch=3) + + assert sorted(merged["example_id"]) == ["e1", "e2"] + assert "_example_key" not in merged.columns + + +def test_duplicate_example_across_ranks_is_deduplicated(tmp_path): + _write_rank_csv( + tmp_path, + rank=0, + epoch=1, + rows=[ + {"example_id": "e1", "dataset": "d1", "lddt": 0.9}, + {"example_id": "e2", "dataset": "d1", "lddt": 0.8}, + ], + ) + _write_rank_csv( + tmp_path, + rank=1, + epoch=1, + rows=[ + {"example_id": "e2", "dataset": "d1", "lddt": 0.8}, + {"example_id": "e3", "dataset": "d1", "lddt": 0.7}, + ], + ) + + merged = _callback(tmp_path)._load_and_concatenate_csvs(epoch=1) + + # e2 appears on both ranks but is kept once; e1 and e3 once each. + assert sorted(merged["example_id"]) == ["e1", "e2", "e3"] + + +def test_same_example_different_dataset_is_kept(tmp_path): + """De-duplication is keyed on example_id AND dataset, not example_id alone.""" + _write_rank_csv( + tmp_path, + rank=0, + epoch=2, + rows=[ + {"example_id": "e1", "dataset": "d1", "lddt": 0.9}, + {"example_id": "e1", "dataset": "d2", "lddt": 0.5}, + ], + ) + + merged = _callback(tmp_path)._load_and_concatenate_csvs(epoch=2) + + assert len(merged) == 2 + assert sorted(merged["dataset"]) == ["d1", "d2"] + + +def test_empty_rank_csv_is_skipped(tmp_path): + """A rank that validated no examples writes a header-only (empty) CSV.""" + _write_rank_csv( + tmp_path, + rank=0, + epoch=5, + rows=[{"example_id": "e1", "dataset": "d1", "lddt": 0.9}], + ) + _write_rank_csv(tmp_path, rank=1, epoch=5, rows=[]) + + merged = _callback(tmp_path)._load_and_concatenate_csvs(epoch=5) + + assert merged["example_id"].tolist() == ["e1"] + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py new file mode 100644 index 00000000..dd666779 --- /dev/null +++ b/tests/test_checkpoint.py @@ -0,0 +1,72 @@ +"""Unit tests for foundry.training.checkpoint. + +`create_custom_forward` adapts a kwargs-taking callable into the positional-only +shape `torch.utils.checkpoint` requires, by binding the kwargs into a closure. +`activation_checkpointing` decorates a function so it runs through gradient +checkpointing when grad is enabled and calls through directly otherwise. The +tests pin the kwarg binding, both branches, and that gradients still flow (and +match the non-checkpointed result) on the checkpointed path. +""" + +import pytest +import torch + +from foundry.training.checkpoint import activation_checkpointing, create_custom_forward + + +def test_create_custom_forward_binds_fixed_kwargs(): + """Bound kwargs are supplied to the wrapped callable on each call.""" + forward = create_custom_forward(lambda a, b: a + b, b=10) + assert forward(5) == 15 + + +def test_create_custom_forward_forwards_positional_inputs(): + """All positional inputs pass through in order, alongside the fixed kwargs.""" + forward = create_custom_forward(lambda a, b, c: (a, b, c), c=3) + assert forward(1, 2) == (1, 2, 3) + + +def test_activation_checkpointing_no_grad_calls_directly(): + """With grad disabled the decorator just calls the function.""" + + @activation_checkpointing + def double(x: torch.Tensor) -> torch.Tensor: + return x * 2 + + with torch.no_grad(): + out = double(torch.tensor([1.0, 2.0])) + assert torch.allclose(out, torch.tensor([2.0, 4.0])) + + +def test_activation_checkpointing_grad_enabled_matches_and_backprops(): + """The checkpointed path returns the same value and propagates gradients.""" + + def square_sum(x: torch.Tensor) -> torch.Tensor: + return (x**2).sum() + + checkpointed = activation_checkpointing(square_sum) + x = torch.tensor([3.0], requires_grad=True) + out = checkpointed(x) + out.backward() + + assert torch.allclose(out, torch.tensor(9.0)) + assert torch.allclose(x.grad, torch.tensor([6.0])) # d/dx x^2 = 2x + + +def test_activation_checkpointing_forwards_kwargs_through_checkpoint(): + """Keyword arguments reach the function via the checkpointed path.""" + + def scale_sum(x: torch.Tensor, *, scale: float) -> torch.Tensor: + return (x * scale).sum() + + checkpointed = activation_checkpointing(scale_sum) + x = torch.tensor([2.0], requires_grad=True) + out = checkpointed(x, scale=4.0) + out.backward() + + assert torch.allclose(out, torch.tensor(8.0)) + assert torch.allclose(x.grad, torch.tensor([4.0])) + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_checkpoint_registry.py b/tests/test_checkpoint_registry.py new file mode 100644 index 00000000..aeadfa15 --- /dev/null +++ b/tests/test_checkpoint_registry.py @@ -0,0 +1,94 @@ +"""Unit tests for the pure path-resolution logic in +foundry.inference_engines.checkpoint_registry. + +The download/dotenv-writing helpers are side-effecting glue. The path search +order is load-bearing (it decides which checkpoint directory wins) and is pinned +here, including the fact that directories from ``FOUNDRY_CHECKPOINT_DIRS`` are +searched *before* the default ``~/.foundry/checkpoints`` directory. +""" + +from pathlib import Path + +import pytest + +from foundry.inference_engines import checkpoint_registry as cr +from foundry.inference_engines.checkpoint_registry import ( + RegisteredCheckpoint, + _normalize_paths, + get_default_checkpoint_dir, + get_default_checkpoint_dirs, +) + + +def test_normalize_paths_absolutizes_and_dedupes_preserving_order(tmp_path): + a, b = tmp_path / "a", tmp_path / "b" + result = _normalize_paths([a, b, a]) + + assert result == [a, b] # duplicate 'a' dropped, order preserved + assert all(p.is_absolute() for p in result) + + +def test_normalize_paths_treats_equivalent_relative_paths_as_one(): + result = _normalize_paths([Path("x"), Path("./x")]) + assert result == [Path("x").absolute()] + + +def _clear_env(monkeypatch): + monkeypatch.delenv("FOUNDRY_CHECKPOINT_DIRS", raising=False) + monkeypatch.delenv("FOUNDRY_CHECKPOINTS_DIR", raising=False) + + +def test_default_dirs_is_just_default_when_env_unset(tmp_path, monkeypatch): + _clear_env(monkeypatch) + default = tmp_path / "default" + monkeypatch.setattr(cr, "DEFAULT_CHECKPOINT_DIR", default) + + assert get_default_checkpoint_dirs() == [default] + + +def test_env_dirs_are_searched_before_the_default(tmp_path, monkeypatch): + default = tmp_path / "default" + monkeypatch.setattr(cr, "DEFAULT_CHECKPOINT_DIR", default) + a, b = tmp_path / "a", tmp_path / "b" + monkeypatch.setenv("FOUNDRY_CHECKPOINT_DIRS", f"{a}:{b}") + + assert get_default_checkpoint_dirs() == [a, b, default] + # The "primary" dir is the first env dir, not the default. + assert get_default_checkpoint_dir() == a + + +def test_legacy_env_var_is_used_when_new_one_is_unset(tmp_path, monkeypatch): + monkeypatch.delenv("FOUNDRY_CHECKPOINT_DIRS", raising=False) + default = tmp_path / "default" + monkeypatch.setattr(cr, "DEFAULT_CHECKPOINT_DIR", default) + legacy = tmp_path / "legacy" + monkeypatch.setenv("FOUNDRY_CHECKPOINTS_DIR", str(legacy)) + + assert get_default_checkpoint_dirs() == [legacy, default] + + +def test_get_default_path_returns_first_existing_file(tmp_path, monkeypatch): + _clear_env(monkeypatch) + default = tmp_path / "default" + monkeypatch.setattr(cr, "DEFAULT_CHECKPOINT_DIR", default) + extra = tmp_path / "extra" + extra.mkdir() + (extra / "model.ckpt").write_text("weights") + monkeypatch.setenv("FOUNDRY_CHECKPOINT_DIRS", str(extra)) + + ckpt = RegisteredCheckpoint(url="u", filename="model.ckpt", description="d") + assert ckpt.get_default_path() == extra / "model.ckpt" + + +def test_get_default_path_falls_back_to_primary_dir_when_missing(tmp_path, monkeypatch): + _clear_env(monkeypatch) + default = tmp_path / "default" + monkeypatch.setattr(cr, "DEFAULT_CHECKPOINT_DIR", default) + + ckpt = RegisteredCheckpoint(url="u", filename="absent.ckpt", description="d") + # Nothing exists anywhere -> the primary (first) dir / filename. + assert ckpt.get_default_path() == default / "absent.ckpt" + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 00000000..4c6f544a --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,113 @@ +"""Unit tests for foundry.common. + +These are small, pervasively-used helpers with no prior coverage. The contracts +worth pinning are the ones where behaviour is easy to get subtly wrong: +`exists`/`default` treat *only* `None` as absent (so `0`/`""`/`[]` are present), +`run_once` fires its wrapped function exactly once per process, `concat_dicts` +zips same-keyed dicts into lists, the `listmap` family materialises a list, and +`ensure_dtype` is a no-op (same object) when the dtype already matches. +""" + +import pytest +import torch + +from foundry.common import ( + at_least_one_exists, + concat_dicts, + default, + do_nothing, + ensure_dtype, + exactly_one_exists, + exists, + listmap, + listmap_with_idx, + run_once, +) + + +def test_run_once_executes_only_once(): + calls = [] + + @run_once + def record() -> str: + calls.append(1) + return "ran" + + assert record() == "ran" + assert record() is None # second call short-circuits + assert record() is None + assert calls == [1] + + +def test_do_nothing_returns_none_for_any_args(): + assert do_nothing() is None + assert do_nothing(1, 2, key="value") is None + + +def test_exists_treats_only_none_as_absent(): + assert exists(None) is False + assert exists(0) is True + assert exists("") is True + assert exists([]) is True + assert exists(False) is True + + +def test_default_falls_back_only_on_none(): + assert default(None, 5) == 5 + assert default(3, 5) == 3 + assert default(0, 5) == 0 # 0 exists, so it is kept + + +def test_exactly_one_exists(): + assert exactly_one_exists(1, None) is True + assert exactly_one_exists(1) is True + assert exactly_one_exists(1, 2) is False + assert exactly_one_exists(None, None) is False + + +def test_at_least_one_exists(): + assert at_least_one_exists(None, 1) is True + assert at_least_one_exists(None, None) is False + assert at_least_one_exists() is False + + +def test_concat_dicts_zips_same_keys_into_lists(): + assert concat_dicts({"a": 1, "b": 2}, {"a": 3, "b": 4}) == { + "a": [1, 3], + "b": [2, 4], + } + + +def test_concat_dicts_single_dict_wraps_values(): + assert concat_dicts({"a": 1, "b": 2}) == {"a": [1], "b": [2]} + + +def test_listmap_applies_and_materialises(): + assert listmap(lambda x: x + 1, [1, 2, 3]) == [2, 3, 4] + assert listmap(str, (i for i in range(3))) == ["0", "1", "2"] # consumes iterables + assert listmap(lambda x: x, []) == [] + + +def test_listmap_with_idx_passes_index_and_value(): + assert listmap_with_idx(lambda i, x: f"{i}_{x}", ["a", "b", "c"]) == [ + "0_a", + "1_b", + "2_c", + ] + + +def test_ensure_dtype_noop_when_already_matching(): + """Matching dtype returns the same tensor object (no copy).""" + t = torch.ones(3, dtype=torch.float32) + assert ensure_dtype(t, torch.float32) is t + + +def test_ensure_dtype_converts_when_mismatched(): + t = torch.ones(3, dtype=torch.float32) + out = ensure_dtype(t, torch.float64) + assert out.dtype == torch.float64 + assert torch.allclose(out, t.double()) + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 00000000..4874d4c1 --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,300 @@ +"""Unit tests for foundry.utils.datasets. + +These helpers build the training/validation dataloading stack from hydra configs: +selecting a sampler from config keys, wrapping a dataset+sampler with fallback +dataloading, and converting a non-distributed sampler into a distributed one +before assembling the final ``DataLoader``. The contracts worth pinning are the +control flow and the input-validation guards, not the heavy atomworks dataset +machinery, so the fixtures are tiny map-style datasets and stock torch samplers. +""" + +import pandas as pd +import pytest +import torch +from atomworks.ml.datasets import FallbackDatasetWrapper +from atomworks.ml.samplers import ( + DistributedMixedSampler, + FallbackSamplerWrapper, + MixedSampler, +) +from omegaconf import OmegaConf +from torch.utils.data import ( + DataLoader, + Dataset, + RandomSampler, + Sampler, + SequentialSampler, + Subset, + WeightedRandomSampler, +) +from torch.utils.data.distributed import DistributedSampler + +from foundry.utils.datasets import ( + assemble_distributed_loader, + instantiate_single_dataset_and_sampler, + wrap_dataset_and_sampler_with_fallbacks, +) + + +class _Tiny(Dataset): + """Minimal map-style dataset: indexable, with a ``__len__`` and a ``.data`` frame.""" + + def __init__(self, n: int = 4): + self._n = n + self.data = pd.DataFrame({"x": range(n)}) + + def __len__(self) -> int: + return self._n + + def __getitem__(self, i: int) -> int: + return i + + +class _StubSampler(Sampler): + """A sampler with no required init args, to verify the ``sampler`` config branch.""" + + def __iter__(self): + return iter(()) + + def __len__(self) -> int: + return 0 + + +def _weights_from_df(dataset_df: pd.DataFrame) -> torch.Tensor: + """Hydra target for the ``weights`` config key; receives ``dataset_df`` like the real ones.""" + return torch.tensor([0.5, 0.3, 0.2]) + + +# fully-qualified targets for hydra to import (module is already in sys.modules under __name__) +_DATASET_TARGET = f"{__name__}._Tiny" +_WEIGHTS_TARGET = f"{__name__}._weights_from_df" +_SAMPLER_TARGET = f"{__name__}._StubSampler" + + +# -------------------------------------------------------------------------------------- +# wrap_dataset_and_sampler_with_fallbacks +# -------------------------------------------------------------------------------------- + + +def test_wrap_uses_fallback_sampler_weights_when_present(): + """A weighted fallback sampler's own weights are reused (the `hasattr` fix). + + Regression for the latent bug where `"weights" in sampler` iterated the sampler's + integer indices and never matched the string, so a weighted sampler silently fell + back to uniform weights. + """ + dataset, fallback = _Tiny(3), _Tiny(3) + weighted = WeightedRandomSampler( + weights=torch.tensor([0.1, 0.2, 0.7]), num_samples=3, replacement=True + ) + + _, wrapped_sampler = wrap_dataset_and_sampler_with_fallbacks( + dataset, SequentialSampler(dataset), fallback, weighted, n_fallback_retries=2 + ) + + assert wrapped_sampler.fallback_sampler.weights.tolist() == pytest.approx( + [0.1, 0.2, 0.7] + ) + + +def test_wrap_uses_uniform_weights_when_sampler_has_no_weights(): + """A sampler without `.weights` yields uniform weights sized to the fallback dataset.""" + dataset, fallback = _Tiny(3), _Tiny(5) + + _, wrapped_sampler = wrap_dataset_and_sampler_with_fallbacks( + dataset, + SequentialSampler(dataset), + fallback, + SequentialSampler(fallback), + n_fallback_retries=2, + ) + + assert wrapped_sampler.fallback_sampler.weights.tolist() == [1.0] * 5 + + +def test_wrap_returns_fallback_wrapper_types(): + dataset, fallback = _Tiny(3), _Tiny(3) + + wrapped_dataset, wrapped_sampler = wrap_dataset_and_sampler_with_fallbacks( + dataset, + SequentialSampler(dataset), + fallback, + SequentialSampler(fallback), + n_fallback_retries=3, + ) + + assert isinstance(wrapped_dataset, FallbackDatasetWrapper) + assert isinstance(wrapped_sampler, FallbackSamplerWrapper) + assert wrapped_sampler.n_fallback_retries == 3 + + +# -------------------------------------------------------------------------------------- +# instantiate_single_dataset_and_sampler +# -------------------------------------------------------------------------------------- + + +def test_instantiate_weights_only_builds_weighted_sampler(): + """`weights` without `sampler` -> WeightedRandomSampler from the instantiated weights.""" + cfg = OmegaConf.create( + { + "dataset": {"_target_": _DATASET_TARGET, "n": 3}, + "weights": {"_target_": _WEIGHTS_TARGET}, + } + ) + result = instantiate_single_dataset_and_sampler(cfg) + + assert isinstance(result["sampler"], WeightedRandomSampler) + assert result["sampler"].num_samples == 3 + assert result["sampler"].weights.tolist() == pytest.approx([0.5, 0.3, 0.2]) + + +def test_instantiate_sampler_only_uses_provided_sampler(): + """`sampler` without `weights` -> that sampler is instantiated verbatim.""" + cfg = OmegaConf.create( + { + "dataset": {"_target_": _DATASET_TARGET, "n": 3}, + "sampler": {"_target_": _SAMPLER_TARGET}, + } + ) + result = instantiate_single_dataset_and_sampler(cfg) + + assert isinstance(result["sampler"], _StubSampler) + + +def test_instantiate_neither_key_falls_back_to_uniform(): + """Neither `weights` nor `sampler` -> uniform WeightedRandomSampler over the dataset.""" + cfg = OmegaConf.create({"dataset": {"_target_": _DATASET_TARGET, "n": 4}}) + result = instantiate_single_dataset_and_sampler(cfg) + + assert isinstance(result["sampler"], WeightedRandomSampler) + assert result["sampler"].num_samples == 4 + assert result["sampler"].weights.tolist() == [1.0] * 4 + + +def test_instantiate_both_keys_falls_back_to_uniform(): + """Providing BOTH `weights` and `sampler` falls through to uniform weights (not either one).""" + cfg = OmegaConf.create( + { + "dataset": {"_target_": _DATASET_TARGET, "n": 4}, + "weights": {"_target_": _WEIGHTS_TARGET}, + "sampler": {"_target_": _SAMPLER_TARGET}, + } + ) + result = instantiate_single_dataset_and_sampler(cfg) + + assert isinstance(result["sampler"], WeightedRandomSampler) + assert result["sampler"].weights.tolist() == [1.0] * 4 + + +# -------------------------------------------------------------------------------------- +# assemble_distributed_loader +# -------------------------------------------------------------------------------------- + + +@pytest.mark.parametrize("sampler_cls", [RandomSampler, SequentialSampler]) +def test_assemble_random_sequential_requires_rank_world_size(sampler_cls): + dataset = _Tiny(4) + with pytest.raises(AssertionError, match="Rank and world_size must be provided"): + assemble_distributed_loader(dataset, sampler=sampler_cls(dataset)) + + +def test_assemble_converts_sequential_to_distributed_sampler(): + dataset = _Tiny(4) + loader = assemble_distributed_loader( + dataset, sampler=SequentialSampler(dataset), rank=0, world_size=1 + ) + + assert isinstance(loader, DataLoader) + assert isinstance(loader.sampler, DistributedSampler) + + +def test_assemble_mixed_sampler_requires_distributed_args(): + dataset = _Tiny(4) + mixed = MixedSampler( + datasets_info=[ + { + "sampler": SequentialSampler(dataset), + "dataset": dataset, + "probability": 1.0, + } + ], + n_examples_per_epoch=None, + ) + with pytest.raises(AssertionError, match="must be provided for MixedSampler"): + assemble_distributed_loader(dataset, sampler=mixed) + + +def test_assemble_converts_mixed_to_distributed_mixed_sampler(): + dataset = _Tiny(4) + mixed = MixedSampler( + datasets_info=[ + { + "sampler": SequentialSampler(dataset), + "dataset": dataset, + "probability": 1.0, + } + ], + n_examples_per_epoch=None, + ) + loader = assemble_distributed_loader( + dataset, sampler=mixed, rank=0, world_size=1, n_examples_per_epoch=4 + ) + + assert isinstance(loader.sampler, DistributedMixedSampler) + + +def test_assemble_rejects_unknown_sampler_type(): + """A non-distributed sampler that isn't Mixed/Random/Sequential is rejected.""" + dataset = _Tiny(3) + bare = WeightedRandomSampler(weights=torch.ones(3), num_samples=3, replacement=True) + with pytest.raises(AssertionError, match="Invalid sampler type"): + assemble_distributed_loader(dataset, sampler=bare) + + +def test_assemble_rejects_rank_with_already_distributed_sampler(): + dataset = _Tiny(4) + dist = DistributedSampler(dataset, num_replicas=1, rank=0) + with pytest.raises(AssertionError, match="will have no effect"): + assemble_distributed_loader(dataset, sampler=dist, rank=0, world_size=1) + + +def test_assemble_passes_through_distributed_sampler(): + dataset = _Tiny(4) + dist = DistributedSampler(dataset, num_replicas=1, rank=0) + loader = assemble_distributed_loader(dataset, sampler=dist) + + assert loader.sampler is dist + assert loader.dataset is dataset + + +def test_assemble_subset_with_no_sampler(): + """A pre-subset dataset with sampler=None is loaded as-is (no distributed sampler).""" + subset = Subset(_Tiny(4), [0, 1]) + loader = assemble_distributed_loader(subset, sampler=None) + + assert loader.dataset is subset + + +def test_assemble_wraps_with_fallbacks_when_configured(): + dataset = _Tiny(4) + dist = DistributedSampler(dataset, num_replicas=1, rank=0) + loader = assemble_distributed_loader( + dataset, sampler=dist, loader_cfg={"n_fallback_retries": 2} + ) + + assert isinstance(loader.sampler, FallbackSamplerWrapper) + assert isinstance(loader.dataset, FallbackDatasetWrapper) + + +def test_assemble_forwards_dataloader_params(): + dataset = _Tiny(4) + dist = DistributedSampler(dataset, num_replicas=1, rank=0) + loader = assemble_distributed_loader( + dataset, sampler=dist, loader_cfg={"dataloader_params": {"batch_size": 2}} + ) + + assert loader.batch_size == 2 + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_download_checkpoints.py b/tests/test_download_checkpoints.py new file mode 100644 index 00000000..de5a85e7 --- /dev/null +++ b/tests/test_download_checkpoints.py @@ -0,0 +1,80 @@ +"""Unit tests for foundry_cli.download_checkpoints. + +The actual download path is network/file I/O glue and is not tested. The pieces +with a non-obvious, deterministic contract are pinned here: ``_resolve_checkpoint_dirs`` +puts the user-requested directory first (inserting or moving it to the front), +and the ``list-available`` / ``list-installed`` commands report correctly on an +empty vs populated checkpoint directory. +""" + +import pytest +from typer.testing import CliRunner + +from foundry_cli import download_checkpoints as dc + +runner = CliRunner() + + +@pytest.fixture +def no_env_persistence(monkeypatch): + """Stop _resolve_checkpoint_dirs from touching a real .env file.""" + monkeypatch.setattr(dc, "append_checkpoint_to_env", lambda dirs: False) + + +def test_resolve_dirs_returns_defaults_when_none(monkeypatch, tmp_path): + base = [tmp_path / "a", tmp_path / "b"] + monkeypatch.setattr(dc, "get_default_checkpoint_dirs", lambda: list(base)) + + assert dc._resolve_checkpoint_dirs(None) == base + + +def test_resolve_dirs_prepends_a_new_directory( + monkeypatch, tmp_path, no_env_persistence +): + base = [tmp_path / "a", tmp_path / "b"] + monkeypatch.setattr(dc, "get_default_checkpoint_dirs", lambda: list(base)) + extra = (tmp_path / "extra").absolute() + + assert dc._resolve_checkpoint_dirs(extra) == [extra, *base] + + +def test_resolve_dirs_moves_an_existing_directory_to_front( + monkeypatch, tmp_path, no_env_persistence +): + a, b = (tmp_path / "a").absolute(), (tmp_path / "b").absolute() + monkeypatch.setattr(dc, "get_default_checkpoint_dirs", lambda: [a, b]) + + # Requesting 'b' (already present) moves it to the front without duplicating. + assert dc._resolve_checkpoint_dirs(b) == [b, a] + + +def test_list_available_lists_registered_models(): + result = runner.invoke(dc.app, ["list-available"]) + + assert result.exit_code == 0 + assert "Available models" in result.stdout + assert "rf3" in result.stdout + + +def test_list_installed_reports_empty_directory(monkeypatch, tmp_path): + monkeypatch.setattr(dc, "get_default_checkpoint_dirs", lambda: [tmp_path]) + + result = runner.invoke(dc.app, ["list-installed"]) + + assert result.exit_code == 0 + assert "No checkpoint files found" in result.stdout + + +def test_list_installed_totals_populated_directory(monkeypatch, tmp_path): + (tmp_path / "model.ckpt").write_bytes(b"x" * 1024) + monkeypatch.setattr(dc, "get_default_checkpoint_dirs", lambda: [tmp_path]) + + result = runner.invoke(dc.app, ["list-installed"]) + + assert result.exit_code == 0 + assert "No checkpoint files found" not in result.stdout + assert "Total:" in result.stdout + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_ema.py b/tests/test_ema.py new file mode 100644 index 00000000..ca4f6d51 --- /dev/null +++ b/tests/test_ema.py @@ -0,0 +1,113 @@ +"""Unit tests for foundry.training.EMA. + +`EMA` keeps a shadow copy of a model whose parameters track the live model via +the exponential-moving-average update +``shadow -= (1 - decay) * (shadow - param)``. The contracts worth pinning are +numeric and behavioural: the update applies that exact formula, only touches +parameters that require grad, copies buffers verbatim (not EMA'd), refuses to +run outside training mode, and `forward` dispatches to the live model while +training and to the shadow while evaluating. +""" + +import pytest +import torch +import torch.nn as nn + +from foundry.training.EMA import EMA + + +class _TinyModel(nn.Module): + """Minimal module with two parameters and a buffer for exercising EMA.""" + + def __init__(self) -> None: + super().__init__() + self.weight = nn.Parameter(torch.zeros(2, 3)) + self.bias = nn.Parameter(torch.zeros(2)) + self.register_buffer("counter", torch.zeros(1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x @ self.weight.t() + self.bias + + +def test_shadow_is_detached_at_init(): + """The shadow starts as a copy whose parameters are off the autograd graph.""" + ema = EMA(_TinyModel(), decay=0.9) + assert all(not p.requires_grad for p in ema.shadow.parameters()) + assert all(p.requires_grad for p in ema.model.parameters()) + + +def test_update_applies_ema_formula(): + """shadow moves toward the live param by exactly (1 - decay) of the gap.""" + model = _TinyModel() + with torch.no_grad(): + model.weight.fill_(1.0) + ema = EMA(model, decay=0.9) # shadow.weight captured at 1.0 + with torch.no_grad(): + model.weight.fill_(2.0) + + ema.train() + ema.update() + + # 1.0 - (1 - 0.9) * (1.0 - 2.0) = 1.0 + 0.1 = 1.1 + assert torch.allclose(ema.shadow.weight, torch.full((2, 3), 1.1)) + + +def test_update_skips_frozen_params(): + """A parameter with requires_grad=False is left untouched by the update.""" + model = _TinyModel() + with torch.no_grad(): + model.weight.fill_(1.0) + model.bias.fill_(1.0) + model.bias.requires_grad_(False) + ema = EMA(model, decay=0.5) + with torch.no_grad(): + model.weight.fill_(3.0) + model.bias.fill_(3.0) + + ema.train() + ema.update() + + # weight is trainable: 1.0 - 0.5 * (1.0 - 3.0) = 2.0 + assert torch.allclose(ema.shadow.weight, torch.full((2, 3), 2.0)) + # bias is frozen: unchanged from its captured value + assert torch.allclose(ema.shadow.bias, torch.full((2,), 1.0)) + + +def test_update_copies_buffers_verbatim(): + """Buffers are copied, not exponentially averaged.""" + model = _TinyModel() + ema = EMA(model, decay=0.5) # shadow.counter captured at 0.0 + with torch.no_grad(): + model.counter.fill_(7.0) + + ema.train() + ema.update() + + # A copy gives 7.0; an EMA with decay 0.5 from 0.0 would give 3.5. + assert torch.allclose(ema.shadow.counter, torch.full((1,), 7.0)) + + +def test_update_raises_outside_training(): + ema = EMA(_TinyModel(), decay=0.9) + ema.eval() + with pytest.raises(RuntimeError, match="during training"): + ema.update() + + +def test_forward_dispatches_model_in_train_shadow_in_eval(): + """Training routes to the live model; evaluation routes to the shadow.""" + model = _TinyModel() + ema = EMA(model, decay=0.9) + with torch.no_grad(): + ema.shadow.bias.fill_(5.0) # make the shadow differ from the model + x = torch.zeros(4, 3) + + ema.train() + assert torch.allclose(ema(x), torch.zeros(4, 2)) # model bias is 0 + + ema.eval() + assert torch.allclose(ema(x), torch.full((4, 2), 5.0)) # shadow bias is 5 + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_instantiators.py b/tests/test_instantiators.py new file mode 100644 index 00000000..59c378fd --- /dev/null +++ b/tests/test_instantiators.py @@ -0,0 +1,81 @@ +"""Unit tests for foundry.utils.instantiators. + +These helpers turn a hydra config group into a list of instantiated objects. +The contract worth pinning is the control flow, not the object type: a missing / +empty config yields an empty list, each sub-config is instantiated via its +``_target_``, and a sub-config that is not an instantiable ``DictConfig`` (no +``_target_`` key) raises ``InstantiationError``. The functions do not themselves +check that the result is a callback / logger, so the tests use a lightweight +stdlib target (``types.SimpleNamespace``) to exercise that flow directly. +""" + +from types import SimpleNamespace + +import pytest +from omegaconf import OmegaConf + +from foundry.utils.instantiators import ( + InstantiationError, + _can_be_instantiated, + instantiate_callbacks, + instantiate_loggers, +) + +_TARGET = "types.SimpleNamespace" + + +def test_can_be_instantiated_true_with_target(): + assert _can_be_instantiated(OmegaConf.create({"_target_": _TARGET})) is True + + +def test_can_be_instantiated_false_without_target(): + assert _can_be_instantiated(OmegaConf.create({"x": 1})) is False + + +def test_can_be_instantiated_false_for_non_dictconfig(): + """A plain dict is not a DictConfig, so it is not instantiable.""" + assert _can_be_instantiated({"_target_": _TARGET}) is False + + +def test_instantiate_callbacks_none_returns_empty(): + assert instantiate_callbacks(None) == [] + + +def test_instantiate_callbacks_empty_config_returns_empty(): + assert instantiate_callbacks(OmegaConf.create({})) == [] + + +def test_instantiate_callbacks_builds_each_target_in_order(): + cfg = OmegaConf.create( + { + "first": {"_target_": _TARGET, "x": 1}, + "second": {"_target_": _TARGET, "x": 2}, + } + ) + result = instantiate_callbacks(cfg) + assert result == [SimpleNamespace(x=1), SimpleNamespace(x=2)] + + +def test_instantiate_callbacks_raises_on_missing_target(): + cfg = OmegaConf.create({"bad": {"x": 1}}) + with pytest.raises(InstantiationError): + instantiate_callbacks(cfg) + + +def test_instantiate_loggers_none_returns_empty(): + assert instantiate_loggers(None) == [] + + +def test_instantiate_loggers_builds_target(): + cfg = OmegaConf.create({"logger": {"_target_": _TARGET, "name": "run"}}) + assert instantiate_loggers(cfg) == [SimpleNamespace(name="run")] + + +def test_instantiate_loggers_raises_on_missing_target(): + cfg = OmegaConf.create({"bad": {"name": "run"}}) + with pytest.raises(InstantiationError): + instantiate_loggers(cfg) + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 00000000..e81285d4 --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,76 @@ +"""Unit tests for the pure helpers in foundry.utils.logging. + +The display/configuration functions in this module are side-effecting glue +(Rich console output, warning filters, logger levels). The two pieces with a +non-obvious, environment-independent contract are pinned here: + +- ``CachedDataFilter`` suppresses a specific atomworks log line by substring. +- ``condense_count_columns_of_grouped_df`` collapses the repeated per-metric + ``count`` columns of a grouped (MultiIndex-column) DataFrame into one ``Count`` + column — but only when the count is identical across metrics in every row, + and only for a MultiIndex frame with both ``count`` and ``mean`` sub-levels. +""" + +import logging + +import pandas as pd +import pytest + +from foundry.utils.logging import ( + CachedDataFilter, + condense_count_columns_of_grouped_df, +) + + +def _record(msg: str) -> logging.LogRecord: + return logging.LogRecord("test", logging.INFO, __file__, 1, msg, None, None) + + +def test_cached_data_filter_suppresses_cached_data_message(): + assert ( + CachedDataFilter().filter(_record("Cached data not found at /tmp/x")) is False + ) + + +def test_cached_data_filter_keeps_unrelated_message(): + assert CachedDataFilter().filter(_record("Loaded 12 structures")) is True + + +def _grouped(rows: list[list[float]]) -> pd.DataFrame: + """Frame with MultiIndex columns (metric, {count,mean}) for two metrics.""" + cols = pd.MultiIndex.from_tuples( + [("a", "count"), ("a", "mean"), ("b", "count"), ("b", "mean")] + ) + return pd.DataFrame(rows, columns=cols) + + +def test_condense_returns_non_multiindex_frame_unchanged(): + df = pd.DataFrame({"x": [1, 2], "y": [3, 4]}) + assert condense_count_columns_of_grouped_df(df) is df + + +def test_condense_collapses_consistent_counts(): + df = _grouped([[5, 1.0, 5, 2.0], [3, 0.5, 3, 1.5]]) + result = condense_count_columns_of_grouped_df(df) + + assert list(result.columns) == ["a (mean)", "b (mean)", "Count"] + assert result["Count"].tolist() == [5, 3] + assert result["a (mean)"].tolist() == [1.0, 0.5] + assert result["b (mean)"].tolist() == [2.0, 1.5] + + +def test_condense_leaves_frame_when_counts_disagree_within_a_row(): + """Row 0's metrics have counts 5 vs 6, so the frame is returned untouched.""" + df = _grouped([[5, 1.0, 6, 2.0]]) + assert condense_count_columns_of_grouped_df(df) is df + + +def test_condense_leaves_frame_without_a_count_sublevel(): + """MultiIndex columns lacking a 'count' level raise KeyError -> returned as-is.""" + cols = pd.MultiIndex.from_tuples([("a", "total"), ("a", "mean")]) + df = pd.DataFrame([[5, 1.0]], columns=cols) + assert condense_count_columns_of_grouped_df(df) is df + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_losses.py b/tests/test_losses.py new file mode 100644 index 00000000..cad39bd3 --- /dev/null +++ b/tests/test_losses.py @@ -0,0 +1,62 @@ +"""Unit tests for foundry.metrics.losses. + +`Loss` aggregates a set of child loss functions: its `forward` sums their scalar +losses, merges their per-loss dicts, and records the (detached) running total +under `total_loss` while still returning the grad-carrying sum. The child losses +are normally Hydra-instantiated; here we set `to_compute` directly with stubs to +exercise the aggregation logic without a config. +""" + +import pytest +import torch + +from foundry.metrics.losses import Loss + + +def _stub_loss(value: float, extra: dict, requires_grad: bool = False): + """A child loss returning a fixed scalar tensor and a fixed loss dict.""" + tensor = torch.tensor(value, requires_grad=requires_grad) + + def loss_fn(network_input, network_output, loss_input): + return tensor, dict(extra) + + return loss_fn + + +def test_empty_loss_has_no_children(): + assert Loss().to_compute == [] + + +def test_forward_sums_children_and_merges_dicts(): + loss = Loss() + loss.to_compute = [ + _stub_loss(1.0, {"a": 10}), + _stub_loss(2.0, {"b": 20}), + ] + + total, loss_dict = loss({}, {}, {}) + + assert torch.allclose(total, torch.tensor(3.0)) + assert loss_dict["a"] == 10 + assert loss_dict["b"] == 20 + assert torch.allclose(loss_dict["total_loss"], torch.tensor(3.0)) + + +def test_forward_total_loss_is_detached_but_returned_loss_keeps_grad(): + loss = Loss() + loss.to_compute = [ + _stub_loss(2.0, {}, requires_grad=True), + _stub_loss(3.0, {}, requires_grad=True), + ] + + total, loss_dict = loss({}, {}, {}) + + # The returned aggregate still carries grad for the backward pass... + assert total.requires_grad + # ...while the logged copy is detached. + assert not loss_dict["total_loss"].requires_grad + assert torch.allclose(loss_dict["total_loss"], torch.tensor(5.0)) + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_metric.py b/tests/test_metric.py new file mode 100644 index 00000000..0e24bc92 --- /dev/null +++ b/tests/test_metric.py @@ -0,0 +1,160 @@ +"""Unit tests for foundry.metrics.metric. + +`Metric` + `MetricManager` are the introspection machinery every model's +validation metrics ride on. The contracts worth pinning are non-obvious from +the signatures: `Metric.__init__` rejects required/prohibited tag conflicts; +`compute_from_kwargs` either forwards kwargs verbatim (no mapping) or pulls each +compute argument from a nested-key path, treating `optional_kwargs` as +present-only; and `MetricManager.__call__` extracts `example_id`, prefixes each +result key with the metric name, and skips metrics whose tag requirements the +batch does not satisfy. +""" + +from functools import cached_property + +import pytest + +from foundry.metrics.metric import Metric, MetricManager + + +class _SumMetric(Metric): + """No key mapping: receives kwargs verbatim; absorbs extras via **kwargs.""" + + def compute(self, x, y, **kwargs): + return {"value": x + y} + + +class _MappedMetric(Metric): + """Pulls compute args from nested key paths in the incoming kwargs.""" + + @cached_property + def kwargs_to_compute_args(self): + return {"x": ("a", "b"), "y": ("c",)} + + def compute(self, x, y): + return {"value": x + y} + + +class _OptionalMetric(Metric): + @cached_property + def kwargs_to_compute_args(self): + return {"x": ("x",), "opt": ("opt",)} + + @property + def optional_kwargs(self): + return frozenset(["opt"]) + + def compute(self, x, opt="default"): + return {"x": x, "opt": opt} + + +class _ListMetric(Metric): + def compute(self, **kwargs): + return [{"row": 1}, {"row": 2}] + + +class _BoomMetric(Metric): + def compute(self, **kwargs): + raise ValueError("boom") + + +# --- Metric base ----------------------------------------------------------- + + +def test_tag_conflict_raises(): + with pytest.raises(ValueError, match="disjoint"): + _SumMetric(required_tags_all=["a"], prohibited_tags=["a"]) + + +def test_required_compute_args_read_from_signature(): + assert _MappedMetric().required_compute_args == frozenset({"x", "y"}) + + +# --- compute_from_kwargs --------------------------------------------------- + + +def test_compute_from_kwargs_passes_through_without_mapping(): + assert _SumMetric().compute_from_kwargs(x=1, y=2) == {"value": 3} + + +def test_compute_from_kwargs_remaps_nested_keys(): + result = _MappedMetric().compute_from_kwargs(a={"b": 10}, c=5) + assert result == {"value": 15} + + +def test_compute_from_kwargs_optional_absent_uses_default(): + assert _OptionalMetric().compute_from_kwargs(x=1) == {"x": 1, "opt": "default"} + + +def test_compute_from_kwargs_optional_present_is_passed(): + assert _OptionalMetric().compute_from_kwargs(x=1, opt=99) == {"x": 1, "opt": 99} + + +# --- MetricManager --------------------------------------------------------- + + +def test_manager_prefixes_keys_and_defaults_example_id_to_none(): + manager = MetricManager({"sum": _SumMetric()}) + assert manager(x=1, y=2) == {"example_id": None, "sum.value": 3} + + +def test_manager_extracts_example_id_from_extra_info(): + manager = MetricManager({"sum": _SumMetric()}) + result = manager(x=1, y=2, extra_info={"example_id": "abc"}) + assert result["example_id"] == "abc" + assert result["sum.value"] == 3 + + +def test_manager_required_tags_all_must_all_be_present(): + manager = MetricManager({"sum": _SumMetric(required_tags_all=["needed"])}) + missing = manager(x=1, y=2, extra_info={"metrics_tags": ["other"]}) + assert "sum.value" not in missing + present = manager(x=1, y=2, extra_info={"metrics_tags": ["needed", "other"]}) + assert present["sum.value"] == 3 + + +def test_manager_required_tags_any_needs_one(): + manager = MetricManager({"sum": _SumMetric(required_tags_any=["p", "q"])}) + missing = manager(x=1, y=2, extra_info={"metrics_tags": ["other"]}) + assert "sum.value" not in missing + present = manager(x=1, y=2, extra_info={"metrics_tags": ["q"]}) + assert present["sum.value"] == 3 + + +def test_manager_prohibited_tags_block_computation(): + manager = MetricManager({"sum": _SumMetric(prohibited_tags=["skip"])}) + blocked = manager(x=1, y=2, extra_info={"metrics_tags": ["skip"]}) + assert "sum.value" not in blocked + + +def test_manager_list_result_stored_under_metric_name(): + manager = MetricManager({"rows": _ListMetric()}) + result = manager(anything=1) + assert result["rows"] == [{"row": 1}, {"row": 2}] + + +def test_manager_swallows_failure_when_raise_errors_false(): + manager = MetricManager({"boom": _BoomMetric()}, raise_errors=False) + result = manager(x=1) # the failing metric is skipped, no exception + assert "boom" not in result + assert result == {"example_id": None} + + +def test_manager_propagates_failure_when_raise_errors_true(): + manager = MetricManager({"boom": _BoomMetric()}, raise_errors=True) + with pytest.raises(ValueError, match="boom"): + manager(x=1) + + +def test_from_metrics_accepts_list_of_tuples(): + manager = MetricManager.from_metrics([("sum", _SumMetric())]) + assert manager(x=1, y=2)["sum.value"] == 3 + + +def test_from_metrics_rejects_non_metric(): + with pytest.raises(TypeError, match="must be a Metric"): + MetricManager.from_metrics({"bad": object()}) + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_resolvers.py b/tests/test_resolvers.py new file mode 100644 index 00000000..5c4557f8 --- /dev/null +++ b/tests/test_resolvers.py @@ -0,0 +1,51 @@ +"""Unit tests for the pure Hydra resolver helpers in foundry.hydra.resolvers. + +``register_resolvers`` is one-shot global registration with OmegaConf (side +effecting) and is not tested. The two resolver functions have a non-obvious +contract pinned here: ``resolve_import`` walks a dotted attribute path, and +``chain_type_info_to_regex`` builds an alternation regex from ChainType / +ChainTypeInfo enum members. +""" + +import os + +import pytest +from atomworks.enums import ChainType, ChainTypeInfo + +from foundry.hydra.resolvers import chain_type_info_to_regex, resolve_import + + +def test_resolve_import_returns_the_module_when_no_attribute(): + assert resolve_import("os") is os + + +def test_resolve_import_walks_a_dotted_attribute_path(): + # os.path.join is reached by splitting "path.join" and chaining getattr. + assert resolve_import("os", "path.join") is os.path.join + + +def test_resolve_import_resolves_a_single_attribute(): + assert resolve_import("os", "sep") == os.sep + + +def test_chain_type_info_to_regex_uses_chain_type_value(): + assert chain_type_info_to_regex("DNA") == str(ChainType.DNA.value) + + +def test_chain_type_info_to_regex_expands_a_chain_type_info_group(): + expected = "|".join(str(ct.value) for ct in ChainTypeInfo.PROTEINS) + assert chain_type_info_to_regex("PROTEINS") == expected + + +def test_chain_type_info_to_regex_joins_multiple_args_with_pipe(): + result = chain_type_info_to_regex("DNA", "RNA") + assert result == f"{ChainType.DNA.value}|{ChainType.RNA.value}" + + +def test_chain_type_info_to_regex_rejects_unknown_attribute(): + with pytest.raises(ValueError, match="Attribute not found"): + chain_type_info_to_regex("NOT_A_CHAIN_TYPE") + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py new file mode 100644 index 00000000..c6d66395 --- /dev/null +++ b/tests/test_schedulers.py @@ -0,0 +1,97 @@ +"""Unit tests for foundry.training.schedulers. + +`AF3Scheduler` implements the AF-3 two-phase learning-rate schedule: a linear +warmup from 0 to `base_lr` over `warmup_steps`, then a geometric decay by +`decay_factor` every `decay_steps`. The tests pin those phase boundaries on a +small, exactly-computable schedule. `SchedulerConfig` is a thin Lightning-style +config wrapper whose state-dict round-trip is also pinned. +""" + +import pytest +import torch +from torch.optim import SGD + +from foundry.training.schedulers import AF3Scheduler, SchedulerConfig + +# Small, exact schedule: base_lr=1.0 keeps the expected LR values trivial. +_KW = dict(base_lr=1.0, warmup_steps=10, decay_factor=0.5, decay_steps=20) + + +def _single_param_optimizer() -> SGD: + return SGD([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + + +def _lr_after_steps(n: int) -> float: + """LR reported after advancing a fresh AF3Scheduler `n` times.""" + opt = _single_param_optimizer() + scheduler = AF3Scheduler(opt, **_KW) + for _ in range(n): + opt.step() # documented order: optimizer before scheduler + scheduler.step() + return scheduler.get_last_lr()[0] + + +def test_initial_lr_is_zero(): + """Construction steps once to last_epoch=0, the start of warmup (LR 0).""" + opt = _single_param_optimizer() + scheduler = AF3Scheduler(opt, **_KW) + assert scheduler.get_last_lr()[0] == pytest.approx(0.0) + + +def test_linear_warmup_midpoint(): + # last_epoch=5, warmup_steps=10 -> 1.0 * 5/10 + assert _lr_after_steps(5) == pytest.approx(0.5) + + +def test_lr_reaches_base_at_end_of_warmup(): + # last_epoch=10: warmup is exclusive (10 < 10 is False), so decay branch + # with num_decays=0 -> base_lr * 0.5**0 = 1.0 + assert _lr_after_steps(10) == pytest.approx(1.0) + + +def test_geometric_decay_after_warmup(): + # last_epoch=30 -> num_decays=(30-10)//20=1 -> 1.0 * 0.5 + assert _lr_after_steps(30) == pytest.approx(0.5) + # last_epoch=50 -> num_decays=(50-10)//20=2 -> 1.0 * 0.25 + assert _lr_after_steps(50) == pytest.approx(0.25) + + +def test_all_param_groups_share_one_lr(): + """get_lr emits the same value for every param group.""" + p1 = torch.nn.Parameter(torch.zeros(1)) + p2 = torch.nn.Parameter(torch.zeros(1)) + opt = SGD([{"params": [p1]}, {"params": [p2]}], lr=1.0) + scheduler = AF3Scheduler(opt, **_KW) + for _ in range(5): + opt.step() + scheduler.step() + lrs = scheduler.get_last_lr() + assert len(lrs) == 2 + assert lrs[0] == lrs[1] == pytest.approx(0.5) + + +def test_scheduler_config_state_dict_roundtrip(): + """load_state_dict restores interval, frequency, and the wrapped scheduler.""" + opt = _single_param_optimizer() + scheduler = AF3Scheduler(opt, **_KW) + cfg = SchedulerConfig(scheduler=scheduler, interval="epoch", frequency=3) + for _ in range(7): + opt.step() + scheduler.step() + state = cfg.state_dict() + assert state["interval"] == "epoch" + assert state["frequency"] == 3 + + opt2 = _single_param_optimizer() + fresh = AF3Scheduler(opt2, **_KW) + cfg2 = SchedulerConfig(scheduler=fresh, interval="step", frequency=1) + cfg2.load_state_dict(state) + + assert cfg2.interval == "epoch" + assert cfg2.frequency == 3 + assert cfg2.scheduler.last_epoch == scheduler.last_epoch + assert cfg2.scheduler.get_last_lr() == scheduler.get_last_lr() + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_testing_helpers.py b/tests/test_testing_helpers.py new file mode 100644 index 00000000..de53bb45 --- /dev/null +++ b/tests/test_testing_helpers.py @@ -0,0 +1,25 @@ +"""Unit tests for the pure helper in foundry.testing.fixtures. + +The ``gpu`` fixture and ``configure_pytest`` hook are environment/side-effecting +glue (GPU detection, project-root setup, dotenv loading) and are not tested. +``get_test_data_dir`` has a small but real contract: the test ``data`` directory +sits next to the conftest file that calls it. +""" + +import pytest + +from foundry.testing.fixtures import get_test_data_dir + + +def test_get_test_data_dir_is_data_next_to_conftest(tmp_path): + conftest = tmp_path / "conftest.py" + assert get_test_data_dir(str(conftest)) == tmp_path.resolve() / "data" + + +def test_get_test_data_dir_tracks_the_files_directory(tmp_path): + nested = tmp_path / "sub" / "conftest.py" + assert get_test_data_dir(str(nested)) == (tmp_path / "sub").resolve() / "data" + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_xpu.py b/tests/test_xpu.py new file mode 100644 index 00000000..f127c3ad --- /dev/null +++ b/tests/test_xpu.py @@ -0,0 +1,71 @@ +"""Unit tests for the Intel-XPU Lightning plugins in foundry.utils.xpu. + +These plugins target Intel GPUs (``torch.xpu``), which are absent in CI, so the +parts that actually touch XPU hardware (``setup_device`` success, autocast +contexts, device counts) are not exercised. The device-independent contracts are +pinned here: device parsing, the no-XPU guards (which must raise on a CPU box), +and the precision→dtype mapping + tensor conversion. +""" + +import pytest +import torch + +from foundry.utils.xpu.single_xpu_strategy import SingleXPUStrategy +from foundry.utils.xpu.xpu_accelerator import XPUAccelerator +from foundry.utils.xpu.xpu_precision import XPUMixedPrecision + + +def test_accelerator_name(): + assert XPUAccelerator.name() == "xpu" + + +def test_accelerator_not_available_on_non_xpu_host(): + assert XPUAccelerator.is_available() is False + + +def test_parse_devices_passes_lists_through_and_wraps_scalars(): + assert XPUAccelerator.parse_devices([0, 1]) == [0, 1] + assert XPUAccelerator.parse_devices(0) == [0] + + +def test_get_parallel_devices_builds_xpu_devices(): + devices = XPUAccelerator.get_parallel_devices([0, 1]) + assert devices == [torch.device("xpu", 0), torch.device("xpu", 1)] + + +def test_get_device_stats_is_empty(): + assert XPUAccelerator.get_device_stats("xpu:0") == {} + + +def test_setup_device_rejects_non_xpu_device(): + with pytest.raises(RuntimeError, match="Device should be xpu"): + XPUAccelerator.setup_device(torch.device("cpu")) + + +def test_single_xpu_strategy_requires_xpu(): + with pytest.raises(RuntimeError, match="requires XPU devices"): + SingleXPUStrategy() + + +def test_mixed_precision_maps_precision_to_dtype(): + assert XPUMixedPrecision("16-mixed")._desired_input_dtype == torch.float16 + assert XPUMixedPrecision("bf16-mixed")._desired_input_dtype == torch.bfloat16 + + +def test_mixed_precision_rejects_invalid_precision(): + with pytest.raises(ValueError, match="Invalid precision"): + XPUMixedPrecision("32-true") + + +def test_mixed_precision_converts_only_float_tensors(): + plugin = XPUMixedPrecision("bf16-mixed") + + converted = plugin.convert_input(torch.ones(3, dtype=torch.float32)) + assert converted.dtype == torch.bfloat16 + + untouched = torch.ones(3, dtype=torch.int64) + assert plugin.convert_input(untouched).dtype == torch.int64 + + +if __name__ == "__main__": + pytest.main(["-v", __file__])