diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index 560b82be3..b40f920e4 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [ OFTAdapter, BOFTAdapter, ] +adapter_maps: dict[str, type[WeightAdapterBase]] = { + "LoRA": LoRAAdapter, + "LoHa": LoHaAdapter, + "LoKr": LoKrAdapter, + "OFT": OFTAdapter, + ## We disable not implemented algo for now + # "GLoRA": GLoRAAdapter, + # "BOFT": BOFTAdapter, +} + __all__ = [ "WeightAdapterBase", "WeightAdapterTrainBase", - "adapters" + "adapters", + "adapter_maps", ] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index b5c7db423..43644b106 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid): def tucker_weight(wa, wb, t): temp = torch.einsum("i j ..., j r -> i r ...", t, wb) return torch.einsum("i j ..., i r -> r j ...", temp, wa) + + +def factorization(dimension: int, factor: int = -1) -> tuple[int, int]: + """ + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + """ + + if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m < n: + new_m = m + 1 + while dimension % new_m != 0: + new_m += 1 + new_n = dimension // new_m + if new_m + new_n > length or new_m > factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index ce79abad5..55c97a3af 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -3,7 +3,120 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose +from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose + + +class HadaWeight(torch.autograd.Function): + @staticmethod + def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)): + ctx.save_for_backward(w1d, w1u, w2d, w2u, scale) + diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale + return diff_weight + + @staticmethod + def backward(ctx, grad_out): + (w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors + grad_out = grad_out * scale + temp = grad_out * (w2u @ w2d) + grad_w1u = temp @ w1d.T + grad_w1d = w1u.T @ temp + + temp = grad_out * (w1u @ w1d) + grad_w2u = temp @ w2d.T + grad_w2d = w2u.T @ temp + + del temp + return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None + + +class HadaWeightTucker(torch.autograd.Function): + @staticmethod + def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)): + ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale) + + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u) + + return rebuild1 * rebuild2 * scale + + @staticmethod + def backward(ctx, grad_out): + (t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors + grad_out = grad_out * scale + + temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u) + + grad_w = rebuild * grad_out + del rebuild + + grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T) + del grad_w, temp + + grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp) + grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T) + del grad_temp + + temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u) + + grad_w = rebuild * grad_out + del rebuild + + grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T) + del grad_w, temp + + grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp) + grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T) + del grad_temp + return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None + + +class LohaDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + # Unpack weights tuple from LoHaAdapter + w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights + + # Create trainable parameters + self.hada_w1_a = torch.nn.Parameter(w1a) + self.hada_w1_b = torch.nn.Parameter(w1b) + self.hada_w2_a = torch.nn.Parameter(w2a) + self.hada_w2_b = torch.nn.Parameter(w2b) + + self.use_tucker = False + if t1 is not None and t2 is not None: + self.use_tucker = True + self.hada_t1 = torch.nn.Parameter(t1) + self.hada_t2 = torch.nn.Parameter(t2) + else: + # Keep the attributes for consistent access + self.hada_t1 = None + self.hada_t2 = None + + # Store rank and non-trainable alpha + self.rank = w1b.shape[0] + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + def __call__(self, w): + org_dtype = w.dtype + + scale = self.alpha / self.rank + if self.use_tucker: + diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale) + else: + diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + + # Add the scaled difference to the original weight + weight = w.to(diff_weight) + diff_weight.reshape(w.shape) + + return weight.to(org_dtype) + + def passive_memory_usage(self): + """Calculates memory usage of the trainable parameters.""" + return sum(param.numel() * param.element_size() for param in self.parameters()) class LoHaAdapter(WeightAdapterBase): @@ -13,6 +126,25 @@ class LoHaAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + in_dim = weight.shape[1:].numel() + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + torch.nn.init.normal_(mat1, 0.1) + torch.nn.init.constant_(mat2, 0.0) + mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) + mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + torch.nn.init.normal_(mat3, 0.1) + torch.nn.init.normal_(mat4, 0.01) + return LohaDiff( + (mat1, mat2, alpha, mat3, mat4, None, None, None) + ) + + def to_train(self): + return LohaDiff(self.weights) + @classmethod def load( cls, diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 51233db2d..49b0be55f 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -3,7 +3,77 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + factorization, +) + + +class LokrDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights + self.use_tucker = False + if lokr_w1_a is not None: + _, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1] + rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1] + self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a) + self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b) + self.w1_rebuild = True + self.ranka = rank_a + + if lokr_w2_a is not None: + _, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1] + rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1] + self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a) + self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b) + if lokr_t2 is not None: + self.use_tucker = True + self.lokr_t2 = torch.nn.Parameter(lokr_t2) + self.w2_rebuild = True + self.rankb = rank_b + + if lokr_w1 is not None: + self.lokr_w1 = torch.nn.Parameter(lokr_w1) + self.w1_rebuild = False + + if lokr_w2 is not None: + self.lokr_w2 = torch.nn.Parameter(lokr_w2) + self.w2_rebuild = False + + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + @property + def w1(self): + if self.w1_rebuild: + return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka) + else: + return self.lokr_w1 + + @property + def w2(self): + if self.w2_rebuild: + if self.use_tucker: + w2 = torch.einsum( + 'i j k l, j r, i p -> p r k l', + self.lokr_t2, + self.lokr_w2_b, + self.lokr_w2_a + ) + else: + w2 = self.lokr_w2_a @ self.lokr_w2_b + return w2 * (self.alpha / self.rankb) + else: + return self.lokr_w2 + + def __call__(self, w): + diff = torch.kron(self.w1, self.w2) + return w + diff.reshape(w.shape).to(w) + + def passive_memory_usage(self): + return sum(param.numel() * param.element_size() for param in self.parameters()) class LoKrAdapter(WeightAdapterBase): @@ -13,6 +83,20 @@ class LoKrAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + in_dim = weight.shape[1:].numel() + out1, out2 = factorization(out_dim, rank) + in1, in2 = factorization(in_dim, rank) + mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype) + mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype) + torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) + torch.nn.init.constant_(mat1, 0.0) + return LokrDiff( + (mat1, mat2, alpha, None, None, None, None, None, None) + ) + @classmethod def load( cls, diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index 25009eca3..9d4982083 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -3,7 +3,58 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, weight_decompose +from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization + + +class OFTDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + # Unpack weights tuple from LoHaAdapter + blocks, rescale, alpha, _ = weights + + # Create trainable parameters + self.oft_blocks = torch.nn.Parameter(blocks) + if rescale is not None: + self.rescale = torch.nn.Parameter(rescale) + self.rescaled = True + else: + self.rescaled = False + self.block_num, self.block_size, _ = blocks.shape + self.constraint = float(alpha) + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + def __call__(self, w): + org_dtype = w.dtype + I = torch.eye(self.block_size, device=self.oft_blocks.device) + + ## generate r + # for Q = -Q^T + q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + normed_q = q + if self.constraint: + q_norm = torch.norm(q) + 1e-8 + if q_norm > self.constraint: + normed_q = q * self.constraint / q_norm + # use float() to prevent unsupported type + r = (I + normed_q) @ (I - normed_q).float().inverse() + + ## Apply chunked matmul on weight + _, *shape = w.shape + org_weight = w.to(dtype=r.dtype) + org_weight = org_weight.unflatten(0, (self.block_num, self.block_size)) + # Init R=0, so add I on it to ensure the output of step0 is original model output + weight = torch.einsum( + "k n m, k n ... -> k m ...", + r, + org_weight, + ).flatten(0, 1) + if self.rescaled: + weight = self.rescale * weight + return weight.to(org_dtype) + + def passive_memory_usage(self): + """Calculates memory usage of the trainable parameters.""" + return sum(param.numel() * param.element_size() for param in self.parameters()) class OFTAdapter(WeightAdapterBase): @@ -13,6 +64,18 @@ class OFTAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + block_size, block_num = factorization(out_dim, rank) + block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype) + return OFTDiff( + (block, None, alpha, None) + ) + + def to_train(self): + return OFTDiff(self.weights) + @classmethod def load( cls, @@ -60,6 +123,8 @@ class OFTAdapter(WeightAdapterBase): blocks = v[0] rescale = v[1] alpha = v[2] + if alpha is None: + alpha = 0 dora_scale = v[3] blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 3d05fdab5..c3aaaee9b 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -20,7 +20,7 @@ import folder_paths import node_helpers from comfy.cli_args import args from comfy.comfy_types.node_typing import IO -from comfy.weight_adapter import adapters +from comfy.weight_adapter import adapters, adapter_maps def make_batch_extra_option_dict(d, indicies, full_size=None): @@ -39,13 +39,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None): class TrainSampler(comfy.samplers.Sampler): - - def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): + def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback self.batch_size = batch_size self.total_steps = total_steps + self.grad_acc = grad_acc self.seed = seed self.training_dtype = training_dtype @@ -92,8 +92,9 @@ class TrainSampler(comfy.samplers.Sampler): self.loss_callback(loss.item()) pbar.set_postfix({"loss": f"{loss.item():.4f}"}) - self.optimizer.step() - self.optimizer.zero_grad() + if (i+1) % self.grad_acc == 0: + self.optimizer.step() + self.optimizer.zero_grad() torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -419,6 +420,16 @@ class TrainLoraNode: "tooltip": "The batch size to use for training.", }, ), + "grad_accumulation_steps": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 1024, + "step": 1, + "tooltip": "The number of gradient accumulation steps to use for training.", + } + ), "steps": ( IO.INT, { @@ -478,6 +489,17 @@ class TrainLoraNode: ["bf16", "fp32"], {"default": "bf16", "tooltip": "The dtype to use for lora."}, ), + "algorithm": ( + list(adapter_maps.keys()), + {"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."}, + ), + "gradient_checkpointing": ( + IO.BOOLEAN, + { + "default": True, + "tooltip": "Use gradient checkpointing for training.", + } + ), "existing_lora": ( folder_paths.get_filename_list("loras") + ["[None]"], { @@ -501,6 +523,7 @@ class TrainLoraNode: positive, batch_size, steps, + grad_accumulation_steps, learning_rate, rank, optimizer, @@ -508,6 +531,8 @@ class TrainLoraNode: seed, training_dtype, lora_dtype, + algorithm, + gradient_checkpointing, existing_lora, ): mp = model.clone() @@ -558,10 +583,8 @@ class TrainLoraNode: if existing_adapter is not None: break else: - # If no existing adapter found, use LoRA - # We will add algo option in the future existing_adapter = None - adapter_cls = adapters[0] + adapter_cls = adapter_maps[algorithm] if existing_adapter is not None: train_adapter = existing_adapter.to_train().to(lora_dtype) @@ -615,8 +638,9 @@ class TrainLoraNode: criterion = torch.nn.SmoothL1Loss() # setup models - for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): - patch(m) + if gradient_checkpointing: + for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + patch(m) mp.model.requires_grad_(False) comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) @@ -629,7 +653,8 @@ class TrainLoraNode: optimizer, loss_callback=loss_callback, batch_size=batch_size, - total_steps=steps, + grad_acc=grad_accumulation_steps, + total_steps=steps*grad_accumulation_steps, seed=seed, training_dtype=dtype )