mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 13:35:05 +00:00
Support full SD3 loras.
This commit is contained in:
@@ -252,15 +252,14 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map[diffusers_lora_key] = unet_key
|
||||
|
||||
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
|
||||
for i in range(model.model_config.unet_config.get("depth", 0)):
|
||||
k = "transformer.transformer_blocks.{}.attn.".format(i)
|
||||
qkv = "diffusion_model.joint_blocks.{}.x_block.attn.qkv.weight".format(i)
|
||||
proj = "diffusion_model.joint_blocks.{}.x_block.attn.proj.weight".format(i)
|
||||
if qkv in sd:
|
||||
offset = sd[qkv].shape[0] // 3
|
||||
key_map["{}to_q".format(k)] = (qkv, (0, 0, offset))
|
||||
key_map["{}to_k".format(k)] = (qkv, (0, offset, offset))
|
||||
key_map["{}to_v".format(k)] = (qkv, (0, offset * 2, offset))
|
||||
key_map["{}to_out.0".format(k)] = proj
|
||||
diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
|
||||
key_map[key_lora] = to
|
||||
|
||||
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
|
||||
key_map[key_lora] = to
|
||||
|
||||
return key_map
|
||||
|
Reference in New Issue
Block a user