From 3aaabb12d422eb35cd0314a09582c0a47d505a37 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 14 Jan 2025 05:14:10 -0500 Subject: [PATCH 01/11] Implement Cosmos Image/Video to World (Video) diffusion models. Use CosmosImageToVideoLatent to set the input image/video. --- comfy/model_base.py | 29 ++++++++++++++++++++---- comfy/model_detection.py | 5 ++-- comfy/samplers.py | 2 +- comfy/sd.py | 2 +- comfy/supported_models.py | 14 ++++++++++-- comfy_extras/nodes_cosmos.py | 44 +++++++++++++++++++++++++++++++++++- 6 files changed, 84 insertions(+), 12 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index a67504cbb..7625b7126 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -189,9 +189,10 @@ class BaseModel(torch.nn.Module): if denoise_mask is not None: if len(denoise_mask.shape) == len(noise.shape): - denoise_mask = denoise_mask[:,:1] + denoise_mask = denoise_mask[:, :1] - denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) + num_dim = noise.ndim - 2 + denoise_mask = denoise_mask.reshape((-1, 1) + tuple(denoise_mask.shape[-num_dim:])) if denoise_mask.shape[-2:] != noise.shape[-2:]: denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) @@ -201,12 +202,16 @@ class BaseModel(torch.nn.Module): if ck == "mask": cond_concat.append(denoise_mask.to(device)) elif ck == "masked_image": - cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space + cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space + elif ck == "mask_inverted": + cond_concat.append(1.0 - denoise_mask.to(device)) else: if ck == "mask": - cond_concat.append(torch.ones_like(noise)[:,:1]) + cond_concat.append(torch.ones_like(noise)[:, :1]) elif ck == "masked_image": cond_concat.append(self.blank_inpaint_image_like(noise)) + elif ck == "mask_inverted": + cond_concat.append(torch.zeros_like(noise)[:, :1]) data = torch.cat(cond_concat, dim=1) return data return None @@ -294,6 +299,9 @@ class BaseModel(torch.nn.Module): return blank_image self.blank_inpaint_image_like = blank_inpaint_image_like + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image) + def memory_required(self, input_shape): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): dtype = self.get_dtype() @@ -859,8 +867,11 @@ class HunyuanVideo(BaseModel): return out class CosmosVideo(BaseModel): - def __init__(self, model_config, model_type=ModelType.EDM, device=None): + def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT) + self.image_to_video = image_to_video + if self.image_to_video: + self.concat_keys = ("mask_inverted",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -873,3 +884,11 @@ class CosmosVideo(BaseModel): out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None)) return out + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)) + sigma_noise_augmentation = 0 #TODO + if sigma_noise_augmentation != 0: + latent_image = latent_image + noise + latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) + return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 20cd6bb86..ba96ebe85 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -245,13 +245,14 @@ def detect_unet_config(state_dict, key_prefix): dit_config["max_img_h"] = 240 dit_config["max_img_w"] = 240 dit_config["max_frames"] = 128 - dit_config["in_channels"] = 16 + concat_padding_mask = True + dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) dit_config["out_channels"] = 16 dit_config["patch_spatial"] = 2 dit_config["patch_temporal"] = 1 dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0] dit_config["block_config"] = "FA-CA-MLP" - dit_config["concat_padding_mask"] = True + dit_config["concat_padding_mask"] = concat_padding_mask dit_config["pos_emb_cls"] = "rope3d" dit_config["pos_emb_learnable"] = False dit_config["pos_emb_interpolation"] = "crop" diff --git a/comfy/samplers.py b/comfy/samplers.py index fa176c6de..8f25935d7 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -376,7 +376,7 @@ class KSamplerX0Inpaint: if "denoise_mask_function" in model_options: denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask - x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask + x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask out = self.inner_model(x, sigma, model_options=model_options, seed=seed) if denoise_mask is not None: out = out * denoise_mask + self.latent_image * latent_mask diff --git a/comfy/sd.py b/comfy/sd.py index 7db1c2d60..6ba6af474 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -534,7 +534,7 @@ class VAE: def encode(self, pixel_samples): pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) - if self.latent_dim == 3: + if self.latent_dim == 3 and pixel_samples.ndim < 5: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 31de1ae9e..ff3f14329 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -824,9 +824,10 @@ class HunyuanVideo(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect)) -class Cosmos(supported_models_base.BASE): +class CosmosT2V(supported_models_base.BASE): unet_config = { "image_model": "cosmos", + "in_channels": 16, } sampling_settings = { @@ -854,7 +855,16 @@ class Cosmos(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect)) +class CosmosI2V(CosmosT2V): + unet_config = { + "image_model": "cosmos", + "in_channels": 17, + } -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, Cosmos] + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CosmosVideo(self, image_to_video=True, device=device) + return out + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py index d88773e25..5fbabb9a7 100644 --- a/comfy_extras/nodes_cosmos.py +++ b/comfy_extras/nodes_cosmos.py @@ -1,6 +1,8 @@ import nodes import torch import comfy.model_management +import comfy.utils + class EmptyCosmosLatentVideo: @classmethod @@ -16,8 +18,48 @@ class EmptyCosmosLatentVideo: def generate(self, width, height, length, batch_size=1): latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return ({"samples":latent}, ) + return ({"samples": latent}, ) + + +class CosmosImageToVideoLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": {"vae": ("VAE", ), + "image": ("IMAGE", ), + "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "conditioning/inpaint" + + def encode(self, vae, image, width, height, length, batch_size): + pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + pixel_len = min(pixels.shape[0], length) + padded_length = min(length, (((pixel_len - 1) // 8) + 2) * 8 - 7) + padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5 + padded_pixels[:pixel_len] = pixels[:pixel_len] + + latent_temp = vae.encode(padded_pixels) + + latent = torch.zeros([1, latent_temp.shape[1], ((length - 1) // 8) + 1, latent_temp.shape[-2], latent_temp.shape[-1]], device=comfy.model_management.intermediate_device()) + latent_len = ((pixel_len - 1) // 8) + 1 + latent[:, :, :latent_len] = latent_temp[:, :, :latent_len] + + mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + mask[:, :, :latent_len] *= 0.0 + + out_latent = {} + out_latent["samples"] = latent + out_latent["noise_mask"] = mask + return (out_latent,) + NODE_CLASS_MAPPINGS = { "EmptyCosmosLatentVideo": EmptyCosmosLatentVideo, + "CosmosImageToVideoLatent": CosmosImageToVideoLatent, } From c78a45685d2664e03927b8b57bc2f950c47d6ad3 Mon Sep 17 00:00:00 2001 From: Pam <42671363+pamparamm@users.noreply.github.com> Date: Wed, 15 Jan 2025 04:20:06 +0500 Subject: [PATCH 02/11] Rewrite res_multistep sampler and implement res_multistep_cfg_pp sampler. (#6462) --- comfy/k_diffusion/res.py | 258 ---------------------------------- comfy/k_diffusion/sampling.py | 71 ++++++++-- comfy/samplers.py | 2 +- 3 files changed, 63 insertions(+), 268 deletions(-) delete mode 100644 comfy/k_diffusion/res.py diff --git a/comfy/k_diffusion/res.py b/comfy/k_diffusion/res.py deleted file mode 100644 index 6caedec39..000000000 --- a/comfy/k_diffusion/res.py +++ /dev/null @@ -1,258 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copied from Nvidia Cosmos code. - -import torch -from torch import Tensor -from typing import Callable, List, Tuple, Optional, Any -import math -from tqdm.auto import trange - - -def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: - ndims1 = x.ndim - ndims2 = y.ndim - - if ndims1 < ndims2: - x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) - elif ndims2 < ndims1: - y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) - - return x, y - - -def batch_mul(x: Tensor, y: Tensor) -> Tensor: - x, y = common_broadcast(x, y) - return x * y - - -def phi1(t: torch.Tensor) -> torch.Tensor: - """ - Compute the first order phi function: (exp(t) - 1) / t. - - Args: - t: Input tensor. - - Returns: - Tensor: Result of phi1 function. - """ - input_dtype = t.dtype - t = t.to(dtype=torch.float32) - return (torch.expm1(t) / t).to(dtype=input_dtype) - - -def phi2(t: torch.Tensor) -> torch.Tensor: - """ - Compute the second order phi function: (phi1(t) - 1) / t. - - Args: - t: Input tensor. - - Returns: - Tensor: Result of phi2 function. - """ - input_dtype = t.dtype - t = t.to(dtype=torch.float32) - return ((phi1(t) - 1.0) / t).to(dtype=input_dtype) - - -def res_x0_rk2_step( - x_s: torch.Tensor, - t: torch.Tensor, - s: torch.Tensor, - x0_s: torch.Tensor, - s1: torch.Tensor, - x0_s1: torch.Tensor, -) -> torch.Tensor: - """ - Perform a residual-based 2nd order Runge-Kutta step. - - Args: - x_s: Current state tensor. - t: Target time tensor. - s: Current time tensor. - x0_s: Prediction at current time. - s1: Intermediate time tensor. - x0_s1: Prediction at intermediate time. - - Returns: - Tensor: Updated state tensor. - - Raises: - AssertionError: If step size is too small. - """ - s = -torch.log(s) - t = -torch.log(t) - m = -torch.log(s1) - - dt = t - s - assert not torch.any(torch.isclose(dt, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" - assert not torch.any(torch.isclose(m - s, torch.zeros_like(dt), atol=1e-6)), "Step size is too small" - - c2 = (m - s) / dt - phi1_val, phi2_val = phi1(-dt), phi2(-dt) - - # Handle edge case where t = s = m - b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0) - b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0) - - return batch_mul(torch.exp(-dt), x_s) + batch_mul(dt, batch_mul(b1, x0_s) + batch_mul(b2, x0_s1)) - - -def reg_x0_euler_step( - x_s: torch.Tensor, - s: torch.Tensor, - t: torch.Tensor, - x0_s: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Perform a regularized Euler step based on x0 prediction. - - Args: - x_s: Current state tensor. - s: Current time tensor. - t: Target time tensor. - x0_s: Prediction at current time. - - Returns: - Tuple[Tensor, Tensor]: Updated state tensor and current prediction. - """ - coef_x0 = (s - t) / s - coef_xs = t / s - return batch_mul(coef_x0, x0_s) + batch_mul(coef_xs, x_s), x0_s - - -def order2_fn( - x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor -) -> Tuple[torch.Tensor, List[torch.Tensor]]: - """ - impl the second order multistep method in https://arxiv.org/pdf/2308.02157 - Adams Bashforth approach! - """ - if x0_preds: - x0_s1, s1 = x0_preds[0] - x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1) - else: - x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0] - return x_t, [(x0_s, s)] - - -class SolverConfig: - is_multi: bool = True - rk: str = "2mid" - multistep: str = "2ab" - s_churn: float = 0.0 - s_t_max: float = float("inf") - s_t_min: float = 0.0 - s_noise: float = 1.0 - - -def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any, disable=None) -> Any: - """ - Implements a for loop with a function. - - Args: - lower: Lower bound of the loop (inclusive). - upper: Upper bound of the loop (exclusive). - body_fun: Function to be applied in each iteration. - init_val: Initial value for the loop. - - Returns: - The final result after all iterations. - """ - val = init_val - for i in trange(lower, upper, disable=disable): - val = body_fun(i, val) - return val - - -def differential_equation_solver( - x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], - sigmas_L: torch.Tensor, - solver_cfg: SolverConfig, - noise_sampler, - callback=None, - disable=None, -) -> Callable[[torch.Tensor], torch.Tensor]: - """ - Creates a differential equation solver function. - - Args: - x0_fn: Function to compute x0 prediction. - sigmas_L: Tensor of sigma values with shape [L,]. - solver_cfg: Configuration for the solver. - - Returns: - A function that solves the differential equation. - """ - num_step = len(sigmas_L) - 1 - - # if solver_cfg.is_multi: - # update_step_fn = get_multi_step_fn(solver_cfg.multistep) - # else: - # update_step_fn = get_runge_kutta_fn(solver_cfg.rk) - update_step_fn = order2_fn - - eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1) - - def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor: - """ - Samples from the differential equation. - - Args: - input_xT_B_StateShape: Input tensor with shape [B, StateShape]. - - Returns: - Output tensor with shape [B, StateShape]. - """ - ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float32) - - def step_fn( - i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]] - ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: - input_x_B_StateShape, x0_preds = state - sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1] - - if sigma_next_0 == 0: - output_x_B_StateShape = x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B) - else: - # algorithm 2: line 4-6 - if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max and eta > 0: - hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0 - input_x_B_StateShape = input_x_B_StateShape + ( - hat_sigma_cur_0**2 - sigma_cur_0**2 - ).sqrt() * solver_cfg.s_noise * noise_sampler(sigma_cur_0, sigma_next_0) # torch.randn_like(input_x_B_StateShape) - sigma_cur_0 = hat_sigma_cur_0 - - if solver_cfg.is_multi: - x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B) - output_x_B_StateShape, x0_preds = update_step_fn( - input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds - ) - else: - output_x_B_StateShape, x0_preds = update_step_fn( - input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn - ) - - if callback is not None: - callback({'x': input_x_B_StateShape, 'i': i_th, 'sigma': sigma_cur_0, 'sigma_hat': sigma_cur_0, 'denoised': x0_pred_B_StateShape}) - - return output_x_B_StateShape, x0_preds - - x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None], disable=disable) - return x_at_eps - - return sample_fn diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 3a98e6a7c..13ae272fd 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -8,7 +8,6 @@ from tqdm.auto import trange, tqdm from . import utils from . import deis -from . import res import comfy.model_patcher import comfy.model_sampling @@ -1268,18 +1267,72 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis return x @torch.no_grad() -def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None): +def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None, cfg_pp=False): extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + phi1_fn = lambda t: torch.expm1(t) / t + phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t - x0_func = lambda x, sigma: model(x, sigma, **extra_args) + old_denoised = None + uncond_denoised = None + def post_cfg_function(args): + nonlocal uncond_denoised + uncond_denoised = args["uncond_denoised"] + return args["denoised"] - solver_cfg = res.SolverConfig() - solver_cfg.s_churn = s_churn - solver_cfg.s_t_max = s_tmax - solver_cfg.s_t_min = s_tmin - solver_cfg.s_noise = s_noise + if cfg_pp: + model_options = extra_args.get("model_options", {}).copy() + extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) - x = res.differential_equation_solver(x0_func, sigmas, solver_cfg, noise_sampler, callback=callback, disable=disable)(x) + for i in trange(len(sigmas) - 1, disable=disable): + if s_churn > 0: + gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 + sigma_hat = sigmas[i] * (gamma + 1) + else: + gamma = 0 + sigma_hat = sigmas[i] + + if gamma > 0: + eps = torch.randn_like(x) * s_noise + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + if callback is not None: + callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised}) + if sigmas[i + 1] == 0 or old_denoised is None: + # Euler method + if cfg_pp: + d = to_d(x, sigma_hat, uncond_denoised) + x = denoised + d * sigmas[i + 1] + else: + d = to_d(x, sigma_hat, denoised) + dt = sigmas[i + 1] - sigma_hat + x = x + d * dt + else: + # Second order multistep method in https://arxiv.org/pdf/2308.02157 + t, t_next, t_prev = t_fn(sigmas[i]), t_fn(sigmas[i + 1]), t_fn(sigmas[i - 1]) + h = t_next - t + c2 = (t_prev - t) / h + + phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h) + b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0) + b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0) + + if cfg_pp: + x = x + (denoised - uncond_denoised) + + x = (sigma_fn(t_next) / sigma_fn(t)) * x + h * (b1 * denoised + b2 * old_denoised) + + old_denoised = denoised return x + +@torch.no_grad() +def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None): + return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=False) + +@torch.no_grad() +def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None): + return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, noise_sampler=noise_sampler, cfg_pp=True) diff --git a/comfy/samplers.py b/comfy/samplers.py index 8f25935d7..c508a3a41 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -687,7 +687,7 @@ class Sampler: KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", - "ipndm", "ipndm_v", "deis", "res_multistep"] + "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): From 2cdbaf5169a631b126542f41432f2484c4f0a608 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 14 Jan 2025 19:05:45 -0500 Subject: [PATCH 03/11] Add SetFirstSigma node (#6459) Useful for models utilizing ztSNR. See: https://arxiv.org/abs/2409.15997 --- comfy_extras/nodes_custom_sampler.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index c7ff9a4d8..576fc3b2c 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -231,6 +231,24 @@ class FlipSigmas: sigmas[0] = 0.0001 return (sigmas,) +class SetFirstSigma: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"sigmas": ("SIGMAS", ), + "sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/sigmas" + + FUNCTION = "set_first_sigma" + + def set_first_sigma(self, sigmas, sigma): + sigmas = sigmas.clone() + sigmas[0] = sigma + return (sigmas, ) + class KSamplerSelect: @classmethod def INPUT_TYPES(s): @@ -710,6 +728,7 @@ NODE_CLASS_MAPPINGS = { "SplitSigmas": SplitSigmas, "SplitSigmasDenoise": SplitSigmasDenoise, "FlipSigmas": FlipSigmas, + "SetFirstSigma": SetFirstSigma, "CFGGuider": CFGGuider, "DualCFGGuider": DualCFGGuider, From 5b657f8c15a2cc437049dcbfc10eb268fb0194d4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 00:41:35 -0500 Subject: [PATCH 04/11] Allow setting start and end image in CosmosImageToVideoLatent. --- comfy_extras/nodes_cosmos.py | 47 ++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py index 5fbabb9a7..b76ff950b 100644 --- a/comfy_extras/nodes_cosmos.py +++ b/comfy_extras/nodes_cosmos.py @@ -21,37 +21,54 @@ class EmptyCosmosLatentVideo: return ({"samples": latent}, ) +def vae_encode_with_padding(vae, image, width, height, length, padding=0): + pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + pixel_len = min(pixels.shape[0], length) + padded_length = min(length, (((pixel_len - 1) // 8) + 1 + padding) * 8 - 7) + padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5 + padded_pixels[:pixel_len] = pixels[:pixel_len] + latent_len = ((pixel_len - 1) // 8) + 1 + latent_temp = vae.encode(padded_pixels) + return latent_temp[:, :, :latent_len] + + class CosmosImageToVideoLatent: @classmethod def INPUT_TYPES(s): return {"required": {"vae": ("VAE", ), - "image": ("IMAGE", ), "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} + }, + "optional": {"start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "conditioning/inpaint" - def encode(self, vae, image, width, height, length, batch_size): - pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - pixel_len = min(pixels.shape[0], length) - padded_length = min(length, (((pixel_len - 1) // 8) + 2) * 8 - 7) - padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5 - padded_pixels[:pixel_len] = pixels[:pixel_len] - - latent_temp = vae.encode(padded_pixels) - - latent = torch.zeros([1, latent_temp.shape[1], ((length - 1) // 8) + 1, latent_temp.shape[-2], latent_temp.shape[-1]], device=comfy.model_management.intermediate_device()) - latent_len = ((pixel_len - 1) // 8) + 1 - latent[:, :, :latent_len] = latent_temp[:, :, :latent_len] + def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None): + latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is None and end_image is None: + out_latent = {} + out_latent["samples"] = latent + return (out_latent,) mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) - mask[:, :, :latent_len] *= 0.0 + + if start_image is not None: + latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1) + latent[:, :, :latent_temp.shape[-3]] = latent_temp + mask[:, :, :latent_temp.shape[-3]] *= 0.0 + + if end_image is not None: + latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0) + latent[:, :, -latent_temp.shape[-3]:] = latent_temp + mask[:, :, -latent_temp.shape[-3]:] *= 0.0 out_latent = {} out_latent["samples"] = latent From 2feb8d0b77ce80a471ecab84b92d5bbcaa37f8fe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 03:50:27 -0500 Subject: [PATCH 05/11] Force safe loading of files in torch format on pytorch 2.4+ If this breaks something for you make an issue. --- comfy/utils.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index b486b2deb..bcefa1804 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -29,17 +29,29 @@ import itertools from torch.nn.functional import interpolate from einops import rearrange +ALWAYS_SAFE_LOAD = False +if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated + class ModelCheckpoint: + pass + ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint" + + from numpy.core.multiarray import scalar + from numpy import dtype + from numpy.dtypes import Float64DType + from _codecs import encode + + torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode]) + ALWAYS_SAFE_LOAD = True + logging.info("Checkpoint files will always be loaded safely.") + + def load_torch_file(ckpt, safe_load=False, device=None): if device is None: device = torch.device("cpu") if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): sd = safetensors.torch.load_file(ckpt, device=device.type) else: - if safe_load: - if not 'weights_only' in torch.load.__code__.co_varnames: - logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") - safe_load = False - if safe_load: + if safe_load or ALWAYS_SAFE_LOAD: pl_sd = torch.load(ckpt, map_location=device, weights_only=True) else: pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) From cba58fff0bfebfc81fbe678bb80491890a3df14a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 04:32:23 -0500 Subject: [PATCH 06/11] Remove unsafe embedding load for very old pytorch. --- comfy/sd1_clip.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 95d41c30f..85518afd9 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -388,13 +388,10 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No import safetensors.torch embed = safetensors.torch.load_file(embed_path, device="cpu") else: - if 'weights_only' in torch.load.__code__.co_varnames: - try: - embed = torch.load(embed_path, weights_only=True, map_location="cpu") - except: - embed_out = safe_load_embed_zip(embed_path) - else: - embed = torch.load(embed_path, map_location="cpu") + try: + embed = torch.load(embed_path, weights_only=True, map_location="cpu") + except: + embed_out = safe_load_embed_zip(embed_path) except Exception: logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name)) return None From 1709a8441e7ad88ead87285b802e429b5ab7aebb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 14:50:40 -0500 Subject: [PATCH 07/11] Use latest python 3.12.8 the portable release. --- .github/workflows/stable-release.yml | 2 +- .github/workflows/windows_release_dependencies.yml | 2 +- .github/workflows/windows_release_package.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 0bdd5a3bd..4a5ba58f6 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -22,7 +22,7 @@ on: description: 'Python patch version' required: true type: string - default: "7" + default: "8" jobs: diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index 85e6a52fd..6c7937ae2 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -29,7 +29,7 @@ on: description: 'python patch version' required: true type: string - default: "7" + default: "8" # push: # branches: # - master diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 11e724ba7..24f928ee0 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -19,7 +19,7 @@ on: description: 'python patch version' required: true type: string - default: "7" + default: "8" # push: # branches: # - master From 3baf92d120a91842c84a9907883eda64882e90d6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 17:19:59 -0500 Subject: [PATCH 08/11] CosmosImageToVideoLatent batch_size now does something. --- comfy_extras/nodes_cosmos.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py index b76ff950b..bd35ddb06 100644 --- a/comfy_extras/nodes_cosmos.py +++ b/comfy_extras/nodes_cosmos.py @@ -71,8 +71,8 @@ class CosmosImageToVideoLatent: mask[:, :, -latent_temp.shape[-3]:] *= 0.0 out_latent = {} - out_latent["samples"] = latent - out_latent["noise_mask"] = mask + out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) + out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) return (out_latent,) From 2e20e399ea6d9fad5f0e40f987d96088f052b74c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 20:19:56 -0500 Subject: [PATCH 09/11] Add minimum numpy version to requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 4c2c0b2b2..3bc945a1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ torch torchsde torchvision torchaudio +numpy>=1.25.0 einops transformers>=4.28.1 tokenizers>=0.13.3 From 55ade36d01fd4bf3c1ba7238a06a5fa386597124 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 20:24:55 -0500 Subject: [PATCH 10/11] Remove python 3.8 from test-build workflow. --- .github/workflows/test-build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 444d6b254..419873ad8 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -28,4 +28,4 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt \ No newline at end of file + pip install -r requirements.txt From bfd5dfd6111d4133b305b8174c71b224a780b6e3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 20:32:44 -0500 Subject: [PATCH 11/11] 3.13 doesn't work yet. --- .github/workflows/test-build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 419873ad8..865e1ec25 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }}