mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-10 11:35:40 +00:00
Adding Google Gemini Image API node (#9566)
* bigcat88's progress on adding Google Gemini Image node * Made Google Gemini Image node functional * Bump frontend version to get static pricing badge on Gemini Image node
This commit is contained in:
19
comfy_api_nodes/apis/gemini_api.py
Normal file
19
comfy_api_nodes/apis/gemini_api.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||
responseModalities: Optional[List[str]] = None
|
||||
|
||||
|
||||
class GeminiImageGenerateContentRequest(BaseModel):
|
||||
contents: List[GeminiContent]
|
||||
generationConfig: Optional[GeminiImageGenerationConfig] = None
|
||||
safetySettings: Optional[List[GeminiSafetySetting]] = None
|
||||
systemInstruction: Optional[GeminiSystemInstructionContent] = None
|
||||
tools: Optional[List[GeminiTool]] = None
|
||||
videoMetadata: Optional[GeminiVideoMetadata] = None
|
@@ -4,11 +4,12 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import uuid
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal
|
||||
|
||||
@@ -25,6 +26,7 @@ from comfy_api_nodes.apis import (
|
||||
GeminiPart,
|
||||
GeminiMimeType,
|
||||
)
|
||||
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
@@ -35,6 +37,7 @@ from comfy_api_nodes.apinode_utils import (
|
||||
audio_to_base64_string,
|
||||
video_to_base64_string,
|
||||
tensor_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
)
|
||||
|
||||
|
||||
@@ -53,6 +56,14 @@ class GeminiModel(str, Enum):
|
||||
gemini_2_5_flash = "gemini-2.5-flash"
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
"""
|
||||
Gemini Image Model Names allowed by comfy-api
|
||||
"""
|
||||
|
||||
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
|
||||
|
||||
|
||||
def get_gemini_endpoint(
|
||||
model: GeminiModel,
|
||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
||||
@@ -75,6 +86,135 @@ def get_gemini_endpoint(
|
||||
)
|
||||
|
||||
|
||||
def get_gemini_image_endpoint(
|
||||
model: GeminiImageModel,
|
||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
||||
"""
|
||||
Get the API endpoint for a given Gemini model.
|
||||
|
||||
Args:
|
||||
model: The Gemini model to use, either as enum or string value.
|
||||
|
||||
Returns:
|
||||
ApiEndpoint configured for the specific Gemini model.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = GeminiImageModel(model)
|
||||
return ApiEndpoint(
|
||||
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
|
||||
method=HttpMethod.POST,
|
||||
request_model=GeminiImageGenerateContentRequest,
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
|
||||
|
||||
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert image tensor input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
image_input: Batch of image tensors from ComfyUI.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded images.
|
||||
"""
|
||||
image_parts: list[GeminiPart] = []
|
||||
for image_index in range(image_input.shape[0]):
|
||||
image_as_b64 = tensor_to_base64_string(
|
||||
image_input[image_index].unsqueeze(0)
|
||||
)
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=image_as_b64,
|
||||
)
|
||||
)
|
||||
)
|
||||
return image_parts
|
||||
|
||||
|
||||
def create_text_part(text: str) -> GeminiPart:
|
||||
"""
|
||||
Create a text part for the Gemini API request.
|
||||
|
||||
Args:
|
||||
text: The text content to include in the request.
|
||||
|
||||
Returns:
|
||||
A GeminiPart object with the text content.
|
||||
"""
|
||||
return GeminiPart(text=text)
|
||||
|
||||
|
||||
def get_parts_from_response(
|
||||
response: GeminiGenerateContentResponse
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Extract all parts from the Gemini API response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
List of response parts from the first candidate.
|
||||
"""
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
|
||||
def get_parts_by_type(
|
||||
response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
part_type: Type of parts to extract ("text" or a MIME type).
|
||||
|
||||
Returns:
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
parts = []
|
||||
for part in get_parts_from_response(response):
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
parts.append(part)
|
||||
elif (
|
||||
hasattr(part, "inlineData")
|
||||
and part.inlineData
|
||||
and part.inlineData.mimeType == part_type
|
||||
):
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
return parts
|
||||
|
||||
|
||||
def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
"""
|
||||
Extract and concatenate all text parts from the response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
Combined text from all text parts in the response.
|
||||
"""
|
||||
parts = get_parts_by_type(response, "text")
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
parts = get_parts_by_type(response, "image/png")
|
||||
for part in parts:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
image_tensors.append(returned_image)
|
||||
if len(image_tensors) == 0:
|
||||
return torch.zeros((1,1024,1024,4))
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
|
||||
class GeminiNode(ComfyNodeABC):
|
||||
"""
|
||||
Node to generate text responses from a Gemini model.
|
||||
@@ -159,59 +299,6 @@ class GeminiNode(ComfyNodeABC):
|
||||
CATEGORY = "api node/text/Gemini"
|
||||
API_NODE = True
|
||||
|
||||
def get_parts_from_response(
|
||||
self, response: GeminiGenerateContentResponse
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Extract all parts from the Gemini API response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
List of response parts from the first candidate.
|
||||
"""
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
def get_parts_by_type(
|
||||
self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
part_type: Type of parts to extract ("text" or a MIME type).
|
||||
|
||||
Returns:
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
parts = []
|
||||
for part in self.get_parts_from_response(response):
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
parts.append(part)
|
||||
elif (
|
||||
hasattr(part, "inlineData")
|
||||
and part.inlineData
|
||||
and part.inlineData.mimeType == part_type
|
||||
):
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
return parts
|
||||
|
||||
def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
|
||||
"""
|
||||
Extract and concatenate all text parts from the response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
Combined text from all text parts in the response.
|
||||
"""
|
||||
parts = self.get_parts_by_type(response, "text")
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert video input to Gemini API compatible parts.
|
||||
@@ -271,43 +358,6 @@ class GeminiNode(ComfyNodeABC):
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert image tensor input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
image_input: Batch of image tensors from ComfyUI.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded images.
|
||||
"""
|
||||
image_parts: list[GeminiPart] = []
|
||||
for image_index in range(image_input.shape[0]):
|
||||
image_as_b64 = tensor_to_base64_string(
|
||||
image_input[image_index].unsqueeze(0)
|
||||
)
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=image_as_b64,
|
||||
)
|
||||
)
|
||||
)
|
||||
return image_parts
|
||||
|
||||
def create_text_part(self, text: str) -> GeminiPart:
|
||||
"""
|
||||
Create a text part for the Gemini API request.
|
||||
|
||||
Args:
|
||||
text: The text content to include in the request.
|
||||
|
||||
Returns:
|
||||
A GeminiPart object with the text content.
|
||||
"""
|
||||
return GeminiPart(text=text)
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -323,11 +373,11 @@ class GeminiNode(ComfyNodeABC):
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
# Create parts list with text prompt as the first part
|
||||
parts: list[GeminiPart] = [self.create_text_part(prompt)]
|
||||
parts: list[GeminiPart] = [create_text_part(prompt)]
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
image_parts = self.create_image_parts(images)
|
||||
image_parts = create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
if audio is not None:
|
||||
parts.extend(self.create_audio_parts(audio))
|
||||
@@ -351,7 +401,7 @@ class GeminiNode(ComfyNodeABC):
|
||||
).execute()
|
||||
|
||||
# Get result output
|
||||
output_text = self.get_text_from_response(response)
|
||||
output_text = get_text_from_response(response)
|
||||
if unique_id and output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
@@ -462,12 +512,162 @@ class GeminiInputFiles(ComfyNodeABC):
|
||||
return (files,)
|
||||
|
||||
|
||||
class GeminiImage(ComfyNodeABC):
|
||||
"""
|
||||
Node to generate text and image responses from a Gemini model.
|
||||
|
||||
This node allows users to interact with Google's Gemini AI models, providing
|
||||
multimodal inputs (text, images, files) to generate coherent
|
||||
text and image responses. The node works with the latest Gemini models, handling the
|
||||
API communication and response parsing.
|
||||
"""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for generation",
|
||||
},
|
||||
),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The Gemini model to use for generating responses.",
|
||||
"options": [model.value for model in GeminiImageModel],
|
||||
"default": GeminiImageModel.gemini_2_5_flash_image_preview.value,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 42,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"images": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
),
|
||||
"files": (
|
||||
"GEMINI_INPUT_FILES",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
},
|
||||
),
|
||||
# TODO: later we can add this parameter later
|
||||
# "n": (
|
||||
# IO.INT,
|
||||
# {
|
||||
# "default": 1,
|
||||
# "min": 1,
|
||||
# "max": 8,
|
||||
# "step": 1,
|
||||
# "display": "number",
|
||||
# "tooltip": "How many images to generate",
|
||||
# },
|
||||
# ),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE, IO.STRING)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Gemini"
|
||||
DESCRIPTION = "Edit images synchronously via Google API."
|
||||
API_NODE = True
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
prompt: str,
|
||||
model: GeminiImageModel,
|
||||
images: Optional[IO.IMAGE] = None,
|
||||
files: Optional[list[GeminiPart]] = None,
|
||||
n=1,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Validate inputs
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
# Create parts list with text prompt as the first part
|
||||
parts: list[GeminiPart] = [create_text_part(prompt)]
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
image_parts = create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
response = await SynchronousOperation(
|
||||
endpoint=get_gemini_image_endpoint(model),
|
||||
request=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=parts,
|
||||
),
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=["TEXT","IMAGE"]
|
||||
)
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
|
||||
output_image = get_image_from_response(response)
|
||||
output_text = get_text_from_response(response)
|
||||
if unique_id and output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
|
||||
output_text = output_text or "Empty response from Gemini model..."
|
||||
return (output_image, output_text,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"GeminiNode": GeminiNode,
|
||||
"GeminiImageNode": GeminiImage,
|
||||
"GeminiInputFiles": GeminiInputFiles,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GeminiNode": "Google Gemini",
|
||||
"GeminiImageNode": "Google Gemini Image",
|
||||
"GeminiInputFiles": "Gemini Input Files",
|
||||
}
|
||||
|
@@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.25.10
|
||||
comfyui-frontend-package==1.25.11
|
||||
comfyui-workflow-templates==0.1.66
|
||||
comfyui-embedded-docs==0.2.6
|
||||
torch
|
||||
|
Reference in New Issue
Block a user