Remove some trailing white space.

This commit is contained in:
comfyanonymous
2024-12-27 18:02:21 -05:00
parent 9cfd185676
commit d170292594
15 changed files with 37 additions and 38 deletions

View File

@@ -305,7 +305,7 @@ class FeatherMask:
output[:, -y, :] *= feather_rate
return (output,)
class GrowMask:
@classmethod
def INPUT_TYPES(cls):
@@ -316,7 +316,7 @@ class GrowMask:
"tapered_corners": ("BOOLEAN", {"default": True}),
},
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)

View File

@@ -64,7 +64,7 @@ class Guider_PerpNeg(comfy.samplers.CFGGuider):
def predict_noise(self, x, timestep, model_options={}, seed=None):
# in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg
# but we'd rather do a single batch of sampling pos, neg, and empty, so we call calc_cond_batch([pos,neg,empty]) directly
positive_cond = self.conds.get("positive", None)
negative_cond = self.conds.get("negative", None)
empty_cond = self.conds.get("empty_negative_prompt", None)

View File

@@ -40,7 +40,7 @@ class LatentRebatch:
return slices, indexable[num * batch_size:]
else:
return slices, None
@staticmethod
def slice_batch(batch, num, batch_size):
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
@@ -81,7 +81,7 @@ class LatentRebatch:
if current_batch[0].shape[0] > batch_size:
num = current_batch[0].shape[0] // batch_size
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
for i in range(num):
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})

View File

@@ -40,9 +40,8 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
return do_nothing, do_nothing
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
with torch.no_grad():
hsy, wsx = h // sy, w // sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
@@ -50,7 +49,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
@@ -99,7 +98,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
n, t1, c = src.shape
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)