mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Implement the USO subject identity lora. (#9674)
Use the lora with FluxContextMultiReferenceLatentMethod node set to "uso" and a ReferenceLatent node with the reference image.
This commit is contained in:
@@ -233,12 +233,18 @@ class Flux(nn.Module):
|
|||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
index = 0
|
index = 0
|
||||||
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
|
ref_latents_method = kwargs.get("ref_latents_method", "offset")
|
||||||
for ref in ref_latents:
|
for ref in ref_latents:
|
||||||
if index_ref_method:
|
if ref_latents_method == "index":
|
||||||
index += 1
|
index += 1
|
||||||
h_offset = 0
|
h_offset = 0
|
||||||
w_offset = 0
|
w_offset = 0
|
||||||
|
elif ref_latents_method == "uso":
|
||||||
|
index = 0
|
||||||
|
h_offset = h_len * patch_size + h
|
||||||
|
w_offset = w_len * patch_size + w
|
||||||
|
h += ref.shape[-2]
|
||||||
|
w += ref.shape[-1]
|
||||||
else:
|
else:
|
||||||
index = 1
|
index = 1
|
||||||
h_offset = 0
|
h_offset = 0
|
||||||
|
@@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||||
|
for k in sdk:
|
||||||
|
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
|
||||||
|
if k.endswith(".weight") and ".linear1." in k:
|
||||||
|
key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.GenmoMochi):
|
if isinstance(model, comfy.model_base.GenmoMochi):
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
|
@@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
|
|||||||
def convert_lora_wan_fun(sd): #Wan Fun loras
|
def convert_lora_wan_fun(sd): #Wan Fun loras
|
||||||
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
|
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
|
||||||
|
|
||||||
|
def convert_uso_lora(sd):
|
||||||
|
sd_out = {}
|
||||||
|
for k in sd:
|
||||||
|
tensor = sd[k]
|
||||||
|
k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
|
||||||
|
.replace(".up.weight", ".lora_up.weight")
|
||||||
|
.replace(".qkv_lora2.", ".txt_attn.qkv.")
|
||||||
|
.replace(".qkv_lora1.", ".img_attn.qkv.")
|
||||||
|
.replace(".proj_lora1.", ".img_attn.proj.")
|
||||||
|
.replace(".proj_lora2.", ".txt_attn.proj.")
|
||||||
|
.replace(".qkv_lora.", ".linear1_qkv.")
|
||||||
|
.replace(".proj_lora.", ".linear2.")
|
||||||
|
.replace(".processor.", ".")
|
||||||
|
)
|
||||||
|
sd_out[k_to] = tensor
|
||||||
|
return sd_out
|
||||||
|
|
||||||
|
|
||||||
def convert_lora(sd):
|
def convert_lora(sd):
|
||||||
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||||
return convert_lora_bfl_control(sd)
|
return convert_lora_bfl_control(sd)
|
||||||
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
|
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
|
||||||
return convert_lora_wan_fun(sd)
|
return convert_lora_wan_fun(sd)
|
||||||
|
if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
|
||||||
|
return convert_uso_lora(sd)
|
||||||
return sd
|
return sd
|
||||||
|
@@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {
|
return {"required": {
|
||||||
"conditioning": ("CONDITIONING", ),
|
"conditioning": ("CONDITIONING", ),
|
||||||
"reference_latents_method": (("offset", "index"), ),
|
"reference_latents_method": (("offset", "index", "uso"), ),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
Reference in New Issue
Block a user