mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 16:26:39 +00:00
Merge pull request #8968 from bigcat88/v3/nodes/latent-and-lt
[V3] nodes_lt.py and nodes_latent.py
This commit is contained in:
commit
87e72fc04c
340
comfy_extras/v3/nodes_latent.py
Normal file
340
comfy_extras/v3/nodes_latent.py
Normal file
@ -0,0 +1,340 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
import comfy_extras.nodes_post_processing
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||
if latent.shape[1:] != target_shape[1:]:
|
||||
latent = comfy.utils.common_upscale(
|
||||
latent, target_shape[-1], target_shape[-2], "bilinear", "center"
|
||||
)
|
||||
if repeat_batch:
|
||||
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
|
||||
return latent
|
||||
|
||||
|
||||
class LatentAdd(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LatentAdd_V3",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples1"),
|
||||
io.Latent.Input(id="samples2"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples1, samples2):
|
||||
samples_out = samples1.copy()
|
||||
|
||||
s1 = samples1["samples"]
|
||||
s2 = samples2["samples"]
|
||||
|
||||
s2 = reshape_latent_to(s1.shape, s2)
|
||||
samples_out["samples"] = s1 + s2
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
|
||||
class LatentApplyOperation(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LatentApplyOperation_V3",
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples"),
|
||||
io.LatentOperation.Input(id="operation"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, operation):
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
samples_out["samples"] = operation(latent=s1)
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
|
||||
class LatentApplyOperationCFG(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LatentApplyOperationCFG_V3",
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.LatentOperation.Input(id="operation"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, operation):
|
||||
m = model.clone()
|
||||
|
||||
def pre_cfg_function(args):
|
||||
conds_out = args["conds_out"]
|
||||
if len(conds_out) == 2:
|
||||
conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1]
|
||||
else:
|
||||
conds_out[0] = operation(latent=conds_out[0])
|
||||
return conds_out
|
||||
|
||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class LatentBatch(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LatentBatch_V3",
|
||||
category="latent/batch",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples1"),
|
||||
io.Latent.Input(id="samples2"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples1, samples2):
|
||||
samples_out = samples1.copy()
|
||||
s1 = samples1["samples"]
|
||||
s2 = samples2["samples"]
|
||||
|
||||
s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False)
|
||||
s = torch.cat((s1, s2), dim=0)
|
||||
samples_out["samples"] = s
|
||||
samples_out["batch_index"] = (samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) +
|
||||
samples2.get("batch_index", [x for x in range(0, s2.shape[0])]))
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
|
||||
class LatentBatchSeedBehavior(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LatentBatchSeedBehavior_V3",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples"),
|
||||
io.Combo.Input(id="seed_behavior", options=["random", "fixed"], default="fixed"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, seed_behavior):
|
||||
samples_out = samples.copy()
|
||||
latent = samples["samples"]
|
||||
if seed_behavior == "random":
|
||||
if 'batch_index' in samples_out:
|
||||
samples_out.pop('batch_index')
|
||||
elif seed_behavior == "fixed":
|
||||
batch_number = samples_out.get("batch_index", [0])[0]
|
||||
samples_out["batch_index"] = [batch_number] * latent.shape[0]
|
||||
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
|
||||
class LatentInterpolate(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LatentInterpolate_V3",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples1"),
|
||||
io.Latent.Input(id="samples2"),
|
||||
io.Float.Input(id="ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples1, samples2, ratio):
|
||||
samples_out = samples1.copy()
|
||||
|
||||
s1 = samples1["samples"]
|
||||
s2 = samples2["samples"]
|
||||
|
||||
s2 = reshape_latent_to(s1.shape, s2)
|
||||
|
||||
m1 = torch.linalg.vector_norm(s1, dim=(1))
|
||||
m2 = torch.linalg.vector_norm(s2, dim=(1))
|
||||
|
||||
s1 = torch.nan_to_num(s1 / m1)
|
||||
s2 = torch.nan_to_num(s2 / m2)
|
||||
|
||||
t = (s1 * ratio + s2 * (1.0 - ratio))
|
||||
mt = torch.linalg.vector_norm(t, dim=(1))
|
||||
st = torch.nan_to_num(t / mt)
|
||||
|
||||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
|
||||
class LatentMultiply(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LatentMultiply_V3",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples"),
|
||||
io.Float.Input(id="multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, multiplier):
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
samples_out["samples"] = s1 * multiplier
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
|
||||
class LatentOperationSharpen(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LatentOperationSharpen_V3",
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Int.Input(id="sharpen_radius", default=9, min=1, max=31, step=1),
|
||||
io.Float.Input(id="sigma", default=1.0, min=0.1, max=10.0, step=0.1),
|
||||
io.Float.Input(id="alpha", default=0.1, min=0.0, max=5.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.LatentOperation.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, sharpen_radius, sigma, alpha):
|
||||
def sharpen(latent, **kwargs):
|
||||
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
||||
normalized_latent = latent / luminance
|
||||
channels = latent.shape[1]
|
||||
|
||||
kernel_size = sharpen_radius * 2 + 1
|
||||
kernel = comfy_extras.nodes_post_processing.gaussian_kernel(kernel_size, sigma, device=luminance.device)
|
||||
center = kernel_size // 2
|
||||
|
||||
kernel *= alpha * -10
|
||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||
|
||||
padded_image = torch.nn.functional.pad(
|
||||
normalized_latent, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), "reflect"
|
||||
)
|
||||
sharpened = torch.nn.functional.conv2d(
|
||||
padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels
|
||||
)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||
|
||||
return luminance * sharpened
|
||||
return io.NodeOutput(sharpen)
|
||||
|
||||
|
||||
class LatentOperationTonemapReinhard(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LatentOperationTonemapReinhard_V3",
|
||||
category="latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Float.Input(id="multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.LatentOperation.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, multiplier):
|
||||
def tonemap_reinhard(latent, **kwargs):
|
||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||
normalized_latent = latent / latent_vector_magnitude
|
||||
|
||||
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
|
||||
top = (std * 5 + mean) * multiplier
|
||||
|
||||
#reinhard
|
||||
latent_vector_magnitude *= (1.0 / top)
|
||||
new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0)
|
||||
new_magnitude *= top
|
||||
|
||||
return normalized_latent * new_magnitude
|
||||
return io.NodeOutput(tonemap_reinhard)
|
||||
|
||||
|
||||
class LatentSubtract(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LatentSubtract_V3",
|
||||
category="latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input(id="samples1"),
|
||||
io.Latent.Input(id="samples2"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples1, samples2):
|
||||
samples_out = samples1.copy()
|
||||
|
||||
s1 = samples1["samples"]
|
||||
s2 = samples2["samples"]
|
||||
|
||||
s2 = reshape_latent_to(s1.shape, s2)
|
||||
samples_out["samples"] = s1 - s2
|
||||
return io.NodeOutput(samples_out)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
LatentAdd,
|
||||
LatentApplyOperation,
|
||||
LatentApplyOperationCFG,
|
||||
LatentBatch,
|
||||
LatentBatchSeedBehavior,
|
||||
LatentInterpolate,
|
||||
LatentMultiply,
|
||||
LatentOperationSharpen,
|
||||
LatentOperationTonemapReinhard,
|
||||
LatentSubtract,
|
||||
]
|
528
comfy_extras/v3/nodes_lt.py
Normal file
528
comfy_extras/v3/nodes_lt.py
Normal file
@ -0,0 +1,528 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import sys
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.model_sampling
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
import nodes
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import (
|
||||
SymmetricPatchifier,
|
||||
latent_to_pixel_coords,
|
||||
)
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def conditioning_get_any_value(conditioning, key, default=None):
|
||||
for t in conditioning:
|
||||
if key in t[1]:
|
||||
return t[1][key]
|
||||
return default
|
||||
|
||||
|
||||
def get_noise_mask(latent):
|
||||
noise_mask = latent.get("noise_mask", None)
|
||||
latent_image = latent["samples"]
|
||||
if noise_mask is None:
|
||||
batch_size, _, latent_length, _, _ = latent_image.shape
|
||||
return torch.ones(
|
||||
(batch_size, 1, latent_length, 1, 1),
|
||||
dtype=torch.float32,
|
||||
device=latent_image.device,
|
||||
)
|
||||
return noise_mask.clone()
|
||||
|
||||
|
||||
def get_keyframe_idxs(cond):
|
||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||
if keyframe_idxs is None:
|
||||
return None, 0
|
||||
return keyframe_idxs, torch.unique(keyframe_idxs[:, 0]).shape[0]
|
||||
|
||||
|
||||
def encode_single_frame(output_file, image_array: np.ndarray, crf):
|
||||
container = av.open(output_file, "w", format="mp4")
|
||||
try:
|
||||
stream = container.add_stream(
|
||||
"libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
||||
)
|
||||
stream.height = image_array.shape[0]
|
||||
stream.width = image_array.shape[1]
|
||||
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
|
||||
format="yuv420p"
|
||||
)
|
||||
container.mux(stream.encode(av_frame))
|
||||
container.mux(stream.encode())
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
|
||||
def decode_single_frame(video_file):
|
||||
container = av.open(video_file)
|
||||
try:
|
||||
stream = next(s for s in container.streams if s.type == "video")
|
||||
frame = next(container.decode(stream))
|
||||
finally:
|
||||
container.close()
|
||||
return frame.to_ndarray(format="rgb24")
|
||||
|
||||
|
||||
def preprocess(image: torch.Tensor, crf=29):
|
||||
if crf == 0:
|
||||
return image
|
||||
|
||||
image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
|
||||
with sys.modules['io'].BytesIO() as output_file:
|
||||
encode_single_frame(output_file, image_array, crf)
|
||||
video_bytes = output_file.getvalue()
|
||||
with sys.modules['io'].BytesIO(video_bytes) as video_file:
|
||||
image_array = decode_single_frame(video_file)
|
||||
return torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
|
||||
|
||||
|
||||
class EmptyLTXVLatentVideo(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="EmptyLTXVLatentVideo_V3",
|
||||
category="latent/video/ltxv",
|
||||
inputs=[
|
||||
io.Int.Input(id="width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input(id="height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input(id="length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input(id="batch_size", default=1, min=1, max=4096),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size):
|
||||
latent = torch.zeros(
|
||||
[batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
|
||||
class LTXVAddGuide(io.ComfyNodeV3):
|
||||
NUM_PREFIX_FRAMES = 2
|
||||
PATCHIFIER = SymmetricPatchifier(1)
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LTXVAddGuide_V3",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="positive"),
|
||||
io.Conditioning.Input(id="negative"),
|
||||
io.Vae.Input(id="vae"),
|
||||
io.Latent.Input(id="latent"),
|
||||
io.Image.Input(
|
||||
id="image",
|
||||
tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. "
|
||||
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.",
|
||||
),
|
||||
io.Int.Input(
|
||||
id="frame_idx",
|
||||
default=0,
|
||||
min=-9999,
|
||||
max=9999,
|
||||
tooltip="Frame index to start the conditioning at. "
|
||||
"For single-frame images or videos with 1-8 frames, any frame_idx value is acceptable. "
|
||||
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
|
||||
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
||||
),
|
||||
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
||||
io.Latent.Output(id="latent_out", display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength):
|
||||
scale_factors = vae.downscale_index_formula
|
||||
latent_image = latent["samples"]
|
||||
noise_mask = get_noise_mask(latent)
|
||||
|
||||
_, _, latent_length, latent_height, latent_width = latent_image.shape
|
||||
image, t = cls._encode(vae, latent_width, latent_height, image, scale_factors)
|
||||
|
||||
frame_idx, latent_idx = cls._get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
|
||||
|
||||
positive, negative, latent_image, noise_mask = cls._append_keyframe(
|
||||
positive,
|
||||
negative,
|
||||
frame_idx,
|
||||
latent_image,
|
||||
noise_mask,
|
||||
t[:, :, :num_prefix_frames],
|
||||
strength,
|
||||
scale_factors,
|
||||
)
|
||||
|
||||
latent_idx += num_prefix_frames
|
||||
|
||||
t = t[:, :, num_prefix_frames:]
|
||||
if t.shape[2] == 0:
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
latent_image, noise_mask = cls._replace_latent_frames(
|
||||
latent_image,
|
||||
noise_mask,
|
||||
t,
|
||||
latent_idx,
|
||||
strength,
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
@classmethod
|
||||
def _encode(cls, vae, latent_width, latent_height, images, scale_factors):
|
||||
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
||||
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
||||
pixels = comfy.utils.common_upscale(
|
||||
images.movedim(-1, 1),
|
||||
latent_width * width_scale_factor,
|
||||
latent_height * height_scale_factor,
|
||||
"bilinear",
|
||||
crop="disabled",
|
||||
).movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
return encode_pixels, t
|
||||
|
||||
@classmethod
|
||||
def _get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||
time_scale_factor, _, _ = scale_factors
|
||||
_, num_keyframes = get_keyframe_idxs(cond)
|
||||
latent_count = latent_length - num_keyframes
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
if guide_length > 1 and frame_idx != 0:
|
||||
frame_idx = (frame_idx - 1) // time_scale_factor * time_scale_factor + 1
|
||||
return frame_idx, (frame_idx + time_scale_factor - 1) // time_scale_factor
|
||||
|
||||
@classmethod
|
||||
def _add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors):
|
||||
keyframe_idxs, _ = get_keyframe_idxs(cond)
|
||||
_, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
|
||||
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0)
|
||||
pixel_coords[:, 0] += frame_idx
|
||||
if keyframe_idxs is None:
|
||||
keyframe_idxs = pixel_coords
|
||||
else:
|
||||
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
|
||||
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
|
||||
|
||||
@classmethod
|
||||
def _append_keyframe(
|
||||
cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors
|
||||
):
|
||||
_, latent_idx = cls._get_latent_index(
|
||||
cond=positive,
|
||||
latent_length=latent_image.shape[2],
|
||||
guide_length=guiding_latent.shape[2],
|
||||
frame_idx=frame_idx,
|
||||
scale_factors=scale_factors,
|
||||
)
|
||||
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
|
||||
|
||||
positive = cls._add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
|
||||
negative = cls._add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
|
||||
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
|
||||
return positive, negative, latent_image, torch.cat([noise_mask, mask], dim=2)
|
||||
|
||||
@classmethod
|
||||
def _replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_idx, strength):
|
||||
cond_length = guiding_latent.shape[2]
|
||||
assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
mask = torch.full(
|
||||
(noise_mask.shape[0], 1, cond_length, 1, 1),
|
||||
1.0 - strength,
|
||||
dtype=noise_mask.dtype,
|
||||
device=noise_mask.device,
|
||||
)
|
||||
|
||||
latent_image = latent_image.clone()
|
||||
noise_mask = noise_mask.clone()
|
||||
|
||||
latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent
|
||||
noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask
|
||||
|
||||
return latent_image, noise_mask
|
||||
|
||||
|
||||
class LTXVConditioning(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LTXVConditioning_V3",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="positive"),
|
||||
io.Conditioning.Input(id="negative"),
|
||||
io.Float.Input(id="frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, frame_rate):
|
||||
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
|
||||
return io.NodeOutput(positive, negative)
|
||||
|
||||
|
||||
class LTXVCropGuides(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LTXVCropGuides_V3",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="positive"),
|
||||
io.Conditioning.Input(id="negative"),
|
||||
io.Latent.Input(id="latent"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
||||
io.Latent.Output(id="latent_out", display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, latent):
|
||||
latent_image = latent["samples"].clone()
|
||||
noise_mask = get_noise_mask(latent)
|
||||
|
||||
_, num_keyframes = get_keyframe_idxs(positive)
|
||||
if num_keyframes == 0:
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
latent_image = latent_image[:, :, :-num_keyframes]
|
||||
noise_mask = noise_mask[:, :, :-num_keyframes]
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
|
||||
class LTXVImgToVideo(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LTXVImgToVideo_V3",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input(id="positive"),
|
||||
io.Conditioning.Input(id="negative"),
|
||||
io.Vae.Input(id="vae"),
|
||||
io.Image.Input(id="image"),
|
||||
io.Int.Input(id="width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input(id="height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input(id="length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input(id="batch_size", default=1, min=1, max=4096),
|
||||
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(id="positive_out", display_name="positive"),
|
||||
io.Conditioning.Output(id="negative_out", display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, image, vae, width, height, length, batch_size, strength):
|
||||
pixels = comfy.utils.common_upscale(
|
||||
image.movedim(-1, 1), width, height, "bilinear", "center"
|
||||
).movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
|
||||
latent = torch.zeros(
|
||||
[batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
latent[:, :, :t.shape[2]] = t
|
||||
|
||||
conditioning_latent_frames_mask = torch.ones(
|
||||
(batch_size, 1, latent.shape[2], 1, 1),
|
||||
dtype=torch.float32,
|
||||
device=latent.device,
|
||||
)
|
||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask})
|
||||
|
||||
|
||||
class LTXVPreprocess(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LTXVPreprocess_V3",
|
||||
category="image",
|
||||
inputs=[
|
||||
io.Image.Input(id="image"),
|
||||
io.Int.Input(
|
||||
id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image."
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(id="output_image", display_name="output_image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, img_compression):
|
||||
output_images = []
|
||||
for i in range(image.shape[0]):
|
||||
output_images.append(preprocess(image[i], img_compression))
|
||||
return io.NodeOutput(torch.stack(output_images))
|
||||
|
||||
|
||||
class LTXVScheduler(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="LTXVScheduler_V3",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Int.Input(id="steps", default=20, min=1, max=10000),
|
||||
io.Float.Input(id="max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
|
||||
io.Float.Input(id="base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
|
||||
io.Boolean.Input(
|
||||
id="stretch",
|
||||
default=True,
|
||||
tooltip="Stretch the sigmas to be in the range [terminal, 1].",
|
||||
),
|
||||
io.Float.Input(
|
||||
id="terminal",
|
||||
default=0.1,
|
||||
min=0.0,
|
||||
max=0.99,
|
||||
step=0.01,
|
||||
tooltip="The terminal value of the sigmas after stretching.",
|
||||
),
|
||||
io.Latent.Input(id="latent", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Sigmas.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, steps, max_shift, base_shift, stretch, terminal, latent=None):
|
||||
if latent is None:
|
||||
tokens = 4096
|
||||
else:
|
||||
tokens = math.prod(latent["samples"].shape[2:])
|
||||
|
||||
sigmas = torch.linspace(1.0, 0.0, steps + 1)
|
||||
|
||||
x1 = 1024
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
sigma_shift = (tokens) * mm + b
|
||||
|
||||
power = 1
|
||||
sigmas = torch.where(
|
||||
sigmas != 0,
|
||||
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
||||
0,
|
||||
)
|
||||
|
||||
if stretch:
|
||||
non_zero_mask = sigmas != 0
|
||||
non_zero_sigmas = sigmas[non_zero_mask]
|
||||
one_minus_z = 1.0 - non_zero_sigmas
|
||||
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
||||
stretched = 1.0 - (one_minus_z / scale_factor)
|
||||
sigmas[non_zero_mask] = stretched
|
||||
|
||||
return io.NodeOutput(sigmas)
|
||||
|
||||
|
||||
class ModelSamplingLTXV(io.ComfyNodeV3):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.SchemaV3(
|
||||
node_id="ModelSamplingLTXV_V3",
|
||||
category="advanced/model",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
io.Float.Input(id="max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
|
||||
io.Float.Input(id="base_shift", default=0.95, min=0.0, max=100.0, step=0.01),
|
||||
io.Latent.Input(id="latent", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, max_shift, base_shift, latent=None):
|
||||
m = model.clone()
|
||||
|
||||
if latent is None:
|
||||
tokens = 4096
|
||||
else:
|
||||
tokens = math.prod(latent["samples"].shape[2:])
|
||||
|
||||
x1 = 1024
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
shift = (tokens) * mm + b
|
||||
|
||||
sampling_base = comfy.model_sampling.ModelSamplingFlux
|
||||
sampling_type = comfy.model_sampling.CONST
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_parameters(shift=shift)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
EmptyLTXVLatentVideo,
|
||||
LTXVAddGuide,
|
||||
LTXVConditioning,
|
||||
LTXVCropGuides,
|
||||
LTXVImgToVideo,
|
||||
LTXVPreprocess,
|
||||
LTXVScheduler,
|
||||
ModelSamplingLTXV,
|
||||
]
|
Loading…
x
Reference in New Issue
Block a user