Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ uv sync --frozen
source .venv/bin/activate
```

### Hardware Backends

The code automatically selects the best available PyTorch backend in this order: CUDA, MPS, then CPU. On Apple Silicon
Macs with an MPS-enabled PyTorch install, pipeline and utility-script defaults now use `torch.device("mps")`; pass
`--device cpu`, `--device cuda`, or `--device mps` when you need to override auto-selection. CUDA-only optimizations
such as TensorRT-LLM FP8 scaled matrix multiplication, FlashAttention, xFormers, and bitsandbytes 8-bit loading remain
NVIDIA-specific.

### Required Models

Download the following models from the [LTX-2.3 HuggingFace repository](https://huggingface.co/Lightricks/LTX-2.3):
Expand Down
18 changes: 9 additions & 9 deletions packages/ltx-core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pip install -e packages/ltx-core

### Loader

The `loader/` module provides `SingleGPUModelBuilder`, a frozen dataclass that loads a PyTorch model from `.safetensors` checkpoints and optionally fuses one or more LoRA adapters.
The `loader/` module provides `SingleGPUModelBuilder`, a frozen dataclass that loads a PyTorch model from `.safetensors` checkpoints and optionally fuses one or more LoRA adapters. When no device is supplied it automatically picks CUDA, then MPS, then CPU.

#### Basic usage

Expand All @@ -69,7 +69,7 @@ builder = SingleGPUModelBuilder(
model_class_configurator=MyModelConfigurator,
model_path="/path/to/model.safetensors",
)
model = builder.build(device=torch.device("cuda"))
model = builder.build()
```

#### Loading LoRA adapters
Expand All @@ -89,27 +89,27 @@ builder = (
.lora("/path/to/lora_a.safetensors", 0.8, lora_sd_ops)
.lora("/path/to/lora_b.safetensors", 0.5, lora_sd_ops)
)
model = builder.build(device=torch.device("cuda"))
model = builder.build()
```

#### Memory-efficient LoRA loading (`lora_load_device`)

By default, LoRA weights are loaded onto the **CPU** (`lora_load_device=torch.device("cpu")`). This means each LoRA adapter is kept in CPU memory and transferred to the GPU sequentially during weight fusion, which keeps peak GPU memory low even when fusing large adapters.
By default, LoRA weights are loaded onto the **CPU** (`lora_load_device=torch.device("cpu")`). This means each LoRA adapter is kept in CPU memory and transferred to the accelerator sequentially during weight fusion, which keeps peak accelerator memory low even when fusing large adapters.

If all adapters fit comfortably in GPU memory you can skip the CPU staging by setting `lora_load_device` to the target CUDA device:
If all adapters fit comfortably in accelerator memory you can skip the CPU staging by setting `lora_load_device` to the target CUDA or MPS device:

```python
import torch
from ltx_core.loader import SingleGPUModelBuilder

# Load LoRA weights directly onto the GPU (faster, but uses more GPU memory)
# Load LoRA weights directly onto the accelerator (faster, but uses more accelerator memory)
builder = SingleGPUModelBuilder(
model_class_configurator=MyModelConfigurator,
model_path="/path/to/model.safetensors",
lora_load_device=torch.device("cuda"),
lora_load_device=torch.device("mps"),
).lora("/path/to/lora.safetensors", 1.0, lora_sd_ops)

model = builder.build(device=torch.device("cuda"))
model = builder.build(device=torch.device("mps"))
```

### Quantization
Expand Down Expand Up @@ -144,7 +144,7 @@ builder = SingleGPUModelBuilder(
module_ops=policy.module_ops,
fuse_rule=policy.fuse_rule,
)
model = builder.build(device=torch.device("cuda"))
model = builder.build()
```

#### FP8 Cast
Expand Down
41 changes: 32 additions & 9 deletions packages/ltx-core/src/ltx_core/block_streaming/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ltx_core.block_streaming.source import DiskWeightSource, PinnedWeightSource, WeightSource
from ltx_core.block_streaming.utils import allocate_layout_views, derive_layout, make_block_key, resolve_attr
from ltx_core.block_streaming.wrapper import BlockStreamingWrapper
from ltx_core.devices import synchronize_device
from ltx_core.loader.fuse_loras import FuseRule, bf16_fuse_rule, fuse_lora_weights
from ltx_core.loader.helpers import create_meta_model, load_state_dict, read_model_config
from ltx_core.loader.module_ops import ModuleOps
Expand Down Expand Up @@ -102,7 +103,7 @@ def build(
) -> BlockStreamingWrapper:
"""Build and return a ready-to-use :class:`BlockStreamingWrapper`.
Args:
target_device: GPU device for compute.
target_device: Accelerator device for compute.
dtype: Weight dtype (e.g. ``torch.bfloat16``).
cpu_slots_count: Number of pinned CPU buffer slots.
``None`` = RAM streaming (all blocks pre-loaded with LoRA fusion).
Expand All @@ -112,6 +113,9 @@ def build(
if not self.blocks_prefix:
raise ValueError("blocks_prefix must be non-empty for streaming")

target_device = torch.device(target_device)
use_cuda_streaming = target_device.type == "cuda"

config = read_model_config(self.model_path, self.model_loader)
meta_model: nn.Module = create_meta_model(self.model_class_configurator, config, self.module_ops)
meta_model.eval()
Expand All @@ -126,20 +130,37 @@ def build(

if cpu_slots_count >= len(blocks):
source, lora_sources = self._build_pinned_source(
meta_model, target_device, dtype, cpu_slots_count, block_key_map, non_block_keys
meta_model, target_device, dtype, cpu_slots_count, block_key_map, non_block_keys, use_cuda_streaming
)
else:
reader = DiskTensorReader(checkpoint_paths)
source, lora_sources = self._build_disk_source(
meta_model, target_device, dtype, cpu_slots_count, reader, block_key_map, non_block_keys
meta_model,
target_device,
dtype,
cpu_slots_count,
reader,
block_key_map,
non_block_keys,
use_cuda_streaming,
)

copy_stream = torch.cuda.Stream(device=target_device)
copy_stream = torch.cuda.Stream(device=target_device) if use_cuda_streaming else None
if copy_stream is not None:

def reuse_barrier(event: object) -> None:
copy_stream.wait_event(event)

else:

def reuse_barrier(_event: object) -> None:
return None

gpu_pool = WeightPool(
source.block_layout,
gpu_slots_count,
target_device,
reuse_barrier=lambda event: copy_stream.wait_event(event),
reuse_barrier=reuse_barrier,
)
provider = WeightsProvider(
gpu_pool,
Expand All @@ -165,6 +186,7 @@ def _build_pinned_source(
cpu_slots_count: int,
block_key_map: dict[int, list[tuple[str, str]]],
non_block_keys: list[tuple[str, str]],
pin_memory: bool,
) -> tuple[WeightSource, list[LoraSource]]:
"""Pre-load all blocks into pinned CPU buffers with LoRA fusion."""
model_sd = load_state_dict(
Expand Down Expand Up @@ -194,7 +216,7 @@ def _build_pinned_source(
key = make_block_key(self.blocks_prefix, block_idx, param_name)
block_tensors[key] = block_params[param_name]
blocks_layout = derive_layout(block_tensors, dtype)
pinned_blocks = allocate_layout_views(blocks_layout, pin_memory=True)
pinned_blocks = allocate_layout_views(blocks_layout, pin_memory=pin_memory)

should_sync = False
for key, fused in fuse_lora_weights(
Expand All @@ -207,7 +229,7 @@ def _build_pinned_source(
else:
model_sd.sd[key] = fused
if should_sync:
torch.cuda.synchronize()
synchronize_device(target_device)

# Fill remaining pinned keys from the source state dict.
for key in blocks_layout:
Expand Down Expand Up @@ -242,13 +264,14 @@ def _build_disk_source(
reader: DiskTensorReader,
block_key_map: dict[int, list[tuple[str, str]]],
non_block_keys: list[tuple[str, str]],
pin_memory: bool,
) -> tuple[WeightSource, list[LoraSource]]:
"""Create a DiskWeightSource backed by a DiskBlockReader for lazy loading.
Derives the shared pool layout from the meta model's block 0 - this
relies on module_ops (e.g. fp8_cast) leaving the meta param dtype in
sync with the post-sd_ops checkpoint dtype.
"""
lora_sources = [LoraSource(lora.path, lora.sd_ops, lora.strength) for lora in self.loras]
lora_sources = [LoraSource(lora.path, lora.sd_ops, lora.strength, pin_memory=pin_memory) for lora in self.loras]

self._load_non_block_weights(
reader,
Expand All @@ -269,7 +292,7 @@ def _build_disk_source(
cpu_slots_count,
torch.device("cpu"),
reuse_barrier=lambda event: event.synchronize(),
pin_memory=True,
pin_memory=pin_memory,
)
block_reader = DiskBlockReader(
reader=reader,
Expand Down
8 changes: 4 additions & 4 deletions packages/ltx-core/src/ltx_core/block_streaming/disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def cleanup(self) -> None:


class LoraSource:
"""Pinned-memory cache of matched LoRA A/B factors backed by a single buffer."""
"""CPU cache of matched LoRA A/B factors backed by a single buffer."""

def __init__(self, path: str, sd_ops: SDOps | None, strength: float) -> None:
def __init__(self, path: str, sd_ops: SDOps | None, strength: float, pin_memory: bool = True) -> None:
self.strength = strength
self._pinned_ab: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}

Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(self, path: str, sd_ops: SDOps | None, strength: float) -> None:
_SAFETENSORS_DTYPE_TO_TORCH[b_slice_view.get_dtype()],
)

all_views = allocate_layout_views(layout, pin_memory=True)
all_views = allocate_layout_views(layout, pin_memory=pin_memory)

for prefix in matched_prefixes:
a_view = all_views[f"{prefix}.A"]
Expand Down Expand Up @@ -153,7 +153,7 @@ def get_ab(
if pair is None:
return None
a, b = pair
if device is not None and device.type == "cuda":
if device is not None and device.type != "cpu":
a = a.to(device=device, non_blocking=True)
b = b.to(device=device, non_blocking=True)
if dtype is not None:
Expand Down
6 changes: 3 additions & 3 deletions packages/ltx-core/src/ltx_core/block_streaming/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def __init__(
buffer_layout: TensorLayout,
capacity: int,
device: torch.device,
reuse_barrier: Callable[[torch.cuda.Event], None],
reuse_barrier: Callable[[object], None],
pin_memory: bool = False,
) -> None:
self._buffer_layout = buffer_layout
self._capacity = capacity
self._free: deque[dict[str, torch.Tensor]] = deque()
self._events: dict[int, torch.cuda.Event] = {}
self._events: dict[int, object] = {}
self._reuse_barrier = reuse_barrier
memory_layout = {
_make_key(slot, name): (shape, dtype)
Expand All @@ -61,7 +61,7 @@ def acquire(self) -> dict[str, torch.Tensor]:
self._reuse_barrier(event)
return weights

def release(self, weights: dict[str, torch.Tensor], event: torch.cuda.Event | None = None) -> None:
def release(self, weights: dict[str, torch.Tensor], event: object | None = None) -> None:
"""Return a buffer to the free list.
If *event* is given it is waited on the next :meth:`acquire`
of this buffer, ensuring the prior operation has completed.
Expand Down
69 changes: 45 additions & 24 deletions packages/ltx-core/src/ltx_core/block_streaming/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ltx_core.block_streaming.disk import LoraSource
from ltx_core.block_streaming.pool import WeightPool
from ltx_core.block_streaming.source import WeightSource
from ltx_core.devices import synchronize_device
from ltx_core.loader.fuse_loras import FuseRule, aggregate_lora_products, bf16_fuse_rule
from ltx_core.loader.primitives import StateDict

Expand Down Expand Up @@ -37,13 +38,14 @@ def _contiguous_byte_view(weights: dict[str, torch.Tensor]) -> torch.Tensor | No


class WeightsProvider:
"""Provides GPU-ready block weights via H2D copy from a pinned CPU weight source.
"""Provides accelerator-ready block weights via copies from a CPU weight source.
Args:
pool: Pre-allocated GPU weight buffer pool.
copy_stream: Dedicated CUDA stream for async H2D copies.
target_device: GPU device for compute.
source: Pinned CPU weight source.
lora_sources: LoRA adapters fused on H2D copy.
pool: Pre-allocated accelerator weight buffer pool.
copy_stream: Dedicated CUDA stream for async H2D copies. ``None`` uses
synchronous copies, which is used for MPS/CPU.
target_device: Accelerator device for compute.
source: CPU weight source.
lora_sources: LoRA adapters fused after copying.
blocks_prefix: State-dict prefix for LoRA key matching.
fuse_rule: Per-policy LoRA merge rule (must be streaming-compatible:
no companion-key emission). Defaults to ``bf16_fuse_rule``.
Expand All @@ -52,7 +54,7 @@ class WeightsProvider:
def __init__(
self,
pool: WeightPool,
copy_stream: torch.cuda.Stream,
copy_stream: torch.cuda.Stream | None,
target_device: torch.device,
source: WeightSource,
lora_sources: list[LoraSource] | None = None,
Expand All @@ -62,15 +64,15 @@ def __init__(
self._copy_stream = copy_stream
self._pool = pool
self._cache: OrderedDict[int, dict[str, torch.Tensor]] = OrderedDict()
self._events: dict[int, torch.cuda.Event] = {}
self._events: dict[int, object] = {}
self._target_device = target_device
self._source = source
self._lora_sources = lora_sources or []
self._blocks_prefix = blocks_prefix
self._fuse_rule = fuse_rule

def get(self, idx: int) -> dict[str, torch.Tensor]:
"""Return GPU weights for block *idx*. Does H2D copy on miss."""
"""Return accelerator weights for block *idx*. Copies from CPU on miss."""
if idx in self._cache:
return self._cache[idx]

Expand All @@ -82,30 +84,30 @@ def get(self, idx: int) -> dict[str, torch.Tensor]:
gpu_weights = self._pool.acquire()
cpu_weights = self._source.get(idx)

h2d_event = self._copy_to_gpu(idx, gpu_weights, cpu_weights)
h2d_event = self._copy_to_device(idx, gpu_weights, cpu_weights)
self._source.release(idx, event=h2d_event)

self._cache[idx] = gpu_weights
return gpu_weights

def _copy_to_gpu(
def _copy_to_device(
self,
idx: int,
gpu_weights: dict[str, torch.Tensor],
cpu_weights: dict[str, torch.Tensor],
) -> torch.cuda.Event:
"""Enqueue H2D copy + LoRA fusion on the copy stream and wait on compute.
) -> object | None:
"""Copy weights to the target device and fuse LoRAs.
The wait is intentionally inside this method so callers -- and
instrumentation regions wrapping it -- observe the full transfer time.
"""
if self._copy_stream is None:
self._copy_weights(gpu_weights, cpu_weights, non_blocking=False)
if self._lora_sources:
self._fuse_block_loras(idx, gpu_weights)
return None

with torch.cuda.stream(self._copy_stream):
gpu_view = _contiguous_byte_view(gpu_weights)
cpu_view = _contiguous_byte_view(cpu_weights)
if gpu_view is not None and cpu_view is not None and gpu_view.numel() == cpu_view.numel():
gpu_view.copy_(cpu_view, non_blocking=True)
else:
for name, gpu_tensor in gpu_weights.items():
gpu_tensor.copy_(cpu_weights[name], non_blocking=True)
self._copy_weights(gpu_weights, cpu_weights, non_blocking=True)
if self._lora_sources:
self._fuse_block_loras(idx, gpu_weights)
h2d_event = torch.cuda.Event()
Expand All @@ -114,14 +116,33 @@ def _copy_to_gpu(
torch.cuda.current_stream(self._target_device).wait_event(h2d_event)
return h2d_event

def release(self, idx: int, event: torch.cuda.Event) -> None:
@staticmethod
def _copy_weights(
gpu_weights: dict[str, torch.Tensor],
cpu_weights: dict[str, torch.Tensor],
*,
non_blocking: bool,
) -> None:
gpu_view = _contiguous_byte_view(gpu_weights)
cpu_view = _contiguous_byte_view(cpu_weights)
if gpu_view is not None and cpu_view is not None and gpu_view.numel() == cpu_view.numel():
gpu_view.copy_(cpu_view, non_blocking=non_blocking)
else:
for name, gpu_tensor in gpu_weights.items():
gpu_tensor.copy_(cpu_weights[name], non_blocking=non_blocking)

def release(self, idx: int, event: object | None = None) -> None:
"""Attach a compute-done event -- waited before this buffer is recycled."""
self._events[idx] = event
if event is not None:
self._events[idx] = event

def cleanup(self) -> None:
"""Synchronize streams and release all resources."""
self._copy_stream.synchronize()
torch.cuda.current_stream(self._target_device).synchronize()
if self._copy_stream is not None:
self._copy_stream.synchronize()
torch.cuda.current_stream(self._target_device).synchronize()
else:
synchronize_device(self._target_device)
self._cache.clear()
self._events.clear()
self._source.cleanup()
Expand Down
Loading