mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
78 lines
2.9 KiB
Python
78 lines
2.9 KiB
Python
import torch
|
|
import math
|
|
|
|
from .model import QwenImageTransformer2DModel
|
|
|
|
|
|
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
|
def __init__(
|
|
self,
|
|
extra_condition_channels=0,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
**kwargs
|
|
):
|
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
|
self.main_model_double = 60
|
|
|
|
# controlnet_blocks
|
|
self.controlnet_blocks = torch.nn.ModuleList([])
|
|
for _ in range(len(self.transformer_blocks)):
|
|
self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype))
|
|
self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
timesteps,
|
|
context,
|
|
attention_mask=None,
|
|
guidance: torch.Tensor = None,
|
|
ref_latents=None,
|
|
hint=None,
|
|
transformer_options={},
|
|
**kwargs
|
|
):
|
|
timestep = timesteps
|
|
encoder_hidden_states = context
|
|
encoder_hidden_states_mask = attention_mask
|
|
|
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
|
hint, _, _ = self.process_img(hint)
|
|
|
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
|
del ids, txt_ids, img_ids
|
|
|
|
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
|
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
|
|
|
if guidance is not None:
|
|
guidance = guidance * 1000
|
|
|
|
temb = (
|
|
self.time_text_embed(timestep, hidden_states)
|
|
if guidance is None
|
|
else self.time_text_embed(timestep, guidance, hidden_states)
|
|
)
|
|
|
|
repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks))
|
|
|
|
controlnet_block_samples = ()
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
encoder_hidden_states, hidden_states = block(
|
|
hidden_states=hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
|
temb=temb,
|
|
image_rotary_emb=image_rotary_emb,
|
|
)
|
|
|
|
controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat
|
|
|
|
return {"input": controlnet_block_samples[:self.main_model_double]}
|