Implement the USO subject identity lora. (#9674)

Use the lora with FluxContextMultiReferenceLatentMethod node set to "uso"
and a ReferenceLatent node with the reference image.
This commit is contained in:
comfyanonymous
2025-09-01 15:54:02 -07:00
committed by GitHub
parent 9b15155972
commit 27e067ce50
4 changed files with 32 additions and 3 deletions

View File

@@ -233,12 +233,18 @@ class Flux(nn.Module):
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
ref_latents_method = kwargs.get("ref_latents_method", "offset")
for ref in ref_latents:
if index_ref_method:
if ref_latents_method == "index":
index += 1
h_offset = 0
w_offset = 0
elif ref_latents_method == "uso":
index = 0
h_offset = h_len * patch_size + h
w_offset = w_len * patch_size + w
h += ref.shape[-2]
w += ref.shape[-1]
else:
index = 1
h_offset = 0

View File

@@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
for k in sdk:
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
if k.endswith(".weight") and ".linear1." in k:
key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:

View File

@@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
def convert_lora_wan_fun(sd): #Wan Fun loras
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
def convert_uso_lora(sd):
sd_out = {}
for k in sd:
tensor = sd[k]
k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
.replace(".up.weight", ".lora_up.weight")
.replace(".qkv_lora2.", ".txt_attn.qkv.")
.replace(".qkv_lora1.", ".img_attn.qkv.")
.replace(".proj_lora1.", ".img_attn.proj.")
.replace(".proj_lora2.", ".txt_attn.proj.")
.replace(".qkv_lora.", ".linear1_qkv.")
.replace(".proj_lora.", ".linear2.")
.replace(".processor.", ".")
)
sd_out[k_to] = tensor
return sd_out
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
return convert_lora_wan_fun(sd)
if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
return convert_uso_lora(sd)
return sd

View File

@@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod:
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index"), ),
"reference_latents_method": (("offset", "index", "uso"), ),
}}
RETURN_TYPES = ("CONDITIONING",)