convert Stable Cascade nodes to V3 schema (#9373)

This commit is contained in:
Alexander Piskun
2025-08-31 06:19:21 +03:00
committed by GitHub
parent 4449e14769
commit f949094b3c

View File

@@ -17,55 +17,61 @@
""" """
import torch import torch
import nodes from typing_extensions import override
import comfy.utils import comfy.utils
import nodes
from comfy_api.latest import ComfyExtension, io
class StableCascade_EmptyLatentImage: class StableCascade_EmptyLatentImage(io.ComfyNode):
def __init__(self, device="cpu"): @classmethod
self.device = device def define_schema(cls):
return io.Schema(
node_id="StableCascade_EmptyLatentImage",
category="latent/stable_cascade",
inputs=[
io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, width, height, compression, batch_size=1):
return {"required": {
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, width, height, compression, batch_size=1):
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4]) b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
return ({ return io.NodeOutput({
"samples": c_latent, "samples": c_latent,
}, { }, {
"samples": b_latent, "samples": b_latent,
}) })
class StableCascade_StageC_VAEEncode:
def __init__(self, device="cpu"): class StableCascade_StageC_VAEEncode(io.ComfyNode):
self.device = device @classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_StageC_VAEEncode",
category="latent/stable_cascade",
inputs=[
io.Image.Input("image"),
io.Vae.Input("vae"),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
],
outputs=[
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, image, vae, compression):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, image, vae, compression):
width = image.shape[-2] width = image.shape[-2]
height = image.shape[-3] height = image.shape[-3]
out_width = (width // compression) * vae.downscale_ratio out_width = (width // compression) * vae.downscale_ratio
@@ -75,51 +81,59 @@ class StableCascade_StageC_VAEEncode:
c_latent = vae.encode(s[:,:,:,:3]) c_latent = vae.encode(s[:,:,:,:3])
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
return ({ return io.NodeOutput({
"samples": c_latent, "samples": c_latent,
}, { }, {
"samples": b_latent, "samples": b_latent,
}) })
class StableCascade_StageB_Conditioning:
class StableCascade_StageB_Conditioning(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "conditioning": ("CONDITIONING",), return io.Schema(
"stage_c": ("LATENT",), node_id="StableCascade_StageB_Conditioning",
}} category="conditioning/stable_cascade",
RETURN_TYPES = ("CONDITIONING",) inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("stage_c"),
],
outputs=[
io.Conditioning.Output(),
],
)
FUNCTION = "set_prior" @classmethod
def execute(cls, conditioning, stage_c):
CATEGORY = "conditioning/stable_cascade"
def set_prior(self, conditioning, stage_c):
c = [] c = []
for t in conditioning: for t in conditioning:
d = t[1].copy() d = t[1].copy()
d['stable_cascade_prior'] = stage_c['samples'] d["stable_cascade_prior"] = stage_c["samples"]
n = [t[0], d] n = [t[0], d]
c.append(n) c.append(n)
return (c, ) return io.NodeOutput(c)
class StableCascade_SuperResolutionControlnet:
def __init__(self, device="cpu"): class StableCascade_SuperResolutionControlnet(io.ComfyNode):
self.device = device @classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_SuperResolutionControlnet",
category="_for_testing/stable_cascade",
is_experimental=True,
inputs=[
io.Image.Input("image"),
io.Vae.Input("vae"),
],
outputs=[
io.Image.Output(display_name="controlnet_input"),
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, image, vae):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
}}
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate"
EXPERIMENTAL = True
CATEGORY = "_for_testing/stable_cascade"
def generate(self, image, vae):
width = image.shape[-2] width = image.shape[-2]
height = image.shape[-3] height = image.shape[-3]
batch_size = image.shape[0] batch_size = image.shape[0]
@@ -127,15 +141,22 @@ class StableCascade_SuperResolutionControlnet:
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
return (controlnet_input, { return io.NodeOutput(controlnet_input, {
"samples": c_latent, "samples": c_latent,
}, { }, {
"samples": b_latent, "samples": b_latent,
}) })
NODE_CLASS_MAPPINGS = {
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, class StableCascadeExtension(ComfyExtension):
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, @override
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet, return [
} StableCascade_EmptyLatentImage,
StableCascade_StageB_Conditioning,
StableCascade_StageC_VAEEncode,
StableCascade_SuperResolutionControlnet,
]
async def comfy_entrypoint() -> StableCascadeExtension:
return StableCascadeExtension()