Support base SDXL and SDXL refiner models.

Large refactor of the model detection and loading code.
This commit is contained in:
comfyanonymous
2023-06-22 13:03:50 -04:00
parent 9fccf4aa03
commit f87ec10a97
16 changed files with 754 additions and 289 deletions

View File

@@ -26,10 +26,10 @@ def load_torch_file(ckpt, safe_load=False):
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}.positional_embedding": "{}.embeddings.position_embedding.weight",
"{}.token_embedding.weight": "{}.embeddings.token_embedding.weight",
"{}.ln_final.weight": "{}.final_layer_norm.weight",
"{}.ln_final.bias": "{}.final_layer_norm.bias",
"{}positional_embedding": "{}embeddings.position_embedding.weight",
"{}token_embedding.weight": "{}embeddings.token_embedding.weight",
"{}ln_final.weight": "{}final_layer_norm.weight",
"{}ln_final.bias": "{}final_layer_norm.bias",
}
for k in keys_to_replace:
@@ -48,19 +48,19 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
for resblock in range(number):
for x in resblock_to_replace:
for y in ["weight", "bias"]:
k = "{}.transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
if k in sd:
sd[k_to] = sd.pop(k)
for y in ["weight", "bias"]:
k_from = "{}.transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
if k_from in sd:
weights = sd.pop(k_from)
shape_from = weights.shape[0] // 3
for x in range(3):
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
k_to = "{}.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
k_to = "{}encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
return sd