mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 15:17:14 +00:00
fix multiple image return from api nodes (#7772)
This commit is contained in:
parent
e2eed9eb9b
commit
5c80da31db
@ -31,35 +31,43 @@ def downscale_input(image):
|
|||||||
s = s.movedim(1,-1)
|
s = s.movedim(1,-1)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def validate_and_cast_response (response):
|
def validate_and_cast_response(response):
|
||||||
# validate raw JSON response
|
# validate raw JSON response
|
||||||
data = response.data
|
data = response.data
|
||||||
if not data or len(data) == 0:
|
if not data or len(data) == 0:
|
||||||
raise Exception("No images returned from API endpoint")
|
raise Exception("No images returned from API endpoint")
|
||||||
|
|
||||||
# Get base64 image data
|
# Initialize list to store image tensors
|
||||||
image_url = data[0].url
|
image_tensors = []
|
||||||
b64_data = data[0].b64_json
|
|
||||||
if not image_url and not b64_data:
|
|
||||||
raise Exception("No image was generated in the response")
|
|
||||||
|
|
||||||
if b64_data:
|
# Process each image in the data array
|
||||||
img_data = base64.b64decode(b64_data)
|
for image_data in data:
|
||||||
img = Image.open(io.BytesIO(img_data))
|
image_url = image_data.url
|
||||||
|
b64_data = image_data.b64_json
|
||||||
|
|
||||||
elif image_url:
|
if not image_url and not b64_data:
|
||||||
img_response = requests.get(image_url)
|
raise Exception("No image was generated in the response")
|
||||||
if img_response.status_code != 200:
|
|
||||||
raise Exception("Failed to download the image")
|
|
||||||
img = Image.open(io.BytesIO(img_response.content))
|
|
||||||
|
|
||||||
img = img.convert("RGBA")
|
if b64_data:
|
||||||
|
img_data = base64.b64decode(b64_data)
|
||||||
|
img = Image.open(io.BytesIO(img_data))
|
||||||
|
|
||||||
# Convert to numpy array, normalize to float32 between 0 and 1
|
elif image_url:
|
||||||
img_array = np.array(img).astype(np.float32) / 255.0
|
img_response = requests.get(image_url)
|
||||||
|
if img_response.status_code != 200:
|
||||||
|
raise Exception("Failed to download the image")
|
||||||
|
img = Image.open(io.BytesIO(img_response.content))
|
||||||
|
|
||||||
# Convert to torch tensor and add batch dimension
|
img = img.convert("RGBA")
|
||||||
return torch.from_numpy(img_array)[None,]
|
|
||||||
|
# Convert to numpy array, normalize to float32 between 0 and 1
|
||||||
|
img_array = np.array(img).astype(np.float32) / 255.0
|
||||||
|
img_tensor = torch.from_numpy(img_array)
|
||||||
|
|
||||||
|
# Add to list of tensors
|
||||||
|
image_tensors.append(img_tensor)
|
||||||
|
|
||||||
|
return torch.stack(image_tensors, dim=0)
|
||||||
|
|
||||||
class OpenAIDalle2(ComfyNodeABC):
|
class OpenAIDalle2(ComfyNodeABC):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user