Some cleanups to how the text encoders are loaded.

This commit is contained in:
comfyanonymous
2024-02-19 10:29:18 -05:00
parent dbe0979b3f
commit d91f45ef28
3 changed files with 32 additions and 28 deletions

View File

@@ -138,8 +138,11 @@ class CLIP:
tokens = self.tokenize(text)
return self.encode_from_tokens(tokens)
def load_sd(self, sd):
return self.cond_stage_model.load_sd(sd)
def load_sd(self, sd, full_model=False):
if full_model:
return self.cond_stage_model.load_state_dict(sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
def get_sd(self):
return self.cond_stage_model.state_dict()
@@ -494,9 +497,6 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
load_device = model_management.get_torch_device()
class WeightsLoader(torch.nn.Module):
pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
@@ -521,14 +521,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
vae = VAE(sd=vae_sd)
if output_clip:
w = WeightsLoader()
clip_target = model_config.clip_target()
if clip_target is not None:
sd = model_config.process_clip_state_dict(sd)
if any(k.startswith('cond_stage_model.') for k in sd):
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
load_model_weights(w, sd)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
print("clip missing:", m)
if len(u) > 0:
print("clip unexpected:", u)
else:
print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")