mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 21:16:09 +00:00
Support SVD img2vid model.
This commit is contained in:
@@ -10,17 +10,22 @@ from . import utils
|
||||
class ModelType(Enum):
|
||||
EPS = 1
|
||||
V_PREDICTION = 2
|
||||
V_PREDICTION_EDM = 3
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
s = ModelSamplingDiscrete
|
||||
|
||||
if model_type == ModelType.EPS:
|
||||
c = EPS
|
||||
elif model_type == ModelType.V_PREDICTION:
|
||||
c = V_PREDICTION
|
||||
|
||||
s = ModelSamplingDiscrete
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@@ -262,3 +267,48 @@ class SDXL(BaseModel):
|
||||
out.append(self.embedder(torch.Tensor([target_width])))
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
class SVD_img2vid(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
fps_id = kwargs.get("fps", 6) - 1
|
||||
motion_bucket_id = kwargs.get("motion_bucket_id", 127)
|
||||
augmentation = kwargs.get("augmentation_level", 0)
|
||||
|
||||
out = []
|
||||
out.append(self.embedder(torch.Tensor([fps_id])))
|
||||
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
|
||||
out.append(self.embedder(torch.Tensor([augmentation])))
|
||||
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
|
||||
return flat
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
|
||||
latent_image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if latent_image is None:
|
||||
latent_image = torch.zeros_like(noise)
|
||||
|
||||
if latent_image.shape[1:] != noise.shape[1:]:
|
||||
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0])
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
|
||||
|
||||
if "time_conditioning" in kwargs:
|
||||
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
|
||||
|
||||
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
|
||||
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
|
||||
return out
|
||||
|
Reference in New Issue
Block a user