From 974254218ab873e8b9642b6a467d56842cd228c4 Mon Sep 17 00:00:00 2001 From: josephrocca <1167575+josephrocca@users.noreply.github.com> Date: Wed, 9 Jul 2025 03:56:59 +0800 Subject: [PATCH] Un-hardcode chroma patch_size (#8840) --- comfy/ldm/chroma/model.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index c75023a31..06021d4f2 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -254,13 +254,12 @@ class Chroma(nn.Module): def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape - patch_size = 2 - x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) - img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) - h_len = ((h + (patch_size // 2)) // patch_size) - w_len = ((w + (patch_size // 2)) // patch_size) + h_len = ((h + (self.patch_size // 2)) // self.patch_size) + w_len = ((w + (self.patch_size // 2)) // self.patch_size) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) @@ -268,4 +267,4 @@ class Chroma(nn.Module): txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h,:w]