diff --git a/nodes.py b/nodes.py index 8e6825292..abb60675b 100644 --- a/nodes.py +++ b/nodes.py @@ -63,31 +63,22 @@ class ConditioningAverage : @classmethod def INPUT_TYPES(s): return {"required": {"conditioning_from": ("CONDITIONING", ), "conditioning_to": ("CONDITIONING", ), - "conditioning_from_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1}), - #"conditioning_to_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.1}) + "conditioning_from_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1}) }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "addWeighted" CATEGORY = "conditioning" - #def applyConditions(self, conditioning_from, conditioning_to, conditioning_from_strength, conditioning_to_strength): - # c = [] - # for t in conditioning_from: - # averaged = self.averageConditioning(t[0], conditioning_to, conditioning_from_strength, conditioning_to_strength) - # n = [averaged, t[1].clone()] - # c.append(n) - # return (c, ) - def addWeighted(self, conditioning_from, conditioning_to, conditioning_from_strength): conditioning_to_strength = (1-conditioning_from_strength) conditioning_from_tensor = conditioning_from[0][0] conditioning_to_tensor = conditioning_to[0][0] output = conditioning_from - if conditioning_from_tensor.shape[1] > conditioning_to_tensor.shape[1]: - conditioning_to_tensor = torch.cat((conditioning_to_tensor, torch.zeros((1,conditioning_from_tensor.shape[1] - conditioning_to_tensor.shape[1],768))), dim=1) + if conditioning_from_tensor.shape[0] > conditioning_to_tensor.shape[1]: + conditioning_to_tensor = torch.cat((conditioning_to_tensor, torch.zeros((1, conditioning_from_tensor.shape[1] - conditioning_to_tensor.shape[1], conditioning_from_tensor.shape[1].value))), dim=1) elif conditioning_to_tensor.shape[1] > conditioning_from_tensor.shape[1]: - conditioning_from_tensor = torch.cat((conditioning_from_tensor, torch.zeros((conditioning_to_tensor.shape[1].value,conditioning_to_tensor.shape[1] - conditioning_from_tensor.shape[1],conditioning_from_tensor.shape[1].value))), dim=1) + conditioning_from_tensor = torch.cat((conditioning_from_tensor, torch.zeros((1, conditioning_to_tensor.shape[1] - conditioning_from_tensor.shape[1], conditioning_to_tensor.shape[1].value))), dim=1) output[0][0] = ((conditioning_from_tensor * conditioning_from_strength) + (conditioning_to_tensor * conditioning_to_strength)) return (output, )