mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-14 05:25:23 +00:00
Remove attention logging code
This commit is contained in:
@@ -132,114 +132,12 @@ class FeedForward(nn.Module):
|
|||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
import inspect
|
|
||||||
LOG_ATTN_CALLS = False
|
|
||||||
LOG_CONTENTS = {}
|
|
||||||
|
|
||||||
def save_log_contents():
|
|
||||||
import folder_paths
|
|
||||||
output_dir = folder_paths.get_output_directory()
|
|
||||||
|
|
||||||
# Create attn_logs directory if it doesn't exist
|
|
||||||
attn_logs_dir = os.path.join(output_dir, "attn_logs")
|
|
||||||
os.makedirs(attn_logs_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Generate timestamp filename (down to second)
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
filename = f"{timestamp}.json"
|
|
||||||
filepath = os.path.join(attn_logs_dir, filename)
|
|
||||||
|
|
||||||
# Save LOG_CONTENTS as JSON file
|
|
||||||
try:
|
|
||||||
with open(filepath, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(list(LOG_CONTENTS.values()), f, indent=2, ensure_ascii=False)
|
|
||||||
logging.info(f"Saved attention log contents to {filepath}")
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to save attention log contents: {e}")
|
|
||||||
|
|
||||||
def get_class_from_frame(frame):
|
|
||||||
# Check for 'self' (instance method) or 'cls' (classmethod)
|
|
||||||
if 'self' in frame.f_locals:
|
|
||||||
return frame.f_locals['self'].__class__.__name__
|
|
||||||
elif 'cls' in frame.f_locals:
|
|
||||||
return frame.f_locals['cls'].__name__
|
|
||||||
return None
|
|
||||||
|
|
||||||
def has_transformer_options_passed(frame):
|
|
||||||
if 'transformer_options' in frame.f_locals.keys():
|
|
||||||
if frame.f_locals['transformer_options']:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def wrap_attn(func):
|
def wrap_attn(func):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
remove_attn_wrapper_key = False
|
remove_attn_wrapper_key = False
|
||||||
try:
|
try:
|
||||||
if LOG_ATTN_CALLS:
|
|
||||||
continue_to_add = True
|
|
||||||
to_add = 1000
|
|
||||||
logged_stack = []
|
|
||||||
logged_stack_to_index = -1
|
|
||||||
|
|
||||||
frame = inspect.currentframe()
|
|
||||||
try:
|
|
||||||
# skip wrapper, start at actual wrapped function
|
|
||||||
frame = frame.f_back
|
|
||||||
|
|
||||||
while frame and continue_to_add and to_add > 0:
|
|
||||||
code = frame.f_code
|
|
||||||
filename = code.co_filename
|
|
||||||
function = code.co_name
|
|
||||||
lineno = frame.f_lineno
|
|
||||||
|
|
||||||
if function == "_calc_cond_batch_outer":
|
|
||||||
break
|
|
||||||
if 'venv' in filename:
|
|
||||||
frame = frame.f_back
|
|
||||||
continue
|
|
||||||
elif 'ComfyUI' not in filename:
|
|
||||||
frame = frame.f_back
|
|
||||||
continue
|
|
||||||
elif 'execution.py' in filename:
|
|
||||||
frame = frame.f_back
|
|
||||||
continue
|
|
||||||
elif 'patcher_extension.py' in filename:
|
|
||||||
frame = frame.f_back
|
|
||||||
continue
|
|
||||||
|
|
||||||
to_add -= 1
|
|
||||||
cls_name = get_class_from_frame(frame)
|
|
||||||
log_string = f"{filename}:{lineno}"
|
|
||||||
if cls_name:
|
|
||||||
log_string += f":{cls_name}.{function}"
|
|
||||||
else:
|
|
||||||
log_string += f":{function}"
|
|
||||||
|
|
||||||
if has_transformer_options_passed(frame):
|
|
||||||
log_string += ":✅"
|
|
||||||
if logged_stack_to_index == -1:
|
|
||||||
logged_stack_to_index = len(logged_stack)
|
|
||||||
else:
|
|
||||||
log_string += ":❌"
|
|
||||||
|
|
||||||
logged_stack.append(log_string)
|
|
||||||
|
|
||||||
# move up the stack
|
|
||||||
frame = frame.f_back
|
|
||||||
|
|
||||||
# check if we get what we want from transformer_options
|
|
||||||
t_check = "❌❌❌"
|
|
||||||
transformer_options = kwargs.get("transformer_options", None)
|
|
||||||
if transformer_options is not None:
|
|
||||||
if "optimized_attention_override" in transformer_options:
|
|
||||||
t_check = "✅✅✅"
|
|
||||||
|
|
||||||
LOG_CONTENTS["|".join(logged_stack)] = (t_check, logged_stack_to_index, logged_stack)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Important: break ref cycles so tensors aren't pinned
|
|
||||||
del frame
|
|
||||||
if "_inside_attn_wrapper" not in kwargs:
|
if "_inside_attn_wrapper" not in kwargs:
|
||||||
transformer_options = kwargs.get("transformer_options", None)
|
transformer_options = kwargs.get("transformer_options", None)
|
||||||
remove_attn_wrapper_key = True
|
remove_attn_wrapper_key = True
|
||||||
|
@@ -1019,7 +1019,6 @@ class CFGGuider:
|
|||||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||||
preprocess_conds_hooks(self.conds)
|
preprocess_conds_hooks(self.conds)
|
||||||
|
|
||||||
import comfy.ldm.modules.attention #TODO: Remove this $$$$$
|
|
||||||
try:
|
try:
|
||||||
orig_model_options = self.model_options
|
orig_model_options = self.model_options
|
||||||
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||||
@@ -1034,23 +1033,12 @@ class CFGGuider:
|
|||||||
self,
|
self,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
||||||
)
|
)
|
||||||
comfy.ldm.modules.attention.LOG_ATTN_CALLS = False #TODO: Remove this $$$$$
|
|
||||||
comfy.ldm.modules.attention.LOG_CONTENTS = {}
|
|
||||||
if "optimized_attention_override" not in self.model_options["transformer_options"]:
|
|
||||||
def optimized_attention_override(func, *args, **kwargs):
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
self.model_options["transformer_options"]["optimized_attention_override"] = optimized_attention_override
|
|
||||||
|
|
||||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||||
finally:
|
finally:
|
||||||
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||||
self.model_options = orig_model_options
|
self.model_options = orig_model_options
|
||||||
self.model_patcher.hook_mode = orig_hook_mode
|
self.model_patcher.hook_mode = orig_hook_mode
|
||||||
self.model_patcher.restore_hook_patches()
|
self.model_patcher.restore_hook_patches()
|
||||||
if comfy.ldm.modules.attention.LOG_ATTN_CALLS:
|
|
||||||
comfy.ldm.modules.attention.save_log_contents()
|
|
||||||
comfy.ldm.modules.attention.LOG_ATTN_CALLS = False #TODO: Remove this $$$$$
|
|
||||||
comfy.ldm.modules.attention.LOG_CONTENTS = {}
|
|
||||||
|
|
||||||
del self.conds
|
del self.conds
|
||||||
return output
|
return output
|
||||||
|
Reference in New Issue
Block a user