Make sure cond_concat is on the right device.

This commit is contained in:
comfyanonymous
2023-10-19 01:10:41 -04:00
parent 45c972aba8
commit e6962120c6
2 changed files with 7 additions and 5 deletions

View File

@@ -79,6 +79,7 @@ class BaseModel(torch.nn.Module):
denoise_mask = kwargs.get("denoise_mask", None)
latent_image = kwargs.get("latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
@@ -92,9 +93,9 @@ class BaseModel(torch.nn.Module):
for ck in concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1])
cond_concat.append(denoise_mask[:,:1].to(device))
elif ck == "masked_image":
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])