mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 04:27:21 +00:00
Fix SAG.
This commit is contained in:
@@ -420,6 +420,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
inner_dim = dim
|
||||
|
||||
self.is_res = inner_dim == dim
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
if self.ff_in:
|
||||
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
@@ -427,7 +428,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if disable_temporal_crossattention:
|
||||
@@ -441,7 +442,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
context_dim_attn2 = context_dim
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
@@ -471,6 +472,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
extra_options["n_heads"] = self.n_heads
|
||||
extra_options["dim_head"] = self.d_head
|
||||
extra_options["attn_precision"] = self.attn_precision
|
||||
|
||||
if self.ff_in:
|
||||
x_skip = x
|
||||
|
Reference in New Issue
Block a user