Small refactor.

This commit is contained in:
comfyanonymous
2023-06-06 03:25:49 -04:00
parent a3a713b6c5
commit 0e425603fb
2 changed files with 17 additions and 16 deletions

View File

@@ -24,6 +24,18 @@ def load_torch_file(ckpt, safe_load=False):
return sd
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}.positional_embedding": "{}.embeddings.position_embedding.weight",
"{}.token_embedding.weight": "{}.embeddings.token_embedding.weight",
"{}.ln_final.weight": "{}.final_layer_norm.weight",
"{}.ln_final.bias": "{}.final_layer_norm.bias",
}
for k in keys_to_replace:
x = k.format(prefix_from)
if x in sd:
sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x)
resblock_to_replace = {
"ln_1": "layer_norm1",
"ln_2": "layer_norm2",