Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/en/advanced/mooncake-dataproto-transfer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Mooncake DataProto Rollout Transfer

slime can transfer rollout data through Mooncake instead of Ray object references. This is useful when the rollout producer and actor consumer run on different nodes and Mooncake Store is configured for the cluster transport.

The default transfer backend remains Ray. Enable Mooncake DataProto transfer explicitly:

```bash
python3 train.py \
--transfer-backend mooncake_dataproto \
--mooncake-dataproto-store-init-kwargs '{"setup_method":"setup"}'
```

## What is transferred

The Mooncake path keeps slime's rollout data layout unchanged:

- per-rank rollout partitions are still selected by slime before actor consumption;
- tensor fields such as `tokens` and `loss_masks` are stored as Mooncake remote tensor batches;
- non-tensor rollout fields and metadata stay in the `DataProto` wrapper;
- cleanup keys are tracked in metadata and removed after actor-side materialization.

## Options

| Option | Default | Meaning |
| --- | --- | --- |
| `--transfer-backend` | `ray` | Set to `mooncake_dataproto` to enable Mooncake rollout transfer. |
| `--mooncake-dataproto-store-init-kwargs` | `null` | JSON arguments used to initialize the Mooncake store. Use `{"setup_method":"setup"}` for real Mooncake Store setup and `{"setup_method":"setup_dummy"}` for local unit tests. |
| `--mooncake-dataproto-hard-pin` | `true` | Hard-pin remote tensor data to the producer segment when publishing tensor batches. |

For performance runs, configure Mooncake Store with the production transport, for example RDMA, and keep buffer registration or prewarm costs separate from online transfer latency.
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ Start by Use Case
advanced/pd-disaggregation.md
advanced/external-rollout-engines.md
advanced/delta-weight-sync.md
advanced/mooncake-dataproto-transfer.md
advanced/sglang-config.md
advanced/megatron-config.md
advanced/arch-support-beyond-megatron.md
Expand Down
15 changes: 15 additions & 0 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,21 @@ def _split_train_data_by_dp(self, data):
rollout_indices=data["rollout_ids"],
)

if getattr(self.args, "transfer_backend", "ray") in {"mooncake", "mooncake_dataproto"}:
from slime.utils.rollout_dataproto import split_rollout_data_by_dp_dataproto

dynamic_global_batch_size = getattr(self, "_dynamic_global_batch_size", None)
return split_rollout_data_by_dp_dataproto(
self.args,
data,
dp_size,
partitions,
dynamic_global_batch_size,
micro_batch_indices,
num_microbatches,
global_batch_sizes,
)

# Package per-rank rollout_data
rollout_data_refs = []
for r in range(dp_size):
Expand Down
35 changes: 35 additions & 0 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,24 @@ def add_rollout_arguments(parser):
"This is used to shuffle the prompts and also for the random sampling of the prompts."
),
)
parser.add_argument(
"--transfer-backend",
choices=["ray", "mooncake", "mooncake_dataproto"],
default="ray",
help="Rollout data transfer backend. Keep ray as the default; mooncake is experimental.",
)
parser.add_argument(
"--mooncake-dataproto-hard-pin",
action=argparse.BooleanOptionalAction,
default=True,
help="Hard-pin Mooncake rollout tensors to the producer segment for Mooncake transfer.",
)
parser.add_argument(
"--mooncake-dataproto-store-init-kwargs",
type=json.loads,
default=None,
help="JSON kwargs used to initialize MooncakeDistributedStore for Mooncake transfer.",
)

# sampling
parser.add_argument(
Expand Down Expand Up @@ -1748,6 +1766,23 @@ def _validate_update_weight_args(args) -> None:
def slime_validate_args(args):
args.eval_datasets = _resolve_eval_datasets(args)

if getattr(args, "transfer_backend", "ray") == "mooncake_dataproto":
args.transfer_backend = "mooncake"

if getattr(args, "transfer_backend", "ray") == "mooncake":
from slime.utils.remote_batch import normalize_store_init_kwargs

args.mooncake_dataproto_store_init_kwargs = normalize_store_init_kwargs(
args.mooncake_dataproto_store_init_kwargs
)

if args.use_slime_router:
logger.warning(
"--use-slime-router is deprecated and ignored. slime now always uses sglang_router "
"built from https://github.com/zhuzilin/sgl-router."
)
args.use_slime_router = False

if args.kl_coef != 0 or args.use_kl_loss:
if not os.path.exists(args.ref_load):
raise FileNotFoundError(f"ref_load {args.ref_load} does not exist, please check the path.")
Expand Down
8 changes: 6 additions & 2 deletions slime/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import random
import re

import numpy as np
import ray

Expand Down Expand Up @@ -291,7 +290,12 @@ def __len__(self):

def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size):
assert len(rollout_data_ref) == dp_size
rollout_data = ray.get(rollout_data_ref[dp_rank].inner)
if getattr(args, "transfer_backend", "ray") in {"mooncake", "mooncake_dataproto"}:
from slime.utils.rollout_dataproto import materialize_dataproto_rollout_data

rollout_data = materialize_dataproto_rollout_data(args, rollout_data_ref[dp_rank])
else:
rollout_data = ray.get(rollout_data_ref[dp_rank].inner)

partition = rollout_data.pop("partition")
total_lengths = rollout_data["total_lengths"]
Expand Down
132 changes: 132 additions & 0 deletions slime/utils/remote_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

import os
from typing import Any, Mapping

import torch

ALLOWED_SETUP_METHODS = {"setup", "setup_dummy"}
_STORE_CACHE: dict[tuple[tuple[str, str], ...], Any] = {}


def normalize_store_init_kwargs(store_init_kwargs: dict[str, Any] | None) -> dict[str, Any]:
if store_init_kwargs is None:
raise ValueError("mooncake transfer requires --mooncake-dataproto-store-init-kwargs")
if not store_init_kwargs:
return {"setup_method": "setup"}
setup_method = store_init_kwargs.get("setup_method", "setup")
if setup_method not in ALLOWED_SETUP_METHODS:
raise ValueError(f"unsupported Mooncake store setup_method {setup_method!r}; allowed: {sorted(ALLOWED_SETUP_METHODS)}")
return dict(store_init_kwargs)


def create_mooncake_store(store_init_kwargs: dict[str, Any] | None = None) -> Any:
kwargs = normalize_store_init_kwargs(store_init_kwargs or {})
setup_method = kwargs.get("setup_method", "setup")
if setup_method == "setup_dummy":
return InMemoryMooncakeStore()

from mooncake.store import MooncakeDistributedStore # type: ignore

store = MooncakeDistributedStore()
setup_kwargs = {key: val for key, val in kwargs.items() if key != "setup_method"}
setup = getattr(store, setup_method)
try:
ret = setup(**setup_kwargs)
except TypeError:
if setup_method != "setup":
raise
ret = setup(_env_store_config() | setup_kwargs)
if ret != 0:
raise RuntimeError(f"Mooncake store {setup_method} failed with retcode {ret}")
return store


def get_cached_mooncake_store(store_init_kwargs: dict[str, Any] | None = None) -> Any:
kwargs = normalize_store_init_kwargs(store_init_kwargs)
cache_key = tuple(sorted((key, repr(val)) for key, val in kwargs.items()))
if cache_key not in _STORE_CACHE:
_STORE_CACHE[cache_key] = create_mooncake_store(kwargs)
return _STORE_CACHE[cache_key]


def put_mooncake_dataproto(
data: Mapping[str, Any],
store: Any,
*,
key_prefix: str,
namespace: str = "slime",
partition: str = "rollout",
) -> dict[str, Any]:
transfer_cls, export_ref, _is_ref_handle = _import_mooncake_dataproto_helpers()
transfer = transfer_cls(store, key_prefix=key_prefix)
ref = transfer.put_dataproto(data, namespace=namespace, partition=partition, stage="rollout")
handle = export_ref(ref)
handle["slime_key_prefix"] = key_prefix
return handle


def get_mooncake_dataproto(handle: Mapping[str, Any], store: Any) -> dict[str, Any]:
transfer_cls, _export_ref, _is_ref_handle = _import_mooncake_dataproto_helpers()
transfer = transfer_cls(store, key_prefix=handle.get("slime_key_prefix", ""))
return transfer.get_dataproto(handle)


def cleanup_mooncake_dataproto(handle: Mapping[str, Any], store: Any) -> None:
transfer_cls, _export_ref, _is_ref_handle = _import_mooncake_dataproto_helpers()
transfer = transfer_cls(store, key_prefix=handle.get("slime_key_prefix", ""))
transfer.cleanup_dataproto(handle)


def is_mooncake_dataproto_handle(value: Any) -> bool:
_transfer_cls, _export_ref, is_ref_handle = _import_mooncake_dataproto_helpers()
return is_ref_handle(value)


def _env_store_config() -> dict[str, Any]:
return {
"local_hostname": os.getenv("MOONCAKE_LOCAL_HOSTNAME", "localhost"),
"metadata_server": os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"),
"global_segment_size": int(os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", str(16 * 1024 * 1024 * 1024))),
"local_buffer_size": int(os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", str(16 * 1024 * 1024 * 1024))),
"protocol": os.getenv("MOONCAKE_PROTOCOL", "rdma"),
"rdma_devices": os.getenv("MOONCAKE_DEVICE", ""),
"master_server_addr": os.getenv("MOONCAKE_MASTER", "127.0.0.1:50051"),
}


class InMemoryMooncakeStore:
def __init__(self) -> None:
self.objects: dict[str, bytes] = {}
self.tensors: dict[str, torch.Tensor] = {}

def put(self, key: str, value: Any) -> int:
self.objects[key] = bytes(value)
return 0

def get(self, key: str) -> bytes:
return self.objects[key]

def remove(self, key: str, force: bool = False) -> int:
self.objects.pop(key, None)
self.tensors.pop(key, None)
return 0

def put_tensor(self, key: str, tensor: torch.Tensor) -> int:
self.tensors[key] = tensor.detach().cpu().clone()
return 0

def get_tensor(self, key: str) -> torch.Tensor:
return self.tensors[key].clone()


def _import_mooncake_dataproto_helpers():
try:
from mooncake.structured_object_store import (
MooncakeBundleTransfer,
export_dataproto_ref,
is_dataproto_ref_handle,
)
except ImportError as exc:
raise ImportError("Mooncake structured object DataProto helpers are required for mooncake_dataproto transfer") from exc
return MooncakeBundleTransfer, export_dataproto_ref, is_dataproto_ref_handle
Loading
Loading