Clean up percent start/end and make controlnets work with sigmas.

This commit is contained in:
comfyanonymous
2023-10-31 22:14:32 -04:00
parent a268a574fa
commit 7c0f255de1
3 changed files with 26 additions and 9 deletions

View File

@@ -132,6 +132,7 @@ class ControlNet(ControlBase):
self.control_model = control_model
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
self.model_sampling_current = None
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@@ -159,7 +160,10 @@ class ControlNet(ControlBase):
y = cond.get('y', None)
if y is not None:
y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
def copy(self):
@@ -172,6 +176,14 @@ class ControlNet(ControlBase):
out.append(self.control_model_wrapped)
return out
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
self.model_sampling_current = model.model_sampling
def cleanup(self):
self.model_sampling_current = None
super().cleanup()
class ControlLoraOps:
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,