CAPE: Causality-Induced Positional Encoding for Transformer-Based Representation Learning of Non-Sequential Features
Official implementation for the CAPE paper.
CAPE is a positional encoding method for transformer-based representation learning of non-sequential features. It learns causality-aware feature positions from data and integrates them into transformer self-attention through a rotary positional encoding form. The current code supports CAPE integration with scBERT and scGPT, with cell type annotation (CTA) provided as an application workflow.
- CAPE positional encoding for non-sequential, causally-related features.
- Current backbone support for
scgptandscbert. - CTA application workflow for
.h5adAnnData inputs. - Optional local pretrained assets with Hugging Face fallback.
- YAML-based experiment configuration with reusable defaults.
- Standard outputs for metrics, predictions, resolved configs, probabilities, logs, and CAPE position artifacts.
The manuscript is available at docs/CAPE.pdf.
We provide organized scBERT and scGPT pretrained assets on Hugging Face for use with this pipeline:
- scBERT:
kaichenxu/cape_scbert - scGPT:
kaichenxu/cape_scgpt
Each repository includes the model weights and companion files expected by the
CAPE wrappers. The example configs reference these IDs through
model.hf_repo_id, and the pipeline downloads them automatically when
model.path does not point to an existing local asset directory.
configs/ Experiment configs and shared CAPE defaults
docs/ Paper and project figures
scripts/ Convenience scripts for application workflows
src/ CAPE modules, model wrappers, data utilities, pipelines
tests/ Smoke and unit tests
Create and activate a Python environment, then install the package in editable mode:
python -m venv .venv
source .venv/bin/activate
pip install -e .The project requires Python 3.10 or newer. GPU training is supported through PyTorch when a CUDA-enabled installation is available.
The provided CTA workflow expects an AnnData .h5ad file. At minimum, the file
must include:
adata.X: expression matrix, or setdata.input_layerto use a layer fromadata.layers.adata.obs[data.label_column]: cell type labels.- Gene identifiers in
adata.var_names, or inadata.var[data.gene_column]whendata.gene_columnis set.
Optional fields include a batch column in adata.obs and a split column for
predefined train/validation/test partitions.
Starter configs are provided for both supported backends:
configs/CTA/scgpt_CTA.yamlconfigs/CTA/scbert_CTA.yaml
Before running, update at least the dataset path and label column:
data:
path: /path/to/dataset.h5ad
label_column: celltype
gene_column: null
input_layer: nullPretrained assets are resolved from model.path when that directory exists.
If it does not exist, the pipeline uses model.hf_repo_id, for example
kaichenxu/cape_scgpt or kaichenxu/cape_scbert.
The default configs use a stratified split:
data:
split:
mode: stratified
ratios:
train: 0.8
val: 0.1
test: 0.1To use a predefined split column, set mode: column and provide the split
column plus label values for train, validation, and test.
Run scGPT CTA:
python -m src.main --config configs/CTA/scgpt_CTA.yamlRun scBERT CTA:
python -m src.main --config configs/CTA/scbert_CTA.yamlEquivalent convenience scripts are also available:
bash scripts/run_scgpt_CTA.sh
bash scripts/run_scbert_CTA.shSet run.device to auto, cpu, cuda, or a specific CUDA device string
supported by PyTorch.
For the CTA workflow, a run named scgpt_cta_run writes outputs under:
results/CTA/scgpt/scgpt_cta_run/
Standard artifacts include:
metrics.json: test accuracy, macro F1, and weighted F1.predictions.csv: cell IDs, predicted labels, and true labels.probabilities.npy: class probabilities whensave_probabilitiesis true.label_mapping.json: label-to-ID mapping learned from the training split.config_resolved.yaml: fully merged run configuration.summary.json: compact run summary and artifact paths.cape/: selected features, token IDs, priority scores, and rank positions.
Logs are written under:
logs/CTA/<model_name>/<run_name>.log
Run the test suite with:
pytestThe tests cover config loading, data preprocessing helpers, pretrained model source resolution, and pipeline output smoke checks.
