mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-13 04:55:53 +00:00
Made StableCascade work with optimized_attention_override
This commit is contained in:
@@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
|
|||||||
|
|
||||||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, q, k, v):
|
def forward(self, q, k, v, transformer_options={}):
|
||||||
q = self.to_q(q)
|
q = self.to_q(q)
|
||||||
k = self.to_k(k)
|
k = self.to_k(k)
|
||||||
v = self.to_v(v)
|
v = self.to_v(v)
|
||||||
|
|
||||||
out = optimized_attention(q, k, v, self.heads)
|
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
|
||||||
|
|
||||||
return self.out_proj(out)
|
return self.out_proj(out)
|
||||||
|
|
||||||
@@ -47,13 +47,13 @@ class Attention2D(nn.Module):
|
|||||||
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
|
||||||
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x, kv, self_attn=False):
|
def forward(self, x, kv, self_attn=False, transformer_options={}):
|
||||||
orig_shape = x.shape
|
orig_shape = x.shape
|
||||||
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
|
||||||
if self_attn:
|
if self_attn:
|
||||||
kv = torch.cat([x, kv], dim=1)
|
kv = torch.cat([x, kv], dim=1)
|
||||||
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
# x = self.attn(x, kv, kv, need_weights=False)[0]
|
||||||
x = self.attn(x, kv, kv)
|
x = self.attn(x, kv, kv, transformer_options=transformer_options)
|
||||||
x = x.permute(0, 2, 1).view(*orig_shape)
|
x = x.permute(0, 2, 1).view(*orig_shape)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
|
|||||||
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
operations.Linear(c_cond, c, dtype=dtype, device=device)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, kv):
|
def forward(self, x, kv, transformer_options={}):
|
||||||
kv = self.kv_mapper(kv)
|
kv = self.kv_mapper(kv)
|
||||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@@ -173,7 +173,7 @@ class StageB(nn.Module):
|
|||||||
clip = self.clip_norm(clip)
|
clip = self.clip_norm(clip)
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def _down_encode(self, x, r_embed, clip):
|
def _down_encode(self, x, r_embed, clip, transformer_options={}):
|
||||||
level_outputs = []
|
level_outputs = []
|
||||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||||
for down_block, downscaler, repmap in block_group:
|
for down_block, downscaler, repmap in block_group:
|
||||||
@@ -187,7 +187,7 @@ class StageB(nn.Module):
|
|||||||
elif isinstance(block, AttnBlock) or (
|
elif isinstance(block, AttnBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
AttnBlock)):
|
AttnBlock)):
|
||||||
x = block(x, clip)
|
x = block(x, clip, transformer_options=transformer_options)
|
||||||
elif isinstance(block, TimestepBlock) or (
|
elif isinstance(block, TimestepBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
TimestepBlock)):
|
TimestepBlock)):
|
||||||
@@ -199,7 +199,7 @@ class StageB(nn.Module):
|
|||||||
level_outputs.insert(0, x)
|
level_outputs.insert(0, x)
|
||||||
return level_outputs
|
return level_outputs
|
||||||
|
|
||||||
def _up_decode(self, level_outputs, r_embed, clip):
|
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
|
||||||
x = level_outputs[0]
|
x = level_outputs[0]
|
||||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||||
@@ -216,7 +216,7 @@ class StageB(nn.Module):
|
|||||||
elif isinstance(block, AttnBlock) or (
|
elif isinstance(block, AttnBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
AttnBlock)):
|
AttnBlock)):
|
||||||
x = block(x, clip)
|
x = block(x, clip, transformer_options=transformer_options)
|
||||||
elif isinstance(block, TimestepBlock) or (
|
elif isinstance(block, TimestepBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
TimestepBlock)):
|
TimestepBlock)):
|
||||||
@@ -228,7 +228,7 @@ class StageB(nn.Module):
|
|||||||
x = upscaler(x)
|
x = upscaler(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
|
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
|
||||||
if pixels is None:
|
if pixels is None:
|
||||||
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
pixels = x.new_zeros(x.size(0), 3, 8, 8)
|
||||||
|
|
||||||
@@ -245,8 +245,8 @@ class StageB(nn.Module):
|
|||||||
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
|
||||||
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
|
||||||
align_corners=True)
|
align_corners=True)
|
||||||
level_outputs = self._down_encode(x, r_embed, clip)
|
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
|
||||||
x = self._up_decode(level_outputs, r_embed, clip)
|
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
|
||||||
return self.clf(x)
|
return self.clf(x)
|
||||||
|
|
||||||
def update_weights_ema(self, src_model, beta=0.999):
|
def update_weights_ema(self, src_model, beta=0.999):
|
||||||
|
@@ -182,7 +182,7 @@ class StageC(nn.Module):
|
|||||||
clip = self.clip_norm(clip)
|
clip = self.clip_norm(clip)
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def _down_encode(self, x, r_embed, clip, cnet=None):
|
def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
|
||||||
level_outputs = []
|
level_outputs = []
|
||||||
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
||||||
for down_block, downscaler, repmap in block_group:
|
for down_block, downscaler, repmap in block_group:
|
||||||
@@ -201,7 +201,7 @@ class StageC(nn.Module):
|
|||||||
elif isinstance(block, AttnBlock) or (
|
elif isinstance(block, AttnBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
AttnBlock)):
|
AttnBlock)):
|
||||||
x = block(x, clip)
|
x = block(x, clip, transformer_options=transformer_options)
|
||||||
elif isinstance(block, TimestepBlock) or (
|
elif isinstance(block, TimestepBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
TimestepBlock)):
|
TimestepBlock)):
|
||||||
@@ -213,7 +213,7 @@ class StageC(nn.Module):
|
|||||||
level_outputs.insert(0, x)
|
level_outputs.insert(0, x)
|
||||||
return level_outputs
|
return level_outputs
|
||||||
|
|
||||||
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
|
def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
|
||||||
x = level_outputs[0]
|
x = level_outputs[0]
|
||||||
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
||||||
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
for i, (up_block, upscaler, repmap) in enumerate(block_group):
|
||||||
@@ -235,7 +235,7 @@ class StageC(nn.Module):
|
|||||||
elif isinstance(block, AttnBlock) or (
|
elif isinstance(block, AttnBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
AttnBlock)):
|
AttnBlock)):
|
||||||
x = block(x, clip)
|
x = block(x, clip, transformer_options=transformer_options)
|
||||||
elif isinstance(block, TimestepBlock) or (
|
elif isinstance(block, TimestepBlock) or (
|
||||||
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
|
||||||
TimestepBlock)):
|
TimestepBlock)):
|
||||||
@@ -247,7 +247,7 @@ class StageC(nn.Module):
|
|||||||
x = upscaler(x)
|
x = upscaler(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
|
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
|
||||||
# Process the conditioning embeddings
|
# Process the conditioning embeddings
|
||||||
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
|
||||||
for c in self.t_conds:
|
for c in self.t_conds:
|
||||||
@@ -262,8 +262,8 @@ class StageC(nn.Module):
|
|||||||
|
|
||||||
# Model Blocks
|
# Model Blocks
|
||||||
x = self.embedding(x)
|
x = self.embedding(x)
|
||||||
level_outputs = self._down_encode(x, r_embed, clip, cnet)
|
level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||||
x = self._up_decode(level_outputs, r_embed, clip, cnet)
|
x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
|
||||||
return self.clf(x)
|
return self.clf(x)
|
||||||
|
|
||||||
def update_weights_ema(self, src_model, beta=0.999):
|
def update_weights_ema(self, src_model, beta=0.999):
|
||||||
|
Reference in New Issue
Block a user