WIP way to support multi multi dimensional latents. (#10456)

This commit is contained in:
comfyanonymous
2025-10-23 18:21:14 -07:00
committed by GitHub
parent a1864c01f2
commit 1bcda6df98
5 changed files with 158 additions and 15 deletions

View File

@@ -197,8 +197,14 @@ class BaseModel(torch.nn.Module):
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
def process_timestep(self, timestep, **kwargs):
return timestep

91
comfy/nested_tensor.py Normal file
View File

@@ -0,0 +1,91 @@
import torch
class NestedTensor:
def __init__(self, tensors):
self.tensors = list(tensors)
self.is_nested = True
def _copy(self):
return NestedTensor(self.tensors)
def apply_operation(self, other, operation):
o = self._copy()
if isinstance(other, NestedTensor):
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other.tensors[i])
else:
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other)
return o
def __add__(self, b):
return self.apply_operation(b, lambda x, y: x + y)
def __sub__(self, b):
return self.apply_operation(b, lambda x, y: x - y)
def __mul__(self, b):
return self.apply_operation(b, lambda x, y: x * y)
# def __itruediv__(self, b):
# return self.apply_operation(b, lambda x, y: x / y)
def __truediv__(self, b):
return self.apply_operation(b, lambda x, y: x / y)
def __getitem__(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
def unbind(self):
return self.tensors
def to(self, *args, **kwargs):
o = self._copy()
for i, t in enumerate(o.tensors):
o.tensors[i] = t.to(*args, **kwargs)
return o
def new_ones(self, *args, **kwargs):
return self.tensors[0].new_ones(*args, **kwargs)
def float(self):
return self.to(dtype=torch.float)
def chunk(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
def size(self):
return self.tensors[0].size()
@property
def shape(self):
return self.tensors[0].shape
@property
def ndim(self):
dims = 0
for t in self.tensors:
dims = max(t.ndim, dims)
return dims
@property
def device(self):
return self.tensors[0].device
@property
def dtype(self):
return self.tensors[0].dtype
@property
def layout(self):
return self.tensors[0].layout
def cat_nested(tensors, *args, **kwargs):
cated_tensors = []
for i in range(len(tensors[0].tensors)):
tens = []
for j in range(len(tensors)):
tens.append(tensors[j].tensors[i])
cated_tensors.append(torch.cat(tens, *args, **kwargs))
return NestedTensor(cated_tensors)

View File

@@ -4,13 +4,9 @@ import comfy.samplers
import comfy.utils
import numpy as np
import logging
import comfy.nested_tensor
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
def prepare_noise_inner(latent_image, generator, noise_inds=None):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
@@ -22,9 +18,28 @@ def prepare_noise(latent_image, seed, noise_inds=None):
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
if latent_image.is_nested:
tensors = latent_image.unbind()
noises = []
for t in tensors:
noises.append(prepare_noise_inner(t, generator, noise_inds))
noises = comfy.nested_tensor.NestedTensor(noises)
else:
noises = prepare_noise_inner(latent_image, generator, noise_inds)
return noises
def fix_empty_latent_channels(model, latent_image):
if latent_image.is_nested:
return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)

View File

@@ -782,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
for k in conds:
conds[k] = conds[k][:]
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
@@ -792,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
if hasattr(model, 'extra_conds'):
for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
#make sure each cond area has an opposite one with the same area
for k in conds:
@@ -962,11 +962,11 @@ class CFGGuider:
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
@@ -980,7 +980,7 @@ class CFGGuider:
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
@@ -994,7 +994,7 @@ class CFGGuider:
try:
self.model_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
self.model_patcher.cleanup()
@@ -1007,6 +1007,12 @@ class CFGGuider:
if sigmas.shape[-1] == 0:
return latent_image
if latent_image.is_nested:
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
noise, _ = comfy.utils.pack_latents(noise.unbind())
else:
latent_shapes = [latent_image.shape]
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
@@ -1026,7 +1032,7 @@ class CFGGuider:
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options
@@ -1034,6 +1040,9 @@ class CFGGuider:
self.model_patcher.restore_hook_patches()
del self.conds
if len(latent_shapes) > 1:
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
return output

View File

@@ -1106,3 +1106,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
dim=1
)
return out
def pack_latents(latents):
latent_shapes = []
tensors = []
for tensor in latents:
latent_shapes.append(tensor.shape)
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
latent = torch.cat(tensors, dim=-1)
return latent, latent_shapes
def unpack_latents(combined_latent, latent_shapes):
if len(latent_shapes) > 1:
output_tensors = []
for shape in latent_shapes:
cut = math.prod(shape[1:])
tens = combined_latent[:, :, :cut]
combined_latent = combined_latent[:, :, cut:]
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
else:
output_tensors = combined_latent
return output_tensors