Skip to content

OpenTabular/DeepTab

Repository files navigation

DeepTab: Tabular Deep Learning Made Simple

DeepTab is a Python library for deep learning on tabular data, built on PyTorch and Lightning with a scikit-learn compatible API. It offers 15 neural architectures, from Mamba-inspired state space models and Transformers to tree ensembles and MLP baselines, each available as a classifier, regressor, or distributional (LSS) model. One fit/predict/evaluate workflow covers everyday modeling, architecture research, and production deployment.

Why DeepTab?

  • Familiar interface. A scikit-learn fit/predict/evaluate API that drops into existing pipelines, including GridSearchCV.
  • Automatic preprocessing. Feature-type detection, encoding, scaling, and missing-value handling are built in.
  • One model, three tasks. Every architecture ships as a classifier, a regressor, and a distributional (LSS) variant for uncertainty quantification.
  • A broad model zoo. 15 stable architectures plus experimental models, all behind the same interface, with selection guidance.
  • Built for real data. Mixed feature types, class imbalance, GPU acceleration, and early stopping work out of the box.

⚑ What's New in v2.0

v2.0 is a ground-up restructuring of DeepTab. The high-level estimator API (MambularClassifier().fit(...)) is largely unchanged, but the internal package layout, configuration objects, and import paths have moved.

⚠️ Upgrading from v1? Packages were reorganised, the Default<Arch>Config classes were renamed to <Arch>Config, and the data modules were renamed to TabularDataModule / TabularDataset. Code that only uses the high-level estimators mostly keeps working; code that imported internal modules needs updating. See the FAQ for v1 support and upgrade notes.

Configuration and data

  • Split-config API: The model, preprocessing, and training each have their own configuration object, so you can tune one concern without disturbing the others. This is the first thing you reach for in v2.
  • Typed data layer: TabularDataset, TabularDataModule, and FeatureSchema give the data pipeline an explicit, inspectable contract, with stratified splitting controlled through TrainerConfig.

Models

  • New stable models: AutoInt, ENODE, and TabR.
  • New experimental models: Tangos, Trompt, and ModernNCA, under evaluation for promotion.

Training and evaluation

  • Observability and experiment tracking: ObservabilityConfig adds structured lifecycle logging via structlog and one-line MLflow or TensorBoard tracking, with every run saved to an organised directory tree. It is opt-in and silent by default.
  • Registry-driven training: Every torch.optim optimizer, learning-rate scheduler, and loss is selectable by name through TrainerConfig, and you can register your own at runtime.
  • Unified metrics: deeptab.metrics ships 25+ metric classes for regression, classification, and distributional models, auto-selected per task through a registry.
  • Reproducibility: set_seed and seed_context seed Python, NumPy, and PyTorch across CPU, CUDA, and MPS, including the DataLoader and sampler generators.

Deployment

  • Deployment-safe inference: InferenceModel wraps a fitted estimator in a read-only prediction surface with schema validation and task-type enforcement. Training methods are deliberately absent, so a served model cannot be re-fitted by accident.
  • Self-describing artifacts: save and load go through a single .deeptab format that bundles the architecture, feature schema, preprocessing, task type, and package versions alongside the weights, so a saved model carries everything needed to reload it.

Documentation

πŸƒ Quickstart

from deeptab.models import MambularClassifier

# Initialize and fit (sklearn-compatible)
model = MambularClassifier()
model.fit(X_train, y_train, max_epochs=50)

# Predict
predictions = model.predict(X_test)
probabilities = model.predict_proba(X_test)

That's it! DeepTab handles preprocessing, batching, and training automatically.

Works with pandas & numpy: Pass DataFrames or arrays, and DeepTab auto-detects feature types.

Available Models

DeepTab provides 15 stable architectures across five families: State Space Models (Mambular, MambaTab, MambAttention), Transformers (FTTransformer, TabTransformer, SAINT, AutoInt), residual networks (ResNet, TabR), tree-inspired models (NODE, ENODE, NDTF), and general baselines (MLP, TabM, TabulaRNN). Three experimental models (ModernNCA, Tangos, Trompt) are under evaluation for promotion.

See the Model Zoo for detailed comparisons, complexity analysis, and selection guidance.

Stable Models

Category Model Architecture Best For
State Space Models Mambular Stacked Mamba over feature tokens General-purpose tabular modeling
MambaTab Lightweight Mamba SSM Small datasets and fast training
MambAttention Mamba with feature attention Feature-interaction-heavy data
Transformers FTTransformer Feature Tokenizer + Transformer Strong attention-based baseline
TabTransformer Transformer over categorical tokens Categorical-heavy data
SAINT Row and column attention Small or label-scarce datasets
AutoInt Self-attentive feature interactions Automatic high-order interactions
Residual Networks ResNet Residual MLP Fast dense baseline
TabR Retrieval-augmented MLP/kNN Large datasets with neighbor signal
Tree-Inspired NODE Neural oblivious decision ensembles Differentiable tree inductive bias
ENODE Embedded NODE-style soft trees Tree-inspired modeling with embeddings
NDTF Neural decision tree forest Differentiable forest experiments
Other MLP Feedforward dense network Fastest baseline
TabM Parameter-efficient ensemble MLP Strong efficient baseline
TabulaRNN Recurrent feature-sequence model Sequential feature modeling

Experimental Models ⚠️

⚠️ API Not Stable: Experimental models may change in minor releases. Always pin exact version: deeptab==x.y.z

  • ModernNCA: Neighborhood Component Analysis (metric learning)
  • Tangos: Gradient orthogonalization approach
  • Trompt: Prompt-based learning for tabular data

Task Variants

All models come in three variants:

  • *Classifier: Classification (binary & multi-class)
  • *Regressor: Regression (point estimates)
  • *LSS: Distributional regression (full distribution prediction)

Consistent API: All models use the same interface, so you can swap architectures without changing code.

πŸ“š Documentation

Full documentation: deeptab.readthedocs.io

Quick Links

πŸ› οΈ Installation

Basic installation:

pip install deeptab

With experiment tracking and structured logging:

pip install 'deeptab[tracking]'   # MLflow + TensorBoard loggers
pip install 'deeptab[logs]'       # structured logging via structlog
pip install 'deeptab[all]'        # every optional backend

Faster Mamba models (optional CUDA kernels):

pip install mamba-ssm

Mamba kernels are optional: They give a 20-30% speedup for Mamba-based models on a compatible NVIDIA GPU (CUDA 11.6+). If the install fails or no GPU is present, DeepTab falls back to a pure-PyTorch implementation automatically.

Lightweight by default: Tracking backends are optional and imported lazily, so a plain pip install deeptab stays small. Install only the extras you actually use.

Requirements: Python 3.10+, PyTorch 2.2+, Lightning 2.3.3+

GPU Support: See installation guide for CUDA setup.

Usage

Basic Workflow

from deeptab.models import MambularClassifier
from deeptab.configs import MambularConfig, PreprocessingConfig, TrainerConfig

# 1. Initialize with configuration (optional - defaults work well!)
model_config = MambularConfig(d_model=64, n_layers=6)
prep_config = PreprocessingConfig(numerical_preprocessing="quantile")
trainer_config = TrainerConfig(lr=1e-4, batch_size=256)

model = MambularClassifier(
    model_config=model_config,
    preprocessing_config=prep_config,
    trainer_config=trainer_config
)

# 2. Fit (X can be pandas DataFrame or numpy array)
model.fit(X_train, y_train, max_epochs=50)

# 3. Predict
predictions = model.predict(X_test)
probabilities = model.predict_proba(X_test)

# 4. Evaluate
metrics = model.evaluate(X_test, y_test)
# Regression:      {"rmse": …, "mae": …, "r2": …}
# Classification:  {"accuracy": …, "auroc": …, "log_loss": …}
# LSS (normal):    {"crps": …, "rmse": …, "mae": …}

πŸ’‘ Tip: Start with defaults (MambularClassifier()) and tune only if needed. See Recommended Configs for guidance.

Hyperparameter Tuning

DeepTab models are sklearn-compatible, so you can use GridSearchCV:

from sklearn.model_selection import GridSearchCV
from deeptab.models import MambularClassifier

param_grid = {
    "model_config__d_model": [64, 128, 256],
    "model_config__n_layers": [4, 6, 8],
    "trainer_config__lr": [1e-4, 5e-4, 1e-3],
}

search = GridSearchCV(
    MambularClassifier(),
    param_grid,
    cv=5,
    scoring="accuracy"
)
search.fit(X_train, y_train)
print(f"Best params: {search.best_params_}")
print(f"Best score: {search.best_score_}")

Built-in HPO: Every estimator exposes optimize_hparams(), which runs Gaussian process Bayesian optimization (via scikit-optimize) over a search space derived from the model config. See the HPO Tutorial.

Distributional Regression (LSS)

Predict a full distribution instead of a single point estimate:

from deeptab.models import MambularLSS

# Choose a distribution family when you fit
model = MambularLSS()
model.fit(X_train, y_train, family="normal", max_epochs=50)

# predict() returns the estimated distribution parameters per sample
# (for "normal", that is the location and scale)
params = model.predict(X_test)

# Evaluate with proper scoring rules selected for the family
metrics = model.evaluate(X_test, y_test)

Available families: normal, lognormal, studentt, gamma, beta, tweedie, poisson, zip, negativebinom, dirichlet, mog, quantile, and more. Each family auto-selects appropriate evaluation metrics (CRPS, deviances, NLL).

Prediction intervals: Turn the predicted parameters into calibrated intervals as shown in the Uncertainty Quantification tutorial.

Advanced Features

Preprocessing

DeepTab includes comprehensive preprocessing powered by PreTab:

from deeptab.configs import PreprocessingConfig
from deeptab.models import MambularClassifier

prep_config = PreprocessingConfig(
    numerical_preprocessing="ple",  # Piecewise linear encoding
    n_bins=50                       # Number of bins for the encoding
)

model = MambularClassifier(preprocessing_config=prep_config)
model.fit(X_train, y_train, max_epochs=50)

Features:

  • Automatic detection: Feature types detected from data
  • Type-aware: Separate strategies for numerical and categorical features
  • Methods: PLE, quantile transform, splines, standardization, min-max, and robust scaling
  • Pre-trained encodings: Transfer learning for categorical features

Learn more: Preprocessing is driven by PreprocessingConfig; see the Config System guide and the PreTab project.

Observability & Experiment Tracking

DeepTab can record what happens during training without you writing any callbacks. Pass an ObservabilityConfig when you build a model, and each run captures its hyperparameters, lifecycle events, and final metrics in one self-contained folder.

from deeptab.core.observability import ObservabilityConfig
from deeptab.models import MambularClassifier

obs = ObservabilityConfig(
    experiment_name="churn_baseline",
    structured_logging=True,          # human-readable console + JSON event log
    experiment_trackers=["mlflow"],   # also supports "tensorboard"
)

model = MambularClassifier(observability_config=obs)
model.fit(X_train, y_train, max_epochs=50)

Every fit produces a tidy, reproducible run directory:

deeptab_runs/
  runs/churn_baseline/20260611_174830_8f3a2c/
    config.yaml       # estimator hyperparameters
    lifecycle.jsonl   # structured event log
    summary.json      # final metrics
    checkpoints/best.ckpt
  tensorboard/...
  mlflow/...

Tune the noise: verbosity controls how much is emitted (0 silent, 1 milestones, 2 detailed, 3 debug). The default keeps notebooks quiet.

πŸ”¬ For researchers: Lifecycle events such as fit.started, model.created, and train.completed carry structured metadata (sample counts, parameter counts, best validation loss), so you can script experiment sweeps and compare runs programmatically.

πŸ“– Learn more: Observability

Custom Models

Implement your own architecture with DeepTab's base classes. A model is three small pieces: a dataclass config (subclassing BaseModelConfig), a PyTorch architecture (subclassing BaseModel), and one estimator per task that binds them via _model_cls / _config_cls:

from dataclasses import dataclass, field

import torch
import torch.nn as nn

from deeptab.configs import BaseModelConfig, TrainerConfig
from deeptab.core import BaseModel, get_feature_dimensions
from deeptab.models import SklearnBaseRegressor


@dataclass
class MyCustomConfig(BaseModelConfig):
    layer_sizes: list = field(default_factory=lambda: [128, 64])
    dropout: float = 0.1


class MyCustomModel(BaseModel):
    def __init__(
        self,
        feature_information: tuple,  # (num_info, cat_info, embedding_info)
        num_classes: int = 1,
        config: MyCustomConfig = MyCustomConfig(),  # noqa: B008
        **kwargs,
    ):
        super().__init__(config=config, **kwargs)
        self.save_hyperparameters(ignore=["feature_information"])

        # Input width is derived from the data, never hard-coded.
        input_dim = get_feature_dimensions(*feature_information)

        layers: list[nn.Module] = []
        prev = input_dim
        for size in self.hparams.layer_sizes:
            layers += [nn.Linear(prev, size), nn.ReLU(), nn.Dropout(self.hparams.dropout)]
            prev = size
        layers.append(nn.Linear(prev, num_classes))
        self.layers = nn.Sequential(*layers)

    def forward(self, *data) -> torch.Tensor:
        # data == (num_features, cat_features, embeddings)
        x = torch.cat([t for group in data for t in group], dim=1)
        return self.layers(x)


class MyRegressor(SklearnBaseRegressor):
    _model_cls = MyCustomModel
    _config_cls = MyCustomConfig


# Use like any other DeepTab model
model = MyRegressor(
    model_config=MyCustomConfig(layer_sizes=[256, 128]),
    trainer_config=TrainerConfig(lr=1e-3),
)
model.fit(X_train, y_train, max_epochs=50)

πŸ“– Learn more: Custom Models walks through configs, embeddings, and the *Classifier / *Regressor / *LSS variants.

πŸ› οΈ Developer Guide: See Contributing for architecture guidelines.

🏷️ Citation

If you use DeepTab in your research, please cite:

@article{thielmann2024mambular,
  title={Mambular: A Sequential Model for Tabular Deep Learning},
  author={Thielmann, Anton Frederik and Kumar, Manish and Weisser, Christoph and Reuter, Arik and S{\"a}fken, Benjamin and Samiee, Soheila},
  journal={arXiv preprint arXiv:2408.06291},
  year={2024}
}

@article{thielmann2024efficiency,
  title={On the Efficiency of NLP-Inspired Methods for Tabular Deep Learning},
  author={Thielmann, Anton Frederik and Samiee, Soheila},
  journal={arXiv preprint arXiv:2411.17207},
  year={2024}
}

πŸ“„ License

DeepTab is licensed under the MIT License. See LICENSE for details.

🀝 Contributing

Contributions are welcome. See the Contributing Guide to get started, and please follow our Code of Conduct.

πŸ“ž Support

About

DeepTab is a Python package that simplifies tabular deep learning by providing a suite of models for regression, classification, and distributional regression tasks. It includes models such as Mambular, TabM, FT-Transformer, TabulaRNN, TabTransformer, and tabular ResNets.

Topics

Resources

License

Code of conduct

Stars

Watchers

Forks

Packages

 
 
 

Contributors