Use common function to reshape batch to.

This commit is contained in:
comfyanonymous
2023-09-02 03:42:49 -04:00
parent 36ea8784a8
commit 77a176f9e0
2 changed files with 10 additions and 5 deletions

View File

@@ -223,6 +223,13 @@ def unet_to_diffusers(unet_config):
return diffusers_unet_map
def repeat_to_batch_size(tensor, batch_size):
if tensor.shape[0] > batch_size:
return tensor[:batch_size]
elif tensor.shape[0] < batch_size:
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
return tensor
def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys())
for k in keys: