from __future__ import annotations import torch from comfy_api.v3 import io class ImageRebatch(io.ComfyNodeV3): @classmethod def define_schema(cls): return io.Schema( node_id="RebatchImages_V3", display_name="Rebatch Images _V3", category="image/batch", is_input_list=True, inputs=[ io.Image.Input("images"), io.Int.Input("batch_size", default=1, min=1, max=4096), ], outputs=[ io.Image.Output(display_name="IMAGE", is_output_list=True), ], ) @classmethod def execute(cls, images, batch_size): batch_size = batch_size[0] output_list = [] all_images = [] for img in images: for i in range(img.shape[0]): all_images.append(img[i:i+1]) for i in range(0, len(all_images), batch_size): output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) return io.NodeOutput(output_list) class LatentRebatch(io.ComfyNodeV3): @classmethod def define_schema(cls): return io.Schema( node_id="RebatchLatents_V3", display_name="Rebatch Latents _V3", category="latent/batch", is_input_list=True, inputs=[ io.Latent.Input("latents"), io.Int.Input("batch_size", default=1, min=1, max=4096), ], outputs=[ io.Latent.Output(is_output_list=True), ], ) @staticmethod def get_batch(latents, list_ind, offset): """prepare a batch out of the list of latents""" samples = latents[list_ind]['samples'] shape = samples.shape mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") if mask.shape[0] < samples.shape[0]: mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] if 'batch_index' in latents[list_ind]: batch_inds = latents[list_ind]['batch_index'] else: batch_inds = [x+offset for x in range(shape[0])] return samples, mask, batch_inds @staticmethod def get_slices(indexable, num, batch_size): """divides an indexable object into num slices of length batch_size, and a remainder""" slices = [] for i in range(num): slices.append(indexable[i*batch_size:(i+1)*batch_size]) if num * batch_size < len(indexable): return slices, indexable[num * batch_size:] else: return slices, None @staticmethod def slice_batch(batch, num, batch_size): result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] return list(zip(*result)) @staticmethod def cat_batch(batch1, batch2): if batch1[0] is None: return batch2 result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] return result @classmethod def execute(cls, latents, batch_size): batch_size = batch_size[0] output_list = [] current_batch = (None, None, None) processed = 0 for i in range(len(latents)): # fetch new entry of list #samples, masks, indices = self.get_batch(latents, i) next_batch = cls.get_batch(latents, i, processed) processed += len(next_batch[2]) # set to current if current is None if current_batch[0] is None: current_batch = next_batch # add previous to list if dimensions do not match elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: sliced, _ = cls.slice_batch(current_batch, 1, batch_size) output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) current_batch = next_batch # cat if everything checks out else: current_batch = cls.cat_batch(current_batch, next_batch) # add to list if dimensions gone above target batch size if current_batch[0].shape[0] > batch_size: num = current_batch[0].shape[0] // batch_size sliced, remainder = cls.slice_batch(current_batch, num, batch_size) for i in range(num): output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) current_batch = remainder #add remainder if current_batch[0] is not None: sliced, _ = cls.slice_batch(current_batch, 1, batch_size) output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) #get rid of empty masks for s in output_list: if s['noise_mask'].mean() == 1.0: del s['noise_mask'] return io.NodeOutput(output_list) NODES_LIST = [ ImageRebatch, LatentRebatch, ]