Support llama hunyuan video text encoder in scaled fp8 format.

This commit is contained in:
comfyanonymous
2024-12-17 04:19:22 -05:00
parent f4cdedea62
commit d6656b0c0c
3 changed files with 25 additions and 4 deletions

View File

@@ -6,6 +6,19 @@ import torch
import os
def llama_detect(state_dict, prefix=""):
out = {}
t5_key = "{}model.norm.weight".format(prefix)
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype
scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
return out
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")