mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 19:46:38 +00:00
Fix EasyCache/LazyCache crash when tensor shape/dtype/device changes during sampling (#9528)
* Fix EasyCache/LazyCache crash when tensor shape/dtype/device changes during sampling * Fix missing LazyCache check_metadata method Ensure LazyCache reset method resets all the tensor state values
This commit is contained in:
@@ -28,6 +28,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
|||||||
input_change = None
|
input_change = None
|
||||||
do_easycache = easycache.should_do_easycache(sigmas)
|
do_easycache = easycache.should_do_easycache(sigmas)
|
||||||
if do_easycache:
|
if do_easycache:
|
||||||
|
easycache.check_metadata(x)
|
||||||
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
||||||
if easycache.skip_current_step:
|
if easycache.skip_current_step:
|
||||||
if easycache.verbose:
|
if easycache.verbose:
|
||||||
@@ -92,6 +93,7 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
|||||||
input_change = None
|
input_change = None
|
||||||
do_easycache = easycache.should_do_easycache(timestep)
|
do_easycache = easycache.should_do_easycache(timestep)
|
||||||
if do_easycache:
|
if do_easycache:
|
||||||
|
easycache.check_metadata(x)
|
||||||
if easycache.has_x_prev_subsampled():
|
if easycache.has_x_prev_subsampled():
|
||||||
if easycache.has_x_prev_subsampled():
|
if easycache.has_x_prev_subsampled():
|
||||||
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
|
||||||
@@ -194,6 +196,7 @@ class EasyCacheHolder:
|
|||||||
# how to deal with mismatched dims
|
# how to deal with mismatched dims
|
||||||
self.allow_mismatch = True
|
self.allow_mismatch = True
|
||||||
self.cut_from_start = True
|
self.cut_from_start = True
|
||||||
|
self.state_metadata = None
|
||||||
|
|
||||||
def is_past_end_timestep(self, timestep: float) -> bool:
|
def is_past_end_timestep(self, timestep: float) -> bool:
|
||||||
return not (timestep[0] > self.end_t).item()
|
return not (timestep[0] > self.end_t).item()
|
||||||
@@ -283,6 +286,17 @@ class EasyCacheHolder:
|
|||||||
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
|
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
|
||||||
return self.first_cond_uuid in uuids
|
return self.first_cond_uuid in uuids
|
||||||
|
|
||||||
|
def check_metadata(self, x: torch.Tensor) -> bool:
|
||||||
|
metadata = (x.device, x.dtype, x.shape[1:])
|
||||||
|
if self.state_metadata is None:
|
||||||
|
self.state_metadata = metadata
|
||||||
|
return True
|
||||||
|
if metadata == self.state_metadata:
|
||||||
|
return True
|
||||||
|
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
||||||
|
self.reset()
|
||||||
|
return False
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.relative_transformation_rate = 0.0
|
self.relative_transformation_rate = 0.0
|
||||||
self.cumulative_change_rate = 0.0
|
self.cumulative_change_rate = 0.0
|
||||||
@@ -299,6 +313,7 @@ class EasyCacheHolder:
|
|||||||
del self.uuid_cache_diffs
|
del self.uuid_cache_diffs
|
||||||
self.uuid_cache_diffs = {}
|
self.uuid_cache_diffs = {}
|
||||||
self.total_steps_skipped = 0
|
self.total_steps_skipped = 0
|
||||||
|
self.state_metadata = None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
@@ -360,6 +375,7 @@ class LazyCacheHolder:
|
|||||||
self.output_change_rates = []
|
self.output_change_rates = []
|
||||||
self.approx_output_change_rates = []
|
self.approx_output_change_rates = []
|
||||||
self.total_steps_skipped = 0
|
self.total_steps_skipped = 0
|
||||||
|
self.state_metadata = None
|
||||||
|
|
||||||
def has_cache_diff(self) -> bool:
|
def has_cache_diff(self) -> bool:
|
||||||
return self.cache_diff is not None
|
return self.cache_diff is not None
|
||||||
@@ -404,6 +420,17 @@ class LazyCacheHolder:
|
|||||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
|
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
|
||||||
self.cache_diff = output - x
|
self.cache_diff = output - x
|
||||||
|
|
||||||
|
def check_metadata(self, x: torch.Tensor) -> bool:
|
||||||
|
metadata = (x.device, x.dtype, x.shape)
|
||||||
|
if self.state_metadata is None:
|
||||||
|
self.state_metadata = metadata
|
||||||
|
return True
|
||||||
|
if metadata == self.state_metadata:
|
||||||
|
return True
|
||||||
|
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
|
||||||
|
self.reset()
|
||||||
|
return False
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.relative_transformation_rate = 0.0
|
self.relative_transformation_rate = 0.0
|
||||||
self.cumulative_change_rate = 0.0
|
self.cumulative_change_rate = 0.0
|
||||||
@@ -412,7 +439,14 @@ class LazyCacheHolder:
|
|||||||
self.approx_output_change_rates = []
|
self.approx_output_change_rates = []
|
||||||
del self.cache_diff
|
del self.cache_diff
|
||||||
self.cache_diff = None
|
self.cache_diff = None
|
||||||
|
del self.x_prev_subsampled
|
||||||
|
self.x_prev_subsampled = None
|
||||||
|
del self.output_prev_subsampled
|
||||||
|
self.output_prev_subsampled = None
|
||||||
|
del self.output_prev_norm
|
||||||
|
self.output_prev_norm = None
|
||||||
self.total_steps_skipped = 0
|
self.total_steps_skipped = 0
|
||||||
|
self.state_metadata = None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
|
Reference in New Issue
Block a user