Initialize the unet directly on the target device.

This commit is contained in:
comfyanonymous
2023-07-29 14:51:56 -04:00
parent ad5866b02b
commit 4b957a0010
6 changed files with 110 additions and 103 deletions

View File

@@ -109,8 +109,8 @@ class SDXLRefiner(supported_models_base.BASE):
latent_format = latent_formats.SDXL
def get_model(self, state_dict, prefix=""):
return model_base.SDXLRefiner(self)
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXLRefiner(self, device=device)
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}
@@ -152,8 +152,8 @@ class SDXL(supported_models_base.BASE):
else:
return model_base.ModelType.EPS
def get_model(self, state_dict, prefix=""):
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix))
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
def process_clip_state_dict(self, state_dict):
keys_to_replace = {}