mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 04:55:53 +00:00
Initialize the unet directly on the target device.
This commit is contained in:
@@ -12,14 +12,14 @@ class ModelType(Enum):
|
||||
V_PREDICTION = 2
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__()
|
||||
|
||||
unet_config = model_config.unet_config
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
self.diffusion_model = UNetModel(**unet_config)
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||
self.model_type = model_type
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
@@ -107,8 +107,8 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
|
||||
class SD21UNCLIP(BaseModel):
|
||||
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION):
|
||||
super().__init__(model_config, model_type)
|
||||
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
@@ -143,13 +143,13 @@ class SD21UNCLIP(BaseModel):
|
||||
return adm_out
|
||||
|
||||
class SDInpaint(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||
super().__init__(model_config, model_type)
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.concat_keys = ("mask", "masked_image")
|
||||
|
||||
class SDXLRefiner(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||
super().__init__(model_config, model_type)
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
@@ -174,8 +174,8 @@ class SDXLRefiner(BaseModel):
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
class SDXL(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS):
|
||||
super().__init__(model_config, model_type)
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.embedder = Timestep(256)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
|
Reference in New Issue
Block a user