Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/autogluon/cloud/endpoint/tabular_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from pathlib import Path
from typing import Any, Dict, Union

import pandas as pd

from autogluon.common.loaders import load_pd

from ..utils.serializers import AutoGluonSerializationWrapper
from ..utils.utils import split_pred_and_pred_proba
from .endpoint import Endpoint


class TabularEndpoint:
"""High-level endpoint for tabular foundation-model prediction.

Each request carries both the labeled few-shot context (``train_data``) and the unlabeled
rows to predict on (``test_data``); the endpoint fits a fresh predictor per request.
"""

def __init__(self, endpoint: Endpoint):
self._endpoint = endpoint

@property
def endpoint_name(self) -> str:
return self._endpoint.endpoint_name

def predict(
self,
train_data: Union[str, Path, pd.DataFrame],
test_data: Union[str, Path, pd.DataFrame],
label: str = "target",
accept: str = "application/x-parquet",
) -> pd.Series:
"""Run real-time prediction. Returns the predicted label column."""
result = self._invoke(train_data, test_data, label=label, accept=accept)
if result.shape[1] == 1:
return result.iloc[:, 0]
pred, _ = split_pred_and_pred_proba(result)
return pred

def predict_proba(
self,
train_data: Union[str, Path, pd.DataFrame],
test_data: Union[str, Path, pd.DataFrame],
label: str = "target",
include_predict: bool = False,
accept: str = "application/x-parquet",
) -> Union[pd.DataFrame, "tuple"]:
"""Run real-time prediction. Returns class probabilities (classification only).

If ``include_predict`` is True, returns ``(prediction, predict_probability)``.
"""
result = self._invoke(train_data, test_data, label=label, accept=accept)
if result.shape[1] == 1:
raise ValueError(
"predict_proba is not supported for regression endpoints — only a single column was returned."
)
pred, pred_proba = split_pred_and_pred_proba(result)
if include_predict:
return pred, pred_proba
return pred_proba

def _invoke(
self,
train_data: Union[str, Path, pd.DataFrame],
test_data: Union[str, Path, pd.DataFrame],
*,
label: str,
accept: str,
) -> pd.DataFrame:
train_df = load_pd.load(str(train_data)) if not isinstance(train_data, pd.DataFrame) else train_data
test_df = load_pd.load(str(test_data)) if not isinstance(test_data, pd.DataFrame) else test_data

inference_kwargs: Dict[str, Any] = {"label": label}
payload = AutoGluonSerializationWrapper(
data=test_df,
inference_kwargs=inference_kwargs,
train_data=train_df,
)
return self._endpoint.predict(payload, initial_args={"Accept": accept})

def delete_endpoint(self) -> None:
"""Delete the endpoint and cleanup artifacts."""
self._endpoint.delete_endpoint()
217 changes: 160 additions & 57 deletions src/autogluon/cloud/model/foundation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ..backend.backend_factory import BackendFactory
from ..backend.constant import SAGEMAKER, TABULAR_SAGEMAKER, TIMESERIES_SAGEMAKER
from ..endpoint.tabular_endpoint import TabularEndpoint
from ..endpoint.timeseries_endpoint import TimeSeriesEndpoint
from ..scripts.script_manager import ScriptManager
from ..utils.aws_utils import resolve_cloud_output_path
Expand Down Expand Up @@ -431,100 +432,202 @@ class TabularFoundationModel(FoundationModel):

@property
def _serve_script_path(self) -> str:
raise NotImplementedError("Tabular FM deploy is not yet supported")

def deploy(self, **kwargs):
raise NotImplementedError("Tabular FM deploy is not yet supported")
return ScriptManager.SAGEMAKER_TABULAR_FM_SERVE_SCRIPT_PATH

def _build_predictor_init_args(self, label: str = "target", **kwargs) -> Dict[str, Any]:
"""Map user kwargs to TabularPredictor init args."""
return {"label": label}

def predict(
def _build_predictor_fit_args(self, hyperparameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
model_name = self._config["model_name"]
merged_hp = self._get_hyperparameters("inference", hyperparameters)
return {
"hyperparameters": {model_name: merged_hp},
"fit_weighted_ensemble": False,
"calibrate_decision_threshold": False,
}

def deploy(
self,
train_data: Union[str, Path, pd.DataFrame],
test_data: Union[str, Path, pd.DataFrame],
label: str = "target",
hyperparameters: Optional[Dict[str, Any]] = None,
instance_type: Optional[str] = None,
endpoint_name: Optional[str] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
framework_version: str = "latest",
custom_image_uri: Optional[str] = None,
wait: bool = True,
**backend_kwargs,
) -> Optional[pd.DataFrame]:
) -> TabularEndpoint:
"""
Run batch prediction for tabular tasks.

For tabular foundation models (e.g., Mitra), train_data provides the few-shot
context and test_data contains the rows to predict on.
Deploy the foundation model to a real-time endpoint.

Parameters
----------
train_data
Labeled few-shot context for the foundation model.
test_data
Unlabeled data to predict on.
label
Target column name in train_data.
hyperparameters
Model hyperparameters for inference. Overrides values passed to the constructor.
instance_type
Instance type for the prediction job. If None, uses registry default.
Instance type for the endpoint. If None, uses registry default.
endpoint_name
Custom endpoint name. If None, auto-generated.
hyperparameters
Model hyperparameters for inference. Overrides constructor values.
framework_version
Container framework version.
custom_image_uri
Custom Docker image URI for the container.
Custom Docker image URI for the inference container.
wait
If True, block and return DataFrame. If False, return the job handle.
Whether to block until the endpoint is ready.
**backend_kwargs
Additional backend-specific arguments.
Backend-specific arguments.

Returns
-------
Optional[pd.DataFrame]
TabularEndpoint
"""
# TODO: requires fit_predict support for TabularCloudPredictor
raise NotImplementedError
self._deploy_backend(
instance_type=instance_type,
endpoint_name=endpoint_name,
hyperparameters=hyperparameters,
framework_version=framework_version,
custom_image_uri=custom_image_uri,
wait=wait,
**backend_kwargs,
)
return TabularEndpoint(self._backend.endpoint)

def predict_proba(
def _run_fit_predict_job(
self,
train_data: Union[str, Path, pd.DataFrame],
test_data: Union[str, Path, pd.DataFrame],
*,
label: str,
hyperparameters: Optional[Dict[str, Any]],
instance_type: Optional[str],
framework_version: str,
custom_image_uri: Optional[str],
wait: bool,
predictions_path: Optional[str],
**backend_kwargs,
) -> Optional[pd.DataFrame]:
"""Launch a single fit+predict SageMaker job.

For classification tasks the job emits a combined DataFrame ``[<label>, <class>_proba, ...]``
so :meth:`predict` and :meth:`predict_proba` can share one job. For regression it emits the
single-column predictions DataFrame.
"""
if instance_type is None:
instance_type = self._config["predict_instance_type"]

predictor_init_args = self._build_predictor_init_args(label=label)
predictor_fit_args = self._build_predictor_fit_args(hyperparameters)
data_channels = {
"train_data": train_data,
"test_data": test_data,
}

extra_ag_args: Dict[str, Any] = {"predict_after_fit": True}
if predictions_path is not None:
extra_ag_args["predictions_path"] = predictions_path

self._backend.fit(
predictor_init_args=predictor_init_args,
predictor_fit_args=predictor_fit_args,
data_channels=data_channels,
framework_version=framework_version,
instance_type=instance_type,
custom_image_uri=custom_image_uri,
wait=wait,
extra_ag_args=extra_ag_args,
**backend_kwargs,
)

if not wait:
return None

return self._backend.get_fit_predict_results()

def predict(
self,
train_data: Union[str, Path, pd.DataFrame],
test_data: Union[str, Path, pd.DataFrame],
label: str = "target",
hyperparameters: Optional[Dict[str, Any]] = None,
output_path: Optional[str] = None,
instance_type: Optional[str] = None,
framework_version: str = "latest",
custom_image_uri: Optional[str] = None,
wait: bool = True,
predictions_path: Optional[str] = None,
**backend_kwargs,
) -> Optional[pd.DataFrame]:
) -> Optional[pd.Series]:
"""
Run batch prediction returning class probabilities.
Run batch prediction for tabular tasks.

Parameters
----------
train_data
Labeled few-shot context for the foundation model.
test_data
Unlabeled data to predict on.
label
Target column name in train_data.
hyperparameters
Model hyperparameters for inference. Overrides values passed to the constructor.
Available hyperparameters for each model are listed in the AutoGluon documentation.
output_path
S3 path to store predictions.
If None, will auto-generate under cloud_output_path.
instance_type
Instance type for the prediction job.
If None, will use the default from the model registry.
wait
If True, block and return DataFrame. If False, return the job handle.
**backend_kwargs
Additional backend-specific arguments (e.g. job_name, custom_image_uri,
framework_version, volume_size).
For tabular foundation models (e.g., Mitra), ``train_data`` provides the few-shot context
and ``test_data`` contains the rows to predict on. Both are uploaded to a single SageMaker
training job that runs the in-context-learning fit and prediction in one pass.

Returns
-------
Optional[pd.DataFrame]
Optional[pd.Series]
Predicted labels (``None`` when ``wait`` is False).
"""
raise NotImplementedError
from ..utils.utils import split_pred_and_pred_proba

result = self._run_fit_predict_job(
train_data=train_data,
test_data=test_data,
label=label,
hyperparameters=hyperparameters,
instance_type=instance_type,
framework_version=framework_version,
custom_image_uri=custom_image_uri,
wait=wait,
predictions_path=predictions_path,
**backend_kwargs,
)
if result is None:
return None
if self._config["task"] == "regression":
return result.iloc[:, 0]
pred, _ = split_pred_and_pred_proba(result)
return pred

def predict_proba(
self,
train_data: Union[str, Path, pd.DataFrame],
test_data: Union[str, Path, pd.DataFrame],
label: str = "target",
include_predict: bool = False,
hyperparameters: Optional[Dict[str, Any]] = None,
instance_type: Optional[str] = None,
framework_version: str = "latest",
custom_image_uri: Optional[str] = None,
wait: bool = True,
predictions_path: Optional[str] = None,
**backend_kwargs,
) -> Optional[Union[pd.DataFrame, "tuple"]]:
"""
Run batch prediction returning class probabilities. Only valid for classification tasks.

Parameters mirror :meth:`predict`. If ``include_predict`` is True, returns a tuple of
``(prediction, predict_probability)``; otherwise returns ``predict_probability`` only.
"""
from ..utils.utils import split_pred_and_pred_proba

if self._config["task"] != "classification":
raise ValueError(f"predict_proba is only supported for classification, got task='{self._config['task']}'.")
result = self._run_fit_predict_job(
train_data=train_data,
test_data=test_data,
label=label,
hyperparameters=hyperparameters,
instance_type=instance_type,
framework_version=framework_version,
custom_image_uri=custom_image_uri,
wait=wait,
predictions_path=predictions_path,
**backend_kwargs,
)
if result is None:
return None
pred, pred_proba = split_pred_and_pred_proba(result)
if include_predict:
return pred, pred_proba
return pred_proba
29 changes: 19 additions & 10 deletions src/autogluon/cloud/model/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,32 @@ class FoundationModelConfig(TypedDict):
"fit_instance_type": "ml.g5.xlarge",
"fine_tunable": True,
},
# TODO: Replace dummy configs with real values
"mitra-classification": {
"task": "classification",
"model_name": "Mitra",
"inference_hyperparameters": {"model_path": "TODO"},
"training_hyperparameters": {"model_path": "TODO"},
"predict_instance_type": "ml.m5.2xlarge",
"model_name": "MITRA",
"inference_hyperparameters": {"fine_tune": False},
"training_hyperparameters": {"fine_tune": True},
"predict_instance_type": "ml.g5.xlarge",
"deploy_instance_type": "ml.g5.xlarge",
"fit_instance_type": "ml.g5.xlarge",
"fine_tunable": False,
"fine_tunable": True,
},
"mitra-regression": {
"task": "regression",
"model_name": "Mitra",
"inference_hyperparameters": {"model_path": "TODO"},
"training_hyperparameters": {"model_path": "TODO"},
"predict_instance_type": "ml.m5.2xlarge",
"model_name": "MITRA",
"inference_hyperparameters": {"fine_tune": False},
"training_hyperparameters": {"fine_tune": True},
"predict_instance_type": "ml.g5.xlarge",
"deploy_instance_type": "ml.g5.xlarge",
"fit_instance_type": "ml.g5.xlarge",
"fine_tunable": True,
},
"tabicl": {
"task": "classification",
"model_name": "TABICL",
"inference_hyperparameters": {},
"training_hyperparameters": {},
"predict_instance_type": "ml.g5.xlarge",
"deploy_instance_type": "ml.g5.xlarge",
"fit_instance_type": "ml.g5.xlarge",
"fine_tunable": False,
Expand Down
Loading
Loading