V3 ControlNet nodes: use io.NodeOutput; adjust code style

This commit is contained in:
bigcat88 2025-07-12 11:19:52 +03:00
parent 21c9d7b289
commit 535faa84f6
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721

View File

@ -11,67 +11,25 @@ class ControlNetApplyAdvanced_V3(io.ComfyNodeV3):
display_name="Apply ControlNet _V3", display_name="Apply ControlNet _V3",
category="conditioning/controlnet", category="conditioning/controlnet",
inputs=[ inputs=[
io.Conditioning.Input( io.Conditioning.Input("positive"),
"positive", io.Conditioning.Input("negative"),
display_name="positive", io.ControlNet.Input("control_net"),
), io.Image.Input("image"),
io.Conditioning.Input( io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
"negative", io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
display_name="negative", io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
), io.Vae.Input("vae", optional=True),
io.ControlNet.Input(
"control_net",
display_name="control_net",
),
io.Image.Input(
"image",
display_name="image",
),
io.Float.Input(
"strength",
display_name="strength",
default=1.0,
min=0.0,
max=10.0,
step=0.01,
),
io.Float.Input(
"start_percent",
display_name="start percent",
default=0.0,
min=0.0,
max=1.0,
step=0.001,
),
io.Float.Input(
"end_percent",
display_name="end percent",
default=1.0,
min=0.0,
max=1.0,
step=0.001,
),
io.Vae.Input(
"vae",
optional=True,
),
], ],
outputs=[ outputs=[
io.Conditioning.Output( io.Conditioning.Output("positive_out", display_name="positive"),
"positive_out", io.Conditioning.Output("negative_out", display_name="negative"),
display_name="positive",
),
io.Conditioning.Output(
"negative_out",
display_name="negative",
),
], ],
) )
@classmethod @classmethod
def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]): def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]) -> io.NodeOutput:
if strength == 0: if strength == 0:
return (positive, negative) return io.NodeOutput(positive, negative)
control_hint = image.movedim(-1,1) control_hint = image.movedim(-1,1)
cnets = {} cnets = {}
@ -95,7 +53,7 @@ class ControlNetApplyAdvanced_V3(io.ComfyNodeV3):
n = [t[0], d] n = [t[0], d]
c.append(n) c.append(n)
out.append(c) out.append(c)
return (out[0], out[1]) return io.NodeOutput(out[0], out[1])
class SetUnionControlNetType_V3(io.ComfyNodeV3): class SetUnionControlNetType_V3(io.ComfyNodeV3):
@ -105,25 +63,16 @@ class SetUnionControlNetType_V3(io.ComfyNodeV3):
node_id="SetUnionControlNetType_V3", node_id="SetUnionControlNetType_V3",
category="conditioning/controlnet", category="conditioning/controlnet",
inputs=[ inputs=[
io.ControlNet.Input( io.ControlNet.Input("control_net"),
"control_net", io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
display_name="control_net",
),
io.Combo.Input(
"type",
options=["auto"] + list(UNION_CONTROLNET_TYPES.keys()),
),
], ],
outputs=[ outputs=[
io.ControlNet.Output( io.ControlNet.Output("control_net_out"),
"control_net_out",
display_name="control_net",
),
], ],
) )
@classmethod @classmethod
def execute(cls, control_net, type): def execute(cls, control_net, type) -> io.NodeOutput:
control_net = control_net.copy() control_net = control_net.copy()
type_number = UNION_CONTROLNET_TYPES.get(type, -1) type_number = UNION_CONTROLNET_TYPES.get(type, -1)
if type_number >= 0: if type_number >= 0:
@ -131,7 +80,7 @@ class SetUnionControlNetType_V3(io.ComfyNodeV3):
else: else:
control_net.set_extra_arg("control_type", []) control_net.set_extra_arg("control_type", [])
return (control_net,) return io.NodeOutput(control_net)
class ControlNetInpaintingAliMamaApply_V3(ControlNetApplyAdvanced_V3): class ControlNetInpaintingAliMamaApply_V3(ControlNetApplyAdvanced_V3):
@ -141,70 +90,24 @@ class ControlNetInpaintingAliMamaApply_V3(ControlNetApplyAdvanced_V3):
node_id="ControlNetInpaintingAliMamaApply_V3", node_id="ControlNetInpaintingAliMamaApply_V3",
category="conditioning/controlnet", category="conditioning/controlnet",
inputs=[ inputs=[
io.Conditioning.Input( io.Conditioning.Input("positive"),
"positive", io.Conditioning.Input("negative"),
display_name="positive", io.ControlNet.Input("control_net"),
), io.Vae.Input("vae"),
io.Conditioning.Input( io.Image.Input("image"),
"negative", io.Mask.Input("mask"),
display_name="negative", io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
), io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.ControlNet.Input( io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
"control_net",
display_name="control_net",
),
io.Vae.Input(
"vae",
display_name="vae",
),
io.Image.Input(
"image",
display_name="image",
),
io.Mask.Input(
"mask",
display_name="mask",
),
io.Float.Input(
"strength",
display_name="strength",
default=1.0,
min=0.0,
max=10.0,
step=0.01,
),
io.Float.Input(
"start_percent",
display_name="start percent",
default=0.0,
min=0.0,
max=1.0,
step=0.001,
),
io.Float.Input(
"end_percent",
display_name="end percent",
default=1.0,
min=0.0,
max=1.0,
step=0.001,
),
], ],
outputs=[ outputs=[
io.Conditioning.Output( io.Conditioning.Output("positive_out", display_name="positive"),
"positive_out", io.Conditioning.Output("negative_out", display_name="negative"),
display_name="positive",
),
io.Conditioning.Output(
"negative_out",
display_name="negative",
),
], ],
) )
@classmethod @classmethod
def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent): def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput:
extra_concat = [] extra_concat = []
if control_net.concat_mask: if control_net.concat_mask:
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))