This commit is contained in:
comfyanonymous
2024-05-14 18:02:27 -04:00
parent bb4940d837
commit ec6f16adb6
2 changed files with 9 additions and 7 deletions

View File

@@ -420,6 +420,7 @@ class BasicTransformerBlock(nn.Module):
inner_dim = dim
self.is_res = inner_dim == dim
self.attn_precision = attn_precision
if self.ff_in:
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
@@ -427,7 +428,7 @@ class BasicTransformerBlock(nn.Module):
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention:
@@ -441,7 +442,7 @@ class BasicTransformerBlock(nn.Module):
context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
@@ -471,6 +472,7 @@ class BasicTransformerBlock(nn.Module):
extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head
extra_options["attn_precision"] = self.attn_precision
if self.ff_in:
x_skip = x