diff --git a/comfy_extras/v3/nodes_differential_diffusion.py b/comfy_extras/v3/nodes_differential_diffusion.py new file mode 100644 index 000000000..85725138b --- /dev/null +++ b/comfy_extras/v3/nodes_differential_diffusion.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import torch + +from comfy_api.v3 import io + + +class DifferentialDiffusion(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="DifferentialDiffusion_V3", + display_name="Differential Diffusion _V3", + category="_for_testing", + inputs=[ + io.Model.Input(id="model"), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model): + model = model.clone() + model.set_model_denoise_mask_function(cls.forward) + return io.NodeOutput(model) + + @classmethod + def forward(cls, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): + model = extra_options["model"] + step_sigmas = extra_options["sigmas"] + sigma_to = model.inner_model.model_sampling.sigma_min + if step_sigmas[-1] > sigma_to: + sigma_to = step_sigmas[-1] + sigma_from = step_sigmas[0] + + ts_from = model.inner_model.model_sampling.timestep(sigma_from) + ts_to = model.inner_model.model_sampling.timestep(sigma_to) + current_ts = model.inner_model.model_sampling.timestep(sigma[0]) + + threshold = (current_ts - ts_to) / (ts_from - ts_to) + + return (denoise_mask >= threshold).to(denoise_mask.dtype) + + +NODES_LIST = [ + DifferentialDiffusion, +] diff --git a/comfy_extras/v3/nodes_rebatch.py b/comfy_extras/v3/nodes_rebatch.py new file mode 100644 index 000000000..0cf1898bb --- /dev/null +++ b/comfy_extras/v3/nodes_rebatch.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import torch + +from comfy_api.v3 import io + + +class ImageRebatch(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="RebatchImages_V3", + display_name="Rebatch Images _V3", + category="image/batch", + is_input_list=True, + inputs=[ + io.Image.Input("images"), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Image.Output("IMAGE", display_name="IMAGE", is_output_list=True), + ], + ) + + @classmethod + def execute(cls, images, batch_size): + batch_size = batch_size[0] + + output_list = [] + all_images = [] + for img in images: + for i in range(img.shape[0]): + all_images.append(img[i:i+1]) + + for i in range(0, len(all_images), batch_size): + output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) + + return io.NodeOutput(output_list) + + +class LatentRebatch(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="RebatchLatents_V3", + display_name="Rebatch Latents _V3", + category="latent/batch", + is_input_list=True, + inputs=[ + io.Latent.Input("latents"), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(is_output_list=True), + ], + ) + + @staticmethod + def get_batch(latents, list_ind, offset): + """prepare a batch out of the list of latents""" + samples = latents[list_ind]['samples'] + shape = samples.shape + mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') + if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: + torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") + if mask.shape[0] < samples.shape[0]: + mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] + if 'batch_index' in latents[list_ind]: + batch_inds = latents[list_ind]['batch_index'] + else: + batch_inds = [x+offset for x in range(shape[0])] + return samples, mask, batch_inds + + @staticmethod + def get_slices(indexable, num, batch_size): + """divides an indexable object into num slices of length batch_size, and a remainder""" + slices = [] + for i in range(num): + slices.append(indexable[i*batch_size:(i+1)*batch_size]) + if num * batch_size < len(indexable): + return slices, indexable[num * batch_size:] + else: + return slices, None + + @staticmethod + def slice_batch(batch, num, batch_size): + result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] + return list(zip(*result)) + + @staticmethod + def cat_batch(batch1, batch2): + if batch1[0] is None: + return batch2 + result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] + return result + + @classmethod + def execute(cls, latents, batch_size): + batch_size = batch_size[0] + + output_list = [] + current_batch = (None, None, None) + processed = 0 + + for i in range(len(latents)): + # fetch new entry of list + #samples, masks, indices = self.get_batch(latents, i) + next_batch = cls.get_batch(latents, i, processed) + processed += len(next_batch[2]) + # set to current if current is None + if current_batch[0] is None: + current_batch = next_batch + # add previous to list if dimensions do not match + elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: + sliced, _ = cls.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + current_batch = next_batch + # cat if everything checks out + else: + current_batch = cls.cat_batch(current_batch, next_batch) + + # add to list if dimensions gone above target batch size + if current_batch[0].shape[0] > batch_size: + num = current_batch[0].shape[0] // batch_size + sliced, remainder = cls.slice_batch(current_batch, num, batch_size) + + for i in range(num): + output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) + + current_batch = remainder + + #add remainder + if current_batch[0] is not None: + sliced, _ = cls.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + + #get rid of empty masks + for s in output_list: + if s['noise_mask'].mean() == 1.0: + del s['noise_mask'] + + return io.NodeOutput(output_list) + + +NODES_LIST = [ + ImageRebatch, + LatentRebatch, +] diff --git a/nodes.py b/nodes.py index bbe5167f8..50956edb8 100644 --- a/nodes.py +++ b/nodes.py @@ -2313,6 +2313,7 @@ def init_builtin_extra_nodes(): "v3/nodes_cond.py", "v3/nodes_controlnet.py", "v3/nodes_cosmos.py", + "v3/nodes_differential_diffusion.py", "v3/nodes_flux.py", "v3/nodes_freelunch.py", "v3/nodes_fresca.py", @@ -2321,6 +2322,7 @@ def init_builtin_extra_nodes(): "v3/nodes_mask.py", "v3/nodes_preview_any.py", "v3/nodes_primitive.py", + "v3/nodes_rebatch.py", "v3/nodes_stable_cascade.py", "v3/nodes_webcam.py", ] diff --git a/pyproject.toml b/pyproject.toml index 55bc29518..f3bdadbce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ lint.select = [ # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f "F", ] -ignore = ["E501"] # disable line-length checking +lint.ignore = ["E501"] # disable line-length checking exclude = ["*.ipynb"] [tool.ruff.lint.per-file-ignores]