Made StableCascade work with optimized_attention_override

This commit is contained in:
Jedrzej Kosinski
2025-08-28 21:42:08 -07:00
parent 09c84b31a2
commit 034d6c12e6
3 changed files with 20 additions and 20 deletions

View File

@@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, q, k, v): def forward(self, q, k, v, transformer_options={}):
q = self.to_q(q) q = self.to_q(q)
k = self.to_k(k) k = self.to_k(k)
v = self.to_v(v) v = self.to_v(v)
out = optimized_attention(q, k, v, self.heads) out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
return self.out_proj(out) return self.out_proj(out)
@@ -47,13 +47,13 @@ class Attention2D(nn.Module):
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations) self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device) # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
def forward(self, x, kv, self_attn=False): def forward(self, x, kv, self_attn=False, transformer_options={}):
orig_shape = x.shape orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn: if self_attn:
kv = torch.cat([x, kv], dim=1) kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0] # x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv) x = self.attn(x, kv, kv, transformer_options=transformer_options)
x = x.permute(0, 2, 1).view(*orig_shape) x = x.permute(0, 2, 1).view(*orig_shape)
return x return x
@@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
operations.Linear(c_cond, c, dtype=dtype, device=device) operations.Linear(c_cond, c, dtype=dtype, device=device)
) )
def forward(self, x, kv): def forward(self, x, kv, transformer_options={}):
kv = self.kv_mapper(kv) kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
return x return x

View File

@@ -173,7 +173,7 @@ class StageB(nn.Module):
clip = self.clip_norm(clip) clip = self.clip_norm(clip)
return clip return clip
def _down_encode(self, x, r_embed, clip): def _down_encode(self, x, r_embed, clip, transformer_options={}):
level_outputs = [] level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group: for down_block, downscaler, repmap in block_group:
@@ -187,7 +187,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or ( elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)): AttnBlock)):
x = block(x, clip) x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or ( elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)): TimestepBlock)):
@@ -199,7 +199,7 @@ class StageB(nn.Module):
level_outputs.insert(0, x) level_outputs.insert(0, x)
return level_outputs return level_outputs
def _up_decode(self, level_outputs, r_embed, clip): def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
x = level_outputs[0] x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group): for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -216,7 +216,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or ( elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)): AttnBlock)):
x = block(x, clip) x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or ( elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)): TimestepBlock)):
@@ -228,7 +228,7 @@ class StageB(nn.Module):
x = upscaler(x) x = upscaler(x)
return x return x
def forward(self, x, r, effnet, clip, pixels=None, **kwargs): def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
if pixels is None: if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8) pixels = x.new_zeros(x.size(0), 3, 8, 8)
@@ -245,8 +245,8 @@ class StageB(nn.Module):
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True)) nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear', x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True) align_corners=True)
level_outputs = self._down_encode(x, r_embed, clip) level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip) x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
return self.clf(x) return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999): def update_weights_ema(self, src_model, beta=0.999):

View File

@@ -182,7 +182,7 @@ class StageC(nn.Module):
clip = self.clip_norm(clip) clip = self.clip_norm(clip)
return clip return clip
def _down_encode(self, x, r_embed, clip, cnet=None): def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
level_outputs = [] level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group: for down_block, downscaler, repmap in block_group:
@@ -201,7 +201,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or ( elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)): AttnBlock)):
x = block(x, clip) x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or ( elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)): TimestepBlock)):
@@ -213,7 +213,7 @@ class StageC(nn.Module):
level_outputs.insert(0, x) level_outputs.insert(0, x)
return level_outputs return level_outputs
def _up_decode(self, level_outputs, r_embed, clip, cnet=None): def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
x = level_outputs[0] x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group): for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -235,7 +235,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or ( elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)): AttnBlock)):
x = block(x, clip) x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or ( elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)): TimestepBlock)):
@@ -247,7 +247,7 @@ class StageC(nn.Module):
x = upscaler(x) x = upscaler(x)
return x return x
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs): def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
# Process the conditioning embeddings # Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds: for c in self.t_conds:
@@ -262,8 +262,8 @@ class StageC(nn.Module):
# Model Blocks # Model Blocks
x = self.embedding(x) x = self.embedding(x)
level_outputs = self._down_encode(x, r_embed, clip, cnet) level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip, cnet) x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
return self.clf(x) return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999): def update_weights_ema(self, src_model, beta=0.999):