Refactor comfy.ops

comfy.ops -> comfy.ops.disable_weight_init

This should make it more clear what they actually do.

Some unused code has also been removed.
This commit is contained in:
comfyanonymous
2023-12-11 23:27:13 -05:00
parent b0aab1e4ea
commit 77755ab8db
10 changed files with 94 additions and 170 deletions

View File

@@ -208,9 +208,9 @@ class ControlLoraOps:
def forward(self, input):
if self.up is not None:
return torch.nn.functional.linear(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
else:
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias)
class Conv2d(torch.nn.Module):
def __init__(
@@ -247,24 +247,9 @@ class ControlLoraOps:
def forward(self, input):
if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
return self.Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class Conv3d(comfy.ops.Conv3d):
pass
class GroupNorm(comfy.ops.GroupNorm):
pass
class LayerNorm(comfy.ops.LayerNorm):
pass
return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
class ControlLora(ControlNet):
@@ -278,7 +263,9 @@ class ControlLora(ControlNet):
controlnet_config = model.model_config.unet_config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps()
class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
pass
controlnet_config["operations"] = control_lora_ops
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
dtype = model.get_dtype()
self.control_model.to(dtype)