diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index bf516abc825f..a29d74024c18 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2883,3 +2883,88 @@ def get_alpha_scales(down_weight, alpha_key): converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} return converted_state_dict + + +def _convert_non_diffusers_ideogram4_lora_to_diffusers(state_dict): + """ + Convert non-diffusers Ideogram4 LoRA state dict to diffusers format. + + Handles: + - `diffusion_model.` / `conditional_transformer.` prefix removal + - `lora_down`/`lora_up` (kohya) -> `lora_A`/`lora_B`, with `.alpha` folded into the weights + - fused `attention.qkv` -> split `to_q`/`to_k`/`to_v`; `attention.o` -> `to_out.0` + - `feed_forward.w1`/`w2`/`w3` and `adaln_modulation` map one-to-one + """ + for prefix in ("diffusion_model.", "conditional_transformer."): + if any(k.startswith(prefix) for k in state_dict): + state_dict = {k.removeprefix(prefix): v for k, v in state_dict.items()} + break + + is_kohya = any(".lora_down.weight" in k for k in state_dict) + down_suffix = ".lora_down.weight" if is_kohya else ".lora_A.weight" + up_suffix = ".lora_up.weight" if is_kohya else ".lora_B.weight" + + def get_alpha_scales(down_weight, alpha_key): + rank = down_weight.shape[0] + alpha_tensor = state_dict.pop(alpha_key, None) + if alpha_tensor is None: + return 1.0, 1.0 + # LoRA is scaled by `alpha / rank` in the forward pass; split the factor between down and up. + scale = alpha_tensor.item() / rank + scale_down, scale_up = scale, 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + return scale_down, scale_up + + def pull(base): + """Pop the scaled (lora_A, lora_B) pair for a module path, or return None if absent.""" + down_key = base + down_suffix + if down_key not in state_dict: + return None + down = state_dict.pop(down_key) + up = state_dict.pop(base + up_suffix) + scale_down, scale_up = get_alpha_scales(down, base + ".alpha") + return down * scale_down, up * scale_up + + num_layers = 0 + for k in state_dict: + match = re.match(r"layers\.(\d+)\.", k) + if match: + num_layers = max(num_layers, int(match.group(1)) + 1) + + converted_state_dict = {} + for i in range(num_layers): + layer_prefix = f"layers.{i}" + + # Fused qkv -> split to_q / to_k / to_v (shared down/lora_A, chunk up/lora_B in thirds). + qkv = pull(f"{layer_prefix}.attention.qkv") + if qkv is not None: + down, up = qkv + up_q, up_k, up_v = torch.chunk(up, 3, dim=0) + for proj, up_proj in (("to_q", up_q), ("to_k", up_k), ("to_v", up_v)): + converted_state_dict[f"{layer_prefix}.attention.{proj}.lora_A.weight"] = down.clone() + converted_state_dict[f"{layer_prefix}.attention.{proj}.lora_B.weight"] = up_proj.contiguous() + + # attention.o -> attention.to_out.0 + out = pull(f"{layer_prefix}.attention.o") + if out is not None: + down, up = out + converted_state_dict[f"{layer_prefix}.attention.to_out.0.lora_A.weight"] = down + converted_state_dict[f"{layer_prefix}.attention.to_out.0.lora_B.weight"] = up + + # feed_forward.{w1,w2,w3} and adaln_modulation map one-to-one. + for module in ("feed_forward.w1", "feed_forward.w2", "feed_forward.w3", "adaln_modulation"): + pair = pull(f"{layer_prefix}.{module}") + if pair is not None: + down, up = pair + converted_state_dict[f"{layer_prefix}.{module}.lora_A.weight"] = down + converted_state_dict[f"{layer_prefix}.{module}.lora_B.weight"] = up + + if len(state_dict) > 0: + raise ValueError( + f"`state_dict` should be empty at this point but has {sorted(state_dict.keys())}. " + "This may be an unsupported Ideogram4 LoRA layout." + ) + + return {f"transformer.{k}": v for k, v in converted_state_dict.items()} diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index f4bec3f15e1b..0abeba91e983 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -49,6 +49,7 @@ _convert_non_diffusers_anima_lora_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, + _convert_non_diffusers_ideogram4_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_ltx2_lora_to_diffusers, _convert_non_diffusers_ltxv_lora_to_diffusers, @@ -6028,7 +6029,6 @@ class Ideogram4LoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], @@ -6078,6 +6078,14 @@ def lora_state_dict( logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + # ai-toolkit (ostris) saves Ideogram4 LoRAs under a `diffusion_model.` prefix with a fused + # `attention.qkv` projection; convert those to the diffusers layout before loading. + is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict) or any( + ".attention.qkv." in k for k in state_dict + ) + if is_non_diffusers_format: + state_dict = _convert_non_diffusers_ideogram4_lora_to_diffusers(state_dict) + out = (state_dict, metadata) if return_lora_metadata else state_dict return out