mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
convert Stable Cascade nodes to V3 schema (#9373)
This commit is contained in:
@@ -17,55 +17,61 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
from typing_extensions import override
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import nodes
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
class StableCascade_EmptyLatentImage:
|
class StableCascade_EmptyLatentImage(io.ComfyNode):
|
||||||
def __init__(self, device="cpu"):
|
@classmethod
|
||||||
self.device = device
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="StableCascade_EmptyLatentImage",
|
||||||
|
category="latent/stable_cascade",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
|
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
|
io.Int.Input("compression", default=42, min=4, max=128, step=1),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="stage_c"),
|
||||||
|
io.Latent.Output(display_name="stage_b"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, width, height, compression, batch_size=1):
|
||||||
return {"required": {
|
|
||||||
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
|
||||||
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
|
||||||
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
|
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("LATENT", "LATENT")
|
|
||||||
RETURN_NAMES = ("stage_c", "stage_b")
|
|
||||||
FUNCTION = "generate"
|
|
||||||
|
|
||||||
CATEGORY = "latent/stable_cascade"
|
|
||||||
|
|
||||||
def generate(self, width, height, compression, batch_size=1):
|
|
||||||
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
|
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
|
||||||
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
|
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
|
||||||
return ({
|
return io.NodeOutput({
|
||||||
"samples": c_latent,
|
"samples": c_latent,
|
||||||
}, {
|
}, {
|
||||||
"samples": b_latent,
|
"samples": b_latent,
|
||||||
})
|
})
|
||||||
|
|
||||||
class StableCascade_StageC_VAEEncode:
|
|
||||||
def __init__(self, device="cpu"):
|
class StableCascade_StageC_VAEEncode(io.ComfyNode):
|
||||||
self.device = device
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="StableCascade_StageC_VAEEncode",
|
||||||
|
category="latent/stable_cascade",
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("compression", default=42, min=4, max=128, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="stage_c"),
|
||||||
|
io.Latent.Output(display_name="stage_b"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, image, vae, compression):
|
||||||
return {"required": {
|
|
||||||
"image": ("IMAGE",),
|
|
||||||
"vae": ("VAE", ),
|
|
||||||
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("LATENT", "LATENT")
|
|
||||||
RETURN_NAMES = ("stage_c", "stage_b")
|
|
||||||
FUNCTION = "generate"
|
|
||||||
|
|
||||||
CATEGORY = "latent/stable_cascade"
|
|
||||||
|
|
||||||
def generate(self, image, vae, compression):
|
|
||||||
width = image.shape[-2]
|
width = image.shape[-2]
|
||||||
height = image.shape[-3]
|
height = image.shape[-3]
|
||||||
out_width = (width // compression) * vae.downscale_ratio
|
out_width = (width // compression) * vae.downscale_ratio
|
||||||
@@ -75,51 +81,59 @@ class StableCascade_StageC_VAEEncode:
|
|||||||
|
|
||||||
c_latent = vae.encode(s[:,:,:,:3])
|
c_latent = vae.encode(s[:,:,:,:3])
|
||||||
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
|
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
|
||||||
return ({
|
return io.NodeOutput({
|
||||||
"samples": c_latent,
|
"samples": c_latent,
|
||||||
}, {
|
}, {
|
||||||
"samples": b_latent,
|
"samples": b_latent,
|
||||||
})
|
})
|
||||||
|
|
||||||
class StableCascade_StageB_Conditioning:
|
|
||||||
|
class StableCascade_StageB_Conditioning(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "conditioning": ("CONDITIONING",),
|
return io.Schema(
|
||||||
"stage_c": ("LATENT",),
|
node_id="StableCascade_StageB_Conditioning",
|
||||||
}}
|
category="conditioning/stable_cascade",
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
inputs=[
|
||||||
|
io.Conditioning.Input("conditioning"),
|
||||||
|
io.Latent.Input("stage_c"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "set_prior"
|
@classmethod
|
||||||
|
def execute(cls, conditioning, stage_c):
|
||||||
CATEGORY = "conditioning/stable_cascade"
|
|
||||||
|
|
||||||
def set_prior(self, conditioning, stage_c):
|
|
||||||
c = []
|
c = []
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
d = t[1].copy()
|
d = t[1].copy()
|
||||||
d['stable_cascade_prior'] = stage_c['samples']
|
d["stable_cascade_prior"] = stage_c["samples"]
|
||||||
n = [t[0], d]
|
n = [t[0], d]
|
||||||
c.append(n)
|
c.append(n)
|
||||||
return (c, )
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
class StableCascade_SuperResolutionControlnet:
|
|
||||||
def __init__(self, device="cpu"):
|
class StableCascade_SuperResolutionControlnet(io.ComfyNode):
|
||||||
self.device = device
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="StableCascade_SuperResolutionControlnet",
|
||||||
|
category="_for_testing/stable_cascade",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(display_name="controlnet_input"),
|
||||||
|
io.Latent.Output(display_name="stage_c"),
|
||||||
|
io.Latent.Output(display_name="stage_b"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, image, vae):
|
||||||
return {"required": {
|
|
||||||
"image": ("IMAGE",),
|
|
||||||
"vae": ("VAE", ),
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
|
|
||||||
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
|
||||||
FUNCTION = "generate"
|
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
CATEGORY = "_for_testing/stable_cascade"
|
|
||||||
|
|
||||||
def generate(self, image, vae):
|
|
||||||
width = image.shape[-2]
|
width = image.shape[-2]
|
||||||
height = image.shape[-3]
|
height = image.shape[-3]
|
||||||
batch_size = image.shape[0]
|
batch_size = image.shape[0]
|
||||||
@@ -127,15 +141,22 @@ class StableCascade_SuperResolutionControlnet:
|
|||||||
|
|
||||||
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
|
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
|
||||||
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
|
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
|
||||||
return (controlnet_input, {
|
return io.NodeOutput(controlnet_input, {
|
||||||
"samples": c_latent,
|
"samples": c_latent,
|
||||||
}, {
|
}, {
|
||||||
"samples": b_latent,
|
"samples": b_latent,
|
||||||
})
|
})
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
|
class StableCascadeExtension(ComfyExtension):
|
||||||
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
|
@override
|
||||||
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet,
|
return [
|
||||||
}
|
StableCascade_EmptyLatentImage,
|
||||||
|
StableCascade_StageB_Conditioning,
|
||||||
|
StableCascade_StageC_VAEEncode,
|
||||||
|
StableCascade_SuperResolutionControlnet,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> StableCascadeExtension:
|
||||||
|
return StableCascadeExtension()
|
||||||
|
Reference in New Issue
Block a user