mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 03:58:22 +00:00
Support for Control Loras.
Control loras are controlnets where some of the weights are stored in "lora" format: an up and a down low rank matrice that when multiplied together and added to the unet weight give the controlnet weight. This allows a much smaller memory footprint depending on the rank of the matrices. These controlnets are used just like regular ones.
This commit is contained in:
110
comfy/sd.py
110
comfy/sd.py
@@ -844,9 +844,119 @@ class ControlNet(ControlBase):
|
||||
out.append(self.control_model_wrapped)
|
||||
return out
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = None
|
||||
self.up = None
|
||||
self.down = None
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input, self.weight, self.bias)
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.transposed = False
|
||||
self.output_padding = 0
|
||||
self.groups = groups
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
self.up = None
|
||||
self.down = None
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return torch.nn.functional.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
def conv_nd(self, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
return self.Conv2d(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
||||
ControlBase.__init__(self, device)
|
||||
self.control_weights = control_weights
|
||||
self.global_average_pooling = global_average_pooling
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
controlnet_config = model.model_config.unet_config.copy()
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
||||
controlnet_config["operations"] = ControlLoraOps()
|
||||
self.control_model = cldm.ControlNet(**controlnet_config)
|
||||
if model_management.should_use_fp16():
|
||||
self.control_model.half()
|
||||
self.control_model.to(model_management.get_torch_device())
|
||||
diffusion_model = model.diffusion_model
|
||||
sd = diffusion_model.state_dict()
|
||||
cm = self.control_model.state_dict()
|
||||
|
||||
for k in sd:
|
||||
try:
|
||||
set_attr(self.control_model, k, sd[k])
|
||||
except:
|
||||
pass
|
||||
|
||||
for k in self.control_weights:
|
||||
if k not in {"lora_controlnet"}:
|
||||
set_attr(self.control_model, k, self.control_weights[k].to(model_management.get_torch_device()))
|
||||
|
||||
def copy(self):
|
||||
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def cleanup(self):
|
||||
del self.control_model
|
||||
self.control_model = None
|
||||
super().cleanup()
|
||||
|
||||
def get_models(self):
|
||||
out = ControlBase.get_models(self)
|
||||
return out
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
return ControlLora(controlnet_data)
|
||||
|
||||
controlnet_config = None
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||
|
Reference in New Issue
Block a user