mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 16:26:39 +00:00
converted ImageRebatch, LatentRebatch, DifferentialDiffusion
This commit is contained in:
parent
18ed598fa1
commit
2a7793394f
50
comfy_extras/v3/nodes_differential_diffusion.py
Normal file
50
comfy_extras/v3/nodes_differential_diffusion.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
class DifferentialDiffusion(io.ComfyNodeV3):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.SchemaV3(
|
||||||
|
node_id="DifferentialDiffusion_V3",
|
||||||
|
display_name="Differential Diffusion _V3",
|
||||||
|
category="_for_testing",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input(id="model"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model):
|
||||||
|
model = model.clone()
|
||||||
|
model.set_model_denoise_mask_function(cls.forward)
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict):
|
||||||
|
model = extra_options["model"]
|
||||||
|
step_sigmas = extra_options["sigmas"]
|
||||||
|
sigma_to = model.inner_model.model_sampling.sigma_min
|
||||||
|
if step_sigmas[-1] > sigma_to:
|
||||||
|
sigma_to = step_sigmas[-1]
|
||||||
|
sigma_from = step_sigmas[0]
|
||||||
|
|
||||||
|
ts_from = model.inner_model.model_sampling.timestep(sigma_from)
|
||||||
|
ts_to = model.inner_model.model_sampling.timestep(sigma_to)
|
||||||
|
current_ts = model.inner_model.model_sampling.timestep(sigma[0])
|
||||||
|
|
||||||
|
threshold = (current_ts - ts_to) / (ts_from - ts_to)
|
||||||
|
|
||||||
|
return (denoise_mask >= threshold).to(denoise_mask.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
DifferentialDiffusion,
|
||||||
|
]
|
148
comfy_extras/v3/nodes_rebatch.py
Normal file
148
comfy_extras/v3/nodes_rebatch.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_api.v3 import io
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRebatch(io.ComfyNodeV3):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.SchemaV3(
|
||||||
|
node_id="RebatchImages_V3",
|
||||||
|
display_name="Rebatch Images _V3",
|
||||||
|
category="image/batch",
|
||||||
|
is_input_list=True,
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("images"),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output("IMAGE", display_name="IMAGE", is_output_list=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, images, batch_size):
|
||||||
|
batch_size = batch_size[0]
|
||||||
|
|
||||||
|
output_list = []
|
||||||
|
all_images = []
|
||||||
|
for img in images:
|
||||||
|
for i in range(img.shape[0]):
|
||||||
|
all_images.append(img[i:i+1])
|
||||||
|
|
||||||
|
for i in range(0, len(all_images), batch_size):
|
||||||
|
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
|
||||||
|
|
||||||
|
return io.NodeOutput(output_list)
|
||||||
|
|
||||||
|
|
||||||
|
class LatentRebatch(io.ComfyNodeV3):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.SchemaV3(
|
||||||
|
node_id="RebatchLatents_V3",
|
||||||
|
display_name="Rebatch Latents _V3",
|
||||||
|
category="latent/batch",
|
||||||
|
is_input_list=True,
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("latents"),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(is_output_list=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_batch(latents, list_ind, offset):
|
||||||
|
"""prepare a batch out of the list of latents"""
|
||||||
|
samples = latents[list_ind]['samples']
|
||||||
|
shape = samples.shape
|
||||||
|
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
|
||||||
|
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
|
||||||
|
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
|
||||||
|
if mask.shape[0] < samples.shape[0]:
|
||||||
|
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
|
||||||
|
if 'batch_index' in latents[list_ind]:
|
||||||
|
batch_inds = latents[list_ind]['batch_index']
|
||||||
|
else:
|
||||||
|
batch_inds = [x+offset for x in range(shape[0])]
|
||||||
|
return samples, mask, batch_inds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_slices(indexable, num, batch_size):
|
||||||
|
"""divides an indexable object into num slices of length batch_size, and a remainder"""
|
||||||
|
slices = []
|
||||||
|
for i in range(num):
|
||||||
|
slices.append(indexable[i*batch_size:(i+1)*batch_size])
|
||||||
|
if num * batch_size < len(indexable):
|
||||||
|
return slices, indexable[num * batch_size:]
|
||||||
|
else:
|
||||||
|
return slices, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def slice_batch(batch, num, batch_size):
|
||||||
|
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
|
||||||
|
return list(zip(*result))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cat_batch(batch1, batch2):
|
||||||
|
if batch1[0] is None:
|
||||||
|
return batch2
|
||||||
|
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, latents, batch_size):
|
||||||
|
batch_size = batch_size[0]
|
||||||
|
|
||||||
|
output_list = []
|
||||||
|
current_batch = (None, None, None)
|
||||||
|
processed = 0
|
||||||
|
|
||||||
|
for i in range(len(latents)):
|
||||||
|
# fetch new entry of list
|
||||||
|
#samples, masks, indices = self.get_batch(latents, i)
|
||||||
|
next_batch = cls.get_batch(latents, i, processed)
|
||||||
|
processed += len(next_batch[2])
|
||||||
|
# set to current if current is None
|
||||||
|
if current_batch[0] is None:
|
||||||
|
current_batch = next_batch
|
||||||
|
# add previous to list if dimensions do not match
|
||||||
|
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
|
||||||
|
sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
|
||||||
|
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||||
|
current_batch = next_batch
|
||||||
|
# cat if everything checks out
|
||||||
|
else:
|
||||||
|
current_batch = cls.cat_batch(current_batch, next_batch)
|
||||||
|
|
||||||
|
# add to list if dimensions gone above target batch size
|
||||||
|
if current_batch[0].shape[0] > batch_size:
|
||||||
|
num = current_batch[0].shape[0] // batch_size
|
||||||
|
sliced, remainder = cls.slice_batch(current_batch, num, batch_size)
|
||||||
|
|
||||||
|
for i in range(num):
|
||||||
|
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
||||||
|
|
||||||
|
current_batch = remainder
|
||||||
|
|
||||||
|
#add remainder
|
||||||
|
if current_batch[0] is not None:
|
||||||
|
sliced, _ = cls.slice_batch(current_batch, 1, batch_size)
|
||||||
|
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
||||||
|
|
||||||
|
#get rid of empty masks
|
||||||
|
for s in output_list:
|
||||||
|
if s['noise_mask'].mean() == 1.0:
|
||||||
|
del s['noise_mask']
|
||||||
|
|
||||||
|
return io.NodeOutput(output_list)
|
||||||
|
|
||||||
|
|
||||||
|
NODES_LIST = [
|
||||||
|
ImageRebatch,
|
||||||
|
LatentRebatch,
|
||||||
|
]
|
2
nodes.py
2
nodes.py
@ -2313,6 +2313,7 @@ def init_builtin_extra_nodes():
|
|||||||
"v3/nodes_cond.py",
|
"v3/nodes_cond.py",
|
||||||
"v3/nodes_controlnet.py",
|
"v3/nodes_controlnet.py",
|
||||||
"v3/nodes_cosmos.py",
|
"v3/nodes_cosmos.py",
|
||||||
|
"v3/nodes_differential_diffusion.py",
|
||||||
"v3/nodes_flux.py",
|
"v3/nodes_flux.py",
|
||||||
"v3/nodes_freelunch.py",
|
"v3/nodes_freelunch.py",
|
||||||
"v3/nodes_fresca.py",
|
"v3/nodes_fresca.py",
|
||||||
@ -2321,6 +2322,7 @@ def init_builtin_extra_nodes():
|
|||||||
"v3/nodes_mask.py",
|
"v3/nodes_mask.py",
|
||||||
"v3/nodes_preview_any.py",
|
"v3/nodes_preview_any.py",
|
||||||
"v3/nodes_primitive.py",
|
"v3/nodes_primitive.py",
|
||||||
|
"v3/nodes_rebatch.py",
|
||||||
"v3/nodes_stable_cascade.py",
|
"v3/nodes_stable_cascade.py",
|
||||||
"v3/nodes_webcam.py",
|
"v3/nodes_webcam.py",
|
||||||
]
|
]
|
||||||
|
@ -23,7 +23,7 @@ lint.select = [
|
|||||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||||
"F",
|
"F",
|
||||||
]
|
]
|
||||||
ignore = ["E501"] # disable line-length checking
|
lint.ignore = ["E501"] # disable line-length checking
|
||||||
exclude = ["*.ipynb"]
|
exclude = ["*.ipynb"]
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user