From 66d35c07ce44b07011314ad7a28b2bdbcbb4e4cc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 Jul 2024 20:27:40 -0400 Subject: [PATCH] Improve artifacts on hydit, auraflow and SD3 on specific resolutions. This breaks seeds for resolutions that are not a multiple of 16 in pixel resolution by using circular padding instead of reflection padding but should lower the amount of artifacts when doing img2img at those resolutions. --- comfy/ldm/aura/mmdit.py | 2 +- comfy/ldm/modules/diffusionmodules/mmdit.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index c465619b..2564166a 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -409,7 +409,7 @@ class MMDiT(nn.Module): pad_h = (self.patch_size - H % self.patch_size) % self.patch_size pad_w = (self.patch_size - W % self.patch_size) % self.patch_size - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') + x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular') x = x.view( B, C, diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index f37f7ff7..aac48a7f 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -69,12 +69,14 @@ class PatchEmbed(nn.Module): bias: bool = True, strict_img_size: bool = True, dynamic_img_pad: bool = True, + padding_mode='circular', dtype=None, device=None, operations=None, ): super().__init__() self.patch_size = (patch_size, patch_size) + self.padding_mode = padding_mode if img_size is not None: self.img_size = (img_size, img_size) self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) @@ -110,7 +112,7 @@ class PatchEmbed(nn.Module): if self.dynamic_img_pad: pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') + x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode=self.padding_mode) x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC