diff --git a/comfy_extras/v3/nodes_stable_cascade.py b/comfy_extras/v3/nodes_stable_cascade.py index af2893641..36d7e3321 100644 --- a/comfy_extras/v3/nodes_stable_cascade.py +++ b/comfy_extras/v3/nodes_stable_cascade.py @@ -30,42 +30,14 @@ class StableCascade_EmptyLatentImage_V3(io.ComfyNodeV3): node_id="StableCascade_EmptyLatentImage_V3", category="latent/stable_cascade", inputs=[ - io.Int.Input( - "width", - default=1024, - min=256, - max=nodes.MAX_RESOLUTION, step=8, - ), - io.Int.Input( - "height", - default=1024, - min=256, - max=nodes.MAX_RESOLUTION, - step=8, - ), - io.Int.Input( - "compression", - default=42, - min=4, - max=128, - step=1, - ), - io.Int.Input( - "batch_size", - default=1, - min=1, - max=4096, - ), + io.Int.Input("width", default=1024,min=256,max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("compression", default=42, min=4, max=128, step=1), + io.Int.Input("batch_size", default=1, min=1, max=4096), ], outputs=[ - io.Latent.Output( - "stage_c", - display_name="stage_c", - ), - io.Latent.Output( - "stage_b", - display_name="stage_b", - ), + io.Latent.Output("stage_c", display_name="stage_c"), + io.Latent.Output("stage_b", display_name="stage_b"), ], ) @@ -73,11 +45,7 @@ class StableCascade_EmptyLatentImage_V3(io.ComfyNodeV3): def execute(cls, width, height, compression, batch_size=1): c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) b_latent = torch.zeros([batch_size, 4, height // 4, width // 4]) - return ({ - "samples": c_latent, - }, { - "samples": b_latent, - }) + return io.NodeOutput({"samples": c_latent}, {"samples": b_latent}) class StableCascade_StageC_VAEEncode_V3(io.ComfyNodeV3): @@ -87,29 +55,13 @@ class StableCascade_StageC_VAEEncode_V3(io.ComfyNodeV3): node_id="StableCascade_StageC_VAEEncode_V3", category="latent/stable_cascade", inputs=[ - io.Image.Input( - "image", - ), - io.Vae.Input( - "vae", - ), - io.Int.Input( - "compression", - default=42, - min=4, - max=128, - step=1, - ), + io.Image.Input("image"), + io.Vae.Input("vae"), + io.Int.Input("compression", default=42, min=4, max=128, step=1), ], outputs=[ - io.Latent.Output( - "stage_c", - display_name="stage_c", - ), - io.Latent.Output( - "stage_b", - display_name="stage_b", - ), + io.Latent.Output("stage_c", display_name="stage_c"), + io.Latent.Output("stage_b", display_name="stage_b"), ], ) @@ -124,11 +76,7 @@ class StableCascade_StageC_VAEEncode_V3(io.ComfyNodeV3): c_latent = vae.encode(s[:,:,:,:3]) b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) - return ({ - "samples": c_latent, - }, { - "samples": b_latent, - }) + return io.NodeOutput({"samples": c_latent}, {"samples": b_latent}) class StableCascade_StageB_Conditioning_V3(io.ComfyNodeV3): @@ -138,17 +86,11 @@ class StableCascade_StageB_Conditioning_V3(io.ComfyNodeV3): node_id="StableCascade_StageB_Conditioning_V3", category="conditioning/stable_cascade", inputs=[ - io.Conditioning.Input( - "conditioning", - ), - io.Latent.Input( - "stage_c", - ), + io.Conditioning.Input("conditioning"), + io.Latent.Input("stage_c"), ], outputs=[ - io.Conditioning.Output( - "CONDITIONING", - ), + io.Conditioning.Output(), ], ) @@ -160,7 +102,7 @@ class StableCascade_StageB_Conditioning_V3(io.ComfyNodeV3): d['stable_cascade_prior'] = stage_c['samples'] n = [t[0], d] c.append(n) - return (c, ) + return io.NodeOutput(c) class StableCascade_SuperResolutionControlnet_V3(io.ComfyNodeV3): @@ -171,26 +113,13 @@ class StableCascade_SuperResolutionControlnet_V3(io.ComfyNodeV3): category="_for_testing/stable_cascade", is_experimental=True, inputs=[ - io.Image.Input( - "image", - ), - io.Vae.Input( - "vae", - ), + io.Image.Input("image"), + io.Vae.Input("vae"), ], outputs=[ - io.Image.Output( - "controlnet_input", - display_name="controlnet_input", - ), - io.Latent.Output( - "stage_c", - display_name="stage_c", - ), - io.Latent.Output( - "stage_b", - display_name="stage_b", - ), + io.Image.Output("controlnet_input", display_name="controlnet_input"), + io.Latent.Output("stage_c", display_name="stage_c"), + io.Latent.Output("stage_b", display_name="stage_b"), ], ) @@ -203,11 +132,7 @@ class StableCascade_SuperResolutionControlnet_V3(io.ComfyNodeV3): c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) - return (controlnet_input, { - "samples": c_latent, - }, { - "samples": b_latent, - }) + return io.NodeOutput(controlnet_input, {"samples": c_latent}, {"samples": b_latent}) NODES_LIST: list[type[io.ComfyNodeV3]] = [