diff --git a/nodes.py b/nodes.py index abb60675b..e48d3c5bc 100644 --- a/nodes.py +++ b/nodes.py @@ -1,3 +1,5 @@ +from queue import Empty +from turtle import clone import torch import os @@ -5,9 +7,11 @@ import sys import json import hashlib import traceback +import copy from PIL import Image from PIL.PngImagePlugin import PngInfo +from copy import deepcopy import numpy as np @@ -74,12 +78,15 @@ class ConditioningAverage : conditioning_to_strength = (1-conditioning_from_strength) conditioning_from_tensor = conditioning_from[0][0] conditioning_to_tensor = conditioning_to[0][0] - output = conditioning_from - if conditioning_from_tensor.shape[0] > conditioning_to_tensor.shape[1]: - conditioning_to_tensor = torch.cat((conditioning_to_tensor, torch.zeros((1, conditioning_from_tensor.shape[1] - conditioning_to_tensor.shape[1], conditioning_from_tensor.shape[1].value))), dim=1) + + output = copy.deepcopy(conditioning_from) + + if conditioning_from_tensor.shape[1] > conditioning_to_tensor.shape[1]: + conditioning_to_tensor = torch.cat((conditioning_to_tensor, torch.zeros((1, conditioning_from_tensor.shape[1] - conditioning_to_tensor.shape[1], conditioning_from_tensor.shape[1].value))), dim=1) elif conditioning_to_tensor.shape[1] > conditioning_from_tensor.shape[1]: conditioning_from_tensor = torch.cat((conditioning_from_tensor, torch.zeros((1, conditioning_to_tensor.shape[1] - conditioning_from_tensor.shape[1], conditioning_to_tensor.shape[1].value))), dim=1) output[0][0] = ((conditioning_from_tensor * conditioning_from_strength) + (conditioning_to_tensor * conditioning_to_strength)) + return (output, ) class ConditioningSetArea: