Merge branch 'master' into v3-definition-wip

This commit is contained in:
Jedrzej Kosinski 2025-06-27 11:30:15 -07:00
commit 533090465c

View File

@ -11,7 +11,7 @@ from comfy.ldm.modules.ema import LitEma
import comfy.ops import comfy.ops
class DiagonalGaussianRegularizer(torch.nn.Module): class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = True): def __init__(self, sample: bool = False):
super().__init__() super().__init__()
self.sample = sample self.sample = sample
@ -19,16 +19,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
yield from () yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z) posterior = DiagonalGaussianDistribution(z)
if self.sample: if self.sample:
z = posterior.sample() z = posterior.sample()
else: else:
z = posterior.mode() z = posterior.mode()
kl_loss = posterior.kl() return z, None
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
class AbstractAutoencoder(torch.nn.Module): class AbstractAutoencoder(torch.nn.Module):