diff --git a/comfy_extras/v3/nodes_lt.py b/comfy_extras/v3/nodes_lt.py new file mode 100644 index 000000000..10c60ed46 --- /dev/null +++ b/comfy_extras/v3/nodes_lt.py @@ -0,0 +1,528 @@ +from __future__ import annotations + +import math +import sys + +import av +import numpy as np +import torch + +import comfy.model_management +import comfy.model_sampling +import comfy.utils +import node_helpers +import nodes +from comfy.ldm.lightricks.symmetric_patchifier import ( + SymmetricPatchifier, + latent_to_pixel_coords, +) +from comfy_api.v3 import io + + +def conditioning_get_any_value(conditioning, key, default=None): + for t in conditioning: + if key in t[1]: + return t[1][key] + return default + + +def get_noise_mask(latent): + noise_mask = latent.get("noise_mask", None) + latent_image = latent["samples"] + if noise_mask is None: + batch_size, _, latent_length, _, _ = latent_image.shape + return torch.ones( + (batch_size, 1, latent_length, 1, 1), + dtype=torch.float32, + device=latent_image.device, + ) + return noise_mask.clone() + + +def get_keyframe_idxs(cond): + keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) + if keyframe_idxs is None: + return None, 0 + return keyframe_idxs, torch.unique(keyframe_idxs[:, 0]).shape[0] + + +def encode_single_frame(output_file, image_array: np.ndarray, crf): + container = av.open(output_file, "w", format="mp4") + try: + stream = container.add_stream( + "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"} + ) + stream.height = image_array.shape[0] + stream.width = image_array.shape[1] + av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat( + format="yuv420p" + ) + container.mux(stream.encode(av_frame)) + container.mux(stream.encode()) + finally: + container.close() + + +def decode_single_frame(video_file): + container = av.open(video_file) + try: + stream = next(s for s in container.streams if s.type == "video") + frame = next(container.decode(stream)) + finally: + container.close() + return frame.to_ndarray(format="rgb24") + + +def preprocess(image: torch.Tensor, crf=29): + if crf == 0: + return image + + image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy() + with sys.modules['io'].BytesIO() as output_file: + encode_single_frame(output_file, image_array, crf) + video_bytes = output_file.getvalue() + with sys.modules['io'].BytesIO(video_bytes) as video_file: + image_array = decode_single_frame(video_file) + return torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 + + +class EmptyLTXVLatentVideo(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="EmptyLTXVLatentVideo_V3", + category="latent/video/ltxv", + inputs=[ + io.Int.Input(id="width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input(id="height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input(id="length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input(id="batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, length, batch_size): + latent = torch.zeros( + [batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], + device=comfy.model_management.intermediate_device(), + ) + return io.NodeOutput({"samples": latent}) + + +class LTXVAddGuide(io.ComfyNodeV3): + NUM_PREFIX_FRAMES = 2 + PATCHIFIER = SymmetricPatchifier(1) + + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LTXVAddGuide_V3", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input(id="positive"), + io.Conditioning.Input(id="negative"), + io.Vae.Input(id="vae"), + io.Latent.Input(id="latent"), + io.Image.Input( + id="image", + tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. " + "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.", + ), + io.Int.Input( + id="frame_idx", + default=0, + min=-9999, + max=9999, + tooltip="Frame index to start the conditioning at. " + "For single-frame images or videos with 1-8 frames, any frame_idx value is acceptable. " + "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded " + "down to the nearest multiple of 8. Negative values are counted from the end of the video.", + ), + io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Conditioning.Output(id="positive_out", display_name="positive"), + io.Conditioning.Output(id="negative_out", display_name="negative"), + io.Latent.Output(id="latent_out", display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, latent, image, frame_idx, strength): + scale_factors = vae.downscale_index_formula + latent_image = latent["samples"] + noise_mask = get_noise_mask(latent) + + _, _, latent_length, latent_height, latent_width = latent_image.shape + image, t = cls._encode(vae, latent_width, latent_height, image, scale_factors) + + frame_idx, latent_idx = cls._get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) + assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." + + num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2]) + + positive, negative, latent_image, noise_mask = cls._append_keyframe( + positive, + negative, + frame_idx, + latent_image, + noise_mask, + t[:, :, :num_prefix_frames], + strength, + scale_factors, + ) + + latent_idx += num_prefix_frames + + t = t[:, :, num_prefix_frames:] + if t.shape[2] == 0: + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) + + latent_image, noise_mask = cls._replace_latent_frames( + latent_image, + noise_mask, + t, + latent_idx, + strength, + ) + + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) + + @classmethod + def _encode(cls, vae, latent_width, latent_height, images, scale_factors): + time_scale_factor, width_scale_factor, height_scale_factor = scale_factors + images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] + pixels = comfy.utils.common_upscale( + images.movedim(-1, 1), + latent_width * width_scale_factor, + latent_height * height_scale_factor, + "bilinear", + crop="disabled", + ).movedim(1, -1) + encode_pixels = pixels[:, :, :, :3] + t = vae.encode(encode_pixels) + return encode_pixels, t + + @classmethod + def _get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): + time_scale_factor, _, _ = scale_factors + _, num_keyframes = get_keyframe_idxs(cond) + latent_count = latent_length - num_keyframes + frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0) + if guide_length > 1 and frame_idx != 0: + frame_idx = (frame_idx - 1) // time_scale_factor * time_scale_factor + 1 + return frame_idx, (frame_idx + time_scale_factor - 1) // time_scale_factor + + @classmethod + def _add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors): + keyframe_idxs, _ = get_keyframe_idxs(cond) + _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) + pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) + pixel_coords[:, 0] += frame_idx + if keyframe_idxs is None: + keyframe_idxs = pixel_coords + else: + keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2) + return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) + + @classmethod + def _append_keyframe( + cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors + ): + _, latent_idx = cls._get_latent_index( + cond=positive, + latent_length=latent_image.shape[2], + guide_length=guiding_latent.shape[2], + frame_idx=frame_idx, + scale_factors=scale_factors, + ) + noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 + + positive = cls._add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) + negative = cls._add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) + + mask = torch.full( + (noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1), + 1.0 - strength, + dtype=noise_mask.dtype, + device=noise_mask.device, + ) + + latent_image = torch.cat([latent_image, guiding_latent], dim=2) + return positive, negative, latent_image, torch.cat([noise_mask, mask], dim=2) + + @classmethod + def _replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_idx, strength): + cond_length = guiding_latent.shape[2] + assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence." + + mask = torch.full( + (noise_mask.shape[0], 1, cond_length, 1, 1), + 1.0 - strength, + dtype=noise_mask.dtype, + device=noise_mask.device, + ) + + latent_image = latent_image.clone() + noise_mask = noise_mask.clone() + + latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent + noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask + + return latent_image, noise_mask + + +class LTXVConditioning(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LTXVConditioning_V3", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input(id="positive"), + io.Conditioning.Input(id="negative"), + io.Float.Input(id="frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01), + ], + outputs=[ + io.Conditioning.Output(id="positive_out", display_name="positive"), + io.Conditioning.Output(id="negative_out", display_name="negative"), + ], + ) + + @classmethod + def execute(cls, positive, negative, frame_rate): + positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate}) + negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate}) + return io.NodeOutput(positive, negative) + + +class LTXVCropGuides(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LTXVCropGuides_V3", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input(id="positive"), + io.Conditioning.Input(id="negative"), + io.Latent.Input(id="latent"), + ], + outputs=[ + io.Conditioning.Output(id="positive_out", display_name="positive"), + io.Conditioning.Output(id="negative_out", display_name="negative"), + io.Latent.Output(id="latent_out", display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, latent): + latent_image = latent["samples"].clone() + noise_mask = get_noise_mask(latent) + + _, num_keyframes = get_keyframe_idxs(positive) + if num_keyframes == 0: + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) + + latent_image = latent_image[:, :, :-num_keyframes] + noise_mask = noise_mask[:, :, :-num_keyframes] + + positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) + negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) + + return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) + + +class LTXVImgToVideo(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LTXVImgToVideo_V3", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input(id="positive"), + io.Conditioning.Input(id="negative"), + io.Vae.Input(id="vae"), + io.Image.Input(id="image"), + io.Int.Input(id="width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input(id="height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input(id="length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input(id="batch_size", default=1, min=1, max=4096), + io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0), + ], + outputs=[ + io.Conditioning.Output(id="positive_out", display_name="positive"), + io.Conditioning.Output(id="negative_out", display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, image, vae, width, height, length, batch_size, strength): + pixels = comfy.utils.common_upscale( + image.movedim(-1, 1), width, height, "bilinear", "center" + ).movedim(1, -1) + encode_pixels = pixels[:, :, :, :3] + t = vae.encode(encode_pixels) + + latent = torch.zeros( + [batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], + device=comfy.model_management.intermediate_device(), + ) + latent[:, :, :t.shape[2]] = t + + conditioning_latent_frames_mask = torch.ones( + (batch_size, 1, latent.shape[2], 1, 1), + dtype=torch.float32, + device=latent.device, + ) + conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength + + return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}) + + +class LTXVPreprocess(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LTXVPreprocess_V3", + category="image", + inputs=[ + io.Image.Input(id="image"), + io.Int.Input( + id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image." + ), + ], + outputs=[ + io.Image.Output(id="output_image", display_name="output_image"), + ], + ) + + @classmethod + def execute(cls, image, img_compression): + output_images = [] + for i in range(image.shape[0]): + output_images.append(preprocess(image[i], img_compression)) + return io.NodeOutput(torch.stack(output_images)) + + +class LTXVScheduler(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="LTXVScheduler_V3", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input(id="steps", default=20, min=1, max=10000), + io.Float.Input(id="max_shift", default=2.05, min=0.0, max=100.0, step=0.01), + io.Float.Input(id="base_shift", default=0.95, min=0.0, max=100.0, step=0.01), + io.Boolean.Input( + id="stretch", + default=True, + tooltip="Stretch the sigmas to be in the range [terminal, 1].", + ), + io.Float.Input( + id="terminal", + default=0.1, + min=0.0, + max=0.99, + step=0.01, + tooltip="The terminal value of the sigmas after stretching.", + ), + io.Latent.Input(id="latent", optional=True), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) + + @classmethod + def execute(cls, steps, max_shift, base_shift, stretch, terminal, latent=None): + if latent is None: + tokens = 4096 + else: + tokens = math.prod(latent["samples"].shape[2:]) + + sigmas = torch.linspace(1.0, 0.0, steps + 1) + + x1 = 1024 + x2 = 4096 + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + sigma_shift = (tokens) * mm + b + + power = 1 + sigmas = torch.where( + sigmas != 0, + math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power), + 0, + ) + + if stretch: + non_zero_mask = sigmas != 0 + non_zero_sigmas = sigmas[non_zero_mask] + one_minus_z = 1.0 - non_zero_sigmas + scale_factor = one_minus_z[-1] / (1.0 - terminal) + stretched = 1.0 - (one_minus_z / scale_factor) + sigmas[non_zero_mask] = stretched + + return io.NodeOutput(sigmas) + + +class ModelSamplingLTXV(io.ComfyNodeV3): + @classmethod + def define_schema(cls): + return io.SchemaV3( + node_id="ModelSamplingLTXV_V3", + category="advanced/model", + inputs=[ + io.Model.Input(id="model"), + io.Float.Input(id="max_shift", default=2.05, min=0.0, max=100.0, step=0.01), + io.Float.Input(id="base_shift", default=0.95, min=0.0, max=100.0, step=0.01), + io.Latent.Input(id="latent", optional=True), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, max_shift, base_shift, latent=None): + m = model.clone() + + if latent is None: + tokens = 4096 + else: + tokens = math.prod(latent["samples"].shape[2:]) + + x1 = 1024 + x2 = 4096 + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + shift = (tokens) * mm + b + + sampling_base = comfy.model_sampling.ModelSamplingFlux + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift=shift) + m.add_object_patch("model_sampling", model_sampling) + + return io.NodeOutput(m) + + +NODES_LIST = [ + EmptyLTXVLatentVideo, + LTXVAddGuide, + LTXVConditioning, + LTXVCropGuides, + LTXVImgToVideo, + LTXVPreprocess, + LTXVScheduler, + ModelSamplingLTXV, +] diff --git a/nodes.py b/nodes.py index 9be58e027..d660f25e5 100644 --- a/nodes.py +++ b/nodes.py @@ -2320,6 +2320,7 @@ def init_builtin_extra_nodes(): "v3/nodes_gits.py", "v3/nodes_images.py", "v3/nodes_latent.py", + "v3/nodes_lt.py", "v3/nodes_mask.py", "v3/nodes_preview_any.py", "v3/nodes_primitive.py",