diff --git a/.ci/nightly/update_windows/update.py b/.ci/nightly/update_windows/update.py deleted file mode 100755 index c09f29a80..000000000 --- a/.ci/nightly/update_windows/update.py +++ /dev/null @@ -1,65 +0,0 @@ -import pygit2 -from datetime import datetime -import sys - -def pull(repo, remote_name='origin', branch='master'): - for remote in repo.remotes: - if remote.name == remote_name: - remote.fetch() - remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target - merge_result, _ = repo.merge_analysis(remote_master_id) - # Up to date, do nothing - if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE: - return - # We can just fastforward - elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD: - repo.checkout_tree(repo.get(remote_master_id)) - try: - master_ref = repo.lookup_reference('refs/heads/%s' % (branch)) - master_ref.set_target(remote_master_id) - except KeyError: - repo.create_branch(branch, repo.get(remote_master_id)) - repo.head.set_target(remote_master_id) - elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL: - repo.merge(remote_master_id) - - if repo.index.conflicts is not None: - for conflict in repo.index.conflicts: - print('Conflicts found in:', conflict[0].path) - raise AssertionError('Conflicts, ahhhhh!!') - - user = repo.default_signature - tree = repo.index.write_tree() - commit = repo.create_commit('HEAD', - user, - user, - 'Merge!', - tree, - [repo.head.target, remote_master_id]) - # We need to do this or git CLI will think we are still merging. - repo.state_cleanup() - else: - raise AssertionError('Unknown merge analysis result') - - -repo = pygit2.Repository(str(sys.argv[1])) -ident = pygit2.Signature('comfyui', 'comfy@ui') -try: - print("stashing current changes") - repo.stash(ident) -except KeyError: - print("nothing to stash") -backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) -print("creating backup branch: {}".format(backup_branch_name)) -repo.branches.local.create(backup_branch_name, repo.head.peel()) - -print("checking out master branch") -branch = repo.lookup_branch('master') -ref = repo.lookup_reference(branch.name) -repo.checkout(ref) - -print("pulling latest changes") -pull(repo) - -print("Done!") - diff --git a/.ci/nightly/update_windows/update_comfyui.bat b/.ci/nightly/update_windows/update_comfyui.bat deleted file mode 100755 index 60d1e694f..000000000 --- a/.ci/nightly/update_windows/update_comfyui.bat +++ /dev/null @@ -1,2 +0,0 @@ -..\python_embeded\python.exe .\update.py ..\ComfyUI\ -pause diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat index c5e0c6be7..b4989534f 100755 --- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat +++ b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r ../ComfyUI/requirements.txt pygit2 +..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 pause diff --git a/.ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt b/.ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt deleted file mode 100755 index 656b9db43..000000000 --- a/.ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt +++ /dev/null @@ -1,27 +0,0 @@ -HOW TO RUN: - -if you have a NVIDIA gpu: - -run_nvidia_gpu.bat - - - -To run it in slow CPU mode: - -run_cpu.bat - - - -IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints - -You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt - - - -RECOMMENDED WAY TO UPDATE: -To update the ComfyUI code: update\update_comfyui.bat - - - -To update ComfyUI with the python dependencies: -update\update_comfyui_and_python_dependencies.bat diff --git a/.ci/nightly/windows_base_files/run_cpu.bat b/.ci/nightly/windows_base_files/run_cpu.bat deleted file mode 100755 index c3ba41721..000000000 --- a/.ci/nightly/windows_base_files/run_cpu.bat +++ /dev/null @@ -1,2 +0,0 @@ -.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build -pause diff --git a/.github/workflows/windows_release_cu118_dependencies_2.yml b/.github/workflows/windows_release_cu118_dependencies_2.yml index a88449527..42adee9e7 100644 --- a/.github/workflows/windows_release_cu118_dependencies_2.yml +++ b/.github/workflows/windows_release_cu118_dependencies_2.yml @@ -17,7 +17,7 @@ jobs: - shell: bash run: | - python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers==0.0.19.dev516 --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir + python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir python -m pip install --no-cache-dir ./temp_wheel_dir/* echo installed basic ls -lah temp_wheel_dir diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 291d754e3..f23cae6d5 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -46,6 +46,8 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/nightly/update_windows/* ./update/ + cp -r ComfyUI/.ci/nightly/windows_base_files/* ./ cd .. diff --git a/README.md b/README.md index bf16006bf..3b3824714 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ A powerful and modular stable diffusion GUI and backend. This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out: ### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/) +### [Installing ComfyUI](#installing) + ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Fully supports SD1.x and SD2.x @@ -17,6 +19,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models. - Embeddings/Textual inversion - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) +- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/) - Loading full workflows (with seeds) from generated PNG files. - Saving/Loading workflows as Json files. - Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones. diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b24054ce0..764427165 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -10,6 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") +parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index e96cfc93a..2952be62d 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -712,7 +712,7 @@ class UniPC: def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform', method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, corrector=False, + atol=0.0078, rtol=0.05, corrector=False, callback=None ): t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end t_T = self.noise_schedule.T if t_start is None else t_start @@ -766,6 +766,8 @@ class UniPC: if model_x is None: model_x = self.model_fn(x, vec_t) model_prev_list[-1] = model_x + if callback is not None: + callback(step_index, model_prev_list[-1], x) else: raise NotImplementedError() if denoise_to_zero: @@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex order = min(3, len(timesteps) - 1) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) - x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True) + x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback) if not to_zero: x /= ns.marginal_alpha(timesteps[-1]) return x diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 98dbda635..ce7180d91 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads query = self.to_q(x) context = default(context, x) key = self.to_k(context) - value = self.to_v(context) + if value is not None: + value = self.to_v(value) + else: + value = self.to_v(context) + del context, x query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) @@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) k_in = self.to_k(context) - v_in = self.to_v(context) + if value is not None: + v_in = self.to_v(value) + del value + else: + v_in = self.to_v(context) del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) @@ -350,13 +358,17 @@ class CrossAttention(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) @@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module): self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) b, _, _ = q.shape q, k, v = map( @@ -447,19 +463,19 @@ class CrossAttentionPytorch(nn.Module): self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) b, _, _ = q.shape q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), + lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2), (q, k, v), ) @@ -468,10 +484,7 @@ class CrossAttentionPytorch(nn.Module): if exists(mask): raise NotImplementedError out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) + out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head) ) return self.to_out(out) @@ -519,11 +532,25 @@ class BasicTransformerBlock(nn.Module): transformer_patches = {} n = self.norm1(x) + if self.disable_self_attn: + context_attn1 = context + else: + context_attn1 = None + value_attn1 = None + + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + if context_attn1 is None: + context_attn1 = n + value_attn1 = context_attn1 + for p in patch: + n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1) + if "tomesd" in transformer_options: m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) - n = u(self.attn1(m(n), context=context if self.disable_self_attn else None)) + n = u(self.attn1(m(n), context=context_attn1, value=value_attn1)) else: - n = self.attn1(n, context=context if self.disable_self_attn else None) + n = self.attn1(n, context=context_attn1, value=value_attn1) x += n if "middle_patch" in transformer_patches: @@ -532,7 +559,16 @@ class BasicTransformerBlock(nn.Module): x = p(current_index, x) n = self.norm2(x) - n = self.attn2(n, context=context) + + context_attn2 = context + value_attn2 = None + if "attn2_patch" in transformer_patches: + patch = transformer_patches["attn2_patch"] + value_attn2 = context_attn2 + for p in patch: + n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2) + + n = self.attn2(n, context=context_attn2, value=value_attn2) x += n x = self.ff(self.norm3(x)) + x diff --git a/comfy/model_management.py b/comfy/model_management.py index a0d1313d2..9497ae7af 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -20,6 +20,18 @@ total_vram_available_mb = -1 accelerate_enabled = False xpu_available = False +directml_enabled = False +if args.directml is not None: + import torch_directml + directml_enabled = True + device_index = args.directml + if device_index < 0: + directml_device = torch_directml.device() + else: + directml_device = torch_directml.device(device_index) + print("Using directml with device:", torch_directml.device_name(device_index)) + # torch_directml.disable_tiled_resources(True) + try: import torch try: @@ -133,6 +145,7 @@ def unload_model(): #never unload models from GPU on high vram if vram_state != VRAMState.HIGH_VRAM: current_loaded_model.model.cpu() + current_loaded_model.model_patches_to("cpu") current_loaded_model.unpatch_model() current_loaded_model = None @@ -156,6 +169,8 @@ def load_model_gpu(model): except Exception as e: model.unpatch_model() raise e + + model.model_patches_to(get_torch_device()) current_loaded_model = model if vram_state == VRAMState.CPU: pass @@ -214,6 +229,10 @@ def unload_if_low_vram(model): def get_torch_device(): global xpu_available + global directml_enabled + if directml_enabled: + global directml_device + return directml_device if vram_state == VRAMState.MPS: return torch.device("mps") if vram_state == VRAMState.CPU: @@ -231,8 +250,14 @@ def get_autocast_device(dev): def xformers_enabled(): + global xpu_available + global directml_enabled if vram_state == VRAMState.CPU: return False + if xpu_available: + return False + if directml_enabled: + return False return XFORMERS_IS_AVAILABLE @@ -248,6 +273,7 @@ def pytorch_attention_enabled(): def get_free_memory(dev=None, torch_free_too=False): global xpu_available + global directml_enabled if dev is None: dev = get_torch_device() @@ -255,7 +281,10 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: - if xpu_available: + if directml_enabled: + mem_free_total = 1024 * 1024 * 1024 #TODO + mem_free_torch = mem_free_total + elif xpu_available: mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) mem_free_torch = mem_free_total else: @@ -290,9 +319,14 @@ def mps_mode(): def should_use_fp16(): global xpu_available + global directml_enabled + if FORCE_FP32: return False + if directml_enabled: + return False + if cpu_mode() or mps_mode() or xpu_available: return False #TODO ? diff --git a/comfy/sample.py b/comfy/sample.py new file mode 100644 index 000000000..f4132bbed --- /dev/null +++ b/comfy/sample.py @@ -0,0 +1,83 @@ +import torch +import comfy.model_management +import comfy.samplers +import math + +def prepare_noise(latent_image, seed, skip=0): + """ + creates random noise given a latent image and a seed. + optional arg skip can be used to skip and discard x number of noise generations for a given seed + """ + generator = torch.manual_seed(seed) + for _ in range(skip): + noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + return noise + +def prepare_mask(noise_mask, shape, device): + """ensures noise mask is of proper dimensions""" + noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") + noise_mask = noise_mask.round() + noise_mask = torch.cat([noise_mask] * shape[1], dim=1) + if noise_mask.shape[0] < shape[0]: + noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]] + noise_mask = noise_mask.to(device) + return noise_mask + +def broadcast_cond(cond, batch, device): + """broadcasts conditioning to the batch size""" + copy = [] + for p in cond: + t = p[0] + if t.shape[0] < batch: + t = torch.cat([t] * batch) + t = t.to(device) + copy += [[t] + p[1:]] + return copy + +def get_models_from_cond(cond, model_type): + models = [] + for c in cond: + if model_type in c[1]: + models += [c[1][model_type]] + return models + +def load_additional_models(positive, negative): + """loads additional models in positive and negative conditioning""" + control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control") + gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") + gligen = [x[1] for x in gligen] + models = control_nets + gligen + comfy.model_management.load_controlnet_gpu(models) + return models + +def cleanup_additional_models(models): + """cleanup additional models that were loaded""" + for m in models: + m.cleanup() + +def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None): + device = comfy.model_management.get_torch_device() + + if noise_mask is not None: + noise_mask = prepare_mask(noise_mask, noise.shape, device) + + real_model = None + comfy.model_management.load_model_gpu(model) + real_model = model.model + + noise = noise.to(device) + latent_image = latent_image.to(device) + + positive_copy = broadcast_cond(positive, noise.shape[0], device) + negative_copy = broadcast_cond(negative, noise.shape[0], device) + + models = load_additional_models(positive, negative) + + sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) + + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback) + samples = samples.cpu() + + cleanup_additional_models(models) + return samples diff --git a/comfy/samplers.py b/comfy/samplers.py index 19ebc97d9..fc19ddcfc 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -7,23 +7,6 @@ from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps -class CFGDenoiser(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - - def forward(self, x, sigma, uncond, cond, cond_scale): - if len(uncond[0]) == len(cond[0]) and x.shape[0] * x.shape[2] * x.shape[3] < (96 * 96): #TODO check memory instead - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - else: - cond = self.inner_model(x, sigma, cond=cond) - uncond = self.inner_model(x, sigma, cond=uncond) - return uncond + (cond - uncond) * cond_scale - - #The main sampling function shared by all the samplers #Returns predicted noise def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}): @@ -214,7 +197,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con transformer_options = model_options['transformer_options'].copy() if patches is not None: - transformer_options["patches"] = patches + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + else: + transformer_options["patches"] = patches c['transformer_options'] = transformer_options @@ -438,7 +429,7 @@ class KSampler: self.denoise = denoise self.model_options = model_options - def _calculate_sigmas(self, steps): + def calculate_sigmas(self, steps): sigmas = None discard_penultimate_sigma = False @@ -447,13 +438,13 @@ class KSampler: discard_penultimate_sigma = True if self.scheduler == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device) + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) elif self.scheduler == "normal": - sigmas = self.model_wrap.get_sigmas(steps).to(self.device) + sigmas = self.model_wrap.get_sigmas(steps) elif self.scheduler == "simple": - sigmas = simple_scheduler(self.model_wrap, steps).to(self.device) + sigmas = simple_scheduler(self.model_wrap, steps) elif self.scheduler == "ddim_uniform": - sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device) + sigmas = ddim_scheduler(self.model_wrap, steps) else: print("error invalid scheduler", self.scheduler) @@ -464,15 +455,16 @@ class KSampler: def set_steps(self, steps, denoise=None): self.steps = steps if denoise is None or denoise > 0.9999: - self.sigmas = self._calculate_sigmas(steps) + self.sigmas = self.calculate_sigmas(steps).to(self.device) else: new_steps = int(steps/denoise) - sigmas = self._calculate_sigmas(new_steps) + sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None): - sigmas = self.sigmas + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): + if sigmas is None: + sigmas = self.sigmas sigma_min = self.sigma_min if last_step is not None and last_step < (len(sigmas) - 1): @@ -535,9 +527,9 @@ class KSampler: with precision_scope(model_management.get_autocast_device(self.device)): if self.sampler == "uni_pc": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask) + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback) elif self.sampler == "uni_pc_bh2": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2') + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2') elif self.sampler == "ddim": timesteps = [] for s in range(sigmas.shape[0]): @@ -545,6 +537,11 @@ class KSampler: noise_mask = None if denoise_mask is not None: noise_mask = 1.0 - denoise_mask + + ddim_callback = None + if callback is not None: + ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None) + sampler = DDIMSampler(self.model, device=self.device) sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise) @@ -558,6 +555,7 @@ class KSampler: eta=0.0, x_T=z_enc, x0=latent_image, + img_callback=ddim_callback, denoise_function=sampling_function, extra_args=extra_args, mask=noise_mask, @@ -571,13 +569,17 @@ class KSampler: noise = noise * sigmas[0] + k_callback = None + if callback is not None: + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"]) + if latent_image is not None: noise += latent_image if self.sampler == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args) + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback) elif self.sampler == "dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args) + samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback) else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args) + samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback) return samples.to(torch.float32) diff --git a/comfy/sd.py b/comfy/sd.py index 211acd70e..92dbb931d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -254,6 +254,29 @@ class ModelPatcher: def set_model_sampler_cfg_function(self, sampler_cfg_function): self.model_options["sampler_cfg_function"] = sampler_cfg_function + + def set_model_patch(self, patch, name): + to = self.model_options["transformer_options"] + if "patches" not in to: + to["patches"] = {} + to["patches"][name] = to["patches"].get(name, []) + [patch] + + def set_model_attn1_patch(self, patch): + self.set_model_patch(patch, "attn1_patch") + + def set_model_attn2_patch(self, patch): + self.set_model_patch(patch, "attn2_patch") + + def model_patches_to(self, device): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + patch_list[i] = patch_list[i].to(device) + def model_dtype(self): return self.model.diffusion_model.dtype diff --git a/comfy/utils.py b/comfy/utils.py index 0380b91dd..68f93403c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,11 +1,14 @@ import torch -def load_torch_file(ckpt): +def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: - pl_sd = torch.load(ckpt, map_location="cpu") + if safe_load: + pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) + else: + pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: diff --git a/comfy_extras/chainner_models/architecture/block.py b/comfy_extras/chainner_models/architecture/block.py index 1abe1ed8f..214642cc4 100644 --- a/comfy_extras/chainner_models/architecture/block.py +++ b/comfy_extras/chainner_models/architecture/block.py @@ -4,7 +4,10 @@ from __future__ import annotations from collections import OrderedDict -from typing import Literal +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal import torch import torch.nn as nn diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py new file mode 100644 index 000000000..0c7250e43 --- /dev/null +++ b/comfy_extras/nodes_hypernetwork.py @@ -0,0 +1,109 @@ +import comfy.utils +import folder_paths +import torch + +def load_hypernetwork_patch(path, strength): + sd = comfy.utils.load_torch_file(path, safe_load=True) + activation_func = sd.get('activation_func', 'linear') + is_layer_norm = sd.get('is_layer_norm', False) + use_dropout = sd.get('use_dropout', False) + activate_output = sd.get('activate_output', False) + last_layer_dropout = sd.get('last_layer_dropout', False) + + valid_activation = { + "linear": torch.nn.Identity, + "relu": torch.nn.ReLU, + "leakyrelu": torch.nn.LeakyReLU, + "elu": torch.nn.ELU, + "swish": torch.nn.Hardswish, + "tanh": torch.nn.Tanh, + "sigmoid": torch.nn.Sigmoid, + } + + if activation_func not in valid_activation: + print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) + return None + + out = {} + + for d in sd: + try: + dim = int(d) + except: + continue + + output = [] + for index in [0, 1]: + attn_weights = sd[dim][index] + keys = attn_weights.keys() + + linears = filter(lambda a: a.endswith(".weight"), keys) + linears = list(map(lambda a: a[:-len(".weight")], linears)) + layers = [] + + for i in range(len(linears)): + lin_name = linears[i] + last_layer = (i == (len(linears) - 1)) + penultimate_layer = (i == (len(linears) - 2)) + + lin_weight = attn_weights['{}.weight'.format(lin_name)] + lin_bias = attn_weights['{}.bias'.format(lin_name)] + layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) + layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) + layers.append(layer) + if activation_func != "linear": + if (not last_layer) or (activate_output): + layers.append(valid_activation[activation_func]()) + if is_layer_norm: + layers.append(torch.nn.LayerNorm(lin_weight.shape[0])) + if use_dropout: + if (not last_layer) and (not penultimate_layer or last_layer_dropout): + layers.append(torch.nn.Dropout(p=0.3)) + + output.append(torch.nn.Sequential(*layers)) + out[dim] = torch.nn.ModuleList(output) + + class hypernetwork_patch: + def __init__(self, hypernet, strength): + self.hypernet = hypernet + self.strength = strength + def __call__(self, current_index, q, k, v): + dim = k.shape[-1] + if dim in self.hypernet: + hn = self.hypernet[dim] + k = k + hn[0](k) * self.strength + v = v + hn[1](v) * self.strength + + return q, k, v + + def to(self, device): + for d in self.hypernet.keys(): + self.hypernet[d] = self.hypernet[d].to(device) + return self + + return hypernetwork_patch(out, strength) + +class HypernetworkLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_hypernetwork" + + CATEGORY = "loaders" + + def load_hypernetwork(self, model, hypernetwork_name, strength): + hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) + model_hypernetwork = model.clone() + patch = load_hypernetwork_patch(hypernetwork_path, strength) + if patch is not None: + model_hypernetwork.set_model_attn1_patch(patch) + model_hypernetwork.set_model_attn2_patch(patch) + return (model_hypernetwork,) + +NODE_CLASS_MAPPINGS = { + "HypernetworkLoader": HypernetworkLoader +} diff --git a/execution.py b/execution.py index 73be6db03..c19c10bc6 100644 --- a/execution.py +++ b/execution.py @@ -40,15 +40,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = unique_id return input_data_all -def recursive_execute(server, prompt, outputs, current_item, extra_data={}): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return [] - - executed = [] + return for x in inputs: input_data = inputs[x] @@ -57,7 +55,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: @@ -72,7 +70,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}): server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) if "result" in outputs[unique_id]: outputs[unique_id] = outputs[unique_id]["result"] - return executed + [unique_id] + executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item @@ -99,40 +97,44 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item is_changed_old = '' is_changed = '' + to_delete = False if hasattr(class_def, 'IS_CHANGED'): if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: is_changed_old = old_prompt[unique_id]['is_changed'] if 'is_changed' not in prompt[unique_id]: input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: - is_changed = class_def.IS_CHANGED(**input_data_all) - prompt[unique_id]['is_changed'] = is_changed + try: + is_changed = class_def.IS_CHANGED(**input_data_all) + prompt[unique_id]['is_changed'] = is_changed + except: + to_delete = True else: is_changed = prompt[unique_id]['is_changed'] if unique_id not in outputs: return True - to_delete = False - if is_changed != is_changed_old: - to_delete = True - elif unique_id not in old_prompt: - to_delete = True - elif inputs == old_prompt[unique_id]['inputs']: - for x in inputs: - input_data = inputs[x] + if not to_delete: + if is_changed != is_changed_old: + to_delete = True + elif unique_id not in old_prompt: + to_delete = True + elif inputs == old_prompt[unique_id]['inputs']: + for x in inputs: + input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id in outputs: - to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) - else: - to_delete = True - if to_delete: - break - else: - to_delete = True + if isinstance(input_data, list): + input_unique_id = input_data[0] + output_index = input_data[1] + if input_unique_id in outputs: + to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) + else: + to_delete = True + if to_delete: + break + else: + to_delete = True if to_delete: d = outputs.pop(unique_id) @@ -154,11 +156,20 @@ class PromptExecutor: self.server.client_id = None with torch.inference_mode(): + #delete cached outputs if nodes don't exist for them + to_delete = [] + for o in self.outputs: + if o not in prompt: + to_delete += [o] + for o in to_delete: + d = self.outputs.pop(o) + del d + for x in prompt: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) - executed = [] + executed = set() try: to_execute = [] for x in prompt: @@ -181,12 +192,12 @@ class PromptExecutor: except: valid = False if valid: - executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: print(traceback.format_exc()) to_delete = [] for o in self.outputs: - if o not in current_outputs: + if (o not in current_outputs) and (o not in executed): to_delete += [o] if o in self.old_prompt: d = self.old_prompt.pop(o) @@ -194,11 +205,9 @@ class PromptExecutor: for o in to_delete: d = self.outputs.pop(o) del d - else: - executed = set(executed) + finally: for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) - finally: self.server.last_node_id = None if self.server.client_id is not None: self.server.send_sync("executing", { "node": None }, self.server.client_id) @@ -249,9 +258,15 @@ def validate_inputs(prompt, item): if "max" in info[1] and val > info[1]["max"]: return (False, "Value bigger than max. {}, {}".format(class_type, x)) - if isinstance(type_input, list): - if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + if hasattr(obj_class, "VALIDATE_INPUTS"): + input_data_all = get_input_data(inputs, obj_class, unique_id) + ret = obj_class.VALIDATE_INPUTS(**input_data_all) + if ret != True: + return (False, "{}, {}".format(class_type, ret)) + else: + if isinstance(type_input, list): + if val not in type_input: + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) return (True, "") def validate_prompt(prompt): @@ -273,7 +288,8 @@ def validate_prompt(prompt): m = validate_inputs(prompt, o) valid = m[0] reason = m[1] - except: + except Exception as e: + print(traceback.format_exc()) valid = False reason = "Parsing error" diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index ac1ffe9d2..fa5418a68 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -13,6 +13,7 @@ a111: models/ESRGAN models/SwinIR embeddings: embeddings + hypernetworks: models/hypernetworks controlnet: models/ControlNet #other_ui: diff --git a/folder_paths.py b/folder_paths.py index 3c4ad3711..e5b89492c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -32,6 +32,7 @@ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_m folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) +folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") @@ -68,6 +69,46 @@ def get_directory_by_type(type_name): return None +# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format +# otherwise use default_path as base_dir +def annotated_filepath(name): + if name.endswith("[output]"): + base_dir = get_output_directory() + name = name[:-9] + elif name.endswith("[input]"): + base_dir = get_input_directory() + name = name[:-8] + elif name.endswith("[temp]"): + base_dir = get_temp_directory() + name = name[:-7] + else: + return name, None + + return name, base_dir + + +def get_annotated_filepath(name, default_dir=None): + name, base_dir = annotated_filepath(name) + + if base_dir is None: + if default_dir is not None: + base_dir = default_dir + else: + base_dir = get_input_directory() # fallback path + + return os.path.join(base_dir, name) + + +def exists_annotated_filepath(name): + name, base_dir = annotated_filepath(name) + + if base_dir is None: + base_dir = get_input_directory() # fallback path + + filepath = os.path.join(base_dir, name) + return os.path.exists(filepath) + + def add_model_folder_path(folder_name, full_folder_path): global folder_names_and_paths if folder_name in folder_names_and_paths: diff --git a/models/hypernetworks/put_hypernetworks_here b/models/hypernetworks/put_hypernetworks_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 4039a88cf..8e6825292 100644 --- a/nodes.py +++ b/nodes.py @@ -16,6 +16,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "co import comfy.diffusers_convert import comfy.samplers +import comfy.sample import comfy.sd import comfy.utils @@ -203,24 +204,24 @@ class VAEEncodeForInpaint: def encode(self, vae, pixels, mask): x = (pixels.shape[1] // 64) * 64 y = (pixels.shape[2] // 64) * 64 - mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0] + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: pixels = pixels[:,:x,:y,:] - mask = mask[:x,:y] + mask = mask[:,:,:x,:y] #grow mask by a few pixels to keep things seamless in latent space kernel_tensor = torch.ones((1, 1, 6, 6)) - mask_erosion = torch.clamp(torch.nn.functional.conv2d((mask.round())[None], kernel_tensor, padding=3), 0, 1) - m = (1.0 - mask.round()) + mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=3), 0, 1) + m = (1.0 - mask.round()).squeeze(1) for i in range(3): pixels[:,:,:,i] -= 0.5 pixels[:,:,:,i] *= m pixels[:,:,:,i] += 0.5 t = vae.encode(pixels) - return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, ) + return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) class CheckpointLoader: @classmethod @@ -771,79 +772,23 @@ class SetLatentNoiseMask: s["noise_mask"] = mask return (s,) - def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): - latent_image = latent["samples"] - noise_mask = None device = comfy.model_management.get_torch_device() + latent_image = latent["samples"] if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - batch_index = 0 - if "batch_index" in latent: - batch_index = latent["batch_index"] - - generator = torch.manual_seed(seed) - for i in range(batch_index): - noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + skip = latent["batch_index"] if "batch_index" in latent else 0 + noise = comfy.sample.prepare_noise(latent_image, seed, skip) + noise_mask = None if "noise_mask" in latent: - noise_mask = latent['noise_mask'] - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") - noise_mask = noise_mask.round() - noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) - noise_mask = torch.cat([noise_mask] * noise.shape[0]) - noise_mask = noise_mask.to(device) - - real_model = None - comfy.model_management.load_model_gpu(model) - real_model = model.model - - noise = noise.to(device) - latent_image = latent_image.to(device) - - positive_copy = [] - negative_copy = [] - - control_nets = [] - def get_models(cond): - models = [] - for c in cond: - if 'control' in c[1]: - models += [c[1]['control']] - if 'gligen' in c[1]: - models += [c[1]['gligen'][1]] - return models - - for p in positive: - t = p[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) - t = t.to(device) - positive_copy += [[t] + p[1:]] - for n in negative: - t = n[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) - t = t.to(device) - negative_copy += [[t] + n[1:]] - - models = get_models(positive) + get_models(negative) - comfy.model_management.load_controlnet_gpu(models) - - if sampler_name in comfy.samplers.KSampler.SAMPLERS: - sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - else: - #other samplers - pass - - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) - samples = samples.cpu() - for m in models: - m.cleanup() + noise_mask = latent["noise_mask"] + samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, + force_full_denoise=force_full_denoise, noise_mask=noise_mask) out = latent.copy() out["samples"] = samples return (out, ) @@ -1006,8 +951,7 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" def load_image(self, image): - input_dir = folder_paths.get_input_directory() - image_path = os.path.join(input_dir, image) + image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -1021,20 +965,27 @@ class LoadImage: @classmethod def IS_CHANGED(s, image): - input_dir = folder_paths.get_input_directory() - image_path = os.path.join(input_dir, image) + image_path = folder_paths.get_annotated_filepath(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() + @classmethod + def VALIDATE_INPUTS(s, image): + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + + return True + class LoadImageMask: + _color_channels = ["alpha", "red", "green", "blue"] @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() return {"required": {"image": (sorted(os.listdir(input_dir)), ), - "channel": (["alpha", "red", "green", "blue"], ),} + "channel": (s._color_channels, ),} } CATEGORY = "mask" @@ -1042,8 +993,7 @@ class LoadImageMask: RETURN_TYPES = ("MASK",) FUNCTION = "load_image" def load_image(self, image, channel): - input_dir = folder_paths.get_input_directory() - image_path = os.path.join(input_dir, image) + image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) if i.getbands() != ("R", "G", "B", "A"): i = i.convert("RGBA") @@ -1060,13 +1010,22 @@ class LoadImageMask: @classmethod def IS_CHANGED(s, image, channel): - input_dir = folder_paths.get_input_directory() - image_path = os.path.join(input_dir, image) + image_path = folder_paths.get_annotated_filepath(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() + @classmethod + def VALIDATE_INPUTS(s, image, channel): + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + + if channel not in s._color_channels: + return "Invalid color channel: {}".format(channel) + + return True + class ImageScale: upscale_methods = ["nearest-exact", "bilinear", "area"] crop_methods = ["disabled", "center"] @@ -1302,6 +1261,7 @@ def load_custom_nodes(): def init_custom_nodes(): load_custom_nodes() + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index c1982d8be..fecfa6707 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -47,7 +47,7 @@ " !git pull\n", "\n", "!echo -= Install dependencies =-\n", - "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118" + "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117" ] }, { diff --git a/server.py b/server.py index b5403670f..1c5c17916 100644 --- a/server.py +++ b/server.py @@ -112,13 +112,20 @@ class PromptServer(): @routes.post("/upload/image") async def upload_image(request): - upload_dir = folder_paths.get_input_directory() + post = await request.post() + image = post.get("image") + + if post.get("type") is None: + upload_dir = folder_paths.get_input_directory() + elif post.get("type") == "input": + upload_dir = folder_paths.get_input_directory() + elif post.get("type") == "temp": + upload_dir = folder_paths.get_temp_directory() + elif post.get("type") == "output": + upload_dir = folder_paths.get_output_directory() if not os.path.exists(upload_dir): os.makedirs(upload_dir) - - post = await request.post() - image = post.get("image") if image and image.file: filename = image.filename diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index bebc80b12..b937bb103 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -89,24 +89,17 @@ app.registerExtension({ end = nearestEnclosure.end; selectedText = inputField.value.substring(start, end); } else { - // Select the current word, find the start and end of the word (first space before and after) - const wordStart = inputField.value.substring(0, start).lastIndexOf(" ") + 1; - const wordEnd = inputField.value.substring(end).indexOf(" "); - // If there is no space after the word, select to the end of the string - if (wordEnd === -1) { - end = inputField.value.length; - } else { - end += wordEnd; + // Select the current word, find the start and end of the word + const delimiters = " .,\\/!?%^*;:{}=-_`~()\r\n\t"; + + while (!delimiters.includes(inputField.value[start - 1]) && start > 0) { + start--; + } + + while (!delimiters.includes(inputField.value[end]) && end < inputField.value.length) { + end++; } - start = wordStart; - // Remove all punctuation at the end and beginning of the word - while (inputField.value[start].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) { - start++; - } - while (inputField.value[end - 1].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) { - end--; - } selectedText = inputField.value.substring(start, end); if (!selectedText) return; } @@ -135,8 +128,13 @@ app.registerExtension({ // Increment the weight const weightDelta = event.key === "ArrowUp" ? delta : -delta; - const updatedText = selectedText.replace(/(.*:)(\d+(\.\d+)?)(.*)/, (match, prefix, weight, _, suffix) => { - return prefix + incrementWeight(weight, weightDelta) + suffix; + const updatedText = selectedText.replace(/\((.*):(\d+(?:\.\d+)?)\)/, (match, text, weight) => { + weight = incrementWeight(weight, weightDelta); + if (weight == 1) { + return text; + } else { + return `(${text}:${weight})`; + } }); inputField.setRangeText(updatedText, start, end, "select"); diff --git a/web/extensions/core/slotDefaults.js b/web/extensions/core/slotDefaults.js index 0b6a0a150..3ec605900 100644 --- a/web/extensions/core/slotDefaults.js +++ b/web/extensions/core/slotDefaults.js @@ -1,21 +1,72 @@ import { app } from "/scripts/app.js"; - +import { ComfyWidgets } from "/scripts/widgets.js"; // Adds defaults for quickly adding nodes with middle click on the input/output app.registerExtension({ name: "Comfy.SlotDefaults", + suggestionsNumber: null, init() { LiteGraph.middle_click_slot_add_default_node = true; - LiteGraph.slot_types_default_in = { - MODEL: "CheckpointLoaderSimple", - LATENT: "EmptyLatentImage", - VAE: "VAELoader", - }; - - LiteGraph.slot_types_default_out = { - LATENT: "VAEDecode", - IMAGE: "SaveImage", - CLIP: "CLIPTextEncode", - }; + this.suggestionsNumber = app.ui.settings.addSetting({ + id: "Comfy.NodeSuggestions.number", + name: "number of nodes suggestions", + type: "slider", + attrs: { + min: 1, + max: 100, + step: 1, + }, + defaultValue: 5, + onChange: (newVal, oldVal) => { + this.setDefaults(newVal); + } + }); }, + slot_types_default_out: {}, + slot_types_default_in: {}, + async beforeRegisterNodeDef(nodeType, nodeData, app) { + var nodeId = nodeData.name; + var inputs = []; + inputs = nodeData["input"]["required"]; //only show required inputs to reduce the mess also not logical to create node with optional inputs + for (const inputKey in inputs) { + var input = (inputs[inputKey]); + if (typeof input[0] !== "string") continue; + + var type = input[0] + if (type in ComfyWidgets) { + var customProperties = input[1] + if (!(customProperties?.forceInput)) continue; //ignore widgets that don't force input + } + + if (!(type in this.slot_types_default_out)) { + this.slot_types_default_out[type] = ["Reroute"]; + } + if (this.slot_types_default_out[type].includes(nodeId)) continue; + this.slot_types_default_out[type].push(nodeId); + } + + var outputs = nodeData["output"]; + for (const key in outputs) { + var type = outputs[key]; + if (!(type in this.slot_types_default_in)) { + this.slot_types_default_in[type] = ["Reroute"];// ["Reroute", "Primitive"]; primitive doesn't always work :'() + } + + this.slot_types_default_in[type].push(nodeId); + } + var maxNum = this.suggestionsNumber.value; + this.setDefaults(maxNum); + }, + setDefaults(maxNum) { + + LiteGraph.slot_types_default_out = {}; + LiteGraph.slot_types_default_in = {}; + + for (const type in this.slot_types_default_out) { + LiteGraph.slot_types_default_out[type] = this.slot_types_default_out[type].slice(0, maxNum); + } + for (const type in this.slot_types_default_in) { + LiteGraph.slot_types_default_in[type] = this.slot_types_default_in[type].slice(0, maxNum); + } + } }); diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 4189a48c0..20ec35476 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -9953,11 +9953,11 @@ LGraphNode.prototype.executeAction = function(action) } break; case "slider": - var range = w.options.max - w.options.min; + var old_value = w.value; var nvalue = Math.clamp((x - 15) / (widget_width - 30), 0, 1); if(w.options.read_only) break; w.value = w.options.min + (w.options.max - w.options.min) * nvalue; - if (w.callback) { + if (old_value != w.value) { setTimeout(function() { inner_value_change(w, w.value); }, 20); @@ -10044,7 +10044,7 @@ LGraphNode.prototype.executeAction = function(action) if (event.click_time < 200 && delta == 0) { this.prompt("Value",w.value,function(v) { // check if v is a valid equation or a number - if (/^[0-9+\-*/()\s]+$/.test(v)) { + if (/^[0-9+\-*/()\s]+|\d+\.\d+$/.test(v)) { try {//solve the equation if possible v = eval(v); } catch (e) { } diff --git a/web/scripts/api.js b/web/scripts/api.js index 2b90c2abc..d29faa5ba 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -35,7 +35,7 @@ class ComfyApi extends EventTarget { } let opened = false; - let existingSession = sessionStorage["Comfy.SessionId"] || ""; + let existingSession = window.name; if (existingSession) { existingSession = "?clientId=" + existingSession; } @@ -75,7 +75,7 @@ class ComfyApi extends EventTarget { case "status": if (msg.data.sid) { this.clientId = msg.data.sid; - sessionStorage["Comfy.SessionId"] = this.clientId; + window.name = this.clientId; } this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); break; diff --git a/web/scripts/app.js b/web/scripts/app.js index f158f3457..a161bf40e 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -20,6 +20,12 @@ export class ComfyApp { */ #processingQueue = false; + /** + * Content Clipboard + * @type {serialized node object} + */ + static clipspace = null; + constructor() { this.ui = new ComfyUI(this); @@ -130,6 +136,83 @@ export class ComfyApp { ); } } + + options.push( + { + content: "Copy (Clipspace)", + callback: (obj) => { + var widgets = null; + if(this.widgets) { + widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); + } + + let img = new Image(); + var imgs = undefined; + if(this.imgs != undefined) { + img.src = this.imgs[0].src; + imgs = [img]; + } + + ComfyApp.clipspace = { + 'widgets': widgets, + 'imgs': imgs, + 'original_imgs': imgs, + 'images': this.images + }; + } + }); + + if(ComfyApp.clipspace != null) { + options.push( + { + content: "Paste (Clipspace)", + callback: () => { + if(ComfyApp.clipspace != null) { + if(ComfyApp.clipspace.widgets != null && this.widgets != null) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); + if (prop) { + prop.callback(value); + } + }); + } + + // image paste + if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) { + var filename = ""; + if(this.images && ComfyApp.clipspace.images) { + this.images = ComfyApp.clipspace.images; + } + + if(ComfyApp.clipspace.images != undefined) { + const clip_image = ComfyApp.clipspace.images[0]; + if(clip_image.subfolder != '') + filename = `${clip_image.subfolder}/`; + filename += `${clip_image.filename} [${clip_image.type}]`; + } + else if(ComfyApp.clipspace.widgets != undefined) { + const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); + if(index_in_clip >= 0) { + filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`; + } + } + + const index = this.widgets.findIndex(obj => obj.name === 'image'); + if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) { + this.imgs = ComfyApp.clipspace.imgs; + + this.widgets[index].value = filename; + if(this.widgets_values != undefined) { + this.widgets_values[index] = filename; + } + } + } + this.trigger('changed'); + } + } + } + ); + } }; } diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 2acc5f2c0..c0e73ffa1 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -136,9 +136,11 @@ function addMultilineWidget(node, name, opts, app) { left: `${t.a * margin + t.e}px`, top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`, width: `${(widgetWidth - margin * 2 - 3) * t.a}px`, + background: (!node.color)?'':node.color, height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`, position: "absolute", - zIndex: 1, + color: (!node.color)?'':'white', + zIndex: app.graph._nodes.indexOf(node), fontSize: `${t.d * 10.0}px`, }); this.inputEl.hidden = !visible; @@ -270,6 +272,9 @@ export const ComfyWidgets = { app.graph.setDirtyCanvas(true); }; img.src = `/view?filename=${name}&type=input`; + if ((node.size[1] - node.imageOffset) < 100) { + node.size[1] = 250 + node.imageOffset; + } } // Add our own callback to the combo widget to render an image when it changes