SD3 Support.

This commit is contained in:
comfyanonymous
2024-06-10 13:26:25 -04:00
parent a5e6a632f9
commit 8c4a9befa7
17 changed files with 132182 additions and 5 deletions

View File

@@ -5,6 +5,7 @@ from . import utils
from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
from . import sd3_clip
from . import supported_models_base
from . import latent_formats
@@ -488,6 +489,28 @@ class SDXL_instructpix2pix(SDXL):
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
class SD3(supported_models_base.BASE):
unet_config = {
"in_channels": 16,
"pos_embed_scaling_factor": None,
}
sampling_settings = {
"shift": 3.0,
}
unet_extra_config = {}
latent_format = latent_formats.SD3
text_encoder_key_prefix = ["text_encoders."] #TODO?
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD3(self, device=device)
return out
def clip_target(self):
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.SD3ClipModel) #TODO?
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3]
models += [SVD_img2vid]