Add ControlNet support.

This commit is contained in:
comfyanonymous
2023-02-16 10:38:08 -05:00
parent bc69fb5245
commit 4efa67fa12
9 changed files with 580 additions and 63 deletions

View File

@@ -15,10 +15,12 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy"))
import comfy.samplers
import comfy.sd
import comfy.utils
import model_management
supported_ckpt_extensions = ['.ckpt']
supported_pt_extensions = ['.ckpt', '.pt', '.bin']
supported_ckpt_extensions = ['.ckpt', '.pth']
supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth']
try:
import safetensors.torch
supported_ckpt_extensions += ['.safetensors']
@@ -77,12 +79,14 @@ class ConditioningSetArea:
CATEGORY = "conditioning"
def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0):
c = copy.deepcopy(conditioning)
for t in c:
t[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
t[1]['strength'] = strength
t[1]['min_sigma'] = min_sigma
t[1]['max_sigma'] = max_sigma
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
n[1]['strength'] = strength
n[1]['min_sigma'] = min_sigma
n[1]['max_sigma'] = max_sigma
c.append(n)
return (c, )
class VAEDecode:
@@ -134,7 +138,6 @@ class VAEEncodeForInpaint:
CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask):
print(pixels.shape, mask.shape)
x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64
if pixels.shape[1] != x or pixels.shape[2] != y:
@@ -144,7 +147,6 @@ class VAEEncodeForInpaint:
#shave off a few pixels to keep things seamless
kernel_tensor = torch.ones((1, 1, 6, 6))
mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1)
print(mask_erosion.shape, pixels.shape)
for i in range(3):
pixels[:,:,:,i] -= 0.5
pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round()
@@ -211,6 +213,44 @@ class VAELoader:
vae = comfy.sd.VAE(ckpt_path=vae_path)
return (vae,)
class ControlNetLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
controlnet_dir = os.path.join(models_dir, "controlnet")
@classmethod
def INPUT_TYPES(s):
return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
CATEGORY = "loaders"
def load_controlnet(self, control_net_name):
controlnet_path = os.path.join(self.controlnet_dir, control_net_name)
controlnet = comfy.sd.load_controlnet(controlnet_path)
return (controlnet,)
class ControlNetApply:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ), "control_net": ("CONTROL_NET", ), "image": ("IMAGE", )}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_controlnet"
CATEGORY = "conditioning"
def apply_controlnet(self, conditioning, control_net, image):
c = []
control_hint = image.movedim(-1,1)
print(control_hint.shape)
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['control'] = control_net.copy().set_cond_hint(control_hint)
c.append(n)
return (c, )
class CLIPLoader:
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models")
clip_dir = os.path.join(models_dir, "clip")
@@ -248,22 +288,7 @@ class EmptyLatentImage:
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return ({"samples":latent}, )
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x]
else:
s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
class LatentUpscale:
upscale_methods = ["nearest-exact", "bilinear", "area"]
@@ -282,7 +307,7 @@ class LatentUpscale:
def upscale(self, samples, upscale_method, width, height, crop):
s = samples.copy()
s["samples"] = common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
return (s,)
class LatentRotate:
@@ -461,19 +486,26 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
positive_copy = []
negative_copy = []
control_nets = []
for p in positive:
t = p[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
if 'control' in p[1]:
control_nets += [p[1]['control']]
positive_copy += [[t] + p[1:]]
for n in negative:
t = n[0]
if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0])
t = t.to(device)
if 'control' in p[1]:
control_nets += [p[1]['control']]
negative_copy += [[t] + n[1:]]
model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets)))
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
else:
@@ -482,6 +514,9 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
samples = samples.cpu()
for c in control_nets:
c.cleanup()
out = latent.copy()
out["samples"] = samples
return (out, )
@@ -676,7 +711,7 @@ class ImageScale:
def upscale(self, image, upscale_method, width, height, crop):
samples = image.movedim(-1,1)
s = common_upscale(samples, width, height, upscale_method, crop)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop)
s = s.movedim(1,-1)
return (s,)
@@ -704,6 +739,8 @@ NODE_CLASS_MAPPINGS = {
"LatentCrop": LatentCrop,
"LoraLoader": LoraLoader,
"CLIPLoader": CLIPLoader,
"ControlNetApply": ControlNetApply,
"ControlNetLoader": ControlNetLoader,
}