diff --git a/README.md b/README.md index d004364ee..a148623cd 100644 --- a/README.md +++ b/README.md @@ -294,6 +294,13 @@ For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a 2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html) 3. Launch ComfyUI by running `python main.py` +#### Iluvatar Corex + +For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step guide tailored to your platform and installation method: + +1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536) +2. Launch ComfyUI by running `python main.py` + # Running ```python main.py``` diff --git a/app/frontend_management.py b/app/frontend_management.py index 001ebbecb..0bee73685 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -29,18 +29,48 @@ def frontend_install_warning_message(): This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. """.strip() +def parse_version(version: str) -> tuple[int, int, int]: + return tuple(map(int, version.split("."))) + +def is_valid_version(version: str) -> bool: + """Validate if a string is a valid semantic version (X.Y.Z format).""" + pattern = r"^(\d+)\.(\d+)\.(\d+)$" + return bool(re.match(pattern, version)) + +def get_installed_frontend_version(): + """Get the currently installed frontend package version.""" + frontend_version_str = version("comfyui-frontend-package") + return frontend_version_str + +def get_required_frontend_version(): + """Get the required frontend version from requirements.txt.""" + try: + with open(requirements_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line.startswith("comfyui-frontend-package=="): + version_str = line.split("==")[-1] + if not is_valid_version(version_str): + logging.error(f"Invalid version format in requirements.txt: {version_str}") + return None + return version_str + logging.error("comfyui-frontend-package not found in requirements.txt") + return None + except FileNotFoundError: + logging.error("requirements.txt not found. Cannot determine required frontend version.") + return None + except Exception as e: + logging.error(f"Error reading requirements.txt: {e}") + return None def check_frontend_version(): """Check if the frontend version is up to date.""" - def parse_version(version: str) -> tuple[int, int, int]: - return tuple(map(int, version.split("."))) - try: - frontend_version_str = version("comfyui-frontend-package") + frontend_version_str = get_installed_frontend_version() frontend_version = parse_version(frontend_version_str) - with open(requirements_path, "r", encoding="utf-8") as f: - required_frontend = parse_version(f.readline().split("=")[-1]) + required_frontend_str = get_required_frontend_version() + required_frontend = parse_version(required_frontend_str) if frontend_version < required_frontend: app.logger.log_startup_warning( f""" @@ -168,6 +198,11 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None: class FrontendManager: CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions") + @classmethod + def get_required_frontend_version(cls) -> str: + """Get the required frontend package version.""" + return get_required_frontend_version() + @classmethod def default_frontend_path(cls) -> str: try: diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 2ed415b1f..a2bc492fd 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1210,39 +1210,21 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, return x_next -@torch.no_grad() -def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): - extra_args = {} if extra_args is None else extra_args - - temp = [0] - def post_cfg_function(args): - temp[0] = args["uncond_denoised"] - return args["denoised"] - - model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) - - s_in = x.new_ones([x.shape[0]]) - for i in trange(len(sigmas) - 1, disable=disable): - sigma_hat = sigmas[i] - denoised = model(x, sigma_hat * s_in, **extra_args) - d = to_d(x, sigma_hat, temp[0]) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) - # Euler method - x = denoised + d * sigmas[i + 1] - return x - @torch.no_grad() def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - """Ancestral sampling with Euler method steps.""" + """Ancestral sampling with Euler method steps (CFG++).""" extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - temp = [0] + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + + uncond_denoised = None + def post_cfg_function(args): - temp[0] = args["uncond_denoised"] + nonlocal uncond_denoised + uncond_denoised = args["uncond_denoised"] return args["denoised"] model_options = extra_args.get("model_options", {}).copy() @@ -1251,15 +1233,33 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) - sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - d = to_d(x, sigmas[i], temp[0]) - # Euler method - x = denoised + d * sigma_down - if sigmas[i + 1] > 0: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp() + alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp() + d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise + + # DDIM stochastic sampling + sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta) + sigma_down = alpha_t * sigma_down + + # Euler method + x = alpha_t * denoised + sigma_down * d + if eta > 0 and s_noise > 0: + x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x + + +@torch.no_grad() +def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): + """Euler method steps (CFG++).""" + return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None) + + @torch.no_grad() def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" diff --git a/comfy/model_management.py b/comfy/model_management.py index e8b9b5c81..9add54ceb 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -128,6 +128,11 @@ try: except: mlu_available = False +try: + ixuca_available = hasattr(torch, "corex") +except: + ixuca_available = False + if args.cpu: cpu_state = CPUState.CPU @@ -151,6 +156,12 @@ def is_mlu(): return True return False +def is_ixuca(): + global ixuca_available + if ixuca_available: + return True + return False + def get_torch_device(): global directml_enabled global cpu_state @@ -289,7 +300,7 @@ try: if torch_version_numeric[0] >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if is_intel_xpu() or is_ascend_npu() or is_mlu(): + if is_intel_xpu() or is_ascend_npu() or is_mlu() or is_ixuca(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True except: @@ -1045,6 +1056,8 @@ def xformers_enabled(): return False if is_mlu(): return False + if is_ixuca(): + return False if directml_enabled: return False return XFORMERS_IS_AVAILABLE @@ -1080,6 +1093,8 @@ def pytorch_attention_flash_attention(): return True if is_amd(): return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention + if is_ixuca(): + return True return False def force_upcast_attention_dtype(): @@ -1205,6 +1220,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_mlu(): return True + if is_ixuca(): + return True + if torch.version.hip: return True @@ -1268,6 +1286,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_ascend_npu(): return True + if is_ixuca(): + return True + if is_amd(): arch = torch.cuda.get_device_properties(device).gcnArchName if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16 diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index 560b82be3..b40f920e4 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [ OFTAdapter, BOFTAdapter, ] +adapter_maps: dict[str, type[WeightAdapterBase]] = { + "LoRA": LoRAAdapter, + "LoHa": LoHaAdapter, + "LoKr": LoKrAdapter, + "OFT": OFTAdapter, + ## We disable not implemented algo for now + # "GLoRA": GLoRAAdapter, + # "BOFT": BOFTAdapter, +} + __all__ = [ "WeightAdapterBase", "WeightAdapterTrainBase", - "adapters" + "adapters", + "adapter_maps", ] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index b5c7db423..43644b106 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid): def tucker_weight(wa, wb, t): temp = torch.einsum("i j ..., j r -> i r ...", t, wb) return torch.einsum("i j ..., i r -> r j ...", temp, wa) + + +def factorization(dimension: int, factor: int = -1) -> tuple[int, int]: + """ + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + """ + + if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m < n: + new_m = m + 1 + while dimension % new_m != 0: + new_m += 1 + new_n = dimension // new_m + if new_m + new_n > length or new_m > factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index ce79abad5..55c97a3af 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -3,7 +3,120 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose +from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose + + +class HadaWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)): + ctx.save_for_backward(w1d, w1u, w2d, w2u, scale) + diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale + return diff_weight + + @staticmethod + def backward(ctx, grad_out): + (w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors + grad_out = grad_out * scale + temp = grad_out * (w2u @ w2d) + grad_w1u = temp @ w1d.T + grad_w1d = w1u.T @ temp + + temp = grad_out * (w1u @ w1d) + grad_w2u = temp @ w2d.T + grad_w2d = w2u.T @ temp + + del temp + return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None + + +class HadaWeightTucker(torch.autograd.Function): + @staticmethod + def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)): + ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale) + + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u) + + return rebuild1 * rebuild2 * scale + + @staticmethod + def backward(ctx, grad_out): + (t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors + grad_out = grad_out * scale + + temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u) + + grad_w = rebuild * grad_out + del rebuild + + grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T) + del grad_w, temp + + grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp) + grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T) + del grad_temp + + temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u) + + grad_w = rebuild * grad_out + del rebuild + + grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T) + del grad_w, temp + + grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp) + grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T) + del grad_temp + return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None + + +class LohaDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + # Unpack weights tuple from LoHaAdapter + w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights + + # Create trainable parameters + self.hada_w1_a = torch.nn.Parameter(w1a) + self.hada_w1_b = torch.nn.Parameter(w1b) + self.hada_w2_a = torch.nn.Parameter(w2a) + self.hada_w2_b = torch.nn.Parameter(w2b) + + self.use_tucker = False + if t1 is not None and t2 is not None: + self.use_tucker = True + self.hada_t1 = torch.nn.Parameter(t1) + self.hada_t2 = torch.nn.Parameter(t2) + else: + # Keep the attributes for consistent access + self.hada_t1 = None + self.hada_t2 = None + + # Store rank and non-trainable alpha + self.rank = w1b.shape[0] + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + def __call__(self, w): + org_dtype = w.dtype + + scale = self.alpha / self.rank + if self.use_tucker: + diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale) + else: + diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + + # Add the scaled difference to the original weight + weight = w.to(diff_weight) + diff_weight.reshape(w.shape) + + return weight.to(org_dtype) + + def passive_memory_usage(self): + """Calculates memory usage of the trainable parameters.""" + return sum(param.numel() * param.element_size() for param in self.parameters()) class LoHaAdapter(WeightAdapterBase): @@ -13,6 +126,25 @@ class LoHaAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + in_dim = weight.shape[1:].numel() + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + torch.nn.init.normal_(mat1, 0.1) + torch.nn.init.constant_(mat2, 0.0) + mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) + mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + torch.nn.init.normal_(mat3, 0.1) + torch.nn.init.normal_(mat4, 0.01) + return LohaDiff( + (mat1, mat2, alpha, mat3, mat4, None, None, None) + ) + + def to_train(self): + return LohaDiff(self.weights) + @classmethod def load( cls, diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 51233db2d..49b0be55f 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -3,7 +3,77 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + factorization, +) + + +class LokrDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights + self.use_tucker = False + if lokr_w1_a is not None: + _, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1] + rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1] + self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a) + self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b) + self.w1_rebuild = True + self.ranka = rank_a + + if lokr_w2_a is not None: + _, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1] + rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1] + self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a) + self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b) + if lokr_t2 is not None: + self.use_tucker = True + self.lokr_t2 = torch.nn.Parameter(lokr_t2) + self.w2_rebuild = True + self.rankb = rank_b + + if lokr_w1 is not None: + self.lokr_w1 = torch.nn.Parameter(lokr_w1) + self.w1_rebuild = False + + if lokr_w2 is not None: + self.lokr_w2 = torch.nn.Parameter(lokr_w2) + self.w2_rebuild = False + + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + @property + def w1(self): + if self.w1_rebuild: + return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka) + else: + return self.lokr_w1 + + @property + def w2(self): + if self.w2_rebuild: + if self.use_tucker: + w2 = torch.einsum( + 'i j k l, j r, i p -> p r k l', + self.lokr_t2, + self.lokr_w2_b, + self.lokr_w2_a + ) + else: + w2 = self.lokr_w2_a @ self.lokr_w2_b + return w2 * (self.alpha / self.rankb) + else: + return self.lokr_w2 + + def __call__(self, w): + diff = torch.kron(self.w1, self.w2) + return w + diff.reshape(w.shape).to(w) + + def passive_memory_usage(self): + return sum(param.numel() * param.element_size() for param in self.parameters()) class LoKrAdapter(WeightAdapterBase): @@ -13,6 +83,20 @@ class LoKrAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + in_dim = weight.shape[1:].numel() + out1, out2 = factorization(out_dim, rank) + in1, in2 = factorization(in_dim, rank) + mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype) + mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype) + torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) + torch.nn.init.constant_(mat1, 0.0) + return LokrDiff( + (mat1, mat2, alpha, None, None, None, None, None, None) + ) + @classmethod def load( cls, diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index 25009eca3..9d4982083 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -3,7 +3,58 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose +from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization + + +class OFTDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + # Unpack weights tuple from LoHaAdapter + blocks, rescale, alpha, _ = weights + + # Create trainable parameters + self.oft_blocks = torch.nn.Parameter(blocks) + if rescale is not None: + self.rescale = torch.nn.Parameter(rescale) + self.rescaled = True + else: + self.rescaled = False + self.block_num, self.block_size, _ = blocks.shape + self.constraint = float(alpha) + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + def __call__(self, w): + org_dtype = w.dtype + I = torch.eye(self.block_size, device=self.oft_blocks.device) + + ## generate r + # for Q = -Q^T + q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + normed_q = q + if self.constraint: + q_norm = torch.norm(q) + 1e-8 + if q_norm > self.constraint: + normed_q = q * self.constraint / q_norm + # use float() to prevent unsupported type + r = (I + normed_q) @ (I - normed_q).float().inverse() + + ## Apply chunked matmul on weight + _, *shape = w.shape + org_weight = w.to(dtype=r.dtype) + org_weight = org_weight.unflatten(0, (self.block_num, self.block_size)) + # Init R=0, so add I on it to ensure the output of step0 is original model output + weight = torch.einsum( + "k n m, k n ... -> k m ...", + r, + org_weight, + ).flatten(0, 1) + if self.rescaled: + weight = self.rescale * weight + return weight.to(org_dtype) + + def passive_memory_usage(self): + """Calculates memory usage of the trainable parameters.""" + return sum(param.numel() * param.element_size() for param in self.parameters()) class OFTAdapter(WeightAdapterBase): @@ -13,6 +64,18 @@ class OFTAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + block_size, block_num = factorization(out_dim, rank) + block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype) + return OFTDiff( + (block, None, alpha, None) + ) + + def to_train(self): + return OFTDiff(self.weights) + @classmethod def load( cls, @@ -60,6 +123,8 @@ class OFTAdapter(WeightAdapterBase): blocks = v[0] rescale = v[1] alpha = v[2] + if alpha is None: + alpha = 0 dora_scale = v[3] blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 3d05fdab5..c3aaaee9b 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -20,7 +20,7 @@ import folder_paths import node_helpers from comfy.cli_args import args from comfy.comfy_types.node_typing import IO -from comfy.weight_adapter import adapters +from comfy.weight_adapter import adapters, adapter_maps def make_batch_extra_option_dict(d, indicies, full_size=None): @@ -39,13 +39,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None): class TrainSampler(comfy.samplers.Sampler): - - def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): + def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback self.batch_size = batch_size self.total_steps = total_steps + self.grad_acc = grad_acc self.seed = seed self.training_dtype = training_dtype @@ -92,8 +92,9 @@ class TrainSampler(comfy.samplers.Sampler): self.loss_callback(loss.item()) pbar.set_postfix({"loss": f"{loss.item():.4f}"}) - self.optimizer.step() - self.optimizer.zero_grad() + if (i+1) % self.grad_acc == 0: + self.optimizer.step() + self.optimizer.zero_grad() torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -419,6 +420,16 @@ class TrainLoraNode: "tooltip": "The batch size to use for training.", }, ), + "grad_accumulation_steps": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 1024, + "step": 1, + "tooltip": "The number of gradient accumulation steps to use for training.", + } + ), "steps": ( IO.INT, { @@ -478,6 +489,17 @@ class TrainLoraNode: ["bf16", "fp32"], {"default": "bf16", "tooltip": "The dtype to use for lora."}, ), + "algorithm": ( + list(adapter_maps.keys()), + {"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."}, + ), + "gradient_checkpointing": ( + IO.BOOLEAN, + { + "default": True, + "tooltip": "Use gradient checkpointing for training.", + } + ), "existing_lora": ( folder_paths.get_filename_list("loras") + ["[None]"], { @@ -501,6 +523,7 @@ class TrainLoraNode: positive, batch_size, steps, + grad_accumulation_steps, learning_rate, rank, optimizer, @@ -508,6 +531,8 @@ class TrainLoraNode: seed, training_dtype, lora_dtype, + algorithm, + gradient_checkpointing, existing_lora, ): mp = model.clone() @@ -558,10 +583,8 @@ class TrainLoraNode: if existing_adapter is not None: break else: - # If no existing adapter found, use LoRA - # We will add algo option in the future existing_adapter = None - adapter_cls = adapters[0] + adapter_cls = adapter_maps[algorithm] if existing_adapter is not None: train_adapter = existing_adapter.to_train().to(lora_dtype) @@ -615,8 +638,9 @@ class TrainLoraNode: criterion = torch.nn.SmoothL1Loss() # setup models - for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): - patch(m) + if gradient_checkpointing: + for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + patch(m) mp.model.requires_grad_(False) comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) @@ -629,7 +653,8 @@ class TrainLoraNode: optimizer, loss_callback=loss_callback, batch_size=batch_size, - total_steps=steps, + grad_acc=grad_accumulation_steps, + total_steps=steps*grad_accumulation_steps, seed=seed, training_dtype=dtype ) diff --git a/cuda_malloc.py b/cuda_malloc.py index eb2857c5f..c1d9ae3ca 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -74,7 +74,8 @@ if not args.cuda_malloc: module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) version = module.__version__ - if int(version[0]) >= 2: #enable by default for torch version 2.0 and up + + if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch args.cuda_malloc = cuda_malloc_supported() except: pass diff --git a/server.py b/server.py index 6f801d66b..db6b11ad6 100644 --- a/server.py +++ b/server.py @@ -554,6 +554,7 @@ class PromptServer(): ram_free = comfy.model_management.get_free_memory(cpu_device) vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) + required_frontend_version = FrontendManager.get_required_frontend_version() system_stats = { "system": { @@ -561,6 +562,7 @@ class PromptServer(): "ram_total": ram_total, "ram_free": ram_free, "comfyui_version": __version__, + "required_frontend_version": required_frontend_version, "python_version": sys.version, "pytorch_version": comfy.model_management.torch_version, "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", diff --git a/tests-unit/app_test/frontend_manager_test.py b/tests-unit/app_test/frontend_manager_test.py index ce67df6c6..ce43ac564 100644 --- a/tests-unit/app_test/frontend_manager_test.py +++ b/tests-unit/app_test/frontend_manager_test.py @@ -1,7 +1,7 @@ import argparse import pytest from requests.exceptions import HTTPError -from unittest.mock import patch +from unittest.mock import patch, mock_open from app.frontend_management import ( FrontendManager, @@ -172,3 +172,36 @@ def test_init_frontend_fallback_on_error(): # Assert assert frontend_path == "/default/path" mock_check.assert_called_once() + + +def test_get_frontend_version(): + # Arrange + expected_version = "1.25.0" + mock_requirements_content = """torch +torchsde +comfyui-frontend-package==1.25.0 +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_frontend_version() + + # Assert + assert version == expected_version + + +def test_get_frontend_version_invalid_semver(): + # Arrange + mock_requirements_content = """torch +torchsde +comfyui-frontend-package==1.29.3.75 +other-package==1.0.0 +numpy""" + + # Act + with patch("builtins.open", mock_open(read_data=mock_requirements_content)): + version = FrontendManager.get_required_frontend_version() + + # Assert + assert version is None