Basic Flux Schnell and Flux Dev model implementation.

This commit is contained in:
comfyanonymous
2024-08-01 04:03:59 -04:00
parent 7ad574bffd
commit 1589b58d3e
12 changed files with 624 additions and 3 deletions

View File

@@ -9,6 +9,7 @@ import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
from . import supported_models_base
from . import latent_formats
@@ -619,7 +620,45 @@ class HunyuanDiT1(HunyuanDiT):
"linear_end" : 0.03,
}
class Flux(supported_models_base.BASE):
unet_config = {
"image_model": "flux",
"guidance_embed": True,
}
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1]
sampling_settings = {
}
unet_extra_config = {}
latent_format = latent_formats.Flux
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.FluxClipModel)
class FluxSchnell(Flux):
unet_config = {
"image_model": "flux",
"guidance_embed": False,
}
sampling_settings = {
"multiplier": 1.0,
"shift": 1.0,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
return out
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
models += [SVD_img2vid]