diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index 16ff0d83b8c4..c6f6ff886a8d 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -61,7 +61,7 @@ def __init__(self, quantization_config, **kwargs): self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules def validate_environment(self, *args, **kwargs): - if not (torch.cuda.is_available() or torch.xpu.is_available()): + if not (torch.cuda.is_available() or torch.xpu.is_available() or torch.mps.is_available()): raise RuntimeError("No GPU found. A GPU is needed for quantization.") if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): raise ImportError( @@ -240,6 +240,8 @@ def update_device_map(self, device_map): if device_map is None: if torch.xpu.is_available(): current_device = f"xpu:{torch.xpu.current_device()}" + elif torch.mps.is_available(): + current_device = "mps" else: current_device = f"cuda:{torch.cuda.current_device()}" device_map = {"": current_device} @@ -411,6 +413,8 @@ def update_device_map(self, device_map): if device_map is None: if torch.xpu.is_available(): current_device = f"xpu:{torch.xpu.current_device()}" + elif torch.mps.is_available(): + current_device = "mps" else: current_device = f"cuda:{torch.cuda.current_device()}" device_map = {"": current_device}