Created logging code for this branch so that it can be used to track down all the code paths where transformer_options would need to be added

This commit is contained in:
Jedrzej Kosinski
2025-08-27 17:13:33 -07:00
parent b58db6934c
commit 68b00e9c60
2 changed files with 88 additions and 3 deletions

View File

@@ -1,5 +1,8 @@
import math
import sys
import json
import os
from datetime import datetime
import torch
import torch.nn.functional as F
@@ -92,13 +95,89 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
import inspect
LOG_ATTN_CALLS = False
LOG_CONTENTS = {}
def save_log_contents():
import folder_paths
output_dir = folder_paths.get_output_directory()
# Create attn_logs directory if it doesn't exist
attn_logs_dir = os.path.join(output_dir, "attn_logs")
os.makedirs(attn_logs_dir, exist_ok=True)
# Generate timestamp filename (down to second)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{timestamp}.json"
filepath = os.path.join(attn_logs_dir, filename)
# Save LOG_CONTENTS as JSON file
try:
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(list(LOG_CONTENTS.values()), f, indent=2, ensure_ascii=False)
logging.info(f"Saved attention log contents to {filepath}")
except Exception as e:
logging.error(f"Failed to save attention log contents: {e}")
def get_class_from_frame(frame):
# Check for 'self' (instance method) or 'cls' (classmethod)
if 'self' in frame.f_locals:
return frame.f_locals['self'].__class__.__name__
elif 'cls' in frame.f_locals:
return frame.f_locals['cls'].__name__
return None
def has_transformer_options_passed(frame):
if 'transformer_options' in frame.f_locals.keys():
if frame.f_locals['transformer_options']:
return True
return False
def wrap_attn(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if LOG_ATTN_CALLS:
continue_to_add = True
to_add = 1000
logged_stack = []
logged_stack_to_index = -1
for frame_info in inspect.stack()[1:]:
if not continue_to_add:
break
if to_add == 0:
break
if frame_info.function == "_calc_cond_batch_outer":
break
if 'venv' in frame_info.filename:
continue
elif 'ComfyUI' not in frame_info.filename:
continue
elif 'execution.py' in frame_info.filename:
continue
elif 'patcher_extension.py' in frame_info.filename:
continue
to_add -= 1
cls_name = get_class_from_frame(frame_info.frame)
log_string = f"{frame_info.filename}:{frame_info.lineno}"
if cls_name:
log_string += f":{cls_name}.{frame_info.function}"
else:
log_string += f":{frame_info.function}"
if has_transformer_options_passed(frame_info.frame):
log_string += ":✅"
if logged_stack_to_index == -1:
logged_stack_to_index = len(logged_stack)
else:
log_string += ":❌"
logged_stack.append(log_string)
# logging.info(f"Attn call stack: {logged_stack}")
# logging.info(f"Logged stack to index: {logged_stack[:logged_stack_to_index+1]}")
LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack)
transformer_options = kwargs.pop("transformer_options", None)
if transformer_options is not None:
if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](*args, **kwargs)
return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs)
return func(*args, **kwargs)
return wrapper
@@ -760,7 +839,7 @@ class BasicTransformerBlock(nn.Module):
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
n = self.attn1.to_out(n)
else:
n = self.attn1(n, context=context_attn1, value=value_attn1)
n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options)
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
@@ -800,7 +879,7 @@ class BasicTransformerBlock(nn.Module):
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]

View File

@@ -1019,6 +1019,7 @@ class CFGGuider:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
preprocess_conds_hooks(self.conds)
import comfy.ldm.modules.attention #TODO: Remove this $$$$$
try:
orig_model_options = self.model_options
self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
@@ -1033,12 +1034,17 @@ class CFGGuider:
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
comfy.ldm.modules.attention.LOG_ATTN_CALLS = True #TODO: Remove this $$$$$
comfy.ldm.modules.attention.LOG_CONTENTS = {}
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options
self.model_patcher.hook_mode = orig_hook_mode
self.model_patcher.restore_hook_patches()
comfy.ldm.modules.attention.LOG_ATTN_CALLS = False #TODO: Remove this $$$$$
comfy.ldm.modules.attention.save_log_contents()
comfy.ldm.modules.attention.LOG_CONTENTS = {}
del self.conds
return output