mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 20:48:22 +00:00
Refactor comfy.ops
comfy.ops -> comfy.ops.disable_weight_init This should make it more clear what they actually do. Some unused code has also been removed.
This commit is contained in:
@@ -208,9 +208,9 @@ class ControlLoraOps:
|
||||
|
||||
def forward(self, input):
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
|
||||
return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
|
||||
return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias)
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -247,24 +247,9 @@ class ControlLoraOps:
|
||||
|
||||
def forward(self, input):
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
def conv_nd(self, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
return self.Conv2d(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
class Conv3d(comfy.ops.Conv3d):
|
||||
pass
|
||||
|
||||
class GroupNorm(comfy.ops.GroupNorm):
|
||||
pass
|
||||
|
||||
class LayerNorm(comfy.ops.LayerNorm):
|
||||
pass
|
||||
return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
@@ -278,7 +263,9 @@ class ControlLora(ControlNet):
|
||||
controlnet_config = model.model_config.unet_config.copy()
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
||||
controlnet_config["operations"] = ControlLoraOps()
|
||||
class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
|
||||
pass
|
||||
controlnet_config["operations"] = control_lora_ops
|
||||
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||
dtype = model.get_dtype()
|
||||
self.control_model.to(dtype)
|
||||
|
Reference in New Issue
Block a user