Support SVD img2vid model.

This commit is contained in:
comfyanonymous
2023-11-23 19:41:33 -05:00
parent 022033a0e7
commit 871cc20e13
11 changed files with 1030 additions and 100 deletions

View File

@@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
}
unet_extra_config = {
@@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": None,
"use_temporal_attention": False,
}
latent_format = latent_formats.SD15
@@ -88,6 +90,7 @@ class SD21UnclipL(SD20):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 1536,
"use_temporal_attention": False,
}
clip_vision_prefix = "embedder.model.visual."
@@ -100,6 +103,7 @@ class SD21UnclipH(SD20):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 2048,
"use_temporal_attention": False,
}
clip_vision_prefix = "embedder.model.visual."
@@ -112,6 +116,7 @@ class SDXLRefiner(supported_models_base.BASE):
"context_dim": 1280,
"adm_in_channels": 2560,
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
"use_temporal_attention": False,
}
latent_format = latent_formats.SDXL
@@ -148,7 +153,8 @@ class SDXL(supported_models_base.BASE):
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 2816
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
latent_format = latent_formats.SDXL
@@ -203,8 +209,34 @@ class SSD1B(SDXL):
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 4, 4],
"context_dim": 2048,
"adm_in_channels": 2816
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class SVD_img2vid(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 768,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
latent_format = latent_formats.SD15
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SVD_img2vid(self, device=device)
return out
def clip_target(self):
return None
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
models += [SVD_img2vid]