mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-09 07:37:14 +00:00
Slightly cleaner code.
This commit is contained in:
parent
0108616b77
commit
73f60740c8
@ -32,11 +32,10 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
|
|||||||
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
|
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
|
||||||
sd[y] = sd.pop(x)
|
sd[y] = sd.pop(x)
|
||||||
|
|
||||||
try:
|
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd:
|
||||||
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in sd:
|
ids = sd['cond_stage_model.transformer.text_model.embeddings.position_ids']
|
||||||
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = sd['cond_stage_model.transformer.text_model.embeddings.position_ids'].round()
|
if ids.dtype == torch.float32:
|
||||||
except:
|
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||||
pass
|
|
||||||
|
|
||||||
for x in load_state_dict_to:
|
for x in load_state_dict_to:
|
||||||
x.load_state_dict(sd, strict=False)
|
x.load_state_dict(sd, strict=False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user