mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 04:27:21 +00:00
Support base SDXL and SDXL refiner models.
Large refactor of the model detection and loading code.
This commit is contained in:
@@ -26,10 +26,10 @@ def load_torch_file(ckpt, safe_load=False):
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}.positional_embedding": "{}.embeddings.position_embedding.weight",
|
||||
"{}.token_embedding.weight": "{}.embeddings.token_embedding.weight",
|
||||
"{}.ln_final.weight": "{}.final_layer_norm.weight",
|
||||
"{}.ln_final.bias": "{}.final_layer_norm.bias",
|
||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||
"{}token_embedding.weight": "{}embeddings.token_embedding.weight",
|
||||
"{}ln_final.weight": "{}final_layer_norm.weight",
|
||||
"{}ln_final.bias": "{}final_layer_norm.bias",
|
||||
}
|
||||
|
||||
for k in keys_to_replace:
|
||||
@@ -48,19 +48,19 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
for resblock in range(number):
|
||||
for x in resblock_to_replace:
|
||||
for y in ["weight", "bias"]:
|
||||
k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
if k in sd:
|
||||
sd[k_to] = sd.pop(k)
|
||||
|
||||
for y in ["weight", "bias"]:
|
||||
k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||
if k_from in sd:
|
||||
weights = sd.pop(k_from)
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
return sd
|
||||
|
||||
|
Reference in New Issue
Block a user