Support loading unet files in diffusers format.

This commit is contained in:
comfyanonymous
2023-07-05 17:34:45 -04:00
parent e57cba4c61
commit af7a49916b
9 changed files with 123 additions and 15 deletions

View File

@@ -117,6 +117,23 @@ UNET_MAP_RESNET = {
"out_layers.0.bias": "norm2.bias",
}
UNET_MAP_BASIC = {
"label_emb.0.0.weight": "class_embedding.linear_1.weight",
"label_emb.0.0.bias": "class_embedding.linear_1.bias",
"label_emb.0.2.weight": "class_embedding.linear_2.weight",
"label_emb.0.2.bias": "class_embedding.linear_2.bias",
"input_blocks.0.0.weight": "conv_in.weight",
"input_blocks.0.0.bias": "conv_in.bias",
"out.0.weight": "conv_norm_out.weight",
"out.0.bias": "conv_norm_out.bias",
"out.2.weight": "conv_out.weight",
"out.2.bias": "conv_out.bias",
"time_embed.0.weight": "time_embedding.linear_1.weight",
"time_embed.0.bias": "time_embedding.linear_1.bias",
"time_embed.2.weight": "time_embedding.linear_2.weight",
"time_embed.2.bias": "time_embedding.linear_2.bias"
}
def unet_to_diffusers(unet_config):
num_res_blocks = unet_config["num_res_blocks"]
attention_resolutions = unet_config["attention_resolutions"]
@@ -185,6 +202,10 @@ def unet_to_diffusers(unet_config):
for k in ["weight", "bias"]:
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
n += 1
for k in UNET_MAP_BASIC:
diffusers_unet_map[UNET_MAP_BASIC[k]] = k
return diffusers_unet_map
def convert_sd_to(state_dict, dtype):