IP2P model loading support.

This is the code to load the model and inference it with only a text
prompt. This commit does not contain the nodes to properly use it with an
image input.

This supports both the original SD1 instructpix2pix model and the
diffusers SDXL one.
This commit is contained in:
comfyanonymous
2024-03-31 01:25:16 -04:00
parent 96b4c757cf
commit 575acb69e4
4 changed files with 84 additions and 6 deletions

View File

@@ -16,6 +16,8 @@ class BASE:
"num_head_channels": 64,
}
required_keys = {}
clip_prefix = []
clip_vision_prefix = None
noise_aug_config = None
@@ -28,10 +30,14 @@ class BASE:
manual_cast_dtype = None
@classmethod
def matches(s, unet_config):
def matches(s, unet_config, state_dict=None):
for k in s.unet_config:
if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False
if state_dict is not None:
for k in s.required_keys:
if k not in state_dict:
return False
return True
def model_type(self, state_dict, prefix=""):