diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 0ab5cf16d..6a8ffe10b 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -1,5 +1,8 @@ import math import sys +import json +import os +from datetime import datetime import torch import torch.nn.functional as F @@ -92,13 +95,89 @@ class FeedForward(nn.Module): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) +import inspect +LOG_ATTN_CALLS = False +LOG_CONTENTS = {} + +def save_log_contents(): + import folder_paths + output_dir = folder_paths.get_output_directory() + + # Create attn_logs directory if it doesn't exist + attn_logs_dir = os.path.join(output_dir, "attn_logs") + os.makedirs(attn_logs_dir, exist_ok=True) + + # Generate timestamp filename (down to second) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"{timestamp}.json" + filepath = os.path.join(attn_logs_dir, filename) + + # Save LOG_CONTENTS as JSON file + try: + with open(filepath, 'w', encoding='utf-8') as f: + json.dump(list(LOG_CONTENTS.values()), f, indent=2, ensure_ascii=False) + logging.info(f"Saved attention log contents to {filepath}") + except Exception as e: + logging.error(f"Failed to save attention log contents: {e}") + +def get_class_from_frame(frame): + # Check for 'self' (instance method) or 'cls' (classmethod) + if 'self' in frame.f_locals: + return frame.f_locals['self'].__class__.__name__ + elif 'cls' in frame.f_locals: + return frame.f_locals['cls'].__name__ + return None + +def has_transformer_options_passed(frame): + if 'transformer_options' in frame.f_locals.keys(): + if frame.f_locals['transformer_options']: + return True + return False + def wrap_attn(func): @functools.wraps(func) def wrapper(*args, **kwargs): + if LOG_ATTN_CALLS: + continue_to_add = True + to_add = 1000 + logged_stack = [] + logged_stack_to_index = -1 + for frame_info in inspect.stack()[1:]: + if not continue_to_add: + break + if to_add == 0: + break + if frame_info.function == "_calc_cond_batch_outer": + break + if 'venv' in frame_info.filename: + continue + elif 'ComfyUI' not in frame_info.filename: + continue + elif 'execution.py' in frame_info.filename: + continue + elif 'patcher_extension.py' in frame_info.filename: + continue + to_add -= 1 + cls_name = get_class_from_frame(frame_info.frame) + log_string = f"{frame_info.filename}:{frame_info.lineno}" + if cls_name: + log_string += f":{cls_name}.{frame_info.function}" + else: + log_string += f":{frame_info.function}" + if has_transformer_options_passed(frame_info.frame): + log_string += ":✅" + if logged_stack_to_index == -1: + logged_stack_to_index = len(logged_stack) + else: + log_string += ":❌" + logged_stack.append(log_string) + # logging.info(f"Attn call stack: {logged_stack}") + # logging.info(f"Logged stack to index: {logged_stack[:logged_stack_to_index+1]}") + LOG_CONTENTS["|".join(logged_stack)] = (logged_stack_to_index, logged_stack) transformer_options = kwargs.pop("transformer_options", None) if transformer_options is not None: if "optimized_attention_override" in transformer_options: - return transformer_options["optimized_attention_override"](*args, **kwargs) + return transformer_options["optimized_attention_override"](func, transformer_options, *args, **kwargs) return func(*args, **kwargs) return wrapper @@ -760,7 +839,7 @@ class BasicTransformerBlock(nn.Module): n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) n = self.attn1.to_out(n) else: - n = self.attn1(n, context=context_attn1, value=value_attn1) + n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options) if "attn1_output_patch" in transformer_patches: patch = transformer_patches["attn1_output_patch"] @@ -800,7 +879,7 @@ class BasicTransformerBlock(nn.Module): n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) n = self.attn2.to_out(n) else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"] diff --git a/comfy/samplers.py b/comfy/samplers.py index c7dfef4ea..acf86c03d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1019,6 +1019,7 @@ class CFGGuider: self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) preprocess_conds_hooks(self.conds) + import comfy.ldm.modules.attention #TODO: Remove this $$$$$ try: orig_model_options = self.model_options self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options) @@ -1033,12 +1034,17 @@ class CFGGuider: self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) ) + comfy.ldm.modules.attention.LOG_ATTN_CALLS = True #TODO: Remove this $$$$$ + comfy.ldm.modules.attention.LOG_CONTENTS = {} output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options self.model_patcher.hook_mode = orig_hook_mode self.model_patcher.restore_hook_patches() + comfy.ldm.modules.attention.LOG_ATTN_CALLS = False #TODO: Remove this $$$$$ + comfy.ldm.modules.attention.save_log_contents() + comfy.ldm.modules.attention.LOG_CONTENTS = {} del self.conds return output