@@ -29,7 +29,6 @@ def __init__(self, name, filename):
2929
3030 def read_metadata ():
3131 metadata = sd_models .read_metadata_from_safetensors (filename )
32- metadata .pop ('ssmd_cover_images' , None ) # those are cover images, and they are too big to display in UI as text
3332
3433 return metadata
3534
@@ -117,6 +116,12 @@ def __init__(self, net: Network, weights: NetworkWeights):
117116
118117 if hasattr (self .sd_module , 'weight' ):
119118 self .shape = self .sd_module .weight .shape
119+ elif isinstance (self .sd_module , nn .MultiheadAttention ):
120+ # For now, only self-attn use Pytorch's MHA
121+ # So assume all qkvo proj have same shape
122+ self .shape = self .sd_module .out_proj .weight .shape
123+ else :
124+ self .shape = None
120125
121126 self .ops = None
122127 self .extra_kwargs = {}
@@ -146,6 +151,9 @@ def __init__(self, net: Network, weights: NetworkWeights):
146151 self .alpha = weights .w ["alpha" ].item () if "alpha" in weights .w else None
147152 self .scale = weights .w ["scale" ].item () if "scale" in weights .w else None
148153
154+ self .dora_scale = weights .w .get ("dora_scale" , None )
155+ self .dora_norm_dims = len (self .shape ) - 1
156+
149157 def multiplier (self ):
150158 if 'transformer' in self .sd_key [:20 ]:
151159 return self .network .te_multiplier
@@ -160,6 +168,27 @@ def calc_scale(self):
160168
161169 return 1.0
162170
171+ def apply_weight_decompose (self , updown , orig_weight ):
172+ # Match the device/dtype
173+ orig_weight = orig_weight .to (updown .dtype )
174+ dora_scale = self .dora_scale .to (device = orig_weight .device , dtype = updown .dtype )
175+ updown = updown .to (orig_weight .device )
176+
177+ merged_scale1 = updown + orig_weight
178+ merged_scale1_norm = (
179+ merged_scale1 .transpose (0 , 1 )
180+ .reshape (merged_scale1 .shape [1 ], - 1 )
181+ .norm (dim = 1 , keepdim = True )
182+ .reshape (merged_scale1 .shape [1 ], * [1 ] * self .dora_norm_dims )
183+ .transpose (0 , 1 )
184+ )
185+
186+ dora_merged = (
187+ merged_scale1 * (dora_scale / merged_scale1_norm )
188+ )
189+ final_updown = dora_merged - orig_weight
190+ return final_updown
191+
163192 def finalize_updown (self , updown , orig_weight , output_shape , ex_bias = None ):
164193 if self .bias is not None :
165194 updown = updown .reshape (self .bias .shape )
@@ -175,6 +204,9 @@ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
175204 if ex_bias is not None :
176205 ex_bias = ex_bias * self .multiplier ()
177206
207+ if self .dora_scale is not None :
208+ updown = self .apply_weight_decompose (updown , orig_weight )
209+
178210 return updown * self .calc_scale () * self .multiplier (), ex_bias
179211
180212 def calc_updown (self , target ):
0 commit comments