V3 StableCascade nodes: use io.NodeOutput; adjust code style

This commit is contained in:
bigcat88 2025-07-12 10:33:02 +03:00
parent 0be2ab610a
commit c09213ebc1
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721

View File

@ -30,42 +30,14 @@ class StableCascade_EmptyLatentImage_V3(io.ComfyNodeV3):
node_id="StableCascade_EmptyLatentImage_V3",
category="latent/stable_cascade",
inputs=[
io.Int.Input(
"width",
default=1024,
min=256,
max=nodes.MAX_RESOLUTION, step=8,
),
io.Int.Input(
"height",
default=1024,
min=256,
max=nodes.MAX_RESOLUTION,
step=8,
),
io.Int.Input(
"compression",
default=42,
min=4,
max=128,
step=1,
),
io.Int.Input(
"batch_size",
default=1,
min=1,
max=4096,
),
io.Int.Input("width", default=1024,min=256,max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(
"stage_c",
display_name="stage_c",
),
io.Latent.Output(
"stage_b",
display_name="stage_b",
),
io.Latent.Output("stage_c", display_name="stage_c"),
io.Latent.Output("stage_b", display_name="stage_b"),
],
)
@ -73,11 +45,7 @@ class StableCascade_EmptyLatentImage_V3(io.ComfyNodeV3):
def execute(cls, width, height, compression, batch_size=1):
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
return ({
"samples": c_latent,
}, {
"samples": b_latent,
})
return io.NodeOutput({"samples": c_latent}, {"samples": b_latent})
class StableCascade_StageC_VAEEncode_V3(io.ComfyNodeV3):
@ -87,29 +55,13 @@ class StableCascade_StageC_VAEEncode_V3(io.ComfyNodeV3):
node_id="StableCascade_StageC_VAEEncode_V3",
category="latent/stable_cascade",
inputs=[
io.Image.Input(
"image",
),
io.Vae.Input(
"vae",
),
io.Int.Input(
"compression",
default=42,
min=4,
max=128,
step=1,
),
io.Image.Input("image"),
io.Vae.Input("vae"),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
],
outputs=[
io.Latent.Output(
"stage_c",
display_name="stage_c",
),
io.Latent.Output(
"stage_b",
display_name="stage_b",
),
io.Latent.Output("stage_c", display_name="stage_c"),
io.Latent.Output("stage_b", display_name="stage_b"),
],
)
@ -124,11 +76,7 @@ class StableCascade_StageC_VAEEncode_V3(io.ComfyNodeV3):
c_latent = vae.encode(s[:,:,:,:3])
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
return ({
"samples": c_latent,
}, {
"samples": b_latent,
})
return io.NodeOutput({"samples": c_latent}, {"samples": b_latent})
class StableCascade_StageB_Conditioning_V3(io.ComfyNodeV3):
@ -138,17 +86,11 @@ class StableCascade_StageB_Conditioning_V3(io.ComfyNodeV3):
node_id="StableCascade_StageB_Conditioning_V3",
category="conditioning/stable_cascade",
inputs=[
io.Conditioning.Input(
"conditioning",
),
io.Latent.Input(
"stage_c",
),
io.Conditioning.Input("conditioning"),
io.Latent.Input("stage_c"),
],
outputs=[
io.Conditioning.Output(
"CONDITIONING",
),
io.Conditioning.Output(),
],
)
@ -160,7 +102,7 @@ class StableCascade_StageB_Conditioning_V3(io.ComfyNodeV3):
d['stable_cascade_prior'] = stage_c['samples']
n = [t[0], d]
c.append(n)
return (c, )
return io.NodeOutput(c)
class StableCascade_SuperResolutionControlnet_V3(io.ComfyNodeV3):
@ -171,26 +113,13 @@ class StableCascade_SuperResolutionControlnet_V3(io.ComfyNodeV3):
category="_for_testing/stable_cascade",
is_experimental=True,
inputs=[
io.Image.Input(
"image",
),
io.Vae.Input(
"vae",
),
io.Image.Input("image"),
io.Vae.Input("vae"),
],
outputs=[
io.Image.Output(
"controlnet_input",
display_name="controlnet_input",
),
io.Latent.Output(
"stage_c",
display_name="stage_c",
),
io.Latent.Output(
"stage_b",
display_name="stage_b",
),
io.Image.Output("controlnet_input", display_name="controlnet_input"),
io.Latent.Output("stage_c", display_name="stage_c"),
io.Latent.Output("stage_b", display_name="stage_b"),
],
)
@ -203,11 +132,7 @@ class StableCascade_SuperResolutionControlnet_V3(io.ComfyNodeV3):
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
return (controlnet_input, {
"samples": c_latent,
}, {
"samples": b_latent,
})
return io.NodeOutput(controlnet_input, {"samples": c_latent}, {"samples": b_latent})
NODES_LIST: list[type[io.ComfyNodeV3]] = [