mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 19:46:38 +00:00
Qwen rotary embeddings should now match reference code. (#9437)
This commit is contained in:
@@ -349,8 +349,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -396,7 +396,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
|
||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||||
|
Reference in New Issue
Block a user