Implement Linear hypernetworks.

Add a HypernetworkLoader node to use hypernetworks.
This commit is contained in:
comfyanonymous
2023-04-23 12:35:25 -04:00
parent 6908f9c949
commit 5282f56434
9 changed files with 185 additions and 16 deletions

View File

@@ -254,6 +254,29 @@ class ModelPatcher:
def set_model_sampler_cfg_function(self, sampler_cfg_function):
self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_patch(self, patch, name):
to = self.model_options["transformer_options"]
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
def model_dtype(self):
return self.model.diffusion_model.dtype