mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 13:35:05 +00:00
Merge branch 'comfyanonymous:master' into weightedConditionCombine
This commit is contained in:
@@ -1,65 +0,0 @@
|
|||||||
import pygit2
|
|
||||||
from datetime import datetime
|
|
||||||
import sys
|
|
||||||
|
|
||||||
def pull(repo, remote_name='origin', branch='master'):
|
|
||||||
for remote in repo.remotes:
|
|
||||||
if remote.name == remote_name:
|
|
||||||
remote.fetch()
|
|
||||||
remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target
|
|
||||||
merge_result, _ = repo.merge_analysis(remote_master_id)
|
|
||||||
# Up to date, do nothing
|
|
||||||
if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
|
|
||||||
return
|
|
||||||
# We can just fastforward
|
|
||||||
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
|
|
||||||
repo.checkout_tree(repo.get(remote_master_id))
|
|
||||||
try:
|
|
||||||
master_ref = repo.lookup_reference('refs/heads/%s' % (branch))
|
|
||||||
master_ref.set_target(remote_master_id)
|
|
||||||
except KeyError:
|
|
||||||
repo.create_branch(branch, repo.get(remote_master_id))
|
|
||||||
repo.head.set_target(remote_master_id)
|
|
||||||
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
|
|
||||||
repo.merge(remote_master_id)
|
|
||||||
|
|
||||||
if repo.index.conflicts is not None:
|
|
||||||
for conflict in repo.index.conflicts:
|
|
||||||
print('Conflicts found in:', conflict[0].path)
|
|
||||||
raise AssertionError('Conflicts, ahhhhh!!')
|
|
||||||
|
|
||||||
user = repo.default_signature
|
|
||||||
tree = repo.index.write_tree()
|
|
||||||
commit = repo.create_commit('HEAD',
|
|
||||||
user,
|
|
||||||
user,
|
|
||||||
'Merge!',
|
|
||||||
tree,
|
|
||||||
[repo.head.target, remote_master_id])
|
|
||||||
# We need to do this or git CLI will think we are still merging.
|
|
||||||
repo.state_cleanup()
|
|
||||||
else:
|
|
||||||
raise AssertionError('Unknown merge analysis result')
|
|
||||||
|
|
||||||
|
|
||||||
repo = pygit2.Repository(str(sys.argv[1]))
|
|
||||||
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
|
||||||
try:
|
|
||||||
print("stashing current changes")
|
|
||||||
repo.stash(ident)
|
|
||||||
except KeyError:
|
|
||||||
print("nothing to stash")
|
|
||||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
|
||||||
print("creating backup branch: {}".format(backup_branch_name))
|
|
||||||
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
|
||||||
|
|
||||||
print("checking out master branch")
|
|
||||||
branch = repo.lookup_branch('master')
|
|
||||||
ref = repo.lookup_reference(branch.name)
|
|
||||||
repo.checkout(ref)
|
|
||||||
|
|
||||||
print("pulling latest changes")
|
|
||||||
pull(repo)
|
|
||||||
|
|
||||||
print("Done!")
|
|
||||||
|
|
@@ -1,2 +0,0 @@
|
|||||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
|
||||||
pause
|
|
@@ -1,3 +1,3 @@
|
|||||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
||||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r ../ComfyUI/requirements.txt pygit2
|
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2
|
||||||
pause
|
pause
|
||||||
|
@@ -1,27 +0,0 @@
|
|||||||
HOW TO RUN:
|
|
||||||
|
|
||||||
if you have a NVIDIA gpu:
|
|
||||||
|
|
||||||
run_nvidia_gpu.bat
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
To run it in slow CPU mode:
|
|
||||||
|
|
||||||
run_cpu.bat
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
|
||||||
|
|
||||||
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
RECOMMENDED WAY TO UPDATE:
|
|
||||||
To update the ComfyUI code: update\update_comfyui.bat
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
To update ComfyUI with the python dependencies:
|
|
||||||
update\update_comfyui_and_python_dependencies.bat
|
|
@@ -1,2 +0,0 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build
|
|
||||||
pause
|
|
@@ -17,7 +17,7 @@ jobs:
|
|||||||
|
|
||||||
- shell: bash
|
- shell: bash
|
||||||
run: |
|
run: |
|
||||||
python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers==0.0.19.dev516 --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
|
python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
|
||||||
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
python -m pip install --no-cache-dir ./temp_wheel_dir/*
|
||||||
echo installed basic
|
echo installed basic
|
||||||
ls -lah temp_wheel_dir
|
ls -lah temp_wheel_dir
|
||||||
|
@@ -46,6 +46,8 @@ jobs:
|
|||||||
mkdir update
|
mkdir update
|
||||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||||
|
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
|
||||||
|
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
|
||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
|
@@ -7,6 +7,8 @@ A powerful and modular stable diffusion GUI and backend.
|
|||||||
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
|
||||||
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||||
|
|
||||||
|
### [Installing ComfyUI](#installing)
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
- Fully supports SD1.x and SD2.x
|
- Fully supports SD1.x and SD2.x
|
||||||
@@ -17,6 +19,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
||||||
- Embeddings/Textual inversion
|
- Embeddings/Textual inversion
|
||||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||||
|
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
|
||||||
- Loading full workflows (with seeds) from generated PNG files.
|
- Loading full workflows (with seeds) from generated PNG files.
|
||||||
- Saving/Loading workflows as Json files.
|
- Saving/Loading workflows as Json files.
|
||||||
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
|
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
|
||||||
|
@@ -10,6 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the
|
|||||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||||
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||||
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||||
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
||||||
|
@@ -712,7 +712,7 @@ class UniPC:
|
|||||||
|
|
||||||
def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
|
||||||
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||||
atol=0.0078, rtol=0.05, corrector=False,
|
atol=0.0078, rtol=0.05, corrector=False, callback=None
|
||||||
):
|
):
|
||||||
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||||
t_T = self.noise_schedule.T if t_start is None else t_start
|
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||||
@@ -766,6 +766,8 @@ class UniPC:
|
|||||||
if model_x is None:
|
if model_x is None:
|
||||||
model_x = self.model_fn(x, vec_t)
|
model_x = self.model_fn(x, vec_t)
|
||||||
model_prev_list[-1] = model_x
|
model_prev_list[-1] = model_x
|
||||||
|
if callback is not None:
|
||||||
|
callback(step_index, model_prev_list[-1], x)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if denoise_to_zero:
|
if denoise_to_zero:
|
||||||
@@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|||||||
|
|
||||||
order = min(3, len(timesteps) - 1)
|
order = min(3, len(timesteps) - 1)
|
||||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
||||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True)
|
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback)
|
||||||
if not to_zero:
|
if not to_zero:
|
||||||
x /= ns.marginal_alpha(timesteps[-1])
|
x /= ns.marginal_alpha(timesteps[-1])
|
||||||
return x
|
return x
|
||||||
|
@@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
query = self.to_q(x)
|
query = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
key = self.to_k(context)
|
key = self.to_k(context)
|
||||||
value = self.to_v(context)
|
if value is not None:
|
||||||
|
value = self.to_v(value)
|
||||||
|
else:
|
||||||
|
value = self.to_v(context)
|
||||||
|
|
||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
@@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k_in = self.to_k(context)
|
k_in = self.to_k(context)
|
||||||
v_in = self.to_v(context)
|
if value is not None:
|
||||||
|
v_in = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
|
v_in = self.to_v(context)
|
||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
@@ -350,13 +358,17 @@ class CrossAttention(nn.Module):
|
|||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
v = self.to_v(context)
|
if value is not None:
|
||||||
|
v = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
@@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
v = self.to_v(context)
|
if value is not None:
|
||||||
|
v = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
b, _, _ = q.shape
|
b, _, _ = q.shape
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
@@ -447,19 +463,19 @@ class CrossAttentionPytorch(nn.Module):
|
|||||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
v = self.to_v(context)
|
if value is not None:
|
||||||
|
v = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
b, _, _ = q.shape
|
b, _, _ = q.shape
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.unsqueeze(3)
|
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
|
||||||
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
|
||||||
.contiguous(),
|
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -468,10 +484,7 @@ class CrossAttentionPytorch(nn.Module):
|
|||||||
if exists(mask):
|
if exists(mask):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
out = (
|
out = (
|
||||||
out.unsqueeze(0)
|
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
|
||||||
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
@@ -519,11 +532,25 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
transformer_patches = {}
|
transformer_patches = {}
|
||||||
|
|
||||||
n = self.norm1(x)
|
n = self.norm1(x)
|
||||||
|
if self.disable_self_attn:
|
||||||
|
context_attn1 = context
|
||||||
|
else:
|
||||||
|
context_attn1 = None
|
||||||
|
value_attn1 = None
|
||||||
|
|
||||||
|
if "attn1_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_patch"]
|
||||||
|
if context_attn1 is None:
|
||||||
|
context_attn1 = n
|
||||||
|
value_attn1 = context_attn1
|
||||||
|
for p in patch:
|
||||||
|
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
|
||||||
|
|
||||||
if "tomesd" in transformer_options:
|
if "tomesd" in transformer_options:
|
||||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||||
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
|
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
||||||
else:
|
else:
|
||||||
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
@@ -532,7 +559,16 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
x = p(current_index, x)
|
x = p(current_index, x)
|
||||||
|
|
||||||
n = self.norm2(x)
|
n = self.norm2(x)
|
||||||
n = self.attn2(n, context=context)
|
|
||||||
|
context_attn2 = context
|
||||||
|
value_attn2 = None
|
||||||
|
if "attn2_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn2_patch"]
|
||||||
|
value_attn2 = context_attn2
|
||||||
|
for p in patch:
|
||||||
|
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
|
||||||
|
|
||||||
|
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
|
@@ -20,6 +20,18 @@ total_vram_available_mb = -1
|
|||||||
accelerate_enabled = False
|
accelerate_enabled = False
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
|
||||||
|
directml_enabled = False
|
||||||
|
if args.directml is not None:
|
||||||
|
import torch_directml
|
||||||
|
directml_enabled = True
|
||||||
|
device_index = args.directml
|
||||||
|
if device_index < 0:
|
||||||
|
directml_device = torch_directml.device()
|
||||||
|
else:
|
||||||
|
directml_device = torch_directml.device(device_index)
|
||||||
|
print("Using directml with device:", torch_directml.device_name(device_index))
|
||||||
|
# torch_directml.disable_tiled_resources(True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
try:
|
try:
|
||||||
@@ -133,6 +145,7 @@ def unload_model():
|
|||||||
#never unload models from GPU on high vram
|
#never unload models from GPU on high vram
|
||||||
if vram_state != VRAMState.HIGH_VRAM:
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
current_loaded_model.model.cpu()
|
current_loaded_model.model.cpu()
|
||||||
|
current_loaded_model.model_patches_to("cpu")
|
||||||
current_loaded_model.unpatch_model()
|
current_loaded_model.unpatch_model()
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
|
|
||||||
@@ -156,6 +169,8 @@ def load_model_gpu(model):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
model.unpatch_model()
|
model.unpatch_model()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
model.model_patches_to(get_torch_device())
|
||||||
current_loaded_model = model
|
current_loaded_model = model
|
||||||
if vram_state == VRAMState.CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
pass
|
pass
|
||||||
@@ -214,6 +229,10 @@ def unload_if_low_vram(model):
|
|||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
|
global directml_enabled
|
||||||
|
if directml_enabled:
|
||||||
|
global directml_device
|
||||||
|
return directml_device
|
||||||
if vram_state == VRAMState.MPS:
|
if vram_state == VRAMState.MPS:
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
if vram_state == VRAMState.CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
@@ -231,8 +250,14 @@ def get_autocast_device(dev):
|
|||||||
|
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
|
global xpu_available
|
||||||
|
global directml_enabled
|
||||||
if vram_state == VRAMState.CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
return False
|
return False
|
||||||
|
if xpu_available:
|
||||||
|
return False
|
||||||
|
if directml_enabled:
|
||||||
|
return False
|
||||||
return XFORMERS_IS_AVAILABLE
|
return XFORMERS_IS_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
@@ -248,6 +273,7 @@ def pytorch_attention_enabled():
|
|||||||
|
|
||||||
def get_free_memory(dev=None, torch_free_too=False):
|
def get_free_memory(dev=None, torch_free_too=False):
|
||||||
global xpu_available
|
global xpu_available
|
||||||
|
global directml_enabled
|
||||||
if dev is None:
|
if dev is None:
|
||||||
dev = get_torch_device()
|
dev = get_torch_device()
|
||||||
|
|
||||||
@@ -255,7 +281,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
mem_free_total = psutil.virtual_memory().available
|
mem_free_total = psutil.virtual_memory().available
|
||||||
mem_free_torch = mem_free_total
|
mem_free_torch = mem_free_total
|
||||||
else:
|
else:
|
||||||
if xpu_available:
|
if directml_enabled:
|
||||||
|
mem_free_total = 1024 * 1024 * 1024 #TODO
|
||||||
|
mem_free_torch = mem_free_total
|
||||||
|
elif xpu_available:
|
||||||
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
|
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
|
||||||
mem_free_torch = mem_free_total
|
mem_free_torch = mem_free_total
|
||||||
else:
|
else:
|
||||||
@@ -290,9 +319,14 @@ def mps_mode():
|
|||||||
|
|
||||||
def should_use_fp16():
|
def should_use_fp16():
|
||||||
global xpu_available
|
global xpu_available
|
||||||
|
global directml_enabled
|
||||||
|
|
||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if directml_enabled:
|
||||||
|
return False
|
||||||
|
|
||||||
if cpu_mode() or mps_mode() or xpu_available:
|
if cpu_mode() or mps_mode() or xpu_available:
|
||||||
return False #TODO ?
|
return False #TODO ?
|
||||||
|
|
||||||
|
83
comfy/sample.py
Normal file
83
comfy/sample.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.samplers
|
||||||
|
import math
|
||||||
|
|
||||||
|
def prepare_noise(latent_image, seed, skip=0):
|
||||||
|
"""
|
||||||
|
creates random noise given a latent image and a seed.
|
||||||
|
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||||
|
"""
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
for _ in range(skip):
|
||||||
|
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||||
|
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||||
|
return noise
|
||||||
|
|
||||||
|
def prepare_mask(noise_mask, shape, device):
|
||||||
|
"""ensures noise mask is of proper dimensions"""
|
||||||
|
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
|
||||||
|
noise_mask = noise_mask.round()
|
||||||
|
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
|
||||||
|
if noise_mask.shape[0] < shape[0]:
|
||||||
|
noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]]
|
||||||
|
noise_mask = noise_mask.to(device)
|
||||||
|
return noise_mask
|
||||||
|
|
||||||
|
def broadcast_cond(cond, batch, device):
|
||||||
|
"""broadcasts conditioning to the batch size"""
|
||||||
|
copy = []
|
||||||
|
for p in cond:
|
||||||
|
t = p[0]
|
||||||
|
if t.shape[0] < batch:
|
||||||
|
t = torch.cat([t] * batch)
|
||||||
|
t = t.to(device)
|
||||||
|
copy += [[t] + p[1:]]
|
||||||
|
return copy
|
||||||
|
|
||||||
|
def get_models_from_cond(cond, model_type):
|
||||||
|
models = []
|
||||||
|
for c in cond:
|
||||||
|
if model_type in c[1]:
|
||||||
|
models += [c[1][model_type]]
|
||||||
|
return models
|
||||||
|
|
||||||
|
def load_additional_models(positive, negative):
|
||||||
|
"""loads additional models in positive and negative conditioning"""
|
||||||
|
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
|
||||||
|
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
|
||||||
|
gligen = [x[1] for x in gligen]
|
||||||
|
models = control_nets + gligen
|
||||||
|
comfy.model_management.load_controlnet_gpu(models)
|
||||||
|
return models
|
||||||
|
|
||||||
|
def cleanup_additional_models(models):
|
||||||
|
"""cleanup additional models that were loaded"""
|
||||||
|
for m in models:
|
||||||
|
m.cleanup()
|
||||||
|
|
||||||
|
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None):
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
|
if noise_mask is not None:
|
||||||
|
noise_mask = prepare_mask(noise_mask, noise.shape, device)
|
||||||
|
|
||||||
|
real_model = None
|
||||||
|
comfy.model_management.load_model_gpu(model)
|
||||||
|
real_model = model.model
|
||||||
|
|
||||||
|
noise = noise.to(device)
|
||||||
|
latent_image = latent_image.to(device)
|
||||||
|
|
||||||
|
positive_copy = broadcast_cond(positive, noise.shape[0], device)
|
||||||
|
negative_copy = broadcast_cond(negative, noise.shape[0], device)
|
||||||
|
|
||||||
|
models = load_additional_models(positive, negative)
|
||||||
|
|
||||||
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||||
|
|
||||||
|
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback)
|
||||||
|
samples = samples.cpu()
|
||||||
|
|
||||||
|
cleanup_additional_models(models)
|
||||||
|
return samples
|
@@ -7,23 +7,6 @@ from comfy import model_management
|
|||||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||||
|
|
||||||
class CFGDenoiser(torch.nn.Module):
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = model
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
|
||||||
if len(uncond[0]) == len(cond[0]) and x.shape[0] * x.shape[2] * x.shape[3] < (96 * 96): #TODO check memory instead
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
|
||||||
cond_in = torch.cat([uncond, cond])
|
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
|
||||||
else:
|
|
||||||
cond = self.inner_model(x, sigma, cond=cond)
|
|
||||||
uncond = self.inner_model(x, sigma, cond=uncond)
|
|
||||||
return uncond + (cond - uncond) * cond_scale
|
|
||||||
|
|
||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns predicted noise
|
||||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
|
||||||
@@ -214,7 +197,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
transformer_options = model_options['transformer_options'].copy()
|
transformer_options = model_options['transformer_options'].copy()
|
||||||
|
|
||||||
if patches is not None:
|
if patches is not None:
|
||||||
transformer_options["patches"] = patches
|
if "patches" in transformer_options:
|
||||||
|
cur_patches = transformer_options["patches"].copy()
|
||||||
|
for p in patches:
|
||||||
|
if p in cur_patches:
|
||||||
|
cur_patches[p] = cur_patches[p] + patches[p]
|
||||||
|
else:
|
||||||
|
cur_patches[p] = patches[p]
|
||||||
|
else:
|
||||||
|
transformer_options["patches"] = patches
|
||||||
|
|
||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
@@ -438,7 +429,7 @@ class KSampler:
|
|||||||
self.denoise = denoise
|
self.denoise = denoise
|
||||||
self.model_options = model_options
|
self.model_options = model_options
|
||||||
|
|
||||||
def _calculate_sigmas(self, steps):
|
def calculate_sigmas(self, steps):
|
||||||
sigmas = None
|
sigmas = None
|
||||||
|
|
||||||
discard_penultimate_sigma = False
|
discard_penultimate_sigma = False
|
||||||
@@ -447,13 +438,13 @@ class KSampler:
|
|||||||
discard_penultimate_sigma = True
|
discard_penultimate_sigma = True
|
||||||
|
|
||||||
if self.scheduler == "karras":
|
if self.scheduler == "karras":
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device)
|
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
||||||
elif self.scheduler == "normal":
|
elif self.scheduler == "normal":
|
||||||
sigmas = self.model_wrap.get_sigmas(steps).to(self.device)
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
elif self.scheduler == "simple":
|
elif self.scheduler == "simple":
|
||||||
sigmas = simple_scheduler(self.model_wrap, steps).to(self.device)
|
sigmas = simple_scheduler(self.model_wrap, steps)
|
||||||
elif self.scheduler == "ddim_uniform":
|
elif self.scheduler == "ddim_uniform":
|
||||||
sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device)
|
sigmas = ddim_scheduler(self.model_wrap, steps)
|
||||||
else:
|
else:
|
||||||
print("error invalid scheduler", self.scheduler)
|
print("error invalid scheduler", self.scheduler)
|
||||||
|
|
||||||
@@ -464,15 +455,16 @@ class KSampler:
|
|||||||
def set_steps(self, steps, denoise=None):
|
def set_steps(self, steps, denoise=None):
|
||||||
self.steps = steps
|
self.steps = steps
|
||||||
if denoise is None or denoise > 0.9999:
|
if denoise is None or denoise > 0.9999:
|
||||||
self.sigmas = self._calculate_sigmas(steps)
|
self.sigmas = self.calculate_sigmas(steps).to(self.device)
|
||||||
else:
|
else:
|
||||||
new_steps = int(steps/denoise)
|
new_steps = int(steps/denoise)
|
||||||
sigmas = self._calculate_sigmas(new_steps)
|
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
||||||
self.sigmas = sigmas[-(steps + 1):]
|
self.sigmas = sigmas[-(steps + 1):]
|
||||||
|
|
||||||
|
|
||||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None):
|
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None):
|
||||||
sigmas = self.sigmas
|
if sigmas is None:
|
||||||
|
sigmas = self.sigmas
|
||||||
sigma_min = self.sigma_min
|
sigma_min = self.sigma_min
|
||||||
|
|
||||||
if last_step is not None and last_step < (len(sigmas) - 1):
|
if last_step is not None and last_step < (len(sigmas) - 1):
|
||||||
@@ -535,9 +527,9 @@ class KSampler:
|
|||||||
|
|
||||||
with precision_scope(model_management.get_autocast_device(self.device)):
|
with precision_scope(model_management.get_autocast_device(self.device)):
|
||||||
if self.sampler == "uni_pc":
|
if self.sampler == "uni_pc":
|
||||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask)
|
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback)
|
||||||
elif self.sampler == "uni_pc_bh2":
|
elif self.sampler == "uni_pc_bh2":
|
||||||
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2')
|
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2')
|
||||||
elif self.sampler == "ddim":
|
elif self.sampler == "ddim":
|
||||||
timesteps = []
|
timesteps = []
|
||||||
for s in range(sigmas.shape[0]):
|
for s in range(sigmas.shape[0]):
|
||||||
@@ -545,6 +537,11 @@ class KSampler:
|
|||||||
noise_mask = None
|
noise_mask = None
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
noise_mask = 1.0 - denoise_mask
|
noise_mask = 1.0 - denoise_mask
|
||||||
|
|
||||||
|
ddim_callback = None
|
||||||
|
if callback is not None:
|
||||||
|
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None)
|
||||||
|
|
||||||
sampler = DDIMSampler(self.model, device=self.device)
|
sampler = DDIMSampler(self.model, device=self.device)
|
||||||
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||||
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
|
||||||
@@ -558,6 +555,7 @@ class KSampler:
|
|||||||
eta=0.0,
|
eta=0.0,
|
||||||
x_T=z_enc,
|
x_T=z_enc,
|
||||||
x0=latent_image,
|
x0=latent_image,
|
||||||
|
img_callback=ddim_callback,
|
||||||
denoise_function=sampling_function,
|
denoise_function=sampling_function,
|
||||||
extra_args=extra_args,
|
extra_args=extra_args,
|
||||||
mask=noise_mask,
|
mask=noise_mask,
|
||||||
@@ -571,13 +569,17 @@ class KSampler:
|
|||||||
|
|
||||||
noise = noise * sigmas[0]
|
noise = noise * sigmas[0]
|
||||||
|
|
||||||
|
k_callback = None
|
||||||
|
if callback is not None:
|
||||||
|
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"])
|
||||||
|
|
||||||
if latent_image is not None:
|
if latent_image is not None:
|
||||||
noise += latent_image
|
noise += latent_image
|
||||||
if self.sampler == "dpm_fast":
|
if self.sampler == "dpm_fast":
|
||||||
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args)
|
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback)
|
||||||
elif self.sampler == "dpm_adaptive":
|
elif self.sampler == "dpm_adaptive":
|
||||||
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args)
|
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback)
|
||||||
else:
|
else:
|
||||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args)
|
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback)
|
||||||
|
|
||||||
return samples.to(torch.float32)
|
return samples.to(torch.float32)
|
||||||
|
23
comfy/sd.py
23
comfy/sd.py
@@ -254,6 +254,29 @@ class ModelPatcher:
|
|||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
|
|
||||||
|
|
||||||
|
def set_model_patch(self, patch, name):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches" not in to:
|
||||||
|
to["patches"] = {}
|
||||||
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||||
|
|
||||||
|
def set_model_attn1_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn1_patch")
|
||||||
|
|
||||||
|
def set_model_attn2_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
|
def model_patches_to(self, device):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches" in to:
|
||||||
|
patches = to["patches"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for i in range(len(patch_list)):
|
||||||
|
if hasattr(patch_list[i], "to"):
|
||||||
|
patch_list[i] = patch_list[i].to(device)
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
return self.model.diffusion_model.dtype
|
return self.model.diffusion_model.dtype
|
||||||
|
|
||||||
|
@@ -1,11 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
def load_torch_file(ckpt):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
if safe_load:
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
||||||
|
else:
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
|
@@ -4,7 +4,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Literal
|
try:
|
||||||
|
from typing import Literal
|
||||||
|
except ImportError:
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
109
comfy_extras/nodes_hypernetwork.py
Normal file
109
comfy_extras/nodes_hypernetwork.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def load_hypernetwork_patch(path, strength):
|
||||||
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||||
|
activation_func = sd.get('activation_func', 'linear')
|
||||||
|
is_layer_norm = sd.get('is_layer_norm', False)
|
||||||
|
use_dropout = sd.get('use_dropout', False)
|
||||||
|
activate_output = sd.get('activate_output', False)
|
||||||
|
last_layer_dropout = sd.get('last_layer_dropout', False)
|
||||||
|
|
||||||
|
valid_activation = {
|
||||||
|
"linear": torch.nn.Identity,
|
||||||
|
"relu": torch.nn.ReLU,
|
||||||
|
"leakyrelu": torch.nn.LeakyReLU,
|
||||||
|
"elu": torch.nn.ELU,
|
||||||
|
"swish": torch.nn.Hardswish,
|
||||||
|
"tanh": torch.nn.Tanh,
|
||||||
|
"sigmoid": torch.nn.Sigmoid,
|
||||||
|
}
|
||||||
|
|
||||||
|
if activation_func not in valid_activation:
|
||||||
|
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
|
||||||
|
return None
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
|
||||||
|
for d in sd:
|
||||||
|
try:
|
||||||
|
dim = int(d)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for index in [0, 1]:
|
||||||
|
attn_weights = sd[dim][index]
|
||||||
|
keys = attn_weights.keys()
|
||||||
|
|
||||||
|
linears = filter(lambda a: a.endswith(".weight"), keys)
|
||||||
|
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for i in range(len(linears)):
|
||||||
|
lin_name = linears[i]
|
||||||
|
last_layer = (i == (len(linears) - 1))
|
||||||
|
penultimate_layer = (i == (len(linears) - 2))
|
||||||
|
|
||||||
|
lin_weight = attn_weights['{}.weight'.format(lin_name)]
|
||||||
|
lin_bias = attn_weights['{}.bias'.format(lin_name)]
|
||||||
|
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
|
||||||
|
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
|
||||||
|
layers.append(layer)
|
||||||
|
if activation_func != "linear":
|
||||||
|
if (not last_layer) or (activate_output):
|
||||||
|
layers.append(valid_activation[activation_func]())
|
||||||
|
if is_layer_norm:
|
||||||
|
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
|
||||||
|
if use_dropout:
|
||||||
|
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
||||||
|
layers.append(torch.nn.Dropout(p=0.3))
|
||||||
|
|
||||||
|
output.append(torch.nn.Sequential(*layers))
|
||||||
|
out[dim] = torch.nn.ModuleList(output)
|
||||||
|
|
||||||
|
class hypernetwork_patch:
|
||||||
|
def __init__(self, hypernet, strength):
|
||||||
|
self.hypernet = hypernet
|
||||||
|
self.strength = strength
|
||||||
|
def __call__(self, current_index, q, k, v):
|
||||||
|
dim = k.shape[-1]
|
||||||
|
if dim in self.hypernet:
|
||||||
|
hn = self.hypernet[dim]
|
||||||
|
k = k + hn[0](k) * self.strength
|
||||||
|
v = v + hn[1](v) * self.strength
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
for d in self.hypernet.keys():
|
||||||
|
self.hypernet[d] = self.hypernet[d].to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
return hypernetwork_patch(out, strength)
|
||||||
|
|
||||||
|
class HypernetworkLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "load_hypernetwork"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||||
|
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
|
||||||
|
model_hypernetwork = model.clone()
|
||||||
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
|
if patch is not None:
|
||||||
|
model_hypernetwork.set_model_attn1_patch(patch)
|
||||||
|
model_hypernetwork.set_model_attn2_patch(patch)
|
||||||
|
return (model_hypernetwork,)
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"HypernetworkLoader": HypernetworkLoader
|
||||||
|
}
|
90
execution.py
90
execution.py
@@ -40,15 +40,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
input_data_all[x] = unique_id
|
input_data_all[x] = unique_id
|
||||||
return input_data_all
|
return input_data_all
|
||||||
|
|
||||||
def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
class_type = prompt[unique_id]['class_type']
|
class_type = prompt[unique_id]['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
if unique_id in outputs:
|
if unique_id in outputs:
|
||||||
return []
|
return
|
||||||
|
|
||||||
executed = []
|
|
||||||
|
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
@@ -57,7 +55,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
|||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if input_unique_id not in outputs:
|
if input_unique_id not in outputs:
|
||||||
executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data)
|
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
|
||||||
|
|
||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
@@ -72,7 +70,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
|||||||
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
||||||
if "result" in outputs[unique_id]:
|
if "result" in outputs[unique_id]:
|
||||||
outputs[unique_id] = outputs[unique_id]["result"]
|
outputs[unique_id] = outputs[unique_id]["result"]
|
||||||
return executed + [unique_id]
|
executed.add(unique_id)
|
||||||
|
|
||||||
def recursive_will_execute(prompt, outputs, current_item):
|
def recursive_will_execute(prompt, outputs, current_item):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
@@ -99,40 +97,44 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
|
|
||||||
is_changed_old = ''
|
is_changed_old = ''
|
||||||
is_changed = ''
|
is_changed = ''
|
||||||
|
to_delete = False
|
||||||
if hasattr(class_def, 'IS_CHANGED'):
|
if hasattr(class_def, 'IS_CHANGED'):
|
||||||
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
|
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
|
||||||
is_changed_old = old_prompt[unique_id]['is_changed']
|
is_changed_old = old_prompt[unique_id]['is_changed']
|
||||||
if 'is_changed' not in prompt[unique_id]:
|
if 'is_changed' not in prompt[unique_id]:
|
||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||||
if input_data_all is not None:
|
if input_data_all is not None:
|
||||||
is_changed = class_def.IS_CHANGED(**input_data_all)
|
try:
|
||||||
prompt[unique_id]['is_changed'] = is_changed
|
is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||||
|
prompt[unique_id]['is_changed'] = is_changed
|
||||||
|
except:
|
||||||
|
to_delete = True
|
||||||
else:
|
else:
|
||||||
is_changed = prompt[unique_id]['is_changed']
|
is_changed = prompt[unique_id]['is_changed']
|
||||||
|
|
||||||
if unique_id not in outputs:
|
if unique_id not in outputs:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
to_delete = False
|
if not to_delete:
|
||||||
if is_changed != is_changed_old:
|
if is_changed != is_changed_old:
|
||||||
to_delete = True
|
to_delete = True
|
||||||
elif unique_id not in old_prompt:
|
elif unique_id not in old_prompt:
|
||||||
to_delete = True
|
to_delete = True
|
||||||
elif inputs == old_prompt[unique_id]['inputs']:
|
elif inputs == old_prompt[unique_id]['inputs']:
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
|
|
||||||
if isinstance(input_data, list):
|
if isinstance(input_data, list):
|
||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if input_unique_id in outputs:
|
if input_unique_id in outputs:
|
||||||
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
|
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
|
||||||
else:
|
else:
|
||||||
to_delete = True
|
to_delete = True
|
||||||
if to_delete:
|
if to_delete:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
to_delete = True
|
to_delete = True
|
||||||
|
|
||||||
if to_delete:
|
if to_delete:
|
||||||
d = outputs.pop(unique_id)
|
d = outputs.pop(unique_id)
|
||||||
@@ -154,11 +156,20 @@ class PromptExecutor:
|
|||||||
self.server.client_id = None
|
self.server.client_id = None
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
#delete cached outputs if nodes don't exist for them
|
||||||
|
to_delete = []
|
||||||
|
for o in self.outputs:
|
||||||
|
if o not in prompt:
|
||||||
|
to_delete += [o]
|
||||||
|
for o in to_delete:
|
||||||
|
d = self.outputs.pop(o)
|
||||||
|
del d
|
||||||
|
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||||
|
|
||||||
current_outputs = set(self.outputs.keys())
|
current_outputs = set(self.outputs.keys())
|
||||||
executed = []
|
executed = set()
|
||||||
try:
|
try:
|
||||||
to_execute = []
|
to_execute = []
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
@@ -181,12 +192,12 @@ class PromptExecutor:
|
|||||||
except:
|
except:
|
||||||
valid = False
|
valid = False
|
||||||
if valid:
|
if valid:
|
||||||
executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data)
|
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for o in self.outputs:
|
for o in self.outputs:
|
||||||
if o not in current_outputs:
|
if (o not in current_outputs) and (o not in executed):
|
||||||
to_delete += [o]
|
to_delete += [o]
|
||||||
if o in self.old_prompt:
|
if o in self.old_prompt:
|
||||||
d = self.old_prompt.pop(o)
|
d = self.old_prompt.pop(o)
|
||||||
@@ -194,11 +205,9 @@ class PromptExecutor:
|
|||||||
for o in to_delete:
|
for o in to_delete:
|
||||||
d = self.outputs.pop(o)
|
d = self.outputs.pop(o)
|
||||||
del d
|
del d
|
||||||
else:
|
finally:
|
||||||
executed = set(executed)
|
|
||||||
for x in executed:
|
for x in executed:
|
||||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||||
finally:
|
|
||||||
self.server.last_node_id = None
|
self.server.last_node_id = None
|
||||||
if self.server.client_id is not None:
|
if self.server.client_id is not None:
|
||||||
self.server.send_sync("executing", { "node": None }, self.server.client_id)
|
self.server.send_sync("executing", { "node": None }, self.server.client_id)
|
||||||
@@ -249,9 +258,15 @@ def validate_inputs(prompt, item):
|
|||||||
if "max" in info[1] and val > info[1]["max"]:
|
if "max" in info[1] and val > info[1]["max"]:
|
||||||
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
||||||
|
|
||||||
if isinstance(type_input, list):
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||||
if val not in type_input:
|
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||||
|
if ret != True:
|
||||||
|
return (False, "{}, {}".format(class_type, ret))
|
||||||
|
else:
|
||||||
|
if isinstance(type_input, list):
|
||||||
|
if val not in type_input:
|
||||||
|
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
||||||
return (True, "")
|
return (True, "")
|
||||||
|
|
||||||
def validate_prompt(prompt):
|
def validate_prompt(prompt):
|
||||||
@@ -273,7 +288,8 @@ def validate_prompt(prompt):
|
|||||||
m = validate_inputs(prompt, o)
|
m = validate_inputs(prompt, o)
|
||||||
valid = m[0]
|
valid = m[0]
|
||||||
reason = m[1]
|
reason = m[1]
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
valid = False
|
valid = False
|
||||||
reason = "Parsing error"
|
reason = "Parsing error"
|
||||||
|
|
||||||
|
@@ -13,6 +13,7 @@ a111:
|
|||||||
models/ESRGAN
|
models/ESRGAN
|
||||||
models/SwinIR
|
models/SwinIR
|
||||||
embeddings: embeddings
|
embeddings: embeddings
|
||||||
|
hypernetworks: models/hypernetworks
|
||||||
controlnet: models/ControlNet
|
controlnet: models/ControlNet
|
||||||
|
|
||||||
#other_ui:
|
#other_ui:
|
||||||
|
@@ -32,6 +32,7 @@ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_m
|
|||||||
|
|
||||||
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
|
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
|
||||||
|
|
||||||
|
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
||||||
|
|
||||||
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
||||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||||
@@ -68,6 +69,46 @@ def get_directory_by_type(type_name):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||||
|
# otherwise use default_path as base_dir
|
||||||
|
def annotated_filepath(name):
|
||||||
|
if name.endswith("[output]"):
|
||||||
|
base_dir = get_output_directory()
|
||||||
|
name = name[:-9]
|
||||||
|
elif name.endswith("[input]"):
|
||||||
|
base_dir = get_input_directory()
|
||||||
|
name = name[:-8]
|
||||||
|
elif name.endswith("[temp]"):
|
||||||
|
base_dir = get_temp_directory()
|
||||||
|
name = name[:-7]
|
||||||
|
else:
|
||||||
|
return name, None
|
||||||
|
|
||||||
|
return name, base_dir
|
||||||
|
|
||||||
|
|
||||||
|
def get_annotated_filepath(name, default_dir=None):
|
||||||
|
name, base_dir = annotated_filepath(name)
|
||||||
|
|
||||||
|
if base_dir is None:
|
||||||
|
if default_dir is not None:
|
||||||
|
base_dir = default_dir
|
||||||
|
else:
|
||||||
|
base_dir = get_input_directory() # fallback path
|
||||||
|
|
||||||
|
return os.path.join(base_dir, name)
|
||||||
|
|
||||||
|
|
||||||
|
def exists_annotated_filepath(name):
|
||||||
|
name, base_dir = annotated_filepath(name)
|
||||||
|
|
||||||
|
if base_dir is None:
|
||||||
|
base_dir = get_input_directory() # fallback path
|
||||||
|
|
||||||
|
filepath = os.path.join(base_dir, name)
|
||||||
|
return os.path.exists(filepath)
|
||||||
|
|
||||||
|
|
||||||
def add_model_folder_path(folder_name, full_folder_path):
|
def add_model_folder_path(folder_name, full_folder_path):
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
if folder_name in folder_names_and_paths:
|
if folder_name in folder_names_and_paths:
|
||||||
|
0
models/hypernetworks/put_hypernetworks_here
Normal file
0
models/hypernetworks/put_hypernetworks_here
Normal file
116
nodes.py
116
nodes.py
@@ -16,6 +16,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "co
|
|||||||
|
|
||||||
import comfy.diffusers_convert
|
import comfy.diffusers_convert
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
|
import comfy.sample
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
@@ -203,24 +204,24 @@ class VAEEncodeForInpaint:
|
|||||||
def encode(self, vae, pixels, mask):
|
def encode(self, vae, pixels, mask):
|
||||||
x = (pixels.shape[1] // 64) * 64
|
x = (pixels.shape[1] // 64) * 64
|
||||||
y = (pixels.shape[2] // 64) * 64
|
y = (pixels.shape[2] // 64) * 64
|
||||||
mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0]
|
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||||
|
|
||||||
pixels = pixels.clone()
|
pixels = pixels.clone()
|
||||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||||
pixels = pixels[:,:x,:y,:]
|
pixels = pixels[:,:x,:y,:]
|
||||||
mask = mask[:x,:y]
|
mask = mask[:,:,:x,:y]
|
||||||
|
|
||||||
#grow mask by a few pixels to keep things seamless in latent space
|
#grow mask by a few pixels to keep things seamless in latent space
|
||||||
kernel_tensor = torch.ones((1, 1, 6, 6))
|
kernel_tensor = torch.ones((1, 1, 6, 6))
|
||||||
mask_erosion = torch.clamp(torch.nn.functional.conv2d((mask.round())[None], kernel_tensor, padding=3), 0, 1)
|
mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=3), 0, 1)
|
||||||
m = (1.0 - mask.round())
|
m = (1.0 - mask.round()).squeeze(1)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
pixels[:,:,:,i] -= 0.5
|
pixels[:,:,:,i] -= 0.5
|
||||||
pixels[:,:,:,i] *= m
|
pixels[:,:,:,i] *= m
|
||||||
pixels[:,:,:,i] += 0.5
|
pixels[:,:,:,i] += 0.5
|
||||||
t = vae.encode(pixels)
|
t = vae.encode(pixels)
|
||||||
|
|
||||||
return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, )
|
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
|
||||||
|
|
||||||
class CheckpointLoader:
|
class CheckpointLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -771,79 +772,23 @@ class SetLatentNoiseMask:
|
|||||||
s["noise_mask"] = mask
|
s["noise_mask"] = mask
|
||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
|
|
||||||
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
|
||||||
latent_image = latent["samples"]
|
|
||||||
noise_mask = None
|
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
|
latent_image = latent["samples"]
|
||||||
|
|
||||||
if disable_noise:
|
if disable_noise:
|
||||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||||
else:
|
else:
|
||||||
batch_index = 0
|
skip = latent["batch_index"] if "batch_index" in latent else 0
|
||||||
if "batch_index" in latent:
|
noise = comfy.sample.prepare_noise(latent_image, seed, skip)
|
||||||
batch_index = latent["batch_index"]
|
|
||||||
|
|
||||||
generator = torch.manual_seed(seed)
|
|
||||||
for i in range(batch_index):
|
|
||||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
|
||||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
|
||||||
|
|
||||||
|
noise_mask = None
|
||||||
if "noise_mask" in latent:
|
if "noise_mask" in latent:
|
||||||
noise_mask = latent['noise_mask']
|
noise_mask = latent["noise_mask"]
|
||||||
noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear")
|
|
||||||
noise_mask = noise_mask.round()
|
|
||||||
noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1)
|
|
||||||
noise_mask = torch.cat([noise_mask] * noise.shape[0])
|
|
||||||
noise_mask = noise_mask.to(device)
|
|
||||||
|
|
||||||
real_model = None
|
|
||||||
comfy.model_management.load_model_gpu(model)
|
|
||||||
real_model = model.model
|
|
||||||
|
|
||||||
noise = noise.to(device)
|
|
||||||
latent_image = latent_image.to(device)
|
|
||||||
|
|
||||||
positive_copy = []
|
|
||||||
negative_copy = []
|
|
||||||
|
|
||||||
control_nets = []
|
|
||||||
def get_models(cond):
|
|
||||||
models = []
|
|
||||||
for c in cond:
|
|
||||||
if 'control' in c[1]:
|
|
||||||
models += [c[1]['control']]
|
|
||||||
if 'gligen' in c[1]:
|
|
||||||
models += [c[1]['gligen'][1]]
|
|
||||||
return models
|
|
||||||
|
|
||||||
for p in positive:
|
|
||||||
t = p[0]
|
|
||||||
if t.shape[0] < noise.shape[0]:
|
|
||||||
t = torch.cat([t] * noise.shape[0])
|
|
||||||
t = t.to(device)
|
|
||||||
positive_copy += [[t] + p[1:]]
|
|
||||||
for n in negative:
|
|
||||||
t = n[0]
|
|
||||||
if t.shape[0] < noise.shape[0]:
|
|
||||||
t = torch.cat([t] * noise.shape[0])
|
|
||||||
t = t.to(device)
|
|
||||||
negative_copy += [[t] + n[1:]]
|
|
||||||
|
|
||||||
models = get_models(positive) + get_models(negative)
|
|
||||||
comfy.model_management.load_controlnet_gpu(models)
|
|
||||||
|
|
||||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
|
||||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
|
||||||
else:
|
|
||||||
#other samplers
|
|
||||||
pass
|
|
||||||
|
|
||||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
|
|
||||||
samples = samples.cpu()
|
|
||||||
for m in models:
|
|
||||||
m.cleanup()
|
|
||||||
|
|
||||||
|
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||||
|
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||||
|
force_full_denoise=force_full_denoise, noise_mask=noise_mask)
|
||||||
out = latent.copy()
|
out = latent.copy()
|
||||||
out["samples"] = samples
|
out["samples"] = samples
|
||||||
return (out, )
|
return (out, )
|
||||||
@@ -1006,8 +951,7 @@ class LoadImage:
|
|||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
RETURN_TYPES = ("IMAGE", "MASK")
|
||||||
FUNCTION = "load_image"
|
FUNCTION = "load_image"
|
||||||
def load_image(self, image):
|
def load_image(self, image):
|
||||||
input_dir = folder_paths.get_input_directory()
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
image_path = os.path.join(input_dir, image)
|
|
||||||
i = Image.open(image_path)
|
i = Image.open(image_path)
|
||||||
image = i.convert("RGB")
|
image = i.convert("RGB")
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
@@ -1021,20 +965,27 @@ class LoadImage:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, image):
|
def IS_CHANGED(s, image):
|
||||||
input_dir = folder_paths.get_input_directory()
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
image_path = os.path.join(input_dir, image)
|
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
with open(image_path, 'rb') as f:
|
with open(image_path, 'rb') as f:
|
||||||
m.update(f.read())
|
m.update(f.read())
|
||||||
return m.digest().hex()
|
return m.digest().hex()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, image):
|
||||||
|
if not folder_paths.exists_annotated_filepath(image):
|
||||||
|
return "Invalid image file: {}".format(image)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
class LoadImageMask:
|
class LoadImageMask:
|
||||||
|
_color_channels = ["alpha", "red", "green", "blue"]
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
return {"required":
|
return {"required":
|
||||||
{"image": (sorted(os.listdir(input_dir)), ),
|
{"image": (sorted(os.listdir(input_dir)), ),
|
||||||
"channel": (["alpha", "red", "green", "blue"], ),}
|
"channel": (s._color_channels, ),}
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "mask"
|
CATEGORY = "mask"
|
||||||
@@ -1042,8 +993,7 @@ class LoadImageMask:
|
|||||||
RETURN_TYPES = ("MASK",)
|
RETURN_TYPES = ("MASK",)
|
||||||
FUNCTION = "load_image"
|
FUNCTION = "load_image"
|
||||||
def load_image(self, image, channel):
|
def load_image(self, image, channel):
|
||||||
input_dir = folder_paths.get_input_directory()
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
image_path = os.path.join(input_dir, image)
|
|
||||||
i = Image.open(image_path)
|
i = Image.open(image_path)
|
||||||
if i.getbands() != ("R", "G", "B", "A"):
|
if i.getbands() != ("R", "G", "B", "A"):
|
||||||
i = i.convert("RGBA")
|
i = i.convert("RGBA")
|
||||||
@@ -1060,13 +1010,22 @@ class LoadImageMask:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, image, channel):
|
def IS_CHANGED(s, image, channel):
|
||||||
input_dir = folder_paths.get_input_directory()
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
image_path = os.path.join(input_dir, image)
|
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
with open(image_path, 'rb') as f:
|
with open(image_path, 'rb') as f:
|
||||||
m.update(f.read())
|
m.update(f.read())
|
||||||
return m.digest().hex()
|
return m.digest().hex()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, image, channel):
|
||||||
|
if not folder_paths.exists_annotated_filepath(image):
|
||||||
|
return "Invalid image file: {}".format(image)
|
||||||
|
|
||||||
|
if channel not in s._color_channels:
|
||||||
|
return "Invalid color channel: {}".format(channel)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
class ImageScale:
|
class ImageScale:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||||
crop_methods = ["disabled", "center"]
|
crop_methods = ["disabled", "center"]
|
||||||
@@ -1302,6 +1261,7 @@ def load_custom_nodes():
|
|||||||
|
|
||||||
def init_custom_nodes():
|
def init_custom_nodes():
|
||||||
load_custom_nodes()
|
load_custom_nodes()
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||||
|
@@ -47,7 +47,7 @@
|
|||||||
" !git pull\n",
|
" !git pull\n",
|
||||||
"\n",
|
"\n",
|
||||||
"!echo -= Install dependencies =-\n",
|
"!echo -= Install dependencies =-\n",
|
||||||
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118"
|
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
15
server.py
15
server.py
@@ -112,14 +112,21 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.post("/upload/image")
|
@routes.post("/upload/image")
|
||||||
async def upload_image(request):
|
async def upload_image(request):
|
||||||
upload_dir = folder_paths.get_input_directory()
|
post = await request.post()
|
||||||
|
image = post.get("image")
|
||||||
|
|
||||||
|
if post.get("type") is None:
|
||||||
|
upload_dir = folder_paths.get_input_directory()
|
||||||
|
elif post.get("type") == "input":
|
||||||
|
upload_dir = folder_paths.get_input_directory()
|
||||||
|
elif post.get("type") == "temp":
|
||||||
|
upload_dir = folder_paths.get_temp_directory()
|
||||||
|
elif post.get("type") == "output":
|
||||||
|
upload_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
if not os.path.exists(upload_dir):
|
if not os.path.exists(upload_dir):
|
||||||
os.makedirs(upload_dir)
|
os.makedirs(upload_dir)
|
||||||
|
|
||||||
post = await request.post()
|
|
||||||
image = post.get("image")
|
|
||||||
|
|
||||||
if image and image.file:
|
if image and image.file:
|
||||||
filename = image.filename
|
filename = image.filename
|
||||||
if not filename:
|
if not filename:
|
||||||
|
@@ -89,24 +89,17 @@ app.registerExtension({
|
|||||||
end = nearestEnclosure.end;
|
end = nearestEnclosure.end;
|
||||||
selectedText = inputField.value.substring(start, end);
|
selectedText = inputField.value.substring(start, end);
|
||||||
} else {
|
} else {
|
||||||
// Select the current word, find the start and end of the word (first space before and after)
|
// Select the current word, find the start and end of the word
|
||||||
const wordStart = inputField.value.substring(0, start).lastIndexOf(" ") + 1;
|
const delimiters = " .,\\/!?%^*;:{}=-_`~()\r\n\t";
|
||||||
const wordEnd = inputField.value.substring(end).indexOf(" ");
|
|
||||||
// If there is no space after the word, select to the end of the string
|
|
||||||
if (wordEnd === -1) {
|
|
||||||
end = inputField.value.length;
|
|
||||||
} else {
|
|
||||||
end += wordEnd;
|
|
||||||
}
|
|
||||||
start = wordStart;
|
|
||||||
|
|
||||||
// Remove all punctuation at the end and beginning of the word
|
while (!delimiters.includes(inputField.value[start - 1]) && start > 0) {
|
||||||
while (inputField.value[start].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) {
|
start--;
|
||||||
start++;
|
|
||||||
}
|
}
|
||||||
while (inputField.value[end - 1].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) {
|
|
||||||
end--;
|
while (!delimiters.includes(inputField.value[end]) && end < inputField.value.length) {
|
||||||
|
end++;
|
||||||
}
|
}
|
||||||
|
|
||||||
selectedText = inputField.value.substring(start, end);
|
selectedText = inputField.value.substring(start, end);
|
||||||
if (!selectedText) return;
|
if (!selectedText) return;
|
||||||
}
|
}
|
||||||
@@ -135,8 +128,13 @@ app.registerExtension({
|
|||||||
|
|
||||||
// Increment the weight
|
// Increment the weight
|
||||||
const weightDelta = event.key === "ArrowUp" ? delta : -delta;
|
const weightDelta = event.key === "ArrowUp" ? delta : -delta;
|
||||||
const updatedText = selectedText.replace(/(.*:)(\d+(\.\d+)?)(.*)/, (match, prefix, weight, _, suffix) => {
|
const updatedText = selectedText.replace(/\((.*):(\d+(?:\.\d+)?)\)/, (match, text, weight) => {
|
||||||
return prefix + incrementWeight(weight, weightDelta) + suffix;
|
weight = incrementWeight(weight, weightDelta);
|
||||||
|
if (weight == 1) {
|
||||||
|
return text;
|
||||||
|
} else {
|
||||||
|
return `(${text}:${weight})`;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
inputField.setRangeText(updatedText, start, end, "select");
|
inputField.setRangeText(updatedText, start, end, "select");
|
||||||
|
@@ -1,21 +1,72 @@
|
|||||||
import { app } from "/scripts/app.js";
|
import { app } from "/scripts/app.js";
|
||||||
|
import { ComfyWidgets } from "/scripts/widgets.js";
|
||||||
// Adds defaults for quickly adding nodes with middle click on the input/output
|
// Adds defaults for quickly adding nodes with middle click on the input/output
|
||||||
|
|
||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
name: "Comfy.SlotDefaults",
|
name: "Comfy.SlotDefaults",
|
||||||
|
suggestionsNumber: null,
|
||||||
init() {
|
init() {
|
||||||
LiteGraph.middle_click_slot_add_default_node = true;
|
LiteGraph.middle_click_slot_add_default_node = true;
|
||||||
LiteGraph.slot_types_default_in = {
|
this.suggestionsNumber = app.ui.settings.addSetting({
|
||||||
MODEL: "CheckpointLoaderSimple",
|
id: "Comfy.NodeSuggestions.number",
|
||||||
LATENT: "EmptyLatentImage",
|
name: "number of nodes suggestions",
|
||||||
VAE: "VAELoader",
|
type: "slider",
|
||||||
};
|
attrs: {
|
||||||
|
min: 1,
|
||||||
LiteGraph.slot_types_default_out = {
|
max: 100,
|
||||||
LATENT: "VAEDecode",
|
step: 1,
|
||||||
IMAGE: "SaveImage",
|
},
|
||||||
CLIP: "CLIPTextEncode",
|
defaultValue: 5,
|
||||||
};
|
onChange: (newVal, oldVal) => {
|
||||||
|
this.setDefaults(newVal);
|
||||||
|
}
|
||||||
|
});
|
||||||
},
|
},
|
||||||
|
slot_types_default_out: {},
|
||||||
|
slot_types_default_in: {},
|
||||||
|
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||||
|
var nodeId = nodeData.name;
|
||||||
|
var inputs = [];
|
||||||
|
inputs = nodeData["input"]["required"]; //only show required inputs to reduce the mess also not logical to create node with optional inputs
|
||||||
|
for (const inputKey in inputs) {
|
||||||
|
var input = (inputs[inputKey]);
|
||||||
|
if (typeof input[0] !== "string") continue;
|
||||||
|
|
||||||
|
var type = input[0]
|
||||||
|
if (type in ComfyWidgets) {
|
||||||
|
var customProperties = input[1]
|
||||||
|
if (!(customProperties?.forceInput)) continue; //ignore widgets that don't force input
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!(type in this.slot_types_default_out)) {
|
||||||
|
this.slot_types_default_out[type] = ["Reroute"];
|
||||||
|
}
|
||||||
|
if (this.slot_types_default_out[type].includes(nodeId)) continue;
|
||||||
|
this.slot_types_default_out[type].push(nodeId);
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputs = nodeData["output"];
|
||||||
|
for (const key in outputs) {
|
||||||
|
var type = outputs[key];
|
||||||
|
if (!(type in this.slot_types_default_in)) {
|
||||||
|
this.slot_types_default_in[type] = ["Reroute"];// ["Reroute", "Primitive"]; primitive doesn't always work :'()
|
||||||
|
}
|
||||||
|
|
||||||
|
this.slot_types_default_in[type].push(nodeId);
|
||||||
|
}
|
||||||
|
var maxNum = this.suggestionsNumber.value;
|
||||||
|
this.setDefaults(maxNum);
|
||||||
|
},
|
||||||
|
setDefaults(maxNum) {
|
||||||
|
|
||||||
|
LiteGraph.slot_types_default_out = {};
|
||||||
|
LiteGraph.slot_types_default_in = {};
|
||||||
|
|
||||||
|
for (const type in this.slot_types_default_out) {
|
||||||
|
LiteGraph.slot_types_default_out[type] = this.slot_types_default_out[type].slice(0, maxNum);
|
||||||
|
}
|
||||||
|
for (const type in this.slot_types_default_in) {
|
||||||
|
LiteGraph.slot_types_default_in[type] = this.slot_types_default_in[type].slice(0, maxNum);
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
@@ -9953,11 +9953,11 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case "slider":
|
case "slider":
|
||||||
var range = w.options.max - w.options.min;
|
var old_value = w.value;
|
||||||
var nvalue = Math.clamp((x - 15) / (widget_width - 30), 0, 1);
|
var nvalue = Math.clamp((x - 15) / (widget_width - 30), 0, 1);
|
||||||
if(w.options.read_only) break;
|
if(w.options.read_only) break;
|
||||||
w.value = w.options.min + (w.options.max - w.options.min) * nvalue;
|
w.value = w.options.min + (w.options.max - w.options.min) * nvalue;
|
||||||
if (w.callback) {
|
if (old_value != w.value) {
|
||||||
setTimeout(function() {
|
setTimeout(function() {
|
||||||
inner_value_change(w, w.value);
|
inner_value_change(w, w.value);
|
||||||
}, 20);
|
}, 20);
|
||||||
@@ -10044,7 +10044,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
if (event.click_time < 200 && delta == 0) {
|
if (event.click_time < 200 && delta == 0) {
|
||||||
this.prompt("Value",w.value,function(v) {
|
this.prompt("Value",w.value,function(v) {
|
||||||
// check if v is a valid equation or a number
|
// check if v is a valid equation or a number
|
||||||
if (/^[0-9+\-*/()\s]+$/.test(v)) {
|
if (/^[0-9+\-*/()\s]+|\d+\.\d+$/.test(v)) {
|
||||||
try {//solve the equation if possible
|
try {//solve the equation if possible
|
||||||
v = eval(v);
|
v = eval(v);
|
||||||
} catch (e) { }
|
} catch (e) { }
|
||||||
|
@@ -35,7 +35,7 @@ class ComfyApi extends EventTarget {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let opened = false;
|
let opened = false;
|
||||||
let existingSession = sessionStorage["Comfy.SessionId"] || "";
|
let existingSession = window.name;
|
||||||
if (existingSession) {
|
if (existingSession) {
|
||||||
existingSession = "?clientId=" + existingSession;
|
existingSession = "?clientId=" + existingSession;
|
||||||
}
|
}
|
||||||
@@ -75,7 +75,7 @@ class ComfyApi extends EventTarget {
|
|||||||
case "status":
|
case "status":
|
||||||
if (msg.data.sid) {
|
if (msg.data.sid) {
|
||||||
this.clientId = msg.data.sid;
|
this.clientId = msg.data.sid;
|
||||||
sessionStorage["Comfy.SessionId"] = this.clientId;
|
window.name = this.clientId;
|
||||||
}
|
}
|
||||||
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
||||||
break;
|
break;
|
||||||
|
@@ -20,6 +20,12 @@ export class ComfyApp {
|
|||||||
*/
|
*/
|
||||||
#processingQueue = false;
|
#processingQueue = false;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Content Clipboard
|
||||||
|
* @type {serialized node object}
|
||||||
|
*/
|
||||||
|
static clipspace = null;
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
this.ui = new ComfyUI(this);
|
this.ui = new ComfyUI(this);
|
||||||
|
|
||||||
@@ -130,6 +136,83 @@ export class ComfyApp {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
options.push(
|
||||||
|
{
|
||||||
|
content: "Copy (Clipspace)",
|
||||||
|
callback: (obj) => {
|
||||||
|
var widgets = null;
|
||||||
|
if(this.widgets) {
|
||||||
|
widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value }));
|
||||||
|
}
|
||||||
|
|
||||||
|
let img = new Image();
|
||||||
|
var imgs = undefined;
|
||||||
|
if(this.imgs != undefined) {
|
||||||
|
img.src = this.imgs[0].src;
|
||||||
|
imgs = [img];
|
||||||
|
}
|
||||||
|
|
||||||
|
ComfyApp.clipspace = {
|
||||||
|
'widgets': widgets,
|
||||||
|
'imgs': imgs,
|
||||||
|
'original_imgs': imgs,
|
||||||
|
'images': this.images
|
||||||
|
};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if(ComfyApp.clipspace != null) {
|
||||||
|
options.push(
|
||||||
|
{
|
||||||
|
content: "Paste (Clipspace)",
|
||||||
|
callback: () => {
|
||||||
|
if(ComfyApp.clipspace != null) {
|
||||||
|
if(ComfyApp.clipspace.widgets != null && this.widgets != null) {
|
||||||
|
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||||
|
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
|
||||||
|
if (prop) {
|
||||||
|
prop.callback(value);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// image paste
|
||||||
|
if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) {
|
||||||
|
var filename = "";
|
||||||
|
if(this.images && ComfyApp.clipspace.images) {
|
||||||
|
this.images = ComfyApp.clipspace.images;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(ComfyApp.clipspace.images != undefined) {
|
||||||
|
const clip_image = ComfyApp.clipspace.images[0];
|
||||||
|
if(clip_image.subfolder != '')
|
||||||
|
filename = `${clip_image.subfolder}/`;
|
||||||
|
filename += `${clip_image.filename} [${clip_image.type}]`;
|
||||||
|
}
|
||||||
|
else if(ComfyApp.clipspace.widgets != undefined) {
|
||||||
|
const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image');
|
||||||
|
if(index_in_clip >= 0) {
|
||||||
|
filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const index = this.widgets.findIndex(obj => obj.name === 'image');
|
||||||
|
if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) {
|
||||||
|
this.imgs = ComfyApp.clipspace.imgs;
|
||||||
|
|
||||||
|
this.widgets[index].value = filename;
|
||||||
|
if(this.widgets_values != undefined) {
|
||||||
|
this.widgets_values[index] = filename;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.trigger('changed');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -136,9 +136,11 @@ function addMultilineWidget(node, name, opts, app) {
|
|||||||
left: `${t.a * margin + t.e}px`,
|
left: `${t.a * margin + t.e}px`,
|
||||||
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
|
top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`,
|
||||||
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
|
width: `${(widgetWidth - margin * 2 - 3) * t.a}px`,
|
||||||
|
background: (!node.color)?'':node.color,
|
||||||
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
|
height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`,
|
||||||
position: "absolute",
|
position: "absolute",
|
||||||
zIndex: 1,
|
color: (!node.color)?'':'white',
|
||||||
|
zIndex: app.graph._nodes.indexOf(node),
|
||||||
fontSize: `${t.d * 10.0}px`,
|
fontSize: `${t.d * 10.0}px`,
|
||||||
});
|
});
|
||||||
this.inputEl.hidden = !visible;
|
this.inputEl.hidden = !visible;
|
||||||
@@ -270,6 +272,9 @@ export const ComfyWidgets = {
|
|||||||
app.graph.setDirtyCanvas(true);
|
app.graph.setDirtyCanvas(true);
|
||||||
};
|
};
|
||||||
img.src = `/view?filename=${name}&type=input`;
|
img.src = `/view?filename=${name}&type=input`;
|
||||||
|
if ((node.size[1] - node.imageOffset) < 100) {
|
||||||
|
node.size[1] = 250 + node.imageOffset;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add our own callback to the combo widget to render an image when it changes
|
// Add our own callback to the combo widget to render an image when it changes
|
||||||
|
Reference in New Issue
Block a user