Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,6 +1482,12 @@ def cuda(self, *args, **kwargs):
def to(self, *args, **kwargs):
from ..hooks.group_offloading import _is_group_offload_enabled

fp32_modules = self._keep_in_fp32_modules or []
if fp32_modules is not None:
Comment thread
sayakpaul marked this conversation as resolved.
Outdated
logger.debug(
f"There are modules in {self.__class__.__name__} that should be kept in float32. A bare `to()` might lead to inconsistent results."
)

device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
dtype_present_in_args = "dtype" in kwargs

Expand Down
42 changes: 41 additions & 1 deletion tests/models/testing_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,28 @@ def test_keep_in_fp32_modules(self, tmp_path):
else:
assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}"

def test_to_keep_in_fp32_modules_warns(self, caplog):
fp32_modules = self.model_class._keep_in_fp32_modules
if fp32_modules is None or len(fp32_modules) == 0:
pytest.skip("Model does not have _keep_in_fp32_modules defined.")

model = self.model_class(**self.get_init_dict())

logger_name = "diffusers.models.modeling_utils"
logging.enable_propagation()
try:
with caplog.at_level(logging.DEBUG, logger=logger_name):
caplog.clear()
model.to(torch_device)
finally:
logging.disable_propagation()

expected_message = (
f"There are modules in {model.__class__.__name__} that should be kept in float32. "
"A bare `to()` might lead to inconsistent results."
)
assert expected_message in caplog.text

@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
Expand All @@ -481,7 +503,25 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4,
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules or []

model.to(dtype).save_pretrained(tmp_path)
# Build the reference model with the same mixed-precision layout that `from_pretrained` enforces, so
# the comparison reflects real save/load fidelity:
# - `_keep_in_fp32_modules` stay in fp32 while everything else is cast to `dtype`;
# - non-persistent buffers (e.g. fp32 RoPE `inv_freq`) are left untouched, because they are not part
# of the checkpoint and are regenerated by `__init__` on load. Truncating them here would make the
# reference diverge from the reloaded model for reasons unrelated to save/load.
persistent_tensor_names = {name for name, _ in named_persistent_module_tensors(model, recurse=True)}

def keep_in_fp32(name):
return any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules)

for name, param in model.named_parameters():
param.data = param.data.to(torch.float32 if keep_in_fp32(name) else dtype)
for name, buf in model.named_buffers():
if not buf.is_floating_point() or name not in persistent_tensor_names:
continue
buf.data = buf.data.to(torch.float32 if keep_in_fp32(name) else dtype)

model.save_pretrained(tmp_path)
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)

for name, param in model_loaded.named_parameters():
Expand Down
7 changes: 0 additions & 7 deletions tests/models/transformers/test_models_transformer_anyflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import AnyFlowTransformer3DModel
Expand Down Expand Up @@ -100,12 +99,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestAnyFlowTransformer3D(AnyFlowTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for AnyFlow Transformer 3D (bidirectional variant)."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestAnyFlowTransformer3DMemory(AnyFlowTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AnyFlow Transformer 3D."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,6 @@ def get_dummy_inputs(self) -> dict[str, "torch.Tensor"]:
class TestAnyFlowFARTransformer3D(AnyFlowFARTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for AnyFlow FAR causal Transformer 3D."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestAnyFlowFARTransformer3DMemory(AnyFlowFARTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AnyFlow FAR Transformer 3D."""
Expand Down
6 changes: 0 additions & 6 deletions tests/models/transformers/test_models_transformer_helios.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestHeliosTransformer3D(HeliosTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Helios Transformer 3D."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestHeliosTransformer3DMemory(HeliosTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Helios Transformer 3D."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import Ideogram4Transformer2DModel
Expand Down Expand Up @@ -141,14 +140,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestIdeogram4Transformer(Ideogram4TransformerTesterConfig, ModelTesterMixin):
"""Core model tests for Ideogram 4 Transformer."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: the non-persistent fp32 RoPE inv_freq buffer is truncated to fp16 by the in-memory
# .to(dtype) path but kept fp32 by from_pretrained, so the two outputs diverge well beyond any
# meaningful tolerance. Dtype preservation is already covered by test_from_save_pretrained_dtype
# and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestIdeogram4TransformerMemory(Ideogram4TransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Ideogram 4 Transformer."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import JoyImageEditTransformer3DModel
Expand Down Expand Up @@ -86,9 +85,7 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:


class TestJoyImageEditTransformer(JoyImageEditTransformerTesterConfig, ModelTesterMixin):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
pytest.skip("Tolerance requirements too high for meaningful test")
pass


class TestJoyImageEditTransformerMemory(JoyImageEditTransformerTesterConfig, MemoryTesterMixin):
Expand Down
7 changes: 0 additions & 7 deletions tests/models/transformers/test_models_transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers import WanTransformer3DModel
Expand Down Expand Up @@ -106,12 +105,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Transformer 3D."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Transformer 3D."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,6 @@ def test_output(self):
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol (~1e-2) to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")


class TestWanAnimateTransformer3DMemory(WanAnimateTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Animate Transformer 3D."""
Expand Down
6 changes: 0 additions & 6 deletions tests/models/transformers/test_models_transformer_wan_vace.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,6 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
class TestWanVACETransformer3D(WanVACETransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan VACE Transformer 3D."""

@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")

def test_model_parallelism(self, tmp_path):
# Skip: Device mismatch between cuda:0 and cuda:1 in VACE control flow
pytest.skip("Model parallelism not yet supported for WanVACE")
Expand Down
Loading