Speed up hunyuan dit inference a bit.

This commit is contained in:
comfyanonymous
2024-08-10 07:36:27 -04:00
parent 1b5b8ca81a
commit ae197f651b
2 changed files with 4 additions and 4 deletions

View File

@@ -21,6 +21,7 @@ def calc_rope(x, patch_size, head_size):
sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
rope = (rope[0].to(x), rope[1].to(x))
return rope