diff --git a/docs/en/advanced/mooncake-dataproto-transfer.md b/docs/en/advanced/mooncake-dataproto-transfer.md new file mode 100644 index 0000000000..572146e43b --- /dev/null +++ b/docs/en/advanced/mooncake-dataproto-transfer.md @@ -0,0 +1,30 @@ +# Mooncake DataProto Rollout Transfer + +slime can transfer rollout data through Mooncake instead of Ray object references. This is useful when the rollout producer and actor consumer run on different nodes and Mooncake Store is configured for the cluster transport. + +The default transfer backend remains Ray. Enable Mooncake DataProto transfer explicitly: + +```bash +python3 train.py \ + --transfer-backend mooncake_dataproto \ + --mooncake-dataproto-store-init-kwargs '{"setup_method":"setup"}' +``` + +## What is transferred + +The Mooncake path keeps slime's rollout data layout unchanged: + +- per-rank rollout partitions are still selected by slime before actor consumption; +- tensor fields such as `tokens` and `loss_masks` are stored as Mooncake remote tensor batches; +- non-tensor rollout fields and metadata stay in the `DataProto` wrapper; +- cleanup keys are tracked in metadata and removed after actor-side materialization. + +## Options + +| Option | Default | Meaning | +| --- | --- | --- | +| `--transfer-backend` | `ray` | Set to `mooncake_dataproto` to enable Mooncake rollout transfer. | +| `--mooncake-dataproto-store-init-kwargs` | `null` | JSON arguments used to initialize the Mooncake store. Use `{"setup_method":"setup"}` for real Mooncake Store setup and `{"setup_method":"setup_dummy"}` for local unit tests. | +| `--mooncake-dataproto-hard-pin` | `true` | Hard-pin remote tensor data to the producer segment when publishing tensor batches. | + +For performance runs, configure Mooncake Store with the production transport, for example RDMA, and keep buffer registration or prewarm costs separate from online transfer latency. diff --git a/docs/en/index.rst b/docs/en/index.rst index f3401ce3a7..4eff4a4764 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -84,6 +84,7 @@ Start by Use Case advanced/pd-disaggregation.md advanced/external-rollout-engines.md advanced/delta-weight-sync.md + advanced/mooncake-dataproto-transfer.md advanced/sglang-config.md advanced/megatron-config.md advanced/arch-support-beyond-megatron.md diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index ff571101af..225ec69428 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -850,6 +850,21 @@ def _split_train_data_by_dp(self, data): rollout_indices=data["rollout_ids"], ) + if getattr(self.args, "transfer_backend", "ray") in {"mooncake", "mooncake_dataproto"}: + from slime.utils.rollout_dataproto import split_rollout_data_by_dp_dataproto + + dynamic_global_batch_size = getattr(self, "_dynamic_global_batch_size", None) + return split_rollout_data_by_dp_dataproto( + self.args, + data, + dp_size, + partitions, + dynamic_global_batch_size, + micro_batch_indices, + num_microbatches, + global_batch_sizes, + ) + # Package per-rank rollout_data rollout_data_refs = [] for r in range(dp_size): diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index d5cac9d44b..b8867dc021 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -402,6 +402,24 @@ def add_rollout_arguments(parser): "This is used to shuffle the prompts and also for the random sampling of the prompts." ), ) + parser.add_argument( + "--transfer-backend", + choices=["ray", "mooncake", "mooncake_dataproto"], + default="ray", + help="Rollout data transfer backend. Keep ray as the default; mooncake is experimental.", + ) + parser.add_argument( + "--mooncake-dataproto-hard-pin", + action=argparse.BooleanOptionalAction, + default=True, + help="Hard-pin Mooncake rollout tensors to the producer segment for Mooncake transfer.", + ) + parser.add_argument( + "--mooncake-dataproto-store-init-kwargs", + type=json.loads, + default=None, + help="JSON kwargs used to initialize MooncakeDistributedStore for Mooncake transfer.", + ) # sampling parser.add_argument( @@ -1748,6 +1766,23 @@ def _validate_update_weight_args(args) -> None: def slime_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) + if getattr(args, "transfer_backend", "ray") == "mooncake_dataproto": + args.transfer_backend = "mooncake" + + if getattr(args, "transfer_backend", "ray") == "mooncake": + from slime.utils.remote_batch import normalize_store_init_kwargs + + args.mooncake_dataproto_store_init_kwargs = normalize_store_init_kwargs( + args.mooncake_dataproto_store_init_kwargs + ) + + if args.use_slime_router: + logger.warning( + "--use-slime-router is deprecated and ignored. slime now always uses sglang_router " + "built from https://github.com/zhuzilin/sgl-router." + ) + args.use_slime_router = False + if args.kl_coef != 0 or args.use_kl_loss: if not os.path.exists(args.ref_load): raise FileNotFoundError(f"ref_load {args.ref_load} does not exist, please check the path.") diff --git a/slime/utils/data.py b/slime/utils/data.py index 0d26b6dda5..1d2b64a339 100644 --- a/slime/utils/data.py +++ b/slime/utils/data.py @@ -4,7 +4,6 @@ import os import random import re - import numpy as np import ray @@ -291,7 +290,12 @@ def __len__(self): def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size): assert len(rollout_data_ref) == dp_size - rollout_data = ray.get(rollout_data_ref[dp_rank].inner) + if getattr(args, "transfer_backend", "ray") in {"mooncake", "mooncake_dataproto"}: + from slime.utils.rollout_dataproto import materialize_dataproto_rollout_data + + rollout_data = materialize_dataproto_rollout_data(args, rollout_data_ref[dp_rank]) + else: + rollout_data = ray.get(rollout_data_ref[dp_rank].inner) partition = rollout_data.pop("partition") total_lengths = rollout_data["total_lengths"] diff --git a/slime/utils/remote_batch.py b/slime/utils/remote_batch.py new file mode 100644 index 0000000000..2ab36ca784 --- /dev/null +++ b/slime/utils/remote_batch.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import os +from typing import Any, Mapping + +import torch + +ALLOWED_SETUP_METHODS = {"setup", "setup_dummy"} +_STORE_CACHE: dict[tuple[tuple[str, str], ...], Any] = {} + + +def normalize_store_init_kwargs(store_init_kwargs: dict[str, Any] | None) -> dict[str, Any]: + if store_init_kwargs is None: + raise ValueError("mooncake transfer requires --mooncake-dataproto-store-init-kwargs") + if not store_init_kwargs: + return {"setup_method": "setup"} + setup_method = store_init_kwargs.get("setup_method", "setup") + if setup_method not in ALLOWED_SETUP_METHODS: + raise ValueError(f"unsupported Mooncake store setup_method {setup_method!r}; allowed: {sorted(ALLOWED_SETUP_METHODS)}") + return dict(store_init_kwargs) + + +def create_mooncake_store(store_init_kwargs: dict[str, Any] | None = None) -> Any: + kwargs = normalize_store_init_kwargs(store_init_kwargs or {}) + setup_method = kwargs.get("setup_method", "setup") + if setup_method == "setup_dummy": + return InMemoryMooncakeStore() + + from mooncake.store import MooncakeDistributedStore # type: ignore + + store = MooncakeDistributedStore() + setup_kwargs = {key: val for key, val in kwargs.items() if key != "setup_method"} + setup = getattr(store, setup_method) + try: + ret = setup(**setup_kwargs) + except TypeError: + if setup_method != "setup": + raise + ret = setup(_env_store_config() | setup_kwargs) + if ret != 0: + raise RuntimeError(f"Mooncake store {setup_method} failed with retcode {ret}") + return store + + +def get_cached_mooncake_store(store_init_kwargs: dict[str, Any] | None = None) -> Any: + kwargs = normalize_store_init_kwargs(store_init_kwargs) + cache_key = tuple(sorted((key, repr(val)) for key, val in kwargs.items())) + if cache_key not in _STORE_CACHE: + _STORE_CACHE[cache_key] = create_mooncake_store(kwargs) + return _STORE_CACHE[cache_key] + + +def put_mooncake_dataproto( + data: Mapping[str, Any], + store: Any, + *, + key_prefix: str, + namespace: str = "slime", + partition: str = "rollout", +) -> dict[str, Any]: + transfer_cls, export_ref, _is_ref_handle = _import_mooncake_dataproto_helpers() + transfer = transfer_cls(store, key_prefix=key_prefix) + ref = transfer.put_dataproto(data, namespace=namespace, partition=partition, stage="rollout") + handle = export_ref(ref) + handle["slime_key_prefix"] = key_prefix + return handle + + +def get_mooncake_dataproto(handle: Mapping[str, Any], store: Any) -> dict[str, Any]: + transfer_cls, _export_ref, _is_ref_handle = _import_mooncake_dataproto_helpers() + transfer = transfer_cls(store, key_prefix=handle.get("slime_key_prefix", "")) + return transfer.get_dataproto(handle) + + +def cleanup_mooncake_dataproto(handle: Mapping[str, Any], store: Any) -> None: + transfer_cls, _export_ref, _is_ref_handle = _import_mooncake_dataproto_helpers() + transfer = transfer_cls(store, key_prefix=handle.get("slime_key_prefix", "")) + transfer.cleanup_dataproto(handle) + + +def is_mooncake_dataproto_handle(value: Any) -> bool: + _transfer_cls, _export_ref, is_ref_handle = _import_mooncake_dataproto_helpers() + return is_ref_handle(value) + + +def _env_store_config() -> dict[str, Any]: + return { + "local_hostname": os.getenv("MOONCAKE_LOCAL_HOSTNAME", "localhost"), + "metadata_server": os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"), + "global_segment_size": int(os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", str(16 * 1024 * 1024 * 1024))), + "local_buffer_size": int(os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", str(16 * 1024 * 1024 * 1024))), + "protocol": os.getenv("MOONCAKE_PROTOCOL", "rdma"), + "rdma_devices": os.getenv("MOONCAKE_DEVICE", ""), + "master_server_addr": os.getenv("MOONCAKE_MASTER", "127.0.0.1:50051"), + } + + +class InMemoryMooncakeStore: + def __init__(self) -> None: + self.objects: dict[str, bytes] = {} + self.tensors: dict[str, torch.Tensor] = {} + + def put(self, key: str, value: Any) -> int: + self.objects[key] = bytes(value) + return 0 + + def get(self, key: str) -> bytes: + return self.objects[key] + + def remove(self, key: str, force: bool = False) -> int: + self.objects.pop(key, None) + self.tensors.pop(key, None) + return 0 + + def put_tensor(self, key: str, tensor: torch.Tensor) -> int: + self.tensors[key] = tensor.detach().cpu().clone() + return 0 + + def get_tensor(self, key: str) -> torch.Tensor: + return self.tensors[key].clone() + + +def _import_mooncake_dataproto_helpers(): + try: + from mooncake.structured_object_store import ( + MooncakeBundleTransfer, + export_dataproto_ref, + is_dataproto_ref_handle, + ) + except ImportError as exc: + raise ImportError("Mooncake structured object DataProto helpers are required for mooncake_dataproto transfer") from exc + return MooncakeBundleTransfer, export_dataproto_ref, is_dataproto_ref_handle diff --git a/slime/utils/rollout_dataproto.py b/slime/utils/rollout_dataproto.py new file mode 100644 index 0000000000..a527a4b731 --- /dev/null +++ b/slime/utils/rollout_dataproto.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +import uuid +from typing import Any, Mapping + +import numpy as np +import torch + +from slime.utils.remote_batch import ( + cleanup_mooncake_dataproto, + get_cached_mooncake_store, + get_mooncake_dataproto, + is_mooncake_dataproto_handle, + normalize_store_init_kwargs, + put_mooncake_dataproto, +) + +PARTITIONED_KEYS = ( + "tokens", + "multimodal_train_inputs", + "response_lengths", + "rewards", + "truncated", + "loss_masks", + "round_number", + "sample_indices", + "rollout_ids", + "rollout_mask_sums", + "rollout_log_probs", + "rollout_top_p_token_ids", + "rollout_top_p_token_offsets", + "rollout_routed_experts", + "prompt", + "teacher_log_probs", +) +GLOBAL_KEYS = ( + "raw_reward", + "total_lengths", + "global_batch_sizes", + "num_microbatches", + "micro_batch_indices", + "dynamic_global_batch_size", +) +_ROLLOUT_DATA_TENSOR_DTYPES = { + "tokens": torch.long, + "loss_masks": torch.int, + "rollout_log_probs": torch.float32, + "rollout_top_p_token_ids": torch.int32, + "rollout_top_p_token_offsets": torch.int32, + "teacher_log_probs": torch.float32, + "rollout_routed_experts": None, +} + + +def split_rollout_data_by_dp_dataproto( + args: Any, + data: dict, + dp_size: int, + partitions: list, + dynamic_global_batch_size: int | None = None, + micro_batch_indices: list | None = None, + num_microbatches: list | None = None, + global_batch_sizes: list | None = None, +) -> list[dict[str, Any]]: + if len(partitions) != dp_size: + raise ValueError(f"expected {dp_size} partitions, got {len(partitions)}") + store_init_kwargs = _store_init_kwargs(args) + store = get_cached_mooncake_store(store_init_kwargs) + transfer_id = uuid.uuid4().hex + refs = [] + try: + for dp_rank, partition in enumerate(partitions): + rollout_data = _build_rank_rollout_data( + data, + [int(idx) for idx in partition], + micro_batch_indices[dp_rank] if micro_batch_indices is not None else None, + num_microbatches, + global_batch_sizes, + dynamic_global_batch_size, + ) + _tensorize_rollout_data_for_training(rollout_data) + ref = put_mooncake_dataproto( + _rollout_data_to_dataproto_envelope(rollout_data), + store, + key_prefix=f"slime-rollout/{transfer_id}/dp{dp_rank}", + ) + refs.append(ref) + except Exception: + cleanup_dataproto_refs(refs, store_init_kwargs=store_init_kwargs) + raise + return refs + + +def materialize_dataproto_rollout_data(args: Any, ref: Mapping[str, Any]) -> dict: + store_init_kwargs = _store_init_kwargs(args) + store = get_cached_mooncake_store(store_init_kwargs) + if not is_mooncake_dataproto_handle(ref): + raise TypeError(f"expected Mooncake DataProto handle, got {type(ref).__name__}") + envelope = get_mooncake_dataproto(ref, store) + rollout_data = dict(envelope.get("batch", {})) + rollout_data.update( + { + key: _non_tensor_value_to_legacy(value) + for key, value in envelope.get("non_tensor_batch", {}).items() + } + ) + rollout_data.update(envelope.get("meta_info", {})) + _tensorize_rollout_data_for_training(rollout_data) + return rollout_data + + +def maybe_cleanup_dataproto_refs(args: Any, refs: list[Mapping[str, Any]], suppress_errors: bool = False) -> None: + if getattr(args, "transfer_backend", "ray") not in {"mooncake", "mooncake_dataproto"}: + return + store_init_kwargs = _store_init_kwargs(args) + if not suppress_errors: + cleanup_dataproto_refs(refs, store_init_kwargs=store_init_kwargs) + return + try: + cleanup_dataproto_refs(refs, store_init_kwargs=store_init_kwargs) + except Exception: + return + + +def cleanup_dataproto_refs( + refs: list[Mapping[str, Any]], + store_init_kwargs: dict[str, Any] | None = None, +) -> None: + if not refs: + return + store = get_cached_mooncake_store(store_init_kwargs or {"setup_method": "setup"}) + for ref in refs: + cleanup_mooncake_dataproto(ref, store) + + +def _build_rank_rollout_data( + data: dict, + partition: list[int], + micro_batch_indices: list | None, + num_microbatches: list | None, + global_batch_sizes: list | None, + dynamic_global_batch_size: int | None, +) -> dict: + rollout_data = {"partition": partition} + for key in PARTITIONED_KEYS: + if key in data: + rollout_data[key] = [data[key][idx] for idx in partition] + for key in ("raw_reward", "total_lengths"): + if key in data: + rollout_data[key] = data[key] + if global_batch_sizes is not None: + rollout_data["global_batch_sizes"] = global_batch_sizes + if num_microbatches is not None: + rollout_data["num_microbatches"] = num_microbatches + if micro_batch_indices is not None: + rollout_data["micro_batch_indices"] = micro_batch_indices + if dynamic_global_batch_size is not None: + rollout_data["dynamic_global_batch_size"] = dynamic_global_batch_size + return rollout_data + + +def _rollout_data_to_dataproto_envelope(rollout_data: dict) -> dict[str, dict[str, Any]]: + batch_size = len(rollout_data["partition"]) + batch = {} + non_tensor_batch = {} + meta_info = {} + for key, value in rollout_data.items(): + if key in GLOBAL_KEYS: + meta_info[key] = _json_safe_metadata(value) + elif isinstance(value, torch.Tensor) and _is_row_aligned(value, batch_size): + batch[key] = value + elif _is_row_aligned(value, batch_size): + non_tensor_batch[key] = _to_object_array(value) + else: + meta_info[key] = _json_safe_metadata(value) + return {"batch": batch, "non_tensor_batch": non_tensor_batch, "meta_info": meta_info} + + +def _is_row_aligned(value: Any, batch_size: int) -> bool: + if isinstance(value, torch.Tensor): + return value.ndim > 0 and value.shape[0] == batch_size + if isinstance(value, np.ndarray): + return value.ndim > 0 and value.shape[0] == batch_size + if isinstance(value, (list, tuple)): + return len(value) == batch_size + return False + + +def _to_object_array(value: Any) -> np.ndarray: + if isinstance(value, np.ndarray): + return value + if isinstance(value, torch.Tensor): + return np.asarray([item.detach().cpu() for item in value], dtype=object) + return np.asarray(value, dtype=object) + + +def _non_tensor_value_to_legacy(value: Any) -> Any: + if isinstance(value, np.ndarray): + return value.tolist() + return value + + +def _tensorize_rollout_data_for_training(rollout_data: dict[str, Any]) -> None: + for key, dtype in _ROLLOUT_DATA_TENSOR_DTYPES.items(): + if key in rollout_data: + rollout_data[key] = [_cpu_tensor(value, dtype=dtype) for value in rollout_data[key]] + + if "multimodal_train_inputs" in rollout_data: + rollout_data["multimodal_train_inputs"] = [ + ( + { + key: _cpu_tensor(value) if isinstance(value, (np.ndarray, torch.Tensor)) else value + for key, value in mm_dict.items() + } + if mm_dict is not None + else None + ) + for mm_dict in rollout_data["multimodal_train_inputs"] + ] + + if "rollout_mask_sums" in rollout_data: + rollout_data["rollout_mask_sums"] = _cpu_tensor( + rollout_data["rollout_mask_sums"], + dtype=torch.float32, + ) + + +def _cpu_tensor(value: Any, dtype: torch.dtype | None = None) -> torch.Tensor: + if isinstance(value, np.ndarray) and not value.flags.writeable: + value = value.copy() + tensor = torch.as_tensor(value, dtype=dtype) if dtype is not None else torch.as_tensor(value) + return tensor.detach().cpu().contiguous() + + +def _json_safe_metadata(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, np.generic): + return value.item() + if isinstance(value, torch.Tensor): + return value.detach().cpu().tolist() + if isinstance(value, np.ndarray): + return value.tolist() + if isinstance(value, Mapping): + return {str(key): _json_safe_metadata(val) for key, val in value.items()} + if isinstance(value, (list, tuple)): + return [_json_safe_metadata(item) for item in value] + return value + + +def _store_init_kwargs(args: Any) -> dict[str, Any]: + kwargs = getattr(args, "mooncake_dataproto_store_init_kwargs", None) + return normalize_store_init_kwargs(kwargs) diff --git a/tests/utils/test_dataproto_transfer.py b/tests/utils/test_dataproto_transfer.py new file mode 100644 index 0000000000..ce6fba823c --- /dev/null +++ b/tests/utils/test_dataproto_transfer.py @@ -0,0 +1,147 @@ +import pickle +import sys +import types + +import numpy as np +import torch + +from slime.utils import remote_batch +from slime.utils.remote_batch import create_mooncake_store, normalize_store_init_kwargs +from slime.utils.rollout_dataproto import ( + cleanup_dataproto_refs, + materialize_dataproto_rollout_data, + split_rollout_data_by_dp_dataproto, +) + + +class FakeRef: + def __init__(self, manifest_key, batch_size): + self.manifest_key = manifest_key + self.batch_size = batch_size + + +class FakeMooncakeBundleTransfer: + def __init__(self, store, key_prefix=""): + self.store = store + self.key_prefix = key_prefix or "default" + + def put_dataproto(self, data, namespace="default", partition="default", stage="default"): + key = f"{self.key_prefix}/{namespace}/{partition}/{stage}/manifest" + self.store.put(key, pickle.dumps(data)) + if data["non_tensor_batch"]: + batch_size = len(next(iter(data["non_tensor_batch"].values()))) + else: + batch_size = len(next(iter(data["batch"].values()))) + return FakeRef(key, batch_size) + + def get_dataproto(self, handle): + return pickle.loads(self.store.get(handle["manifest_key"])) + + def cleanup_dataproto(self, handle): + self.store.remove(handle["manifest_key"], True) + + +def fake_export_dataproto_ref(ref): + return { + "type": "mooncake_dataproto_ref", + "version": 1, + "kind": "bundle_stages", + "manifest_key": ref.manifest_key, + "batch_size": ref.batch_size, + } + + +def fake_is_dataproto_ref_handle(value): + return isinstance(value, dict) and value.get("type") == "mooncake_dataproto_ref" + + +def install_fake_mooncake(monkeypatch): + mooncake_module = types.ModuleType("mooncake") + structured_module = types.ModuleType("mooncake.structured_object_store") + structured_module.MooncakeBundleTransfer = FakeMooncakeBundleTransfer + structured_module.export_dataproto_ref = fake_export_dataproto_ref + structured_module.is_dataproto_ref_handle = fake_is_dataproto_ref_handle + monkeypatch.setitem(sys.modules, "mooncake", mooncake_module) + monkeypatch.setitem(sys.modules, "mooncake.structured_object_store", structured_module) + remote_batch._STORE_CACHE.clear() + + +def test_rollout_dict_roundtrips_through_mooncake_handle(monkeypatch): + install_fake_mooncake(monkeypatch) + args = types.SimpleNamespace(mooncake_dataproto_store_init_kwargs={"setup_method": "setup_dummy"}) + data = { + "tokens": [[1, 2], [3, 4, 5], [6]], + "loss_masks": [[1, 1], [1, 1, 1], [1]], + "response_lengths": [2, 3, 1], + "rewards": [1.0, 2.0, 3.0], + "rollout_ids": [10, 11, 12], + "rollout_mask_sums": [2.0, 3.0, 1.0], + "total_lengths": [2, 3, 1], + "raw_reward": [1.0, 2.0, 3.0], + } + + refs = split_rollout_data_by_dp_dataproto( + args, + data, + 2, + [[0, 2], [1]], + micro_batch_indices=[[[0], [1]], [[0]]], + num_microbatches=[2, 1], + global_batch_sizes=[2, 1], + ) + + assert all(ref["type"] == "mooncake_dataproto_ref" for ref in refs) + rollout_data = materialize_dataproto_rollout_data(args, refs[0]) + + assert rollout_data["partition"] == [0, 2] + assert [row.tolist() for row in rollout_data["tokens"]] == [[1, 2], [6]] + assert [row.tolist() for row in rollout_data["loss_masks"]] == [[1, 1], [1]] + assert rollout_data["response_lengths"] == [2, 1] + assert rollout_data["rollout_ids"] == [10, 12] + assert rollout_data["rollout_mask_sums"].tolist() == [2.0, 1.0] + assert rollout_data["total_lengths"] == [2, 3, 1] + assert rollout_data["micro_batch_indices"] == [[0], [1]] + assert rollout_data["num_microbatches"] == [2, 1] + assert rollout_data["global_batch_sizes"] == [2, 1] + + store = remote_batch.get_cached_mooncake_store({"setup_method": "setup_dummy"}) + cleanup_dataproto_refs(refs, store_init_kwargs={"setup_method": "setup_dummy"}) + assert store.objects == {} + + +def test_rollout_transfer_rejects_partition_mismatch(monkeypatch): + install_fake_mooncake(monkeypatch) + args = types.SimpleNamespace(mooncake_dataproto_store_init_kwargs={"setup_method": "setup_dummy"}) + try: + split_rollout_data_by_dp_dataproto(args, {}, 2, [[]]) + except ValueError as exc: + assert "expected 2 partitions" in str(exc) + else: + raise AssertionError("partition mismatch should be rejected") + + +def test_normalizes_empty_mooncake_setup_kwargs_to_setup(): + assert normalize_store_init_kwargs({}) == {"setup_method": "setup"} + + +def test_rejects_unsafe_mooncake_setup_method(): + try: + normalize_store_init_kwargs({"setup_method": "remove"}) + except ValueError as exc: + assert "unsupported Mooncake store setup_method" in str(exc) + else: + raise AssertionError("unsafe setup_method should be rejected") + + +def test_create_store_normalizes_none_for_default_call(monkeypatch): + class Store: + def setup(self): + return 0 + + mooncake_module = types.ModuleType("mooncake") + store_module = types.ModuleType("mooncake.store") + store_module.MooncakeDistributedStore = Store + monkeypatch.setitem(sys.modules, "mooncake", mooncake_module) + monkeypatch.setitem(sys.modules, "mooncake.store", store_module) + + assert isinstance(create_mooncake_store(), Store) diff --git a/train.py b/train.py index 620f7e8d70..02392d0fe5 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,7 @@ import ray +from slime.utils.rollout_dataproto import maybe_cleanup_dataproto_refs + from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models from slime.utils.arguments import parse_args from slime.utils.logging_utils import configure_logger, finish_tracking, init_tracking @@ -71,14 +73,19 @@ def save(rollout_id): actor_trains_this_step = (not args.use_critic) or rollout_id >= args.num_critic_only_steps - if args.use_critic: - value_refs = critic_model.async_train(rollout_id, rollout_data_ref) - if actor_trains_this_step: - ray.get(actor_model.async_train(rollout_id, rollout_data_ref, external_data=value_refs)) + train_succeeded = False + try: + if args.use_critic: + value_refs = critic_model.async_train(rollout_id, rollout_data_ref) + if actor_trains_this_step: + ray.get(actor_model.async_train(rollout_id, rollout_data_ref, external_data=value_refs)) + else: + ray.get(value_refs) else: - ray.get(value_refs) - else: - ray.get(actor_model.async_train(rollout_id, rollout_data_ref)) + ray.get(actor_model.async_train(rollout_id, rollout_data_ref)) + train_succeeded = True + finally: + maybe_cleanup_dataproto_refs(args, rollout_data_ref, suppress_errors=not train_succeeded) if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch, args.num_rollout): save(rollout_id) diff --git a/train_async.py b/train_async.py index 9d4c9b6473..b29a7c1741 100644 --- a/train_async.py +++ b/train_async.py @@ -1,5 +1,7 @@ import ray +from slime.utils.rollout_dataproto import maybe_cleanup_dataproto_refs + from slime.ray.placement_group import create_placement_groups, create_rollout_manager, create_training_models from slime.utils.arguments import parse_args from slime.utils.logging_utils import configure_logger, finish_tracking, init_tracking @@ -38,15 +40,21 @@ def train(args): if rollout_id + 1 < args.num_rollout: rollout_data_next_future = rollout_manager.generate.remote(rollout_id + 1) - if args.use_critic: - actor_trains_this_step = rollout_id >= args.num_critic_only_steps - value_refs = critic_model.async_train(rollout_id, rollout_data_curr_ref) - if actor_trains_this_step: - ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref, external_data=value_refs)) + train_succeeded = False + try: + if args.use_critic: + actor_trains_this_step = rollout_id >= args.num_critic_only_steps + value_refs = critic_model.async_train(rollout_id, rollout_data_curr_ref) + if actor_trains_this_step: + ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref, external_data=value_refs)) + else: + ray.get(value_refs) else: - ray.get(value_refs) - else: - ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref)) + actor_trains_this_step = True + ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref)) + train_succeeded = True + finally: + maybe_cleanup_dataproto_refs(args, rollout_data_curr_ref, suppress_errors=not train_succeeded) if should_run_periodic_action(rollout_id, args.save_interval, num_rollout_per_epoch, args.num_rollout): if (not args.use_critic) or rollout_id >= args.num_critic_only_steps: