mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-09 19:17:44 +00:00
convert Stable Cascade nodes to V3 schema (#9373)
This commit is contained in:
@@ -17,55 +17,61 @@
|
||||
"""
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.utils
|
||||
import nodes
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class StableCascade_EmptyLatentImage:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
class StableCascade_EmptyLatentImage(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_EmptyLatentImage",
|
||||
category="latent/stable_cascade",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("compression", default=42, min=4, max=128, step=1),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="stage_c"),
|
||||
io.Latent.Output(display_name="stage_b"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT", "LATENT")
|
||||
RETURN_NAMES = ("stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/stable_cascade"
|
||||
|
||||
def generate(self, width, height, compression, batch_size=1):
|
||||
def execute(cls, width, height, compression, batch_size=1):
|
||||
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
|
||||
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
|
||||
return ({
|
||||
return io.NodeOutput({
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
class StableCascade_StageC_VAEEncode:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
class StableCascade_StageC_VAEEncode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_StageC_VAEEncode",
|
||||
category="latent/stable_cascade",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("compression", default=42, min=4, max=128, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="stage_c"),
|
||||
io.Latent.Output(display_name="stage_b"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("IMAGE",),
|
||||
"vae": ("VAE", ),
|
||||
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT", "LATENT")
|
||||
RETURN_NAMES = ("stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/stable_cascade"
|
||||
|
||||
def generate(self, image, vae, compression):
|
||||
def execute(cls, image, vae, compression):
|
||||
width = image.shape[-2]
|
||||
height = image.shape[-3]
|
||||
out_width = (width // compression) * vae.downscale_ratio
|
||||
@@ -75,51 +81,59 @@ class StableCascade_StageC_VAEEncode:
|
||||
|
||||
c_latent = vae.encode(s[:,:,:,:3])
|
||||
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
|
||||
return ({
|
||||
return io.NodeOutput({
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
class StableCascade_StageB_Conditioning:
|
||||
|
||||
class StableCascade_StageB_Conditioning(io.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "conditioning": ("CONDITIONING",),
|
||||
"stage_c": ("LATENT",),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_StageB_Conditioning",
|
||||
category="conditioning/stable_cascade",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Latent.Input("stage_c"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
FUNCTION = "set_prior"
|
||||
|
||||
CATEGORY = "conditioning/stable_cascade"
|
||||
|
||||
def set_prior(self, conditioning, stage_c):
|
||||
@classmethod
|
||||
def execute(cls, conditioning, stage_c):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
d['stable_cascade_prior'] = stage_c['samples']
|
||||
d["stable_cascade_prior"] = stage_c["samples"]
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
return io.NodeOutput(c)
|
||||
|
||||
class StableCascade_SuperResolutionControlnet:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
class StableCascade_SuperResolutionControlnet(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StableCascade_SuperResolutionControlnet",
|
||||
category="_for_testing/stable_cascade",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="controlnet_input"),
|
||||
io.Latent.Output(display_name="stage_c"),
|
||||
io.Latent.Output(display_name="stage_b"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("IMAGE",),
|
||||
"vae": ("VAE", ),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
|
||||
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
|
||||
FUNCTION = "generate"
|
||||
|
||||
EXPERIMENTAL = True
|
||||
CATEGORY = "_for_testing/stable_cascade"
|
||||
|
||||
def generate(self, image, vae):
|
||||
def execute(cls, image, vae):
|
||||
width = image.shape[-2]
|
||||
height = image.shape[-3]
|
||||
batch_size = image.shape[0]
|
||||
@@ -127,15 +141,22 @@ class StableCascade_SuperResolutionControlnet:
|
||||
|
||||
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
|
||||
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
|
||||
return (controlnet_input, {
|
||||
return io.NodeOutput(controlnet_input, {
|
||||
"samples": c_latent,
|
||||
}, {
|
||||
"samples": b_latent,
|
||||
})
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
|
||||
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
|
||||
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
|
||||
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet,
|
||||
}
|
||||
|
||||
class StableCascadeExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
StableCascade_EmptyLatentImage,
|
||||
StableCascade_StageB_Conditioning,
|
||||
StableCascade_StageC_VAEEncode,
|
||||
StableCascade_SuperResolutionControlnet,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> StableCascadeExtension:
|
||||
return StableCascadeExtension()
|
||||
|
Reference in New Issue
Block a user