mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 05:25:23 +00:00
Add ControlNet support.
This commit is contained in:
76
comfy/sd.py
76
comfy/sd.py
@@ -6,6 +6,9 @@ import model_management
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.autoencoder import AutoencoderKL
|
||||
from omegaconf import OmegaConf
|
||||
from .cldm import cldm
|
||||
|
||||
from . import utils
|
||||
|
||||
def load_torch_file(ckpt):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
@@ -323,6 +326,79 @@ class VAE:
|
||||
samples = samples.cpu()
|
||||
return samples
|
||||
|
||||
class ControlNet:
|
||||
def __init__(self, control_model):
|
||||
self.control_model = control_model
|
||||
self.cond_hint_original = None
|
||||
self.cond_hint = None
|
||||
|
||||
def get_control(self, x_noisy, t, cond_txt):
|
||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device)
|
||||
print("set cond_hint", self.cond_hint.shape)
|
||||
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
|
||||
return control
|
||||
|
||||
def set_cond_hint(self, cond_hint):
|
||||
self.cond_hint_original = cond_hint
|
||||
return self
|
||||
|
||||
def cleanup(self):
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(self.control_model)
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
return c
|
||||
|
||||
def load_controlnet(ckpt_path):
|
||||
controlnet_data = load_torch_file(ckpt_path)
|
||||
pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||
pth = False
|
||||
sd2 = False
|
||||
key = 'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
|
||||
if pth_key in controlnet_data:
|
||||
pth = True
|
||||
key = pth_key
|
||||
elif key in controlnet_data:
|
||||
pass
|
||||
else:
|
||||
print("error checkpoint does not contain controlnet data", ckpt_path)
|
||||
return None
|
||||
|
||||
context_dim = controlnet_data[key].shape[1]
|
||||
control_model = cldm.ControlNet(image_size=32,
|
||||
in_channels=4,
|
||||
hint_channels=3,
|
||||
model_channels=320,
|
||||
attention_resolutions=[ 4, 2, 1 ],
|
||||
num_res_blocks=2,
|
||||
channel_mult=[ 1, 2, 4, 4 ],
|
||||
num_heads=8,
|
||||
use_spatial_transformer=True,
|
||||
transformer_depth=1,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=True,
|
||||
legacy=False)
|
||||
|
||||
if pth:
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.control_model = control_model
|
||||
w.load_state_dict(controlnet_data, strict=False)
|
||||
else:
|
||||
control_model.load_state_dict(controlnet_data, strict=False)
|
||||
|
||||
control = ControlNet(control_model)
|
||||
return control
|
||||
|
||||
|
||||
def load_clip(ckpt_path, embedding_directory=None):
|
||||
clip_data = load_torch_file(ckpt_path)
|
||||
config = {}
|
||||
|
Reference in New Issue
Block a user