Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2883,3 +2883,88 @@ def get_alpha_scales(down_weight, alpha_key):

converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict


def _convert_non_diffusers_ideogram4_lora_to_diffusers(state_dict):
"""
Convert non-diffusers Ideogram4 LoRA state dict to diffusers format.

Handles:
- `diffusion_model.` / `conditional_transformer.` prefix removal
- `lora_down`/`lora_up` (kohya) -> `lora_A`/`lora_B`, with `.alpha` folded into the weights
- fused `attention.qkv` -> split `to_q`/`to_k`/`to_v`; `attention.o` -> `to_out.0`
- `feed_forward.w1`/`w2`/`w3` and `adaln_modulation` map one-to-one
"""
for prefix in ("diffusion_model.", "conditional_transformer."):
if any(k.startswith(prefix) for k in state_dict):
state_dict = {k.removeprefix(prefix): v for k, v in state_dict.items()}
break

is_kohya = any(".lora_down.weight" in k for k in state_dict)
down_suffix = ".lora_down.weight" if is_kohya else ".lora_A.weight"
up_suffix = ".lora_up.weight" if is_kohya else ".lora_B.weight"

def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha_tensor = state_dict.pop(alpha_key, None)
if alpha_tensor is None:
return 1.0, 1.0
# LoRA is scaled by `alpha / rank` in the forward pass; split the factor between down and up.
scale = alpha_tensor.item() / rank
scale_down, scale_up = scale, 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up

def pull(base):
"""Pop the scaled (lora_A, lora_B) pair for a module path, or return None if absent."""
down_key = base + down_suffix
if down_key not in state_dict:
return None
down = state_dict.pop(down_key)
up = state_dict.pop(base + up_suffix)
scale_down, scale_up = get_alpha_scales(down, base + ".alpha")
return down * scale_down, up * scale_up

num_layers = 0
for k in state_dict:
match = re.match(r"layers\.(\d+)\.", k)
if match:
num_layers = max(num_layers, int(match.group(1)) + 1)

converted_state_dict = {}
for i in range(num_layers):
layer_prefix = f"layers.{i}"

# Fused qkv -> split to_q / to_k / to_v (shared down/lora_A, chunk up/lora_B in thirds).
qkv = pull(f"{layer_prefix}.attention.qkv")
if qkv is not None:
down, up = qkv
up_q, up_k, up_v = torch.chunk(up, 3, dim=0)
for proj, up_proj in (("to_q", up_q), ("to_k", up_k), ("to_v", up_v)):
converted_state_dict[f"{layer_prefix}.attention.{proj}.lora_A.weight"] = down.clone()
converted_state_dict[f"{layer_prefix}.attention.{proj}.lora_B.weight"] = up_proj.contiguous()

# attention.o -> attention.to_out.0
out = pull(f"{layer_prefix}.attention.o")
if out is not None:
down, up = out
converted_state_dict[f"{layer_prefix}.attention.to_out.0.lora_A.weight"] = down
converted_state_dict[f"{layer_prefix}.attention.to_out.0.lora_B.weight"] = up

# feed_forward.{w1,w2,w3} and adaln_modulation map one-to-one.
for module in ("feed_forward.w1", "feed_forward.w2", "feed_forward.w3", "adaln_modulation"):
pair = pull(f"{layer_prefix}.{module}")
if pair is not None:
down, up = pair
converted_state_dict[f"{layer_prefix}.{module}.lora_A.weight"] = down
converted_state_dict[f"{layer_prefix}.{module}.lora_B.weight"] = up

if len(state_dict) > 0:
raise ValueError(
f"`state_dict` should be empty at this point but has {sorted(state_dict.keys())}. "
"This may be an unsupported Ideogram4 LoRA layout."
)

return {f"transformer.{k}": v for k, v in converted_state_dict.items()}
10 changes: 9 additions & 1 deletion src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
_convert_non_diffusers_anima_lora_to_diffusers,
_convert_non_diffusers_flux2_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_ideogram4_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_ltx2_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
Expand Down Expand Up @@ -6028,7 +6029,6 @@ class Ideogram4LoraLoaderMixin(LoraBaseMixin):

@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
Expand Down Expand Up @@ -6078,6 +6078,14 @@ def lora_state_dict(
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

# ai-toolkit (ostris) saves Ideogram4 LoRAs under a `diffusion_model.` prefix with a fused
# `attention.qkv` projection; convert those to the diffusers layout before loading.
is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict) or any(
".attention.qkv." in k for k in state_dict
)
if is_non_diffusers_format:
state_dict = _convert_non_diffusers_ideogram4_lora_to_diffusers(state_dict)

out = (state_dict, metadata) if return_lora_metadata else state_dict
return out

Expand Down
Loading