From 5b3fda32d5cfb2f0d8716d3f1e94cc9e3b412e0f Mon Sep 17 00:00:00 2001 From: Nripesh Niketan Date: Fri, 29 May 2026 09:15:34 +0100 Subject: [PATCH] Add Apple Silicon (MPS) backend support via torch.mps Introduces a unified device abstraction layer (`devices.py`) and updates all packages (ltx-core, ltx-pipelines, ltx-trainer) to support Metal Performance Shaders on Apple Silicon Macs alongside CUDA. Handles MPS- specific limitations such as missing ops by providing fallback paths and adapts memory management, attention, quantization, and block streaming logic for the MPS backend. --- README.md | 8 + packages/ltx-core/README.md | 18 +-- .../src/ltx_core/block_streaming/builder.py | 41 +++-- .../src/ltx_core/block_streaming/disk.py | 8 +- .../src/ltx_core/block_streaming/pool.py | 6 +- .../src/ltx_core/block_streaming/provider.py | 69 +++++--- .../src/ltx_core/block_streaming/source.py | 11 +- .../src/ltx_core/block_streaming/wrapper.py | 13 +- packages/ltx-core/src/ltx_core/devices.py | 149 ++++++++++++++++++ .../src/ltx_core/loader/fuse_loras.py | 5 +- .../src/ltx_core/loader/primitives.py | 11 +- .../loader/single_gpu_model_builder.py | 18 ++- .../ltx_core/model/transformer/attention.py | 24 ++- .../gemma/encoders/base_encoder.py | 3 +- packages/ltx-pipelines/CLAUDE.md | 2 +- packages/ltx-pipelines/README.md | 11 +- .../src/ltx_pipelines/hdr_ic_lora.py | 6 +- .../ltx-pipelines/src/ltx_pipelines/retake.py | 2 +- .../src/ltx_pipelines/utils/gpu_model.py | 7 +- .../src/ltx_pipelines/utils/helpers.py | 15 +- packages/ltx-trainer/README.md | 5 +- packages/ltx-trainer/docs/quick-start.md | 5 +- .../ltx-trainer/scripts/caption_videos.py | 6 +- .../ltx-trainer/scripts/decode_latents.py | 9 +- packages/ltx-trainer/scripts/inference.py | 9 +- .../ltx-trainer/scripts/process_captions.py | 6 +- .../ltx-trainer/scripts/process_dataset.py | 4 +- .../ltx-trainer/scripts/process_videos.py | 9 +- .../ltx-trainer/src/ltx_trainer/captioning.py | 11 +- .../ltx-trainer/src/ltx_trainer/gemma_8bit.py | 8 +- .../ltx-trainer/src/ltx_trainer/gpu_utils.py | 30 ++-- .../src/ltx_trainer/model_loader.py | 11 +- .../src/ltx_trainer/quantization.py | 13 +- .../ltx-trainer/src/ltx_trainer/trainer.py | 33 +++- .../src/ltx_trainer/training_state.py | 1 + .../src/ltx_trainer/validation_sampler.py | 5 +- 36 files changed, 430 insertions(+), 162 deletions(-) create mode 100644 packages/ltx-core/src/ltx_core/devices.py diff --git a/README.md b/README.md index 362422d5..a283517c 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,14 @@ uv sync --frozen source .venv/bin/activate ``` +### Hardware Backends + +The code automatically selects the best available PyTorch backend in this order: CUDA, MPS, then CPU. On Apple Silicon +Macs with an MPS-enabled PyTorch install, pipeline and utility-script defaults now use `torch.device("mps")`; pass +`--device cpu`, `--device cuda`, or `--device mps` when you need to override auto-selection. CUDA-only optimizations +such as TensorRT-LLM FP8 scaled matrix multiplication, FlashAttention, xFormers, and bitsandbytes 8-bit loading remain +NVIDIA-specific. + ### Required Models Download the following models from the [LTX-2.3 HuggingFace repository](https://huggingface.co/Lightricks/LTX-2.3): diff --git a/packages/ltx-core/README.md b/packages/ltx-core/README.md index 33556ddc..988847de 100644 --- a/packages/ltx-core/README.md +++ b/packages/ltx-core/README.md @@ -58,7 +58,7 @@ pip install -e packages/ltx-core ### Loader -The `loader/` module provides `SingleGPUModelBuilder`, a frozen dataclass that loads a PyTorch model from `.safetensors` checkpoints and optionally fuses one or more LoRA adapters. +The `loader/` module provides `SingleGPUModelBuilder`, a frozen dataclass that loads a PyTorch model from `.safetensors` checkpoints and optionally fuses one or more LoRA adapters. When no device is supplied it automatically picks CUDA, then MPS, then CPU. #### Basic usage @@ -69,7 +69,7 @@ builder = SingleGPUModelBuilder( model_class_configurator=MyModelConfigurator, model_path="/path/to/model.safetensors", ) -model = builder.build(device=torch.device("cuda")) +model = builder.build() ``` #### Loading LoRA adapters @@ -89,27 +89,27 @@ builder = ( .lora("/path/to/lora_a.safetensors", 0.8, lora_sd_ops) .lora("/path/to/lora_b.safetensors", 0.5, lora_sd_ops) ) -model = builder.build(device=torch.device("cuda")) +model = builder.build() ``` #### Memory-efficient LoRA loading (`lora_load_device`) -By default, LoRA weights are loaded onto the **CPU** (`lora_load_device=torch.device("cpu")`). This means each LoRA adapter is kept in CPU memory and transferred to the GPU sequentially during weight fusion, which keeps peak GPU memory low even when fusing large adapters. +By default, LoRA weights are loaded onto the **CPU** (`lora_load_device=torch.device("cpu")`). This means each LoRA adapter is kept in CPU memory and transferred to the accelerator sequentially during weight fusion, which keeps peak accelerator memory low even when fusing large adapters. -If all adapters fit comfortably in GPU memory you can skip the CPU staging by setting `lora_load_device` to the target CUDA device: +If all adapters fit comfortably in accelerator memory you can skip the CPU staging by setting `lora_load_device` to the target CUDA or MPS device: ```python import torch from ltx_core.loader import SingleGPUModelBuilder -# Load LoRA weights directly onto the GPU (faster, but uses more GPU memory) +# Load LoRA weights directly onto the accelerator (faster, but uses more accelerator memory) builder = SingleGPUModelBuilder( model_class_configurator=MyModelConfigurator, model_path="/path/to/model.safetensors", - lora_load_device=torch.device("cuda"), + lora_load_device=torch.device("mps"), ).lora("/path/to/lora.safetensors", 1.0, lora_sd_ops) -model = builder.build(device=torch.device("cuda")) +model = builder.build(device=torch.device("mps")) ``` ### Quantization @@ -144,7 +144,7 @@ builder = SingleGPUModelBuilder( module_ops=policy.module_ops, fuse_rule=policy.fuse_rule, ) -model = builder.build(device=torch.device("cuda")) +model = builder.build() ``` #### FP8 Cast diff --git a/packages/ltx-core/src/ltx_core/block_streaming/builder.py b/packages/ltx-core/src/ltx_core/block_streaming/builder.py index 0930e4a0..11e2794a 100644 --- a/packages/ltx-core/src/ltx_core/block_streaming/builder.py +++ b/packages/ltx-core/src/ltx_core/block_streaming/builder.py @@ -16,6 +16,7 @@ from ltx_core.block_streaming.source import DiskWeightSource, PinnedWeightSource, WeightSource from ltx_core.block_streaming.utils import allocate_layout_views, derive_layout, make_block_key, resolve_attr from ltx_core.block_streaming.wrapper import BlockStreamingWrapper +from ltx_core.devices import synchronize_device from ltx_core.loader.fuse_loras import FuseRule, bf16_fuse_rule, fuse_lora_weights from ltx_core.loader.helpers import create_meta_model, load_state_dict, read_model_config from ltx_core.loader.module_ops import ModuleOps @@ -102,7 +103,7 @@ def build( ) -> BlockStreamingWrapper: """Build and return a ready-to-use :class:`BlockStreamingWrapper`. Args: - target_device: GPU device for compute. + target_device: Accelerator device for compute. dtype: Weight dtype (e.g. ``torch.bfloat16``). cpu_slots_count: Number of pinned CPU buffer slots. ``None`` = RAM streaming (all blocks pre-loaded with LoRA fusion). @@ -112,6 +113,9 @@ def build( if not self.blocks_prefix: raise ValueError("blocks_prefix must be non-empty for streaming") + target_device = torch.device(target_device) + use_cuda_streaming = target_device.type == "cuda" + config = read_model_config(self.model_path, self.model_loader) meta_model: nn.Module = create_meta_model(self.model_class_configurator, config, self.module_ops) meta_model.eval() @@ -126,20 +130,37 @@ def build( if cpu_slots_count >= len(blocks): source, lora_sources = self._build_pinned_source( - meta_model, target_device, dtype, cpu_slots_count, block_key_map, non_block_keys + meta_model, target_device, dtype, cpu_slots_count, block_key_map, non_block_keys, use_cuda_streaming ) else: reader = DiskTensorReader(checkpoint_paths) source, lora_sources = self._build_disk_source( - meta_model, target_device, dtype, cpu_slots_count, reader, block_key_map, non_block_keys + meta_model, + target_device, + dtype, + cpu_slots_count, + reader, + block_key_map, + non_block_keys, + use_cuda_streaming, ) - copy_stream = torch.cuda.Stream(device=target_device) + copy_stream = torch.cuda.Stream(device=target_device) if use_cuda_streaming else None + if copy_stream is not None: + + def reuse_barrier(event: object) -> None: + copy_stream.wait_event(event) + + else: + + def reuse_barrier(_event: object) -> None: + return None + gpu_pool = WeightPool( source.block_layout, gpu_slots_count, target_device, - reuse_barrier=lambda event: copy_stream.wait_event(event), + reuse_barrier=reuse_barrier, ) provider = WeightsProvider( gpu_pool, @@ -165,6 +186,7 @@ def _build_pinned_source( cpu_slots_count: int, block_key_map: dict[int, list[tuple[str, str]]], non_block_keys: list[tuple[str, str]], + pin_memory: bool, ) -> tuple[WeightSource, list[LoraSource]]: """Pre-load all blocks into pinned CPU buffers with LoRA fusion.""" model_sd = load_state_dict( @@ -194,7 +216,7 @@ def _build_pinned_source( key = make_block_key(self.blocks_prefix, block_idx, param_name) block_tensors[key] = block_params[param_name] blocks_layout = derive_layout(block_tensors, dtype) - pinned_blocks = allocate_layout_views(blocks_layout, pin_memory=True) + pinned_blocks = allocate_layout_views(blocks_layout, pin_memory=pin_memory) should_sync = False for key, fused in fuse_lora_weights( @@ -207,7 +229,7 @@ def _build_pinned_source( else: model_sd.sd[key] = fused if should_sync: - torch.cuda.synchronize() + synchronize_device(target_device) # Fill remaining pinned keys from the source state dict. for key in blocks_layout: @@ -242,13 +264,14 @@ def _build_disk_source( reader: DiskTensorReader, block_key_map: dict[int, list[tuple[str, str]]], non_block_keys: list[tuple[str, str]], + pin_memory: bool, ) -> tuple[WeightSource, list[LoraSource]]: """Create a DiskWeightSource backed by a DiskBlockReader for lazy loading. Derives the shared pool layout from the meta model's block 0 - this relies on module_ops (e.g. fp8_cast) leaving the meta param dtype in sync with the post-sd_ops checkpoint dtype. """ - lora_sources = [LoraSource(lora.path, lora.sd_ops, lora.strength) for lora in self.loras] + lora_sources = [LoraSource(lora.path, lora.sd_ops, lora.strength, pin_memory=pin_memory) for lora in self.loras] self._load_non_block_weights( reader, @@ -269,7 +292,7 @@ def _build_disk_source( cpu_slots_count, torch.device("cpu"), reuse_barrier=lambda event: event.synchronize(), - pin_memory=True, + pin_memory=pin_memory, ) block_reader = DiskBlockReader( reader=reader, diff --git a/packages/ltx-core/src/ltx_core/block_streaming/disk.py b/packages/ltx-core/src/ltx_core/block_streaming/disk.py index 16e15d50..79b3ece4 100644 --- a/packages/ltx-core/src/ltx_core/block_streaming/disk.py +++ b/packages/ltx-core/src/ltx_core/block_streaming/disk.py @@ -83,9 +83,9 @@ def cleanup(self) -> None: class LoraSource: - """Pinned-memory cache of matched LoRA A/B factors backed by a single buffer.""" + """CPU cache of matched LoRA A/B factors backed by a single buffer.""" - def __init__(self, path: str, sd_ops: SDOps | None, strength: float) -> None: + def __init__(self, path: str, sd_ops: SDOps | None, strength: float, pin_memory: bool = True) -> None: self.strength = strength self._pinned_ab: dict[str, tuple[torch.Tensor, torch.Tensor]] = {} @@ -117,7 +117,7 @@ def __init__(self, path: str, sd_ops: SDOps | None, strength: float) -> None: _SAFETENSORS_DTYPE_TO_TORCH[b_slice_view.get_dtype()], ) - all_views = allocate_layout_views(layout, pin_memory=True) + all_views = allocate_layout_views(layout, pin_memory=pin_memory) for prefix in matched_prefixes: a_view = all_views[f"{prefix}.A"] @@ -153,7 +153,7 @@ def get_ab( if pair is None: return None a, b = pair - if device is not None and device.type == "cuda": + if device is not None and device.type != "cpu": a = a.to(device=device, non_blocking=True) b = b.to(device=device, non_blocking=True) if dtype is not None: diff --git a/packages/ltx-core/src/ltx_core/block_streaming/pool.py b/packages/ltx-core/src/ltx_core/block_streaming/pool.py index 8dac072b..6851951a 100644 --- a/packages/ltx-core/src/ltx_core/block_streaming/pool.py +++ b/packages/ltx-core/src/ltx_core/block_streaming/pool.py @@ -28,13 +28,13 @@ def __init__( buffer_layout: TensorLayout, capacity: int, device: torch.device, - reuse_barrier: Callable[[torch.cuda.Event], None], + reuse_barrier: Callable[[object], None], pin_memory: bool = False, ) -> None: self._buffer_layout = buffer_layout self._capacity = capacity self._free: deque[dict[str, torch.Tensor]] = deque() - self._events: dict[int, torch.cuda.Event] = {} + self._events: dict[int, object] = {} self._reuse_barrier = reuse_barrier memory_layout = { _make_key(slot, name): (shape, dtype) @@ -61,7 +61,7 @@ def acquire(self) -> dict[str, torch.Tensor]: self._reuse_barrier(event) return weights - def release(self, weights: dict[str, torch.Tensor], event: torch.cuda.Event | None = None) -> None: + def release(self, weights: dict[str, torch.Tensor], event: object | None = None) -> None: """Return a buffer to the free list. If *event* is given it is waited on the next :meth:`acquire` of this buffer, ensuring the prior operation has completed. diff --git a/packages/ltx-core/src/ltx_core/block_streaming/provider.py b/packages/ltx-core/src/ltx_core/block_streaming/provider.py index fb79e39e..37619ae4 100644 --- a/packages/ltx-core/src/ltx_core/block_streaming/provider.py +++ b/packages/ltx-core/src/ltx_core/block_streaming/provider.py @@ -9,6 +9,7 @@ from ltx_core.block_streaming.disk import LoraSource from ltx_core.block_streaming.pool import WeightPool from ltx_core.block_streaming.source import WeightSource +from ltx_core.devices import synchronize_device from ltx_core.loader.fuse_loras import FuseRule, aggregate_lora_products, bf16_fuse_rule from ltx_core.loader.primitives import StateDict @@ -37,13 +38,14 @@ def _contiguous_byte_view(weights: dict[str, torch.Tensor]) -> torch.Tensor | No class WeightsProvider: - """Provides GPU-ready block weights via H2D copy from a pinned CPU weight source. + """Provides accelerator-ready block weights via copies from a CPU weight source. Args: - pool: Pre-allocated GPU weight buffer pool. - copy_stream: Dedicated CUDA stream for async H2D copies. - target_device: GPU device for compute. - source: Pinned CPU weight source. - lora_sources: LoRA adapters fused on H2D copy. + pool: Pre-allocated accelerator weight buffer pool. + copy_stream: Dedicated CUDA stream for async H2D copies. ``None`` uses + synchronous copies, which is used for MPS/CPU. + target_device: Accelerator device for compute. + source: CPU weight source. + lora_sources: LoRA adapters fused after copying. blocks_prefix: State-dict prefix for LoRA key matching. fuse_rule: Per-policy LoRA merge rule (must be streaming-compatible: no companion-key emission). Defaults to ``bf16_fuse_rule``. @@ -52,7 +54,7 @@ class WeightsProvider: def __init__( self, pool: WeightPool, - copy_stream: torch.cuda.Stream, + copy_stream: torch.cuda.Stream | None, target_device: torch.device, source: WeightSource, lora_sources: list[LoraSource] | None = None, @@ -62,7 +64,7 @@ def __init__( self._copy_stream = copy_stream self._pool = pool self._cache: OrderedDict[int, dict[str, torch.Tensor]] = OrderedDict() - self._events: dict[int, torch.cuda.Event] = {} + self._events: dict[int, object] = {} self._target_device = target_device self._source = source self._lora_sources = lora_sources or [] @@ -70,7 +72,7 @@ def __init__( self._fuse_rule = fuse_rule def get(self, idx: int) -> dict[str, torch.Tensor]: - """Return GPU weights for block *idx*. Does H2D copy on miss.""" + """Return accelerator weights for block *idx*. Copies from CPU on miss.""" if idx in self._cache: return self._cache[idx] @@ -82,30 +84,30 @@ def get(self, idx: int) -> dict[str, torch.Tensor]: gpu_weights = self._pool.acquire() cpu_weights = self._source.get(idx) - h2d_event = self._copy_to_gpu(idx, gpu_weights, cpu_weights) + h2d_event = self._copy_to_device(idx, gpu_weights, cpu_weights) self._source.release(idx, event=h2d_event) self._cache[idx] = gpu_weights return gpu_weights - def _copy_to_gpu( + def _copy_to_device( self, idx: int, gpu_weights: dict[str, torch.Tensor], cpu_weights: dict[str, torch.Tensor], - ) -> torch.cuda.Event: - """Enqueue H2D copy + LoRA fusion on the copy stream and wait on compute. + ) -> object | None: + """Copy weights to the target device and fuse LoRAs. The wait is intentionally inside this method so callers -- and instrumentation regions wrapping it -- observe the full transfer time. """ + if self._copy_stream is None: + self._copy_weights(gpu_weights, cpu_weights, non_blocking=False) + if self._lora_sources: + self._fuse_block_loras(idx, gpu_weights) + return None + with torch.cuda.stream(self._copy_stream): - gpu_view = _contiguous_byte_view(gpu_weights) - cpu_view = _contiguous_byte_view(cpu_weights) - if gpu_view is not None and cpu_view is not None and gpu_view.numel() == cpu_view.numel(): - gpu_view.copy_(cpu_view, non_blocking=True) - else: - for name, gpu_tensor in gpu_weights.items(): - gpu_tensor.copy_(cpu_weights[name], non_blocking=True) + self._copy_weights(gpu_weights, cpu_weights, non_blocking=True) if self._lora_sources: self._fuse_block_loras(idx, gpu_weights) h2d_event = torch.cuda.Event() @@ -114,14 +116,33 @@ def _copy_to_gpu( torch.cuda.current_stream(self._target_device).wait_event(h2d_event) return h2d_event - def release(self, idx: int, event: torch.cuda.Event) -> None: + @staticmethod + def _copy_weights( + gpu_weights: dict[str, torch.Tensor], + cpu_weights: dict[str, torch.Tensor], + *, + non_blocking: bool, + ) -> None: + gpu_view = _contiguous_byte_view(gpu_weights) + cpu_view = _contiguous_byte_view(cpu_weights) + if gpu_view is not None and cpu_view is not None and gpu_view.numel() == cpu_view.numel(): + gpu_view.copy_(cpu_view, non_blocking=non_blocking) + else: + for name, gpu_tensor in gpu_weights.items(): + gpu_tensor.copy_(cpu_weights[name], non_blocking=non_blocking) + + def release(self, idx: int, event: object | None = None) -> None: """Attach a compute-done event -- waited before this buffer is recycled.""" - self._events[idx] = event + if event is not None: + self._events[idx] = event def cleanup(self) -> None: """Synchronize streams and release all resources.""" - self._copy_stream.synchronize() - torch.cuda.current_stream(self._target_device).synchronize() + if self._copy_stream is not None: + self._copy_stream.synchronize() + torch.cuda.current_stream(self._target_device).synchronize() + else: + synchronize_device(self._target_device) self._cache.clear() self._events.clear() self._source.cleanup() diff --git a/packages/ltx-core/src/ltx_core/block_streaming/source.py b/packages/ltx-core/src/ltx_core/block_streaming/source.py index 62d0d424..19c87161 100644 --- a/packages/ltx-core/src/ltx_core/block_streaming/source.py +++ b/packages/ltx-core/src/ltx_core/block_streaming/source.py @@ -26,7 +26,7 @@ def get(self, idx: int) -> dict[str, torch.Tensor]: """Return CPU weights for block *idx*.""" ... - def release(self, idx: int, event: torch.cuda.Event) -> None: + def release(self, idx: int, event: object | None = None) -> None: """Signal that an async operation using these weights is guarded by *event*.""" ... @@ -41,7 +41,7 @@ class DiskWeightSource(WeightSource): def __init__(self, pool: WeightPool, reader: DiskBlockReader) -> None: self._pool = pool self._cache: OrderedDict[int, dict[str, torch.Tensor]] = OrderedDict() - self._events: dict[int, torch.cuda.Event] = {} + self._events: dict[int, object] = {} self._reader = reader @property @@ -62,9 +62,10 @@ def get(self, idx: int) -> dict[str, torch.Tensor]: self._cache[idx] = weights return weights - def release(self, idx: int, event: torch.cuda.Event) -> None: + def release(self, idx: int, event: object | None = None) -> None: """Attach an H2D event -- waited before this buffer is recycled.""" - self._events[idx] = event + if event is not None: + self._events[idx] = event def cleanup(self) -> None: """Clear cache and close the disk reader.""" @@ -92,7 +93,7 @@ def block_layout(self) -> TensorLayout: def get(self, idx: int) -> dict[str, torch.Tensor]: return self._weights[idx] - def release(self, idx: int, event: torch.cuda.Event) -> None: + def release(self, idx: int, event: object | None = None) -> None: pass def cleanup(self) -> None: diff --git a/packages/ltx-core/src/ltx_core/block_streaming/wrapper.py b/packages/ltx-core/src/ltx_core/block_streaming/wrapper.py index c7fed140..c9e8b0f6 100644 --- a/packages/ltx-core/src/ltx_core/block_streaming/wrapper.py +++ b/packages/ltx-core/src/ltx_core/block_streaming/wrapper.py @@ -10,6 +10,7 @@ from ltx_core.block_streaming.provider import WeightsProvider from ltx_core.block_streaming.utils import assign_tensor_to_module +from ltx_core.devices import synchronize_device class BlockStreamingWrapper(nn.Module): @@ -55,10 +56,14 @@ def _pre_hook(self, block_idx: int) -> None: assign_tensor_to_module(block, name, gpu_weights[name]) def _post_hook(self, block_idx: int) -> None: - """Record a compute-done event and release the block weights.""" - compute_done = torch.cuda.Event() - compute_done.record(torch.cuda.current_stream(self._target_device)) - self._provider.release(block_idx, event=compute_done) + """Record a compute-done event when the backend exposes CUDA streams.""" + if self._target_device.type == "cuda": + compute_done = torch.cuda.Event() + compute_done.record(torch.cuda.current_stream(self._target_device)) + self._provider.release(block_idx, event=compute_done) + else: + synchronize_device(self._target_device) + self._provider.release(block_idx) def _register_hooks(self) -> None: for idx, block in enumerate(self._blocks): diff --git a/packages/ltx-core/src/ltx_core/devices.py b/packages/ltx-core/src/ltx_core/devices.py new file mode 100644 index 00000000..f29c2cdb --- /dev/null +++ b/packages/ltx-core/src/ltx_core/devices.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import gc +import logging +from collections.abc import Iterator +from contextlib import contextmanager + +import torch + +DeviceSpec = str | int | torch.device | None + +ACCELERATOR_DEVICE_TYPES = frozenset({"cuda", "mps"}) + + +def is_mps_available() -> bool: + """Return whether PyTorch can use the Apple Metal/MPS backend.""" + mps_backend = getattr(torch.backends, "mps", None) + return bool(mps_backend is not None and mps_backend.is_available()) + + +def get_preferred_device(local_rank: int | None = None) -> torch.device: + """Prefer CUDA, then MPS, then CPU. + + ``local_rank`` is only meaningful for CUDA multi-process launches. MPS exposes + a single logical device in PyTorch, so rank-based indexing is not used there. + """ + if torch.cuda.is_available(): + index = torch.cuda.current_device() if local_rank is None else local_rank + return torch.device("cuda", index) + if is_mps_available(): + return torch.device("mps") + return torch.device("cpu") + + +def resolve_device(device: DeviceSpec = None, *, local_rank: int | None = None) -> torch.device: + """Resolve ``None``/``auto`` to the best available accelerator.""" + if device is None: + return get_preferred_device(local_rank=local_rank) + if isinstance(device, int): + return torch.device("cuda", device) + if isinstance(device, str): + if device.lower() in {"auto", "accelerator", "gpu"}: + return get_preferred_device(local_rank=local_rank) + return torch.device(device) + return device + + +def is_accelerator_device(device: DeviceSpec) -> bool: + return resolve_device(device).type in ACCELERATOR_DEVICE_TYPES + + +def synchronize_device(device: DeviceSpec = None) -> None: + """Synchronize CUDA or MPS work if the selected backend supports it.""" + if device is None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + if is_mps_available() and hasattr(torch, "mps"): + torch.mps.synchronize() + return + + resolved = resolve_device(device) + if resolved.type == "cuda" and torch.cuda.is_available(): + torch.cuda.synchronize(resolved) + elif resolved.type == "mps" and is_mps_available() and hasattr(torch, "mps"): + torch.mps.synchronize() + + +def empty_device_cache(device: DeviceSpec = None) -> None: + """Release cached allocator memory for CUDA or MPS.""" + if device is None: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if is_mps_available() and hasattr(torch, "mps"): + torch.mps.empty_cache() + return + + resolved = resolve_device(device) + if resolved.type == "cuda" and torch.cuda.is_available(): + torch.cuda.empty_cache() + elif resolved.type == "mps" and is_mps_available() and hasattr(torch, "mps"): + torch.mps.empty_cache() + + +def cleanup_accelerator_memory(device: DeviceSpec = None) -> None: + """Run Python GC and release CUDA/MPS allocator caches.""" + gc.collect() + empty_device_cache(device) + synchronize_device(device) + try: + if hasattr(torch._C, "_host_emptyCache"): + torch._C._host_emptyCache() + except Exception: + logging.warning("Host empty cache cleanup failed; ignoring.", exc_info=True) + + +def device_memory_allocated(device: DeviceSpec) -> int: + resolved = resolve_device(device) + if resolved.type == "cuda" and torch.cuda.is_available(): + return torch.cuda.memory_allocated(resolved) + if resolved.type == "mps" and is_mps_available() and hasattr(torch, "mps"): + return torch.mps.current_allocated_memory() + return 0 + + +def device_memory_reserved(device: DeviceSpec) -> int: + resolved = resolve_device(device) + if resolved.type == "cuda" and torch.cuda.is_available(): + return torch.cuda.memory_reserved(resolved) + if resolved.type == "mps" and is_mps_available() and hasattr(torch, "mps"): + if hasattr(torch.mps, "driver_allocated_memory"): + return torch.mps.driver_allocated_memory() + return torch.mps.current_allocated_memory() + return 0 + + +def device_memory_allocated_gb(device: DeviceSpec) -> float: + return device_memory_allocated(device) / 1024**3 + + +def get_accelerator_rng_state(device: DeviceSpec = None) -> torch.Tensor | None: + resolved = resolve_device(device) + if resolved.type == "cuda" and torch.cuda.is_available(): + return torch.cuda.get_rng_state(resolved) + if resolved.type == "mps" and is_mps_available() and hasattr(torch, "mps"): + return torch.mps.get_rng_state() + return None + + +def set_accelerator_rng_state(state: torch.Tensor | None, device: DeviceSpec = None) -> None: + if state is None: + return + resolved = resolve_device(device) + if resolved.type == "cuda" and torch.cuda.is_available(): + torch.cuda.set_rng_state(state, resolved) + elif resolved.type == "mps" and is_mps_available() and hasattr(torch, "mps"): + torch.mps.set_rng_state(state) + + +@contextmanager +def fork_device_rng(device: DeviceSpec = None) -> Iterator[None]: + """Temporarily fork CPU plus selected CUDA/MPS RNG state.""" + resolved = resolve_device(device) + cpu_state = torch.random.get_rng_state() + accelerator_state = get_accelerator_rng_state(resolved) + try: + yield + finally: + torch.random.set_rng_state(cpu_state) + set_accelerator_rng_state(accelerator_state, resolved) diff --git a/packages/ltx-core/src/ltx_core/loader/fuse_loras.py b/packages/ltx-core/src/ltx_core/loader/fuse_loras.py index e7a122b7..ed5cf18a 100644 --- a/packages/ltx-core/src/ltx_core/loader/fuse_loras.py +++ b/packages/ltx-core/src/ltx_core/loader/fuse_loras.py @@ -4,6 +4,7 @@ import torch +from ltx_core.devices import get_preferred_device from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict @@ -72,9 +73,7 @@ def _bf16_fuse( def _get_device() -> torch.device: - if torch.cuda.is_available(): - return torch.device("cuda", torch.cuda.current_device()) - return torch.device("cpu") + return get_preferred_device() def aggregate_lora_products( diff --git a/packages/ltx-core/src/ltx_core/loader/primitives.py b/packages/ltx-core/src/ltx_core/loader/primitives.py index fb490fd9..02d97e9c 100644 --- a/packages/ltx-core/src/ltx_core/loader/primitives.py +++ b/packages/ltx-core/src/ltx_core/loader/primitives.py @@ -5,6 +5,7 @@ import torch +from ltx_core.devices import DeviceSpec from ltx_core.loader.module_ops import ModuleOps from ltx_core.loader.sd_ops import SDOps from ltx_core.model.model_protocol import ModelType @@ -60,9 +61,7 @@ def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch class BuilderProtocol(Protocol[ModelType]): """Protocol for model builders that produce a model via ``build()``.""" - def build( - self, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: object - ) -> ModelType: ... + def build(self, device: DeviceSpec = None, dtype: torch.dtype | None = None, **kwargs: object) -> ModelType: ... class ModelBuilderProtocol(BuilderProtocol[ModelType], Protocol[ModelType]): @@ -107,7 +106,7 @@ def with_registry(self, registry: "Registry") -> "ModelBuilderProtocol[ModelType """Return a copy of this builder using the given weight registry for allocation.""" ... - def with_lora_load_device(self, device: torch.device) -> "ModelBuilderProtocol[ModelType]": + def with_lora_load_device(self, device: DeviceSpec) -> "ModelBuilderProtocol[ModelType]": """Return a copy of this builder that loads LoRA weights onto the given device.""" ... @@ -115,9 +114,7 @@ def with_fuse_rule(self, fuse_rule: "FuseRule") -> "ModelBuilderProtocol[ModelTy """Return a copy of this builder with the given LoRA fuse rule (e.g. from a quantization policy).""" ... - def build( - self, device: torch.device | None = None, dtype: torch.dtype | None = None, **kwargs: object - ) -> ModelType: + def build(self, device: DeviceSpec = None, dtype: torch.dtype | None = None, **kwargs: object) -> ModelType: """ Build the model Args: diff --git a/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py b/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py index faf41dc1..7d91585b 100644 --- a/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py +++ b/packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py @@ -5,6 +5,7 @@ import torch from torch import nn +from ltx_core.devices import DeviceSpec, resolve_device from ltx_core.loader.fuse_loras import FuseRule, apply_loras, bf16_fuse_rule from ltx_core.loader.helpers import create_meta_model, load_state_dict, read_model_config from ltx_core.loader.module_ops import ModuleOps @@ -81,7 +82,7 @@ def _load_model_weights( @dataclass(frozen=True) class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol): """ - Builder for PyTorch models residing on a single GPU. + Builder for PyTorch models residing on a single accelerator. Attributes: model_class_configurator: Class responsible for constructing the model from a config dict. model_path: Path (or tuple of shard paths) to the model's `.safetensors` checkpoint(s). @@ -94,7 +95,7 @@ class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], lora_load_device: Device used when loading LoRA weight tensors from disk. Defaults to ``torch.device("cpu")``, which keeps LoRA weights in CPU memory and transfers them to the target GPU sequentially during fusion, reducing peak GPU memory usage compared to - loading all LoRA weights directly onto the GPU at once. + loading all LoRA weights directly onto the accelerator at once. fuse_rule: Per-policy LoRA merge rule. Defaults to ``bf16_fuse_rule``; """ @@ -123,8 +124,8 @@ def with_loras(self, loras: tuple[LoraPathStrengthAndSDOps, ...]) -> "SingleGPUM def with_registry(self, registry: Registry) -> "SingleGPUModelBuilder": return replace(self, registry=registry) - def with_lora_load_device(self, device: torch.device) -> "SingleGPUModelBuilder": - return replace(self, lora_load_device=device) + def with_lora_load_device(self, device: DeviceSpec) -> "SingleGPUModelBuilder": + return replace(self, lora_load_device=resolve_device(device)) def with_fuse_rule(self, fuse_rule: FuseRule) -> "SingleGPUModelBuilder": return replace(self, fuse_rule=fuse_rule) @@ -136,9 +137,10 @@ def meta_model(self, config: dict, module_ops: tuple[ModuleOps, ...]) -> ModelTy return create_meta_model(self.model_class_configurator, config, module_ops) def load_sd( - self, paths: list[str], registry: Registry, device: torch.device | None, sd_ops: SDOps | None = None + self, paths: list[str], registry: Registry, device: DeviceSpec = None, sd_ops: SDOps | None = None ) -> StateDict: - return load_state_dict(paths, self.model_loader, registry, device, sd_ops) + resolved_device = resolve_device(device) if device is not None else None + return load_state_dict(paths, self.model_loader, registry, resolved_device, sd_ops) def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelType: uninitialized = _check_uninitialized(meta_model) @@ -149,11 +151,11 @@ def _return_model(self, meta_model: ModelType, device: torch.device) -> ModelTyp def build( self, - device: torch.device | None = None, + device: DeviceSpec = None, dtype: torch.dtype | None = None, **kwargs: object, # noqa: ARG002 ) -> ModelType: - device = torch.device("cuda") if device is None else device + device = resolve_device(device) config = self.model_config() meta_model = self.meta_model(config, self.module_ops) diff --git a/packages/ltx-core/src/ltx_core/model/transformer/attention.py b/packages/ltx-core/src/ltx_core/model/transformer/attention.py index ec49640c..427611bd 100644 --- a/packages/ltx-core/src/ltx_core/model/transformer/attention.py +++ b/packages/ltx-core/src/ltx_core/model/transformer/attention.py @@ -117,6 +117,8 @@ def __call__( ) -> torch.Tensor: if memory_efficient_attention is None: raise RuntimeError("XFormersAttention was selected but `xformers` is not installed.") + if q.device.type != "cuda": + raise RuntimeError("XFormersAttention requires CUDA. Use PyTorch SDPA on CPU or MPS.") b, _, dim_head = q.shape dim_head //= heads @@ -163,6 +165,8 @@ def __call__( ) -> torch.Tensor: if flash_attn_interface is None: raise RuntimeError("FlashAttention3 was selected but `FlashAttention3` is not installed.") + if q.device.type != "cuda": + raise RuntimeError("FlashAttention3 requires CUDA. Use PyTorch SDPA on CPU or MPS.") b, _, dim_head = q.shape dim_head //= heads @@ -186,6 +190,8 @@ def __call__( ) -> torch.Tensor: if flash_attn_4_func is None: raise RuntimeError("FlashAttention4 was selected but `flash-attn-4` is not installed.") + if q.device.type != "cuda": + raise RuntimeError("FlashAttention4 requires CUDA. Use PyTorch SDPA on CPU or MPS.") b, _, dim_head = q.shape dim_head //= heads @@ -282,7 +288,7 @@ def _select_masked_attention() -> MaskedAttentionCallable: """Pick a mask-aware attention. Prefers xFormers when installed; else SDPA with the full priority list (the dispatcher rejects FLASH automatically when a mask is present and walks past it).""" - if memory_efficient_attention is not None: + if torch.cuda.is_available() and memory_efficient_attention is not None: return XFormersAttention() return _sdpa_full_priority() @@ -335,7 +341,7 @@ class AttentionFunction(Enum): # :func:`automatic_attention`. Default for :class:`AttentionOps`. AUTOMATIC = "automatic" - def to_callable(self) -> AttentionCallable: # noqa: PLR0911 + def to_callable(self) -> AttentionCallable: # noqa: PLR0911, PLR0912 """Resolve to a concrete callable. Use this at module init time so that torch.compile can trace through the attention call without graph breaks. Every non-AUTOMATIC variant raises :class:`RuntimeError` when the backend @@ -353,18 +359,28 @@ def to_callable(self) -> AttentionCallable: # noqa: PLR0911 case AttentionFunction.XFORMERS: if memory_efficient_attention is None: raise RuntimeError("AttentionFunction.XFORMERS selected but `xformers` is not installed.") + if not torch.cuda.is_available(): + raise RuntimeError("AttentionFunction.XFORMERS requires CUDA. Use PyTorch SDPA on CPU or MPS.") return XFormersAttention() case AttentionFunction.FLASH_ATTENTION_3: if flash_attn_interface is None: raise RuntimeError( "AttentionFunction.FLASH_ATTENTION_3 selected but `flash-attn-3` is not installed." ) + if not torch.cuda.is_available(): + raise RuntimeError( + "AttentionFunction.FLASH_ATTENTION_3 requires CUDA. Use PyTorch SDPA on CPU or MPS." + ) return FlashAttention3() case AttentionFunction.FLASH_ATTENTION_4: if flash_attn_4_func is None: raise RuntimeError( "AttentionFunction.FLASH_ATTENTION_4 selected but `flash-attn-4` is not installed." ) + if not torch.cuda.is_available(): + raise RuntimeError( + "AttentionFunction.FLASH_ATTENTION_4 requires CUDA. Use PyTorch SDPA on CPU or MPS." + ) return FlashAttention4() case AttentionFunction.SDPA_MATH: return PytorchAttention(priority=[SDPBackend.MATH]) @@ -415,6 +431,10 @@ def to_callable(self) -> MaskedAttentionCallable: case MaskedAttentionFunction.XFORMERS: if memory_efficient_attention is None: raise RuntimeError("MaskedAttentionFunction.XFORMERS selected but `xformers` is not installed.") + if not torch.cuda.is_available(): + raise RuntimeError( + "MaskedAttentionFunction.XFORMERS requires CUDA. Use PyTorch SDPA on CPU or MPS." + ) return XFormersAttention() case MaskedAttentionFunction.SDPA_MATH: return PytorchAttention(priority=[SDPBackend.MATH]) diff --git a/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py b/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py index cadde9e0..d5d7316c 100644 --- a/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py +++ b/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py @@ -4,6 +4,7 @@ import torch from transformers import AutoImageProcessor, Gemma3ForConditionalGeneration, Gemma3Processor +from ltx_core.devices import fork_device_rng from ltx_core.loader.module_ops import ModuleOps from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer from ltx_core.utils import find_matching_file @@ -65,7 +66,7 @@ def _enhance( pad_token_id = self.processor.tokenizer.pad_token_id if self.processor.tokenizer.pad_token_id is not None else 0 model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id=pad_token_id) - with torch.inference_mode(), torch.random.fork_rng(devices=[self.model.device]): + with torch.inference_mode(), fork_device_rng(self.model.device): torch.manual_seed(seed) outputs = self.model.generate( **model_inputs, diff --git a/packages/ltx-pipelines/CLAUDE.md b/packages/ltx-pipelines/CLAUDE.md index 8b5f62ee..d34e1a66 100644 --- a/packages/ltx-pipelines/CLAUDE.md +++ b/packages/ltx-pipelines/CLAUDE.md @@ -57,7 +57,7 @@ Inference pipelines for LTX-2 audio-video generation. Depends on `ltx-core` for ### Memory management - **Model lifecycle**: All blocks build their model on call and free it on exit. `gpu_model()` moves params to `"meta"` device on exit, immediately releasing storage. No model persists between calls. -- **Block streaming**: When offloading is enabled, `DiffusionStage` wraps the transformer in `BlockStreamingWrapper`. Blocks live on pinned CPU memory; only 2 blocks are buffered on GPU at a time (one for compute, one for async H2D copy on a separate CUDA stream). +- **Block streaming**: When offloading is enabled, `DiffusionStage` wraps the transformer in `BlockStreamingWrapper`. Blocks live on CPU memory; only 2 blocks are buffered on the accelerator at a time. CUDA uses pinned memory plus an async H2D copy stream; MPS uses synchronous copies. - **Batch splitting**: `BatchSplitAdapter` wraps the transformer and splits inputs exceeding `max_batch_size` into sequential chunks. If guidance needs B=4 but `max_batch_size=1`, it runs 4 sequential B=1 passes. Higher `max_batch_size` reduces layer-streaming PCIe transfers at the cost of peak memory. ## Denoisers (`utils/denoisers.py`) diff --git a/packages/ltx-pipelines/README.md b/packages/ltx-pipelines/README.md index 79e1e705..36248121 100644 --- a/packages/ltx-pipelines/README.md +++ b/packages/ltx-pipelines/README.md @@ -343,19 +343,20 @@ audio_guider_params = MultiModalGuiderParams( **FP8 Quantization (Lower Memory Footprint):** -For smaller GPU memory footprint, use the `--quantization` flag and set `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`. +For smaller accelerator memory footprint, use the `--quantization` flag. On CUDA, also set +`PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`. Two quantization policies are available: | Policy | CLI Flag | Description | | ------ | -------- | ----------- | | **FP8 Cast** | `--quantization fp8-cast` | Downcasts transformer linear weights to FP8 during loading; upcasts on the fly during inference. No extra dependencies. | -| **FP8 Scaled MM** | `--quantization fp8-scaled-mm` | Uses FP8 scaled matrix multiplication via TensorRT-LLM (`tensorrt_llm` must be installed). Best performance on Hopper GPUs. | +| **FP8 Scaled MM** | `--quantization fp8-scaled-mm` | Uses FP8 scaled matrix multiplication via TensorRT-LLM (`tensorrt_llm` must be installed). CUDA-only; best performance on Hopper GPUs. | **CLI:** ```bash -# FP8 Cast (works on any GPU with FP8 support) +# FP8 Cast (works on CUDA GPUs with FP8 support) PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python -m ltx_pipelines.ti2vid_two_stages \ --quantization fp8-cast --checkpoint-path=... @@ -384,7 +385,7 @@ pipeline = TI2VidTwoStagesPipeline( pipeline(...) ``` -You still need to use `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` when launching: +On CUDA, use `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` when launching: ```bash PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python my_denoising_pipeline.py @@ -392,7 +393,7 @@ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python my_denoising_pipeline.py **Memory Cleanup Between Stages:** -By default, pipelines clean GPU memory (especially transformer weights) between stages. If you have enough memory, you can skip this cleanup to reduce running time: +By default, pipelines clean CUDA/MPS memory (especially transformer weights) between stages. If you have enough memory, you can skip this cleanup to reduce running time: ```python # In pipeline implementations, memory cleanup happens automatically diff --git a/packages/ltx-pipelines/src/ltx_pipelines/hdr_ic_lora.py b/packages/ltx-pipelines/src/ltx_pipelines/hdr_ic_lora.py index b961967a..29d1532b 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/hdr_ic_lora.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/hdr_ic_lora.py @@ -57,7 +57,7 @@ ) from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES from ltx_pipelines.utils.denoisers import SimpleDenoiser -from ltx_pipelines.utils.helpers import get_device, modality_from_latent_state +from ltx_pipelines.utils.helpers import cleanup_memory, get_device, modality_from_latent_state from ltx_pipelines.utils.media_io import ResizeMode, align_resolution, load_video_conditioning_hdr from ltx_pipelines.utils.quantization_factory import QuantizationKind from ltx_pipelines.utils.types import ModalitySpec, OffloadMode @@ -721,7 +721,7 @@ def _process_single_video( # noqa: PLR0913 del hdr_video gc.collect() - torch.cuda.empty_cache() + cleanup_memory() if not skip_mp4: # Wait for EXR saves to finish before encoding. @@ -818,6 +818,8 @@ def _build_arg_parser() -> "argparse.ArgumentParser": # noqa: F821 "and keeps every other frame for smoother output. ~2x slower.", ) return parser + + @torch.inference_mode() def main() -> None: """Batch HDR IC-LoRA inference: per-frame EXR + tonemapped ProRes .mov.""" diff --git a/packages/ltx-pipelines/src/ltx_pipelines/retake.py b/packages/ltx-pipelines/src/ltx_pipelines/retake.py index d6f11b48..dc02b1a6 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/retake.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/retake.py @@ -55,7 +55,7 @@ class RetakePipeline: loras : list[LoraPathStrengthAndSDOps] Optional LoRA configs applied to the transformer. device : torch.device - Target device (default: CUDA if available). + Target device (default: CUDA if available, then MPS, then CPU). quantization : QuantizationPolicy | None Optional quantization policy for the transformer. distilled : bool diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py index da8c0bf0..177c891e 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py @@ -4,6 +4,7 @@ import torch +from ltx_core.devices import synchronize_device from ltx_pipelines.utils.helpers import cleanup_memory _M = TypeVar("_M", bound=torch.nn.Module) @@ -14,7 +15,7 @@ def gpu_model(model: _M) -> Iterator[_M]: """Context manager that yields a model and releases its memory on exit. Moves all parameters and buffers to ``meta`` device on exit, which immediately releases the underlying storage on **both** GPU and CPU, - then runs ``cleanup_memory()`` to reclaim fragmented CUDA memory. + then runs ``cleanup_memory()`` to reclaim fragmented accelerator memory. Usage:: with gpu_model(build_encoder()) as encoder: ... # use encoder — typed as the concrete class @@ -23,8 +24,8 @@ def gpu_model(model: _M) -> Iterator[_M]: try: yield model finally: - torch.cuda.synchronize() + synchronize_device() # .to("meta") releases storage for all parameters/buffers regardless - # of their original device (CUDA or CPU). + # of their original device. model.to("meta") cleanup_memory() diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py index 99608710..7a6a70ec 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py @@ -1,4 +1,3 @@ -import gc import logging import torch @@ -9,6 +8,7 @@ VideoConditionByKeyframeIndex, VideoConditionByLatentIndex, ) +from ltx_core.devices import cleanup_accelerator_memory, get_preferred_device from ltx_core.model.audio_vae import encode_audio from ltx_core.model.transformer import Modality from ltx_core.model.video_vae import TilingConfig, VideoEncoder @@ -28,20 +28,11 @@ def get_device() -> torch.device: - if torch.cuda.is_available(): - return torch.device("cuda", torch.cuda.current_device()) - return torch.device("cpu") + return get_preferred_device() def cleanup_memory() -> None: - gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() - try: - if hasattr(torch._C, "_host_emptyCache"): - torch._C._host_emptyCache() - except Exception: - logging.warning("Host empty cache cleanup failed; ignoring.", exc_info=True) + cleanup_accelerator_memory() def _conform_latent_length(latent: torch.Tensor, expected_frames_count: int) -> torch.Tensor: diff --git a/packages/ltx-trainer/README.md b/packages/ltx-trainer/README.md index 56b216d8..e04c1406 100644 --- a/packages/ltx-trainer/README.md +++ b/packages/ltx-trainer/README.md @@ -26,8 +26,9 @@ All detailed guides and technical documentation are in the [docs](./docs/) direc - **LTX-2 Model Checkpoint** - Local `.safetensors` file - **Gemma Text Encoder** - Local Gemma model directory (required for LTX-2) -- **Linux with CUDA** - CUDA 13+ recommended for optimal performance -- **Nvidia GPU with 80GB+ VRAM** - Recommended for the standard config. For GPUs with 32GB VRAM (e.g., RTX 5090), +- **CUDA or Apple Silicon MPS** - CUDA 13+ is recommended for optimal training throughput; Apple Silicon Macs can use + the `mps` backend for local preprocessing, validation, and smaller single-process runs +- **NVIDIA GPU with 80GB+ VRAM** - Recommended for the standard config. For GPUs with 32GB VRAM (e.g., RTX 5090), use the [low VRAM config](configs/ltx2_av_lora_low_vram.yaml) which enables INT8 quantization and other memory optimizations diff --git a/packages/ltx-trainer/docs/quick-start.md b/packages/ltx-trainer/docs/quick-start.md index da899648..638b26cd 100644 --- a/packages/ltx-trainer/docs/quick-start.md +++ b/packages/ltx-trainer/docs/quick-start.md @@ -10,8 +10,9 @@ Before you begin, ensure you have: Download `ltx-2-19b-dev.safetensors` from: [HuggingFace Hub](https://huggingface.co/Lightricks/LTX-2) 2. **Gemma Text Encoder** - A local directory containing the Gemma model (required for LTX-2). Download from: [HuggingFace Hub](https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized/) -3. **Linux with CUDA** - The trainer requires `triton` which is Linux-only -4. **GPU with sufficient VRAM** - 80GB recommended for the standard config. For GPUs with 32GB VRAM (e.g., RTX 5090), +3. **CUDA or Apple Silicon MPS** - CUDA is recommended for full training performance; MPS is supported for local + preprocessing, validation, and smaller single-process runs +4. **GPU with sufficient memory** - 80GB recommended for the standard config. For GPUs with 32GB VRAM (e.g., RTX 5090), use the [low VRAM config](../configs/ltx2_av_lora_low_vram.yaml) which enables INT8 quantization and other memory optimizations diff --git a/packages/ltx-trainer/scripts/caption_videos.py b/packages/ltx-trainer/scripts/caption_videos.py index 92fefe94..d31de0e7 100755 --- a/packages/ltx-trainer/scripts/caption_videos.py +++ b/packages/ltx-trainer/scripts/caption_videos.py @@ -33,7 +33,6 @@ from enum import Enum from pathlib import Path -import torch import typer from rich.console import Console from rich.progress import ( @@ -47,6 +46,7 @@ ) from transformers.utils.logging import disable_progress_bar +from ltx_core.devices import resolve_device from ltx_trainer.captioning import CaptionerType, MediaCaptioningModel, create_captioner VIDEO_EXTENSIONS = ["mp4", "avi", "mov", "mkv", "webm"] @@ -375,7 +375,7 @@ def main( # noqa: PLR0913 None, "--device", "-d", - help="Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu'). Only for local models.", + help="Device to use for inference (e.g., 'auto', 'cuda', 'mps', 'cpu'). Only for local models.", ), use_8bit: bool = typer.Option( False, @@ -468,7 +468,7 @@ def main( # noqa: PLR0913 raise typer.Exit(code=1) # Determine device for local models - device_str = device or ("cuda" if torch.cuda.is_available() else "cpu") + device_str = str(resolve_device(device)) # Parse extensions ext_list = [ext.strip() for ext in extensions.split(",")] diff --git a/packages/ltx-trainer/scripts/decode_latents.py b/packages/ltx-trainer/scripts/decode_latents.py index 73a74f59..7589b24c 100755 --- a/packages/ltx-trainer/scripts/decode_latents.py +++ b/packages/ltx-trainer/scripts/decode_latents.py @@ -28,6 +28,7 @@ ) from transformers.utils.logging import disable_progress_bar +from ltx_core.devices import resolve_device from ltx_core.model.video_vae import SpatialTilingConfig, TemporalTilingConfig, TilingConfig from ltx_trainer import logger from ltx_trainer.model_loader import load_audio_vae_decoder, load_video_vae_decoder, load_vocoder @@ -51,7 +52,7 @@ class LatentsDecoder: def __init__( self, model_path: str, - device: str = "cuda", + device: str = "auto", vae_tiling: bool = False, with_audio: bool = False, ): @@ -62,7 +63,7 @@ def __init__( vae_tiling: Whether to enable VAE tiling for larger video resolutions with_audio: Whether to load audio VAE for audio decoding """ - self.device = torch.device(device) + self.device = resolve_device(device) self.model_path = model_path self.vae = None self.audio_vae = None @@ -306,8 +307,8 @@ def main( help="Path to LTX-2 checkpoint (.safetensors file)", ), device: str = typer.Option( - default="cuda", - help="Device to use for computation", + default="auto", + help="Device to use for computation ('auto', 'cuda', 'mps', or 'cpu')", ), vae_tiling: bool = typer.Option( default=False, diff --git a/packages/ltx-trainer/scripts/inference.py b/packages/ltx-trainer/scripts/inference.py index 13a8b768..6eb0bc76 100755 --- a/packages/ltx-trainer/scripts/inference.py +++ b/packages/ltx-trainer/scripts/inference.py @@ -36,6 +36,7 @@ from safetensors.torch import load_file from torchvision import transforms +from ltx_core.devices import resolve_device from ltx_trainer.model_loader import load_model from ltx_trainer.progress import StandaloneSamplingProgress from ltx_trainer.utils import open_image_as_srgb @@ -274,11 +275,12 @@ def main() -> None: # noqa: PLR0912, PLR0915 parser.add_argument( "--device", type=str, - default="cuda", - help="Device to run on (cuda/cpu)", + default="auto", + help="Device to run on (auto/cuda/mps/cpu)", ) args = parser.parse_args() + device = resolve_device(args.device) # Validate conditioning arguments if args.include_reference_in_output and args.reference_video is None: @@ -351,6 +353,7 @@ def main() -> None: # noqa: PLR0912, PLR0915 else: print("STG: disabled") print(f"Seed: {args.seed}") + print(f"Device: {device}") if args.lora_path: print(f"LoRA: {args.lora_path}") if condition_image is not None: @@ -400,7 +403,7 @@ def main() -> None: # noqa: PLR0912, PLR0915 ) video, audio = sampler.generate( config=gen_config, - device=args.device, + device=device, ) # Save video diff --git a/packages/ltx-trainer/scripts/process_captions.py b/packages/ltx-trainer/scripts/process_captions.py index 95232093..a912d3f1 100755 --- a/packages/ltx-trainer/scripts/process_captions.py +++ b/packages/ltx-trainer/scripts/process_captions.py @@ -232,7 +232,7 @@ def compute_captions_embeddings( # noqa: PLR0913 lora_trigger: str | None = None, remove_llm_prefixes: bool = False, batch_size: int = 8, - device: str = "cuda", + device: str = "auto", load_in_8bit: bool = False, overwrite: bool = False, ) -> None: @@ -417,8 +417,8 @@ def main( # noqa: PLR0913 help="Batch size for processing", ), device: str = typer.Option( - default="cuda", - help="Device to use for computation", + default="auto", + help="Device to use for computation ('auto', 'cuda', 'mps', or 'cpu')", ), lora_trigger: str | None = typer.Option( default=None, diff --git a/packages/ltx-trainer/scripts/process_dataset.py b/packages/ltx-trainer/scripts/process_dataset.py index fa827b3e..a063dbb7 100755 --- a/packages/ltx-trainer/scripts/process_dataset.py +++ b/packages/ltx-trainer/scripts/process_dataset.py @@ -216,8 +216,8 @@ def main( # noqa: PLR0913 help="Batch size for preprocessing", ), device: str = typer.Option( - default="cuda", - help="Device to use for computation", + default="auto", + help="Device to use for computation ('auto', 'cuda', 'mps', or 'cpu')", ), vae_tiling: bool = typer.Option( default=False, diff --git a/packages/ltx-trainer/scripts/process_videos.py b/packages/ltx-trainer/scripts/process_videos.py index c33a815d..0d3d301d 100755 --- a/packages/ltx-trainer/scripts/process_videos.py +++ b/packages/ltx-trainer/scripts/process_videos.py @@ -43,6 +43,7 @@ from torchvision.transforms.functional import crop, resize, to_tensor from transformers.utils.logging import disable_progress_bar +from ltx_core.devices import resolve_device from ltx_core.model.audio_vae import AudioProcessor from ltx_core.types import Audio from ltx_trainer import logger @@ -443,7 +444,7 @@ def compute_latents( # noqa: PLR0913, PLR0915 main_media_column: str | None = None, reshape_mode: str = "center", batch_size: int = 1, - device: str = "cuda", + device: str = "auto", vae_tiling: bool = False, with_audio: bool = False, audio_output_dir: str | None = None, @@ -475,7 +476,7 @@ def compute_latents( # noqa: PLR0913, PLR0915 raise ValueError("audio_output_dir must be provided when with_audio=True") console = Console() - torch_device = torch.device(device) + torch_device = resolve_device(device) dataset = MediaDataset( dataset_file=dataset_file, @@ -1015,8 +1016,8 @@ def main( # noqa: PLR0913 help="Batch size for processing", ), device: str = typer.Option( - default="cuda", - help="Device to use for computation", + default="auto", + help="Device to use for computation ('auto', 'cuda', 'mps', or 'cpu')", ), vae_tiling: bool = typer.Option( default=False, diff --git a/packages/ltx-trainer/src/ltx_trainer/captioning.py b/packages/ltx-trainer/src/ltx_trainer/captioning.py index 5e09c30e..b0a6ba93 100644 --- a/packages/ltx-trainer/src/ltx_trainer/captioning.py +++ b/packages/ltx-trainer/src/ltx_trainer/captioning.py @@ -17,6 +17,8 @@ import torch +from ltx_core.devices import resolve_device + # Instruction for audio-visual captioning (default) - includes speech transcription and sounds DEFAULT_CAPTION_INSTRUCTION = """\ Analyze this media and provide a detailed caption in the following EXACT format. Fill in ALL sections: @@ -133,11 +135,11 @@ def __init__( """ Initialize the Qwen2.5-Omni captioner. Args: - device: Device to use for inference (e.g., 'cuda', 'cuda:0', 'cpu') + device: Device to use for inference (e.g., 'auto', 'cuda', 'mps', 'cpu') use_8bit: Whether to use 8-bit quantization for reduced memory usage instruction: Custom instruction prompt. If None, uses the default instruction """ - self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) + self.device = resolve_device(device) self.instruction = instruction self._load_model(use_8bit=use_8bit) @@ -255,6 +257,9 @@ def _load_model(self, use_8bit: bool) -> None: Qwen2_5OmniThinkerForConditionalGeneration, ) + if use_8bit and self.device.type == "mps": + raise ValueError("8-bit Qwen-Omni captioning uses bitsandbytes and is not supported on MPS.") + quantization_config = BitsAndBytesConfig(load_in_8bit=True) if use_8bit else None # Use Thinker-only model for text generation (saves memory by not loading Talker) @@ -263,7 +268,7 @@ def _load_model(self, use_8bit: bool) -> None: dtype=torch.bfloat16, low_cpu_mem_usage=True, quantization_config=quantization_config, - device_map="auto", + device_map={"": self.device}, ) self.processor = Qwen2_5OmniProcessor.from_pretrained(self.MODEL_ID) diff --git a/packages/ltx-trainer/src/ltx_trainer/gemma_8bit.py b/packages/ltx-trainer/src/ltx_trainer/gemma_8bit.py index 83727907..a58321a0 100644 --- a/packages/ltx-trainer/src/ltx_trainer/gemma_8bit.py +++ b/packages/ltx-trainer/src/ltx_trainer/gemma_8bit.py @@ -19,6 +19,7 @@ import torch +from ltx_core.devices import is_mps_available, resolve_device from ltx_core.text_encoders.gemma.encoders.base_encoder import GemmaTextEncoder from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer @@ -35,7 +36,8 @@ def load_8bit_gemma( Args: gemma_model_path: Path to Gemma model directory dtype: Data type for non-quantized model weights - device: Device to place the quantized model on. When ``None`` (default), + device: Device to place the quantized model on. MPS is not supported + because this path depends on bitsandbytes. When ``None`` (default), the device is inferred from ``LOCAL_RANK`` if CUDA is available, so multi-process launches put each rank's encoder on its own GPU instead of all colliding on ``cuda:0``. @@ -59,9 +61,13 @@ def load_8bit_gemma( # in multi-process launches because every rank picks the same default device. device_map: str | dict[str, int | str | torch.device] if device is not None: + if resolve_device(device).type == "mps": + raise ValueError("8-bit Gemma loading uses bitsandbytes and is not supported on MPS.") device_map = {"": device} elif torch.cuda.is_available(): device_map = {"": int(os.environ.get("LOCAL_RANK", "0"))} + elif is_mps_available(): + raise ValueError("8-bit Gemma loading uses bitsandbytes and is not supported on MPS.") else: device_map = "auto" diff --git a/packages/ltx-trainer/src/ltx_trainer/gpu_utils.py b/packages/ltx-trainer/src/ltx_trainer/gpu_utils.py index de385a1c..9dfd4ee3 100644 --- a/packages/ltx-trainer/src/ltx_trainer/gpu_utils.py +++ b/packages/ltx-trainer/src/ltx_trainer/gpu_utils.py @@ -7,23 +7,30 @@ import torch +from ltx_core.devices import ( + cleanup_accelerator_memory, + device_memory_allocated, + device_memory_allocated_gb, + device_memory_reserved, + get_preferred_device, +) from ltx_trainer import logger F = TypeVar("F", bound=Callable) def free_gpu_memory(log: bool = False) -> None: - """Free GPU memory by running garbage collection and emptying CUDA cache. + """Free accelerator memory by running garbage collection and emptying CUDA/MPS caches. Args: log: If True, log memory stats after clearing """ gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - if log: - allocated = torch.cuda.memory_allocated() / 1024**3 - reserved = torch.cuda.memory_reserved() / 1024**3 - logger.debug(f"GPU memory freed. Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB") + cleanup_accelerator_memory() + if log: + device = get_preferred_device() + allocated = device_memory_allocated(device) / 1024**3 + reserved = device_memory_reserved(device) / 1024**3 + logger.debug(f"Accelerator memory freed. Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB") class free_gpu_memory_context: # noqa: N801 @@ -65,12 +72,15 @@ def wrapper(*args, **kwargs) -> object: def get_gpu_memory_gb(device: torch.device) -> float: - """Get current GPU memory usage in GB using nvidia-smi. + """Get current accelerator memory usage in GB. Args: device: torch.device to get memory usage for Returns: Current GPU memory usage in GB """ + if device.type != "cuda": + return device_memory_allocated_gb(device) + try: device_id = device.index if device.index is not None else 0 result = subprocess.check_output( @@ -85,6 +95,6 @@ def get_gpu_memory_gb(device: torch.device) -> float: ) return float(result.strip()) / 1024 # Convert MB to GB except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e: - logger.error(f"Failed to get GPU memory from nvidia-smi: {e}") + logger.debug(f"Failed to get GPU memory from nvidia-smi: {e}") # Fallback to torch - return torch.cuda.memory_allocated(device) / 1024**3 + return device_memory_allocated(device) / 1024**3 diff --git a/packages/ltx-trainer/src/ltx_trainer/model_loader.py b/packages/ltx-trainer/src/ltx_trainer/model_loader.py index f6aeba62..81e10e42 100644 --- a/packages/ltx-trainer/src/ltx_trainer/model_loader.py +++ b/packages/ltx-trainer/src/ltx_trainer/model_loader.py @@ -6,9 +6,9 @@ for training, using SingleGPUModelBuilder from ltx-core. Example usage: # Load individual components - vae_encoder = load_video_vae_encoder("/path/to/checkpoint.safetensors", device="cuda") - vae_decoder = load_video_vae_decoder("/path/to/checkpoint.safetensors", device="cuda") - text_encoder = load_text_encoder("/path/to/gemma", device="cuda") + vae_encoder = load_video_vae_encoder("/path/to/checkpoint.safetensors", device="auto") + vae_decoder = load_video_vae_decoder("/path/to/checkpoint.safetensors", device="auto") + text_encoder = load_text_encoder("/path/to/gemma", device="auto") # Load all components at once components = load_model("/path/to/checkpoint.safetensors", text_encoder_path="/path/to/gemma") """ @@ -21,6 +21,7 @@ import torch +from ltx_core.devices import resolve_device from ltx_trainer import logger # Type alias for device specification @@ -38,7 +39,7 @@ def _to_torch_device(device: Device) -> torch.device: """Convert device specification to torch.device.""" - return torch.device(device) if isinstance(device, str) else device + return resolve_device(device) # ============================================================================= @@ -306,7 +307,7 @@ def load_model( Args: checkpoint_path: Path to the safetensors checkpoint file text_encoder_path: Path to Gemma model directory (required if with_text_encoder=True) - device: Device to load models on ("cuda", "cpu", etc.) + device: Device to load models on ("auto", "cuda", "mps", "cpu", etc.) dtype: Data type for model weights with_video_vae_encoder: Whether to load the video VAE encoder (for preprocessing) with_video_vae_decoder: Whether to load the video VAE decoder (for inference/validation) diff --git a/packages/ltx-trainer/src/ltx_trainer/quantization.py b/packages/ltx-trainer/src/ltx_trainer/quantization.py index 31bc9bc6..76e646d2 100644 --- a/packages/ltx-trainer/src/ltx_trainer/quantization.py +++ b/packages/ltx-trainer/src/ltx_trainer/quantization.py @@ -5,6 +5,7 @@ import torch from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn +from ltx_core.devices import resolve_device from ltx_trainer import logger QuantizationOptions = Literal[ @@ -64,16 +65,16 @@ def quantize_model( model: The model to quantize. precision: The quantization precision (e.g. "int8-quanto", "fp8-quanto"). quantize_activations: Whether to quantize activations in addition to weights. - device: Device to use for quantization. If None, uses CUDA if available, else CPU. + device: Device to use for quantization. If None, uses CUDA, then MPS, then CPU. Returns: The quantized model. """ from optimum.quanto import freeze, quantize # noqa: PLC0415 - if device is None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - elif isinstance(device, str): - device = torch.device(device) + device = resolve_device(device) + + if device.type == "mps" and precision in ("fp8-quanto", "fp8uz-quanto"): + raise ValueError("FP8 quantization is not supported on MPS devices. Use int2, int4, or int8 instead.") weight_quant = _get_quanto_dtype(precision) @@ -185,8 +186,6 @@ def _get_quanto_dtype(precision: QuantizationOptions) -> torch.dtype: elif precision == "int8-quanto": return qint8 elif precision in ("fp8-quanto", "fp8uz-quanto"): - if torch.backends.mps.is_available(): - raise ValueError("FP8 quantization is not supported on MPS devices. Use int2, int4, or int8 instead.") if precision == "fp8-quanto": return qfloat8 elif precision == "fp8uz-quanto": diff --git a/packages/ltx-trainer/src/ltx_trainer/trainer.py b/packages/ltx-trainer/src/ltx_trainer/trainer.py index 92cfec42..bddb2db9 100644 --- a/packages/ltx-trainer/src/ltx_trainer/trainer.py +++ b/packages/ltx-trainer/src/ltx_trainer/trainer.py @@ -30,8 +30,14 @@ StepLR, ) from torch.utils.data import DataLoader -from torchvision.transforms import functional as F # noqa: N812 +from torchvision.transforms import functional as F +from ltx_core.devices import ( + get_accelerator_rng_state, + get_preferred_device, + is_accelerator_device, + set_accelerator_rng_state, +) from ltx_core.text_encoders.gemma import convert_to_additive_mask from ltx_trainer import logger from ltx_trainer.config import LtxTrainerConfig @@ -394,7 +400,7 @@ def _load_text_encoder_and_cache_embeddings(self) -> list[CachedPromptEmbeddings # Load text encoder (pure Gemma LLM) on GPU — LOCAL_RANK before Accelerator exists local_rank = int(os.environ.get("LOCAL_RANK", "0")) - init_device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + init_device = get_preferred_device(local_rank=local_rank) logger.debug("Loading text encoder...") text_encoder = load_text_encoder( @@ -667,8 +673,10 @@ def _restore_training_state(self, training_state: TrainingState) -> bool: else: if rng.torch_state is not None: torch.random.set_rng_state(rng.torch_state) - if rng.cuda_state is not None and torch.cuda.is_available(): - torch.cuda.set_rng_state(rng.cuda_state) + if rng.cuda_state is not None and self._accelerator.device.type == "cuda": + set_accelerator_rng_state(rng.cuda_state, self._accelerator.device) + if rng.mps_state is not None and self._accelerator.device.type == "mps": + set_accelerator_rng_state(rng.mps_state, self._accelerator.device) logger.debug("Restored RNG states") return True @@ -703,8 +711,8 @@ def _prepare_models_for_training(self) -> None: self._transformer = self._accelerator.prepare(self._transformer) # Log GPU memory usage after model preparation - vram_usage_gb = torch.cuda.memory_allocated() / 1024**3 - logger.debug(f"GPU memory usage after models preparation: {vram_usage_gb:.2f} GB") + vram_usage_gb = get_gpu_memory_gb(self._accelerator.device) + logger.debug(f"Accelerator memory usage after models preparation: {vram_usage_gb:.2f} GB") @staticmethod def _find_checkpoint(checkpoint_path: str | Path) -> Path | None: @@ -805,7 +813,7 @@ def _offloaded_optimizer_state(self) -> Iterator[None]: offloaded_bytes = 0 for state in self._optimizer.state.values(): for k, v in state.items(): - if isinstance(v, torch.Tensor) and v.is_cuda: + if isinstance(v, torch.Tensor) and is_accelerator_device(v.device): offloaded.append((state, k)) offloaded_bytes += v.nbytes if offloaded: @@ -1199,7 +1207,16 @@ def _save_training_state(self, save_dir: Path) -> None: ), rng_states=RngStates( torch_state=torch.random.get_rng_state(), - cuda_state=torch.cuda.get_rng_state() if torch.cuda.is_available() else None, + cuda_state=( + get_accelerator_rng_state(self._accelerator.device) + if self._accelerator.device.type == "cuda" + else None + ), + mps_state=( + get_accelerator_rng_state(self._accelerator.device) + if self._accelerator.device.type == "mps" + else None + ), ), lr_scheduler_state_dict=self._lr_scheduler.state_dict() if self._lr_scheduler is not None else None, optimizer_state_dict=optimizer_state, diff --git a/packages/ltx-trainer/src/ltx_trainer/training_state.py b/packages/ltx-trainer/src/ltx_trainer/training_state.py index 45dd2df9..b1ba0ba4 100644 --- a/packages/ltx-trainer/src/ltx_trainer/training_state.py +++ b/packages/ltx-trainer/src/ltx_trainer/training_state.py @@ -18,6 +18,7 @@ class RngStates(BaseModel): torch_state: torch.Tensor cuda_state: torch.Tensor | None = None + mps_state: torch.Tensor | None = None class TrainingState(BaseModel): diff --git a/packages/ltx-trainer/src/ltx_trainer/validation_sampler.py b/packages/ltx-trainer/src/ltx_trainer/validation_sampler.py index b756cf14..51fa2973 100644 --- a/packages/ltx-trainer/src/ltx_trainer/validation_sampler.py +++ b/packages/ltx-trainer/src/ltx_trainer/validation_sampler.py @@ -19,6 +19,7 @@ get_pixel_coords, ) from ltx_core.components.schedulers import LTX2Scheduler +from ltx_core.devices import resolve_device from ltx_core.guidance.perturbations import ( BatchedPerturbationConfig, Perturbation, @@ -160,7 +161,7 @@ def __init__( def generate( self, config: GenerationConfig, - device: torch.device | str = "cuda", + device: torch.device | str = "auto", ) -> tuple[Tensor, Tensor | None]: """Generate a video (and optionally audio) sample. Args: @@ -171,7 +172,7 @@ def generate( - video: Video tensor [C, F, H, W] in [0, 1] (float32) - audio: Audio waveform tensor [C, samples] or None """ - device = torch.device(device) if isinstance(device, str) else device + device = resolve_device(device) self._validate_config(config) # Route to appropriate generation method