Native LotusD Implementation (#7125)

* draft pass at a native comfy implementation of Lotus-D depth and normal est

* fix model_sampling kludges

* fix ruff

---------

Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
This commit is contained in:
thot experiment
2025-03-21 11:04:15 -07:00
committed by GitHub
parent 0cf2274699
commit 83e839a89b
6 changed files with 74 additions and 3 deletions

View File

@@ -506,6 +506,22 @@ class SDXL_instructpix2pix(SDXL):
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
class LotusD(SD20):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"use_temporal_attention": False,
"adm_in_channels": 4,
"in_channels": 4,
}
unet_extra_config = {
"num_classes": 'sequential'
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.Lotus(self, device=device)
class SD3(supported_models_base.BASE):
unet_config = {
"in_channels": 16,
@@ -997,6 +1013,6 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
latent_format = latent_formats.Hunyuan3Dv2mini
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
models += [SVD_img2vid]