Merge remote-tracking branch 'origin/master' into tiled-progress

This commit is contained in:
pythongosssss
2023-05-03 17:33:42 +01:00
32 changed files with 900 additions and 356 deletions

View File

@@ -4,7 +4,10 @@
from __future__ import annotations
from collections import OrderedDict
from typing import Literal
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch
import torch.nn as nn

View File

@@ -10,7 +10,17 @@ def load_hypernetwork_patch(path, strength):
activate_output = sd.get('activate_output', False)
last_layer_dropout = sd.get('last_layer_dropout', False)
if activation_func != 'linear' or is_layer_norm != False or use_dropout != False or activate_output != False or last_layer_dropout != False:
valid_activation = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
}
if activation_func not in valid_activation:
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
return None
@@ -28,15 +38,27 @@ def load_hypernetwork_patch(path, strength):
keys = attn_weights.keys()
linears = filter(lambda a: a.endswith(".weight"), keys)
linears = sorted(list(map(lambda a: a[:-len(".weight")], linears)))
linears = list(map(lambda a: a[:-len(".weight")], linears))
layers = []
for lin_name in linears:
for i in range(len(linears)):
lin_name = linears[i]
last_layer = (i == (len(linears) - 1))
penultimate_layer = (i == (len(linears) - 2))
lin_weight = attn_weights['{}.weight'.format(lin_name)]
lin_bias = attn_weights['{}.bias'.format(lin_name)]
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
layers += [layer]
layers.append(layer)
if activation_func != "linear":
if (not last_layer) or (activate_output):
layers.append(valid_activation[activation_func]())
if is_layer_norm:
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
if use_dropout:
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
layers.append(torch.nn.Dropout(p=0.3))
output.append(torch.nn.Sequential(*layers))
out[dim] = torch.nn.ModuleList(output)
@@ -71,7 +93,7 @@ class HypernetworkLoader:
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_hypernetwork"
CATEGORY = "_for_testing"
CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)