mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 21:45:06 +00:00
Basic hunyuan dit implementation. (#4102)
* Let tokenizers return weights to be stored in the saved checkpoint. * Basic hunyuan dit implementation. * Fix some resolutions not working. * Support hydit checkpoint save. * Init with right dtype. * Switch to optimized attention in pooler. * Fix black images on hunyuan dit.
This commit is contained in:
@@ -7,6 +7,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
||||
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
import comfy.ldm.aura.mmdit
|
||||
import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.model_management
|
||||
@@ -648,3 +649,35 @@ class StableAudio1(BaseModel):
|
||||
for l in s:
|
||||
sd["{}{}".format(k, l)] = s[l]
|
||||
return sd
|
||||
|
||||
class HunyuanDiT(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['text_embedding_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
|
||||
conditioning_mt5xl = kwargs.get("conditioning_mt5xl", None)
|
||||
if conditioning_mt5xl is not None:
|
||||
out['encoder_hidden_states_t5'] = comfy.conds.CONDRegular(conditioning_mt5xl)
|
||||
|
||||
attention_mask_mt5xl = kwargs.get("attention_mask_mt5xl", None)
|
||||
if attention_mask_mt5xl is not None:
|
||||
out['text_embedding_mask_t5'] = comfy.conds.CONDRegular(attention_mask_mt5xl)
|
||||
|
||||
width = kwargs.get("width", 768)
|
||||
height = kwargs.get("height", 768)
|
||||
crop_w = kwargs.get("crop_w", 0)
|
||||
crop_h = kwargs.get("crop_h", 0)
|
||||
target_width = kwargs.get("target_width", width)
|
||||
target_height = kwargs.get("target_height", height)
|
||||
|
||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||
return out
|
||||
|
Reference in New Issue
Block a user