Add option to use in place weight updating in ModelPatcher.

This commit is contained in:
comfyanonymous
2023-11-11 01:03:39 -05:00
parent 412d3ff57d
commit 4a8a839b40
2 changed files with 24 additions and 5 deletions

View File

@@ -261,6 +261,14 @@ def set_attr(obj, attr, value):
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_attr(obj, attr):
attrs = attr.split(".")
for name in attrs: