mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 15:17:14 +00:00
More API nodes: Gemini/Open AI Chat, Tripo, Rodin, Runway Image (#8295)
* Add Ideogram generate node. * Add staging api. * Add API_NODE and common error for missing auth token (#5) * Add Minimax Video Generation + Async Task queue polling example (#6) * [Minimax] Show video preview and embed workflow in ouput (#7) * Remove uv.lock * Remove polling operations. * Revert "Remove polling operations." This reverts commit 8415404ce8fbc0262b7de54fc700c5c8854a34fc. * Update stubs. * Added Ideogram and Minimax back in. * Added initial BFL Flux 1.1 [pro] Ultra node (#11) * Manually add BFL polling status response schema (#15) * Add function for uploading files. (#18) * Add Luma nodes (#16) Co-authored-by: Robin Huang <robin.j.huang@gmail.com> * Refactor util functions (#20) * Add rest of Luma node functionality (#19) Co-authored-by: Robin Huang <robin.j.huang@gmail.com> * Fix image_luma_ref not working (#28) Co-authored-by: Robin Huang <robin.j.huang@gmail.com> * [Bug] Remove duplicated option T2V-01 in MinimaxTextToVideoNode (#31) * add veo2, bump av req (#32) * Add Recraft nodes (#29) * Add Kling Nodes (#12) * Add Camera Concepts (luma_concepts) to Luma Video nodes (#33) Co-authored-by: Robin Huang <robin.j.huang@gmail.com> * Add Runway nodes (#17) * Convert Minimax node to use VIDEO output type (#34) * Standard `CATEGORY` system for api nodes (#35) * Set `Content-Type` header when uploading files (#36) * add better error propagation to veo2 (#37) * Add Realistic Image and Logo Raster styles for Recraft v3 (#38) * Fix runway image upload and progress polling (#39) * Fix image upload for Luma: only include `Content-Type` header field if it's set explicitly (#40) * Moved Luma nodes to nodes_luma.py (#47) * Moved Recraft nodes to nodes_recraft.py (#48) * Move and fix BFL nodes to node_bfl.py (#49) * Move and edit Minimax node to nodes_minimax.py (#50) * Add Recraft Text to Vector node, add Save SVG node to handle its output (#53) * Added pixverse_template support to Pixverse Text to Video node (#54) * Added Recraft Controls + Recraft Color RGB nodes (#57) * split remaining nodes out of nodes_api, make utility lib, refactor ideogram (#61) * Set request type explicitly (#66) * Add `control_after_generate` to all seed inputs (#69) * Fix bug: deleting `Content-Type` when property does not exist (#73) * Add Pixverse and updated Kling types (#75) * Added Recraft Style - Infinite Style Library node (#82) * add ideogram v3 (#83) * [Kling] Split Camera Control config to its own node (#81) * Add Pika i2v and t2v nodes (#52) * Remove Runway nodes (#88) * Fix: Prompt text can't be validated in Kling nodes when using primitive nodes (#90) * Update Pika Duration and Resolution options (#94) * Removed Infinite Style Library until later (#99) * fix multi image return (#101) close #96 * Serve SVG files directly (#107) * Add a bunch of nodes, 3 ready to use, the rest waiting for endpoint support (#108) * Revert "Serve SVG files directly" (#111) * Expose 4 remaining Recraft nodes (#112) * [Kling] Add `Duration` and `Video ID` outputs (#105) * Add Kling nodes: camera control, start-end frame, lip-sync, video extend (#115) * Fix error for Recraft ImageToImage error for nonexistent random_seed param (#118) * Add remaining Pika nodes (#119) * Make controls input work for Recraft Image to Image node (#120) * Fix: Nested `AnyUrl` in request model cannot be serialized (Kling, Runway) (#129) * Show errors and API output URLs to the user (change log levels) (#131) * Apply small fixes and most prompt validation (if needed to avoid API error) (#135) * Node name/category modifications (#140) * Add back Recraft Style - Infinite Style Library node (#141) * [Kling] Fix: Correct/verify supported subset of input combos in Kling nodes (#149) * Remove pixverse_template from PixVerse Transition Video node (#155) * Use 3.9 compat syntax (#164) * Handle Comfy API key based authorizaton (#167) Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com> * [BFL] Print download URL of successful task result directly on nodes (#175) * Show output URL and progress text on Pika nodes (#168) * [Ideogram] Print download URL of successful task result directly on nodes (#176) * [Kling] Print download URL of successful task result directly on nodes (#181) * Merge upstream may 14 25 (#186) Co-authored-by: comfyanonymous <comfyanonymous@protonmail.com> Co-authored-by: AustinMroz <austinmroz@utexas.edu> Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Co-authored-by: Benjamin Lu <benceruleanlu@proton.me> Co-authored-by: Andrew Kvochko <kvochko@users.noreply.github.com> Co-authored-by: Pam <42671363+pamparamm@users.noreply.github.com> Co-authored-by: chaObserv <154517000+chaObserv@users.noreply.github.com> Co-authored-by: Yoland Yan <4950057+yoland68@users.noreply.github.com> Co-authored-by: guill <guill@users.noreply.github.com> Co-authored-by: Chenlei Hu <hcl@comfy.org> Co-authored-by: Terry Jia <terryjia88@gmail.com> Co-authored-by: Silver <65376327+silveroxides@users.noreply.github.com> Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com> Co-authored-by: liesen <liesen.dev@gmail.com> Co-authored-by: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com> Co-authored-by: Robin Huang <robin.j.huang@gmail.com> Co-authored-by: thot experiment <94414189+thot-experiment@users.noreply.github.com> Co-authored-by: blepping <157360029+blepping@users.noreply.github.com> * Update instructions on how to develop API Nodes. (#171) * Add Runway FLF and I2V nodes (#187) * Add OpenAI chat node (#188) * Update README. * Add Google Gemini API node (#191) * Add Runway Gen 4 Text to Image Node (#193) * [Runway, Gemini] Update node display names and attributes (#194) * Update path from "image-to-video" to "image_to_video" (#197) * [Runway] Split I2V nodes into separate gen3 and gen4 nodes (#198) * Update runway i2v ratio enum (#201) * Rodin3D: implement Rodin3D API Nodes (#190) Co-authored-by: WhiteGiven <c15838568211@163.com> Co-authored-by: Robin Huang <robin.j.huang@gmail.com> * Add Tripo Nodes. (#189) Co-authored-by: Robin Huang <robin.j.huang@gmail.com> * Change casing of categories "3D" => "3d" (#208) * [tripo] fix negtive_prompt and mv2model (#212) * [tripo] set default param to None (#215) * Add description and tooltip to Tripo Refine model. (#218) * Update. * Fix rebase errors. * Fix rebase errors. * Update templates. * Bump frontend. * Add file type info for file inputs. --------- Co-authored-by: Christian Byrne <cbyrne@comfy.org> Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com> Co-authored-by: Chenlei Hu <hcl@comfy.org> Co-authored-by: thot experiment <94414189+thot-experiment@users.noreply.github.com> Co-authored-by: comfyanonymous <comfyanonymous@protonmail.com> Co-authored-by: AustinMroz <austinmroz@utexas.edu> Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Co-authored-by: Benjamin Lu <benceruleanlu@proton.me> Co-authored-by: Andrew Kvochko <kvochko@users.noreply.github.com> Co-authored-by: Pam <42671363+pamparamm@users.noreply.github.com> Co-authored-by: chaObserv <154517000+chaObserv@users.noreply.github.com> Co-authored-by: Yoland Yan <4950057+yoland68@users.noreply.github.com> Co-authored-by: guill <guill@users.noreply.github.com> Co-authored-by: Terry Jia <terryjia88@gmail.com> Co-authored-by: Silver <65376327+silveroxides@users.noreply.github.com> Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com> Co-authored-by: liesen <liesen.dev@gmail.com> Co-authored-by: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Co-authored-by: blepping <157360029+blepping@users.noreply.github.com> Co-authored-by: Changrz <51637999+WhiteGiven@users.noreply.github.com> Co-authored-by: WhiteGiven <c15838568211@163.com> Co-authored-by: seed93 <liangding1990@163.com>
This commit is contained in:
parent
3a10b9641c
commit
f58f0f5696
@ -18,6 +18,8 @@ Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to
|
|||||||
python run main.py --comfy-api-base https://stagingapi.comfy.org
|
python run main.py --comfy-api-base https://stagingapi.comfy.org
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging.
|
||||||
|
|
||||||
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
|
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
|
||||||
|
|
||||||
### Redocly Instructions
|
### Redocly Instructions
|
||||||
@ -28,7 +30,7 @@ When developing locally, use the `redocly-dev.yaml` file to generate pydantic mo
|
|||||||
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
|
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Download the OpenAPI file from prod server.
|
# Download the OpenAPI file from staging server.
|
||||||
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
|
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
|
||||||
|
|
||||||
# Filter out unneeded API definitions.
|
# Filter out unneeded API definitions.
|
||||||
@ -39,3 +41,25 @@ redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_no
|
|||||||
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
# Merging to Master
|
||||||
|
|
||||||
|
Before merging to comfyanonymous/ComfyUI master, follow these steps:
|
||||||
|
|
||||||
|
1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes.
|
||||||
|
1. Make sure the ComfyUI API is deployed to prod with your changes.
|
||||||
|
1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Download the OpenAPI file from prod server.
|
||||||
|
curl -o openapi.yaml https://api.comfy.org/openapi
|
||||||
|
|
||||||
|
# Filter out unneeded API definitions.
|
||||||
|
npm install -g @redocly/cli
|
||||||
|
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
||||||
|
|
||||||
|
# Generate the pydantic datamodels for validation.
|
||||||
|
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
|
||||||
|
|
||||||
|
```
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import mimetypes
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from comfy.utils import common_upscale
|
from comfy.utils import common_upscale
|
||||||
from comfy_api.input_impl import VideoFromFile
|
from comfy_api.input_impl import VideoFromFile
|
||||||
@ -214,6 +215,7 @@ def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
|||||||
image_bytesio = download_url_to_bytesio(url, timeout)
|
image_bytesio = download_url_to_bytesio(url, timeout)
|
||||||
return bytesio_to_image_tensor(image_bytesio)
|
return bytesio_to_image_tensor(image_bytesio)
|
||||||
|
|
||||||
|
|
||||||
def process_image_response(response: requests.Response) -> torch.Tensor:
|
def process_image_response(response: requests.Response) -> torch.Tensor:
|
||||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||||
return bytesio_to_image_tensor(BytesIO(response.content))
|
return bytesio_to_image_tensor(BytesIO(response.content))
|
||||||
@ -318,11 +320,27 @@ def tensor_to_data_uri(
|
|||||||
return f"data:{mime_type};base64,{base64_string}"
|
return f"data:{mime_type};base64,{base64_string}"
|
||||||
|
|
||||||
|
|
||||||
|
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||||
|
"""Converts a text file to a base64 string."""
|
||||||
|
with open(filepath, "rb") as f:
|
||||||
|
file_content = f.read()
|
||||||
|
return base64.b64encode(file_content).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||||
|
"""Converts a text file to a data URI."""
|
||||||
|
base64_string = text_filepath_to_base64_string(filepath)
|
||||||
|
mime_type, _ = mimetypes.guess_type(filepath)
|
||||||
|
if mime_type is None:
|
||||||
|
mime_type = "application/octet-stream"
|
||||||
|
return f"data:{mime_type};base64,{base64_string}"
|
||||||
|
|
||||||
|
|
||||||
def upload_file_to_comfyapi(
|
def upload_file_to_comfyapi(
|
||||||
file_bytes_io: BytesIO,
|
file_bytes_io: BytesIO,
|
||||||
filename: str,
|
filename: str,
|
||||||
upload_mime_type: str,
|
upload_mime_type: str,
|
||||||
auth_kwargs: Optional[dict[str,str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Uploads a single file to ComfyUI API and returns its download URL.
|
Uploads a single file to ComfyUI API and returns its download URL.
|
||||||
@ -357,9 +375,33 @@ def upload_file_to_comfyapi(
|
|||||||
return response.download_url
|
return response.download_url
|
||||||
|
|
||||||
|
|
||||||
|
def video_to_base64_string(
|
||||||
|
video: VideoInput,
|
||||||
|
container_format: VideoContainer = None,
|
||||||
|
codec: VideoCodec = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Converts a video input to a base64 string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video: The video input to convert
|
||||||
|
container_format: Optional container format to use (defaults to video.container if available)
|
||||||
|
codec: Optional codec to use (defaults to video.codec if available)
|
||||||
|
"""
|
||||||
|
video_bytes_io = io.BytesIO()
|
||||||
|
|
||||||
|
# Use provided format/codec if specified, otherwise use video's own if available
|
||||||
|
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||||
|
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||||
|
|
||||||
|
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||||
|
video_bytes_io.seek(0)
|
||||||
|
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def upload_video_to_comfyapi(
|
def upload_video_to_comfyapi(
|
||||||
video: VideoInput,
|
video: VideoInput,
|
||||||
auth_kwargs: Optional[dict[str,str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
container: VideoContainer = VideoContainer.MP4,
|
container: VideoContainer = VideoContainer.MP4,
|
||||||
codec: VideoCodec = VideoCodec.H264,
|
codec: VideoCodec = VideoCodec.H264,
|
||||||
max_duration: Optional[int] = None,
|
max_duration: Optional[int] = None,
|
||||||
@ -461,7 +503,7 @@ def audio_ndarray_to_bytesio(
|
|||||||
|
|
||||||
def upload_audio_to_comfyapi(
|
def upload_audio_to_comfyapi(
|
||||||
audio: AudioInput,
|
audio: AudioInput,
|
||||||
auth_kwargs: Optional[dict[str,str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
container_format: str = "mp4",
|
container_format: str = "mp4",
|
||||||
codec_name: str = "aac",
|
codec_name: str = "aac",
|
||||||
mime_type: str = "audio/mp4",
|
mime_type: str = "audio/mp4",
|
||||||
@ -488,8 +530,25 @@ def upload_audio_to_comfyapi(
|
|||||||
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def audio_to_base64_string(
|
||||||
|
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
|
||||||
|
) -> str:
|
||||||
|
"""Converts an audio input to a base64 string."""
|
||||||
|
sample_rate: int = audio["sample_rate"]
|
||||||
|
waveform: torch.Tensor = audio["waveform"]
|
||||||
|
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||||
|
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||||
|
audio_data_np, sample_rate, container_format, codec_name
|
||||||
|
)
|
||||||
|
audio_bytes = audio_bytes_io.getvalue()
|
||||||
|
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def upload_images_to_comfyapi(
|
def upload_images_to_comfyapi(
|
||||||
image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str,str]] = None, mime_type: Optional[str] = None
|
image: torch.Tensor,
|
||||||
|
max_images=8,
|
||||||
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
|
mime_type: Optional[str] = None,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Uploads images to ComfyUI API and returns download URLs.
|
Uploads images to ComfyUI API and returns download URLs.
|
||||||
@ -554,17 +613,24 @@ def upload_images_to_comfyapi(
|
|||||||
return download_urls
|
return download_urls
|
||||||
|
|
||||||
|
|
||||||
def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
|
def resize_mask_to_image(
|
||||||
upscale_method="nearest-exact", crop="disabled",
|
mask: torch.Tensor,
|
||||||
allow_gradient=True, add_channel_dim=False):
|
image: torch.Tensor,
|
||||||
|
upscale_method="nearest-exact",
|
||||||
|
crop="disabled",
|
||||||
|
allow_gradient=True,
|
||||||
|
add_channel_dim=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
||||||
"""
|
"""
|
||||||
_, H, W, _ = image.shape
|
_, H, W, _ = image.shape
|
||||||
mask = mask.unsqueeze(-1)
|
mask = mask.unsqueeze(-1)
|
||||||
mask = mask.movedim(-1,1)
|
mask = mask.movedim(-1, 1)
|
||||||
mask = common_upscale(mask, width=W, height=H, upscale_method=upscale_method, crop=crop)
|
mask = common_upscale(
|
||||||
mask = mask.movedim(1,-1)
|
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
|
||||||
|
)
|
||||||
|
mask = mask.movedim(1, -1)
|
||||||
if not add_channel_dim:
|
if not add_channel_dim:
|
||||||
mask = mask.squeeze(-1)
|
mask = mask.squeeze(-1)
|
||||||
if not allow_gradient:
|
if not allow_gradient:
|
||||||
@ -572,12 +638,41 @@ def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def validate_string(string: str, strip_whitespace=True, field_name="prompt", min_length=None, max_length=None):
|
def validate_string(
|
||||||
|
string: str,
|
||||||
|
strip_whitespace=True,
|
||||||
|
field_name="prompt",
|
||||||
|
min_length=None,
|
||||||
|
max_length=None,
|
||||||
|
):
|
||||||
|
if string is None:
|
||||||
|
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||||
if strip_whitespace:
|
if strip_whitespace:
|
||||||
string = string.strip()
|
string = string.strip()
|
||||||
if min_length and len(string) < min_length:
|
if min_length and len(string) < min_length:
|
||||||
raise Exception(f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long.")
|
raise Exception(
|
||||||
|
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
||||||
|
)
|
||||||
if max_length and len(string) > max_length:
|
if max_length and len(string) > max_length:
|
||||||
raise Exception(f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long.")
|
raise Exception(
|
||||||
if not string:
|
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
||||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def image_tensor_pair_to_batch(
|
||||||
|
image1: torch.Tensor, image2: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Converts a pair of image tensors to a batch tensor.
|
||||||
|
If the images are not the same size, the smaller image is resized to
|
||||||
|
match the larger image.
|
||||||
|
"""
|
||||||
|
if image1.shape[1:] != image2.shape[1:]:
|
||||||
|
image2 = common_upscale(
|
||||||
|
image2.movedim(-1, 1),
|
||||||
|
image1.shape[2],
|
||||||
|
image1.shape[1],
|
||||||
|
"bilinear",
|
||||||
|
"center",
|
||||||
|
).movedim(1, -1)
|
||||||
|
return torch.cat((image1, image2), dim=0)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -139,7 +139,7 @@ class EmptyRequest(BaseModel):
|
|||||||
|
|
||||||
class UploadRequest(BaseModel):
|
class UploadRequest(BaseModel):
|
||||||
file_name: str = Field(..., description="Filename to upload")
|
file_name: str = Field(..., description="Filename to upload")
|
||||||
content_type: str | None = Field(
|
content_type: Optional[str] = Field(
|
||||||
None,
|
None,
|
||||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||||
)
|
)
|
||||||
|
57
comfy_api_nodes/apis/rodin_api.py
Normal file
57
comfy_api_nodes/apis/rodin_api.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, List
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Rodin3DGenerateRequest(BaseModel):
|
||||||
|
seed: int = Field(..., description="seed_")
|
||||||
|
tier: str = Field(..., description="Tier of generation.")
|
||||||
|
material: str = Field(..., description="The material type.")
|
||||||
|
quality: str = Field(..., description="The generation quality of the mesh.")
|
||||||
|
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
|
||||||
|
|
||||||
|
class GenerateJobsData(BaseModel):
|
||||||
|
uuids: List[str] = Field(..., description="str LIST")
|
||||||
|
subscription_key: str = Field(..., description="subscription key")
|
||||||
|
|
||||||
|
class Rodin3DGenerateResponse(BaseModel):
|
||||||
|
message: Optional[str] = Field(None, description="Return message.")
|
||||||
|
prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
|
||||||
|
submit_time: Optional[str] = Field(None, description="Submit Time")
|
||||||
|
uuid: Optional[str] = Field(None, description="Task str")
|
||||||
|
jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
|
||||||
|
|
||||||
|
class JobStatus(str, Enum):
|
||||||
|
"""
|
||||||
|
Status for jobs
|
||||||
|
"""
|
||||||
|
Done = "Done"
|
||||||
|
Failed = "Failed"
|
||||||
|
Generating = "Generating"
|
||||||
|
Waiting = "Waiting"
|
||||||
|
|
||||||
|
class Rodin3DCheckStatusRequest(BaseModel):
|
||||||
|
subscription_key: str = Field(..., description="subscription from generate endpoint")
|
||||||
|
|
||||||
|
class JobItem(BaseModel):
|
||||||
|
uuid: str = Field(..., description="uuid")
|
||||||
|
status: JobStatus = Field(...,description="Status Currently")
|
||||||
|
|
||||||
|
class Rodin3DCheckStatusResponse(BaseModel):
|
||||||
|
jobs: List[JobItem] = Field(..., description="Job status List")
|
||||||
|
|
||||||
|
class Rodin3DDownloadRequest(BaseModel):
|
||||||
|
task_uuid: str = Field(..., description="Task str")
|
||||||
|
|
||||||
|
class RodinResourceItem(BaseModel):
|
||||||
|
url: str = Field(..., description="Download Url")
|
||||||
|
name: str = Field(..., description="File name with ext")
|
||||||
|
|
||||||
|
class Rodin3DDownloadResponse(BaseModel):
|
||||||
|
list: List[RodinResourceItem] = Field(..., description="Source List")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
275
comfy_api_nodes/apis/tripo_api.py
Normal file
275
comfy_api_nodes/apis/tripo_api.py
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from comfy_api_nodes.apis import (
|
||||||
|
TripoModelVersion,
|
||||||
|
TripoTextureQuality,
|
||||||
|
)
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, List, Dict, Any, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, RootModel
|
||||||
|
|
||||||
|
class TripoStyle(str, Enum):
|
||||||
|
PERSON_TO_CARTOON = "person:person2cartoon"
|
||||||
|
ANIMAL_VENOM = "animal:venom"
|
||||||
|
OBJECT_CLAY = "object:clay"
|
||||||
|
OBJECT_STEAMPUNK = "object:steampunk"
|
||||||
|
OBJECT_CHRISTMAS = "object:christmas"
|
||||||
|
OBJECT_BARBIE = "object:barbie"
|
||||||
|
GOLD = "gold"
|
||||||
|
ANCIENT_BRONZE = "ancient_bronze"
|
||||||
|
NONE = "None"
|
||||||
|
|
||||||
|
class TripoTaskType(str, Enum):
|
||||||
|
TEXT_TO_MODEL = "text_to_model"
|
||||||
|
IMAGE_TO_MODEL = "image_to_model"
|
||||||
|
MULTIVIEW_TO_MODEL = "multiview_to_model"
|
||||||
|
TEXTURE_MODEL = "texture_model"
|
||||||
|
REFINE_MODEL = "refine_model"
|
||||||
|
ANIMATE_PRERIGCHECK = "animate_prerigcheck"
|
||||||
|
ANIMATE_RIG = "animate_rig"
|
||||||
|
ANIMATE_RETARGET = "animate_retarget"
|
||||||
|
STYLIZE_MODEL = "stylize_model"
|
||||||
|
CONVERT_MODEL = "convert_model"
|
||||||
|
|
||||||
|
class TripoTextureAlignment(str, Enum):
|
||||||
|
ORIGINAL_IMAGE = "original_image"
|
||||||
|
GEOMETRY = "geometry"
|
||||||
|
|
||||||
|
class TripoOrientation(str, Enum):
|
||||||
|
ALIGN_IMAGE = "align_image"
|
||||||
|
DEFAULT = "default"
|
||||||
|
|
||||||
|
class TripoOutFormat(str, Enum):
|
||||||
|
GLB = "glb"
|
||||||
|
FBX = "fbx"
|
||||||
|
|
||||||
|
class TripoTopology(str, Enum):
|
||||||
|
BIP = "bip"
|
||||||
|
QUAD = "quad"
|
||||||
|
|
||||||
|
class TripoSpec(str, Enum):
|
||||||
|
MIXAMO = "mixamo"
|
||||||
|
TRIPO = "tripo"
|
||||||
|
|
||||||
|
class TripoAnimation(str, Enum):
|
||||||
|
IDLE = "preset:idle"
|
||||||
|
WALK = "preset:walk"
|
||||||
|
CLIMB = "preset:climb"
|
||||||
|
JUMP = "preset:jump"
|
||||||
|
RUN = "preset:run"
|
||||||
|
SLASH = "preset:slash"
|
||||||
|
SHOOT = "preset:shoot"
|
||||||
|
HURT = "preset:hurt"
|
||||||
|
FALL = "preset:fall"
|
||||||
|
TURN = "preset:turn"
|
||||||
|
|
||||||
|
class TripoStylizeStyle(str, Enum):
|
||||||
|
LEGO = "lego"
|
||||||
|
VOXEL = "voxel"
|
||||||
|
VORONOI = "voronoi"
|
||||||
|
MINECRAFT = "minecraft"
|
||||||
|
|
||||||
|
class TripoConvertFormat(str, Enum):
|
||||||
|
GLTF = "GLTF"
|
||||||
|
USDZ = "USDZ"
|
||||||
|
FBX = "FBX"
|
||||||
|
OBJ = "OBJ"
|
||||||
|
STL = "STL"
|
||||||
|
_3MF = "3MF"
|
||||||
|
|
||||||
|
class TripoTextureFormat(str, Enum):
|
||||||
|
BMP = "BMP"
|
||||||
|
DPX = "DPX"
|
||||||
|
HDR = "HDR"
|
||||||
|
JPEG = "JPEG"
|
||||||
|
OPEN_EXR = "OPEN_EXR"
|
||||||
|
PNG = "PNG"
|
||||||
|
TARGA = "TARGA"
|
||||||
|
TIFF = "TIFF"
|
||||||
|
WEBP = "WEBP"
|
||||||
|
|
||||||
|
class TripoTaskStatus(str, Enum):
|
||||||
|
QUEUED = "queued"
|
||||||
|
RUNNING = "running"
|
||||||
|
SUCCESS = "success"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
BANNED = "banned"
|
||||||
|
EXPIRED = "expired"
|
||||||
|
|
||||||
|
class TripoFileTokenReference(BaseModel):
|
||||||
|
type: Optional[str] = Field(None, description='The type of the reference')
|
||||||
|
file_token: str
|
||||||
|
|
||||||
|
class TripoUrlReference(BaseModel):
|
||||||
|
type: Optional[str] = Field(None, description='The type of the reference')
|
||||||
|
url: str
|
||||||
|
|
||||||
|
class TripoObjectStorage(BaseModel):
|
||||||
|
bucket: str
|
||||||
|
key: str
|
||||||
|
|
||||||
|
class TripoObjectReference(BaseModel):
|
||||||
|
type: str
|
||||||
|
object: TripoObjectStorage
|
||||||
|
|
||||||
|
class TripoFileEmptyReference(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TripoFileReference(RootModel):
|
||||||
|
root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference]
|
||||||
|
|
||||||
|
class TripoGetStsTokenRequest(BaseModel):
|
||||||
|
format: str = Field(..., description='The format of the image')
|
||||||
|
|
||||||
|
class TripoTextToModelRequest(BaseModel):
|
||||||
|
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
|
||||||
|
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
|
||||||
|
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
|
||||||
|
model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
|
||||||
|
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||||
|
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||||
|
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||||
|
image_seed: Optional[int] = Field(None, description='The seed for the text')
|
||||||
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
|
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||||
|
style: Optional[TripoStyle] = None
|
||||||
|
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||||
|
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||||
|
|
||||||
|
class TripoImageToModelRequest(BaseModel):
|
||||||
|
type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task')
|
||||||
|
file: TripoFileReference = Field(..., description='The file reference to convert to a model')
|
||||||
|
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
|
||||||
|
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||||
|
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||||
|
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||||
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
|
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||||
|
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||||
|
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
|
||||||
|
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||||
|
orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT
|
||||||
|
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||||
|
|
||||||
|
class TripoMultiviewToModelRequest(BaseModel):
|
||||||
|
type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
|
||||||
|
files: List[TripoFileReference] = Field(..., description='The file references to convert to a model')
|
||||||
|
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
|
||||||
|
orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
|
||||||
|
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||||
|
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||||
|
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||||
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
|
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||||
|
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
|
||||||
|
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||||
|
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
|
||||||
|
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||||
|
|
||||||
|
class TripoTextureModelRequest(BaseModel):
|
||||||
|
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task')
|
||||||
|
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||||
|
texture: Optional[bool] = Field(True, description='Whether to apply texture to the model')
|
||||||
|
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model')
|
||||||
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
|
texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture')
|
||||||
|
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||||
|
|
||||||
|
class TripoRefineModelRequest(BaseModel):
|
||||||
|
type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task')
|
||||||
|
draft_model_task_id: str = Field(..., description='The task ID of the draft model')
|
||||||
|
|
||||||
|
class TripoAnimatePrerigcheckRequest(BaseModel):
|
||||||
|
type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task')
|
||||||
|
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||||
|
|
||||||
|
class TripoAnimateRigRequest(BaseModel):
|
||||||
|
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task')
|
||||||
|
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||||
|
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
|
||||||
|
spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging')
|
||||||
|
|
||||||
|
class TripoAnimateRetargetRequest(BaseModel):
|
||||||
|
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task')
|
||||||
|
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||||
|
animation: TripoAnimation = Field(..., description='The animation to apply')
|
||||||
|
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
|
||||||
|
bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation')
|
||||||
|
|
||||||
|
class TripoStylizeModelRequest(BaseModel):
|
||||||
|
type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task')
|
||||||
|
style: TripoStylizeStyle = Field(..., description='The style to apply to the model')
|
||||||
|
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||||
|
block_size: Optional[int] = Field(80, description='The block size for stylization')
|
||||||
|
|
||||||
|
class TripoConvertModelRequest(BaseModel):
|
||||||
|
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
|
||||||
|
format: TripoConvertFormat = Field(..., description='The format to convert to')
|
||||||
|
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||||
|
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
|
||||||
|
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
|
||||||
|
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
|
||||||
|
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
|
||||||
|
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
|
||||||
|
texture_size: Optional[int] = Field(4096, description='The size of the texture')
|
||||||
|
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
|
||||||
|
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
|
||||||
|
|
||||||
|
class TripoTaskRequest(RootModel):
|
||||||
|
root: Union[
|
||||||
|
TripoTextToModelRequest,
|
||||||
|
TripoImageToModelRequest,
|
||||||
|
TripoMultiviewToModelRequest,
|
||||||
|
TripoTextureModelRequest,
|
||||||
|
TripoRefineModelRequest,
|
||||||
|
TripoAnimatePrerigcheckRequest,
|
||||||
|
TripoAnimateRigRequest,
|
||||||
|
TripoAnimateRetargetRequest,
|
||||||
|
TripoStylizeModelRequest,
|
||||||
|
TripoConvertModelRequest
|
||||||
|
]
|
||||||
|
|
||||||
|
class TripoTaskOutput(BaseModel):
|
||||||
|
model: Optional[str] = Field(None, description='URL to the model')
|
||||||
|
base_model: Optional[str] = Field(None, description='URL to the base model')
|
||||||
|
pbr_model: Optional[str] = Field(None, description='URL to the PBR model')
|
||||||
|
rendered_image: Optional[str] = Field(None, description='URL to the rendered image')
|
||||||
|
riggable: Optional[bool] = Field(None, description='Whether the model is riggable')
|
||||||
|
|
||||||
|
class TripoTask(BaseModel):
|
||||||
|
task_id: str = Field(..., description='The task ID')
|
||||||
|
type: Optional[str] = Field(None, description='The type of task')
|
||||||
|
status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
|
||||||
|
input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task')
|
||||||
|
output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
|
||||||
|
progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
|
||||||
|
create_time: Optional[int] = Field(None, description='The creation time of the task')
|
||||||
|
running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
|
||||||
|
queue_position: Optional[int] = Field(None, description='The position in the queue')
|
||||||
|
|
||||||
|
class TripoTaskResponse(BaseModel):
|
||||||
|
code: int = Field(0, description='The response code')
|
||||||
|
data: TripoTask = Field(..., description='The task data')
|
||||||
|
|
||||||
|
class TripoGeneralResponse(BaseModel):
|
||||||
|
code: int = Field(0, description='The response code')
|
||||||
|
data: Dict[str, str] = Field(..., description='The task ID data')
|
||||||
|
|
||||||
|
class TripoBalanceData(BaseModel):
|
||||||
|
balance: float = Field(..., description='The account balance')
|
||||||
|
frozen: float = Field(..., description='The frozen balance')
|
||||||
|
|
||||||
|
class TripoBalanceResponse(BaseModel):
|
||||||
|
code: int = Field(0, description='The response code')
|
||||||
|
data: TripoBalanceData = Field(..., description='The balance data')
|
||||||
|
|
||||||
|
class TripoErrorResponse(BaseModel):
|
||||||
|
code: int = Field(..., description='The error code')
|
||||||
|
message: str = Field(..., description='The error message')
|
||||||
|
suggestion: str = Field(..., description='The suggestion for fixing the error')
|
446
comfy_api_nodes/nodes_gemini.py
Normal file
446
comfy_api_nodes/nodes_gemini.py
Normal file
@ -0,0 +1,446 @@
|
|||||||
|
"""
|
||||||
|
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
||||||
|
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||||
|
from server import PromptServer
|
||||||
|
from comfy_api_nodes.apis import (
|
||||||
|
GeminiContent,
|
||||||
|
GeminiGenerateContentRequest,
|
||||||
|
GeminiGenerateContentResponse,
|
||||||
|
GeminiInlineData,
|
||||||
|
GeminiPart,
|
||||||
|
GeminiMimeType,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
validate_string,
|
||||||
|
audio_to_base64_string,
|
||||||
|
video_to_base64_string,
|
||||||
|
tensor_to_base64_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||||
|
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiModel(str, Enum):
|
||||||
|
"""
|
||||||
|
Gemini Model Names allowed by comfy-api
|
||||||
|
"""
|
||||||
|
|
||||||
|
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
|
||||||
|
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
||||||
|
|
||||||
|
|
||||||
|
def get_gemini_endpoint(
|
||||||
|
model: GeminiModel,
|
||||||
|
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
||||||
|
"""
|
||||||
|
Get the API endpoint for a given Gemini model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The Gemini model to use, either as enum or string value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiEndpoint configured for the specific Gemini model.
|
||||||
|
"""
|
||||||
|
if isinstance(model, str):
|
||||||
|
model = GeminiModel(model)
|
||||||
|
return ApiEndpoint(
|
||||||
|
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=GeminiGenerateContentRequest,
|
||||||
|
response_model=GeminiGenerateContentResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiNode(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Node to generate text responses from a Gemini model.
|
||||||
|
|
||||||
|
This node allows users to interact with Google's Gemini AI models, providing
|
||||||
|
multimodal inputs (text, images, audio, video, files) to generate coherent
|
||||||
|
text responses. The node works with the latest Gemini models, handling the
|
||||||
|
API communication and response parsing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"model": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"tooltip": "The Gemini model to use for generating responses.",
|
||||||
|
"options": [model.value for model in GeminiModel],
|
||||||
|
"default": GeminiModel.gemini_2_5_pro_preview_05_06.value,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 42,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"images": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"audio": (
|
||||||
|
IO.AUDIO,
|
||||||
|
{
|
||||||
|
"tooltip": "Optional audio to use as context for the model.",
|
||||||
|
"default": None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"video": (
|
||||||
|
IO.VIDEO,
|
||||||
|
{
|
||||||
|
"tooltip": "Optional video to use as context for the model.",
|
||||||
|
"default": None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"files": (
|
||||||
|
"GEMINI_INPUT_FILES",
|
||||||
|
{
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses."
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node/text/Gemini"
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def get_parts_from_response(
|
||||||
|
self, response: GeminiGenerateContentResponse
|
||||||
|
) -> list[GeminiPart]:
|
||||||
|
"""
|
||||||
|
Extract all parts from the Gemini API response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The API response from Gemini.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of response parts from the first candidate.
|
||||||
|
"""
|
||||||
|
return response.candidates[0].content.parts
|
||||||
|
|
||||||
|
def get_parts_by_type(
|
||||||
|
self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
|
||||||
|
) -> list[GeminiPart]:
|
||||||
|
"""
|
||||||
|
Filter response parts by their type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The API response from Gemini.
|
||||||
|
part_type: Type of parts to extract ("text" or a MIME type).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of response parts matching the requested type.
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
for part in self.get_parts_from_response(response):
|
||||||
|
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||||
|
parts.append(part)
|
||||||
|
elif (
|
||||||
|
hasattr(part, "inlineData")
|
||||||
|
and part.inlineData
|
||||||
|
and part.inlineData.mimeType == part_type
|
||||||
|
):
|
||||||
|
parts.append(part)
|
||||||
|
# Skip parts that don't match the requested type
|
||||||
|
return parts
|
||||||
|
|
||||||
|
def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
|
||||||
|
"""
|
||||||
|
Extract and concatenate all text parts from the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The API response from Gemini.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined text from all text parts in the response.
|
||||||
|
"""
|
||||||
|
parts = self.get_parts_by_type(response, "text")
|
||||||
|
return "\n".join([part.text for part in parts])
|
||||||
|
|
||||||
|
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
|
||||||
|
"""
|
||||||
|
Convert video input to Gemini API compatible parts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_input: Video tensor from ComfyUI.
|
||||||
|
**kwargs: Additional arguments to pass to the conversion function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of GeminiPart objects containing the encoded video.
|
||||||
|
"""
|
||||||
|
from comfy_api.util import VideoContainer, VideoCodec
|
||||||
|
base_64_string = video_to_base64_string(
|
||||||
|
video_input,
|
||||||
|
container_format=VideoContainer.MP4,
|
||||||
|
codec=VideoCodec.H264
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
GeminiPart(
|
||||||
|
inlineData=GeminiInlineData(
|
||||||
|
mimeType=GeminiMimeType.video_mp4,
|
||||||
|
data=base_64_string,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
|
||||||
|
"""
|
||||||
|
Convert audio input to Gemini API compatible parts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of GeminiPart objects containing the encoded audio.
|
||||||
|
"""
|
||||||
|
audio_parts: list[GeminiPart] = []
|
||||||
|
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||||
|
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||||
|
audio_at_index = {
|
||||||
|
"waveform": audio_input["waveform"][batch_index].unsqueeze(0),
|
||||||
|
"sample_rate": audio_input["sample_rate"],
|
||||||
|
}
|
||||||
|
# Convert to MP3 format for compatibility with Gemini API
|
||||||
|
audio_bytes = audio_to_base64_string(
|
||||||
|
audio_at_index,
|
||||||
|
container_format="mp3",
|
||||||
|
codec_name="libmp3lame",
|
||||||
|
)
|
||||||
|
audio_parts.append(
|
||||||
|
GeminiPart(
|
||||||
|
inlineData=GeminiInlineData(
|
||||||
|
mimeType=GeminiMimeType.audio_mp3,
|
||||||
|
data=audio_bytes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return audio_parts
|
||||||
|
|
||||||
|
def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
|
||||||
|
"""
|
||||||
|
Convert image tensor input to Gemini API compatible parts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_input: Batch of image tensors from ComfyUI.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of GeminiPart objects containing the encoded images.
|
||||||
|
"""
|
||||||
|
image_parts: list[GeminiPart] = []
|
||||||
|
for image_index in range(image_input.shape[0]):
|
||||||
|
image_as_b64 = tensor_to_base64_string(
|
||||||
|
image_input[image_index].unsqueeze(0)
|
||||||
|
)
|
||||||
|
image_parts.append(
|
||||||
|
GeminiPart(
|
||||||
|
inlineData=GeminiInlineData(
|
||||||
|
mimeType=GeminiMimeType.image_png,
|
||||||
|
data=image_as_b64,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return image_parts
|
||||||
|
|
||||||
|
def create_text_part(self, text: str) -> GeminiPart:
|
||||||
|
"""
|
||||||
|
Create a text part for the Gemini API request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text content to include in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A GeminiPart object with the text content.
|
||||||
|
"""
|
||||||
|
return GeminiPart(text=text)
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model: GeminiModel,
|
||||||
|
images: Optional[IO.IMAGE] = None,
|
||||||
|
audio: Optional[IO.AUDIO] = None,
|
||||||
|
video: Optional[IO.VIDEO] = None,
|
||||||
|
files: Optional[list[GeminiPart]] = None,
|
||||||
|
unique_id: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[str]:
|
||||||
|
# Validate inputs
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
|
||||||
|
# Create parts list with text prompt as the first part
|
||||||
|
parts: list[GeminiPart] = [self.create_text_part(prompt)]
|
||||||
|
|
||||||
|
# Add other modal parts
|
||||||
|
if images is not None:
|
||||||
|
image_parts = self.create_image_parts(images)
|
||||||
|
parts.extend(image_parts)
|
||||||
|
if audio is not None:
|
||||||
|
parts.extend(self.create_audio_parts(audio))
|
||||||
|
if video is not None:
|
||||||
|
parts.extend(self.create_video_parts(video))
|
||||||
|
if files is not None:
|
||||||
|
parts.extend(files)
|
||||||
|
|
||||||
|
# Create response
|
||||||
|
response = SynchronousOperation(
|
||||||
|
endpoint=get_gemini_endpoint(model),
|
||||||
|
request=GeminiGenerateContentRequest(
|
||||||
|
contents=[
|
||||||
|
GeminiContent(
|
||||||
|
role="user",
|
||||||
|
parts=parts,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
# Get result output
|
||||||
|
output_text = self.get_text_from_response(response)
|
||||||
|
if unique_id and output_text:
|
||||||
|
PromptServer.instance.send_progress_text(output_text, node_id=unique_id)
|
||||||
|
|
||||||
|
return (output_text or "Empty response from Gemini model...",)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiInputFiles(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Loads and formats input files for use with the Gemini API.
|
||||||
|
|
||||||
|
This node allows users to include text (.txt) and PDF (.pdf) files as input
|
||||||
|
context for the Gemini model. Files are converted to the appropriate format
|
||||||
|
required by the API and can be chained together to include multiple files
|
||||||
|
in a single request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
"""
|
||||||
|
For details about the supported file input types, see:
|
||||||
|
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
|
"""
|
||||||
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
input_files = [
|
||||||
|
f
|
||||||
|
for f in os.scandir(input_dir)
|
||||||
|
if f.is_file()
|
||||||
|
and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
|
||||||
|
and f.stat().st_size < GEMINI_MAX_INPUT_FILE_SIZE
|
||||||
|
]
|
||||||
|
input_files = sorted(input_files, key=lambda x: x.name)
|
||||||
|
input_files = [f.name for f in input_files]
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"file": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
|
||||||
|
"options": input_files,
|
||||||
|
"default": input_files[0] if input_files else None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"GEMINI_INPUT_FILES": (
|
||||||
|
"GEMINI_INPUT_FILES",
|
||||||
|
{
|
||||||
|
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
|
||||||
|
"default": None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes."
|
||||||
|
RETURN_TYPES = ("GEMINI_INPUT_FILES",)
|
||||||
|
FUNCTION = "prepare_files"
|
||||||
|
CATEGORY = "api node/text/Gemini"
|
||||||
|
|
||||||
|
def create_file_part(self, file_path: str) -> GeminiPart:
|
||||||
|
mime_type = (
|
||||||
|
GeminiMimeType.pdf
|
||||||
|
if file_path.endswith(".pdf")
|
||||||
|
else GeminiMimeType.text_plain
|
||||||
|
)
|
||||||
|
# Use base64 string directly, not the data URI
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
file_content = f.read()
|
||||||
|
import base64
|
||||||
|
base64_str = base64.b64encode(file_content).decode("utf-8")
|
||||||
|
|
||||||
|
return GeminiPart(
|
||||||
|
inlineData=GeminiInlineData(
|
||||||
|
mimeType=mime_type,
|
||||||
|
data=base64_str,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_files(
|
||||||
|
self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = []
|
||||||
|
) -> tuple[list[GeminiPart]]:
|
||||||
|
"""
|
||||||
|
Loads and formats input files for Gemini API.
|
||||||
|
"""
|
||||||
|
file_path = folder_paths.get_annotated_filepath(file)
|
||||||
|
input_file_content = self.create_file_part(file_path)
|
||||||
|
files = [input_file_content] + GEMINI_INPUT_FILES
|
||||||
|
return (files,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"GeminiNode": GeminiNode,
|
||||||
|
"GeminiInputFiles": GeminiInputFiles,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"GeminiNode": "Google Gemini",
|
||||||
|
"GeminiInputFiles": "Gemini Input Files",
|
||||||
|
}
|
@ -1,29 +1,86 @@
|
|||||||
import io
|
import io
|
||||||
|
from typing import TypedDict, Optional
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from enum import Enum
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||||
|
from server import PromptServer
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
OpenAIImageGenerationRequest,
|
OpenAIImageGenerationRequest,
|
||||||
OpenAIImageEditRequest,
|
OpenAIImageEditRequest,
|
||||||
OpenAIImageGenerationResponse,
|
OpenAIImageGenerationResponse,
|
||||||
|
OpenAICreateResponse,
|
||||||
|
OpenAIResponse,
|
||||||
|
CreateModelResponseProperties,
|
||||||
|
Item,
|
||||||
|
Includable,
|
||||||
|
OutputContent,
|
||||||
|
InputImageContent,
|
||||||
|
Detail,
|
||||||
|
InputTextContent,
|
||||||
|
InputMessage,
|
||||||
|
InputMessageContentList,
|
||||||
|
InputContent,
|
||||||
|
InputFileContent,
|
||||||
)
|
)
|
||||||
|
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.apis.client import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
HttpMethod,
|
||||||
SynchronousOperation,
|
SynchronousOperation,
|
||||||
|
PollingOperation,
|
||||||
|
EmptyRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
from comfy_api_nodes.apinode_utils import (
|
||||||
downscale_image_tensor,
|
downscale_image_tensor,
|
||||||
validate_and_cast_response,
|
validate_and_cast_response,
|
||||||
validate_string,
|
validate_string,
|
||||||
|
tensor_to_base64_string,
|
||||||
|
text_filepath_to_data_uri,
|
||||||
)
|
)
|
||||||
|
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||||
|
|
||||||
|
|
||||||
|
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
|
||||||
|
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryEntry(TypedDict):
|
||||||
|
"""Type definition for a single history entry in the chat."""
|
||||||
|
|
||||||
|
prompt: str
|
||||||
|
response: str
|
||||||
|
response_id: str
|
||||||
|
timestamp: float
|
||||||
|
|
||||||
|
|
||||||
|
class ChatHistory(TypedDict):
|
||||||
|
"""Type definition for the chat history dictionary."""
|
||||||
|
|
||||||
|
__annotations__: dict[str, list[HistoryEntry]]
|
||||||
|
|
||||||
|
|
||||||
|
class SupportedOpenAIModel(str, Enum):
|
||||||
|
o4_mini = "o4-mini"
|
||||||
|
o1 = "o1"
|
||||||
|
o3 = "o3"
|
||||||
|
o1_pro = "o1-pro"
|
||||||
|
gpt_4o = "gpt-4o"
|
||||||
|
gpt_4_1 = "gpt-4.1"
|
||||||
|
gpt_4_1_mini = "gpt-4.1-mini"
|
||||||
|
gpt_4_1_nano = "gpt-4.1-nano"
|
||||||
|
|
||||||
|
|
||||||
class OpenAIDalle2(ComfyNodeABC):
|
class OpenAIDalle2(ComfyNodeABC):
|
||||||
"""
|
"""
|
||||||
@ -115,7 +172,7 @@ class OpenAIDalle2(ComfyNodeABC):
|
|||||||
n=1,
|
n=1,
|
||||||
size="1024x1024",
|
size="1024x1024",
|
||||||
unique_id=None,
|
unique_id=None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
model = "dall-e-2"
|
model = "dall-e-2"
|
||||||
@ -262,7 +319,7 @@ class OpenAIDalle3(ComfyNodeABC):
|
|||||||
quality="standard",
|
quality="standard",
|
||||||
size="1024x1024",
|
size="1024x1024",
|
||||||
unique_id=None,
|
unique_id=None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
model = "dall-e-3"
|
model = "dall-e-3"
|
||||||
@ -400,12 +457,12 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
|||||||
n=1,
|
n=1,
|
||||||
size="1024x1024",
|
size="1024x1024",
|
||||||
unique_id=None,
|
unique_id=None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
model = "gpt-image-1"
|
model = "gpt-image-1"
|
||||||
path = "/proxy/openai/images/generations"
|
path = "/proxy/openai/images/generations"
|
||||||
content_type="application/json"
|
content_type = "application/json"
|
||||||
request_class = OpenAIImageGenerationRequest
|
request_class = OpenAIImageGenerationRequest
|
||||||
img_binaries = []
|
img_binaries = []
|
||||||
mask_binary = None
|
mask_binary = None
|
||||||
@ -414,7 +471,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
|||||||
if image is not None:
|
if image is not None:
|
||||||
path = "/proxy/openai/images/edits"
|
path = "/proxy/openai/images/edits"
|
||||||
request_class = OpenAIImageEditRequest
|
request_class = OpenAIImageEditRequest
|
||||||
content_type ="multipart/form-data"
|
content_type = "multipart/form-data"
|
||||||
|
|
||||||
batch_size = image.shape[0]
|
batch_size = image.shape[0]
|
||||||
|
|
||||||
@ -486,17 +543,466 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
|||||||
return (img_tensor,)
|
return (img_tensor,)
|
||||||
|
|
||||||
|
|
||||||
# A dictionary that contains all nodes you want to export with their names
|
class OpenAITextNode(ComfyNodeABC):
|
||||||
# NOTE: names should be globally unique
|
"""
|
||||||
|
Base class for OpenAI text generation nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.STRING,)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node/text/OpenAI"
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatNode(OpenAITextNode):
|
||||||
|
"""
|
||||||
|
Node to generate text responses from an OpenAI model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the chat node with a new session ID and empty history."""
|
||||||
|
self.current_session_id: str = str(uuid.uuid4())
|
||||||
|
self.history: dict[str, list[HistoryEntry]] = {}
|
||||||
|
self.previous_response_id: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": (
|
||||||
|
IO.STRING,
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text inputs to the model, used to generate a response.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"persist_context": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Persist chat context between calls (multi-turn conversation)",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"model": model_field_to_node_input(
|
||||||
|
IO.COMBO,
|
||||||
|
OpenAICreateResponse,
|
||||||
|
"model",
|
||||||
|
enum_type=SupportedOpenAIModel,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"images": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"files": (
|
||||||
|
"OPENAI_INPUT_FILES",
|
||||||
|
{
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"advanced_options": (
|
||||||
|
"OPENAI_CHAT_CONFIG",
|
||||||
|
{
|
||||||
|
"default": None,
|
||||||
|
"tooltip": "Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DESCRIPTION = "Generate text responses from an OpenAI model."
|
||||||
|
|
||||||
|
def get_result_response(
|
||||||
|
self,
|
||||||
|
response_id: str,
|
||||||
|
include: Optional[list[Includable]] = None,
|
||||||
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
|
) -> OpenAIResponse:
|
||||||
|
"""
|
||||||
|
Retrieve a model response with the given ID from the OpenAI API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response_id (str): The ID of the response to retrieve.
|
||||||
|
include (Optional[List[Includable]]): Additional fields to include
|
||||||
|
in the response. See the `include` parameter for Response
|
||||||
|
creation above for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path=f"{RESPONSES_ENDPOINT}/{response_id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=OpenAIResponse,
|
||||||
|
query_params={"include": include},
|
||||||
|
),
|
||||||
|
completed_statuses=["completed"],
|
||||||
|
failed_statuses=["failed"],
|
||||||
|
status_extractor=lambda response: response.status,
|
||||||
|
auth_kwargs=auth_kwargs,
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
def get_message_content_from_response(
|
||||||
|
self, response: OpenAIResponse
|
||||||
|
) -> list[OutputContent]:
|
||||||
|
"""Extract message content from the API response."""
|
||||||
|
for output in response.output:
|
||||||
|
if output.root.type == "message":
|
||||||
|
return output.root.content
|
||||||
|
raise TypeError("No output message found in response")
|
||||||
|
|
||||||
|
def get_text_from_message_content(
|
||||||
|
self, message_content: list[OutputContent]
|
||||||
|
) -> str:
|
||||||
|
"""Extract text content from message content."""
|
||||||
|
for content_item in message_content:
|
||||||
|
if content_item.root.type == "output_text":
|
||||||
|
return str(content_item.root.text)
|
||||||
|
return "No text output found in response"
|
||||||
|
|
||||||
|
def get_history_text(self, session_id: str) -> str:
|
||||||
|
"""Convert the entire history for a given session to JSON string."""
|
||||||
|
return json.dumps(self.history[session_id])
|
||||||
|
|
||||||
|
def display_history_on_node(self, session_id: str, node_id: str) -> None:
|
||||||
|
"""Display formatted chat history on the node UI."""
|
||||||
|
render_spec = {
|
||||||
|
"node_id": node_id,
|
||||||
|
"component": "ChatHistoryWidget",
|
||||||
|
"props": {
|
||||||
|
"history": self.get_history_text(session_id),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
PromptServer.instance.send_sync(
|
||||||
|
"display_component",
|
||||||
|
render_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_to_history(
|
||||||
|
self, session_id: str, prompt: str, output_text: str, response_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Add a new entry to the chat history."""
|
||||||
|
if session_id not in self.history:
|
||||||
|
self.history[session_id] = []
|
||||||
|
self.history[session_id].append(
|
||||||
|
{
|
||||||
|
"prompt": prompt,
|
||||||
|
"response": output_text,
|
||||||
|
"response_id": response_id,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_output_text_from_response(self, response: OpenAIResponse) -> str:
|
||||||
|
"""Extract text output from the API response."""
|
||||||
|
message_contents = self.get_message_content_from_response(response)
|
||||||
|
return self.get_text_from_message_content(message_contents)
|
||||||
|
|
||||||
|
def generate_new_session_id(self) -> str:
|
||||||
|
"""Generate a new unique session ID."""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
def get_session_id(self, persist_context: bool) -> str:
|
||||||
|
"""Get the current or generate a new session ID based on context persistence."""
|
||||||
|
return (
|
||||||
|
self.current_session_id
|
||||||
|
if persist_context
|
||||||
|
else self.generate_new_session_id()
|
||||||
|
)
|
||||||
|
|
||||||
|
def tensor_to_input_image_content(
|
||||||
|
self, image: torch.Tensor, detail_level: Detail = "auto"
|
||||||
|
) -> InputImageContent:
|
||||||
|
"""Convert a tensor to an input image content object."""
|
||||||
|
return InputImageContent(
|
||||||
|
detail=detail_level,
|
||||||
|
image_url=f"data:image/png;base64,{tensor_to_base64_string(image)}",
|
||||||
|
type="input_image",
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_input_message_contents(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image: Optional[torch.Tensor] = None,
|
||||||
|
files: Optional[list[InputFileContent]] = None,
|
||||||
|
) -> InputMessageContentList:
|
||||||
|
"""Create a list of input message contents from prompt and optional image."""
|
||||||
|
content_list: list[InputContent] = [
|
||||||
|
InputTextContent(text=prompt, type="input_text"),
|
||||||
|
]
|
||||||
|
if image is not None:
|
||||||
|
for i in range(image.shape[0]):
|
||||||
|
content_list.append(
|
||||||
|
self.tensor_to_input_image_content(image[i].unsqueeze(0))
|
||||||
|
)
|
||||||
|
if files is not None:
|
||||||
|
content_list.extend(files)
|
||||||
|
|
||||||
|
return InputMessageContentList(
|
||||||
|
root=content_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_response_id_from_prompt(self, prompt: str) -> Optional[str]:
|
||||||
|
"""Extract response ID from prompt if it exists."""
|
||||||
|
parsed_id = re.search(STARTING_POINT_ID_PATTERN, prompt)
|
||||||
|
return parsed_id.group(1) if parsed_id else None
|
||||||
|
|
||||||
|
def strip_response_tag_from_prompt(self, prompt: str) -> str:
|
||||||
|
"""Remove the response ID tag from the prompt."""
|
||||||
|
return re.sub(STARTING_POINT_ID_PATTERN, "", prompt.strip())
|
||||||
|
|
||||||
|
def delete_history_after_response_id(
|
||||||
|
self, new_start_id: str, session_id: str
|
||||||
|
) -> None:
|
||||||
|
"""Delete history entries after a specific response ID."""
|
||||||
|
if session_id not in self.history:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_history = []
|
||||||
|
i = 0
|
||||||
|
while (
|
||||||
|
i < len(self.history[session_id])
|
||||||
|
and self.history[session_id][i]["response_id"] != new_start_id
|
||||||
|
):
|
||||||
|
new_history.append(self.history[session_id][i])
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Since it's the new starting point (not the response being edited), we include it as well
|
||||||
|
if i < len(self.history[session_id]):
|
||||||
|
new_history.append(self.history[session_id][i])
|
||||||
|
|
||||||
|
self.history[session_id] = new_history
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
persist_context: bool,
|
||||||
|
model: SupportedOpenAIModel,
|
||||||
|
unique_id: Optional[str] = None,
|
||||||
|
images: Optional[torch.Tensor] = None,
|
||||||
|
files: Optional[list[InputFileContent]] = None,
|
||||||
|
advanced_options: Optional[CreateModelResponseProperties] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[str]:
|
||||||
|
# Validate inputs
|
||||||
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
|
||||||
|
session_id = self.get_session_id(persist_context)
|
||||||
|
response_id_override = self.parse_response_id_from_prompt(prompt)
|
||||||
|
if response_id_override:
|
||||||
|
is_starting_from_beginning = response_id_override == "start"
|
||||||
|
if is_starting_from_beginning:
|
||||||
|
self.history[session_id] = []
|
||||||
|
previous_response_id = None
|
||||||
|
else:
|
||||||
|
previous_response_id = response_id_override
|
||||||
|
self.delete_history_after_response_id(response_id_override, session_id)
|
||||||
|
prompt = self.strip_response_tag_from_prompt(prompt)
|
||||||
|
elif persist_context:
|
||||||
|
previous_response_id = self.previous_response_id
|
||||||
|
else:
|
||||||
|
previous_response_id = None
|
||||||
|
|
||||||
|
# Create response
|
||||||
|
create_response = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=RESPONSES_ENDPOINT,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=OpenAICreateResponse,
|
||||||
|
response_model=OpenAIResponse,
|
||||||
|
),
|
||||||
|
request=OpenAICreateResponse(
|
||||||
|
input=[
|
||||||
|
Item(
|
||||||
|
root=InputMessage(
|
||||||
|
content=self.create_input_message_contents(
|
||||||
|
prompt, images, files
|
||||||
|
),
|
||||||
|
role="user",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
store=True,
|
||||||
|
stream=False,
|
||||||
|
model=model,
|
||||||
|
previous_response_id=previous_response_id,
|
||||||
|
**(
|
||||||
|
advanced_options.model_dump(exclude_none=True)
|
||||||
|
if advanced_options
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
).execute()
|
||||||
|
response_id = create_response.id
|
||||||
|
|
||||||
|
# Get result output
|
||||||
|
result_response = self.get_result_response(response_id, auth_kwargs=kwargs)
|
||||||
|
output_text = self.parse_output_text_from_response(result_response)
|
||||||
|
|
||||||
|
# Update history
|
||||||
|
self.add_to_history(session_id, prompt, output_text, response_id)
|
||||||
|
self.display_history_on_node(session_id, unique_id)
|
||||||
|
self.previous_response_id = response_id
|
||||||
|
|
||||||
|
return (output_text,)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIInputFiles(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Loads and formats input files for OpenAI API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
"""
|
||||||
|
For details about the supported file input types, see:
|
||||||
|
https://platform.openai.com/docs/guides/pdf-files?api-mode=responses
|
||||||
|
"""
|
||||||
|
input_dir = folder_paths.get_input_directory()
|
||||||
|
input_files = [
|
||||||
|
f
|
||||||
|
for f in os.scandir(input_dir)
|
||||||
|
if f.is_file()
|
||||||
|
and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
|
||||||
|
and f.stat().st_size < 32 * 1024 * 1024
|
||||||
|
]
|
||||||
|
input_files = sorted(input_files, key=lambda x: x.name)
|
||||||
|
input_files = [f.name for f in input_files]
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"file": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
|
||||||
|
"options": input_files,
|
||||||
|
"default": input_files[0] if input_files else None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"OPENAI_INPUT_FILES": (
|
||||||
|
"OPENAI_INPUT_FILES",
|
||||||
|
{
|
||||||
|
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
|
||||||
|
"default": None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DESCRIPTION = "Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes."
|
||||||
|
RETURN_TYPES = ("OPENAI_INPUT_FILES",)
|
||||||
|
FUNCTION = "prepare_files"
|
||||||
|
CATEGORY = "api node/text/OpenAI"
|
||||||
|
|
||||||
|
def create_input_file_content(self, file_path: str) -> InputFileContent:
|
||||||
|
return InputFileContent(
|
||||||
|
file_data=text_filepath_to_data_uri(file_path),
|
||||||
|
filename=os.path.basename(file_path),
|
||||||
|
type="input_file",
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_files(
|
||||||
|
self, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = []
|
||||||
|
) -> tuple[list[InputFileContent]]:
|
||||||
|
"""
|
||||||
|
Loads and formats input files for OpenAI API.
|
||||||
|
"""
|
||||||
|
file_path = folder_paths.get_annotated_filepath(file)
|
||||||
|
input_file_content = self.create_input_file_content(file_path)
|
||||||
|
files = [input_file_content] + OPENAI_INPUT_FILES
|
||||||
|
return (files,)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatConfig(ComfyNodeABC):
|
||||||
|
"""Allows setting additional configuration for the OpenAI Chat Node."""
|
||||||
|
|
||||||
|
RETURN_TYPES = ("OPENAI_CHAT_CONFIG",)
|
||||||
|
FUNCTION = "configure"
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Allows specifying advanced configuration options for the OpenAI Chat Nodes."
|
||||||
|
)
|
||||||
|
CATEGORY = "api node/text/OpenAI"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"truncation": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["auto", "disabled"],
|
||||||
|
"default": "auto",
|
||||||
|
"tooltip": "The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"max_output_tokens": model_field_to_node_input(
|
||||||
|
IO.INT,
|
||||||
|
OpenAICreateResponse,
|
||||||
|
"max_output_tokens",
|
||||||
|
min=16,
|
||||||
|
default=4096,
|
||||||
|
max=16384,
|
||||||
|
tooltip="An upper bound for the number of tokens that can be generated for a response, including visible output tokens",
|
||||||
|
),
|
||||||
|
"instructions": model_field_to_node_input(
|
||||||
|
IO.STRING, OpenAICreateResponse, "instructions", multiline=True
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def configure(
|
||||||
|
self,
|
||||||
|
truncation: bool,
|
||||||
|
instructions: Optional[str] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
) -> tuple[CreateModelResponseProperties]:
|
||||||
|
"""
|
||||||
|
Configure advanced options for the OpenAI Chat Node.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
While `top_p` and `temperature` are listed as properties in the
|
||||||
|
spec, they are not supported for all models (e.g., o4-mini).
|
||||||
|
They are not exposed as inputs at all to avoid having to manually
|
||||||
|
remove depending on model choice.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
CreateModelResponseProperties(
|
||||||
|
instructions=instructions,
|
||||||
|
truncation=truncation,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"OpenAIDalle2": OpenAIDalle2,
|
"OpenAIDalle2": OpenAIDalle2,
|
||||||
"OpenAIDalle3": OpenAIDalle3,
|
"OpenAIDalle3": OpenAIDalle3,
|
||||||
"OpenAIGPTImage1": OpenAIGPTImage1,
|
"OpenAIGPTImage1": OpenAIGPTImage1,
|
||||||
|
"OpenAIChatNode": OpenAIChatNode,
|
||||||
|
"OpenAIInputFiles": OpenAIInputFiles,
|
||||||
|
"OpenAIChatConfig": OpenAIChatConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"OpenAIDalle2": "OpenAI DALL·E 2",
|
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||||
"OpenAIDalle3": "OpenAI DALL·E 3",
|
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||||
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||||
|
"OpenAIChatNode": "OpenAI Chat",
|
||||||
|
"OpenAIInputFiles": "OpenAI Chat Input Files",
|
||||||
|
"OpenAIChatConfig": "OpenAI Chat Advanced Options",
|
||||||
}
|
}
|
||||||
|
462
comfy_api_nodes/nodes_rodin.py
Normal file
462
comfy_api_nodes/nodes_rodin.py
Normal file
@ -0,0 +1,462 @@
|
|||||||
|
"""
|
||||||
|
ComfyUI X Rodin3D(Deemos) API Nodes
|
||||||
|
|
||||||
|
Rodin API docs: https://developer.hyper3d.ai/
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from inspect import cleandoc
|
||||||
|
from comfy.comfy_types.node_typing import IO
|
||||||
|
import folder_paths as comfy_paths
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from PIL import Image
|
||||||
|
from comfy_api_nodes.apis.rodin_api import (
|
||||||
|
Rodin3DGenerateRequest,
|
||||||
|
Rodin3DGenerateResponse,
|
||||||
|
Rodin3DCheckStatusRequest,
|
||||||
|
Rodin3DCheckStatusResponse,
|
||||||
|
Rodin3DDownloadRequest,
|
||||||
|
Rodin3DDownloadResponse,
|
||||||
|
JobStatus,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
PollingOperation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
COMMON_PARAMETERS = {
|
||||||
|
"Seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default":0,
|
||||||
|
"min":0,
|
||||||
|
"max":65535,
|
||||||
|
"display":"number"
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"Material_Type": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["PBR", "Shaded"],
|
||||||
|
"default": "PBR"
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"Polygon_count": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
|
||||||
|
"default": "18K-Quad"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_task_error(response: Rodin3DGenerateResponse):
|
||||||
|
"""Check if the response has error"""
|
||||||
|
return hasattr(response, "error")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Rodin3DAPI:
|
||||||
|
"""
|
||||||
|
Generate 3D Assets using Rodin API
|
||||||
|
"""
|
||||||
|
RETURN_TYPES = (IO.STRING,)
|
||||||
|
RETURN_NAMES = ("3D Model Path",)
|
||||||
|
CATEGORY = "api node/3d/Rodin"
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048):
|
||||||
|
"""
|
||||||
|
Converts a PyTorch tensor to a file-like object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
|
||||||
|
where C is the number of channels (3 for RGB), H is height, and W is width.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- io.BytesIO: A file-like object containing the image data.
|
||||||
|
"""
|
||||||
|
array = tensor.cpu().numpy()
|
||||||
|
array = (array * 255).astype('uint8')
|
||||||
|
image = Image.fromarray(array, 'RGB')
|
||||||
|
|
||||||
|
original_width, original_height = image.size
|
||||||
|
original_pixels = original_width * original_height
|
||||||
|
if original_pixels > max_pixels:
|
||||||
|
scale = math.sqrt(max_pixels / original_pixels)
|
||||||
|
new_width = int(original_width * scale)
|
||||||
|
new_height = int(original_height * scale)
|
||||||
|
else:
|
||||||
|
new_width, new_height = original_width, original_height
|
||||||
|
|
||||||
|
if new_width != original_width or new_height != original_height:
|
||||||
|
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
|
||||||
|
img_byte_arr.seek(0)
|
||||||
|
return img_byte_arr
|
||||||
|
|
||||||
|
def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
|
||||||
|
has_failed = any(job.status == JobStatus.Failed for job in response.jobs)
|
||||||
|
all_done = all(job.status == JobStatus.Done for job in response.jobs)
|
||||||
|
status_list = [str(job.status) for job in response.jobs]
|
||||||
|
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
|
||||||
|
if has_failed:
|
||||||
|
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
|
||||||
|
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
|
||||||
|
elif all_done:
|
||||||
|
return "DONE"
|
||||||
|
else:
|
||||||
|
return "Generating"
|
||||||
|
|
||||||
|
def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
|
||||||
|
if images == None:
|
||||||
|
raise Exception("Rodin 3D generate requires at least 1 image.")
|
||||||
|
if len(images) >= 5:
|
||||||
|
raise Exception("Rodin 3D generate requires up to 5 image.")
|
||||||
|
|
||||||
|
path = "/proxy/rodin/api/v2/rodin"
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=path,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=Rodin3DGenerateRequest,
|
||||||
|
response_model=Rodin3DGenerateResponse,
|
||||||
|
),
|
||||||
|
request=Rodin3DGenerateRequest(
|
||||||
|
seed=seed,
|
||||||
|
tier=tier,
|
||||||
|
material=material,
|
||||||
|
quality=quality,
|
||||||
|
mesh_mode=mesh_mode
|
||||||
|
),
|
||||||
|
files=[
|
||||||
|
(
|
||||||
|
"images",
|
||||||
|
open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image)
|
||||||
|
)
|
||||||
|
for image in images if image is not None
|
||||||
|
],
|
||||||
|
content_type = "multipart/form-data",
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = operation.execute()
|
||||||
|
|
||||||
|
if create_task_error(response):
|
||||||
|
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
||||||
|
logging.error(error_message)
|
||||||
|
raise Exception(error_message)
|
||||||
|
|
||||||
|
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
|
||||||
|
subscription_key = response.jobs.subscription_key
|
||||||
|
task_uuid = response.uuid
|
||||||
|
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
|
||||||
|
return task_uuid, subscription_key
|
||||||
|
|
||||||
|
def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
|
||||||
|
|
||||||
|
path = "/proxy/rodin/api/v2/status"
|
||||||
|
|
||||||
|
poll_operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path = path,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=Rodin3DCheckStatusRequest,
|
||||||
|
response_model=Rodin3DCheckStatusResponse,
|
||||||
|
),
|
||||||
|
request=Rodin3DCheckStatusRequest(
|
||||||
|
subscription_key = subscription_key
|
||||||
|
),
|
||||||
|
completed_statuses=["DONE"],
|
||||||
|
failed_statuses=["FAILED"],
|
||||||
|
status_extractor=self.check_rodin_status,
|
||||||
|
poll_interval=3.0,
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
||||||
|
|
||||||
|
return poll_operation.execute()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
|
||||||
|
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||||
|
|
||||||
|
path = "/proxy/rodin/api/v2/download"
|
||||||
|
operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=path,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=Rodin3DDownloadRequest,
|
||||||
|
response_model=Rodin3DDownloadResponse,
|
||||||
|
),
|
||||||
|
request=Rodin3DDownloadRequest(
|
||||||
|
task_uuid=uuid
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return operation.execute()
|
||||||
|
|
||||||
|
def GetQualityAndMode(self, PolyCount):
|
||||||
|
if PolyCount == "200K-Triangle":
|
||||||
|
mesh_mode = "Raw"
|
||||||
|
quality = "medium"
|
||||||
|
else:
|
||||||
|
mesh_mode = "Quad"
|
||||||
|
if PolyCount == "4K-Quad":
|
||||||
|
quality = "extra-low"
|
||||||
|
elif PolyCount == "8K-Quad":
|
||||||
|
quality = "low"
|
||||||
|
elif PolyCount == "18K-Quad":
|
||||||
|
quality = "medium"
|
||||||
|
elif PolyCount == "50K-Quad":
|
||||||
|
quality = "high"
|
||||||
|
else:
|
||||||
|
quality = "medium"
|
||||||
|
|
||||||
|
return mesh_mode, quality
|
||||||
|
|
||||||
|
def DownLoadFiles(self, Url_List):
|
||||||
|
Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||||
|
os.makedirs(Save_path, exist_ok=True)
|
||||||
|
model_file_path = None
|
||||||
|
for Item in Url_List.list:
|
||||||
|
url = Item.url
|
||||||
|
file_name = Item.name
|
||||||
|
file_path = os.path.join(Save_path, file_name)
|
||||||
|
if file_path.endswith(".glb"):
|
||||||
|
model_file_path = file_path
|
||||||
|
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
||||||
|
max_retries = 5
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
with requests.get(url, stream=True) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
shutil.copyfileobj(r.raw, f)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
logging.info("Retrying...")
|
||||||
|
time.sleep(2)
|
||||||
|
else:
|
||||||
|
logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.")
|
||||||
|
|
||||||
|
return model_file_path
|
||||||
|
|
||||||
|
|
||||||
|
class Rodin3D_Regular(Rodin3DAPI):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"Images":
|
||||||
|
(
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"forceInput":True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
**COMMON_PARAMETERS
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
Images,
|
||||||
|
Seed,
|
||||||
|
Material_Type,
|
||||||
|
Polygon_count,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
tier = "Regular"
|
||||||
|
num_images = Images.shape[0]
|
||||||
|
m_images = []
|
||||||
|
for i in range(num_images):
|
||||||
|
m_images.append(Images[i])
|
||||||
|
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||||
|
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||||
|
self.poll_for_task_status(subscription_key, **kwargs)
|
||||||
|
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||||
|
model = self.DownLoadFiles(Download_List)
|
||||||
|
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
class Rodin3D_Detail(Rodin3DAPI):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"Images":
|
||||||
|
(
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"forceInput":True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
**COMMON_PARAMETERS
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
Images,
|
||||||
|
Seed,
|
||||||
|
Material_Type,
|
||||||
|
Polygon_count,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
tier = "Detail"
|
||||||
|
num_images = Images.shape[0]
|
||||||
|
m_images = []
|
||||||
|
for i in range(num_images):
|
||||||
|
m_images.append(Images[i])
|
||||||
|
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||||
|
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||||
|
self.poll_for_task_status(subscription_key, **kwargs)
|
||||||
|
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||||
|
model = self.DownLoadFiles(Download_List)
|
||||||
|
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
class Rodin3D_Smooth(Rodin3DAPI):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"Images":
|
||||||
|
(
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"forceInput":True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
**COMMON_PARAMETERS
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
Images,
|
||||||
|
Seed,
|
||||||
|
Material_Type,
|
||||||
|
Polygon_count,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
tier = "Smooth"
|
||||||
|
num_images = Images.shape[0]
|
||||||
|
m_images = []
|
||||||
|
for i in range(num_images):
|
||||||
|
m_images.append(Images[i])
|
||||||
|
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
||||||
|
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||||
|
self.poll_for_task_status(subscription_key, **kwargs)
|
||||||
|
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||||
|
model = self.DownLoadFiles(Download_List)
|
||||||
|
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
class Rodin3D_Sketch(Rodin3DAPI):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"Images":
|
||||||
|
(
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"forceInput":True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"Seed":
|
||||||
|
(
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default":0,
|
||||||
|
"min":0,
|
||||||
|
"max":65535,
|
||||||
|
"display":"number"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
Images,
|
||||||
|
Seed,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
tier = "Sketch"
|
||||||
|
num_images = Images.shape[0]
|
||||||
|
m_images = []
|
||||||
|
for i in range(num_images):
|
||||||
|
m_images.append(Images[i])
|
||||||
|
material_type = "PBR"
|
||||||
|
quality = "medium"
|
||||||
|
mesh_mode = "Quad"
|
||||||
|
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
||||||
|
self.poll_for_task_status(subscription_key, **kwargs)
|
||||||
|
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
||||||
|
model = self.DownLoadFiles(Download_List)
|
||||||
|
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
# A dictionary that contains all nodes you want to export with their names
|
||||||
|
# NOTE: names should be globally unique
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"Rodin3D_Regular": Rodin3D_Regular,
|
||||||
|
"Rodin3D_Detail": Rodin3D_Detail,
|
||||||
|
"Rodin3D_Smooth": Rodin3D_Smooth,
|
||||||
|
"Rodin3D_Sketch": Rodin3D_Sketch,
|
||||||
|
}
|
||||||
|
|
||||||
|
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"Rodin3D_Regular": "Rodin 3D Generate - Regular Generate",
|
||||||
|
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate",
|
||||||
|
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate",
|
||||||
|
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate",
|
||||||
|
}
|
635
comfy_api_nodes/nodes_runway.py
Normal file
635
comfy_api_nodes/nodes_runway.py
Normal file
@ -0,0 +1,635 @@
|
|||||||
|
"""Runway API Nodes
|
||||||
|
|
||||||
|
API Docs:
|
||||||
|
- https://docs.dev.runwayml.com/api/#tag/Task-management/paths/~1v1~1tasks~1%7Bid%7D/delete
|
||||||
|
|
||||||
|
User Guides:
|
||||||
|
- https://help.runwayml.com/hc/en-us/sections/30265301423635-Gen-3-Alpha
|
||||||
|
- https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video
|
||||||
|
- https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo
|
||||||
|
- https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Union, Optional, Any
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_api_nodes.apis import (
|
||||||
|
RunwayImageToVideoRequest,
|
||||||
|
RunwayImageToVideoResponse,
|
||||||
|
RunwayTaskStatusResponse as TaskStatusResponse,
|
||||||
|
RunwayTaskStatusEnum as TaskStatus,
|
||||||
|
RunwayModelEnum as Model,
|
||||||
|
RunwayDurationEnum as Duration,
|
||||||
|
RunwayAspectRatioEnum as AspectRatio,
|
||||||
|
RunwayPromptImageObject,
|
||||||
|
RunwayPromptImageDetailedObject,
|
||||||
|
RunwayTextToImageRequest,
|
||||||
|
RunwayTextToImageResponse,
|
||||||
|
Model4,
|
||||||
|
ReferenceImage,
|
||||||
|
RunwayTextToImageAspectRatioEnum,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
PollingOperation,
|
||||||
|
EmptyRequest,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
upload_images_to_comfyapi,
|
||||||
|
download_url_to_video_output,
|
||||||
|
image_tensor_pair_to_batch,
|
||||||
|
validate_string,
|
||||||
|
download_url_to_image_tensor,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||||
|
from comfy_api.input_impl import VideoFromFile
|
||||||
|
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
||||||
|
|
||||||
|
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||||
|
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||||
|
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
|
||||||
|
|
||||||
|
AVERAGE_DURATION_I2V_SECONDS = 64
|
||||||
|
AVERAGE_DURATION_FLF_SECONDS = 256
|
||||||
|
AVERAGE_DURATION_T2I_SECONDS = 41
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayApiError(Exception):
|
||||||
|
"""Base exception for Runway API errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayGen4TurboAspectRatio(str, Enum):
|
||||||
|
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
|
||||||
|
|
||||||
|
field_1280_720 = "1280:720"
|
||||||
|
field_720_1280 = "720:1280"
|
||||||
|
field_1104_832 = "1104:832"
|
||||||
|
field_832_1104 = "832:1104"
|
||||||
|
field_960_960 = "960:960"
|
||||||
|
field_1584_672 = "1584:672"
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayGen3aAspectRatio(str, Enum):
|
||||||
|
"""Aspect ratios supported for Image to Video API when using gen3a_turbo model."""
|
||||||
|
|
||||||
|
field_768_1280 = "768:1280"
|
||||||
|
field_1280_768 = "1280:768"
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||||
|
"""Returns the video URL from the task status response if it exists."""
|
||||||
|
if response.output and len(response.output) > 0:
|
||||||
|
return response.output[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: replace with updated image validation utils (upstream)
|
||||||
|
def validate_input_image(image: torch.Tensor) -> bool:
|
||||||
|
"""
|
||||||
|
Validate the input image is within the size limits for the Runway API.
|
||||||
|
See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons
|
||||||
|
"""
|
||||||
|
return image.shape[2] < 8000 and image.shape[1] < 8000
|
||||||
|
|
||||||
|
|
||||||
|
def poll_until_finished(
|
||||||
|
auth_kwargs: dict[str, str],
|
||||||
|
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
||||||
|
estimated_duration: Optional[int] = None,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
) -> TaskStatusResponse:
|
||||||
|
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
|
||||||
|
return PollingOperation(
|
||||||
|
poll_endpoint=api_endpoint,
|
||||||
|
completed_statuses=[
|
||||||
|
TaskStatus.SUCCEEDED.value,
|
||||||
|
],
|
||||||
|
failed_statuses=[
|
||||||
|
TaskStatus.FAILED.value,
|
||||||
|
TaskStatus.CANCELLED.value,
|
||||||
|
],
|
||||||
|
status_extractor=lambda response: (response.status.value),
|
||||||
|
auth_kwargs=auth_kwargs,
|
||||||
|
result_url_extractor=get_video_url_from_task_status,
|
||||||
|
estimated_duration=estimated_duration,
|
||||||
|
node_id=node_id,
|
||||||
|
progress_extractor=extract_progress_from_task_status,
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_progress_from_task_status(
|
||||||
|
response: TaskStatusResponse,
|
||||||
|
) -> Union[float, None]:
|
||||||
|
if hasattr(response, "progress") and response.progress is not None:
|
||||||
|
return response.progress * 100
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||||
|
"""Returns the image URL from the task status response if it exists."""
|
||||||
|
if response.output and len(response.output) > 0:
|
||||||
|
return response.output[0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayVideoGenNode(ComfyNodeABC):
|
||||||
|
"""Runway Video Node Base."""
|
||||||
|
|
||||||
|
RETURN_TYPES = ("VIDEO",)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node/video/Runway"
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool:
|
||||||
|
"""
|
||||||
|
Validate the task creation response from the Runway API matches
|
||||||
|
expected format.
|
||||||
|
"""
|
||||||
|
if not bool(response.id):
|
||||||
|
raise RunwayApiError("Invalid initial response from Runway API.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_response(self, response: RunwayImageToVideoResponse) -> bool:
|
||||||
|
"""
|
||||||
|
Validate the successful task status response from the Runway API
|
||||||
|
matches expected format.
|
||||||
|
"""
|
||||||
|
if not response.output or len(response.output) == 0:
|
||||||
|
raise RunwayApiError(
|
||||||
|
"Runway task succeeded but no video data found in response."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_response(
|
||||||
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
|
) -> RunwayImageToVideoResponse:
|
||||||
|
"""Poll the task status until it is finished then get the response."""
|
||||||
|
return poll_until_finished(
|
||||||
|
auth_kwargs,
|
||||||
|
ApiEndpoint(
|
||||||
|
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
),
|
||||||
|
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||||
|
node_id=node_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_video(
|
||||||
|
self,
|
||||||
|
request: RunwayImageToVideoRequest,
|
||||||
|
auth_kwargs: dict[str, str],
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
) -> tuple[VideoFromFile]:
|
||||||
|
initial_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=PATH_IMAGE_TO_VIDEO,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=RunwayImageToVideoRequest,
|
||||||
|
response_model=RunwayImageToVideoResponse,
|
||||||
|
),
|
||||||
|
request=request,
|
||||||
|
auth_kwargs=auth_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_response = initial_operation.execute()
|
||||||
|
self.validate_task_created(initial_response)
|
||||||
|
task_id = initial_response.id
|
||||||
|
|
||||||
|
final_response = self.get_response(task_id, auth_kwargs, node_id)
|
||||||
|
self.validate_response(final_response)
|
||||||
|
|
||||||
|
video_url = get_video_url_from_task_status(final_response)
|
||||||
|
return (download_url_to_video_output(video_url),)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||||
|
"""Runway Image to Video Node using Gen3a Turbo model."""
|
||||||
|
|
||||||
|
DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": model_field_to_node_input(
|
||||||
|
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||||
|
),
|
||||||
|
"start_frame": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{"tooltip": "Start frame to be used for the video"},
|
||||||
|
),
|
||||||
|
"duration": model_field_to_node_input(
|
||||||
|
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||||
|
),
|
||||||
|
"ratio": model_field_to_node_input(
|
||||||
|
IO.COMBO,
|
||||||
|
RunwayImageToVideoRequest,
|
||||||
|
"ratio",
|
||||||
|
enum_type=RunwayGen3aAspectRatio,
|
||||||
|
),
|
||||||
|
"seed": model_field_to_node_input(
|
||||||
|
IO.INT,
|
||||||
|
RunwayImageToVideoRequest,
|
||||||
|
"seed",
|
||||||
|
control_after_generate=True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
start_frame: torch.Tensor,
|
||||||
|
duration: str,
|
||||||
|
ratio: str,
|
||||||
|
seed: int,
|
||||||
|
unique_id: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[VideoFromFile]:
|
||||||
|
# Validate inputs
|
||||||
|
validate_string(prompt, min_length=1)
|
||||||
|
validate_input_image(start_frame)
|
||||||
|
|
||||||
|
# Upload image
|
||||||
|
download_urls = upload_images_to_comfyapi(
|
||||||
|
start_frame,
|
||||||
|
max_images=1,
|
||||||
|
mime_type="image/png",
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
if len(download_urls) != 1:
|
||||||
|
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||||
|
|
||||||
|
return self.generate_video(
|
||||||
|
RunwayImageToVideoRequest(
|
||||||
|
promptText=prompt,
|
||||||
|
seed=seed,
|
||||||
|
model=Model("gen3a_turbo"),
|
||||||
|
duration=Duration(duration),
|
||||||
|
ratio=AspectRatio(ratio),
|
||||||
|
promptImage=RunwayPromptImageObject(
|
||||||
|
root=[
|
||||||
|
RunwayPromptImageDetailedObject(
|
||||||
|
uri=str(download_urls[0]), position="first"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
node_id=unique_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
||||||
|
"""Runway Image to Video Node using Gen4 Turbo model."""
|
||||||
|
|
||||||
|
DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": model_field_to_node_input(
|
||||||
|
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||||
|
),
|
||||||
|
"start_frame": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{"tooltip": "Start frame to be used for the video"},
|
||||||
|
),
|
||||||
|
"duration": model_field_to_node_input(
|
||||||
|
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||||
|
),
|
||||||
|
"ratio": model_field_to_node_input(
|
||||||
|
IO.COMBO,
|
||||||
|
RunwayImageToVideoRequest,
|
||||||
|
"ratio",
|
||||||
|
enum_type=RunwayGen4TurboAspectRatio,
|
||||||
|
),
|
||||||
|
"seed": model_field_to_node_input(
|
||||||
|
IO.INT,
|
||||||
|
RunwayImageToVideoRequest,
|
||||||
|
"seed",
|
||||||
|
control_after_generate=True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
start_frame: torch.Tensor,
|
||||||
|
duration: str,
|
||||||
|
ratio: str,
|
||||||
|
seed: int,
|
||||||
|
unique_id: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[VideoFromFile]:
|
||||||
|
# Validate inputs
|
||||||
|
validate_string(prompt, min_length=1)
|
||||||
|
validate_input_image(start_frame)
|
||||||
|
|
||||||
|
# Upload image
|
||||||
|
download_urls = upload_images_to_comfyapi(
|
||||||
|
start_frame,
|
||||||
|
max_images=1,
|
||||||
|
mime_type="image/png",
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
if len(download_urls) != 1:
|
||||||
|
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||||
|
|
||||||
|
return self.generate_video(
|
||||||
|
RunwayImageToVideoRequest(
|
||||||
|
promptText=prompt,
|
||||||
|
seed=seed,
|
||||||
|
model=Model("gen4_turbo"),
|
||||||
|
duration=Duration(duration),
|
||||||
|
ratio=AspectRatio(ratio),
|
||||||
|
promptImage=RunwayPromptImageObject(
|
||||||
|
root=[
|
||||||
|
RunwayPromptImageDetailedObject(
|
||||||
|
uri=str(download_urls[0]), position="first"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
node_id=unique_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
||||||
|
"""Runway First-Last Frame Node."""
|
||||||
|
|
||||||
|
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
|
||||||
|
|
||||||
|
def get_response(
|
||||||
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
|
) -> RunwayImageToVideoResponse:
|
||||||
|
return poll_until_finished(
|
||||||
|
auth_kwargs,
|
||||||
|
ApiEndpoint(
|
||||||
|
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
),
|
||||||
|
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||||
|
node_id=node_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": model_field_to_node_input(
|
||||||
|
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
|
||||||
|
),
|
||||||
|
"start_frame": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{"tooltip": "Start frame to be used for the video"},
|
||||||
|
),
|
||||||
|
"end_frame": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"tooltip": "End frame to be used for the video. Supported for gen3a_turbo only."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"duration": model_field_to_node_input(
|
||||||
|
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
|
||||||
|
),
|
||||||
|
"ratio": model_field_to_node_input(
|
||||||
|
IO.COMBO,
|
||||||
|
RunwayImageToVideoRequest,
|
||||||
|
"ratio",
|
||||||
|
enum_type=RunwayGen3aAspectRatio,
|
||||||
|
),
|
||||||
|
"seed": model_field_to_node_input(
|
||||||
|
IO.INT,
|
||||||
|
RunwayImageToVideoRequest,
|
||||||
|
"seed",
|
||||||
|
control_after_generate=True,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
start_frame: torch.Tensor,
|
||||||
|
end_frame: torch.Tensor,
|
||||||
|
duration: str,
|
||||||
|
ratio: str,
|
||||||
|
seed: int,
|
||||||
|
unique_id: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[VideoFromFile]:
|
||||||
|
# Validate inputs
|
||||||
|
validate_string(prompt, min_length=1)
|
||||||
|
validate_input_image(start_frame)
|
||||||
|
validate_input_image(end_frame)
|
||||||
|
|
||||||
|
# Upload images
|
||||||
|
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||||
|
download_urls = upload_images_to_comfyapi(
|
||||||
|
stacked_input_images,
|
||||||
|
max_images=2,
|
||||||
|
mime_type="image/png",
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
if len(download_urls) != 2:
|
||||||
|
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||||
|
|
||||||
|
return self.generate_video(
|
||||||
|
RunwayImageToVideoRequest(
|
||||||
|
promptText=prompt,
|
||||||
|
seed=seed,
|
||||||
|
model=Model("gen3a_turbo"),
|
||||||
|
duration=Duration(duration),
|
||||||
|
ratio=AspectRatio(ratio),
|
||||||
|
promptImage=RunwayPromptImageObject(
|
||||||
|
root=[
|
||||||
|
RunwayPromptImageDetailedObject(
|
||||||
|
uri=str(download_urls[0]), position="first"
|
||||||
|
),
|
||||||
|
RunwayPromptImageDetailedObject(
|
||||||
|
uri=str(download_urls[1]), position="last"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
node_id=unique_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunwayTextToImageNode(ComfyNodeABC):
|
||||||
|
"""Runway Text to Image Node."""
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "api_call"
|
||||||
|
CATEGORY = "api node/image/Runway"
|
||||||
|
API_NODE = True
|
||||||
|
DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": model_field_to_node_input(
|
||||||
|
IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True
|
||||||
|
),
|
||||||
|
"ratio": model_field_to_node_input(
|
||||||
|
IO.COMBO,
|
||||||
|
RunwayTextToImageRequest,
|
||||||
|
"ratio",
|
||||||
|
enum_type=RunwayTextToImageAspectRatioEnum,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"reference_image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{"tooltip": "Optional reference image to guide the generation"},
|
||||||
|
)
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate_task_created(self, response: RunwayTextToImageResponse) -> bool:
|
||||||
|
"""
|
||||||
|
Validate the task creation response from the Runway API matches
|
||||||
|
expected format.
|
||||||
|
"""
|
||||||
|
if not bool(response.id):
|
||||||
|
raise RunwayApiError("Invalid initial response from Runway API.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_response(self, response: TaskStatusResponse) -> bool:
|
||||||
|
"""
|
||||||
|
Validate the successful task status response from the Runway API
|
||||||
|
matches expected format.
|
||||||
|
"""
|
||||||
|
if not response.output or len(response.output) == 0:
|
||||||
|
raise RunwayApiError(
|
||||||
|
"Runway task succeeded but no image data found in response."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_response(
|
||||||
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
|
) -> TaskStatusResponse:
|
||||||
|
"""Poll the task status until it is finished then get the response."""
|
||||||
|
return poll_until_finished(
|
||||||
|
auth_kwargs,
|
||||||
|
ApiEndpoint(
|
||||||
|
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
),
|
||||||
|
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||||
|
node_id=node_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def api_call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
ratio: str,
|
||||||
|
reference_image: Optional[torch.Tensor] = None,
|
||||||
|
unique_id: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor]:
|
||||||
|
# Validate inputs
|
||||||
|
validate_string(prompt, min_length=1)
|
||||||
|
|
||||||
|
# Prepare reference images if provided
|
||||||
|
reference_images = None
|
||||||
|
if reference_image is not None:
|
||||||
|
validate_input_image(reference_image)
|
||||||
|
download_urls = upload_images_to_comfyapi(
|
||||||
|
reference_image,
|
||||||
|
max_images=1,
|
||||||
|
mime_type="image/png",
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
if len(download_urls) != 1:
|
||||||
|
raise RunwayApiError("Failed to upload reference image to comfy api.")
|
||||||
|
|
||||||
|
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
|
||||||
|
|
||||||
|
# Create request
|
||||||
|
request = RunwayTextToImageRequest(
|
||||||
|
promptText=prompt,
|
||||||
|
model=Model4.gen4_image,
|
||||||
|
ratio=ratio,
|
||||||
|
referenceImages=reference_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute initial request
|
||||||
|
initial_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=PATH_TEXT_TO_IMAGE,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=RunwayTextToImageRequest,
|
||||||
|
response_model=RunwayTextToImageResponse,
|
||||||
|
),
|
||||||
|
request=request,
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_response = initial_operation.execute()
|
||||||
|
self.validate_task_created(initial_response)
|
||||||
|
task_id = initial_response.id
|
||||||
|
|
||||||
|
# Poll for completion
|
||||||
|
final_response = self.get_response(
|
||||||
|
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||||
|
)
|
||||||
|
self.validate_response(final_response)
|
||||||
|
|
||||||
|
# Download and return image
|
||||||
|
image_url = get_image_url_from_task_status(final_response)
|
||||||
|
return (download_url_to_image_tensor(image_url),)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"RunwayFirstLastFrameNode": RunwayFirstLastFrameNode,
|
||||||
|
"RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a,
|
||||||
|
"RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4,
|
||||||
|
"RunwayTextToImageNode": RunwayTextToImageNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video",
|
||||||
|
"RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)",
|
||||||
|
"RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)",
|
||||||
|
"RunwayTextToImageNode": "Runway Text to Image",
|
||||||
|
}
|
574
comfy_api_nodes/nodes_tripo.py
Normal file
574
comfy_api_nodes/nodes_tripo.py
Normal file
@ -0,0 +1,574 @@
|
|||||||
|
import os
|
||||||
|
from folder_paths import get_output_directory
|
||||||
|
from comfy_api_nodes.mapper_utils import model_field_to_node_input
|
||||||
|
from comfy.comfy_types.node_typing import IO
|
||||||
|
from comfy_api_nodes.apis import (
|
||||||
|
TripoOrientation,
|
||||||
|
TripoModelVersion,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.tripo_api import (
|
||||||
|
TripoTaskType,
|
||||||
|
TripoStyle,
|
||||||
|
TripoFileReference,
|
||||||
|
TripoFileEmptyReference,
|
||||||
|
TripoUrlReference,
|
||||||
|
TripoTaskResponse,
|
||||||
|
TripoTaskStatus,
|
||||||
|
TripoTextToModelRequest,
|
||||||
|
TripoImageToModelRequest,
|
||||||
|
TripoMultiviewToModelRequest,
|
||||||
|
TripoTextureModelRequest,
|
||||||
|
TripoRefineModelRequest,
|
||||||
|
TripoAnimateRigRequest,
|
||||||
|
TripoAnimateRetargetRequest,
|
||||||
|
TripoConvertModelRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
PollingOperation,
|
||||||
|
EmptyRequest,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apinode_utils import (
|
||||||
|
upload_images_to_comfyapi,
|
||||||
|
download_url_to_bytesio,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upload_image_to_tripo(image, **kwargs):
|
||||||
|
urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
|
||||||
|
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
|
||||||
|
|
||||||
|
def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||||
|
if response.data is not None:
|
||||||
|
for key in ["pbr_model", "model", "base_model"]:
|
||||||
|
if getattr(response.data.output, key, None) is not None:
|
||||||
|
return getattr(response.data.output, key)
|
||||||
|
raise RuntimeError(f"Failed to get model url from response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
def poll_until_finished(
|
||||||
|
kwargs: dict[str, str],
|
||||||
|
response: TripoTaskResponse,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
|
||||||
|
if response.code != 0:
|
||||||
|
raise RuntimeError(f"Failed to generate mesh: {response.error}")
|
||||||
|
task_id = response.data.task_id
|
||||||
|
response_poll = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path=f"/proxy/tripo/v2/openapi/task/{task_id}",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=TripoTaskResponse,
|
||||||
|
),
|
||||||
|
completed_statuses=[TripoTaskStatus.SUCCESS],
|
||||||
|
failed_statuses=[
|
||||||
|
TripoTaskStatus.FAILED,
|
||||||
|
TripoTaskStatus.CANCELLED,
|
||||||
|
TripoTaskStatus.UNKNOWN,
|
||||||
|
TripoTaskStatus.BANNED,
|
||||||
|
TripoTaskStatus.EXPIRED,
|
||||||
|
],
|
||||||
|
status_extractor=lambda x: x.data.status,
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
node_id=kwargs["unique_id"],
|
||||||
|
result_url_extractor=get_model_url_from_response,
|
||||||
|
progress_extractor=lambda x: x.data.progress,
|
||||||
|
).execute()
|
||||||
|
if response_poll.data.status == TripoTaskStatus.SUCCESS:
|
||||||
|
url = get_model_url_from_response(response_poll)
|
||||||
|
bytesio = download_url_to_bytesio(url)
|
||||||
|
# Save the downloaded model file
|
||||||
|
model_file = f"tripo_model_{task_id}.glb"
|
||||||
|
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
|
||||||
|
f.write(bytesio.getvalue())
|
||||||
|
return model_file, task_id
|
||||||
|
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
|
||||||
|
|
||||||
|
class TripoTextToModelNode:
|
||||||
|
"""
|
||||||
|
Generates 3D models synchronously based on a text prompt using Tripo's API.
|
||||||
|
"""
|
||||||
|
AVERAGE_DURATION = 80
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt": ("STRING", {"multiline": True}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"negative_prompt": ("STRING", {"multiline": True}),
|
||||||
|
"model_version": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "model_version", enum_type=TripoModelVersion),
|
||||||
|
"style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
|
||||||
|
"texture": ("BOOLEAN", {"default": True}),
|
||||||
|
"pbr": ("BOOLEAN", {"default": True}),
|
||||||
|
"image_seed": ("INT", {"default": 42}),
|
||||||
|
"model_seed": ("INT", {"default": 42}),
|
||||||
|
"texture_seed": ("INT", {"default": 42}),
|
||||||
|
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
|
||||||
|
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
|
||||||
|
"quad": ("BOOLEAN", {"default": False})
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||||
|
RETURN_NAMES = ("model_file", "model task_id")
|
||||||
|
FUNCTION = "generate_mesh"
|
||||||
|
CATEGORY = "api node/3d/Tripo"
|
||||||
|
API_NODE = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||||
|
style_enum = None if style == "None" else style
|
||||||
|
if not prompt:
|
||||||
|
raise RuntimeError("Prompt is required")
|
||||||
|
response = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=TripoTextToModelRequest,
|
||||||
|
response_model=TripoTaskResponse,
|
||||||
|
),
|
||||||
|
request=TripoTextToModelRequest(
|
||||||
|
type=TripoTaskType.TEXT_TO_MODEL,
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
|
model_version=model_version,
|
||||||
|
style=style_enum,
|
||||||
|
texture=texture,
|
||||||
|
pbr=pbr,
|
||||||
|
image_seed=image_seed,
|
||||||
|
model_seed=model_seed,
|
||||||
|
texture_seed=texture_seed,
|
||||||
|
texture_quality=texture_quality,
|
||||||
|
face_limit=face_limit,
|
||||||
|
auto_size=True,
|
||||||
|
quad=quad
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
).execute()
|
||||||
|
return poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
class TripoImageToModelNode:
|
||||||
|
"""
|
||||||
|
Generates 3D models synchronously based on a single image using Tripo's API.
|
||||||
|
"""
|
||||||
|
AVERAGE_DURATION = 80
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"model_version": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "model_version", enum_type=TripoModelVersion),
|
||||||
|
"style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
|
||||||
|
"texture": ("BOOLEAN", {"default": True}),
|
||||||
|
"pbr": ("BOOLEAN", {"default": True}),
|
||||||
|
"model_seed": ("INT", {"default": 42}),
|
||||||
|
"orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
|
||||||
|
"texture_seed": ("INT", {"default": 42}),
|
||||||
|
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
|
||||||
|
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
|
||||||
|
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
|
||||||
|
"quad": ("BOOLEAN", {"default": False})
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||||
|
RETURN_NAMES = ("model_file", "model task_id")
|
||||||
|
FUNCTION = "generate_mesh"
|
||||||
|
CATEGORY = "api node/3d/Tripo"
|
||||||
|
API_NODE = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||||
|
style_enum = None if style == "None" else style
|
||||||
|
if image is None:
|
||||||
|
raise RuntimeError("Image is required")
|
||||||
|
tripo_file = upload_image_to_tripo(image, **kwargs)
|
||||||
|
response = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=TripoImageToModelRequest,
|
||||||
|
response_model=TripoTaskResponse,
|
||||||
|
),
|
||||||
|
request=TripoImageToModelRequest(
|
||||||
|
type=TripoTaskType.IMAGE_TO_MODEL,
|
||||||
|
file=tripo_file,
|
||||||
|
model_version=model_version,
|
||||||
|
style=style_enum,
|
||||||
|
texture=texture,
|
||||||
|
pbr=pbr,
|
||||||
|
model_seed=model_seed,
|
||||||
|
orientation=orientation,
|
||||||
|
texture_alignment=texture_alignment,
|
||||||
|
texture_seed=texture_seed,
|
||||||
|
texture_quality=texture_quality,
|
||||||
|
face_limit=face_limit,
|
||||||
|
auto_size=True,
|
||||||
|
quad=quad
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
).execute()
|
||||||
|
return poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
class TripoMultiviewToModelNode:
|
||||||
|
"""
|
||||||
|
Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API.
|
||||||
|
"""
|
||||||
|
AVERAGE_DURATION = 80
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"image_left": ("IMAGE",),
|
||||||
|
"image_back": ("IMAGE",),
|
||||||
|
"image_right": ("IMAGE",),
|
||||||
|
"model_version": model_field_to_node_input(IO.COMBO, TripoMultiviewToModelRequest, "model_version", enum_type=TripoModelVersion),
|
||||||
|
"orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
|
||||||
|
"texture": ("BOOLEAN", {"default": True}),
|
||||||
|
"pbr": ("BOOLEAN", {"default": True}),
|
||||||
|
"model_seed": ("INT", {"default": 42}),
|
||||||
|
"texture_seed": ("INT", {"default": 42}),
|
||||||
|
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
|
||||||
|
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
|
||||||
|
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
|
||||||
|
"quad": ("BOOLEAN", {"default": False})
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||||
|
RETURN_NAMES = ("model_file", "model task_id")
|
||||||
|
FUNCTION = "generate_mesh"
|
||||||
|
CATEGORY = "api node/3d/Tripo"
|
||||||
|
API_NODE = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
|
||||||
|
if image is None:
|
||||||
|
raise RuntimeError("front image for multiview is required")
|
||||||
|
images = []
|
||||||
|
image_dict = {
|
||||||
|
"image": image,
|
||||||
|
"image_left": image_left,
|
||||||
|
"image_back": image_back,
|
||||||
|
"image_right": image_right
|
||||||
|
}
|
||||||
|
if image_left is None and image_back is None and image_right is None:
|
||||||
|
raise RuntimeError("At least one of left, back, or right image must be provided for multiview")
|
||||||
|
for image_name in ["image", "image_left", "image_back", "image_right"]:
|
||||||
|
image_ = image_dict[image_name]
|
||||||
|
if image_ is not None:
|
||||||
|
tripo_file = upload_image_to_tripo(image_, **kwargs)
|
||||||
|
images.append(tripo_file)
|
||||||
|
else:
|
||||||
|
images.append(TripoFileEmptyReference())
|
||||||
|
response = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=TripoMultiviewToModelRequest,
|
||||||
|
response_model=TripoTaskResponse,
|
||||||
|
),
|
||||||
|
request=TripoMultiviewToModelRequest(
|
||||||
|
type=TripoTaskType.MULTIVIEW_TO_MODEL,
|
||||||
|
files=images,
|
||||||
|
model_version=model_version,
|
||||||
|
orientation=orientation,
|
||||||
|
texture=texture,
|
||||||
|
pbr=pbr,
|
||||||
|
model_seed=model_seed,
|
||||||
|
texture_seed=texture_seed,
|
||||||
|
texture_quality=texture_quality,
|
||||||
|
texture_alignment=texture_alignment,
|
||||||
|
face_limit=face_limit,
|
||||||
|
quad=quad,
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
).execute()
|
||||||
|
return poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
class TripoTextureNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model_task_id": ("MODEL_TASK_ID",),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"texture": ("BOOLEAN", {"default": True}),
|
||||||
|
"pbr": ("BOOLEAN", {"default": True}),
|
||||||
|
"texture_seed": ("INT", {"default": 42}),
|
||||||
|
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
|
||||||
|
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||||
|
RETURN_NAMES = ("model_file", "model task_id")
|
||||||
|
FUNCTION = "generate_mesh"
|
||||||
|
CATEGORY = "api node/3d/Tripo"
|
||||||
|
API_NODE = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
AVERAGE_DURATION = 80
|
||||||
|
|
||||||
|
def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
|
||||||
|
response = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=TripoTextureModelRequest,
|
||||||
|
response_model=TripoTaskResponse,
|
||||||
|
),
|
||||||
|
request=TripoTextureModelRequest(
|
||||||
|
original_model_task_id=model_task_id,
|
||||||
|
texture=texture,
|
||||||
|
pbr=pbr,
|
||||||
|
texture_seed=texture_seed,
|
||||||
|
texture_quality=texture_quality,
|
||||||
|
texture_alignment=texture_alignment
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
).execute()
|
||||||
|
return poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
|
||||||
|
class TripoRefineNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model_task_id": ("MODEL_TASK_ID", {
|
||||||
|
"tooltip": "Must be a v1.4 Tripo model"
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DESCRIPTION = "Refine a draft model created by v1.4 Tripo models only."
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
|
||||||
|
RETURN_NAMES = ("model_file", "model task_id")
|
||||||
|
FUNCTION = "generate_mesh"
|
||||||
|
CATEGORY = "api node/3d/Tripo"
|
||||||
|
API_NODE = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
AVERAGE_DURATION = 240
|
||||||
|
|
||||||
|
def generate_mesh(self, model_task_id, **kwargs):
|
||||||
|
response = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=TripoRefineModelRequest,
|
||||||
|
response_model=TripoTaskResponse,
|
||||||
|
),
|
||||||
|
request=TripoRefineModelRequest(
|
||||||
|
draft_model_task_id=model_task_id
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
).execute()
|
||||||
|
return poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
|
||||||
|
class TripoRigNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"original_model_task_id": ("MODEL_TASK_ID",),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING", "RIG_TASK_ID")
|
||||||
|
RETURN_NAMES = ("model_file", "rig task_id")
|
||||||
|
FUNCTION = "generate_mesh"
|
||||||
|
CATEGORY = "api node/3d/Tripo"
|
||||||
|
API_NODE = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
AVERAGE_DURATION = 180
|
||||||
|
|
||||||
|
def generate_mesh(self, original_model_task_id, **kwargs):
|
||||||
|
response = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=TripoAnimateRigRequest,
|
||||||
|
response_model=TripoTaskResponse,
|
||||||
|
),
|
||||||
|
request=TripoAnimateRigRequest(
|
||||||
|
original_model_task_id=original_model_task_id,
|
||||||
|
out_format="glb",
|
||||||
|
spec="tripo"
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
).execute()
|
||||||
|
return poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
class TripoRetargetNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"original_model_task_id": ("RIG_TASK_ID",),
|
||||||
|
"animation": ([
|
||||||
|
"preset:idle",
|
||||||
|
"preset:walk",
|
||||||
|
"preset:climb",
|
||||||
|
"preset:jump",
|
||||||
|
"preset:slash",
|
||||||
|
"preset:shoot",
|
||||||
|
"preset:hurt",
|
||||||
|
"preset:fall",
|
||||||
|
"preset:turn",
|
||||||
|
],),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING", "RETARGET_TASK_ID")
|
||||||
|
RETURN_NAMES = ("model_file", "retarget task_id")
|
||||||
|
FUNCTION = "generate_mesh"
|
||||||
|
CATEGORY = "api node/3d/Tripo"
|
||||||
|
API_NODE = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
AVERAGE_DURATION = 30
|
||||||
|
|
||||||
|
def generate_mesh(self, animation, original_model_task_id, **kwargs):
|
||||||
|
response = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=TripoAnimateRetargetRequest,
|
||||||
|
response_model=TripoTaskResponse,
|
||||||
|
),
|
||||||
|
request=TripoAnimateRetargetRequest(
|
||||||
|
original_model_task_id=original_model_task_id,
|
||||||
|
animation=animation,
|
||||||
|
out_format="glb",
|
||||||
|
bake_animation=True
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
).execute()
|
||||||
|
return poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
class TripoConversionNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"original_model_task_id": ("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID",),
|
||||||
|
"format": (["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"],),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"quad": ("BOOLEAN", {"default": False}),
|
||||||
|
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
|
||||||
|
"texture_size": ("INT", {"min": 128, "max": 4096, "default": 4096}),
|
||||||
|
"texture_format": (["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], {"default": "JPEG"})
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls, input_types):
|
||||||
|
# The min and max of input1 and input2 are still validated because
|
||||||
|
# we didn't take `input1` or `input2` as arguments
|
||||||
|
if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"):
|
||||||
|
return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type"
|
||||||
|
return True
|
||||||
|
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "generate_mesh"
|
||||||
|
CATEGORY = "api node/3d/Tripo"
|
||||||
|
API_NODE = True
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
AVERAGE_DURATION = 30
|
||||||
|
|
||||||
|
def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
|
||||||
|
if not original_model_task_id:
|
||||||
|
raise RuntimeError("original_model_task_id is required")
|
||||||
|
response = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=TripoConvertModelRequest,
|
||||||
|
response_model=TripoTaskResponse,
|
||||||
|
),
|
||||||
|
request=TripoConvertModelRequest(
|
||||||
|
original_model_task_id=original_model_task_id,
|
||||||
|
format=format,
|
||||||
|
quad=quad if quad else None,
|
||||||
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
|
texture_size=texture_size if texture_size != 4096 else None,
|
||||||
|
texture_format=texture_format if texture_format != "JPEG" else None
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
).execute()
|
||||||
|
return poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TripoTextToModelNode": TripoTextToModelNode,
|
||||||
|
"TripoImageToModelNode": TripoImageToModelNode,
|
||||||
|
"TripoMultiviewToModelNode": TripoMultiviewToModelNode,
|
||||||
|
"TripoTextureNode": TripoTextureNode,
|
||||||
|
"TripoRefineNode": TripoRefineNode,
|
||||||
|
"TripoRigNode": TripoRigNode,
|
||||||
|
"TripoRetargetNode": TripoRetargetNode,
|
||||||
|
"TripoConversionNode": TripoConversionNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"TripoTextToModelNode": "Tripo: Text to Model",
|
||||||
|
"TripoImageToModelNode": "Tripo: Image to Model",
|
||||||
|
"TripoMultiviewToModelNode": "Tripo: Multiview to Model",
|
||||||
|
"TripoTextureNode": "Tripo: Texture model",
|
||||||
|
"TripoRefineNode": "Tripo: Refine Draft model",
|
||||||
|
"TripoRigNode": "Tripo: Rig model",
|
||||||
|
"TripoRetargetNode": "Tripo: Retarget rigged model",
|
||||||
|
"TripoConversionNode": "Tripo: Convert model",
|
||||||
|
}
|
4
nodes.py
4
nodes.py
@ -2281,6 +2281,10 @@ def init_builtin_api_nodes():
|
|||||||
"nodes_pixverse.py",
|
"nodes_pixverse.py",
|
||||||
"nodes_stability.py",
|
"nodes_stability.py",
|
||||||
"nodes_pika.py",
|
"nodes_pika.py",
|
||||||
|
"nodes_runway.py",
|
||||||
|
"nodes_tripo.py",
|
||||||
|
"nodes_rodin.py",
|
||||||
|
"nodes_gemini.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.20.6
|
comfyui-frontend-package==1.20.6
|
||||||
comfyui-workflow-templates==0.1.18
|
comfyui-workflow-templates==0.1.20
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
Loading…
x
Reference in New Issue
Block a user