diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e2ece5cb3685..e5683d9d7f28 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -324,12 +324,12 @@ def set_use_xla_flash_attention( Specify the partition specification if using SPMD. Otherwise None. """ if use_xla_flash_attention: - if not is_torch_xla_available: - raise "torch_xla is not available" + if not is_torch_xla_available(): + raise ImportError("torch_xla is not available") elif is_torch_xla_version("<", "2.3"): - raise "flash attention pallas kernel is supported from torch_xla version 2.3" + raise ImportError("flash attention pallas kernel is supported from torch_xla version 2.3") elif is_spmd() and is_torch_xla_version("<", "2.4"): - raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4" + raise ImportError("flash attention pallas kernel using SPMD is supported from torch_xla version 2.4") else: if is_flux: processor = XLAFluxFlashAttnProcessor2_0(partition_spec)